树状数组倍增学习笔记

树状数组倍增学习笔记

Codeforces 上的教程:link

基本思想

问题描述

我们考虑一个问题:

给定某个序列,要求维护的操作有:单点修改,求前缀和,搜索某个前缀和(类似于在前缀和数组上求 lower_bound)。

这个事实上容易用线段树完成,但是由于树状数组具有空间更小、常数小、代码简单等优点,我们想用树状数组完成这个操作。

$O(\lg^2n)$ 的实现

单点修改和求前缀和都可以用树状数组完成,问题主要在于如何搜索某个前缀和。

由于前缀和这玩意儿具有单调性,一个简单的想法就是二分查找。代码如下:

>folded
1
2
3
4
5
6
7
8
9
int search(int val){
int l = 1, r = n;
while(l < r){
int mid = (l + r) >> 1;
if(sum(mid) < val) l = mid + 1;
else r = mid;
}
return l;
}

二分是 $O(\lg n)$ 的,在树状数组上求前缀和是 $O(\lg n)$ 的,所以总复杂度是 $O(\lg^2n)$ 的。

$O(\lg n)$ 的实现——倍增思想

倍增思想有很多重要的应用,例如 $ST$ 表、倍增求 lca 等,这里可以帮助我们在树状数组上完成 lower_bound 操作。

假设我们想要搜索前缀和为 $val$ 的地方,设定一个 pos 指针,它初始为 $0$,最终将指向最大的前缀和小于 $val$ 的位置;再设置一个变量 sum​,存储 pos​ 处的前缀和;设置倍增的长度 i,最初为 $\lg n$(为了代码方便,一般取 $20$ 即可),在倍增的过程中不断减小至 $0$。每一个状态(possumi)表示我们现在考虑的是位置 pos+(1<<i) 的前缀和,这个前缀和的值是 sum+c[pos+(1<<i)],如果它大于等于了 $val$,那么我们减小倍增的长度 i;否则,我们把 pos 提到 pos+(1<<i) 处。

我们用例子来更直观地说明【以下例子和图片均来源于 Codeforces 的教程】:

给定数组 a[]

它的树状数组 c[] 长这样:

我们想搜索 $val=27$ 的位置,那么算法过程如下:

最后 pos 值为 $13$,是最大的前缀和小于 $27$ 的位置。所以我们的目标位置就是 pos+1.

代码如下:

>folded
1
2
3
4
5
6
7
int search(int val){
int pos = 0, sum = 0;
for(int i = 20; i >= 0; i--)
if(pos + (1<<i) <= n && sum + c[pos+(1<<i)] < val)
pos += (1<<i), sum += c[pos];
return pos + 1;
}

进一步

容易发现,只要我们维护的信息具有单调性,就可以用这个方法。

练习

CF1354D Multiset

题目链接

其实这道题是我学树状数组倍增的原因。

开一个值域树状数组,维护前缀个数,这玩意儿是单调增加的,所以查询第 $k$ 个数可以用上树状数组倍增。

>folded
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include<bits/stdc++.h>

using namespace std;

template<typename T>void read(T&x){x=0;int fl=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')
fl=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}x*=fl;}
template<typename T,typename...Args>inline void read(T&t,Args&...args){read(t);read(args...);}

typedef long long LL;
typedef vector<int> vi;
typedef pair<int, int> pii;
#define mp(x, y) make_pair(x, y)
#define pb(x) emplace_back(x)

const int N = 1000005;

int n, q;

int c[N];
inline int lowbit(int x){ return x & -x; }
inline void add(int x, int val){
while(x <= n){
c[x] += val;
x += lowbit(x);
}
}
inline int search(int val){
int pos = 0, sum = 0;
for(int i = 20; i >= 0; i--)
if(pos + (1<<i) <= n && sum + c[pos+(1<<i)] < val)
pos += (1<<i), sum += c[pos];
return pos + 1;
}

int main(){
read(n, q);
for(int i = 1; i <= n; i++){
int x; read(x);
add(x, 1);
}
while(q--){
int x; read(x);
if(x > 0) add(x, 1);
else add(search(-x), -1);
}
int ans = search(1);
if(ans == n + 1) puts("0");
else printf("%d\n", ans);
return 0;
}

CF992E Nastya and King-Shamans

题目链接

由 $a_i=s_{i-1}$ 可以推出 $s_i=2s_{i-1}$,那么我们用树状数组维护前缀和,每次询问时从 $sum=0$ 开始查找第一个前缀和大于等于 $2sum$ 的位置,由于每次乘 $2$,所以最多查找 $\lg 10^{14}$ 次。

>folded
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include<bits/stdc++.h>

using namespace std;

template<typename T>void read(T&x){x=0;int fl=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')
fl=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}x*=fl;}
template<typename T,typename...Args>inline void read(T&t,Args&...args){read(t);read(args...);}

typedef long long LL;
typedef vector<int> vi;
typedef pair<int, int> pii;
#define mp(x, y) make_pair(x, y)
#define pb(x) emplace_back(x)

const int N = 200005;

int n, q, a[N];

LL c[N];
inline int lowbit(int x){ return x & -x; }
inline void add(int x, LL val){
while(x <= n){
c[x] += val;
x += lowbit(x);
}
}
inline LL sum(int x){
LL res = 0;
while(x){
res += c[x];
x -= lowbit(x);
}
return res;
}
inline int search(LL val){
int pos = 0;
LL sum = 0;
for(int i = 20; i >= 0; i--)
if(pos + (1<<i) <= n && sum + c[pos + (1<<i)] < val)
pos += (1<<i), sum += c[pos];
return pos + 1;
}

int main(){
read(n, q);
for(int i = 1; i <= n; i++){
read(a[i]);
add(i, a[i]);
}
while(q--){
int p, x; read(p, x);
add(p, x - a[p]);
a[p] = x;
LL s = 0;
while(1){
int pos = search(s << 1);
if(pos == n + 1){
puts("-1");
break;
}
if(sum(pos) == sum(pos-1) << 1){
printf("%d\n", pos);
break;
}
s = sum(pos);
}
}
return 0;
}
作者

xyfJASON

发布于

2020-05-19

更新于

2021-02-25

许可协议

评论