您的位置:首页 > 其它

bzoj 1500 维修数列

2013-03-28 19:36 253 查看
Splay的BOSS级别的题目,几乎包含了所有区间维护的操作,对于想要提高代码熟练度的人来说这个题非常值得一写。

首先我们要维护的域是:size,sum,la(左max),ra(右max),ma(整体max)更新的时候用这些域更新就行,注意下传的时候也要更新响应域值。

对于那些恶心操作,我就挨个说一下吧:

注:标准结构:区间右端点+1(R)为根,区间左端点-1(L)为根的左儿子,这样目标区间就是L的右儿子,这种形式以后都用"标准结构"代替。

插入操作:先把需要插入的序列建成一个小平衡树(递归),转出标准结构,插到L的右儿子上就行了。

删除操作:转出标准结构,把L的右儿子切下来就行了(注意因为要回收空间,所以还是把要切的子树遍历了一遍,把这颗树上的节点标号入栈)。

覆盖操作:转出标准结构,把L的右儿子打上覆盖标记cov(以后下传的时候把节点的值改为cov的值,sum变为cov*size,la=ra=ma变为cov和sum中较大的一个,因为有负数的情况)

翻转操作:转出标准结构,把L的右儿子打上翻转标记rev(以后下传的时候要交换左右儿子并且交换la和ra)

求和操作:转出标准结构,答案就是L的右儿子的sum

最大值操作:转出标准结构,答案就是L的右儿子的ma

区间操作的时候一定要明白一点,就是打标记的同时做修改,就是说当一个点带了标记的时候,它已经被修改过了。

这个题还是做了我不少时间的。

repairnum

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#define maxn 1000000
#define inf 2147483646
using namespace std;
int c[maxn][2],fa[maxn];
int a[maxn],key[maxn],sum[maxn],la[maxn],ra[maxn],ma[maxn],cov[maxn],size[maxn];
int q[maxn];
bool rev[maxn];
int n,m,tot,num,rot,st,ed,tail;

void update(int x)
{
if (!x) return;
la[x]=max(la[c[x][0]],sum[c[x][0]]+key[x]+max(0,la[c[x][1]]));
ra[x]=max(ra[c[x][1]],sum[c[x][1]]+key[x]+max(0,ra[c[x][0]]));
ma[x]=max(max(ma[c[x][0]],ma[c[x][1]]),key[x]+max(0,ra[c[x][0]])+max(0,la[c[x][1]]));
sum[x]=sum[c[x][0]]+sum[c[x][1]]+key[x];
size[x]=size[c[x][0]]+size[c[x][1]]+1;
}

void reverse(int x)
{
if (!x) return;
swap(c[x][0],c[x][1]);
swap(la[x],ra[x]);
rev[x]^=1;
}

void recover(int x,int z)
{
if (!x) return ;
key[x]=cov[x]=z;
sum[x]=size[x]*z;
la[x]=ra[x]=ma[x]=max(z,sum[x]);
}

void down(int x)
{
if (!x) return;
if (rev[x])
{
reverse(c[x][0]);
reverse(c[x][1]);
rev[x]=0;
}
if (cov[x]!=-inf)
{
recover(c[x][0],cov[x]);
recover(c[x][1],cov[x]);
cov[x]=-inf;
}
}

void relax(int x,int rot)
{
if (x!=rot) relax(fa[x],rot);
down(x);
}

void rotate(int x,int &rot)
{
int y=fa[x],z=fa[y];
int p=(c[y][1]==x),q=p^1;
if (y==rot) rot=x;
else if (c[z][0]==y) c[z][0]=x; else c[z][1]=x;
fa[x]=z; fa[y]=x; fa[c[x][q]]=y;
c[y][p]=c[x][q]; c[x][q]=y;
update(y);
}

void splay(int x,int &rot)
{
relax(x,rot);
while (x!=rot)
{
int y=fa[x], z=fa[y];
if (y!=rot)
if ((c[y][0]==x)xor(c[z][0]==y)) rotate(x,rot); else rotate(y,rot);
rotate(x,rot);
}
update(x);
}

int pick()
{
if (tail) return q[tail--];
else return ++num;
}

int setup(int x)
{
int t=pick();
key[t]=a[x];
cov[t]=-inf;
rev[t]=0;
la[t]=ra[t]=ma[t]=-inf;
return t;
}

int build(int l,int r)
{
int mid=(l+r)>>1,left=0,right=0;
if (l<mid)      left=build(l,mid-1);
int t=setup(mid);
if (r>mid)      right=build(mid+1,r);
if (left)       c[t][0]=left,fa[left]=t;
if (right)c[t][1]=right,fa[right]=t;
update(t);
return t;
}

int find(int t,int k)
{
down(t);
if (k==size[c[t][0]]+1) return t;
if (k<size[c[t][0]]+1) return find(c[t][0],k);
if (k>size[c[t][0]]+1) return find(c[t][1],k-size[c[t][0]]-1);
}

void del(int &x)
{
if (!x) return;
q[++tail]=x;
fa[x]=0;
del(c[x][0]);
del(c[x][1]);
la[x]=ra[x]=ma[x]=-inf;
x=0;
}

int main()
{
//freopen("build.in","r",stdin);
//freopen("build.out","w",stdout);
scanf("%d %d",&n,&m);
for (int i=2;i<=n+1;i++)
{
scanf("%d",&a[i]);
}
a[st=1]=0; a[ed=n+2]=0;
ra[0]=la[0]=ma[0]=-inf;
rot=build(1,n+2);
char sign[20];
int x,y,l,r,z,ans;
for (int i=1;i<=m;i++)
{
scanf("%s",&sign);
if (sign[0]=='I')
{
scanf("%d %d",&x,&y);
l=find(rot,x+1); r=find(rot,x+2);
splay(r,rot); splay(l,c[rot][0]);
for (int j=1;j<=y;j++)
scanf("%d",&a[j]);
int tmp=build(1,y);
fa[tmp]=l; c[l][1]=tmp;
update(l); update(r);
}
if (sign[0]=='D')
{
scanf("%d %d",&x,&y);
l=find(rot,x);  r=find(rot,x+y+1);
splay(r,rot); splay(l,c[rot][0]);
del(c[l][1]);
update(l); update(r);
}
if (sign[0]=='M'&&sign[2]=='K')
{
scanf("%d %d %d",&x,&y,&z);
l=find(rot,x); r=find(rot,x+y+1);
splay(r,rot); splay(l,c[rot][0]);
recover(c[l][1],z);
}
if (sign[0]=='R')
{
scanf("%d %d",&x,&y);
l=find(rot,x); r=find(rot,x+y+1);
splay(r,rot); splay(l,c[rot][0]);
reverse(c[l][1]);
}
if (sign[0]=='G')
{
scanf("%d %d",&x,&y);
l=find(rot,x); r=find(rot,x+y+1);
splay(r,rot); splay(l,c[rot][0]);
ans=sum[c[l][1]];
printf("%d\n",ans);
}
if (sign[0]=='M'&&sign[2]=='X')
{
splay(ed,rot); splay(st,c[rot][0]);
ans=ma[c[st][1]];
printf("%d\n",ans);
}
//for(int i=1;i<=num;i++) cout<<find(rot,i)<<' ';cout<<"\n";
//for (int i=1;i<=num;i++) cout<<i<<' '<<size[i]<<' '<<fa[i]<<' '<<c[i][0]<<' '<<c[i][1]<<' '<<la[i]<<' '<<ra[i]<<' '<<ma[i]<<' '<<sum[i]<<endl;
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: