树上启发式合并

关于树上启发式合并的学习


参考blog:


启发式合并

问:何为启发式合并?

答:根据人们的经验和直觉,对算法进行的优化,最常见的就是并查集中的按秩合并(讲size小的合并到size大的集合中)操作了。


一个比较典的一般启发式合并题:P3201 HNOI2009 梦幻布丁

题意:一个序列,每个值代表一个颜色,m次操作,操作分为两类:

  1. 将值为x的数全部变成y
  2. 询问整段序列中有多少段颜色?(一段颜色指的是一段相同的值,例如1221有3段颜色)

题解:可以证明每次操作后,ans只减不增,考虑维护一个数组col[i]表示第$i$个数的值,再维护一个邻接表vec[x] = vector<int> 表示值为$x$的数的下标。需要注意的是,此题只有x变y的操作,因此还需要维护一个颜色映射,以便启发式合并。最后每次1操作将小的vec[x]合并到大的vec[y]上就行了。

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int,int>;

void solve(){

int n,m;
cin >> n >> m;
vector<int> a(n+1), col(n+1), now(1000005);
vector<vector<int>> b(1000005);
int ans = 0;
for(int i=1;i<=n;++i){
cin >> a[i];
if(a[i] != a[i-1]) ans++;
col[i] = a[i];
b[col[i]].emplace_back(i);
now[col[i]] = col[i];
}
for(int op,x,y,i=1;i<=m;++i){
cin >> op;
if(op&1){
cin >> x >> y;
if(x == y) continue;
if(b[now[x]].size() > b[now[y]].size()) swap(now[x], now[y]);
for(auto &&t: b[now[x]]){
if(col[t-1] == now[y]) ans--;
if(col[t+1] == now[y]) ans--;
}
for(auto &&t: b[now[x]])
col[t] = now[y];
for(auto &&t: b[now[x]])
b[now[y]].emplace_back(t);
b[now[x]].clear();
}else{
cout << ans << '\n';
}
}

}

int main(){

cin.tie(nullptr)->sync_with_stdio(false);

int _ = 1;
// cin>>_;
while(_--){

solve();

}

return 0;
}

回归到我们的主题,对于一系列不带修改的树上统计问题,我们不难想到一个$O(n^2)$的暴力算法,即遍历每棵子树,对它进行一次$O(n)$的查询操作。为了优化这种做法,便有了树上启发式合并,通过人为设定一种启发式的暴力思路,使得将一些暴力的$O(n^2)$的算法优化至$O(n\log n)$。

树上启发式合并的前置姿势:(也正是启发式的来源)

  • 轻重链剖分

对于一种询问统计所有子树的贡献一类题,首先考虑一种暴力的思路,对于每一棵子树$O(n)$进行统计,再将统计的结果清空(因为我们要保证空间复杂度也是$O(n)$),这样算下来,总的时间复杂度是$O(n^2)$的,考虑如何优化这个暴力过程。

我们发现,对于一个节点$u$的1次统计来说,我们需要对$u$的所有儿子$v$对应子树进行一次统计 + 一次清空操作,然后将$u$节点的贡献算上,再统计所有子节点一次,从而更新$u$这棵子树的贡献。

实际上,当我们统计完了$u$的最后一个子节点$v_l$子树的贡献后,并不需要将其贡献清空,而是可以直接继承这个贡献。显然,我们需要让这个$v_l$对应子树越大越好!自然而然地,我们便想到了重儿子这个概念。

通常情况下,我们定义启发式合并的思路都是源自于树链剖分(重链剖分 or 长链剖分),原因是它预处理出了每个节点的重儿子!我们利用这个性质,在暴力时每次将重儿子所在子树的贡献保留并继承,对于轻儿子子树仍然采取暴力统计+清空的策略,于是便有了如下递归步骤:

  1. 统计轻儿子子树贡献 + 清空贡献
  2. 统计重儿子子树贡献 + 保留贡献
  3. 暴力统计当前节点以及它的所有轻儿子,记录贡献,并更新答案

我们可以简单证明一下这样操作的复杂度:

假设一棵树有$n$个节点,对于一个节点来说,它要么是重儿子,要么是轻儿子,在一个节点被暴力地统计贡献时,这个节点只会是轻儿子,因此一个点被统计一次当且仅当它到根的路径上有一条轻边。根据树链剖分的性质,一个节点到其根的轻边数量为$\log n$,因此总时间复杂度为$O(n\log n)$。

通用模板代码(重链剖分版 + map版)

依据上方的描述,我这里自己总结了一套树上启发式合并通用模板:(针对常规题型)

树链剖分 + 树上启发式合并

如果需要轻重链剖分, 则模板如下:(通常不会被卡常)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
/*
此处首先轻重练剖分,得到以下变量:
dfn: dfs序
bigson[u]: 节点u的重儿子(没有则为0)
csid[u]: 节点u的dfs序
sz[u]: 节点u的
其他... // 例如需要维护的值
*/

// 参数含义:当前节点u,父节点fath,当前节点是否是bigson(根节点不是)
auto dfs2 = [&](auto &&self, int u, int fath, bool isbig){
// 先暴力统计轻儿子,并清除贡献
for(auto &&v: e[u]){
if(v == fath || v == bigson[u]) continue;
self(self, v, u, false);
// 清除v子树贡献,这里因为与处理了dfs序,所以可以直接遍历
for(int i = csid[v]; i <= csid[v] + sz[v] - 1; ++i){
/* 清除i贡献操作 */
}
/* 清除v子树贡献操作 */
}

// 最后统计重儿子,保留贡献
if(bigson[u]) self(self, bigson[u], u, true);

// 前面属于递归逻辑,这里开始正式处理
// 暴力统计所有轻儿子,因为只有轻儿子之前被清除了
for(auto &&V: e[u]){
if(v == fath || v == bigson[u]) continue;
/* 添加v子树贡献操作 */
}

// 计算当前节点u贡献
/* 添加u贡献操作 */
/* 记录答案操作 */
};

map + 树上启发式合并

当然,如果不怕卡常的话,某些题目也可以写成更简单的std::map的递归形式,如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
auto dfs = [&](auto &&self, int u, int fath) -> map<int,int> {
map<int,int> ma;
ma[a[u]]++; /* 一些更新根节点贡献的操作 */
for(auto &&v: e[u]){
if(v == fath) continue;
auto ma2 = self(self, v, u);
if(ma2.size() > ma.size()) swap(ma, ma2); // 启发式合并
for(auto &&[x,y]: ma2){
/* 合并 + 更新操作 */
}
}
/* 更新答案操作 */
return ma;
};

例题

U41492 树上数颜色

题意:一棵有根树,带点权,m次询问,每次问一棵子树有多少不同权值。

题解:记录每种颜色的出现次数cnt[]和总共有多少种颜色tot,然后套模板。

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int,int>;

void solve(){

int n;
cin >> n;
vector<vector<int>> e(n+1);
for(int x,y,i=0;i<n-1;++i){
cin >> x >> y;
e[x].emplace_back(y);
e[y].emplace_back(x);
}
vector<int> a(n+1);
for(int i=1;i<=n;++i) cin >> a[i];

int dfn = 0;
vector<int> bigson(n+1), sz(n+1), csid(n+1), csw(n+1);
auto dfs1 = [&](auto &&self, int u, int fath) -> void{
csid[u] = ++dfn;
csw[dfn] = a[u];
sz[u] = 1;
for(auto &&v: e[u]){
if(v == fath) continue;
self(self, v, u);
sz[u] += sz[v];
if(sz[bigson[u]] < sz[v]){
bigson[u] = v;
}
}
};

dfs1(dfs1, 1, -1);

vector<int> cnt(n+1); // 每种颜色的出现次数
vector<int> ans(n+1);
int res = 0; // 有多少种颜色

auto add = [&](int val) -> void{
if(cnt[val]++ == 0) res ++;
};

auto dfs2 = [&](auto &&self, int u, int fath, bool isbig) -> void{
for(auto &&v: e[u]){ // 先统计轻儿子,不记贡献
if(v == fath || v == bigson[u]) continue;
self(self, v, u, false);

// 对于轻儿子对答案的贡献,直接清空就行了,不保留任何影响
for(int i=csid[v];i<=csid[v]+sz[v]-1;++i){
cnt[csw[i]] = 0;
}
res = 0;
}

if(bigson[u]){ // 再统计重儿子,记录贡献
self(self, bigson[u], u, true);
}

// 暴力更新轻儿子贡献
for(auto &&v: e[u]){
if(v == fath || v == bigson[u]) continue;
for(int i=csid[v];i<=csid[v]+sz[v]-1;++i){
add(csw[i]);
}
}

// 处理u这个点的贡献
add(csw[csid[u]]);
ans[u] = res; // 更新答案

};

dfs2(dfs2, 1, -1, false);

int m;
cin >> m;
for(int x,i=0;i<m;++i){
cin >> x;
cout << ans[x] << '\n';
}

}

int main(){

cin.tie(nullptr)->sync_with_stdio(false);

int _ = 1;
// cin>>_;
while(_--){

solve();

}

return 0;
}

CF600E Lomsat gelral

题意:一个有根树,带点权,求出每一棵子树内,点权所有不同的众数之和。

题解:记录每种权值的出现次数cnt[]、出现次数最大的权值maxh、点权所有不同众数之和res

对于添加贡献操作,只需要每次将出现次数++,然后判断是否与最大值相等并对应更新maxh, res即可

对于删除贡献操作,因为我们先继承了重儿子贡献,再遍历轻儿子,因此对于轻儿子子树来说,这些贡献直接清零就行了

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
/*
https://www.luogu.com.cn/problem/CF600E

树上启发式合并模板题

题意:一个有根树,带点权,求出每一棵子树内,点权众数之和。


*/


#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int,int>;

void solve(){

int n;
cin >> n;
vector<int> a(n+1);
vector<vector<int>> e(n+1);

for(int i=1;i<=n;++i) cin >> a[i];

for(int x,y,i=0;i<n-1;++i){
cin >> x >> y;
e[x].emplace_back(y);
e[y].emplace_back(x);
}

int dfn = 0;
vector<int> sz(n+1), bigson(n+1), csid(n+1), csw(n+1);

auto dfs1 = [&](auto &&self, int u, int fath) -> void{
csid[u] = ++dfn;
csw[dfn] = a[u];
sz[u] = 1;
for(auto &&v: e[u]){
if(v == fath) continue;
self(self, v, u);
sz[u] += sz[v];
if(sz[bigson[u]] < sz[v]){
bigson[u] = v;
}
}
};

dfs1(dfs1, 1, -1);

vector<int> cnt(n+1); // cnt[i] 表示 i 颜色出现多少次
vector<ll> ans(n+1);
int maxh = 0;
ll res = 0;

auto add = [&](int val){
cnt[val]++;
if(cnt[val] > maxh){
maxh = cnt[val];
res = val;
}else if(cnt[val] == maxh){
res += val;
}
};

auto dfs2 = [&](auto &&self, int u, int fath, bool isbig) -> void{
for(auto &&v: e[u]){
if(v == fath || v == bigson[u]) continue;
self(self, v, u, false);
for(int i=csid[v];i<=csid[v]+sz[v]-1;++i){
cnt[csw[i]] = 0;
}
maxh = 0, res = 0;
}

if(bigson[u]){
self(self, bigson[u], u, true);
}

for(auto &&v: e[u]){
if(v == fath || v == bigson[u]) continue;
for(int i=csid[v];i<=csid[v]+sz[v]-1;++i){
add(csw[i]);
}
}

add(csw[csid[u]]);
ans[u] = res;
};

dfs2(dfs2, 1, -1, false);

for(int i=1;i<=n;++i) cout << ans[i] << " \n"[i==n];

}

int main(){

cin.tie(nullptr)->sync_with_stdio(false);

int _ = 1;
// cin>>_;
while(_--){

solve();

}

return 0;
}

CF375D Tree and Queries

题意:一棵有根树,带点权,m次询问,每次问一棵子树中出现次数>=k的权值有多少种

题解:首先我们将询问离线到每个节点,考虑对每个子树求贡献,我们注意到

又是出现次数,老套路了,对这题来说,我们需要记录每个权值出现次数cnt[],以及出现次数>=x的权值的种类数sum[x],对于前者来说非常容易维护,对于后者其实也很简单:

