Yuhang's Blog

树链剖分

2019-11-12 Coding

这个题解把树链剖分的模板进行了结构体封装,并说说对树链剖分的一些理解。

让我一步一步讲。

我们知道数组是什么:一个节点含有一个数字,许多结点线性地列成一排。

当我们对某一段子数组的求和和修改操作感兴趣的时候,我们就想出了线段树。线段树虽然说是一颗树,但终究是为“线段”——线性数组服务的。

有一天,所有的节点都不再线性地排列,而构成了一颗树。当我们又对树上某些节点集合的求和和修改感兴趣的时候,我们希望能利用已有的经验,我们希望把问题化归成从前的情形。是的,我们想把这棵树变成一条链,变成一个数组。

似乎仅仅是把树变成一个数组并不是什么困难的问题。

其实树上节点的权值本来就储存在数组里面。当遍历我们所需要的节点集合的时候,把对应的权值一个一个加起来,这就是最简单,最暴力的方法。显然,太慢。问题在于,我们没有很好地利用我们已有的快速的数据结构:线段树。

为了利用线段树,我们需要尽量把需要求和的节点连续地排在一起。所以,我们需要重新考虑我们的需求,我们需要考虑我们想去求和和修改的树上节点的集合具有什么样的特征。请看题目要求,有两种需要的节点集合:

  1. 两个节点u, v最短路径上的所有节点
  2. 某个节点u的子树

请想一下,如果只有第2个要求该多好啊。直接对树上的节点进行深搜排序即可。显然,深搜排序的序列之中,u子树上节点的位置都连在u的位置后面。结合线段树的功能,我们完全可以解决第2个需求。

但是我们需要仔细考虑需求1.

如果直接利用深搜序,由于对子节点遍历顺序的随意性,在深搜过程中向上回溯的次数可能很多,而这就导致在深搜序列中可能会有很多“断点”:某个点在序列中的下一个点不是它的子节点。

而如果要完成需求1,我们当然要(显性或隐性地)找到u和v的LCA,然后对LCA分别到u和v的两条自顶向下的链进行求和。我们谈起自顶向下的链的时候应该是高兴的,因为深搜序确实也是自顶向下依次排开的。问题是,刚刚说了,我们可能会有太多的断点,那样一来我们需要更多次数对线段树进行查询,并且每次查询的区间长度变小。显然,不划算。

所以我们需要确定一个更加合理的遍历子节点的顺序。其实这个时候会有不同的想法。你可能会觉得,我们可以优先遍历叶子深度最大的那个子节点。好,但不够好。有一个更好的想法:每次优先遍历子树大小最大的那个子节点。我们现在可以把这个子节点叫做重子节点,把其他节点叫做轻子节点。

为什么这个想法更好?因为我们可以证明,像这样优先选取,从一个点走到自己的叶子最多只需经过log子树大小的个数的链。这是因为,最坏的情形是,每一次都走到一个轻子节点上去。但是,走到轻子节点,也意味着子树的大小至少会缩小一半(否则这棵子树一定会变成最大的子树)。

好了。我们已经找到了想要的把树剖分成链的方法。明白了这些,实现起来就不再困难。下面我们需要跑两次深搜:

  • 第一次是建树的过程,
    • 我们当然要首先把每个节点的父亲fa[N]标出来,
    • 为了确定重子节点是谁,需要首先记录所有节点子树的大小siz[N]
    • 把重子节点记录在hson[N]内,
    • 再记录一下每个节点的深度deep[N],后面会用得到。
  • 第二次把顺序排出来。
    • dfn[N]就是我们的重子节点优先的深搜序,
    • rnk[N]可以看成是dfn[N]的反函数,也就是记录某个节点被排在哪个位置了,
    • 在我们的深搜序的序列里,有若干条自顶向下的链(所谓的重链)。显然我们的所有操作都应该在某一个链中进行,所以,我们当然要记录一下序列中每个点所在的那个链条是从谁开始的。这就是top[N]

