DSU on tree 学习笔记

DSU on tree 学习笔记

参考博客:1 2 3

简述

DSU on tree,其实和 dsu(并查集) 没有太大关系,唯一的一点关系可能就是“启发性合并”的思想吧,所以它的中文名称为“树上启发性合并”。

这是一个看起来很暴力的算法。我们拿到一道树上针对顶点询问且无修改的问题,首先考虑如何暴力求解——我们用全局变量存储需要维护的信息,整体 $dfs$ 一遍枚举顶点,对每个枚举的顶点 $dfs$ 它的子树计算这个顶点的答案,当然计算答案前先把全局变量里之前的答案清空。暴力求解的时间复杂度显然是 $O(n^2)$,但我们只需要调节 $dfs$ 的顺序和适时保留全局变量里存的答案就能降低复杂度。

每枚举到一个顶点时,先去计算它的轻儿子的答案,最后再计算重儿子的答案。但是重儿子的答案计算完毕后不再清空全局变量,而是沿用里面的信息去更新它的父节点的答案——这个时候只需要暴力统计轻儿子们对父节点的贡献。

由于从根开始的任意路径上轻边数量不超过 $\lg n$,所以每个点因暴力统计被访问的次数是 $O(\lg n)$ 的,所以 DSU on tree 的复杂度是 $O(n\lg n)$.

模板

流程:

  • 预处理重儿子
  • $dfs$ 遍历节点
    • 先遍历轻儿子,求出答案后消除影响
    • 最后遍历重儿子,求出答案后不消除影响
  • 回溯后,暴力统计轻儿子对当前节点的答案
  • 如果当前节点对于其父节点是轻儿子,则消除影响

代码模板以 Codeforces 600E. Lomsat gelral 为例。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
int fa[N], sz[N], son[N];
void dfs(int x, int f){
fa[x] = f, sz[x] = 1, son[x] = 0;
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == f) continue;
dfs(edge[i].to, x);
sz[x] += sz[edge[i].to];
if(!son[x] || sz[edge[i].to] > sz[son[x]])
son[x] = edge[i].to;
}
}

LL ans[N], mx, sum, cnt[N]; // GLOBAL variants to store the answer
int mark; // mark the heavy son which needs to be ignored
void getData(int x, int val){ // get data with brute-force

cnt[c[x]] += val;
if(mx < cnt[c[x]]) mx = cnt[c[x]], sum = c[x];
else if(mx == cnt[c[x]]) sum += c[x];

for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x]) continue;
if(edge[i].to == mark) continue; // ignore the marked subtree
getData(edge[i].to, val);
}
}
void dsu(int x, bool opt){ // opt == true: answer needs to be erased
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x] || edge[i].to == son[x]) continue;
dsu(edge[i].to, true); // solve for light sons first
}
if(son[x]) dsu(son[x], false), mark = son[x]; // solve for heavy son
// now the global variants have already stored heavy son's answer
getData(x, 1);
mark = 0; // don't forget this!

// now the global variants store the answer for vertex x
ans[x] = sum;

if(opt){ // erase the answer
getData(x, -1);
mx = 0, sum = 0;
}
}

练习

Codeforces 600E. Lomsat gelral

题目链接

开全局变量 mx,sum,cnt[] 来记录最大数量、最大数量的颜色和、每种颜色的数量。

>folded
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include<bits/stdc++.h>

using namespace std;

typedef long long LL;

const int N = 100005;

int n, c[N];

struct Edge{
int nxt, to;
}edge[N<<1];
int head[N], edgeNum;
void addEdge(int from, int to){
edge[++edgeNum] = (Edge){head[from], to};
head[from] = edgeNum;
}

int fa[N], sz[N], son[N];
void dfs(int x, int f){
fa[x] = f, sz[x] = 1, son[x] = 0;
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == f) continue;
dfs(edge[i].to, x);
sz[x] += sz[edge[i].to];
if(!son[x] || sz[edge[i].to] > sz[son[x]])
son[x] = edge[i].to;
}
}

