您的位置:首页 > 其它

BZOJ2588: Spoj 10628. Count on a tree

2015-12-21 14:11 330 查看
题目:http://www.lydsy.com/JudgeOnline/problem.php?id=2588

lca+可持久化线段树

在树上建一棵可持久化线段树就可以了。

#include<cstring>
#include<iostream>
#include<cstdio>
#include<algorithm>
#define rep(i,l,r) for (int i=l;i<=r;i++)
#define down(i,l,r) for (int i=l;i>=r;i--)
#define clr(x,y) memset(x,y,sizeof(x))
#define maxn 100500
#define inf int(1e9)
using namespace std;
struct data{int obj,pre;
}e[maxn*2];
int head[maxn],pos[maxn],sum[maxn*22],ls[maxn*22],rs[maxn*22],dep[maxn],root[maxn*20];
int fa[maxn][22],v[maxn],tmp[maxn],hash[maxn],num[maxn];
int n,m,ans,tot,cnt,cnt2,idx,bin[22];
void insert(int x,int y){
e[++tot].obj=y; e[tot].pre=head[x]; head[x]=tot;
}
int read(){
int x=0,f=1; char ch=getchar();
while (!isdigit(ch)) {if (ch=='-') f=-1; ch=getchar();}
while (isdigit(ch)) {x=x*10+ch-'0'; ch=getchar();}
return x*f;
}
int find(int x){
int l=1,r=cnt;
while (l<r){
int mid=(l+r)/2;
if (hash[mid]==x) return mid;
if (x<hash[mid]) r=mid-1; else l=mid+1;
}
return l;
}
void dfs(int u){
pos[u]=++idx; num[idx]=u;
rep(i,1,20) if (dep[u]>bin[i]) fa[u][i]=fa[fa[u][i-1]][i-1];
for (int j=head[u];j;j=e[j].pre){
int v=e[j].obj;
if (v!=fa[u][0]){
fa[v][0]=u;
dep[v]=dep[u]+1;
dfs(v);
}
}
}
void add(int l,int r,int x,int &y,int val){
y=++cnt2;
sum[y]=sum[x]+1;
if (l==r) return;
ls[y]=ls[x]; rs[y]=rs[x];
int mid=(l+r)/2;
if (val<=mid) add(l,mid,ls[x],ls[y],val);
else add(mid+1,r,rs[x],rs[y],val);
}
int lca(int x,int y){
if (dep[x]<dep[y]) swap(x,y);
int t=dep[x]-dep[y];
rep(i,0,20) if (t&bin[i]) x=fa[x][i];
down(i,20,0) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
if (x!=y) return fa[x][0];
return x;
}
int ask(int x,int y,int k){
int t=lca(x,y);
int a=root[pos[x]],b=root[pos[y]],c=root[pos[t]],d=root[pos[fa[t][0]]];
int l=1,r=cnt;
while (l<r){
int mid=(l+r)/2;
int tmp=sum[ls[a]]+sum[ls[b]]-sum[ls[c]]-sum[ls[d]];
if (k<=tmp) {a=ls[a],b=ls[b],c=ls[c],d=ls[d];r=mid;}
else {k-=tmp; a=rs[a],b=rs[b],c=rs[c],d=rs[d]; l=mid+1;}
}
return hash[l];
}
int main(){
bin[0]=1; rep(i,1,20) bin[i]=bin[i-1]*2;
n=read(); m=read();
rep(i,1,n) v[i]=read(),tmp[i]=v[i];
sort(tmp+1,tmp+1+n);
hash[cnt=1]=tmp[1];
rep(i,2,n) if (tmp[i]!=tmp[i-1]) hash[++cnt]=tmp[i];
rep(i,1,n) v[i]=find(v[i]);
rep(i,1,n-1){
int x=read(),y=read();
insert(x,y); insert(y,x);
}
dep[1]=1; dfs(1);
rep(i,1,n){
int t=num[i];
add(1,cnt,root[pos[fa[t][0]]],root[i],v[t]);
}
rep(i,1,m){
int x=read(),y=read(),k=read();
x=x^ans;
ans=ask(x,y,k);
if (i!=m) printf("%d\n",ans);
else printf("%d",ans);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: