題目的難度顏色使用 Luogu 上的分級,由簡單到困難分別為 🔴🟠🟡🟢🔵🟣⚫。
Problem Statement
題目簡述
給定兩個長度為 n 的整數陣列 A 和 B。我們可以任意重排 B。
你的目標是最大化 P=i=1∏nmin(Ai,Bi)。
題目共有 q 次修改操作,每次操作會給定 (o,x),若 o=1 則將 Ax 加 1,若 o=2 則將 Bx 加 1。
請輸出初始狀態以及每次修改後,可以得到的最大 P 值(對 998244353 取模)。
Constraints
約束條件
- 1≤n,q≤2⋅105
- ∑n,∑q≤4⋅105
- 1≤ai,bi≤5⋅108
思路:貪心 + 排序 + 二分查找
貪心:排序後配對最優
將 A 和 B 分別從小到大排序後,設排序後的陣列為 C 和 D,則最大乘積為:
P=i=1∏nmin(Ci,Di)
為什麼排序後配對最優?
假設 A,B 排序後分別為 C,D。若存在逆序對 (i,j),即 Ci≤Cj 但 Di>Dj。
令 a=Ci,b=Cj(故 a≤b),d=Di,c=Dj(故 c<d)。比較兩種配對:
- 交換前(逆序):Vpre=min(a,d)×min(b,c)
- 交換後(同序):Vpost=min(a,c)×min(b,d)
此時 d>c≥a,故 min(a,d)=a,min(a,c)=a。
- Vpre=a×min(b,c)
- Vpost=a×min(b,d)
因為 c<d,所以 min(b,c)≤min(b,d)。故 Vpost≥Vpre。
此時 c<a≤b,故 min(b,c)=c,min(a,c)=c。
- Vpre=min(a,d)×c
- Vpost=c×min(b,d)
因為 a≤b,所以 min(a,d)≤min(b,d)。故 Vpost≥Vpre。
綜上,將逆序對交換成同序後,乘積不會變小。反覆交換直到兩陣列同序即得最優解。
二分搜尋維護動態更新
由於題目涉及 q 次單點增加(每次 +1),我們不能每次重新排序(O(nlogn) 太慢)。需要在 O(logn) 時間內維護排序數組和答案。
當原陣列 Ai 的值 v 增加 1 時,相當於將排序陣列 C 中任意一個值為 v 的位置變成 v+1。為了保持 C 的有序性,我們選擇 C 中最右邊的那個 v 進行增加:
- 設 idx 為 C 中值為 v 的最後一個元素位置
- 將 C[idx] 從 v 更新為 v+1
- 由於 C[idx+1]≥v+1(若存在),更新後 C[idx]≤C[idx+1],陣列依然有序
由於 C 是有序的,使用二分搜尋即可在 O(logn) 內找到該位置。
利用模意義下的乘法反元素(逆元)更新乘積
我們維護當前的總乘積 P。更新 C[idx] 時,只有當它是瓶頸時,乘積才會改變:
- 若 C[idx]<D[idx]:該位置對乘積的貢獻從 C[idx] 變為 C[idx]+1
P′=P×C[idx](C[idx]+1)
- 若 C[idx]≥D[idx]:該位置貢獻為 D[idx],更新 C[idx] 不影響最小值,P 不變
由於 MOD=998244353 是質數,根據費馬小定理:
a−1≡aMOD−2(modMOD)
在 Python 中可以直接用 pow(B[x], -1, MOD) 計算。
這樣每次查詢只需 O(logn)(二分搜尋)+ O(logMOD)(快速冪求乘法反元素)。
- 維護兩個排序數組
sl1(對應 C)和 sl2(對應 D)
- 同時保留原陣列 A,B 以得知每次操作增加的具體數值
- 操作 Ax 時:取 v=A[x],找
bisect_right(sl1, v) - 1 作為更新點,判斷是否需更新 prod,然後將 sl1[idx] 和 A[x] 都加 1
複雜度分析
- 時間複雜度:O(nlogn+qlogn)
- 初始排序 O(nlogn)
- 每次查詢:二分搜尋 O(logn),乘法反元素 O(logMOD)
- 空間複雜度:O(n),用於存儲排序陣列
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
| from bisect import bisect_right
MOD = 998244353
def solve(): n, q = map(int, input().split()) A = list(map(int, input().split())) B = list(map(int, input().split())) assert len(A) == n and len(B) == n
sl1 = sorted(A) sl2 = sorted(B)
prod = 1 for x, y in zip(sl1, sl2): prod = (prod * min(x, y)) % MOD ans = [prod] for _ in range(q): o, x = map(int, input().split()) x -= 1 if o == 1: idx = bisect_right(sl1, A[x]) - 1 if sl1[idx] < sl2[idx]: prod = (prod * pow(A[x], MOD - 2, MOD)) % MOD prod = (prod * (A[x] + 1)) % MOD sl1[idx] += 1 A[x] += 1 else: idx = bisect_right(sl2, B[x]) - 1 if sl2[idx] < sl1[idx]: prod = (prod * pow(B[x], -1, MOD)) % MOD prod = (prod * (B[x] + 1)) % MOD sl2[idx] += 1 B[x] += 1 ans.append(prod) print(*ans)
if __name__ == "__main__": t = int(input()) for _ in range(t): solve()
|
寫在最後
PROMPT