您的位置:首页 > 其它

树状数组点更新,区间更新理解

2016-10-02 20:30 309 查看
对于一个数列A1A2A3…An,要求支持两种操作:

1.查询[x,y]区间的区间和

2.把[x,y]区间每个元素加val

事实上线段树也可以解决这样的问题,用上一点lazy的思想,每次只更新小区间的区间和,查询的时候加上祖先节点的影响就可以了。

我们用树状数组也可以解决这样的问题,并且效率会更高一些,复杂度都是nlog(n),但树状数组的常数更小,空间占用也更少。

首先回忆一下树状数组维护前缀和的过程,首先对于任意一个整数x,比如11,表达成二进制为1011,而1011 = 1000 + 10 + 1,这样的划分不超过log(11)次,考虑让一个数组C[]维护一小段区间和,让C[11]维护以A[11]作为结尾,长度为1的区间(这个时候只有A[11]一个值),让C[11-1]维护以A[11-1]作为结尾,长度为10(2进制)的区间和,同理让C[11-1-2]维护以A[11-1-2]作为结尾,长度为1000(2进制)的区间和,这样我们的小区间就覆盖了整个[1,11]的区间,前缀和就可以累加得到了,基于这样的思想我们设定了lowbit函数,他是这样的:

int lowbit(int x) {
return x & -x;
}


这个函数利用位运算技巧返回了整数x的最低位的1和后续的0组成的整数,比如10 = 1010(2进制),那么lowbit(10) = 2 = 10(2进制),根据上面的讨论,这个函数实际上用来分解整数达到划分区间的目的。

因此我们用C[i]维护A[i-lowbit(i)+1]A[i-lowbit(i)+2]…A[i]的区间和

假设C数组已经构造好了,那么我们查询从[1,x]区间的前缀和函数是这样的:

int query(int x)
{
int res = 0;
while(x > 0)
{
C[x] += lowbit(x); x -= lowbit(x);
}
return res;
}


联系上面11的例子,11 = 1000 + 10 + 1,所以C[11]只用管A[11]就可以了,剩余的10个不管,C[10]只用管A[9],A[10]就可以了,C[8]管剩下所有的。

当某个A[i]加上一个值val,如何更新呢

void update(int x, int val)
{
while(x <= n)
{
C[x] += val; x += lowbit(x);
}
}


首先可以明确的是,x + lowbit(x)必定会使x的二进制数发生进位,而且是最小的进位,事实上很容易发现进位之后的数一定会管到原来的A[i],画出树状图就可以发现

C数组也可以递推求得,就不介绍了,代码如下,很好懂

memset(c, 0, sizeof(c));
for(int i = 1; i <= n; i++)
{
c[i] += a[i];
int father = i + lowbit(i);
if(father <= n) c[father] += c[i];
}


对一段区间每个值加上一个val如何处理呢?

假设存在一个数组add[],add[x] = val 表示把[x,n]这个区间每个元素+val.

这样当我们把[x,y]区间每个元素都+val的时候,把问题转化为把[x,n]的区间每个元素+val,把[y+1,n]每个元素-val

这样当我们查询[1,x]区间的区间和时,实际的

sum[x] = A[1] + A[2] +… + A[x] + add[1] * (x+1-1) + add[2] * (x+1-2) + … + add[x] * (x+1-x) = (A[1] + A[2] + … + A[x]) + (x+1)(add[1] + add[2] + …+ add[x] ) - (1 * add[1] + 2*add[2] + …+x *add[x])

做到这里问题就转化为了求前缀和了,第一项是A[x]的前缀和直接维护即可,第二项和第三项分别用树状数组维护即可。代码很容易看懂。结合题目poj3468,以下是ac代码

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

const int maxn = 100005;
typedef long long LL;
int n,q,a[maxn];
LL sum[maxn];
LL c1[maxn],c2[maxn]; //c1维护add[i]的前缀和,c2维护add[i]*i的前缀和

int lowbit(int x) { return x & (-x); }

void update1(int x,int val)
{
while(x <= n)
{
c1[x] += val; x += lowbit(x);
}
}
LL query1(int x)
{
LL res = 0;
while(x > 0)
{
res += c1[x]; x -= lowbit(x);
}
return res;
}
void update2(int x,int val)
{
while(x <= n)
{
c2[x] += val; x += lowbit(x);
}
}
LL query2(int x)
{
LL res = 0;
while(x > 0)
{
res += c2[x]; x -= lowbit(x);
}
return res;
}

int main()
{
scanf("%d%d",&n,&q);
memset(a,0,sizeof(a));
memset(c1,0,sizeof(c1));
memset(c2,0,sizeof(c2));
memset(sum, 0, sizeof(sum));
for(int i = 1; i <= n; i++) scanf("%d" ,&a[i]);
for(int i = 1; i <= n; i++) sum[i] = sum[i-1] + a[i];
for(int i = 1; i <= q; i++)
{
char cmd[2];
scanf("%s",cmd);
if(cmd[0] == 'Q')
{
int left,right;
scanf("%d%d" ,&left,&right);
LL x = sum[right] + (right + 1) * query1(right) - query2(right);
LL y = sum[left-1] + left * query1(left-1) - query2(left-1);
printf("%I64d\n" ,x-y);
}
else if(cmd[0] == 'C')
{
int left,right,val;
scanf("%d%d%d" ,&left,&right,&val);
update1(left,val); //转化为点更新
update1(right+1,-val);
update2(left,val*left);
update2(right+1,-val*(right+1));
}
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: