您的位置:首页 > 其它

[BZOJ]1036: [ZJOI2008]树的统计Count

2016-11-25 12:40 302 查看
一道树链剖分的裸题,要注意找最大值的时候ans要设得很小,我因此WA了一次。。。。。。

/**************************************************************
Problem: 1036
User: 200815147
Language: C++
Result: Accepted
Time:2792 ms
Memory:7816 kb
****************************************************************/

#include<cstdio>
#include<cstring>
const int Q=50005;
struct edge
{
int y,next;
}b[Q*2];
int len=0,last[Q];
void ins(int x,int y)
{
int t=++len;
b[t].y=y;b[t].next=last[x];last[x]=t;
}
int tot[Q],son[Q],dep[Q],ys[Q],fa[Q],top[Q];
void pre_node(int x)
{
tot[x]=1;son[x]=0;
for(int i=last[x];i!=-1;i=b[i].next)
{
int y=b[i].y;
if(fa[x]!=y)
{
fa[y]=x;
dep[y]=dep[x]+1;
pre_node(y);
if(tot[y]>tot[son[x]]) son[x]=y;
tot[x]+=tot[y];
}
}
}
int z=0;
void pre_edge(int x,int tp)
{
ys[x]=++z;top[x]=tp;
if(son[x]!=0) pre_edge(son[x],tp);
for(int i=last[x];i!=-1;i=b[i].next)
{
int y=b[i].y;
if(son[x]!=y && fa[x]!=y) pre_edge(y,y);
}
}
struct tree
{
int l,r,Max,Sum,lc,rc;
}a[Q*3];
int trlen=0;
void build(int l,int r)
{
int t=++trlen;
a[t].l=l;a[t].r=r;
if(l<r)
{
int mid=(l+r)>>1;
a[t].lc=trlen+1;build(l,mid);
a[t].rc=trlen+1;build(mid+1,r);
}
}
int mymax(int x,int y) {return x>y?x:y;}
void change(int now,int x,int c)
{
if(a[now].l==a[now].r)
{
a[now].Max=a[now].Sum=c;
return;
}
int lc=a[now].lc,rc=a[now].rc,mid=(a[now].l+a[now].r)>>1;
if(x<=mid) change(lc,x,c);
else change(rc,x,c);
a[now].Sum=a[lc].Sum+a[rc].Sum;
a[now].Max=mymax(a[lc].Max,a[rc].Max);
}
int findmax(int now,int l,int r)
{
if(a[now].l==l && a[now].r==r) return a[now].Max;
int lc=a[now].lc,rc=a[now].rc,mid=(a[now].l+a[now].r)>>1;
if(r<=mid) return findmax(lc,l,r);
else if(l>mid) return findmax(rc,l,r);
else return mymax(findmax(lc,l,mid),findmax(rc,mid+1,r));
}
int findsum(int now,int l,int r)
{
if(a[now].l==l && a[now].r==r) return a[now].Sum;
int lc=a[now].lc,rc=a[now].rc,mid=(a[now].l+a[now].r)>>1;
if(r<=mid) return findsum(lc,l,r);
else if(l>mid) return findsum(rc,l,r);
else return findsum(lc,l,mid)+findsum(rc,mid+1,r);
}
int n,m;
void solve1(int x,int y)
{
int tx=top[x],ty=top[y],ans=-300000;
while(tx!=ty)
{
//printf("x=%d y=%d tx=%d ty=%d\n",x,y,tx,ty);
if(dep[tx]>dep[ty])
{
int t=tx;tx=ty;ty=t;
t=x;x=y;y=t;
}
//printf("x=%d y=%d tx=%d ty=%d\n",x,y,tx,ty);
ans=mymax(ans,findmax(1,ys[ty],ys[y]));
//printf("%d %d %d\n",ys[ty],ys[y],ans);
y=fa[
4000
ty];ty=top[y];
}
if(dep[x]>dep[y]) {int t=x;x=y;y=t;}
ans=mymax(ans,findmax(1,ys[x],ys[y]));
printf("%d\n",ans);
}
void solve2(int x,int y)
{
int tx=top[x],ty=top[y],ans=0;
while(tx!=ty)
{
if(dep[tx]>dep[ty])
{
int t=tx;tx=ty;ty=t;
t=x;x=y;y=t;
}
ans+=findsum(1,ys[ty],ys[y]);
y=fa[ty];ty=top[y];
}
if(dep[x]>dep[y]) {int t=x;x=y;y=t;}
/*printf("%d %d %d %d\n",x,y,ys[x],ys[y]);
printf("%d\n",ans);*/
ans+=findsum(1,ys[x],ys[y]);
printf("%d\n",ans);
}
int main()
{
//freopen("1.txt","w",stdout);
memset(last,-1,sizeof(last));
scanf("%d",&n);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
ins(x,y);ins(y,x);
}//printf("ok");
tot[0]=dep[1]=fa[1]=0;pre_node(1);
pre_edge(1,1);
build(1,n);
for(int i=1;i<=n;i++)
{
int d;
scanf("%d",&d);
change(1,ys[i],d);
}
scanf("%d",&m);//for(int i=1;i<=len;i++)
//printf("%d %d %d %d\n",i,a[i].l,a[i].r,a[i].Sum);
for(int i=1;i<=m;i++)
{
char s[10];
int x,y;
scanf("%s%d%d",s,&x,&y);
if(s[0]=='C') change(1,ys[x],y);
else if(s[1]=='M') solve1(x,y);
else solve2(x,y);
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: