关于扫描线算法中线段树标记的理解

关于扫描线算法中线段树标记的理解

由于扫描线算法的询问只需询问线段树根节点信息,我们没有必要写 pushdown 函数,但同时也引入了一些其他问题。本文记录我做题过程中的理解,水平有限,如有错误,烦请指正。

算法简述

扫描线算法用于解决矩形面积并的问题,当然其思想也可用于其他形状的图形和其他问题。

在矩形面积并问题中:

用一条平行于 $x$ 轴的直线自下而上地扫描平面,每扫描到一条矩形的边,就计算这条边和上一条边之间的面积。线段树在这个过程中维护当前扫描线被覆盖的长度。具体地,每扫描到一个矩形的下边,其对应区间就加 $1$;每扫描到一个矩形的上边,其对应区间减 $1$,询问时统计整个区间非零的长度。

为完成上述操作,每个线段树节点储存一个标记 cnt 和一个信息 lengthcnt 标记表示该节点对应区间被覆盖了几次; length 表示该节点对应区间的被覆盖长度,也即非零长度。

pushdowncnt 标记的影响

由于没有 pushdown 的存在,这里的 cnt 标记比有 pushdown 的线段树的标记的情形复杂一些。

想一想 pushdown 操作的线段树,我们如果要对某节点进行操作,会先将其祖先节点的标记一路下放下来,所以我们在操作这个节点时,其祖先节点时没有标记的;随后我们就可以顺理成章地把信息一路 pushup 回去——祖先节点都没有标记了,它的信息只依赖于子节点的信息。

而对于没有 pushdown 操作的线段树,某节点的信息不能只靠子节点的信息决定,还由其自身的标记决定。如下图所示,本题中,线段树某节点的 cnt 加 $1$ 后,由于没有 pushdown,其子节点完全不知道自己已经被覆盖了,子节点的信息——比如 length ——仍然是之前的状态。用这个子节点的 length 去更新父节点的 length,显然是错误的。但是结合父节点的 cnt 标记,我们可以知道父节点整个都被覆盖了,于是我们就可以正确地把父节点的 length 赋值为整个区间长度。当然此时此刻,子节点的存储信息是错误信息——但这又有什么关系呢?反正我们只询问根节点信息,只需要保证父节点信息的正确性即可。(自然,这个“父节点”也是某节点的“子节点”,它的信息也可能是错误的;不过,根节点不作为任何节点的子节点,信息一定正确。)

上一段阐述的是自身标记的重要性,但是事实上子节点的信息也是不可缺少的。考虑我们把某节点的两个子节点分两次覆盖住,那么这个节点被覆盖住了,但是 cnt 却为 0。更一般的,cnt == 0 的节点完全有可能整个都被覆盖了,甚至覆盖了好几次,甚至分段覆盖次数还不一样。所以,当父节点的 cnt 标记为 0 时,它的信息需要用其子节点的信息来更新(如下图所示)。当然,这时子节点的信息是局部(相对于父节点)正确的(尽管可能是整体错误的)。

综上所述,cnt == 0 时,length 为左右子节点的 length 之和;而 cnt > 0 时,length 是该节点的对应区间长度。

于是,我们的 pushup 应运而生:

1
2
3
4
5
6
7
inline void pushup(int id){
if(tr[id].cnt > 0) tr[id].length = x[tr[id].r + 1] - x[tr[id].l];
else{
if(tr[id].l == tr[id].r) tr[id].length = 0;
else tr[id].length = tr[lid].length + tr[rid].length;
}
}

同时,使用pushup 的时机也要微调——在打了标记之后也需要立刻 pushup

矩形面积并的模板如下:

>folded
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include<cstdio>
#include<algorithm>

using namespace std;

const int N = 200005;

int n, xid;
double tx1, ty1, tx2, ty2, x[N], ans;
struct ScanLine{
double x1, x2, y;
int k; // k == 1 or -1
int dx1, dx2; // after discretization
bool operator < (const ScanLine &A) const{
return y == A.y ? k > A.k : y < A.y;
}
}a[N];

inline void disc(){
sort(x+1, x+xid+1);
xid = unique(x+1, x+xid+1) - (x+1);
for(int i = 1; i <= n; i++){
a[i].dx1 = lower_bound(x+1, x+xid+1, a[i].x1) - x;
a[i].dx2 = lower_bound(x+1, x+xid+1, a[i].x2) - x;
}
}

