雖然方法一的 樹形DP
比較直覺,但在定義狀態跟轉移方程式時會比較複雜,看了好幾次才懂;知道結論後用方法二的 狀態機DP
會容易許多。本篇文章在敘述上可能有許多不完整之處,建議閱讀參考資料。
題意
有一棵包含 n 個節點的無向樹,其中節點編號從 0 到 n−1。並給定一個長度為 n−1 的二維整數陣列 edges,其中 edges[i]=[ui,vi] 表示樹中節點 ui 和 vi 之間有一條邊。同時給定一個正整數 k,以及一個長度為 n 的非負整數陣列 nums,其中 nums[i] 表示節點 i 的值。
你的目標是 最大化 樹中所有節點值之和。為了實現這一目標,你可以執行以下操作 任意 次(包括零次):
- 選擇任何一條連接節點 u 和 v 的邊 [u,v] ,並將 u 和 v 的值更新為:
- nums[u]=nums[u]⊕k
- nums[v]=nums[v]⊕k
返回通過執行以上操作 任意次 後,可以得到的 最大 可能總和。
思路
首先整理狀態,雖然每個點可以被操作任意次,但是由於 XOR
的性質 x⊕x=0 和 x⊕0=x ,所以每個點的其實只有兩種狀態:操作偶數次 、操作奇數次 。
此外,由於是無向樹,所以任何點都可以當作根節點 ,因此我們可以直接假定 0 為根節點。
方法一:樹形DP
因此,這題就變成了「選」或「不選」每條邊的動態規劃(DP)問題。我們可以用 樹形DP
的方式來解決這個問題。
- 定義 f(u,0) 和 f(u,1) 分別表示在節點 u 操作偶數次和奇數次時,其子樹 (不包括 u) 的最大值。
- 注意:這裡的 f(u,0) 和 f(u,1) 並 不等於 dfs(u,fa,0) 和 dfs(u,fa,1),後者的 0 和 1 分別表示不操作 (u,fa) 和操作 (u,fa) 這條邊,所能得到的 u 的子樹(包含 u)的最大值。
- 初始化 f(u,0)=0 和 f(u,1)=−∞ ,因為一開始不考慮任何邊時, u 無法被操作,故用 −∞ 來表示不合法的情況。
接著考慮 (u,fa) 這條邊是否被操作,所能得到的 u 的子樹 (包含 u) 的最大值,也就是 dfs(u,fa,0) 和 dfs(u,fa,1) 的值,0 和 1 分別表示不操作 (u,fa) 和操作 (u,fa) 這條邊,可以由 f(u,0) 和 f(u,1) 的值得到:
- dfs(u,fa,0)=max(f(u,0)+nums[u],f(u,1)+(nums[u]⊕k)),表示不操作 (u,fa) 這條邊。
- 由於 f(u,0) 表示在 u 的子樹中, u 操作偶數次時,不包含 u 的最大值;而 f(u,1) 表示在 u 的子樹中, u 操作奇數次時,不包含 u 的最大值。
- 所以在不操作 (u,fa) 這條邊時,直接從 f(u,0)+nums[u] 和 f(u,1)+(nums[u]⊕k) 中取最大值即可。
- dfs(u,fa,1)=max(f(u,1)+nums[v],f(u,0)+(nums[u]⊕k)),表示操作 (u,fa) 這條邊。
- 在 u 的子樹中操作奇數次,加上操作 (u,fa) 這條邊後,等同 u 操作偶數次,故貢獻為 nums[u] ;反之若在 u 的子樹中操作偶數次,加上操作 (u,fa) 這條邊後,等同 u 操作奇數次,故貢獻為 nums[u]⊕k。
- 所以在操作 (u,fa) 這條邊時,從 f(u,1)+nums[v] 和 f(u,0)+(nums[u]⊕k) 中取最大值即可。
則 f(u,0) 和 f(u,1) 的轉移方程式如下:
- f(u,0)=max(f(u,0)+dfs(v,u,0),f(u,1)+dfs(v,u,1)),表示不操作 (u,v) 或操作 (u,v)。
- f(u,1)=max(f(u,1)+dfs(v,u,0),f(u,0)+dfs(v,u,1)),表示不操作 (u,v) 或操作 (u,v)。
在寫出所有遞迴關係後,我們可以用 DFS 的方式來遍歷整棵樹,並在遍歷的過程中更新 f(u,0) 和 f(u,1) 的值, dfs 函數返回 dfs(u,fa,0) 和 dfs(u,fa,1) 的值。
由於是無向樹,可以令根節點為 0,設其父節點為 −1 。由於根節點沒有父節點,顯然無法操作 (0,−1) 這條邊,故最後返回的值需為 dfs(0,−1,0) ,而 dfs(0,−1,1) 為非法的情況。
複雜度分析
- 時間複雜度:O(n),其中 n 為節點數。
- 空間複雜度:O(n)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| class Solution: def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int: n = len(nums) g = [[] for _ in range(n)] for u, v in edges: g[u].append(v) g[v].append(u) def dfs(u: int, fa: int) -> Tuple[int, int]: f0, f1 = 0, -float('inf') for v in g[u]: if v == fa: continue r0, r1 = dfs(v, u) t0, t1 = f0, f1 f0 = max(t0 + r0, t1 + r1) f1 = max(t0 + r1, t1 + r0) return max(f0 + nums[u], f1 + (nums[u] ^ k)), max(f1 + nums[u], f0 + (nums[u] ^ k)) return dfs(0, -1)[0]
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
| using LL = long long; class Solution { public: LL maximumValueSum(vector<int>& nums, int k, vector<vector<int>>& edges) { int n = nums.size(); vector<vector<int>> g(n); for (auto &e : edges) { g[e[0]].push_back(e[1]); g[e[1]].push_back(e[0]); } function<pair<LL, LL>(int, int)> dfs = [&](int u, int fa) -> pair<LL, LL> { LL f0 = 0, f1 = LLONG_MIN; for (int v : g[u]) { if (v == fa) continue; pair<LL, LL> p = dfs(v, u); LL r0 = p.first, r1 = p.second; LL t0 = f0, t1 = f1; f0 = max(t0 + r0, t1 + r1); f1 = max(t0 + r1, t1 + r0); } return make_pair(max(f0 + nums[u], f1 + (nums[u] ^ k)), max(f1 + nums[u], f0 + (nums[u] ^ k))); }; return dfs(0, -1).first; } };
|
方法二:狀態機DP
對於樹上的任兩點 u,v,樹上兩點間必存在一條路徑,而沿著這條路徑操作,可以將 u 和 v 之間的所有點(除了 u 和 v 之外)的操作次數 +2 ,但由於 XOR
的性質,操作 +2 次和操作 0 次是等價的,只有 u 和 v 會被真正操作各 1 次。故其實可以直接選兩點操作,不用建圖 。
令 dp[i][0/1] 表示前 i 個點操作偶數/奇數次時的最大值,則轉移方程式如下:
- dp[i][0]=max(dp[i−1][0]+nums[i],dp[i−1][1]+(nums[i]⊕k))
前 i 個點操作偶數次,可以從前 i−1 個點操作偶數次且當前點不操作,或者前 i−1 個點操作奇數次且當前點操作轉移而來,兩者取最大值。
- dp[i][1]=max(dp[i−1][1]+nums[i],dp[i−1][0]+(nums[i]⊕k))
前 i 個點操作奇數次,可以從前 i−1 個點操作奇數次且當前點不操作,或者前 i−1 個點操作偶數次且當前點操作轉移而來,兩者取最大值。
複雜度分析
- 時間複雜度:O(n),其中 n 為節點數,也可以說是 nums 的長度。
- 空間複雜度:O(n)。
1 2 3 4 5 6 7 8
| class Solution: def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int: n = len(nums) dp = [[0, -float("inf")] for _ in range(n + 1)] for i, x in enumerate(nums): dp[i+1][0] = max(dp[i][0] + x, dp[i][1] + (x ^ k)) dp[i+1][1] = max(dp[i][1] + x, dp[i][0] + (x ^ k)) return dp[n][0]
|
1 2 3 4 5 6 7 8 9 10 11 12 13
| using LL = long long; class Solution { public: LL maximumValueSum(vector<int>& nums, int k, vector<vector<int>>& edges) { int n = nums.size(); vector<vector<LL>> dp(n + 1, {0, LLONG_MIN}); for (int i = 0; i < n; i++) { dp[i + 1][0] = max(dp[i][0] + nums[i], dp[i][1] + (nums[i] ^ k)); dp[i + 1][1] = max(dp[i][1] + nums[i], dp[i][0] + (nums[i] ^ k)); } return dp[n][0]; } };
|
方法二:狀態機DP,空間優化
由於轉移只和前一個狀態有關,故可以進一步優化空間為 O(1) 。
複雜度分析
- 時間複雜度:O(n)。
- 空間複雜度:O(1)。
1 2 3 4 5 6
| class Solution: def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int: f0, f1 = 0, -float('inf') for x in nums: f0, f1 = max(f0 + x, f1 + (x ^ k)), max(f1 + x, f0 + (x ^ k)) return f0
|
1 2 3 4 5 6 7 8 9 10 11 12 13
| using LL = long long; class Solution { public: LL maximumValueSum(vector<int>& nums, int k, vector<vector<int>>& edges) { LL f0 = 0, f1 = LLONG_MIN; for (int x : nums) { LL t0 = f0, t1 = f1; f0 = max(t0 + x, t1 + (x ^ k)); f1 = max(t1 + x, t0 + (x ^ k)); } return f0; } };
|
參考資料
寫在最後
Cover photo is generated by @たろたろ, thanks for their work!