動機

其實線段樹的花樣很多

基本款

每個點都放區間的

  • 起點
  • 終點
  • 總值
  • 左右child
class Node:
    def __init__(self,i,j,val):
        self.i = i
        self.j = j # i < j, always
        self.mid = i + (j-i) // 2
        self.val = val
        self.left = None
        self.right = None

建構就是top-down去二分

這邊可以注意的是,只要是有結合律的都可以用在segment tree!! 在此是用sum來示範

def build(arr,i,j):
    if i >= j:
        return Node(i,j,0)
    elif j-i == 1:
        return Node(i,j,arr[i])
    else:
        mid = i + (j-i) // 2
        root = Node(i,j,'wait for left and right')
        root.left = build(arr, i, mid)
        root.right = build(arr, mid, j)
        root.val = root.left.val + root.right.val
        return root

query根據node的中點分成3個case

  1. 完全在左
  2. 完全在右
  3. 卡在中間

剩下就是非法range與剛好是這個node的range

def query(root,i,j):
    if not root or i < 0 or j < 0:
        return 0
    if i == root.i and root.j == j:
        return root.val
    elif j < root.mid:
        return query(root.left, i, j)
    elif i > root.mid:
        return query(root.right, i ,j)
    else:
        return query(root.left, i, root.mid)+query(root.right, root.mid, j)

如果說要改其中一個index的值就有點麻煩

因為要把修改傳上去,所以要回傳差值

def update(root,i,val):
    if root.i == i and root.j-root.i == 1:
        diff = val - root.val # new - old = diff => new = old + diff
        root.val = val
        return diff
    else:
        if i < root.mid:
            diff = update(root.left, i, val)
        else:
            diff = update(root.right, i, val)
        root.val += diff
        return diff

懶人標記

如果要改一個range,對每個點用單點修改,效率超低,所以要用懶人標記

懶人標記有兩種做法,區分是會不會改到原本的資料

原本資料不變

一次替整個range加上某個值

因為原本的值沒有變,所以可以用一個cache存差值,之後就是在查詢時把差值加上去

class Node:
    def __init__(self,i,j,val):
        self.i = i
        self.j = j # i < j, always
        self.mid = i + (j-i) // 2
        self.val = val
        self.left = None
        self.right = None

        self.cache = 0

def query(root,i,j):
    if not root or i < 0 or j < 0:
        return 0
    else:
      ret = root.cache*(j-i)
      if i == root.i and root.j == j:
          ret += root.val
      elif j < root.mid:
          ret += query(root.left, i, j)
      elif i > root.mid:
          ret += query(root.right, i ,j)
      else:
          ret += query(root.left, i, root.mid)+query(root.right, root.mid, j)
def update(root,i,j,val):
    if root.i == i and root.j == j:
        root.cache += val
    else:
        if i < root.mid:
            update(root.left, i, val)
        else:
            update(root.right, i, val)

改原本資料

現在變成要一次把一個range的資料都改成同一個數字

但總不能一個一個改,所以可以在node留訊息,等之後有人經過再真的去改自己的資料,並把改變往下推

class Node:
    def __init__(self,i,j,val):
        self.i = i
        self.j = j # i < j, always
        self.mid = i + (j-i) // 2
        self.val = val
        self.left = None
        self.right = None

        self.changed = False
        self.newVal = 0
    
    def push(self):
        self.val = self.newVal*(j-i)
        self.left.changed = self.right.change = True
        self.left.newVal = self.right.newVal = self.newVal

        self.newVal, self.changed = 0, False

def query(root,i,j):
    if not root or i < 0 or j < 0:
        return 0
    root.push()
    if i == root.i and root.j == j:
        return root.val
    elif j < root.mid:
        return query(root.left, i, j)
    elif i > root.mid:
        return query(root.right, i ,j)
    else:
        return query(root.left, i, root.mid)+query(root.right, root.mid, j)

def update(root,i,j,val):
    if root.i == i and root.j == j:
        root.newVal, root.changed = val, True
    else:
        root.push()
        if i < root.mid:
            update(root.left, i, val)
        else:
            update(root.right, i, val)

持久化

保留過去的版本

在update時把經過的node都重生一遍!!

def update(root,i,val):
    if root.i == i and root.j-root.i == 1:
        return Node(root.i, root.j, val)
    else:
        ret = Node(root.i, root.j, 'waiting')
        if i < root.mid:
            ret.right, ret.left = update(root.right, i, val), root.left
        else:
            ret.left, ret.right = update(root.left, i, val), root.right
        ret.val = ret.left.val + ret.right.val
        return ret

segment tree 與 binary index tree的差異

segment treebinary index tree
LoC
query rangea~b1~n
time complexity of buildingO(lg(n))O(N)
time complexity of queryO(lg(n))O(lg(n))
time complexity of updateO(lg(n))O(lg(n))

Ref

Segment Tree