題目的難度顏色使用 Luogu 上的分級,由簡單到困難分別為 🔴🟠🟡🟢🔵🟣⚫。

🔗 🌈 ABC465E Digit Circus

Problem Statement

題目簡述

給定整數 NN,求 1xN1 \le x \le N 中有多少個整數恰好滿足下列三個條件中的一個:

  • xx33 的倍數。
  • xx 的十進位表示包含數字 33
  • xx 的十進位表示恰好使用了三種不同的數字。

答案對 998244353998244353 取模。

Constraints

約束條件

  • NN 為整數。
  • 1N<105001 \le N < 10^{500}

思路:數位 DP

因為 NN 最多有 500500 位,直接枚舉 11NN 顯然是不可行的。這時候可以考慮數位 DP 的方法,按位構造數字,並記錄相同狀態的結果,避免重複計算。

前置知識:數位 DP

數位 DP 的基本做法是從高位到低位填數字,並用「是否仍貼著上下界」限制當前可選範圍。這裡下界固定是 11,上界是 NN,所以每一位只需要知道目前前綴是否仍等於下界或上界前綴,以及其他狀態資訊(例如數位和、已出現數字集合等)。

額外狀態

最後要判斷三件事:是否為 33 的倍數、是否包含數字 33、是否剛好用了三種不同數字,因此除了數位DP基本的狀態外,還需要維護以下狀態:

  • 是否為 33 的倍數,需要維護目前數位和對 33 的餘數 s{0,1,2}s \in \{0, 1, 2\}
  • 是否包含數字 33,可以從「出現過的數字集合 mskmsk」判斷。
  • 是否剛好用了三種不同數字,也可以從同一個集合的大小判斷。

條件判斷

由於我們需要統計恰好滿足一個條件的數字,看似需要使用排容原理,但其實可以直接在每個數字構造完成時檢查三個條件的布林值,並計算它們的總和是否為 11。這樣就不需要額外的排容計算。

複雜度分析

  • 時間複雜度:O(10n3210)\mathcal{O}(10 \cdot n \cdot 3 \cdot 2^{10}),其中 nnNN 的位數。這是因為可記憶化的狀態由 (i,s,msk)(i, s, msk) 決定,總共有 n3210n \cdot 3 \cdot 2^{10} 種,而每個狀態最多枚舉 00991010 個數字作為下一位。
  • 空間複雜度:O(n3210)\mathcal{O}(n \cdot 3 \cdot 2^{10})memo 只需要存所有 (i,s,msk)(i, s, msk) 狀態的答案。

Code

注意 @cache 的開銷

Python 的 @cache 會把完整參數 (i,s,msk,limit_low,limit_high)(i, s, msk, limit\_low, limit\_high) 當作 key,連貼著上下界的一次性狀態也一起快取,而且 dict 的雜湊存取常數較大。本題 nn 可達 500500,直接套 @cache 會 TLE。

解決方式是用三維 list 手動維護 memo,並只在 not limit_low and not limit_high 時寫入——此時答案只由 (i,s,msk)(i, s, msk) 決定,用 index 直接存取比 dict 快很多,常數與記憶體開銷都大幅降低。

為什麼只有 not limit_low and not limit_high 時才記憶化?

記憶化的前提是:之後還有可能再次走到同一個狀態,這樣才有重用的價值。數位 DP 的本質是逐位枚舉填入什麼數字,而 lowhigh 分別對應目前枚舉區間的下界與上界。

limit_lowlimit_hightrue 時,目前前綴正貼著某個邊界。這種狀態只會沿著邊界路徑走一次——例如「前綴恰好等於 low」或「前綴恰好等於 high」——之後不會再有第二條路徑回到這裡,存進 memo 也用不上。

只有當兩個限制都解除後,後續的每一位才能自由填 0099,此時遞迴結果才真正只由 (i,s,msk)(i, s, msk) 決定,記憶化才有重用價值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
MOD = 998244353


def solve():
r = input().strip()
l = 1

n = len(r)
diff = n - len(str(l))
high = list(map(int, str(r)))
low = list(map(int, str(l).zfill(n))) # 補前導零,使 low 和 high 對齊

memo = [[[-1] * (1 << 10) for _ in range(3)] for _ in range(n)]

# @cache
def dfs(i: int, s: int, msk: int, limit_low: bool, limit_high: bool) -> int:
if i == n:
return ((s == 0) + (msk.bit_count() == 3) + ((msk >> 3) & 1)) == 1

if not limit_low and not limit_high and memo[i][s][msk] != -1:
return memo[i][s][msk]

# 第 i 個數位可以從 lo 枚舉到 hi
# 如果對數位還有其它約束,應該只在下面的 for 迴圈做限制,不應修改 lo 或 hi
st = lo = low[i] if limit_low else 0
hi = high[i] if limit_high else 9

res = 0
if i < diff and limit_low:
res += dfs(i + 1, 0, 0, True, False) # 前導 0
st = 1
for d in range(st, hi + 1):
res += dfs(i + 1, (s + d) % 3, msk | (1 << d), limit_low and d == lo, limit_high and d == hi)
if not limit_low and not limit_high:
memo[i][s][msk] = res % MOD
return res % MOD

ans = dfs(0, 0, 0, True, True)
print(ans)


if __name__ == "__main__":
solve()