🔗 🔴 3171. Find Subarray With Bitwise OR Closest to K 2163

題意

給定一個整數陣列 numsnums 和一個整數 kk。你需要找到 numsnums 的一個子陣列,使得該子陣列元素的按位 OR\text{OR}kk 之間的 絕對差 盡可能 。換句話說,選擇一個子陣列 nums[l..r]nums[l..r],使得 k(nums[l] OR nums[l+1] OR ... OR nums[r])|k - (nums[l] \text{ OR } nums[l + 1] \text{ OR } ... \text{ OR } nums[r])| 最小。

返回絕對差的 最小 可能值。

一個 子陣列(Subarray) 是一個陣列內連續的 非空 元素序列。

約束條件:

  • 1nums.length1051 \leq \text{nums.length} \leq 10^5
  • 1nums[i]1091 \leq \text{nums[i]} \leq 10^9
  • 1k1091 \leq k \leq 10^9

思路:線段樹上二分 / LogTrick

首先思考暴力做法,也就是枚舉所有子陣列,計算其 OR 值,並更新最小值。其中枚舉子陣列需要 O(n2)O(n^2) 的時間複雜度,計算 OR 值需要 O(n)O(n) 的時間複雜度,因此總時間複雜度為 O(n3)O(n^3),會 TLE。

而為了快速計算任意子陣列的 OR 值,我們可以使用 線段樹(Segment Tree) ,關於線段樹的介紹可以參考线段树从入门到急停。令線段樹的每個節點代表其對應區間內所有元素的按位 OR 值。這樣,我們可以在 O(logn)O(\log n) 的時間內查詢任意區間的 OR 值,如此時間複雜度可以降低為 O(n2logn)O(n^2 \log n)

但這樣的時間複雜度還是不能接受,因此我們需要進一步優化。由於 OR 的性質,若固定左端點,隨著右端點的遞增,OR 值是遞增的,因此我們可以透過二分搜尋找到最小的右端點 idxidx,使得 OR 值大於等於 kk。但需注意,由於答案是取絕對值,因此其符合條件的右端點會有兩個,即 idxidxidx1idx - 1,因此需要分別計算其與 kk 的差值,並更新最小值。如此一來,我們可以將時間複雜度降低為 O(nlognlogn)=O(nlog2n)O(n \log n \log n) = O(n \log^2 n)

複雜度分析

  • 時間複雜度:O(nlog2n)\mathcal{O}(n \log^2 n),其中 nn 是陣列長度。
    • 建構線段樹需要 O(nlogn)\mathcal{O}(n \log n) 的時間。
    • 對每個左端點進行二分搜尋,每次搜尋需要 O(logn)\mathcal{O}(\log n) 的時間。
    • 每次二分搜尋中,查詢線段樹需要 O(logn)\mathcal{O}(\log n) 的時間。
    • 總體時間複雜度為 O(nlog2n)\mathcal{O}(n \log^2 n)
  • 空間複雜度:O(n)\mathcal{O}(n)
    • 線段樹需要 O(n)\mathcal{O}(n) 的空間。

Python 會 TLE,這裡只作為參考。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class SegmentTree:
def __init__(self, nums: List[int], k: int):
n = len(nums)
self.k = k
self.nums = [0] + nums # 讓 index 從 1 開始
self.tree = [0 for _ in range(4 * n)] # (OR)
self.build(1, 1, n)

def build(self, o, left, right): # node, left, right
if left == right: # Leaf node initialization
self.tree[o] = self.nums[left]
return
mid = (left + right) // 2
self.build(2*o, left, mid)
self.build(2*o+1, mid + 1, right)
self.tree[o] = self.merge(2*o, 2*o+1)

def merge(self, left_child, right_child):
return self.tree[left_child] | self.tree[right_child]

def query(self, o, left, right, l, r):
if l <= left and right <= r:
return self.tree[o]
mid = (left + right) // 2
ans = 0
if l <= mid:
ans |= self.query(2*o, left, mid, l, r)
if r > mid:
ans |= self.query(2*o+1, mid + 1, right, l, r)
return ans

class Solution:
def minimumDifference(self, nums: List[int], k: int) -> int:
n = len(nums)
seg = SegmentTree(nums, k)
ans = float('inf')
for i in range(1, n+1): # 枚舉左端點

