字典树与AC自动机

最近正直夏令营季,抽时间学了一下字符串的重要算法——“字典树Trie”和“AC自动机”,在此处记录一下,以便后续温故和复习

字典树 Trie

普通字典树

字典树,也叫做前缀树、Trie,是一种用于存储大量字符串信息的树形数据结构,通常情况下形如下图

u=3442947730,1854572131&fm=26&gp=0

这张形象的图告诉我们,字典树都具有一个“空”的根节点,从根节点出发的一条简单路径上的字符按顺序组成的字符串就是存储的字符串,如上图中的这棵字典树,就存储了 “美利坚”、“美丽”、“金币”、“金子”、“帝王” 这五个字符串,然而有人会问:那 “美利” 这个字符串在不在这棵字典树内呢?答案是:既可以在,也可以不在。显然应该想到,确定一个字符串是否有效的方式,是在这个字符串的最后一个字符处打上标记,表示这是当前字符串的结尾,如下图中绿色节点代表了一个字符串的结尾,这样一来,就可以完美地实现字符串的插入存储操作了

u=2725490024,3402601292&fm=26&gp=0

插入操作

对于具体的插入操作,可以按照如下方式实现

1
2
3
4
5
6
7
8
9
10
inline void insert(string ss){
int len=ss.size();// 获取字符串长度
int rt=0;// 从根节点0开始
for(int i=0;i<len;i++){
int id=ss[i]-'a';// 获取儿子节点编号
if(!trie[rt][id]) trie[rt][id]=++tot;// 如果当前字符没有分配节点,分配给他新节点
rt=trie[rt][id];// 将根节点切换成当前节点的儿子节点
}
flag[rt]=1;// 标记这个节点为当前字符串结尾
}

询问操作

最简单的情况下,如果需要询问一个字符串是否在之前的字符串中出现过,则只需要按照类似插入操作相同的方法处理即可,时间复杂度为$O(log(L)),L为字符串长度$(你会发现还不如哈希)

1
2
3
4
5
6
7
8
9
10
11
inline bool query(string ss){
int len=ss.size();
int rt=0;
for(int i=0;i<len;i++){
int id=ss[i]-'a';
if(!trie[rt][id]) return false;// 如果当前节点不存在,字符串肯定不存在,直接返回false
rt=trie[rt][id];
}
if(flag[rt]) return true;// 出现过,返回true
return false;// 否则返回false
}

当然这是最简单的应用(不如hash_map快呢。。)但除此之外,还可以衍生出很多操作,具体可以通过刷题得知,有的操作可能就是哈希表无法胜任的了(当然很少)

习题

洛谷 P2580 于是他错误的点名开始了(裸Trie,用哈希表应该也行)

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include<bits/stdc++.h>
using namespace std;
const int maxn=2000005;
int n,m,tot,trie[maxn][30],flag[maxn];

inline void insert(string ss){
int len=ss.size();
int rt=0;
for(int i=0;i<len;i++){
int id=ss[i]-'a';
if(!trie[rt][id]) trie[rt][id]=++tot;
rt=trie[rt][id];
}
flag[rt]=1;
}
inline int query(string ss){
int len=ss.size();
int rt=0;
for(int i=0;i<len;i++){
int id=ss[i]-'a';
if(!trie[rt][id]) return 0;
rt=trie[rt][id];
}
if(flag[rt]==1) return flag[rt]++,1;
else return 2;
}
int main(){
ios::sync_with_stdio(false);
cin>>n;
for(int i=1;i<=n;i++){
string s;
cin>>s;
insert(s);
}
cin>>m;
while(m--){
string s;
cin>>s;
int ans=query(s);
if(ans==0) cout<<"WRONG"<<endl;
else if(ans==1) cout<<"OK"<<endl;
else cout<<"REPEAT"<<endl;
}
return 0;
}//https://www.luogu.com.cn/problem/P2580

HDU1251 统计难题(Trie裸体,统计以某字符串为前缀的字符串数量)

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include<bits/stdc++.h>
using namespace std;
const int maxn=2000005;
char s[maxn];
int trie[maxn][30];
int ci[maxn],tot;

inline void insert(char *ss){
int len=strlen(ss);
int rt=0;
for(int i=0;i<len;i++){
int id=ss[i]-'a';
if(!trie[rt][id])
trie[rt][id]=++tot;
ci[trie[rt][id]]++;
rt=trie[rt][id];
}
}
inline int query(char *ss){
int len=strlen(ss);
int rt=0;
for(int i=0;i<len;i++){
int id=ss[i]-'a';
if(!trie[rt][id]) return 0;
rt=trie[rt][id];
}
return ci[rt];
}
int main(){
freopen("test.in","r",stdin);
while(gets(s)){
if(s[0]=='\0') break;
insert(s);
}
while(~scanf("%s",s)){
printf("%d\n",query(s));
}
return 0;
}//http://acm.hdu.edu.cn/showproblem.php?pid=1251

LOJ1224 DNA Prefix(找到字符串集合中 最长公共前缀长度 乘以 字符串集合数量 最大的 子集)

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include<bits/stdc++.h>
using namespace std;
const int maxn=2000005;
int _,n;
int trie[maxn][5],ci[maxn],tot,ans;
inline void _insert(string ss){
int len=ss.size();
int rt=0;
for(int i=0;i<len;i++){
int id=ss[i]-'A';
if(!trie[rt][id])
trie[rt][id]=++tot;
ci[trie[rt][id]]++;
ans=max(ans,ci[trie[rt][id]]*(i+1));
rt=trie[rt][id];
}
}

inline void init(){
ans=0;
for(int i=0;i<=tot;i++){
ci[i]=0;
for(int j=0;j<5;j++){
trie[i][j]=0;
}
}
}

int main(){
ios::sync_with_stdio(false);
cin>>_;
for(int o=1;o<=_;o++){
cin>>n;
string tp;
for(register int i=1;i<=n;i++){
cin>>tp;
_insert(tp);
}
cout<<"Case "<<o<<": "<<ans<<endl;
init();
}
return 0;
}//http://lightoj.com/volume_showproblem.php?problem=1224

洛谷 P3879 [TJOI2010]阅读理解(统计某些生词在哪几篇短文中出现过)

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include<bits/stdc++.h>
using namespace std;
int n,m,tot[1005],trie[1005][5005][30],flag[1005][5005];

inline void insert(string ss,int num){
int len=ss.size();
int rt=0;
for(int i=0;i<len;i++){
int id=ss[i]-'a';
if(!trie[num][rt][id]) trie[num][rt][id]=++tot[num];
rt=trie[num][rt][id];
}
flag[num][rt]=1;
}
inline bool query(string ss,int num){
int len=ss.size();
int rt=0;
for(int i=0;i<len;i++){
int id=ss[i]-'a';
if(!trie[num][rt][id]) return 0;
rt=trie[num][rt][id];
}
if(flag[num][rt]) return true;
return false;
}

int main(){
ios::sync_with_stdio(false);
cin>>n;
string tp;
int x;
for(int i=1;i<=n;i++){
cin>>x;
for(int j=1;j<=x;j++){
cin>>tp;
insert(tp,i);
}
}
cin>>m;
while(m--){
cin>>tp;
vector<int> vec;
for(int i=1;i<=n;i++){
if(query(tp,i)){
vec.push_back(i);
}
}
for(auto t:vec){
cout<<t<<" ";
}
cout<<endl;
}
return 0;
}

POJ2001 Shortest Prefix(为每个字符串找到能唯一表示其的最短前缀,可暴力)

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include<cstring>
#include<iostream>
using namespace std;
const int maxn=2000005;
string s[2005],tp;
int trie[maxn][30],tot,ci[maxn],num;
inline void insert(string ss){
int len=ss.size();
int rt=0;
for(int i=0;i<len;i++){
int id=ss[i]-'a';
if(!trie[rt][id])
trie[rt][id]=++tot;
ci[trie[rt][id]]++;
rt=trie[rt][id];
}
}
inline string query(string ss){
int len=ss.size();
int rt=0;
string ans="";
for(int i=0;i<len;i++){
int id=ss[i]-'a';
ans+=ss[i];
if(ci[trie[rt][id]]==1) return ans;
rt=trie[rt][id];
}
return ans;
}
int main(){
// freopen("test.in","r",stdin);
ios::sync_with_stdio(false);
while(cin>>tp){
s[++num]=tp;
insert(tp);
}
for(int i=1;i<=num;i++){
cout<<s[i]<<" "<<query(s[i])<<endl;
}
return 0;
}//http://poj.org/problem?id=2001

POJ 2513 Colored Sticks(好题,Trie+欧拉路经转化+并査集)

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include<cstring>
#include<cstdio>
using namespace std;
const int maxn=2000005;
char s1[15],s2[15];
int trie[maxn][30],tot,vis[maxn],deg[maxn];
int fa[maxn];
int cnt;

int find(int x){
return x==fa[x]?x:fa[x]=find(fa[x]);
}

inline void unite(int x,int y){
x=find(x);
y=find(y);
if(x!=y){
fa[x]=y;
}
}

inline int _insert(char *ss){
int len=strlen(ss);
int rt=0;
for(int i=0;i<len;i++){
int id=ss[i]-'a';
if(!trie[rt][id]) trie[rt][id]=++tot;
rt=trie[rt][id];
}
if(!vis[rt]) vis[rt]=++cnt;
if(!fa[cnt]) fa[cnt]=cnt;
return vis[rt];
}
int main(){
// freopen("test.in","r",stdin);
// for(int i=0;i<=maxn;i++) fa[i]=i;
while(~scanf("%s %s",s1,s2)){
int num1=_insert(s1);
deg[num1]++;
int num2=_insert(s2);
deg[num2]++;
// printf("num1=%d num2=%d\n",num1,num2);
unite(num1,num2);
}
int ff=0;
for(int i=1;i<=cnt;i++){
// printf("fa[%d]=%d\n",i,find(i));
if(deg[i]&1){
ff++;
}
if(find(1)!=find(i)){
return printf("Impossible\n"),0;
}
}
if(ff==0||ff==2) printf("Possible\n");
else printf("Impossible\n");
return 0;
}//http://poj.org/problem?id=2513

POJ3630 Phone List(判断字符串集合是否有字符串是其他字符串前缀)

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include<cstring>
#include<iostream>
using namespace std;
const int maxn=2e6+5;
int _,n,trie[maxn][10],tot,ci[maxn],flag[maxn];
string s[10005];

inline void insert(string ss){
int len=ss.size();
int rt=0;
for(int i=0;i<len;i++){
int id=ss[i]-'0';
if(!trie[rt][id]) trie[rt][id]=++tot;
ci[trie[rt][id]]++;
rt=trie[rt][id];
}
flag[rt]=1;
}
inline bool query(string ss){
int len=ss.size();
int rt=0;
for(int i=0;i<len-1;i++){
int id=ss[i]-'0';
if(flag[trie[rt][id]]) return true;
rt=trie[rt][id];
}
return false;
}

inline void init(){
for(int i=0;i<=tot;i++){
flag[i]=0;
for(int j=0;j<=10;j++){
trie[i][j]=0;
}
}
}

int main(){
ios::sync_with_stdio(false);
cin>>_;
while(_--){
init();
cin>>n;
bool f=true;
for(int i=0;i<n;i++){
cin>>s[i];
insert(s[i]);
}
for(int i=0;i<n;i++){
if(query(s[i])){
f=false;
break;
}
}
if(f) cout<<"YES"<<endl;
else cout<<"NO"<<endl;
}
return 0;
}//http://poj.org/problem?id=3630

洛谷U83324 The Power of Face Rolling Keyboard(动态插入和询问字符串是否存在,卡map)

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include<bits/stdc++.h>
using namespace std;
const int maxn=2000005;
int m,tot,trie[maxn][30],flag[maxn];

