堆与优先队列
一、堆的定义与性质
堆(heap)是一棵完全二叉树,且满足堆性质:每个节点的值与其父节点满足大小关系。若「父 ≥ 子」则称为大根堆(max-heap),根为最大值;若「父 ≤ 子」则称为小根堆(min-heap),根为最小值。本章默认讨论大根堆,小根堆对称即可。
完全二叉树可以用数组按层序存储:根在下标 0,节点 i 的左子下标 2*i+1、右子 2*i+2、父 (i-1)//2。这样无需指针,且父子关系可 O(1) 计算。
二、堆化与基本操作
堆化(heapify)指让某一节点满足堆性质的过程。若该节点比子节点小(大根堆),需要把它「下沉」(sift-down):与较大的那个子交换,直到没有子或已 ≥ 子。若在末尾插入新元素,则需「上浮」(sift-up):与父比较,若大于父则交换并继续上浮,直到根或已 ≤ 父。
建堆:从最后一个非叶节点(下标 (n-2)//2)开始,从右到左、从下到上依次做下沉,这样每个子树先成堆,最后整棵成堆。可以证明总比较次数为 O(n),即线性建堆。
插入:将新元素放在数组末尾,再对其做上浮。O(log n)。
取最大值 / 弹出根:根即最大值。弹出时用末尾元素覆盖根,长度减一,再对根做一次下沉。O(log n)。
三、堆的 Python 实现
下面用列表实现大根堆,包含:父子下标函数、下沉、上浮、建堆、插入、弹出根。类型标注与注释完整,便于理解。
from __future__ import annotations
from typing import List
def _parent(i: int) -> int:
"""节点 i 的父节点下标;根 0 的父可视为 -1。"""
return (i - 1) // 2
def _left(i: int) -> int:
"""节点 i 的左子下标。"""
return 2 * i + 1
def _right(i: int) -> int:
"""节点 i 的右子下标。"""
return 2 * i + 2
def _sift_down_max(heap: List[int], n: int, i: int) -> None:
"""
大根堆下沉:以 i 为根的子树中,若 i 不满足堆性质则与较大子交换并递归。
n 为当前堆有效长度(下标 [0, n))。
"""
left = _left(i)
right = _right(i)
largest = i
if left < n and heap[left] > heap[largest]:
largest = left
if right < n and heap[right] > heap[largest]:
largest = right
if largest != i:
heap[i], heap[largest] = heap[largest], heap[i]
_sift_down_max(heap, n, largest)
def _sift_up_max(heap: List[int], i: int) -> None:
"""大根堆上浮:节点 i 若大于父则与父交换并继续上浮。"""
while i > 0:
p = _parent(i)
if heap[i] <= heap[p]:
break
heap[i], heap[p] = heap[p], heap[i]
i = p
def heapify_max(arr: List[int]) -> None:
"""将数组原地建为大根堆。从最后一个非叶开始自底向上下沉,O(n)。"""
n = len(arr)
# 最后一个非叶节点下标为 (n-2)//2
for i in range((n - 2) // 2, -1, -1):
_sift_down_max(arr, n, i)
def heap_push_max(heap: List[int], x: int) -> None:
"""大根堆插入:末尾追加后上浮,O(log n)。"""
heap.append(x)
_sift_up_max(heap, len(heap) - 1)
def heap_pop_max(heap: List[int]) -> int:
"""大根堆弹出最大值:根与末尾交换,删末尾,再对根下沉,O(log n)。"""
if not heap:
raise IndexError("pop from empty heap")
n = len(heap)
heap[0], heap[n - 1] = heap[n - 1], heap[0]
val = heap.pop()
if heap:
_sift_down_max(heap, len(heap), 0)
return val
四、优先队列与 heapq
优先队列(priority queue)是抽象数据类型:支持「插入带优先级的元素」和「取出当前优先级最高(或最低)的元素」。用大根堆可实现「取最大」;用小根堆可实现「取最小」。Python 标准库 heapq 提供的是小根堆(基于 list),因此 heapq.heappush / heapq.heappop 得到的是最小值。若要「取最大」,可对数值取负再入堆,或使用 heapq.nlargest 等。
heapq.heapify(list) 可原地将 list 变为小根堆;heapq.heappush(h, x) 插入;heapq.heappop(h) 弹出最小。均摊 O(log n)。
import heapq
from typing import List
def priority_queue_demo() -> None:
"""使用 heapq 实现「取最小」的优先队列。"""
h: List[int] = []
heapq.heappush(h, 5)
heapq.heappush(h, 2)
heapq.heappush(h, 8)
# 弹出顺序:2, 5, 8(每次 pop 为当前最小)
first: int = heapq.heappop(h) # 2
second: int = heapq.heappop(h) # 5
third: int = heapq.heappop(h) # 8
def k_smallest_with_heapq(nums: List[int], k: int) -> List[int]:
"""取前 k 小:使用 heapq.nsmallest(k, nums),内部用堆实现,O(n log k)。"""
if k <= 0 or not nums:
return []
return heapq.nsmallest(k, nums)
五、Top-K 问题
「在 n 个元素里找最大的(或最小的)K 个」是典型应用。若 K 很小,可以用大小为 K 的小根堆维护「当前最大的 K 个」:遍历时若堆未满则入堆;若已满且当前元素大于堆顶(最小),则替换堆顶并下沉,这样堆里始终是已扫描过的元素里最大的 K 个。最终堆顶为第 K 大,堆内为前 K 大。时间 O(n log K),空间 O(K)。同理,找最小的 K 个可用大小为 K 的大根堆。
六、调度与扩展
任务调度中,按截止时间或优先级排序后,用优先队列每次取「当前最紧急」任务执行,是典型用法。合并 K 个有序链表时,可把每个链表的当前头节点放入小根堆,每次弹出最小、将其后继入堆,得到全局最小顺序。堆排序即先建堆再不断弹出根,得到从大到小(或从小到大)的序列,时间 O(n log n),空间 O(1) 若原地。
heapq 只支持小根堆;需要大根堆时对数值取负即可。多键比较可用 (优先级, 计数器, 元素) 元组入堆,避免直接比较不可比对象。
七、小结
堆是满足「父 ≥ 子」(大根堆)或「父 ≤ 子」(小根堆)的完全二叉树,用数组存,父子下标可公式计算。堆化有下沉与上浮;建堆 O(n),插入与弹出 O(log n)。优先队列用堆实现;Top-K 用大小为 K 的堆在 O(n log K) 内解决。下一章进入图:顶点、边与 DFS、BFS。
八、例题
以下例题使用 Python 3 类型标注与注释;涉及堆时可用手写大根/小根堆或 heapq。
例题 1:数组中的第 K 个最大元素
题目描述:给定整数数组 nums 和整数 k,返回数组中第 k 个最大的元素(非降序排列后第 k 大,即第 n-k+1 小)。
输入:1 ≤ k ≤ len(nums)。
输出:第 k 大的值。
import heapq
from typing import List
def find_kth_largest(nums: List[int], k: int) -> int:
"""
用大小为 k 的小根堆维护「当前最大的 k 个」。
遍历时若堆未满则入堆;若已满且当前数大于堆顶,则替换堆顶并堆化。
最终堆顶即为第 k 大。时间 O(n log k),空间 O(k)。
"""
heap: List[int] = []
for x in nums:
if len(heap) < k:
heapq.heappush(heap, x)
elif x > heap[0]:
heapq.heapreplace(heap, x) # 等价于 pop 再 push,但一次完成
return heap[0]
讲解:小根堆堆顶是堆内最小。堆内始终保持「已扫描过的元素里最大的 k 个」,因此堆顶就是这 k 个里最小的,即全局第 k 大。heapreplace(heap, x) 先 pop 再 push(x),比两次调用更高效。
例题 2:前 K 个高频元素
题目描述:给定整数数组 nums 和整数 k,返回出现频率前 k 高的元素。答案顺序不限。
输入:1 ≤ k ≤ 不同元素个数。
输出:前 k 个高频元素的列表。
import heapq
from typing import List
from collections import Counter
def top_k_frequent(nums: List[int], k: int) -> List[int]:
"""
先统计频次,再用「按频次为优先级」的小根堆维护前 k 个高频。
堆内元素为 (频次, 数值);堆顶为当前 k 个里频次最小的,新元素频次更高则替换。
"""
cnt: Counter[int] = Counter(nums)
heap: List[tuple] = [] # (freq, num),heapq 按元组第一项比较
for num, freq in cnt.items():
if len(heap) < k:
heapq.heappush(heap, (freq, num))
elif freq > heap[0][0]:
heapq.heapreplace(heap, (freq, num))
return [item[1] for item in heap]
讲解:Counter 统计每个数出现次数。用大小为 k 的小根堆,以 (频次, 数值) 为元素,堆顶是「当前 k 个里频次最低的」。遍历所有 (num, freq) 时,若 freq 大于堆顶频次则替换,保证堆内为频次最高的 k 个。最后从堆中取出数值即可。时间 O(n + M log k),M 为不同元素个数。
例题 3:合并 K 个升序链表
题目描述:给定 k 个升序链表,将其合并为一个升序链表并返回头节点。
输入:lists[i] 为第 i 条链表的头;链表节点含 val、next。
输出:合并后的链表头。
import heapq
from typing import List, Optional
class ListNode:
def __init__(self, val: int = 0, next: Optional["ListNode"] = None) -> None:
self.val = val
self.next = next
def merge_k_lists(lists: List[Optional[ListNode]]) -> Optional[ListNode]:
"""
小根堆维护每条链表当前未合并的最小节点。
堆中存 (node.val, 链表编号或 id(node), node),避免节点直接比较。
每次弹出最小节点,将其 next 入堆,接到结果链尾。
"""
dummy = ListNode(0)
tail = dummy
# (val, idx, node) 用 idx 避免两节点 val 相同时比较 node
heap: List[tuple] = []
for i, head in enumerate(lists):
if head is not None:
heapq.heappush(heap, (head.val, i, head))
while heap:
_val, _idx, node = heapq.heappop(heap)
tail.next = node
tail = tail.next
if node.next is not None:
heapq.heappush(heap, (node.next.val, _idx, node.next))
return dummy.next
讲解:每条链表的当前头节点中,全局最小必是某条链的头。用堆维护 k 个链头,每次取最小接到结果链上,并把该链的下一个节点入堆。用 (val, idx, node) 可避免 Python 中 ListNode 无法比较的问题。总节点数 N 时,时间 O(N log k),空间 O(k)。
例题 4:数据流的中位数
题目描述:支持不断加入整数,并随时查询当前已加入数字的中位数(若个数为偶则取中间两个的平均)。
思路:用一个大根堆存「较小一半」的最大值,一个小根堆存「较大一半」的最小值。保持两堆大小相差不超过 1,则中位数由堆顶可得。
import heapq
from typing import List
class MedianFinder:
"""
较小一半放大根堆 lo,较大一半放小根堆 hi。
约定 lo 允许比 hi 多 1(总数为奇时中位在 lo 顶)。
加入数时先入 lo,再把 lo 顶移到 hi;若 hi 更多则把 hi 顶移回 lo,保持平衡。
"""
def __init__(self) -> None:
self.lo: List[int] = [] # 大根堆:存负值,heapq 当小根堆用
self.hi: List[int] = [] # 小根堆
def add_num(self, num: int) -> None:
heapq.heappush(self.lo, -num)
heapq.heappush(self.hi, -heapq.heappop(self.lo))
if len(self.hi) > len(self.lo):
heapq.heappush(self.lo, -heapq.heappop(self.hi))
def find_median(self) -> float:
if len(self.lo) > len(self.hi):
return float(-self.lo[0])
return (-self.lo[0] + self.hi[0]) / 2.0
讲解:新数先入大根堆 lo(用负值实现),再把 lo 的堆顶移到 hi,保证 lo 中任意元素 ≤ hi 中任意元素。若 hi size 大于 lo,则把 hi 顶移回 lo,使 lo 大小 ≥ hi。中位数:若 lo 多一个则为 lo 顶;否则为两堆顶平均。单次加入 O(log n),查询 O(1)。