left, right = i, n
while left <= right:
mid = (left + right) // 2
if seg.query(1, 1, n, i, mid) >= k:
right = mid - 1
else:
left = mid + 1
ans = min(ans, abs(seg.query(1, 1, n, i, left) - k))
if right >= i:
ans = min(ans, abs(k - seg.query(1, 1, n, i, right)))
return ans
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class SegmentTree {
private:
vector<int> nums;
vector<int> tree;
public:
SegmentTree(vector<int>& nums) {
int n = nums.size();
this->nums = vector<int>(n + 1, 0);
for (int i = 1; i <= n; i++) this->nums[i] = nums[i - 1]; // 1-indexed
this->tree = vector<int>(4 * n, 0);
build(1, 1, n);
}

void build(int o, int left, int right) {
if (left == right) {
tree[o] = nums[left];
return;
}
int mid = (left + right) / 2;
build(2 * o, left, mid);
build(2 * o + 1, mid + 1, right);
tree[o] = merge(2 * o, 2 * o + 1);
}

int merge(int left_child, int right_child) {
return tree[left_child] | tree[right_child];
}

int query(int o, int left, int right, int l, int r) {
if (l <= left && right <= r) return tree[o];
int mid = (left + right) / 2;
int ans = 0;
if (l <= mid) ans |= query(2 * o, left, mid, l, r);
if (r > mid) ans |= query(2 * o + 1, mid + 1, right, l, r);
return ans;
}
};

class Solution {
public:
int minimumDifference(vector<int>& nums, int k) {
int n = nums.size();
SegmentTree seg(nums);
int ans = INT_MAX;
for (int i = 1; i <= n; i++) {
int left = i, right = n;
while (left <= right) {
int mid = left + (right - left) / 2;
if (seg.query(1, 1, n, i, mid) >= k) right = mid - 1;
else left = mid + 1;
}
ans = min(ans, abs(seg.query(1, 1, n, i, left) - k));
if (right >= i) ans = min(ans, abs(k - seg.query(1, 1, n, i, right)));
}
return ans;
}
};

方法二:LogTrick

注意到在方法一中,我們提到 OR 的性質,若固定左端點,隨著右端點的遞增,OR 值是遞增的。而再每次 OR 操作後,若操作後的值與操作前不同,則至少會將其中一個位元的 00 變為 11,因此最多只會有 O(max_bit)=O(logU)O(\text{max\_bit}) = O(\log U) 種不同的 OR 值,其中 UU 是陣列中最大元素的值。

為了方便,我們改成枚舉右端點,並用一個堆疊或列表 stst 維護以右端點為結果的 OR 值。

  • 每次枚舉到新的元素 xx,我們新建一個堆疊 st2st2,並將 xx 與堆疊 stst 中的每個元素進行 OR 操作,若結果與堆疊 st2st2 中的最後一個元素不同,則將其添加到堆疊 st2st2 中。如此一來,堆疊 st2st2 中的元素是遞增的,並且不重複。
  • 最後,我們遍歷堆疊 st2st2,更新最小值,並將堆疊 stst 更新為 st2st2

如此一來,我們可以將時間複雜度降低為 O(nlogU)O(n \log U),其中 nn 是陣列長度,logU\log U 是陣列中最大元素的位數。

這種方法之所以被稱為 LogTrick ,是因為它巧妙地利用了 OR 運算的性質,並且其時間複雜度與數字的位數(logarithm of the number)相關。

複雜度分析

  • 時間複雜度:O(nlogU)\mathcal{O}(n \log U),其中 nn 是陣列長度,logU\log U 是陣列中最大元素的位數。
  • 空間複雜度:O(logU)\mathcal{O}(\log U),因為 stst 最多存儲 logU\log U 個元素。
1
2
3
4
5
6
7
8
9
10
11
12
13
class Solution:
def minimumDifference(self, nums: List[int], k: int) -> int:
st = []
ans = float('inf')
for x in nums: # 枚舉右端點
st2 = [x] # 保存以 x 為右端點的所有 OR 結果,注意由於 OR 的性質,這裡的 st2 是遞增的
for y in st:
if x | y != st2[-1]:
st2.append(x | y)
st = st2
for y in st:
ans = min(ans, abs(k - y))
return ans
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution {
public:
int minimumDifference(vector<int>& nums, int k) {
int n = nums.size();
int ans = INT_MAX;
vector<int> st;
for (int x : nums) {
vector<int> st2 = {x};
for (int y : st) {
if ((y | x) != st2.back()) {
st2.push_back(y | x);
}
}
st = st2;
for (int y : st) {
ans = min(ans, abs(y - k));
}
}
return ans;
}
};

類題:LogTrick

參考資料


寫在最後

PROMPT

masterpiece, best quality, high quality,extremely detailed CG unity 8k wallpaper, extremely detailed, High Detail, colors,
(1girl, solo), (idol, idol costume), long hair, black hair, dress, bow, standing, detached sleeves, white dress, hand on hip, curtains, pointing, pointing at self, stage, on stage,
A young girl wearing a lavish purple dress with puffy sleeves and a layered skirt, Her hair is styled in twin tails with purple bows, The background is a dark blue curtain, She is smiling and posing cutely,

賽時是用 C++ 寫線段樹上二分,但實際上可以不用這麼麻煩,直接用 LogTrick 就可以了。

另外,雖然賽時的題目是 AND 而不是 OR,但解法還是相同的,差別只在於 AND 會使數字變小,而 OR 會使數字變大。