LL ans[N], mx, sum, cnt[N]; // GLOBAL variants to store the answer
int mark; // mark the heavy son which needs to be ignored
void getData(int x, int val){ // get data with brute-force
cnt[c[x]] += val;
if(mx < cnt[c[x]]) mx = cnt[c[x]], sum = c[x];
else if(mx == cnt[c[x]]) sum += c[x];
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x]) continue;
if(edge[i].to == mark) continue; // ignore the marked subtree
getData(edge[i].to, val);
}
}
void dsu(int x, bool opt){ // opt == true: answer needs to be erased
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x] || edge[i].to == son[x]) continue;
dsu(edge[i].to, true); // solve for light sons first
}
if(son[x]) dsu(son[x], false), mark = son[x]; // solve for heavy son
// now the global variants have already stored heavy son's answer
getData(x, 1);
mark = 0;
// now the global variants store the answer for vertex x
ans[x] = sum;
if(opt){ // erase the answer
getData(x, -1);
mx = 0, sum = 0;
}
}

int main(){
scanf("%d", &n);
for(int i = 1; i <= n; i++)
scanf("%d", &c[i]);
for(int i = 1; i < n; i++){
int u, v; scanf("%d%d", &u, &v);
addEdge(u, v), addEdge(v, u);
}
dfs(1, 0);
dsu(1, true);
for(int i = 1; i <= n; i++)
printf("%lld ", ans[i]);
return 0;
}

Codeforces 570D. Tree Requests

题目链接

开全局变量——bitset<26>b[] 记录每一层的字母数量奇偶性。

>folded
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include<bits/stdc++.h>

using namespace std;

typedef long long LL;

const int N = 500005;

int n, m;
char s[N];
vector<int> vec[N];

struct Query{
int v, h;
bool ans;
}q[N];

struct Edge{
int nxt, to;
}edge[N<<1];
int head[N], edgeNum;
void addEdge(int from, int to){
edge[++edgeNum] = (Edge){head[from], to};
head[from] = edgeNum;
}

int fa[N], sz[N], son[N], dep[N];
void dfs(int x, int f, int depth){
fa[x] = f, sz[x] = 1, son[x] = 0, dep[x] = depth;
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == f) continue;
dfs(edge[i].to, x, depth + 1);
sz[x] += sz[edge[i].to];
if(!son[x] || sz[edge[i].to] > sz[son[x]])
son[x] = edge[i].to;
}
}

bitset<26> b[N];
int mark;
void getData(int x){
b[dep[x]].flip(s[x]-'a');
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x]) continue;
if(edge[i].to == mark) continue;
getData(edge[i].to);
}
}
void dsu(int x, bool opt){
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x] || edge[i].to == son[x]) continue;
dsu(edge[i].to, true);
}
if(son[x]) dsu(son[x], false), mark = son[x];
getData(x);
mark = 0;
for(auto id: vec[x])
q[id].ans = b[q[id].h].count() <= 1;
if(opt) getData(x);
}

int main(){
scanf("%d%d", &n, &m);
for(int i = 2; i <= n; i++){
int p; scanf("%d", &p);
addEdge(i, p), addEdge(p, i);
}
dfs(1, 0, 1);
scanf("%s", s+1);
for(int i = 1; i <= m; i++){
scanf("%d%d", &q[i].v, &q[i].h);
vec[q[i].v].emplace_back(i);
}
dsu(1, true);
for(int i = 1; i <= m; i++)
puts(q[i].ans ? "Yes" : "No");
return 0;
}

SGU 507. Treediff

题目链接

开全局变量——set<int> s 来记录现在有哪些叶节点的值,更新差值最小值时在 set 里面查前驱后继即可。

>folded
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include<bits/stdc++.h>

using namespace std;

typedef long long LL;

