Yuhang's Blog

莫队算法、带修莫队和树上莫队

2021-08-31 Coding

  1. 1. 莫队算法
    1. 1.1. 理论
    2. 1.2. 例题
    3. 1.3. 实现
  2. 2. 带修莫队
    1. 2.1. 理论
  3. 3. 树上莫队

莫队算法是基于分块思想的区间查询离线算法。带修莫队支持单点修改。树上莫队则是使用欧拉序将树序列化后,套用莫队算法。

莫队算法

理论

  • 离线算法;
  • 不支持修改;
  • 区间查询。

如果可以将区间[l, r]的答案容易地(即以常数时间复杂度)推到区间[l - 1, r], [l, r + 1], [l + 1, r], [l, r - 1]的答案,就可以使用莫队算法。(前两个区间分别是在左端和右端增加一个元素;后两个区间分别是在左端和右端删去一个元素。)

考虑单纯的暴力:我们一般可以容易得到得到空区间(常常用l = 1, r = 0表示)的答案。此视为第0次询问。那么对于所有询问,我们只要基于上次询问的区间,不断地移动左右区间的端点(体现为不断在区间的左侧或右侧进行增加或删去元素的操作),我们就可以由上次询问的答案推到当前询问的答案。

当然,此算法的复杂度是过高的。如果数组大小是$N$,询问次数是$M$,那么每次询问最坏情况下要左右端点分别都要进行$O(N)$的移动,从而总体的时间复杂度是$O(MN)$。

如何优化呢?我们考虑如下这种特殊情况:假定所有的查询的区间的左端点都在长度不超过$B$的区间内,而右端点没有限制;即我们可以认为左端点之间都非常近,或者任意两个左端点互相进行转移都是容易抵达的,此时应主要考虑优化转移右侧端点所需的时间复杂度。可以想见,如果我们按右侧端点的大小对所有询问进行排序,那么处理所有询问的过程中我们用在对右侧的端点进行转移的时间就不超过$O(N)$;另外每次询问都可能有$O(B)$的转移左侧端点的时间复杂度,所有询问加起来就是$O(MB)$。从而在这种特殊情况之中,总的时间复杂度是$O(N+MB)$。

这种情况的处理就是分块的主要灵感。虽然我们的查询区间的左端点事实上分散在整个长度为$N$的区间,但我们可以人为地将整个区间分成长度约是$B$的块,共计约$N/B$块。将询问区间的左端点按所在块划分,在每个块中我们分别进行上述算法,即对询问的右侧端点进行排序后依次处理。这样,总体时间复杂度是$O(N\times N/B + M B)=O(N^2/B+MB)$,从而应取$B=N/\sqrt M$,使得我们得到$O(N \sqrt M)$复杂度的算法。加上排序的时间,总时间是$O(N\sqrt M + M \log M )$。 注意,我们这里对时间复杂度的讨论忽略了左端点需要换块的情况,但事实上这不影响时间复杂度上界。

上述基于分块的离线算法就是莫队算法。

例题

例题:P1494 [国家集训队]小Z的袜子 - 洛谷 | 计算机科学教育新生态。考虑在每个区间内,拿到两只颜色相同的袜子的概率是:
$$
\frac{\sum \binom {x_i} 2}{\binom n 2}
$$
其中$n$是区间长度,$x_i$表示出现在区间中的每种袜子的数量。为求这个概率,只需分别求分子和分母,对每个询问分别再约分即可。

对于分母来说,每次增加或删去一个元素时,维护区间长度就容易得到递推关系。对于分子来说,需要一个数组维护当前区间内所有出现的袜子的数量。在添加或删去一个元素的时候,计入这个元素对应的$x_i$的变化对答案的影响即可。

实现

首先是要将询问写在一个结构体中,方便排序:

1
2
3
4
5
6
7
8
struct Query {
int l, r, id;
} q[N];

ll B;
bool cmp(const Query &a, const Query &b) {
return a.l/B == b.l/B ? (((a.l/B) & 1) ? a.r > b.r : a.r < b.r) : a.l/B < b.l/B;
}

这里的id属性是指该询问的输入顺序。注意这里的$B$可以预设为一个常数,但更好做法是在输入数据后通过$B=N/\sqrt M$计算。

