動機

segment tree 與 binary index tree的練習題

Problem

Given an integer array nums, handle multiple queries of the following types:

  1. Update the value of an element in nums.
  2. Calculate the sum of the elements of nums between indices left and right inclusive where left <= right.

Implement the NumArray class:

  • NumArray(int[] nums) Initializes the object with the integer array nums.
  • void update(int index, int val) Updates the value of nums[index] to be val.
  • int sumRange(int left, int right) Returns the sum of the elements of nums between indices left and right inclusive (i.e. nums[left] + nums[left + 1] + ... + nums[right]).

 

Example 1:

Input[NumArray, sumRange, update, sumRange][[[1, 3, 5]], [0, 2], [1, 2], [0, 2]]Output[null, 9, null, 8]ExplanationNumArray numArray = new NumArray([1, 3, 5]);numArray.sumRange(0, 2); // return 1 + 3 + 5 = 9numArray.update(1, 2);   // nums = [1, 2, 5]numArray.sumRange(0, 2); // return 1 + 2 + 5 = 8

 

Constraints:

  • 1 <= nums.length <= 3 * 104
  • -100 <= nums[i] <= 100
  • 0 <= index < nums.length
  • -100 <= val <= 100
  • 0 <= left <= right < nums.length
  • At most 3 * 104 calls will be made to update and sumRange.

Sol1: segment tree

就是在complete binary tree中放入range與sum的訊息,用dfs的方式去建tree

query就是做區間比較,update就是從最底的點開始往上更新

class Node:
    def __init__(self,i,j,val):
        self.i = i
        self.j = j # i < j, always
        self.mid = i + (j-i) // 2
        self.val = val
        self.left = None
        self.right = None

def query(root,i,j):
    if not root or i < 0 or j < 0:
        return 0
    if i == root.i and root.j == j:
        return root.val
    elif j < root.mid:
        return query(root.left, i, j)
    elif i > root.mid:
        return query(root.right, i ,j)
    else:
        return query(root.left, i, root.mid)+query(root.right, root.mid, j)

def update(root,i,val):
    if root.i == i and root.j-root.i == 1:
        diff = val - root.val # new - old = diff => new = old + diff
        root.val = val
        return diff
    else:
        if i < root.mid:
            diff = update(root.left, i, val)
        else:
            diff = update(root.right, i, val)
        root.val += diff
        return diff

def build(arr,i,j):
    if i >= j:
        return Node(i,j,0)
    elif j-i == 1:
        return Node(i,j,arr[i])
    else:
        mid = i + (j-i) // 2
        root = Node(i,j,'wait for left and right')
        root.left = build(arr, i, mid)
        root.right = build(arr, mid, j)
        root.val = root.left.val + root.right.val
        return root

class NumArray:
    def __init__(self, nums: List[int]):
        self.root = build(nums,0, len(nums))
        self.size = len(nums)

    def update(self, index: int, val: int) -> None:
        if 0 <= index < self.size:
            update(self.root,index,val)

    def sumRange(self, left: int, right: int) -> int:
        return query(self.root, left, right+1)

Sol2: binary index tree(BIT)

雖然叫binary index tree但實際上走訪的方式比較像list,但是呈現出來的結構是tree。

之後會有一篇單獨做筆記,這邊先提BIT與segment tree不一樣的地方

BIT的query是 1~n segment tree的query a~b

class BIT:
    def __init__(self,arr):
        self.n = len(arr)
        self.bit = [0] * (1+self.n)
        
        for (i,n) in enumerate(arr):
            i += 1 # bit start from 1
            self.bit[i] += n
            nextI = i + (i&(-i))
            if nextI <= self.n:
                self.bit[nextI] += self.bit[i]
    def query(self,n): # [1..n]
        ret = 0
        while n > 0:
            ret += self.bit[n]
            n -= (n&(-n))
        return ret
    
    def update(self,i,diff):
        while i <= self.n:
            self.bit[i] += diff
            i += (i&(-i))
            
class NumArray:
    def __init__(self, nums: List[int]):
        self.nums = list(nums)
        self.bit = BIT(nums)
        #print(self.bit.bit)

    def update(self, index: int, val: int) -> None:
        diff = val - self.nums[index]
        self.nums[index] = val
        self.bit.update(index+1, diff)
        #print(self.bit.bit)

    def sumRange(self, left: int, right: int) -> int:
        return self.bit.query(right+1)-self.bit.query(left)