成功把树剖分成链,就可以把它扔进线段树中了。

下面的问题是怎么利用线段树把我们的问题解决了。

先说刚刚的第2个需求。这依然是比较容易的。像刚刚说的,如果要对u的子树上所有节点进行查询和修改,我们发现这些节点在对应的深搜序列中是以u为首项的,项数是siz[u]的一个子序列,所以可以很方便地调用线段树的功能了。

再考虑第1个需求。刚刚我们也提及了解决它的思路。我们寻找u和v的LCA的过程,其实也就是让u或者v向上跳跃,最后会和的过程。

有一种很简单的情形,我们发现top[u] == top[v],这就表明u和v在同一段链中,那我们就可以直接进行求和或修改了。

但事情不会总是这么凑巧,也许u或者v需要先各自进行几次跳跃,才能出现上面的这种情况(这种情况一定会出现,因为这棵树只有一个根节点嘛)。那么如何处理各自进行的跳跃?想利用我们已经存储的信息,也想最大限度地利用线段树的功能,我们当然希望直接算出u或者v到它的top的路径上的所有和,然后u或者v跳跃的时候直接跳到它的top的父亲的位置。但危险的是,如果某一次不慎跳到了LCA上面的某个节点,答案就会出错。如何避免?可以注意到一点:top[u]top[v]不可能同时在LCA上面。否则,u所在的那条链和v所在的那条链就会有重合的元素(LCA),而考虑到我们构造链的方法,这是不可能的。这也就是说,top[u]top[v]中深度较深的那个,一定不会在LCA的上面。明白了这些,如何进行跳跃就十分明了了:每次选择deep[top[u/v]]较大的那个,跳跃到fa[top[u/v]]处即可。

剩下的,就是具体实现的一些处理了。请看代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
#include<bits/stdc++.h>
#define N 112345
using namespace std;
inline int read(){
int x=0; int sign=1; register char c=getchar();
while(c>'9'||c<'0'){if(c=='-')sign=-1; c=getchar();}
while(c>='0'&&c<='9'){x=(x<<3)+(x<<1)+c-'0'; c=getchar();}
return x*sign;
}
long long P; int n, m, r;

struct TREE{
vector<int> g[N]; int wght[N];
int fa[N], deep[N], siz[N], hson[N];
int top[N], dfn[N], rnk[N], tot;

void ini(){
for(int i = 1; i <= n; i++) wght[i] = read();
for(int i = 1; i <= n-1; i++){
int a = read(), b = read();
g[a].push_back(b);
g[b].push_back(a);
}
}

void build_tree(int u, int dep){
siz[u] = 1;
deep[u] = dep;
for(unsigned int i = 0; i < g[u].size(); ++i){
int v = g[u][i];
if (v == fa[u]) continue;
fa[v] = u;
build_tree(v, dep + 1);
siz[u] += siz[v];
if (siz[v] > siz[hson[u]]) hson[u] = v;
}
}

void tree_decom(int u, int t){
top[u] = t;
dfn[u] = ++tot;
rnk[tot] = u;
if (hson[u]){
tree_decom(hson[u], t);
for(unsigned int i = 0; i < g[u].size(); i++){
int v = g[u][i];
if (v == fa[u] || v == hson[u]) continue;
tree_decom(v, v);
}
}
}
};

struct NODE{
int l, r; long long sum, add;
};


