最近在刷題的時候,遇到一個涉及到線段樹的問題。之前沒接觸過,看了幾遍題解才看懂。這裡簡單介紹下入門的過程。
高級數據結構,線段樹入門
一、線段樹的基本思想線段樹是一種常用來維護區間信息的數據結構,它適用於對區間內進行單點查詢、更新、求最值等操作,且時間複雜度能控制到 O(logN)。它的構建過程用到了二分的思想,通過不斷的二分將區間分成兩段,並分別對應左孩子和右孩子。
下面舉例來說明:比如有一個數組 [1, 2, 5, 7, 8, 10,12,18],它的長度是 8,所以範圍是 [1, 8]。如果用二分的思想來分解構造出的線段樹如下所示:
接下來我們來看看怎麼定義線段樹的數據結構。通常有兩種方式,一種方式是定義一個 class,一種方式是使用連續的數組。首先我們來看下自定義 class 的方式,這裡使用 Python 代碼:
class SegTree:
def __init__(self, left, right):
# 當前結點的左邊界
self.lo = left
# 當前結點的右邊界
self.hi = right
# 記錄額外的信息,這裡通常可以是最值或是區間和,根據題目需求來定義
self.other_inf = 0
# 左、右孩子
self.left = None
self.right = None
這種定義方式比較直接,但遍歷起來稍麻煩一點。而第二種方式是使用連續的數組。從上圖我們構造的線段樹可以看出,拋開葉子結點,樹是一個滿二叉樹,所以可以使用連續的數組來存儲,且父子結點的關係為
parent[i]
parent.left = parent[i*2]
parent.right = parent[i*2+1]
# 對於 i * 2 和 i*2+1,使用位運算可快速得到
i * 2 = i << 1
i * 2 + 1 = i << 1 | 1
使用數組時需要注意,葉子結點其實就是對應的給定數組的值,但數組的長度不一定能滿足滿二叉樹葉子結點的個數,這個時候代碼編寫上就比較靈活了,一般有兩種方式:
這種思路的其實相對比較好理解,因為給定的數組都需要放到葉子結點,那如果想要樹是一棵滿二叉樹,則葉子結點的個數必須是 2^n。所以我們需要找到第一個大於等於數組長度的 2 的 n 次冪。對於求第一個大於等於數組長度的 2 的 n 次冪的方法有很多,通過幾個位運算就能實現的,可以參考 Java HashMap 的源碼,也可以看 Integer 的 highestOneBit 方法,代碼如下(這裡不解釋具體原因):
public static int highestOneBit(int i) {
// HD, Figure 3-1
i |= (i >> 1);
i |= (i >> 2);
i |= (i >> 4);
i |= (i >> 8);
i |= (i >> 16);
return i - (i >>> 1);
}
而用一個我們比較好理解的方法,如下:
n = 1
while n < len(nums):
n <<= 1
找到這個數值之後,就可以進行初始化:
# 因為 n 是第一個大於或等於 len(nums) 的 2 次冪,它是等於它之前所有結點和 + 1 的
# 而一般在線段樹中第 0 位通常不用
# 因此 [0] * n 即所有非葉結點的初始化
# [nums] 則是初始化數組,並將其分配到葉子結點
# [0] * (n - len(nums)) 葉子結點未被分配到值的用 0 補全
self.seg_tree = [0] * n + nums + [0] * (n - len(nums))
# 初始化賦值, 根據父子關係的公式
for k in range(n - 1, 0, -1):
self.seg_tree[k] = self.seg_tree[2 * k] + self.seg_tree[2 * k + 1]
# 這裡的做法參考:https://leetcode-cn.com/problems/range-sum-query-mutable/solution/python-shu-zhuang-shu-zu-binary-indexed-tree-by-ze/
因為如果長度為 n 的數組都需要放到葉子結點上,則它的上層有 n/2 個結點,再上層 n/4…,根據等比求和公式很容易得出所有結點個數一定小於 2n。所以我們整個線段樹數組的值設置為 2n 就足夠使用了。代碼如下:
n = len(nums)
self._n = n
self._tree = [0] * (n << 1)
# 將最後 n 位放置到葉子結點,也就是數組的最後 n 位
for i in range(n, len(self._tree)):
self._tree[i] = nums[i - n]
for i in range(n - 1, 0, -1):
# 父結點 = 左結點(父結點序號*2) + 右結點(父結果序號*2+1)
self._tree[i] = self._tree[i << 1] + self._tree[i << 1 | 1]
此外,線段樹常用的方法有:
# 將第 i 個位置的數值更新為 val
def update(i, val)
# 將第 i 個位置的數值加上 val
def add(i, val)
# 查詢區間[i, j]上的區間和或最值,根據具體需求來具體分析
def query(i, j)
這裡就不一一實現這些方法,通過兩個題目來具體實踐下。
二、實踐首先來看一個簡單點的題目:
給定一個整數數組 nums,求出數組從索引 i 到 j (i ≤ j) 範圍內元素的總和,包含 i, j 兩點。
update(i, val) 函數可以通過將下標為 i 的數值更新為 val,從而對數列進行修改。
示例:
Given nums = [1, 3, 5]
sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8
說明:
數組僅可以在 update 函數下進行修改。
你可以假設 update 函數與 sumRange 函數的調用次數是均勻分布的。
來源:力扣(LeetCode)
連結:https://leetcode-cn.com/problems/range-sum-query-mutable
著作權歸領扣網絡所有。商業轉載請聯繫官方授權,非商業轉載請註明出處。
這個題目其實很簡單,首先拋開線段樹的思想,其實直接通過 Python list 就可以實現,也是可以 AC 的:
from typing import List
class NumArray:
def __init__(self, nums: List[int]):
if not nums:
self._nums = []
else:
self._nums = nums
def update(self, i: int, val: int) -> None:
if i >= len(self._nums) or i < 0:
return
self._nums[i] = val
def sumRange(self, i: int, j: int) -> int:
return sum(self._nums[i:j + 1])
那如果使用線段樹呢?這裡我們使用數組存儲的方式來實現。首先線段數數組初始化可以直接套用上面說到的兩種方式,關鍵是 update 與 sumRange。而對於上面提到的兩種方式其實 update 和 sumRange 在解決的的時候實際代碼是一樣的。這裡我們以第二種方式初始化方式為例(畢竟會減少空間的消耗):
class NumArray:
def __init__(self, nums: List[int]):
if not nums:
self._tree = []
return
n = len(nums)
self._n = n
self._tree = [0] * (n << 1)
# 將最後 n 位放置到葉子結點,也就是數組的最後 n 位
for i in range(n, len(self._tree)):
self._tree[i] = nums[i - n]
for i in range(n - 1, 0, -1):
# 父結點 = 左結點(父結點序號*2) + 右結點(父結果序號*2+1)
self._tree[i] = self._tree[i << 1] + self._tree[i << 1 | 1]
# print(self._tree)
接下摟我們來看下 update 方法:
def update(self, i: int, val: int) -> None:
if i < 0 or i >= self._n:
return
# 更新數組的第 i 個位置,即更新 self._tree 的第 n + i 個位置,i 是從 0 開始
# 記錄和改變了多少
change = val - self._tree[self._n + i]
self._tree[self._n + i] = val
# 更新 n+i 結點的所有父結點
parent = (self._n + i) // 2
while parent:
self._tree[parent] += change
parent //= 2
# print(self._tree)
update 方法其實思路也比較簡單,先去更新葉子結點上的數值,並記錄改變差值,然後依次更新父結點的記錄和。我們再來看下 sumRange:
def sumRange(self, l: int, r: int) -> int:
# 做一些特殊邊界情況判斷
if l > r:
return 0
if r < 0:
return 0
if l > self._n:
return 0
if r < 0:
i = 0
if l >= self._n:
j = self._n - 1
l += self._n
r += self._n
result = 0
# 當 l <= r 時
while l <= r:
# 如果左邊界是右孩子,則說明不能加它的父結點的值,所以它的值需要單獨加
if l % 2 == 1:
result += self._tree[l]
# 加完之後,l向後移動,則移到了父結點右孩子的左孩子結點
l += 1
# 如果右邊界在左孩子,則左孩子需要單獨加
if r % 2 == 0:
result += self._tree[r]
r -= 1
l //= 2
r //= 2
return result
我個人認為 sumRange 比 update 要難理解一點,它的主要思想在於如果當前要求值的範圍比當前結點記錄的範圍要大(即既需要左孩子,也需要右孩子),則找父結點,如果只需要當前結點,就加上當前結點。
至此,這個題目就解決了。使用數組的話,代碼在理解上會複雜一點,主要是要對父子關係的靈活運用。
接下來,我們來看另外一個題目,我也是在刷這個題目時了解到線段樹這個數據結構:
給定一個整數數組 nums,返回區間和在 [lower, upper] 之間的個數,包含 lower 和 upper。
區間和 S(i, j) 表示在 nums 中,位置從 i 到 j 的元素之和,包含 i 和 j (i ≤ j)。
說明:
最直觀的算法複雜度是 O(n2) ,請在此基礎上優化你的算法。
示例:
輸入: nums = [-2,5,-1], lower = -2, upper = 2,
輸出: 3
解釋: 3個區間分別是: [0,0], [2,2], [0,2],它們表示的和分別為: -2, -1, 2。
來源:力扣(LeetCode)
連結:https://leetcode-cn.com/problems/count-of-range-sum
著作權歸領扣網絡所有。商業轉載請聯繫官方授權,非商業轉載請註明出處。
這個題目在分析時,我們需要將式子做一個轉換,一旦做了這個轉換,基本就成功一半了。轉換關係如下:
求:lower <= sum(i, j) <= upper
而 sum(i, j) = nums[0] + nums[1] + ... nums[j] - (nums[0] + nums[1] + ... + nums[i-1])
即 sum(i, j) = prefixSum[j] - prefixSum[i-1] , 這裡 prefixSum 為 nums 的前綴和數組
所以題目可以轉換為求:
lower <= prefixSum[j] - prefixSum[i-1] <= upper
->
# lower + prefixSum[i-1] <= prefixSum[j] <= upper + prefixSum[i-1]
# 或是
# prefixSum[j] - upper <= prefixSum[i-1] <= prefixSum[j] - lower
# 如果是第一種移動方法,則表示當給定一個 i 位置的前綴和時,需要找從 i+1 位置往後,滿足前綴和在 lower + prefixSum[i] 和 upper + prefixSum[i] 的個數
# 如果是第二種移動方法,則表示當給定一個 j 位置的前綴各時,需要找從 0 到 j-1 位置的前綴各,在 prefixSum[j] - upper 和 prefixSum[j] - lower 範圍內的
當我們分析到這一步時,我們可以發現,其實我們要求的就是當給定一個數值,然後求在一個範圍內的數值中,在指定範圍的數值有多少個。比如題目中的例子:nums = [-2,5,-1], lower = -2, upper = 2,前綴和數組為 [-2, 3, 2],因此我們可以遍歷前綴和數組,如果是從後往前遍歷,則是利用 lower + prefixSum[i-1] <= prefixSum[j] <= upper + prefixSum[i-1] 這個轉換,如果是從前往後遍歷,則是使用 prefixSum[j] - upper <= prefixSum[i-1] <= prefixSum[j] - lower 轉換(主要是需要確定範圍,所以固定的是式子左右兩邊的變量)。
不管使用哪種方式,其實思路都是一樣的。我們首先找到一個基準的前綴和,然後從當前這個基準向前(或向後)找在範圍內的個數,找到之後,將當前這個基準加入到某種數據結構中,在這個數據結構裡記錄的就是當前基準以前所有的前綴和。而且我們需要這個數據結構來保證,在這個數據結構中查詢在指定範圍內的數值個數時,性能很高,此外因為還會不斷做插入,也要保證插入的性能。
因此,我們明確了,解題需要前綴和數組,和一個能在區間內快速做查詢和插入(也可以是更新)的數據結構。顯然和我們線段樹的適用範圍是很相似的。直接看代碼吧。
# 使用線段樹,第一種移動方式,即 lower + prefixSum[i-1] <= prefixSum[j] <= upper + prefixSum[i-1]
class SegTree:
def __init__(self, left, right):
# 當前結點的左邊界
self.lo = left
# 當前結點的右邊界
self.hi = right
# 記錄在當前範圍內的數有多少個
self.count = 0
# 左、右孩子
self.left = None
self.right = None
class Solution:
def buildSegTree(self, left: int, right: int) -> SegTree:
# 感覺也可以用數組來代替
node = SegTree(left, right)
if left == right:
return node
mid = (left + right) // 2
# 左邊一半
left = self.buildSegTree(left, mid)
# 右邊一半
right = self.buildSegTree(mid + 1, right)
node.left = left
node.right = right
return node
def countOfSegTree(self, node: SegTree, left: int, right: int) -> int:
"""統計在線段樹中,在 [left, right] 範圍內的數值
"""
# 如果範圍比當前 lo hi 的範圍大,則直接返回 count 值
if left <= node.lo and node.hi <= right:
return node.count
# 如果沒在當前範圍內
if left > node.hi or right < node.lo:
return 0
return self.countOfSegTree(node.left, left, right) + self.countOfSegTree(node.right, left, right)
def insertToSegTree(self, node: SegTree, val: int) -> None:
"""在線段樹中,插入一個 val 的值
"""
node.count += 1
if node.lo == node.hi == val:
return
mid = (node.lo + node.hi) // 2
if val <= mid:
self.insertToSegTree(node.left, val)
else:
self.insertToSegTree(node.right, val)
def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
if not nums:
return 0
# 第一步,求前綴和,注意第一個元素為0,這樣在遍歷時,第一個元素到最後一個元素的情況也會考慮進去
prefix_sum = [0]
for n in nums:
prefix_sum.append(prefix_sum[-1] + n)
# 第二步,求出所有 lower + prefixSum[i-1] 和 upper + prefixSum[i-1],以及 prefix_sum 本身
allNumbers = set()
for n in prefix_sum:
allNumbers.add(n)
allNumbers.add(lower + n)
allNumbers.add(upper + n)
# 將 allNumbers 通過 hash 離散到一個連續的數組中
nums_map = {}
for i, n in enumerate(sorted(allNumbers)):
nums_map[n] = i
root = self.buildSegTree(0, len(nums_map))
res = 0
# 因為這裡是看 lower + prefixSum[i-1] <= prefixSum[j] <= upper + prefixSum[i-1],每次都是從當前位置往後看所有的前綴和,所以在遍歷時,應該從最後一個前綴和往前遍歷
for n in prefix_sum[::-1]:
left, right = nums_map[lower + n], nums_map[upper + n]
res += self.countOfSegTree(root, left, right)
self.insertToSegTree(root, nums_map[n])
return res
s = Solution()
print(s.countRangeSum([-2, 5, -1], -2, 2))
# 代碼參考了官方題解的java版本,官方版本用的是第二種移動方式
代碼中需要注意 allNumber 和 nums_map 的理解,這裡線段樹主要記錄的是在 left,right 範圍內的數值個數,因為前綴和可能比較散亂,所以對數值做了映射處理,將它映射到一個連接的數組中。
在題解中也看到了另外一種解決方法,使用的是有序數組+二分來代替的這種線段樹這種數據結構。代碼如下:
# 作者:fan-cai
# 連結:https://leetcode-cn.com/problems/count-of-range-sum/solution/python3-6xing-dai-ma-jian-ji-qian-zhui-he-er-fen-c/
# 來源:力扣(LeetCode)
# 著作權歸作者所有。商業轉載請聯繫作者獲得授權,非商業轉載請註明出處。
class Solution:
def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
"""按照上面的思路,一種解法就是
計算累積和數組 sums 的,其中 sum[i] = nums[0] + nums[1] + ... + nums[i],對於某個i來說,只有那些滿足 lower <= sum[j] - sum[i] <= upper 的 j 能形成一個區間 [i, j] 滿足題意,則有:sum[i] + lower =< sum[j] <= sum[i] + upper,目標就是來找到有多少個這樣的 j滿足上述條件。從後向前遍歷累加和數組,相當於固定sum[i]後,算出有多少的sum[j]滿足左右條件。因為必須滿足0 =< i <= j,所以sum[j]的範圍一定是由sum[i]之後的元素組成的數組。對sum[j]的範圍數組排序,l是找數組中第一個大於等於給定值(左條件)的數,而 r 是找數組中最後一個小於等於給定值(右條件)的數,那麼兩者相減,就是j的個數。
"""
import bisect
res, pre, now = 0, [0], 0
for n in nums:
# now 相當於前綴和
now += n
# 這種解法是針對上面的第二種移動方法
# pre記錄了從當前前綴和位置往前所有的前綴和,而且是排好序的,只需要在 pre 裡找到對應的左邊界和右邊界即可
res += bisect.bisect_right(pre, now - lower) - \
bisect.bisect_left(pre, now - upper)
bisect.insort(pre, now)
return res
線段樹採用了二分的思想,適用在區間範圍內做查詢、更新,見到類似在區間內獲取和、最值等問題,都可以使用線段樹
個人認為線段樹問題難點在於如果構造線段樹。而如果採用連續數組的方式來存儲,要充分利用數組要存放在葉子結點這一特性
leetcode官方題解比較難理解(可能因為都是高手寫的),關鍵還是需要多看代碼,多 debug
看過關於線段樹的其實實現版本,有做懶更新與懶插入,後續有機會再詳細總結下
靈活使用位運行,2*n與2*n+1,以及快速求大於 n 的第一個 2 的 n 次冪等
四、參考資料為了解決上面說的第二個問題,花了幾天時間,主要是連別人的題解都看不懂,參考了一些別人的解決思路,最後 debug 官方的 java 實現版本,才終於理解:
https://oi-wiki.org/ds/seg/
https://leetcode-cn.com/problems/count-of-range-sum/solution/qu-jian-he-de-ge-shu-by-leetcode-solution/
https://leetcode-cn.com/problems/count-of-range-sum/solution/xian-ren-zhi-lu-ru-he-xue-xi-ke-yi-jie-jue-ben-ti-/
https://leetcode-cn.com/problems/count-of-range-sum/solution/qian-zhui-he-xian-duan-shu-by-halfrost-2/ (其實這個人是寫得比較清楚的,但他在畫圖解釋插入的過程沒有太說清楚,看完官方代碼之後回過頭來看就很清晰了)
https://blog.csdn.net/qq_28468707/article/details/103284027 (這篇對於為什麼可以用二分,講得比較清楚)