这里的cmp函数的意义是:首先判断两个询问的左端点是否在同一个块中:

  • 如果是,按右端点大小排序。但这里采取奇偶优化:对于第奇数个块,右端点从大到小排序;对于第偶数个块,右端点从小到大排序。这样可以使得换块时右端点不用一路退回到区间另一侧才得到答案。这个优化对性能的影响显著。
  • 如果不是,按左端点所在块排序。

考虑区间转移的部分。维护当前区间的答案所需的全部变量都设为全局变量(在工程上看这并不非常优雅,请注意变量名不要冲突;如果想改进可以考虑结构体封装),并定义两个函数实现对这些变量的修改:

1
2
3
4
5
6
7
8
9
ll nu, de, len, cnt[N];
void add(int x) {
nu += cnt[x]; cnt[x]++;
de += len; len++;
}
void del(int x) {
cnt[x]--; nu -= cnt[x];
len--; de -= len;
}

主函数内则设两个变量lr表示当前所在区间,并依次处理所有询问:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
sort(q + 1, q + 1 + m, cmp);

ll l = 1, r = 0;
rep(i, 1, m) {
ll ql = q[i].l, qr = q[i].r;
if (ql == qr) {
ans[q[i].id].de = 1;
ans[q[i].id].nu = 0;
continue;
}
while (l > ql) add(c[--l]);
while (r < qr) add(c[++r]);
while (l < ql) del(c[l++]);
while (r > qr) del(c[r--]);
ll g = gcd(de, nu);
ans[q[i].id].de = de / g;
ans[q[i].id].nu = nu / g;
}

rep(i, 1, m) {
printf("%lld/%lld\n", ans[i].nu, ans[i].de);
}

注意在移动左右端点的时候,需要先增加再删去,防止答案出现负数等情况;然后采取先左后右即可。四个while之顺序不可随意调换。一般将答案按q[i].id写在一个数组内,最后再将该数组输出。

完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include<bits/stdc++.h>
using namespace std;
#define rep(i,from,to) for(int i=(int)(from);i<=(int)(to);++i)
#define rev(i,from,to) for(int i=(int)(from);i>=(int)(to);--i)
#define For(i,to) for(int i=0;i<(int)(to);++i)
typedef long long ll;
inline ll read(){
ll x=0;bool s=1;char c=getchar();
while(c>'9'||c<'0'){if(c=='-')s=0;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<3)+(x<<1)+c-'0';c=getchar();}
return s?x:~x+1;
}
ll gcd(ll a, ll b) {
return b ? gcd(b, a % b) : a;
}

const int N = 50000 + 100;

struct Query {
int l, r, id;
} q[N];

struct Ans {
ll de, nu;
} ans[N];

ll B;
bool cmp(const Query &a, const Query &b) {
return a.l/B == b.l/B ? (((a.l/B) & 1) ? a.r > b.r : a.r < b.r) : a.l/B < b.l/B;
}

ll c[N];

ll nu, de, len, cnt[N];
void add(int x) {
nu += cnt[x]; cnt[x]++;
de += len; len++;
}
void del(int x) {
cnt[x]--; nu -= cnt[x];
len--; de -= len;
}


void solve() {
ll n = read(), m = read();
B = max(1.0, (double)n / sqrt(m));

rep(i, 1, n) c[i] = read();
rep(i, 1, m) {
q[i].l = read(); q[i].r = read();
q[i].id = i;
}
sort(q + 1, q + 1 + m, cmp);

ll l = 1, r = 0;
rep(i, 1, m) {
ll ql = q[i].l, qr = q[i].r;
if (ql == qr) {
ans[q[i].id].de = 1;
ans[q[i].id].nu = 0;
continue;
}
while (l > ql) add(c[--l]);
while (r < qr) add(c[++r]);
while (l < ql) del(c[l++]);
while (r > qr) del(c[r--]);
ll g = gcd(de, nu);
ans[q[i].id].de = de / g;
ans[q[i].id].nu = nu / g;
}

rep(i, 1, m) {
printf("%lld/%lld\n", ans[i].nu, ans[i].de);
}
}

int main() {
solve();
return 0;
}

带修莫队