struct SEGTREE{
NODE SegTree[N << 2];
long long a[N];

#define lft(u) SegTree[u].l
#define rgt(u) SegTree[u].r
#define len(u) (rgt(u)-lft(u)+1)
#define sum(p) SegTree[p].sum
#define add(p) SegTree[p].add
#define mid(p) ((lft(p)+rgt(p))>>1)
#define lson(p) (p<<1)
#define rson(p) (p<<1|1)

void build(int u, int l, int r){
lft(u) = l; rgt(u) = r; add(u) = 0;
if (l == r){
sum(u) = a[l];
} else {
build(lson(u), l, mid(u));
build(rson(u), mid(u)+1, r);
sum(u) = sum(lson(u)) + sum(rson(u));
sum(u) %= P;
}
}

void spread(int u){
if (!add(u)) return;
sum(lson(u)) += add(u) * len(lson(u)); sum(lson(u)) %= P;
sum(rson(u)) += add(u) * len(rson(u)); sum(rson(u)) %= P;
add(lson(u)) += add(u); add(lson(u)) %= P;
add(rson(u)) += add(u); add(rson(u)) %= P;
add(u) = 0;
}

void change(int u, int l, int r, long long d){
if (l <= lft(u) && rgt(u) <= r) {
sum(u) += d * len(u); sum(u) %= P;
add(u) += d; add(u) %= P;
} else {
spread(u);
if (l <= mid(u)) change(lson(u), l, r, d);
if (mid(u) + 1 <= r) change(rson(u), l, r, d);
sum(u) = sum(rson(u)) + sum(lson(u)); sum(u) %= P;
}
}

long long query(int u, int l, int r){
if (l <= lft(u) && rgt(u) <= r){
return sum(u) % P;
} else {
spread(u);
long long ans = 0;
if (l <= mid(u)) {ans += query(lson(u), l, r); ans %= P;}
if (mid(u) + 1 <= r) {ans += query(rson(u), l, r); ans %= P;}
return ans % P;
}
}
};

struct TREECHAIN{
SEGTREE Sgt;
TREE Tree;

#define fa(u) Tree.fa[u]
#define deep(u) Tree.deep[u]
#define top(u) Tree.top[u]
#define dfn(u) Tree.dfn[u]
#define rnk(u) Tree.rnk[u]
#define btm(u) (Tree.dfn[u]+Tree.siz[u]-1)
#define wght(u) Tree.wght[u]
#define rnk(u) Tree.rnk[u]

void ini(){
Tree.ini();
Tree.build_tree(r, 1);
Tree.tree_decom(r, r);
for(int i = 1; i <= n; i++){
Sgt.a[i] = wght(rnk(i)) % P;
}
Sgt.build(1, 1, n);
}

void tree_path_change(int u, int v, long long d){
while(top(u) != top(v)){
if (deep(top(u)) < deep(top(v))) swap(u, v);
Sgt.change(1, dfn(top(u)), dfn(u), d);
u = fa(top(u));
}
if (dfn(u) > dfn(v)) swap(u, v);
Sgt.change(1, dfn(u), dfn(v), d);
}

long long tree_path_sum(int u, int v){
long long tot = 0;
while(top(u) != top(v)){
if (deep(top(u)) < deep(top(v))) swap(u, v);
tot += Sgt.query(1, dfn(top(u)), dfn(u)); tot %= P;
u = fa(top(u));
}
if (dfn(u) > dfn(v)) swap(u, v);
tot += Sgt.query(1, dfn(u), dfn(v)); tot %= P;
return tot;
}

void subtree_change(int u, long long d){
Sgt.change(1, dfn(u), btm(u), d);
}

long long subtree_sum(int u){
return Sgt.query(1, dfn(u), btm(u));
}

void work(){
while(m--){
int mode = read();

if (mode == 1){
int u = read(), v = read(); long long d = read();
tree_path_change(u, v, d);
}

if (mode == 2){
int u = read(), v = read();
printf("%lld\n", tree_path_sum(u, v));
}

if (mode == 3){
int u = read(); long long d = read();
subtree_change(u, d);
}

if (mode == 4){
int u = read();
printf("%lld\n", subtree_sum(u));
}
}
}
}Tc;


int main(){
n = read(), m = read(), r = read(), P = read();
Tc.ini();
Tc.work();
}
This article was last updated on days ago, and the information described in the article may have changed.