動機

在dp中看到mutual recursion

這裡不管效率,都是用top-down的方式寫,除了最終版本

這應該是繼burst ballon之後看過最有趣的DP了

Problem

You are given an integer array prices where prices[i] is the price of a given stock on the ith day, and an integer k.

Find the maximum profit you can achieve. You may complete at most k transactions.

Note: You may not engage in multiple transactions simultaneously (i.e., you must sell the stock before you buy again).

 

Example 1:

Input: k = 2, prices = [2,4,1]Output: 2Explanation: Buy on day 1 (price = 2) and sell on day 2 (price = 4), profit = 4-2 = 2.

Example 2:

Input: k = 2, prices = [3,2,6,5,0,3]Output: 7Explanation: Buy on day 2 (price = 2) and sell on day 3 (price = 6), profit = 6-2 = 4. Then buy on day 5 (price = 0) and sell on day 6 (price = 3), profit = 3-0 = 3.

 

Constraints:

  • 0 <= k <= 100
  • 0 <= prices.length <= 1000
  • 0 <= prices[i] <= 1000

Sol Ver1: DP O(N^2) (TLE)

定義f(i,k)為 在第i天還有k次交易額度下的最大利潤

這樣每次都要往[0,i)去找最大利潤,有沒有方法把這個壓成常數

class Solution:
    @functools.cache
    def f(self, i, k):
        if k == 0 or i <= 0:
            return 0
        else:
            return max([max(self.f(j,k-1)+max(0,self.ps[i]-self.ps[j]),self.f(j,k))  for j in range(i)])
    def maxProfit(self, k: int, prices: List[int]) -> int:
        self.ps = prices
        return self.f(len(prices)-1, k)

Sol Ver2: DP O(N*k)

起終點分離

如果每次都要算出每段的利潤,自然要每個點都要試

但是如果把每段拆成負的起點與終點的加總,就不用每個點都試

能拆的理由是每一段不會overlap,一段一定是發生在一段之後

Ver1

定義 f(i,k,can_buy)為 在第i天還有k次交易額度 can_buy(可以買股票) 下的最大收益

    @functools.cache
    def f(self, i, k, can_buy):
        if k == 0 or i == 0:
            return 0 if not can_buy else -self.ps[i]
        elif i < 0 or k < 0:
            return -math.inf
        else:
            if can_buy:
                return max(self.f(i-1, k, True), self.f(i-1, k-1, False) - self.ps[i])
            else:
                return max(self.f(i-1, k, False), self.f(i-1, k, True) + self.ps[i])
    def maxProfit(self, k: int, prices: List[int]) -> int:
        if not prices or k is 0:
            return 0
        self.ps = prices
        return max([self.f(len(prices)-1,k,False) for j in range(k)])

Ver2

不可以買股票 可以當成要算 售出 的最大收益

可以買股票 可以當成要算 購入 的最大收益

故可以拆成兩個function做mutual recursion

    @functools.cache
    def buy(self,i,k):
        #print("buy", i, k)
        if k <= 0 or i < 0:
            ret = -math.inf
        else:
            ret = max(self.buy(i-1,k), self.sell(i-1,k-1)-self.ps[i])
        #print("buy", i, k, ret)
        return ret
    @functools.cache
    def sell(self,i,k):
        #print("sell", i, k)
        if i <= 0 or k <= 0:
            ret = 0 
        else:
            ret = max(self.sell(i-1,k), self.buy(i,k)+self.ps[i])
        #print("sell", i, k, ret)
        return ret
    def maxProfit(self, k: int, prices: List[int]) -> int:
        if not prices or k is 0:
            return 0
        if k*2 > len(prices):
            return self.infTimes(prices)
        self.ps = prices
        return self.sell(len(prices)-1, k)

挑區段

a~b + b~c = a~c 所以可以一直把區段接起來,直到最大 用這種方式來看,可以把k看成 最多可以挑(保留)幾次到目前為止已經接好的區段

定義 localF 為 在第i天還有k次保留機會時接上i-1~i的最大收益 定義 globalF 為 在第i天還有k次保留機會時的最大收益

    @functools.cache
    def localF(self,i,k):
        if i == 0 and k == 0:
            return 0 
        elif k == 0:
            return -self.ps[i]
        elif i == 0:
            return 0 
        else:
            return max(self.globalF(i-1,k-1), self.localF(i-1,k))+(self.ps[i]-self.ps[i-1])
    @functools.cache
    def globalF(self,i,k):
        if i < 0 or k < 0:
            return 0
        else:
            return max(self.globalF(i-1,k), self.localF(i,k))
    def maxProfit(self, k: int, prices: List[int]) -> int:
        self.ps = prices
        return self.globalF(len(prices)-1,k)

湊區段 (最終版本)

到現在,其實這題的重點是在第i天的最大收益是可以來自

  • 同一個前一天的區段 (可以接)
  • 歷史最大收益 (可能是可以接,也可能是不能接)

所以可以用k一次一次擴大區段的嘗試

定義 val 為

  • 前一天的最大收益(不能接)或連續區段的最大收益(可以接) 加上 當前的區段 (當次) 或是
  • 歷史上當天的最大收益 (前幾次的)

之間的最大值

定義 pnl 為 歷史上第i天的最大收益

class Solution:
    def maxProfit(self, k: int, prices: List[int]) -> int:
        if 2*k >= len(prices): 
            return sum(max(0, prices[i]-prices[i-1]) for i in range(1, len(prices)))
        
        pnl = [0]*len(prices)
        for _ in range(k):
            val = 0
            for i in range(1, len(pnl)): 
                val = max(pnl[i], val + prices[i] - prices[i-1]) # val is pnl[i-1] or val
                pnl[i] = max(pnl[i-1], val)
        return pnl[-1]