const int INF = 2147483647;
const int N = 50005;

int n, m, a[N], ans[N];

struct Edge{
int nxt, to;
}edge[N<<1];
int head[N], edgeNum, deg[N];
void addEdge(int from, int to){
edge[++edgeNum] = (Edge){head[from], to};
head[from] = edgeNum;
deg[to]++;
}

int fa[N], son[N], sz[N];
void dfs(int x, int f){
fa[x] = f, son[x] = 0, sz[x] = 1;
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == f) continue;
dfs(edge[i].to, x);
sz[x] += sz[edge[i].to];
if(!son[x] || sz[edge[i].to] > sz[son[x]])
son[x] = edge[i].to;
}
}

set<int> s;
int mn = INF, mark;
void getData(int x){
if(deg[x] == 1 && x != 1){ // is a leaf
auto pre = s.lower_bound(a[x]);
if(pre != s.begin()){
pre--;
mn = min(mn, a[x] - (*pre));
}
auto suf = s.lower_bound(a[x]);
if(suf != s.end()){
mn = min(mn, (*suf) - a[x]);
}
s.emplace(a[x]);
}
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x]) continue;
if(edge[i].to == mark) continue;
getData(edge[i].to);
}
}
void dsu(int x, bool opt){
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x] || edge[i].to == son[x]) continue;
dsu(edge[i].to, true);
}
if(son[x]) dsu(son[x], false), mark = son[x];
getData(x);
ans[x] = mn;
mark = 0;
if(opt){
s.clear();
mn = INF;
}
}

int main(){
scanf("%d%d", &n, &m);
for(int i = 2; i <= n; i++){
int p; scanf("%d", &p);
addEdge(i, p), addEdge(p, i);
}
for(int i = n - m + 1; i <= n; i++)
scanf("%d", &a[i]);
dfs(1, 0);
dsu(1, true);
for(int i = 1; i <= n - m; i++)
printf("%d ", ans[i]);
return 0;
}

Codeforces 208E. Blood Cousins

题目链接

开全局变量——cnt[] 来记录每个深度的节点数。

另外,这道题还有一种巧妙并且更好写的做法:记录下每个深度有哪些 $dfs$ 序,然后询问就是在某个深度的 $dfs$ 序列中二分找某个区间内的数的数量。

>folded
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include<bits/stdc++.h>

using namespace std;

typedef long long LL;

const int N = 100005;

int n, m;
struct Query{
int v, p, ans;
}q[N];
vector<int> vec[N], rt;

struct Edge{
int nxt, to;
}edge[N<<1];
int head[N], edgeNum, deg[N];
void addEdge(int from, int to){
edge[++edgeNum] = (Edge){head[from], to};
head[from] = edgeNum;
deg[to]++;
}

int fa[N][25], son[N], sz[N], dep[N];
void dfs(int x, int f, int depth){
fa[x][0] = f, son[x] = 0, sz[x] = 1, dep[x] = depth;
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == f) continue;
dfs(edge[i].to, x, depth + 1);
sz[x] += sz[edge[i].to];
if(!son[x] || sz[edge[i].to] > sz[son[x]])
son[x] = edge[i].to;
}
}
void init(){
for(int j = 1; (1 << j) <= n; j++)
for(int i = 1; i <= n; i++)
if(fa[i][j-1])
fa[i][j] = fa[fa[i][j-1]][j-1];
}
int getFa(int x, int p){
for(int i = 20; i >= 0; i--){
if((1 << i) <= p){
x = fa[x][i];
p -= (1 << i);
}
}
return x;
}

int mark, cnt[N];
void getData(int x, int val){
cnt[dep[x]] += val;
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x][0]) continue;
if(edge[i].to == mark) continue;
getData(edge[i].to, val);
}
}
void dsu(int x, bool opt){
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x][0] || edge[i].to == son[x]) continue;
dsu(edge[i].to, true);
}
if(son[x]) dsu(son[x], false), mark = son[x];
getData(x, 1);
for(auto k: vec[x]){
q[k].ans = cnt[dep[x] + q[k].p];
if(q[k].ans) q[k].ans--;
}
mark = 0;
if(opt)
getData(x, -1);
}

