題目的難度顏色使用 Luogu 上的分級,由簡單到困難分別為 🔴🟠🟡🟢🔵🟣⚫。

🔗 🟢 AT_dp_j Sushi

Problem Statement

題目簡述

NN 個盤子,每個盤子上最初放有 ai{1,2,3}a_i \in \{1, 2, 3\} 個壽司。
每次操作擲一枚 1N1 \sim N 的公平骰子,若骰到的盤子上有壽司,則吃掉其中一個;若沒有則不進行任何動作。
求將所有壽司吃完所需的期望操作次數。

Constraints

約束條件

  • 1N3001 \le N \le 300
  • 1ai31 \le a_i \le 3

思路:期望值 DP

這是一道典型的期望值 DP 問題。

狀態壓縮的關鍵觀察

由於每個盤子的壽司數量只有 1,2,31, 2, 3 三種,我們不需要記錄每個盤子的具體狀態,只需記錄各類盤子的數量即可。

狀態定義

f(i,j,k)f(i, j, k) 表示當前盤面上有 ii 個剩 1 個壽司的盤子、jj 個剩 2 個壽司的盤子、kk 個剩 3 個壽司的盤子時,吃完所有壽司所需的期望操作次數

  • 目標:求 f(cnt1,cnt2,cnt3)f(\text{cnt}_1, \text{cnt}_2, \text{cnt}_3),其中 cntx\text{cnt}_x 為初始時有 xx 個壽司的盤子數量。
  • 邊界f(0,0,0)=0f(0, 0, 0) = 0(沒有壽司時不需操作)。

轉移方程

在狀態 (i,j,k)(i, j, k) 下擲一次骰子,根據結果分為四種情況:

結果 機率 後續狀態
骰到空盤子 NijkN\dfrac{N - i - j - k}{N} 狀態不變
骰到剩 1 個壽司的盤子 iN\dfrac{i}{N} (i1,j,k)(i-1, j, k)
骰到剩 2 個壽司的盤子 jN\dfrac{j}{N} (i+1,j1,k)(i+1, j-1, k)
骰到剩 3 個壽司的盤子 kN\dfrac{k}{N} (i,j+1,k1)(i, j+1, k-1)

根據期望的定義,每次操作消耗 1 步:

f(i,j,k)=1+NijkNf(i,j,k)+iNf(i1,j,k)+jNf(i+1,j1,k)+kNf(i,j+1,k1)f(i,j,k) = 1 + \frac{N-i-j-k}{N} f(i,j,k) + \frac{i}{N} f(i-1,j,k) + \frac{j}{N} f(i+1,j-1,k) + \frac{k}{N} f(i,j+1,k-1)

移項化簡

將含 f(i,j,k)f(i,j,k) 的項移至左邊:

f(i,j,k)i+j+kN=1+iNf(i1,j,k)+jNf(i+1,j1,k)+kNf(i,j+1,k1)f(i,j,k) \cdot \frac{i+j+k}{N} = 1 + \frac{i}{N} f(i-1,j,k) + \frac{j}{N} f(i+1,j-1,k) + \frac{k}{N} f(i,j+1,k-1)

兩邊同乘 NN 再除以 (i+j+k)(i+j+k),得到最終遞推式:

f(i,j,k)=N+if(i1,j,k)+jf(i+1,j1,k)+kf(i,j+1,k1)i+j+k\boxed{f(i,j,k) = \frac{N + i \cdot f(i-1,j,k) + j \cdot f(i+1,j-1,k) + k \cdot f(i,j+1,k-1)}{i+j+k}}

實現細節

迭代順序

計算 f(i,j,k)f(i, j, k) 時需要用到三個子問題:

依賴項 與當前狀態的關係 如何確保已計算
f(i1,j,k)f(i-1, j, k) ii 減少,j,kj, k 不變 ii 從小到大迭代
f(i+1,j1,k)f(i+1, j-1, k) jj 減少(ii 增加無妨) jj 從小到大迭代,同層 ii 全部算完
f(i,j+1,k1)f(i, j+1, k-1) kk 減少(jj 增加無妨) kk 從小到大迭代,同層 jj 全部算完

關鍵觀察:雖然後兩項中 iijj 會增加,但對應的 jjkk 會減少。由於我們按 kjik \to j \to i 的順序從小到大迭代,當計算 (i,j,k)(i, j, k) 時:

  • 所有 k<kk' < k 的狀態都已算完 → f(i,j+1,k1)f(i, j+1, k-1)
  • 在同一 kk 下,所有 j<jj' < j 的狀態都已算完 → f(i+1,j1,k)f(i+1, j-1, k)
  • 在同一 (j,k)(j, k) 下,所有 i<ii' < i 的狀態都已算完 → f(i1,j,k)f(i-1, j, k)
迭代範圍優化

實現中使用 k in range(cnt[3] + 1) 而非 range(n + 1),這是一種空間剪枝優化:

  • 從初始狀態 (cnt1,cnt2,cnt3)(cnt_1, cnt_2, cnt_3) 出發,kk 最多為 cnt3cnt_3(不可能憑空產生有 3 個壽司的盤子)
    同理,jcnt2+cnt3j \le cnt_2 + cnt_3icnt1+cnt2+cnt3i \le cnt_1 + cnt_2 + cnt_3
  • 此外,當有 kk 個盤子剩 3 個壽司時,不可能有 j>nkj > n - k 個盤子剩 2 個壽司,因為 j+knj + k \le n(總盤子數不變)
    同理,不可能有 i>njki > n - j - k 個盤子剩 1 個壽司。

這樣只計算可達狀態,避免浪費計算資源

複雜度分析

  • 時間複雜度O(N3)\mathcal{O}(N^3)
    • 狀態數為滿足 i+j+kNi + j + k \le N 的非負整數三元組數量。
    • =Nijk\ell = N - i - j - k(空盤子數),則等價於求 i+j+k+=Ni + j + k + \ell = N 的非負整數解數,即重複組合數 HN4=(N+33)N36H^{4}_{N} = \binom{N+3}{3} \approx \frac{N^3}{6}
    • 每個狀態的轉移為 O(1)O(1)
  • 空間複雜度O(N3)\mathcal{O}(N^3),用於存儲 DP 表。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def solve():
n = int(input())
A = list(map(int, input().split()))
assert len(A) == n

cnt = [0] * 4
for a in A:
cnt[a] += 1

f = [[[0] * (n + 1) for _ in range(n + 1)] for _ in range(n + 1)]
# for k in range(n + 1):
# for j in range(n - k + 1):
# for i in range(n - k - j + 1):
s1 = cnt[1] + cnt[2] + cnt[3]
s2 = cnt[2] + cnt[3]
for k in range(cnt[3] + 1):
for j in range(min(s2, n - k) + 1):
for i in range(min(s1, n - k - j) + 1):
if i == 0 and j == 0 and k == 0:
continue
v = n
if i > 0:
v += i * f[i-1][j][k]
if j > 0:
v += j * f[i+1][j-1][k]
if k > 0:
v += k * f[i][j+1][k-1]
f[i][j][k] = v / (i + j + k)
print(f[cnt[1]][cnt[2]][cnt[3]])

if __name__ == "__main__":
solve()