[2019 ICPC 南昌K]Tree

题目链接

Solution

DSU on tree.

DSU on tree 就是加上了启发式合并的暴力,只需要想好怎么暴力做就行了。

考虑如何求以 $rt$ 为 $\text{lca}$ 的满足条件的点对 $(x,y)$ 的数量,即 $v_x+v_y=2\cdot v_{rt}$ 及 $dep_x+dep_y-2\cdot dep_{rt}\leqslant k$. 考虑对每一个值 $v$ 都用数据结构维护值为 $v$ 的点的深度,那么依次扫描 $rt$ 的子树,枚举到 $x$ 时,数据结构里已经存入了之前子树的信息,只需要在维护 $2\cdot v_{rt}-v_x$ 的数据结构里找到满足 $d\leqslant k+2\cdot dep_{rt}-dep_x$ 的深度 $d$ 有多少个,然后扫描完这棵子树后将它上面的信息存入数据结构。可以对每个值开动态开点线段树完成这个操作,线段树的下标表示深度。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include<bits/stdc++.h>
using namespace std;

template<typename T>void read(T&x){x=0;int fl=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')
fl=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}x*=fl;}
template<typename T,typename...Args>inline void read(T&t,Args&...args){read(t);read(args...);}

typedef long long LL;
typedef vector<int> vi;
typedef pair<int, int> pii;
#define mp(x, y) make_pair(x, y)
#define pb(x) emplace_back(x)

const int N = 100005;

int n, v[N], k, root[N];
LL ans;
vector<int> edge[N];

int sz[N], son[N], dep[N];
void dfs(int x, int d){
sz[x] = 1, son[x] = 0, dep[x] = d;
for(auto &to : edge[x]){
dfs(to, d+1);
sz[x] += sz[to];
if(!son[x] || sz[to] > sz[son[x]]) son[x] = to;
}
}

int tot;
struct segTree{
int lson, rson, cnt;
}tr[N*200];
inline void pushup(int rt){
tr[rt].cnt = tr[tr[rt].lson].cnt + tr[tr[rt].rson].cnt;
}
void add(int rt, int l, int r, int p, int val){
if(l == r){ tr[rt].cnt += val; return; }
int mid = (l + r) >> 1;
if(p <= mid){
if(!tr[rt].lson) tr[rt].lson = ++tot;
add(tr[rt].lson, l, mid, p, val);
}
else{
if(!tr[rt].rson) tr[rt].rson = ++tot;
add(tr[rt].rson, mid+1, r, p, val);
}
pushup(rt);
}
LL query(int rt, int l, int r, int L, int R){
if(rt == 0) return 0;
if(L > R) return 0;
if(l == L && r == R) return tr[rt].cnt;
int mid = (l + r) >> 1;
if(R <= mid) return query(tr[rt].lson, l, mid, L, R);
else if(L > mid) return query(tr[rt].rson, mid+1, r, L, R);
else return query(tr[rt].lson, l, mid, L, mid) + query(tr[rt].rson, mid+1, r, mid+1, R);
}

void getAns(int x, int rtx){
if(2*v[rtx]-v[x] >= 0 && 2*v[rtx]-v[x] <= n)
ans += query(root[2*v[rtx]-v[x]], 1, n, 1, min(n, k + 2 * dep[rtx] - dep[x]));
// dep[x] + d - 2 * dep[rtx] <= k
for(auto &to : edge[x]) getAns(to, rtx);
}
void getData(int x){
add(root[v[x]], 1, n, dep[x], 1);
for(auto &to : edge[x]) getData(to);
}
void delData(int x){
add(root[v[x]], 1, n, dep[x], -1);
for(auto &to : edge[x]) delData(to);
}
void dsu(int x, bool opt){
for(auto &to : edge[x]){
if(to == son[x]) continue;
dsu(to, true);
}
if(son[x]) dsu(son[x], false);
for(auto &to : edge[x]){
if(to == son[x]) continue;
getAns(to, x);
getData(to);
}
if(opt){
for(auto &to : edge[x]) delData(to);
}
else add(root[v[x]], 1, n, dep[x], 1);
}

int main(){
read(n, k);
for(int i = 1; i <= n; i++) read(v[i]);
for(int i = 2; i <= n; i++){
int f; read(f);
edge[f].emplace_back(i);
}
dfs(1, 1);
for(int i = 0; i <= n; i++) root[i] = ++tot;
dsu(1, true);
printf("%lld\n", ans * 2);
return 0;
}
作者

xyfJASON

发布于

2021-05-01

更新于

2021-05-05

许可协议

评论