struct SegTree{
int l, r, cnt;
double length;
}tr[N<<2];
#define lid id<<1
#define rid id<<1|1
#define mid ((tr[id].l + tr[id].r) >> 1)
#define len(id) (tr[id].r - tr[id].l + 1)
inline void pushup(int id){
if(tr[id].cnt > 0) tr[id].length = x[tr[id].r + 1] - x[tr[id].l];
else{
if(tr[id].l == tr[id].r) tr[id].length = 0;
else tr[id].length = tr[lid].length + tr[rid].length;
}
}
void build(int id, int l, int r){
tr[id].l = l, tr[id].r = r;
tr[id].cnt = 0, tr[id].length = 0;
if(tr[id].l == tr[id].r) return;
build(lid, l, mid);
build(rid, mid+1, r);
pushup(id);
}
void add(int id, int l, int r, int val){
if(tr[id].l == l && tr[id].r == r){
tr[id].cnt += val;
pushup(id);
return;
}
if(r <= mid) add(lid, l, r, val);
else if(l > mid) add(rid, l, r, val);
else add(lid, l, mid, val), add(rid, mid+1, r, val);
pushup(id);
}

int main(){
scanf("%d", &n);
for(int i = 1; i <= n; i++){
scanf("%lf%lf%lf%lf", &tx1, &ty1, &tx2, &ty2);
a[i] = (ScanLine){tx1, tx2, ty1, 1};
a[i+n] = (ScanLine){tx1, tx2, ty2, -1};
x[++xid] = tx1, x[++xid] = tx2;
}
n <<= 1;
disc();
sort(a+1, a+n+1);
build(1, 1, xid-1);
for(int i = 1; i < n; i++){
add(1, a[i].dx1, a[i].dx2 - 1, a[i].k);
ans += tr[1].length * (a[i+1].y - a[i].y);
}
printf("%.2f\n", ans);
return 0;
}

进一步

与其说 cnt 标记表示“该节点对应区间被覆盖了几次”,不如说:

  • cnt == 0 表示我们对该节点对应区间一无所知——它可能没被覆盖,可能部分覆盖,可能全被覆盖,甚至覆盖次数还不同……我们要知道这个节点的 length 信息,只能从其子节点 pushup 上来;
  • cnt == 1 表示这个节点对应区间被覆盖至少 $1$ 次——当然也可能覆盖了多次,或者分成好几段覆盖了不同次——但总之被完全覆盖了,length 就是对应区间长度。同时它的子节点并不知道它的覆盖情况。

这样的理解是有好处的,比如说我们遇到了加强版的题目:hdu 1255 覆盖的面积

题目要求求出被矩形覆盖过至少两次的区域面积。我们在线段树节点中维护一个标记 cnt,两个信息 length1,length2,分别表示区间被覆盖至少 $1$ 次的长度和被覆盖至少 $2$ 次的长度。于是:

  • cnt == 0 表示我们对该节点对应区间一无所知,它的 length1,length2 信息由子节点决定;
  • cnt == 1 表示这个节点对应区间被覆盖至少 $1$ 次,它的 length1 就是区间长度,length2 是左右子节点的 length1 之和——该节点对应区间整个已经被覆盖至少一次了(子节点并不知道),只需要子节点再覆盖一次就好;
  • cnt >= 2 表示这个节点对应区间被覆盖至少 $2$ 次,它的 length1,length2 都是区间长度。

pushup 如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
inline void pushup(int id){
if(tr[id].cnt >= 2){ // interval is covered at least twice
tr[id].length2 = x[tr[id].r + 1] - x[tr[id].l];
tr[id].length1 = x[tr[id].r + 1] - x[tr[id].l];
}
else if(tr[id].cnt == 1){ // interval is covered at least once
if(tr[id].l == tr[id].r) tr[id].length2 = 0;
else tr[id].length2 = tr[lid].length1 + tr[rid].length1;
tr[id].length1 = x[tr[id].r + 1] - x[tr[id].l];
}
else{ // do not know the infomation of this interval
if(tr[id].l == tr[id].r) tr[id].length1 = tr[id].length2 = 0;
else{
tr[id].length1 = tr[lid].length1 + tr[rid].length1;
tr[id].length2 = tr[lid].length2 + tr[rid].length2;
}
}
}

AC 代码如下:

>folded
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include<cstdio>
#include<cstring>
#include<algorithm>

using namespace std;

const int N = 2005;

int T, n, xid;
double tx1, ty1, tx2, ty2, x[N], ans;
struct ScanLine{
double x1, x2, y;
int k; // k == 1 or -1
int dx1, dx2; // after discretization
bool operator < (const ScanLine &A) const{
return y == A.y ? k > A.k : y < A.y;
}
}a[N];

inline void disc(){
sort(x+1, x+xid+1);
xid = unique(x+1, x+xid+1) - (x+1);
for(int i = 1; i <= n; i++){
a[i].dx1 = lower_bound(x+1, x+xid+1, a[i].x1) - x;
a[i].dx2 = lower_bound(x+1, x+xid+1, a[i].x2) - x;
}
}

struct SegTree{
int l, r, cnt;
double length1, length2;
}tr[N<<2];
#define lid id<<1
#define rid id<<1|1
#define mid ((tr[id].l + tr[id].r) >> 1)
#define len(id) (tr[id].r - tr[id].l + 1)
inline void pushup(int id){
if(tr[id].cnt >= 2){ // interval is covered at least twice
tr[id].length2 = x[tr[id].r + 1] - x[tr[id].l];
tr[id].length1 = x[tr[id].r + 1] - x[tr[id].l];
}
else if(tr[id].cnt == 1){ // interval is covered at least once
if(tr[id].l == tr[id].r) tr[id].length2 = 0;
else tr[id].length2 = tr[lid].length1 + tr[rid].length1;
tr[id].length1 = x[tr[id].r + 1] - x[tr[id].l];
}
else{ // do not know the infomation of this interval
if(tr[id].l == tr[id].r) tr[id].length1 = tr[id].length2 = 0;
else{
tr[id].length1 = tr[lid].length1 + tr[rid].length1;
tr[id].length2 = tr[lid].length2 + tr[rid].length2;
}
}
}
void build(int id, int l, int r){
tr[id].l = l, tr[id].r = r;
tr[id].cnt = 0, tr[id].length1 = tr[id].length2 = 0;
if(tr[id].l == tr[id].r) return;
build(lid, l, mid);
build(rid, mid+1, r);
pushup(id);
}
void add(int id, int l, int r, int val){
if(tr[id].l == l && tr[id].r == r){
tr[id].cnt += val;
pushup(id);
return;
}
if(r <= mid) add(lid, l, r, val);
else if(l > mid) add(rid, l, r, val);
else add(lid, l, mid, val), add(rid, mid+1, r, val);
pushup(id);
}

int main(){
scanf("%d", &T);
while(T--){
scanf("%d", &n);
xid = 0;
ans = 0;
memset(tr, 0, sizeof tr);
for(int i = 1; i <= n; i++){
scanf("%lf%lf%lf%lf", &tx1, &ty1, &tx2, &ty2);
a[i] = (ScanLine){tx1, tx2, ty1, 1};
a[i+n] = (ScanLine){tx1, tx2, ty2, -1};
x[++xid] = tx1, x[++xid] = tx2;
}
n <<= 1;
disc();
sort(a+1, a+n+1);
build(1, 1, xid-1);
for(int i = 1; i < n; i++){
add(1, a[i].dx1, a[i].dx2 - 1, a[i].k);
ans += tr[1].length2 * (a[i+1].y - a[i].y);
}
printf("%.2f\n", ans);
}
return 0;
}

再加强

hdu 3642 Get the Treasury

三维空间,首先我们循环 $z$ 坐标,就可以转化成平面上求被覆盖至少三次的矩形面积。有了上述理解,就可以容易地写出代码了:

>folded
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include<cstdio>
#include<cstring>
#include<algorithm>

using namespace std;

typedef long long LL;

const int N = 2005;

int T, n, xid, aid, zid;
LL x[N], ans, z[N];
struct Node{
LL x1, y1, z1, x2, y2, z2;
}node[N];
struct ScanLine{
LL x1, x2, y;
int k; // k == 1 or -1
int dx1, dx2; // after discretization
bool operator < (const ScanLine &A) const{
return y == A.y ? k > A.k : y < A.y;
}
}a[N];

