您的位置:首页 > 其它

树链剖分 树的统计

2016-01-15 20:12 225 查看
/*题目描述 Description
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。

我们将以下面的形式来要求你对这棵树完成一些操作:

I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和

注意:从点u到点v的路径上的节点包括u和v本身

输入描述 Input Description
输入文件的第一行为一个整数n,表示节点的个数。

接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。

接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。

接下来1行,为一个整数q,表示操作的总数。

接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。

输出描述 Output Description
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
树链剖分模版题,支持树上的查询与修改,没有什么新东西,也就是2遍DFS加线段树加lca。

*/
#include<cstdio>
#include<iostream>
using namespace std;
int zhi[30002],n,m,head[30002],next[60002],u[60002],dui[30002],deep[30002];
int size[30002],f[30002],lc[30002][15],n1,lian[30002];
struct shu
{
int l,r,max,sum;
}a[100005];
void dfs1(int a1)
{
size[a1]=1;
f[a1]=1;
for(int i=1;i<=14;i++)
{
if(deep[a1]<(1<<i))
break;
lc[a1][i]=lc[lc[a1][i-1]][i-1];
}
for(int b=head[a1];b;b=next[b])
if(!f[u[b]])
{
deep[u[b]]=deep[a1]+1;
lc[u[b]][0]=a1;
dfs1(u[b]);
size[a1]+=size[u[b]];
}
return;
}
void dfs2(int a1,int a2)
{
int k=0;
n1++;
dui[a1]=n1;
lian[a1]=a2;
for(int i=head[a1];i;i=next[i])
if(size[k]<size[u[i]]&&deep[u[i]]>deep[a1])
k=u[i];
if(k==0)
return;
dfs2(k,a2);
for(int i=head[a1];i;i=next[i])
if(u[i]!=k&&deep[u[i]]>deep[a1])
dfs2(u[i],u[i]);
return;
}
void build(int a1,int a2,int a3)
{
a[a1].l=a2;
a[a1].r=a3;
if(a2+1==a3)
return;
int mid=(a2+a3)>>1;
build(a1*2,a2,mid);
build(a1*2+1,mid,a3);
return;
}
void cha(int a1,int a2,int a3)
{
int lr=a[a1].l,rr=a[a1].r,mid=(lr+rr)>>1;
if(lr+1==rr)
{
a[a1].max=a3;
a[a1].sum=a3;
return;
}
if(a2<mid)
cha(a1*2,a2,a3);
else
cha(a1*2+1,a2,a3);
a[a1].sum=a[a1*2].sum+a[a1*2+1].sum;
a[a1].max=max(a[a1*2].max,a[a1*2+1].max);
return;
}
int lca(int a1,int a2)
{
if(deep[a1]<deep[a2])
swap(a1,a2);
int t=deep[a1]-deep[a2];
for(int i=0;i<=14;i++)
if(t&1<<i)
a1=lc[a1][i];
for(int i=14;i>=0;i--)
if(lc[a1][i]!=lc[a2][i])
{
a1=lc[a1][i];
a2=lc[a2][i];
}
if(a1==a2)
return a1;
else
return lc[a1][0];
}
int xisu(int a1,int a2,int a3)
{
int sum1=0;
if(a[a1].l>=a2&&a3>=a[a1].r)
return a[a1].sum;
int mid=(a[a1].l+a[a1].r)>>1;
if(a2<mid)
sum1+=xisu(a1*2,a2,a3);
if(a3>mid)
sum1+=xisu(a1*2+1,a2,a3);
return sum1;
}
int lihe(int a1,int a2)
{
int sum1=0;
for(;lian[a1]!=lian[a2];)
{
sum1+=xisu(1,dui[lian[a1]],dui[a1]+1);
a1=lc[lian[a1]][0];
}
sum1+=xisu(1,dui[a2],dui[a1]+1);
return sum1;
}
int xima(int a1,int a2,int a3)
{
int sum1=-10000000;
if(a[a1].l>=a2&&a3>=a[a1].r)
return a[a1].max;
int mid=(a[a1].l+a[a1].r)>>1;
if(a2<mid)
sum1=max(sum1,xima(a1*2,a2,a3));
if(a3>mid)
sum1=max(sum1,xima(a1*2+1,a2,a3));
return sum1;
}
int lima(int a1,int a2)
{
int sum1=-10000000;
for(;lian[a1]!=lian[a2];)
{
sum1=max(sum1,xima(1,dui[lian[a1]],dui[a1]+1));
a1=lc[lian[a1]][0];
}
sum1=max(sum1,xima(1,dui[a2],dui[a1]+1));
return sum1;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<n;i++)
{
int a1,a2;
scanf("%d%d",&a1,&a2);
next[i]=head[a1];
head[a1]=i;
u[i]=a2;
next[i+n]=head[a2];
head[a2]=i+n;
u[i+n]=a1;
}
for(int i=1;i<=n;i++)
scanf("%d",&zhi[i]);
dfs1(1);
dfs2(1,1);
build(1,1,n+1);
scanf("%d",&m);
for(int i=1;i<=n;i++)
cha(1,dui[i],zhi[i]);
for(int i=0;i<m;i++)
{
int a1,a2;
char ch[10];
scanf("%s%d%d",ch,&a1,&a2);
if(ch[0]=='C')
{
zhi[a1]=a2;
cha(1,dui[a1],a2);
}
else
{
int t=lca(a1,a2);
if(ch[1]=='S')
printf("%d\n",lihe(a1,t)+lihe(a2,t)-zhi[t]);
else
printf("%d\n",max(lima(a1,t),lima(a2,t)));
}
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: