🔗 🟡 1442. Count Triplets That Can Form Two Arrays of Equal XOR 1525

tags: Weekly Contest 188 前綴和(Prefix Sum) 雜湊表(Hash Table) 位運算(Bit Manipulation)

題意

給定一個整數陣列 arrarr

從陣列中選擇三個索引 iijjkk,其中 (0i<jk<arr.length)(0 \leq i < j \leq k < \text{arr.length})

定義 aabb 如下:

  • a=arr[i]arr[i+1]arr[j1]a = arr[i] \oplus arr[i + 1] \oplus \ldots \oplus arr[j - 1]
  • b=arr[j]arr[j+1]arr[k]b = arr[j] \oplus arr[j + 1] \oplus \ldots \oplus arr[k]

返回滿足 a=ba = b 的三元組 (i,j,k)(i, j, k) 的數量。

限制

  • 1arr.length3001 \leq arr.length \leq 300
  • 1arr[i]1081 \leq arr[i] \leq 10^8

思路:前綴和(Prefix Sum)

首先,可以注意到 aabb 都是子陣列的 XOR 和,所以可以使用 前綴和(Prefix Sum) 的方式來預處理 arrarr ,使得可以在 O(1)O(1) 的時間內計算 aabb

pre[i]pre[i] 表示 arr[0]arr[1]arr[i1]arr[0] \oplus arr[1] \oplus \ldots \oplus arr[i - 1] ,則

  • a=pre[j]pre[i]a = pre[j] \oplus pre[i]
  • b=pre[k+1]pre[j]b = pre[k + 1] \oplus pre[j]

方法一:三重迴圈

根據題意,枚舉三個下標 iijjkk,其中 i<jki < j \leq k,並檢查 a=ba = b 是否成立,若成立則累加答案即可。

由於 n300n \leq 300,所以 O(n3)O(n^3) 的時間複雜度是可以通過的。

而在檢查 pre[i]pre[j]=pre[j]pre[k+1]pre[i] \oplus pre[j] = pre[j] \oplus pre[k + 1] 時,可以發現等號兩邊都有 pre[j]pre[j],因此在做判斷的時候,可以直接比較 pre[i]pre[i]pre[k+1]pre[k + 1] 是否相等即可。這個性質可以幫助我們優化時間複雜度,會在之後的方法中提到。

複雜度分析

  • 時間複雜度 O(n3)O(n^3)
  • 空間複雜度 O(n)O(n)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Solution:
def countTriplets(self, arr: List[int]) -> int:
n = len(arr)
pre = [0] * (n + 1) # Prefix Sum
for i, x in enumerate(arr):
pre[i + 1] = pre[i] ^ x
ans = 0
for i in range(n): # i < j <= k
for j in range(i + 1, n):
for k in range(j, n):
# if pre[i] ^ pre[j] == pre[j] ^ pre[k + 1]:
if pre[i] == pre[k + 1]: # pre[j] is not important
ans += 1
return ans
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
public:
int n;
vector<int> pre;
int countTriplets(vector<int>& arr) {
n = arr.size();
pre = vector<int>(n + 1);
for (int i = 0; i < n; i++) pre[i + 1] = pre[i] ^ arr[i];
return solve1(arr);
}
int solve1(vector<int>& arr) {
int ans = 0;
for (int i = 0; i < n; i++) // i < j <= k
for (int j = i+1; j < n; j++)
for (int k = j; k < n; k++)
// if ((pre[i] ^ pre[j]) == (pre[j] ^ pre[k+1])) ans++;
if (pre[i] == pre[k+1]) ans++;
return ans;
}
};

方法二:二重迴圈

從方法一中,我們注意到了在判斷 a=ba = b 的時候,與 jj 有關的部分可以省略,因此可以將三重迴圈的 jj 部分省略,只保留兩重迴圈。

再來需要考慮的是當 pre[i]==pre[k+1]pre[i] == pre[k + 1] 時,有多少組 (i,j,k)(i, j, k) 滿足條件。這裡可以從方法一中思考,此時滿足條件的 (i,j,k)(i, j, k)(i,i+1,k),(i,i+2,k),...,(i,k,k)(i, i+1, k), (i, i+2, k), ..., (i, k, k),總共有 kik - i 種組合,因此可以直接計算 kik - i 並累加到答案中。

複雜度分析

  • 時間複雜度 O(n2)O(n^2)
  • 空間複雜度 O(n)O(n)
1
2
3
4
5
6
7
8
9
10
11
12
class Solution:
def countTriplets(self, arr: List[int]) -> int:
n = len(arr)
pre = [0] * (n + 1) # Prefix Sum
for i, x in enumerate(arr):
pre[i + 1] = pre[i] ^ x
ans = 0
for i in range(n): # i < j <= k
for k in range(i + 1, n):
if pre[i] == pre[k + 1]: # j is not important
ans += k - i # (i, i+1, k), (i, i+2, k), ..., (i, k, k)
return ans
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Solution {
public:
int n;
vector<int> pre;
int countTriplets(vector<int>& arr) {
n = arr.size();
pre = vector<int>(n + 1);
for (int i = 0; i < n; i++) pre[i + 1] = pre[i] ^ arr[i];
return solve2(arr);
}
int solve2(vector<int>& arr) {
int ans = 0;
for (int i = 0; i < n; i++) // i < j <= k
for (int k = i+1; k < n; k++)
if (pre[i] == pre[k+1]) ans += k - i; // (i, i+1, k), (i, i+2, k), ..., (i, k, k)
return ans;
}
};

方法三:雜湊表(Hash Table)

由於能夠滿足條件的 (i,j,k)(i, j, k) 皆滿足 pre[i]==pre[k+1]pre[i] == pre[k + 1] ,對於每個 kk ,我們只在乎能夠使得 pre[i]==pre[k+1]pre[i] == pre[k + 1]ii 有哪些,因此我們可以建立一個雜湊表 pospos ,保存每個 pre[i]pre[i] 出現的位置。

接著使用一重迴圈枚舉 kk,對於每個 kk,檢查 pre[k+1]pre[k + 1] 是否在 pospos 中,若在則 pos[pre[k+1]]pos[pre[k + 1]] 中的每個 ii ,其對答案的貢獻為 kik - i ,累加到答案中,這部分和方法二中的思路相同。最後再把當前的 kk 加入 pos[pre[k]]pos[pre[k]] 中,以便後續的計算。

複雜度分析

  • 時間複雜度 O(n2)O(n^2)
  • 空間複雜度 O(n)O(n)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Solution:
def countTriplets(self, arr: List[int]) -> int:
n = len(arr)
pre = [0] * (n + 1) # Prefix Sum
for i, x in enumerate(arr):
pre[i + 1] = pre[i] ^ x
pos = defaultdict(list)
ans = 0
for k in range(n):
if pre[k + 1] in pos:
for i in pos[pre[k + 1]]:
ans += k - i
pos[pre[k]].append(k)
return ans
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
public:
int n;
vector<int> pre;
int countTriplets(vector<int>& arr) {
n = arr.size();
pre = vector<int>(n + 1);
for (int i = 0; i < n; i++) pre[i + 1] = pre[i] ^ arr[i];
return solve3a(arr);
}
int solve3a(vector<int>& arr) {
int ans = 0;
unordered_map<int, vector<int>> pos;
for (int k = 0; k < n; k++) {
if (pos.count(pre[k+1])) {
for (int i : pos[pre[k+1]]) {
ans += k - i;
}
}
pos[pre[k]].push_back(k);
}
return ans;
}
};

方法三優化:雜湊表(Hash Table) + 一重迴圈

方法三中,雖然已經節省了枚舉 i,ji, j 的時間,但是在枚舉 kk 的時候,仍然需要遍歷 pos[pre[k+1]]pos[pre[k + 1]] 中的每個 ii,在最壞情況下,時間複雜度仍然是 O(n2)O(n^2)

而是否存在 O(n)O(n) 的解法呢?不妨從貢獻的角度來思考,對於 pos[pre[k+1]]pos[pre[k+1]] 中的每個 i,其對答案的貢獻為 kik - i,可以將其拆分為兩部分:

  • kk 對答案的總貢獻為 k×mk \times m,其中 mmpos[pre[k+1]]pos[pre[k+1]] 的長度,即下標的個數。
  • ii 對答案的總貢獻為 ipos[pre[k+1]]i\sum_{i \in pos[pre[k+1]]} i

因此,可以將 pos[pre[k+1]]pos[pre[k+1]] 的長度和下標的總和都保存下來,這樣在枚舉 kk 的時候,只需要計算 k×mk \times mipos[pre[k+1]]i\sum_{i \in pos[pre[k+1]]} i 即可。

這裡使用兩個 雜湊表(Hash Table) cntcnttottot 來保存 pos[pre[k+1]]pos[pre[k+1]] 的長度以及總和,這樣就可以在一重迴圈中完成計算。

複雜度分析

  • 時間複雜度 O(n)O(n)
  • 空間複雜度 O(n)O(n)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Solution:
def countTriplets(self, arr: List[int]) -> int:
n = len(arr)
pre = [0] * (n + 1) # Prefix Sum
for i, x in enumerate(arr):
pre[i + 1] = pre[i] ^ x
cnt = Counter()
tot = Counter()
ans = 0
for k in range(n):
if pre[k + 1] in cnt:
ans += cnt[pre[k + 1]] * k - tot[pre[k + 1]]
cnt[pre[k]] += 1
tot[pre[k]] += k
return ans
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Solution {
public:
int n;
vector<int> pre;
int countTriplets(vector<int>& arr) {
n = arr.size();
pre = vector<int>(n + 1);
for (int i = 0; i < n; i++) pre[i + 1] = pre[i] ^ arr[i];
return solve3b(arr);
}
int solve3b(vector<int>& arr) {
int ans = 0;
unordered_map<int, int> cnt, tot;
for (int k = 0; k < n; k++) {
if (cnt.count(pre[k+1])) {
ans += cnt[pre[k+1]] * k - tot[pre[k+1]];
}
cnt[pre[k]]++;
tot[pre[k]] += k;
}
return ans;
}
};

寫在最後

Cover photo is generated by @ゴリラの素材屋さん, thanks for their work!

在這個問題中,可以將問題從暴力解法逐步優化到最優解法,這也是解題的樂趣之一。