我们只需要在每次添加贡献时,将对应的sum[++cnt[u]]++

对于清除贡献操作,将对应的sum[cnt[u]--]--

实际上,我们维护的sum[]数组是一个以出现次数为下标的种类数的后缀和

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int,int>;

void solve(){

int n,m;
cin >> n >> m;
vector<int> a(n+1);
vector<vector<int>> e(n+1);
for(int i=1;i<=n;++i) cin >> a[i];
for(int x,y,i=0;i<n-1;++i){
cin >> x >> y;
e[x].emplace_back(y);
e[y].emplace_back(x);
}
vector<vector<pii>> q(n+1);
for(int x,k,i=1;i<=m;++i){
cin >> x >> k;
q[x].emplace_back(pair(i, k));
}

int dfn = 0;
vector<int> sz(n+1), bigson(n+1), csid(n+1), csw(n+1);

auto dfs1 = [&](auto &&self, int u, int fath) -> void{
csid[u] = ++dfn;
csw[dfn] = a[u];
sz[u] = 1;
for(auto &&v: e[u]){
if(v == fath) continue;
self(self, v, u);
sz[u] += sz[v];
if(sz[v] > sz[bigson[u]]){
bigson[u] = v;
}
}
};

dfs1(dfs1, 1, -1);

vector<int> ans(m+1);
vector<int> cnt(100005), sum(100005);

auto add = [&](int x){
cnt[x]++, sum[cnt[x]]++;
};
auto del = [&](int x){
sum[cnt[x]]--, cnt[x]--;
};

auto dfs2 = [&](auto &&self, int u, int fath, bool isbig) -> void{
for(auto &&v: e[u]){
if(v == fath || v == bigson[u]) continue;
self(self, v, u, false);
for(int i=csid[v];i<=csid[v]+sz[v]-1;++i){
del(csw[i]);
}
}
if(bigson[u]){
self(self, bigson[u], u, true);
}
for(auto &&v: e[u]){
if(v == fath || v == bigson[u]) continue;
for(int i=csid[v];i<=csid[v]+sz[v]-1;++i){
add(csw[i]);
}
}
add(csw[csid[u]]);
for(auto &&[id, val]: q[u]){
ans[id] = sum[val];
}
};
dfs2(dfs2, 1, -1, false);
for(int i=1;i<=m;++i) cout << ans[i] << '\n';

}

int main(){

cin.tie(nullptr)->sync_with_stdio(false);

int _ = 1;
// cin>>_;
while(_--){

solve();

}

return 0;
}

CF1009F Dominant Indices

题意:一棵有根树,对每个节点,求出其子树中具有最多节点数的深度,且该深度最小

题解:同样的套路,维护每个深度的节点个数cnt[]、最大节点个数maxh、最大节点个数对应深度res,那么考虑添加和删除贡献操作

添加贡献时,只需要把对应cnt[dep[u]]++,然后判断是否超出或者等于最大节点数,注意,如果等于了,则需要更新最小深度

删除贡献时,只需要把对应cnt[dep[v]]--,直接将maxh, res清零即可

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int,int>;

void solve(){

int n;
cin >> n;
vector<vector<int>> e(n+1);
for(int x,y,i=0;i<n-1;++i){
cin >> x >> y;
e[x].emplace_back(y);
e[y].emplace_back(x);
}
int dfn = 0;
vector<int> csid(n+1), dep(n+1), sz(n+1), bigson(n+1), csw(n+1);
vector<int> cnt(n+1), ans(n+1);
int maxh = 0, res = 0;

auto dfs1 = [&](auto &&self, int u, int fath, int depth) -> void{
csid[u] = ++dfn;
csw[dfn] = u;
sz[u] = 1;
dep[u] = depth;
for(auto &&v: e[u]){
if(v == fath) continue;
self(self, v, u, depth+1);
sz[u] += sz[v];
if(sz[v] > sz[bigson[u]]){
bigson[u] = v;
}
}
};
dfs1(dfs1, 1, -1, 0);

auto add = [&](int u){
cnt[dep[u]]++;
if(cnt[dep[u]] > maxh){
maxh = cnt[dep[u]];
res = u;
}else if(cnt[dep[u]] == maxh && dep[u] < dep[res]){
res = u;
}
};

auto dfs2 = [&](auto &&self, int u, int fath, bool isbig) -> void{
for(auto &&v: e[u]){
if(v == fath || v == bigson[u]) continue;
self(self, v, u, false);
for(int i=csid[v];i<=csid[v]+sz[v]-1;++i){
cnt[dep[csw[i]]]--;
}
maxh = 0, res = 0;
}
if(bigson[u]) self(self, bigson[u], u, true);
for(auto &&v: e[u]){
if(v == fath || v == bigson[u]) continue;
for(int i=csid[v];i<=csid[v]+sz[v]-1;++i){
add(csw[i]);
}
}
add(csw[csid[u]]);
ans[u] = dep[res] - dep[u];
};

dfs2(dfs2, 1, -1, false);

for(int i=1;i<=n;++i) cout << ans[i] << '\n';

}

int main(){

cin.tie(nullptr)->sync_with_stdio(false);

int _ = 1;
// cin>>_;
while(_--){

solve();

}

return 0;
}

CF570D Tree Requests

题意:一棵有根树,每个节点代表一个小写字母

m次询问格式为[x,y],每次问节点x为根的子树中,真实深度(到原树根的距离)为y的所有节点字母是否能构成回文串。

题解:首先把询问离线到每个不同的节点上,考虑对每个子树求贡献,我们注意到

能构成回文串 == 出现次数为奇数次的字母不能超过1个 == bitmask里至多有一个1

f(i)表示深度为i的节点字母构成的bitmask,则popcont<=1即是"Yes",对于暴力添加和删除操作,只需要异或一下就行了,然后树上启发式合并就行了

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int,int>;

void solve(){

int n,m;
cin >> n >> m;

vector<vector<int>> e(n+1);
for(int x, i=2;i<=n;++i){
cin >> x;
e[x].emplace_back(i);
e[i].emplace_back(x);
}

vector<int> a(n+1);
char c;
for(int t, i=1;i<=n;++i){
cin >> c;
t = (c - 'a');
a[i] = 1 << t;
}

vector<vector<pii>> q(n+1);
for(int x,y,i=1;i<=m;++i){
cin >> x >> y;
q[x].emplace_back(pair(i, y));
}


int dfn = 0;
vector<int> csid(n+1), csw(n+1), sz(n+1), bigson(n+1), dep(n+1);
vector<int> stat(n+1);
vector<string> ans(m+1);

auto dfs1 = [&](auto &&self, int u, int fath, int depth) -> void{
csid[u] = ++dfn;
csw[dfn] = u;
dep[u] = depth;
sz[u] = 1;
for(auto &&v: e[u]){
if(v == fath) continue;
self(self, v, u, depth+1);
sz[u] += sz[v];
if(sz[v] > sz[bigson[u]]){
bigson[u] = v;
}
}
};
dfs1(dfs1, 1, -1, 1);

auto dfs2 = [&](auto &&self, int u, int fath, bool isbig) -> void{
for(auto &&v: e[u]){
if(v == fath || v == bigson[u]) continue;
self(self, v, u, false);
for(int i=csid[v];i<=csid[v]+sz[v]-1;++i){
stat[dep[csw[i]]] ^= a[csw[i]];
}
}
if(bigson[u]) self(self, bigson[u], u, true);
for(auto &&v: e[u]){
if(v == fath || v == bigson[u]) continue;
for(int i=csid[v];i<=csid[v]+sz[v]-1;++i){
stat[dep[csw[i]]] ^= a[csw[i]];
}
}
stat[dep[u]] ^= a[u];
for(auto &&[id, val]: q[u]){
ans[id] = (__builtin_popcount(stat[val])<=1 ? "Yes" : "No");
}
};

dfs2(dfs2, 1, -1, false);

for(int i=1;i<=m;++i) cout << ans[i] << '\n';
}

int main(){

cin.tie(nullptr)->sync_with_stdio(false);

int _ = 1;
// cin>>_;
while(_--){

solve();

}

return 0;
}


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!