動機
segment tree 與 binary index tree的練習題
Problem
Given an integer array nums
, handle multiple queries of the following types:
- Update the value of an element in
nums
. - Calculate the sum of the elements of
nums
between indicesleft
andright
inclusive whereleft <= right
.
Implement the NumArray
class:
NumArray(int[] nums)
Initializes the object with the integer arraynums
.void update(int index, int val)
Updates the value ofnums[index]
to beval
.int sumRange(int left, int right)
Returns the sum of the elements ofnums
between indicesleft
andright
inclusive (i.e.nums[left] + nums[left + 1] + ... + nums[right]
).
Example 1:
Input[NumArray, sumRange, update, sumRange][[[1, 3, 5]], [0, 2], [1, 2], [0, 2]]Output[null, 9, null, 8]ExplanationNumArray numArray = new NumArray([1, 3, 5]);numArray.sumRange(0, 2); // return 1 + 3 + 5 = 9numArray.update(1, 2); // nums = [1, 2, 5]numArray.sumRange(0, 2); // return 1 + 2 + 5 = 8
Constraints:
1 <= nums.length <= 3 * 104
-100 <= nums[i] <= 100
0 <= index < nums.length
-100 <= val <= 100
0 <= left <= right < nums.length
- At most
3 * 104
calls will be made toupdate
andsumRange
.
Sol1: segment tree
就是在complete binary tree中放入range與sum的訊息,用dfs的方式去建tree
query就是做區間比較,update就是從最底的點開始往上更新
class Node:
def __init__(self,i,j,val):
self.i = i
self.j = j # i < j, always
self.mid = i + (j-i) // 2
self.val = val
self.left = None
self.right = None
def query(root,i,j):
if not root or i < 0 or j < 0:
return 0
if i == root.i and root.j == j:
return root.val
elif j < root.mid:
return query(root.left, i, j)
elif i > root.mid:
return query(root.right, i ,j)
else:
return query(root.left, i, root.mid)+query(root.right, root.mid, j)
def update(root,i,val):
if root.i == i and root.j-root.i == 1:
diff = val - root.val # new - old = diff => new = old + diff
root.val = val
return diff
else:
if i < root.mid:
diff = update(root.left, i, val)
else:
diff = update(root.right, i, val)
root.val += diff
return diff
def build(arr,i,j):
if i >= j:
return Node(i,j,0)
elif j-i == 1:
return Node(i,j,arr[i])
else:
mid = i + (j-i) // 2
root = Node(i,j,'wait for left and right')
root.left = build(arr, i, mid)
root.right = build(arr, mid, j)
root.val = root.left.val + root.right.val
return root
class NumArray:
def __init__(self, nums: List[int]):
self.root = build(nums,0, len(nums))
self.size = len(nums)
def update(self, index: int, val: int) -> None:
if 0 <= index < self.size:
update(self.root,index,val)
def sumRange(self, left: int, right: int) -> int:
return query(self.root, left, right+1)
Sol2: binary index tree(BIT)
雖然叫binary index tree但實際上走訪的方式比較像list,但是呈現出來的結構是tree。
之後會有一篇單獨做筆記,這邊先提BIT與segment tree不一樣的地方
BIT的query是 1~n segment tree的query a~b
class BIT:
def __init__(self,arr):
self.n = len(arr)
self.bit = [0] * (1+self.n)
for (i,n) in enumerate(arr):
i += 1 # bit start from 1
self.bit[i] += n
nextI = i + (i&(-i))
if nextI <= self.n:
self.bit[nextI] += self.bit[i]
def query(self,n): # [1..n]
ret = 0
while n > 0:
ret += self.bit[n]
n -= (n&(-n))
return ret
def update(self,i,diff):
while i <= self.n:
self.bit[i] += diff
i += (i&(-i))
class NumArray:
def __init__(self, nums: List[int]):
self.nums = list(nums)
self.bit = BIT(nums)
#print(self.bit.bit)
def update(self, index: int, val: int) -> None:
diff = val - self.nums[index]
self.nums[index] = val
self.bit.update(index+1, diff)
#print(self.bit.bit)
def sumRange(self, left: int, right: int) -> int:
return self.bit.query(right+1)-self.bit.query(left)