理论

  • 离线算法;
  • 支持单点修改;
  • 区间查询。

对于每个询问,记录它是在第几个修改后进行的询问。那么这就相当于给每个询问区间增加了一个时间维度。和上面的思路类似,我们考虑:对于区间[l, r],已知它在第t个版本时的答案,求它在第t-1个版本或t+1个版本时的答案。这种转移如能也在常数时间内完成,就成了我们的带修莫队。

首先我们需要一个结构体表示一次修改:修改的位置和修改后的结果。注意,为了使得修改是可逆的,当当前时间小于此次修改时间的时候,这个修改结果就是下一个版本的对此位置的修改值;但是,当当前时间大于此次修改时间的时候,这个修改结果应当是上一个版本的修改值。另外需要注意的是,对于当前的固定区间[l, r],第t个修改未必影响当前区间的答案,但仍应当互换当前数组中的此位置值和此次修改的修改结果。代码中i表示当前区间和第i个询问中的区间一致;t表示进行第t个修改。

1
2
3
4
5
6
7
void modify(int i, int t) {
if (c[t].p >= q[i].l && c[t].p <= q[i].r) {
del(a[c[t].p]);
add(c[t].x);
}
swap(a[c[t].p], c[t].x);
}

这里的排序则先考虑左端点所在块,再考虑右端点所在块,相当于每块左右端点在大小为$B \times B$的矩形中;然后对$t$进行排序,依然采取奇偶优化。

1
2
3
4
int B;
bool cmp(const Query &a, const Query &b) {
return a.l/B == b.l/B ? (a.r/B == b.r/B ? ((a.r/B & 1) ? a.t < b.t : a.t > b.t) : a.r < b.r) : a.l < b.l;
}

当$N,M$同阶,带修莫队中的$B$应取约$N^{\frac 2 3}$,证明从略。

请看代码:Machine Learning - CodeForces 940F - Virtual Judge。代码中有一个用来离散化的结构体。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include<bits/stdc++.h>
using namespace std;
#define rep(i,from,to) for(int i=(int)(from);i<=(int)(to);++i)
#define rev(i,from,to) for(int i=(int)(from);i>=(int)(to);--i)
#define For(i,to) for(int i=0;i<(int)(to);++i)
#define see(x) (cerr<<(#x)<<'='<<(x)<<endl)
#define printCase(i) printf("Case %d: ", i)
#define endl '\n'
#define coutP(i) cout<<fixed<<setprecision(i)
void dbg() {cout << "\n";}
template<typename T, typename... A> void dbg(T a, A... x) {cout << a << ' '; dbg(x...);}
#define logs(x...) {cout << #x << " -> "; dbg(x);}
#define mmst(a,x) memset(a, x, sizeof(a))
typedef int ll;
typedef long double ld;
typedef double db;
inline ll read(){
ll x=0;bool s=1;char c=getchar();
while(c>'9'||c<'0'){if(c=='-')s=0;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<3)+(x<<1)+c-'0';c=getchar();}
return s?x:~x+1;
}

const int N = 100010;
struct Query {
int l, r, t, id;
} q[N]; int qcnt;

struct Change {
int p, x;
} c[N]; int ccnt;

int a[N];
int res[N];
int cnt[N << 1], cnx[N];

void add(int x) {
cnx[cnt[x]]--;
cnt[x]++;
cnx[cnt[x]]++;
}

void del(int x) {
cnx[cnt[x]]--;
cnt[x]--;
cnx[cnt[x]]++;
}

void modify(int i, int t) {
if (c[t].p >= q[i].l && c[t].p <= q[i].r) {
del(a[c[t].p]);
add(c[t].x);
}
swap(a[c[t].p], c[t].x);
}

int B;
bool cmp(const Query &a, const Query &b) {
return a.l/B == b.l/B ? (a.r/B == b.r/B ? ((a.r/B & 1) ? a.t < b.t : a.t > b.t) : a.r < b.r) : a.l < b.l;
}


struct Dis {
int b[N << 1], n;
void add(int x) { b[n++] = x; }
void init() {
sort(b, b + n);
n = unique(b, b + n) - b;
}

int get(int x) {
return lower_bound(b, b + n, x) - b + 1;
}
} dis;

void solve() {
int n = read(), m = read();
rep(i, 1, n) a[i] = read(), dis.add(a[i]);

rep(i, 1, m) {
int o = read();
if (o == 1) {
qcnt++;
q[qcnt].l = read();
q[qcnt].r = read();
q[qcnt].id = qcnt;
q[qcnt].t = ccnt;
} else {
ccnt++;
c[ccnt].p = read();
c[ccnt].x = read();
dis.add(c[ccnt].x);
}
}
dis.init();

rep(i, 1, n) {
a[i] = dis.get(a[i]);
}
rep(i, 1, ccnt) {
c[i].x = dis.get(c[i].x);
}

B = max(2.0, pow(n, 2.0 / 3));
sort(q + 1, q + 1 + qcnt, cmp);

int l = 1, r = 0, t = 0;
rep(i, 1, qcnt) {
int &ans = res[q[i].id];
if (q[i].l == q[i].r) { ans = 2; continue; }
while (l > q[i].l) add(a[--l]);
while (r < q[i].r) add(a[++r]);
while (l < q[i].l) del(a[l++]);
while (r > q[i].r) del(a[r--]);
while (t < q[i].t) modify(i, ++t);
while (t > q[i].t) modify(i, t--);

for (ans = 1; cnx[ans] > 0; ++ans);
}

rep(i, 1, qcnt) {
printf("%d\n", res[i]);
}
}

int main() {

solve();

return 0;
}

树上莫队

例题:Count on a tree II - SPOJ COT2 - Virtual Judge

如果可以将树上的路径序列化,就可以套用莫队算法。

考虑欧拉序:对整棵树进行DFS遍历时,每次进入和退出一个节点,都将此节点编号记录在序列最后,就能形成欧拉序。欧拉序的一些性质:

  • 显然每个点都在序列中出现了2次(进入和退出),因此整个列的长度是$2n$。我们可以用两个数组sted分别表示每个节点的进入和退出在序列中的位置下标。
  • 节点u的子树在序列中存在且仅存在于st[u]ed[u]之间。这由递归的意义可以得到。
  • u的儿子为v1, v2, ..., vk,那么u的欧拉序列应该形如:
1
[u [v1 ... v1] [v2 ... v2] ... [vk ... vk] u]
  • 从而st[u]之后若不是ed[u],一定是u的第一个儿子的st,即st[v1]ed[u]之后可能是u的一个兄弟,也可能是u的父亲。
  • uv的祖先,考虑从st[u]st[v]这段序列:这个序列可以看作是从u走到v的一条可能的“路线”。注意并非我们一般说的树上的路径:因为这条路线可能经过某些点两次。例如,如果vu遍历到的第2个儿子,那么u的第一个儿子的整棵子树都会出现在这段序列里。为了从这个“路线”转换到路径,我们只需要去掉所有经过两次的点就可以了。结论:若uv的祖先,从st[u]st[v]这段序列,去掉所有出现2次的点后,就是uv的路径。
  • u不是v的祖先,由于v并不在u的子树内,而经过st[u]之后我们的行走方向是向下、即向u的子树内,整个从st[u]ed[u]的一段都是无功而返。所以我们不如从ed[u]开始,到st[v]结束。这条路径可以被概括为:若干次(若干次(经过兄弟的子树),0次或1次(到父亲上去)),若干次(若干次(经过兄弟的子树),0次或1次进入儿子去)。这条路线上将包含uv的路径上除了lca(u, v)之外的所有点(因为可以从兄弟之间跳转,所以不用经过lca;整条路线事实上在st[lca]ed[lca]之间)。同样我们去掉出现了两次的点,就得到了路径。

所以这里出现0次和出现2次是一样的,也就可以用0和1来表示路径上某个元素是否出现。因此这里adddel是相同的操作。

这里面写了一个倍增求LCA,也可以用其他方法写。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#include<bits/stdc++.h>
using namespace std;
#define rep(i,from,to) for(int i=(int)(from);i<=(int)(to);++i)
#define rev(i,from,to) for(int i=(int)(from);i>=(int)(to);--i)
#define For(i,to) for(int i=0;i<(int)(to);++i)
#define see(x) (cerr<<(#x)<<'='<<(x)<<endl)
#define printCase(i) printf("Case %d: ", i)
#define endl '\n'
#define coutP(i) cout<<fixed<<setprecision(i)
void dbg() {cout << "\n";}
template<typename T, typename... A> void dbg(T a, A... x) {cout << a << ' '; dbg(x...);}
#define logs(x...) {cout << #x << " -> "; dbg(x);}
#define mmst(a,x) memset(a, x, sizeof(a))
typedef long long ll;
typedef long double ld;
typedef double db;
inline ll read(){
ll x=0;bool s=1;char c=getchar();
while(c>'9'||c<'0'){if(c=='-')s=0;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<3)+(x<<1)+c-'0';c=getchar();}
return s?x:~x+1;
}
const int N = 1e5 + 100;
const int M = 1e5 + 100;
struct Dis {
int b[N]; int n;
void add(int x) { b[n++] = x; }
void init() { sort(b, b + n); n = unique(b, b + n) - b; }
int get(int x) {
return lower_bound(b, b + n, x) - b + 1;
}
} dis;

int a[N];

vector<int> son[N]; int fa[N][18], dpth[N];
int tot, st[N], ed[N], eu[N << 1];
int F;
void dfs1(int u) {
eu[++tot] = u; st[u] = tot;
for (int v : son[u]) if (v != fa[u][0]) {
fa[v][0] = u;
dpth[v] = dpth[u] + 1;
rep(j, 1, F) fa[v][j] = fa[fa[v][j - 1]][j - 1];
dfs1(v);
}
eu[++tot] = u; ed[u] = tot;
}

int get_lca(int u, int v) {
if (dpth[u] < dpth[v]) swap(u, v);
rev(i, F, 0) if (dpth[fa[u][i]] >= dpth[v]) u = fa[u][i];
if (u == v) return u;
rev(i, F, 0) if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
return fa[u][0];
}


struct Query {
int x, y, l, r, lca, id;
short type;

void debug() {
logs(x, y, l, r, lca, id, type);
rep(j, l, r) {
cout << eu[j] << ' ';
}
cout << endl;
}
} q[M];

int B;
bool cmp(const Query &a, const Query &b) {
return (a.l/B == b.l/B) ? ((a.l/B & 1) ? a.r > b.r : a.r < b.r) : a.l < b.l;
}

int cnt[N];
int col[N];
int ans;
void add(int x) {
cnt[x] ^= 1;
if (cnt[x] == 0) {
col[a[x]]--;
if (col[a[x]] == 0) ans--;
} else {
if (col[a[x]] == 0) ans++;
col[a[x]]++;
}
}

int res[N];

int main() {
int n, m;
n = read(); m = read();
rep(i, 1, n) a[i] = read(), dis.add(a[i]);
dis.init();
rep(i, 1, n) a[i] = dis.get(a[i]);

rep(i, 1, n - 1) {
int u = read(), v = read();
son[u].push_back(v);
son[v].push_back(u);
}
F = log2(n) + 1;
dpth[1] = 1;
dfs1(1);

rep(i, 1, m) {
int x = read(), y = read();
if (st[x] > st[y]) swap(x, y);
int lca = get_lca(x, y);
q[i].x = x;
q[i].y = y;
q[i].id = i;
q[i].lca = lca;
if (lca == x) q[i].l = st[x], q[i].r = st[y], q[i].type = 0;
else q[i].l = ed[x], q[i].r = st[y], q[i].type = 1;
}

B = max(1.0, (db)n / sqrt(m));
sort(q + 1, q + 1 + m, cmp);

int l = 1, r = 0;
rep(i, 1, m) {
int &cur = res[q[i].id];
if (q[i].l == q[i].r) { cur = 1; continue; }
while (l > q[i].l) add(eu[--l]);
while (r < q[i].r) add(eu[++r]);
while (l < q[i].l) add(eu[l++]);
while (r > q[i].r) add(eu[r--]);
if (q[i].type) {
add(q[i].lca);
cur = ans;
add(q[i].lca);
} else cur = ans;
}

rep(i, 1, m) {
printf("%d\n", res[i]);
}

return 0;
}
This article was last updated on days ago, and the information described in the article may have changed.