您的位置:首页 > 其它

bzoj1036: [ZJOI2008]树的统计Count(树链剖分)

2017-09-16 10:12 507 查看
题目传送门

树链剖分模板吧。

不过用LCA好像也是可以的。(表示懒得用)

树链剖分:

dep[x]:表示x的深度,深度越大离根越远。

fa[x]:表示x的父亲节点。

tot[x]:表示以x为根的子树的家族数量。

son[x]:x的重儿子,选取家族数量最多的子节点作为重儿子。

每一个节点连接自己的重儿子,那么就会出现链,称为重链。

对于同一条重链上的点,赋予每一个节点新的编号,而且在同一条重链中节点的新编号都是连续的。

(注意!相邻的两条重链新编号不一定连续哦!!)

ys[x]:表示节点x的新编号。

top[x]:表示x所在的重链的起始位置是谁!

那么既然同一条重链上的节点的新编号都是连续的,那么我们就可以用各种数据结构去管理区特征值,这里我们用的是线段树。

在本题中,需要求的是路径上的最大点值。那么线段树中的特征值就为当前这一段的最大值!

这道题的解法:

相对于修改操作,把x节点的值修改为y,很显然,在线段树里修改即可。

操作为:change(1,ys[x],y)

相对于求值操作,从x到y路径上的最大点值,那么我们只要使得x和y跳到同一条重链中,然后直接用线段树求解答案即可!!

跳法:

每一次把祖先(top)深度较低的节点往上跳!为什么要比较祖先呢??

因为你使得x和y到同一条重链中,那么肯定是祖先相同。

如果你直接比较两个点的深度的话,那么在跳的过程中有可能祖先跳过头了!

代码实现:

#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<cmath>
#include<algorithm>
using namespace std;
struct node {
int x,y,next;
}a[510000];int len,last[210000];
void ins(int x,int y) {
len++;
a[len].x=x;a[len].y=y;
a[len].next=last[x];last[x]=len;
}
int n,fa[210000],dep[210000],son[210000],tot[210000]; //详情见分析
void pre_tree_node(int x) {
son[x]=0;tot[x]=1;
for(int k=last[x];k;k=a[k].next) {
int y=a[k].y;
if(y!=fa[x]) {
fa[y]=x;
dep[y]=dep[x]+1;
pre_tree_node(y);
if(tot[son[x]]<tot[y]) //如果当前的子节点的家族数大于原来重儿子的家族数,那么更新重儿子!
son[x]=y;
tot[x]+=tot[y];
}
}
}
int z,top[210000],ys[210000]; //z表示当前节点的新编号
void pre_tree_edge(int x,int tp) { //表示x的祖先为tp(也就是重链起始端)
ys[x]=++z;top[x]=tp;
if(son[x]!=0)  //重儿子优先递归,使得同一条重链中节点的新编号都是连续的
pre_tree_edge(son[x],tp);
for(int k=last[x];k;k=a[k].next) {
int y=a[k].y;
if(y!=fa[x]&&y!=son[x]) //如果我不是我父亲的重儿子那么说明我是重链的起始端
pre_tree_edge(y,y);
}
}
struct trnode {
int l,r,lc,rc,c,cc; //c值要根据每道题的问题来变换含义,在此题中c值为这一段的最大值,cc为这一段的和
}tr[510000];int trlen;
void bt(int l,int r) {
trlen++;int now=trlen;
tr[now].l=l;tr[now].r=r;tr[now].c=0;
tr[now].lc=tr[now].rc=-1;
if(l<r) {
int mid=(l+r)/2;
tr[now].lc=trlen+1;bt(l,mid);
tr[now].rc=trlen+1;bt(mid+1,r);
}
}
void change(int now,int x,int k) {
if(tr[now].l==tr[now].r) {
tr[now].c=tr[now].cc=k;return ;
}
int lc=tr[now].lc,rc=tr[now].rc,mid=(tr[now].l+tr[now].r)/2;
if(x<=mid)
change(lc,x,k);
else
change(rc,x,k);
tr[now].c=max(tr[lc].c,tr[rc].c); //维护线段树
tr[now].cc=tr[lc].cc+tr[rc].cc;
}
int findmax(int now,int l,int r) {
if(tr[now].l==l&&tr[now].r==r)
return tr[now].c;
int lc=tr[now].lc,rc=tr[now].rc,mid=(tr[now].l+tr[now].r)/2;
if(r<=mid)
return findmax(lc,l,r);
else if(l>mid)
return findmax(rc,l,r);
else
return max(findmax(lc,l,mid),findmax(rc,mid+1,r));
}
int findsum(int now,int l,int r) {
if(tr[now].l==l&&tr[now].r==r)
return tr[now].cc;
int lc=tr[now].lc,rc=tr[now].rc,mid=(tr[now].l+tr[now].r)/2;
if(r<=mid)
return findsum(lc,l,r);
else if(l>mid)
return findsum(rc,l,r);
else
return findsum(lc,l,mid)+findsum(rc,mid+1,r);
}
int solve1(int x,int y) { //用来处理x到y的最大值
int tx=top[x],ty=top[y],ans=-99999999;
while(tx!=ty) { //tx!=ty说明x和y不在同一条重链中,那么就要进来跳
if(dep[tx]>dep[ty]) {  //将祖先深度较小的节点存在y,然后让y往上跳
swap(tx,ty);swap(x,y);
}
ans=max(ans,findmax(1,ys[ty],ys[y])); //跳过的路径相当于走过的路径,要记录答案
y=fa[ty];ty=top[y]; //往上跳一条重链
}
//最后跳出循环后x和y在同一条重链中了。最后再处理一下答案即可
if(x==y)
return max(ans,findmax(1,ys[x],ys[x]));
if(dep[x]>dep[y])
swap(x,y);
return max(ans,findmax(1,ys[x],ys[y]));
}
int solve2(int x,int y) { //同理,用来处理x到y的距离
int tx=top[x],ty=top[y],ans=0;
while(tx!=ty) { //tx!=ty说明x和y不在同一条重链中,那么就要进来跳
if(dep[tx]>dep[ty]) {  //将祖先深度较小的节点存在y,然后让y往上跳
swap(tx,ty);swap(x,y);
}
ans=ans+findsum(1,ys[ty],ys[y]); //跳过的路径相当于走过的路径,要记录答案
y=fa[ty];ty=top[y]; //往上跳一条重链
}
//最后跳出循环后x和y在同一条重链中了。最后再处理一下答案即可
if(x==y)
return ans+findsum(1,ys[x],ys[x]);
if(dep[x]>dep[y])
swap(x,y);
return ans+findsum(1,ys[x],ys[y]);
}
int s[210000];
int main() {
int n,m;scanf("%d",&n);
len=0;memset(last,0,sizeof(last));
for(int i=1;i<n;i++) {
int x,y;scanf("%d%d",&x,&y);
ins(x,y);ins(y,x);
}
fa[1]=0;dep[1]=0;pre_tree_node(1);
z=0;pre_tree_edge(1,1);
trlen=0;bt(1,z);
for(int i=1;i<=n;i++)
scanf("%d",&s[i]);
for(int i=1;i<=n;i++)
change(1,ys[i],s[i]);
scanf("%d",&m);
for(int i=1;i<=m;i++) {
char s[11];int x,y;
scanf("%s%d%d",s+1,&x,&y);
if(s[1]=='C')
change(1,ys[x],y);
else if(s[2]=='M')
printf("%d\n",solve1(x,y));
else
printf("%d\n",solve2(x,y));
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: