您的位置:首页 > 其它

bzoj3653 可持久化线段树分析(通俗易懂的个人理解)

2016-11-14 22:17 423 查看
学习了黄学长的文章才有此顿悟,感谢   黄学长bzoj3653传送门:http://hzwer.com/5444.html

写一下自己的理解

可持久化是什么,就是每个时间点都建一颗线段树,

并且,当前时间的这颗线段树里中的数据肯定包含当前时间之前(从0开始)的所有数据

既然这样,实际上就没必要每个时间点建树,只对于当前时间点新加入的数据新建立点,

其余的点就直接用原来的就行,也就是把指针都指向前一个时间(注意是前一个)的那颗树的点就行,也就是说

实际上每个时间新加入的就是一条链。

结合例子就很好分析了:bzoj3653(无视题面!)

这道题是在一颗树中,求三元组(a,b,c)个数,其中a已知,b和a在树上的距离不超过k,且a和b都是c的祖先

那分两种情况好了,(1)b是a的祖先,答案为min(d[a],k)*sz[a](d[i]表示节点i的深度,sz[i]表示节点i子树有多少个点,下同)

(2)a是b的祖先,那么答案就是 ∑sz[i](d[i]-d[a]<k 且 a是i的祖先)

用文字解释,就是:求在dfs序出现a的最早和最晚端点之间的,且d[i]-d[a]<k的点的

(注:dfs序中,a左右端点间节点为a的子树,不解释)

我们会发现,这个查询到的点需要同时满足这两条性质,所以一般的线段树就不行了

所以就直接把dfs序的时间戳做为上文提到的时间,把原树的深度做下标,记录节点的sz就可以了

在dfs序中不同时间遍历到的节点是不同的,所以每棵线段树的跟对应原树的节点也是不同的,自然,

dfs序中,l[a]时间到r[a]时间之间的线段树,肯定包含了a所有子树的sz

每颗线段树记录的是dfs序从根一直到它经过所有点的sz,查询时就求差分一下就是结果

具体实现请看代码

#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
const int N=300100;
int cnt,idx,mx,n,q,h
,d
,ln
,rn
,sz
,dfn
;
struct edge{int y,next;}mp[N*2];
struct node{
ll x;node *l,*r;
node(){l=r=NULL;}
}mem[N*20],*root
;

void adde(int id,int x,int y){
mp[id].y=y,mp[id].next=h[x],h[x]=id;
}
void dfs(int x,int fa){
ln[x]=++idx,dfn[idx]=x;
for(int i=h[x];i;i=mp[i].next)
if(mp[i].y!=fa)
d[mp[i].y]=d[x]+1,dfs(mp[i].y,x),sz[x]+=sz[mp[i].y]+1;
rn[x]=idx;
}
void build(node *&a,node *b,int l,int r,int pos,int val){
a=mem+cnt++;
a->x=val;	if(b)	a->x+=b->x;
if(l==r)	return ;
if(b){if(b->l)a->l=b->l; if(b->r) a->r=b->r;}
int mid=l+r>>1;
if(pos<=mid)	build(a->l,b?b->l:NULL,l,mid,pos,val);
else build(a->r,b?b->r:NULL,mid+1,r,pos,val);
}
ll query(node *rt,int l,int r,int x,int y){
if(!rt)	return 0;if(y>r)	y=r;
if(x<=l && r<=y)	return rt->x;
int mid=l+r>>1;ll ans=0;
if(x<=mid)	ans+=query(rt->l,l,mid,x,y);
if(y>mid)	ans+=query(rt->r,mid+1,r,x,y);
return ans;
}
int main(){
freopen("3653.in","r",stdin);
freopen("3653.out","w",stdout);
scanf("%d%d",&n,&q);
for(int i=1;i<n;i++){
int x,y;scanf("%d%d",&x,&y);
adde(i,x,y),adde(i+n-1,y,x);
}
dfs(1,0);
for(int i=1;i<=n;i++)	mx=max(d[i],mx);
for(int i=1;i<=n;i++)	build(root[i],root[i-1],0,mx,d[dfn[i]],sz[dfn[i]]);
while(q--){
int p,k;scanf("%d%d",&p,&k);
ll f1=1,ans=sz[p]*f1*min(k,d[p]);
ans+=query(root[rn[p]],0,mx,d[p]+1,d[p]+k);
ans-=query(root[ln[p]-1],0,mx,d[p]+1,d[p]+k);
printf("%lld\n",ans);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息