树状数组(BIT)专题

Binary Indexed Tree

参考blog:

写的真是太好啦!!!

树状数组是一个比线段树常数小,码量小的数据结构,可在$O(logn)$时间内解决一写简单的单点/区间修改 和 单点/区间查询的问题

树状数组里最核心的操作:lowbit操作,作用是求出一个数最低位的1,即从右往左数第一个1所对应的二进制数,也即2^k其中k为从右往左数碰到第一个1之前的连续0的个数,那么它可以用一下操作实现

1
2
3
inline lowbit(int x){
return x&-x;
}

单点更新+单点查询

啊、这。。。传统数组不就行了嘛

单点更新+区间查询

【模板】树状数组1

这是最简单的树状数组应用,构建树状数组的方法就是:输入的时候同时利用原始数组的值完成单点更新

code
1
2
3
4
5
6
7
8
9
10
11
inline int getsum(int i){//求1~i区间的和
int res=0;
while(i>0) res+=c[i],i-=lowbit(i);
}
inline void update(int i,int k){//在原数组i位置加上k
while(i<=n) c[i]+=k,i+=lowbit(i);
}
int main(){
for(int i=1;i<=n;i++)
scanf("%d",&a[i]),update(i,a[i]);//输入时同时更新
}

除此之外,单点更新+区间查询的异或和乘积也是可以维护的,这里就不贴代码了(注意乘积的区间查询和单点更新需要把除法转换成乘法逆元

【模板题】单点更新+区间乘积查询

区间更新+单点查询

【模板】树状数组2

这是一个简单的变式,只需要利用原数组的差分构建树状数组,这样一来对于区间[l,r]的更新就只需要update(l,k),update(r+1,-k)就行了(类似于前缀和差分),而单点查询就是求一次和啦~

只需要把上面的输入时更新换为:update(i,a[i]-a[i-1])

更新时,区间[l,r]+kupdate(l,k),update(r+1,-k)

查询x点的值:getsum(x)

区间更新+区间查询

可以很快地过了这道弱化的线段树模板题:【模板】线段树1

这是进阶版的变式,根据blog里面的推算公式,我们需要维护两个由差分构成的树状数组,满足下面的关系式,最后的答案就是x*getsum1(x)-getsum2(x)

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
inline void update(int i,int k){//单点更新
int x=i;
//这里需要维护两个树状数组
while(i<=n) c1[i]+=k,c2[i]+=k*(x-1),i+=lowbit(i);
}
inline ll getsum(int i){//前缀和
ll res=0;
int x=i;
while(i>0) res+=x*c1[i]-c2[i],i-=lowbit(i);
return res;
}

update(l,k),update(r+1,-k);//区间更新
getsum(r)-getsum(l-1);//区间查询

括号法维护种类数

例题:校门外的树

题意:

区间更新:使[l,r]区间内增加一种新品种

区间查询:询问[l,r]区间内种类总数

所谓括号法,就是使用一个BIT维护区间更新时的左端点(称为左括号)个数,另一个BIT维护右端点(称为右括号)个数,这样一来上面的区间更新就变为单点更新(更新左右括号),区间查询时就是右端点左侧(包括右端点)的左括号总数-左端点左侧(不包括左端点)的右括号总数(因为只要左端点左侧的左括号没有右括号匹配,就说明匹配他的右括号在左端点右边,那么种类数+1)

code
1
2
update(c1,l,1),update(c2,r,1);//c1维护左括号,c2维护右括号
ans=getsum(c1,r)-getsum(c2,l-1);//得到[l,r]区间种类总数

树状数组求逆序对

模板题:POJ 2299 Ultra-QuickSort

首先是离散化操作三部曲:排序、去重、二分查找更新原数组

离散化之后,我们就得到了每个数的相对大小位置,从1~n遍历整个数组,利用一个树状数组存储每个数出现的次数,每次利用i-getsum(a[i])更新答案(其中getsum(a[i])是从1a[i]中的数出现的次数,即小于等于a[i]的数出现的次数,i是当前取出的数的个数,那么i-getsum(a[i])就是到目前为止大于a[i]的数出现的次数,当然这里的a[i]是已经离散化过后的,只保留了相对大小)

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include<cstdio> 
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long ll;
int n;
ll a[500005],c[500005];
inline ll lowbit(ll x){return x&-x;}
inline void update(int i,ll k){
while(i<=n) c[i]+=k,i+=lowbit(i);
}
inline ll getsum(int i){
ll res=0;
while(i>0) res+=c[i],i-=lowbit(i);
return res;
}

int main(){
while(~scanf("%d",&n)){
memset(c,0,sizeof(c));
if(n==0) break;
ll ans=0;
vector<ll> vec;
for(int i=1;i<=n;i++)
scanf("%lld",&a[i]),vec.push_back(a[i]);
sort(vec.begin(),vec.end());//离散化三部曲
vec.erase(unique(vec.begin(),vec.end()),vec.end());
for(int i=1;i<=n;i++){
int pos=lower_bound(vec.begin(),vec.end(),a[i])-vec.begin();
update(pos+1,1);
ans+=i-getsum(pos+1);//更新答案
}
printf("%lld\n",ans);
}
return 0;
}

练习题

POJ 3067 Japan(结构体排序+逆序对)

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
int n,m,k;
ll c[500005];
struct node{
int x,y;
}a[500005];
bool cmp(node x1,node x2){
if(x1.x==x2.x) return x1.y<x2.y;
return x1.x<x2.x;
}
inline void update(int i){
while(i<=m) c[i]++,i+=i&-i;
}
inline ll getsum(int i){
ll res=0;
while(i>0) res+=c[i],i-=i&-i;
return res;
}
int main(){
int t;
scanf("%d",&t);
for(int o=1;o<=t;o++){
memset(c,0,sizeof(c));
scanf("%d%d%d",&n,&m,&k);
for(int i=1;i<=k;i++){
scanf("%d%d",&a[i].x,&a[i].y);
}
sort(a+1,a+1+k,cmp);
ll ans=0;
for(int i=1;i<=k;i++){
update(a[i].y);
ans+=i-getsum(a[i].y);
}
printf("Test case %d: %lld\n",o,ans);
}
return 0;
}

POJ 2352 Stars(裸逆序对)

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include<iostream>
#include<cstdio>
using namespace std;
typedef long long ll;
int n,maxh;
int a[20005],c[50005];
ll ans[20005];
inline int lowbit(int i){
return i&-i;
}
inline void update(int i){
while(i<=maxh) c[i]++,i+=lowbit(i);
}
inline ll getsum(int i){
ll res=0;
while(i) res+=c[i],i-=lowbit(i);
return res;
}
int main(){
scanf("%d",&n);
int x;
for(int i=1;i<=n;i++) scanf("%d%d",&a[i],&x),a[i]++,maxh=max(maxh,a[i]);
for(int i=1;i<=n;i++){
ans[getsum(a[i])]++;
// printf("%d level:%d\n",a[i],getsum(a[i]));
update(a[i]);
}
for(int i=0;i<n;i++) printf("%lld\n",ans[i]);
return 0;
}

树状数组维护区间最值

参考blog:

最值不具有结合律,因此不能用类似于维护前缀和一样简单的方法,那么该怎么维护区间最值呢?

首先,介绍一下单点更新(直接更改单点的值,同样维护的是单点最值,需配合查询最值使用),由BIT的特性可知,当要更改下标为x的数时,当遍历到i时至多只有i-lowbit(i)+1 ~ i这么多数受到影响,因此只需要修改这么多数即可

1
2
3
4
5
6
7
8
9
inline void update(int x,int y){//将下标为x的值更改为y
a[x]=y;//原数组也得改
for(int i=x;i<=n;i+=lowbit(i)){//枚举所有跟x有关的树状数组
c[i]=y;//
for(int j=1;j<lowbit(i);j<<=1){//枚举所有当前i值能影响到的所有树状数组
c[i]=max(c[i],c[i-j]);
}
}
}

接下来是区间最值查询:由BIT特性可知,对于要查询的区间[l,r],我们不能直接用c[r]作为答案,因为c[r]的左端点不一定就恰好是l,因此我们要分类讨论一下:

  1. c[r]左端点i<=r:我们取a[r]更新答案,再令r-=1重复这个讨论过程
  2. c[r]左端点i>r:可以直接用c[r]更新答案,然后用[l,i]区间最值重复这个讨论过程

上述的c[r]左端点i=r-lowbit(r),至此,可以得出以下代码:

1
2
3
4
5
6
7
8
9
10
inline int getmax(int x,int y){//求区间[x,y]最值
int ans=0;
while(x<=y){//直到x==y就退出
ans=max(ans,a[y]);//拿a[y]更新
for(y--;y-lowbit(y)>=x;y-=lowbit(y)){
ans=max(ans,c[y]);//拿c[y]更新
}
}
return ans;
}

模板题:HDU 1754 I Have It

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include<bits/stdc++.h>
using namespace std;
int n,m,a[200005],c[200005];
inline int lowbit(int x){return x&-x;}
inline void update(int x,int y){
a[x]=y;
for(int i=x;i<=n;i+=lowbit(i)){
c[i]=y;
for(int j=1;j<lowbit(i);j<<=1){
c[i]=max(c[i],c[i-j]);
}
}
}
inline int getmax(int x,int y){
int ans=0;
while(x<=y){
ans=max(ans,a[y]);
for(--y;y-lowbit(y)>=x;y-=lowbit(y)){
ans=max(ans,c[y]);
}
}
return ans;
}
int main() {
while(~scanf("%d%d",&n,&m)){
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
update(i,a[i]);
}
char c[2];int x,y;
for(int i=1;i<=m;i++){
scanf("%s%d%d",c,&x,&y);
if(c[0]=='U') update(x,y);
else printf("%d\n",getmax(x,y));
}
}
return 0;
}

二维树状数组

参考blog:

类似于二维前缀和,我们可以维护一个二维树状数组来实现单点更新和区间查询的操作,同理,可以通过差分操作来实现区间更新和单点查询操作

单点更新+区间查询

将二维数组中[x,y]的值+k

1
2
3
4
5
6
inline void update(int x,int y,int k){
for(int i=x;i<=n;i+=lowbit(i))
for(int j=y;j<=m;j+=lowbit(j))
c[i][j]+=k;
}
update(x,y,k);

求二维数组从[1,1][x,y]的元素和

1
2
3
4
5
6
7
inline int getsum(int x,int y){
int res=0;
for(int i=x;i>0;i-=lowbit(i))
for(int j=y;j>0;j-=lowbit(j))
res+=c[i][j];
}
getsum(x,y);

求二维数组从[x1,y1][x2,y2]的和

1
getsum(x2,y2)-getsum(x2,y1-1)-getsum(x1-1,y2)+getsum(x1-1,y1-1)

区间更新+单点查询

模板题:POJ 2155 Matrix

维护差分数组后:

将二维数组从[x1,y1][x2,y2]的值全部+k

1
2
3
4
update(x1,y1,k);
update(x2+1,y1,-k);//注意两负两正
update(x1,y2+1,-k);
update(x2+1,y2+1,k)

求二维数组在[x,y]处的值

1
getsum(x,y)

区间更新+区间查询

首先放一个很长的公式:

于是我们需要维护四个树状数组:

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
inline void update(int x,int y,int k){//单点更新
for(int i=x;i<=n;i+=lowbit(i))
for(int j=y;j<=m;j+=lowbit(j)){
c1[i][j]+=k;
c2[i][j]+=k*x;
c3[i][j]+=k*y;
c4[i][j]+=k*x*y;
}
}
inline int getsum(int x,int y){//前缀和
int res=0;
for(int i=x;i;i-=lowbit(i))
for(int j=y;j;j-=lowbit(j))
res+=(x+1)*(y+1)*c1[i][j]-(y+1)*c2[i][j]-(x+1)*c3[i][j]+c4[i][j];
return res;
}

//区间更新
update(x1,y1,k);
update(x1,y2+1,-k);
update(x2+1,y1,-k);
update(x2+1,y2+1,k);

//区间查询
getsum(x2,y2)-getsum(x1-1,y2)-getsum(x2,y1-1)+getsum(x1-1,y1-1);

本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!