int main(){
scanf("%d", &n);
for(int i = 1; i <= n; i++){
int p; scanf("%d", &p);
if(p == 0) rt.emplace_back(i);
addEdge(i, p), addEdge(p, i);
}
scanf("%d", &m);
for(auto r: rt)
dfs(r, 0, 0);
init();
for(int i = 1; i <= m; i++){
scanf("%d%d", &q[i].v, &q[i].p);
vec[getFa(q[i].v, q[i].p)].emplace_back(i);
}
for(auto r: rt)
dsu(r, true);
for(int i = 1; i <= m; i++)
printf("%d ", q[i].ans);
return 0;
}

Codeforces 246E. Blood Cousins Return

题目链接

开全局变量——map<string, int>cnt[N] 来记录每一层某个名字的出现次数。

map<> 的使用:map<string, int> 减到 $0$ 之后 erase() 掉相应的字符串,然后用 size() 就能统计 map 里面有多少键值,即不同的字符串数。

>folded
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include<bits/stdc++.h>

using namespace std;

typedef long long LL;

const int N = 100005;

int n, m;
string name[N];
struct Query{
int v, p, ans;
}q[N];
vector<int> vec[N], rt;

struct Edge{
int nxt, to;
}edge[N<<1];
int head[N], edgeNum;
void addEdge(int from, int to){
edge[++edgeNum] = (Edge){head[from], to};
head[from] = edgeNum;
}

int fa[N], son[N], sz[N], dep[N];
void dfs(int x, int f, int depth){
fa[x] = f, son[x] = 0, sz[x] = 1, dep[x] = depth;
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == f) continue;
dfs(edge[i].to, x, depth + 1);
sz[x] += sz[edge[i].to];
if(!son[x] || sz[edge[i].to] > sz[son[x]])
son[x] = edge[i].to;
}
}

int mark;
map<string, int> cnt[N<<1];
void getData(int x, int val){
cnt[dep[x]][name[x]] += val;
if(cnt[dep[x]][name[x]] == 0) cnt[dep[x]].erase(name[x]);
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x]) continue;
if(edge[i].to == mark) continue;
getData(edge[i].to, val);
}
}
void dsu(int x, bool opt){
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x] || edge[i].to == son[x]) continue;
dsu(edge[i].to, true);
}
if(son[x]) dsu(son[x], false), mark = son[x];
getData(x, 1);
for(auto k: vec[x])
q[k].ans = cnt[dep[x] + q[k].p].size();
mark = 0;
if(opt)
getData(x, -1);
}

int main(){
scanf("%d", &n);
for(int i = 1; i <= n; i++){
cin >> name[i];
int p; scanf("%d", &p);
if(p == 0) rt.emplace_back(i);
addEdge(i, p), addEdge(p, i);
}
scanf("%d", &m);
for(int i = 1; i <= m; i++){
scanf("%d%d", &q[i].v, &q[i].p);
vec[q[i].v].emplace_back(i);
}
for(auto r: rt) dfs(r, 0, 0);
for(auto r: rt) dsu(r, true);
for(int i = 1; i <= m; i++)
printf("%d\n", q[i].ans);
return 0;
}

Codeforces 1009F. Dominant Indices

题目链接

开全局变量——cnt[],mx,mxDep 分别记录每个深度的节点数、最大节点数和最大节点数所在层数。

>folded
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include<bits/stdc++.h>

using namespace std;

const int N = 1000005;

int n, ans[N];

struct Edge{
int nxt, to;
}edge[N<<1];
int head[N], edgeNum;
void addEdge(int from, int to){
edge[++edgeNum] = (Edge){head[from], to};
head[from] = edgeNum;
}

int sz[N], son[N], dep[N], fa[N];
void dfs(int x, int f, int depth){
fa[x] = f, son[x] = 0, dep[x] = depth, sz[x] = 1;
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == f) continue;
dfs(edge[i].to, x, depth+1);
sz[x] += sz[edge[i].to];
if(!son[x] || sz[edge[i].to] > sz[son[x]])
son[x] = edge[i].to;
}
}

int mark, mxDep, mx, cnt[N];
void getData(int x, int val){
cnt[dep[x]] += val;
if(cnt[dep[x]] > mx){
mx = cnt[dep[x]];
mxDep = dep[x];
}
if(cnt[dep[x]] == mx)
mxDep = min(mxDep, dep[x]);
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x]) continue;
if(edge[i].to == mark) continue;
getData(edge[i].to, val);
}
}
void dsu(int x, bool opt){
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x] || edge[i].to == son[x]) continue;
dsu(edge[i].to, true);
}
if(son[x]) dsu(son[x], false), mark = son[x];
getData(x, 1);
ans[x] = mxDep - dep[x];
mark = 0;
if(opt){
getData(x, -1);
mxDep = 0, mx = 0;
}
}

int main(){
scanf("%d", &n);
for(int i = 1; i < n; i++){
int u, v; scanf("%d%d", &u, &v);
addEdge(u, v), addEdge(v, u);
}
dfs(1, 0, 1);
dsu(1, true);
for(int i = 1; i <= n; i++)
printf("%d\n", ans[i]);
return 0;
}

Codeforces 375D. Tree and Queries

题目链接

开全局变量——num[],d[].

num[c] 记录颜色 $c$ 的数量,d[k] 表示数量大于等于 $k$ 的颜色种数。

>folded
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include<bits/stdc++.h>

using namespace std;

const int N = 100005;

int n, m, col[N];

struct Query{
int v, k, ans;
}q[N];
vector<int> a[N];

struct Edge{
int nxt, to;
}edge[N<<1];
int head[N], edgeNum;
void addEdge(int from, int to){
edge[++edgeNum] = (Edge){head[from], to};
head[from] = edgeNum;
}

int fa[N], sz[N], son[N];
void dfs(int x, int f){
fa[x] = f, sz[x] = 1, son[x] = 0;
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == f) continue;
dfs(edge[i].to, x);
sz[x] += sz[edge[i].to];
if(!son[x] || sz[edge[i].to] > sz[son[x]])
son[x] = edge[i].to;
}
}

int mark;
int d[N], num[N];
// num[i]: number of color i
void getData(int x, int kind){
if(kind == -1) d[num[col[x]]]--;
num[col[x]] += kind;
if(kind == 1) d[num[col[x]]]++;
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x]) continue;
if(edge[i].to == mark) continue;
getData(edge[i].to, kind);
}
}
void dsu(int x, bool opt){
for(int i = head[x]; i; i = edge[i].nxt){
if(edge[i].to == fa[x] || edge[i].to == son[x]) continue;
dsu(edge[i].to, true);
}
if(son[x]) dsu(son[x], false), mark = son[x];
getData(x, 1);
mark = 0;
for(auto id: a[x])
q[id].ans = d[q[id].k];
if(opt)
getData(x, -1);
}

int main(){
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++)
scanf("%d", &col[i]);
for(int i = 1; i < n; i++){
int u, v; scanf("%d%d", &u, &v);
addEdge(u, v), addEdge(v, u);
}
for(int i = 1; i <= m; i++){
scanf("%d%d", &q[i].v, &q[i].k);
a[q[i].v].emplace_back(i);
}
dfs(1, 0);
dsu(1, true);
for(int i = 1; i <= m; i++)
printf("%d\n", q[i].ans);
return 0;
}
作者

xyfJASON

发布于

2020-04-18

更新于

2021-02-25

许可协议

评论