inline void insert(string ss){
int len=ss.size();
int rt=0;
for(int i=0;i<len;i++){
int id=ss[i]-'a';
if(!trie[rt][id]) trie[rt][id]=++tot;
rt=trie[rt][id];
}
flag[rt]=1;
}
inline bool query(string ss){
int len=ss.size();
int rt=0;
for(int i=0;i<len;i++){
int id=ss[i]-'a';
if(!trie[rt][id]) return 0;
rt=trie[rt][id];
}
if(flag[rt]) return true;
return false;
}

int main(){
ios::sync_with_stdio(false);
cin>>m;
while(m--){
int x;
string tp;
cin>>x>>tp;
if(x==0) insert(tp);
else cout<<(query(tp)?"Yes":"No")<<endl;
}
return 0;
}//https://www.luogu.com.cn/problem/U83324

01字典树

与常规处理字符串的字典树不同,01字典树用来巧妙地处理二进制的位运算操作,最常见的是异或操作的处理,01字典树的每个节点(或者说边)代表了0或者1,剩余的模板和普通字典树并无差别

典型例题:最大异或对

题意:从N个整数中找出两个数做异或,求得出的最大异或值

思路:构建出01字典树,遍历每个数,对于当前数而言,从最高位开始考虑,如果当前数最高位为0,则找字典树中1的点,否则找0(因为0 xor 1 = 1是最优的),如果当前位需要找0时字典树中没有1这个节点存在,则只能找0了,这样贪心地找下去,最后的答案就是最优解

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
int n,a[maxn],tot,trie[maxn<<5][2];
bool flag[maxn<<5];

inline void insert(int x){
int rt=0;
for(int i=31;i>=0;--i){
int id=(x>>i)&1;
if(!trie[rt][id]) trie[rt][id]=++tot;
rt=trie[rt][id];
}
flag[rt]=1;
}

inline int query(int x){
int rt=0,res=0;
for(int i=31;i>=0;--i){
int id=(x>>i)&1;
if(trie[rt][!id]) rt=trie[rt][!id],res|=(1<<i);
else if(trie[rt][id]) rt=trie[rt][id];
else break;
}
return res;
}

int main(){
read(n);
rep(i,1,n) read(a[i]),insert(a[i]);
int ans=-inf;
rep(i,1,n) ans=max(ans,query(a[i]));
printf("%d\n",ans);
return 0;
}

练习题:洛谷P4551 最长异或路径

题意:给定一棵带边权的树,寻找两个节点使得这两个节点的简单路径异或和最大,求出最大异或值

思路:首先,需要想到$a,b$节点的异或路径为 根节点到a的异或和 异或上 根尖点到b的异或和,想到了这一点这题就转化成了裸的上一题,直接跑01字典树就行了

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
vector<PII> e[maxn];
int n,trie[maxn*50][2],tot;
ll wt[maxn],ans;

void dfs(int u,int fath){
for(int i=0;i<(int)e[u].size();++i){
int v=e[u][i].first,w=e[u][i].second;
if(v!=fath){
wt[v]=wt[u]^(ll)w;
dfs(v,u);
}
}
}

inline void insert(ll x){
int rt=0;
for(int i=50;i>=0;--i){
int id=(x>>i)&1;
if(!trie[rt][id]) trie[rt][id]=++tot;
rt=trie[rt][id];
}
}

inline ll query(ll x){
int rt=0;
ll res=0;
for(int i=50;i>=0;--i){
int id=(x>>i)&1;
if(trie[rt][!id]) rt=trie[rt][!id],res|=(1<<i);
else if(trie[rt][id]) rt=trie[rt][id];
else break;
}
return res;
}

int main(){
read(n);
rep(i,1,n-1){
int u,v,w;
read(u),read(v),read(w);
e[u].pb(mp(v,w));
e[v].pb(mp(u,w));
}
dfs(1,-1);
rep(i,1,n) insert(wt[i]);
rep(i,1,n) ans=max(ans,query(wt[i]));
printf("%lld\n",ans);
return 0;
}