動機

個數、i<j,使用divide and conquer的時機!!

Problem

Given an integer array nums, return the number of reverse pairs in the array.

A reverse pair is a pair (i, j) where 0 <= i < j < nums.length and nums[i] > 2 * nums[j].

 

Example 1:

Input: nums = [1,3,2,3,1]Output: 2

Example 2:

Input: nums = [2,4,3,5,1]Output: 3

 

Constraints:

  • 1 <= nums.length <= 5 * 104
  • -231 <= nums[i] <= 231 - 1

Sol: divide and conquer

關於divide and conquer的時機,看LC327複習

這裡想記錄的是lower_bound與upper_bound到底算出什麼?

lower_bound: 第一個 大於等於(不小於) target的位置 upper_bound: 第一個 大於 target的位置

class Solution:
    def reversePairs(self, nums: List[int]) -> int:
        self.arr = nums
        return self.daq(0,len(nums))
    def daq(self,l,r):
        if (r-l) <= 1:
            return 0
        
        mid = l + (r-l) // 2
        ret = self.daq(l,mid) + self.daq(mid,r)
        
        # nums[j] < nums[i]/2 是我們要的
        # 要最右邊,也就是第一個不符合nums[j] < nums[i]/2
        # !(nums[j] < nums[i]/2) => nums[j] >= nums[i]/2
        for i in range(l,mid):
            a = bisect_left(self.arr,(self.arr[i]+1)//2,mid,r)
            ret += a-mid
        self.arr[l:r] = sorted(self.arr[l:r])
        return ret

case study: binary index tree

這解法來自General principles behind problems similar to “Reverse Pairs”

簡單總結一下裡面的內容,

  • 基本是break down into subproblems,有很多種方式,但這裡只看
    1. T(i, j) = T(i, j - 1) + C
    2. T(i, j) = T(i, m) + T(m + 1, j) + C
  • 對於第一種,因為每次都要看有多少是符合的,所以要search
    • 只用linear search會變O(n^2)
    • 如果search space是
      1. 不會變動的(static),用binary search,但是這裡每次處理完一個就要加到search space,所以不能用
      2. 變動的,用binary seach tree,平衡加入search space與search的成本,但是不會自平衡就會退化成O(n^2)
      3. (這我自己加的) 因為求的是個數,如果後面的個數會包含前面的個數就可以用bit
  • 第二種就看前面
private int search(int[] bit, int i) {
    int sum = 0;
    
    while (i < bit.length) {
        sum += bit[i];
        i += i & -i;
    }

    return sum;
}

private void insert(int[] bit, int i) {
    while (i > 0) {
        bit[i] += 1;
        i -= i & -i;
    }
}
public int reversePairs(int[] nums) {
    int res = 0;
    int[] copy = Arrays.copyOf(nums, nums.length);
    int[] bit = new int[copy.length + 1];
    
    Arrays.sort(copy);
    
    for (int ele : nums) {
        res += search(bit, index(copy, 2L * ele + 1));
        insert(bit, index(copy, ele));
    }
    
    return res;
}

private int index(int[] arr, long val) {
    int l = 0, r = arr.length - 1, m = 0;
    	
    while (l <= r) {
    	m = l + ((r - l) >> 1);
    		
    	if (arr[m] >= val) {
    	    r = m - 1;
    	} else {
    	    l = m + 1;
    	}
    }
    
    return l + 1;
}