題目的難度顏色使用 Luogu 上的分級,由簡單到困難分別為 🔴🟠🟡🟢🔵🟣⚫。

🔗 🟢 AT_dp_p Independent Set

Problem Statement

題目簡述

給定一棵 NN 個節點的樹,對每個節點塗黑色或白色,要求相鄰節點不能同時塗黑色。求合法塗色方案數,答案對 109+710^9 + 7 取模。

Constraints

約束條件

  • 1N1051 \leq N \leq 10^5
  • 1xi,yiN1 \leq x_i, y_i \leq N
  • 給定的圖保證是一棵樹

思路:樹形 DP

本題是經典的樹上獨立集計數問題。相鄰節點不能同時塗黑,等價於黑色節點構成樹的一個獨立集。

狀態定義

f[u][c]f[u][c] 表示以 uu 為根的子樹中,節點 uu 塗成顏色 cc 時的合法方案數:

  • f[u][0]f[u][0]:節點 uu白色(不選入獨立集)
  • f[u][1]f[u][1]:節點 uu黑色(選入獨立集)

狀態轉移

對於節點 uu 的每個子節點 vv

  • uu 塗黑色vv 必須塗白色(相鄰不能同黑)

    f[u][1]=vchildren(u)f[v][0]f[u][1] = \prod_{v \in \text{children}(u)} f[v][0]

  • uu 塗白色vv 可以塗黑色或白色

    f[u][0]=vchildren(u)(f[v][0]+f[v][1])f[u][0] = \prod_{v \in \text{children}(u)} (f[v][0] + f[v][1])

乘法原理

各子樹的塗色方案互相獨立,因此總方案數為各子樹方案數的乘積。

初始條件與答案

  • 初始條件:對於葉節點,f[u][0]=f[u][1]=1f[u][0] = f[u][1] = 1
  • 最終答案f[root][0]+f[root][1]f[\text{root}][0] + f[\text{root}][1]
根的選擇

可以任選一個節點作為根,答案不變。因為樹是無環連通圖,從任意節點出發都能遍歷整棵樹;而 DP 最終統計的是整棵樹的合法方案數,與「從哪個節點開始看」無關。

實現方式

程式碼採用 BFS 拓撲序 + 逆序迭代 的方式避免遞迴深度過大的問題:

  1. 從根節點 BFS 建立拓撲序
  2. 逆序遍歷(從葉子到根),確保處理 uu 時所有子節點都已計算完畢

複雜度分析

  • 時間複雜度:O(N)\mathcal{O}(N),每個節點和邊各遍歷常數次
  • 空間複雜度:O(N)\mathcal{O}(N),儲存鄰接表、父節點陣列和 DP 陣列

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from collections import deque

MOD = int(1e9 + 7)

def solve():
N = int(input())
g = [[] for _ in range(N)]
for _ in range(N - 1):
u, v = map(lambda x: int(x) - 1, input().split())
g[u].append(v)
g[v].append(u)

# @cache
# def dfs(u: int, fa: int) -> list[int]:
# res = [1, 1]
# for v in g[u]:
# if v == fa:
# continue
# white, black = dfs(v, u)
# res[1] = (res[1] * white) % MOD
# res[0] = (res[0] * (white + black)) % MOD
# return res
# print(sum(dfs(0, -1)) % MOD)

fa = [-1] * N
order = []
q = deque([0])
while q:
u = q.popleft()
order.append(u)
for v in g[u]:
if v == fa[u]:
continue
fa[v] = u
q.append(v)

# f[u][0/1]: u 塗白/黑色的方案數
f = [[1, 1] for _ in range(N)]
for u in reversed(order):
for v in g[u]:
if v == fa[u]:
continue
f[u][1] = (f[u][1] * f[v][0]) % MOD
f[u][0] = (f[u][0] * (f[v][0] + f[v][1])) % MOD
print(sum(f[0]) % MOD)

if __name__ == "__main__":
solve()