inline void disc(){
sort(x+1, x+xid+1);
xid = unique(x+1, x+xid+1) - (x+1);
for(int i = 1; i <= aid; i++){
a[i].dx1 = lower_bound(x+1, x+xid+1, a[i].x1) - x;
a[i].dx2 = lower_bound(x+1, x+xid+1, a[i].x2) - x;
}
}

struct SegTree{
int l, r, cnt;
LL length1, length2, length3;
}tr[N<<2];
#define lid id<<1
#define rid id<<1|1
#define mid ((tr[id].l + tr[id].r) >> 1)
#define len(id) (tr[id].r - tr[id].l + 1)
inline void pushup(int id){
if(tr[id].cnt >= 3)
tr[id].length3 = tr[id].length2 = tr[id].length1 = x[tr[id].r + 1] - x[tr[id].l];
else if(tr[id].cnt == 2){
if(tr[id].l == tr[id].r) tr[id].length3 = 0;
tr[id].length3 = tr[lid].length1 + tr[rid].length1;
tr[id].length2 = tr[id].length1 = x[tr[id].r + 1] - x[tr[id].l];
}
else if(tr[id].cnt == 1){
if(tr[id].l == tr[id].r) tr[id].length3 = tr[id].length2 = 0;
else{
tr[id].length3 = tr[lid].length2 + tr[rid].length2;
tr[id].length2 = tr[lid].length1 + tr[rid].length1;
}
tr[id].length1 = x[tr[id].r + 1] - x[tr[id].l];
}
else{
if(tr[id].l == tr[id].r) tr[id].length3 = tr[id].length2 = tr[id].length1 = 0;
else{
tr[id].length3 = tr[lid].length3 + tr[rid].length3;
tr[id].length2 = tr[lid].length2 + tr[rid].length2;
tr[id].length1 = tr[lid].length1 + tr[rid].length1;
}
}
}
void build(int id, int l, int r){
tr[id].l = l, tr[id].r = r;
tr[id].cnt = 0, tr[id].length1 = tr[id].length2 = 0;
if(tr[id].l == tr[id].r) return;
build(lid, l, mid);
build(rid, mid+1, r);
pushup(id);
}
void add(int id, int l, int r, int val){
if(tr[id].l == l && tr[id].r == r){
tr[id].cnt += val;
pushup(id);
return;
}
if(r <= mid) add(lid, l, r, val);
else if(l > mid) add(rid, l, r, val);
else add(lid, l, mid, val), add(rid, mid+1, r, val);
pushup(id);
}

inline void init(){
xid = aid = 0;
memset(tr, 0, sizeof tr);
memset(a, 0, sizeof a);
memset(x, 0, sizeof x);
memset(a, 0, sizeof a);
}

int main(){
scanf("%d", &T);
for(int CASES = 1; CASES <= T; CASES++){
ans = 0;
zid = 0;
scanf("%d", &n);
for(int i = 1; i <= n; i++){
scanf("%lld%lld%lld%lld%lld%lld", &node[i].x1, &node[i].y1, &node[i].z1, &node[i].x2, &node[i].y2, &node[i].z2);
z[++zid] = node[i].z1, z[++zid] = node[i].z2;
}
sort(z+1, z+zid+1);
zid = unique(z+1, z+zid+1) - (z+1);
for(int j = 1; j < zid; j++){
init();
for(int i = 1; i <= n; i++){
if(node[i].z1 <= z[j] && node[i].z2 > z[j]){
x[++xid] = node[i].x1, x[++xid] = node[i].x2;
a[++aid] = (ScanLine){node[i].x1, node[i].x2, node[i].y1, 1};
a[++aid] = (ScanLine){node[i].x1, node[i].x2, node[i].y2, -1};
}
}
disc();
sort(a+1, a+aid+1);
build(1, 1, xid-1);
LL res = 0;
for(int i = 1; i < aid; i++){
add(1, a[i].dx1, a[i].dx2 - 1, a[i].k);
res += tr[1].length3 * (a[i+1].y - a[i].y);
}
ans += res * (z[j+1] - z[j]);
}
printf("Case %d: %lld\n", CASES, ans);
}
return 0;
}
作者

xyfJASON

发布于

2020-02-12

更新于

2021-02-25

许可协议

评论