您的位置:首页 > 其它

[BZOJ 3757]苹果树:树上莫队

2017-05-23 11:43 453 查看
点击这里查看原题

建议先看这里的讲解

树上莫队大体思路就是把树分成块,区间与区间之间实现O(sqrt(n))的转换,然后当成莫队去做。

/*
User:Small
Language:C++
Problem No.:BZOJ 3757
*/
#include<bits/stdc++.h>
#define ll long long
#define inf 999999999
using namespace std;
const int M=1e5+5;
int n,m,cnt,num[M],pos[M],dfn[M],stk[M],tp,rt,fir[M],tot,anc[M][20],lg[M],c[M],t,ans[M],dep[M],dfs_clock;
bool vis[M];
struct edge{
int v,nex;
}e[M];
struct no{
int u,v,a,b,id;
bool operator<(const no b)const{
return pos[u]==pos[b.u]?dfn[v]<dfn[b.v]:pos[u]<pos[b.u];
}
}q[M];
void add(int u,int v){
e[++tot]=(edge){v,fir[u]};
fir[u]=tot;
}
int dfs(int u){
int siz=0;
dfn[u]=++dfs_clock;
for(int i=fir[u];i;i=e[i].nex){
int v=e[i].v;
if(anc[u][0]==v) continue;
anc[v][0]=u;
dep[v]=dep[u]+1;
siz+=dfs(v);
if(siz>=t){
tot++;
for(int i=1;i<=siz;i++) pos[stk[tp--]]=tot;
siz=0;
}
}
stk[++tp]=u;
return siz+1;
}
int LCA(int u,int v){
if(dep[u]<dep[v]) swap(u,v);
int d=dep[u]-dep[v];
for(int i=lg[d];i>=0;i--)
if((d>>i)&1) u=anc[u][i];
if(u==v) return u;
for(int i=lg
;i>=0;i--)
if(anc[u][i]!=anc[v][i]) u=anc[u][i],v=anc[v][i];
return anc[u][0];
}
void change(int x){
if(!vis[x]){
vis[x]=1;
num[c[x]]++;
if(num[c[x]]==1) cnt++;
}
else{
vis[x]=0;
num[c[x]]--;
if(num[c[x]]==0) cnt--;
}
}
void work(int u,int v,int lca){
while(u!=lca){
change(u);
u=anc[u][0];
}
while(v!=lca){
change(v);
v=anc[v][0];
}
}
void solve(){
for(int i=1;i<=m;i++){
int lca=LCA(q[i].u,q[i].v);
work(q[i-1].u,q[i].u,LCA(q[i-1].u,q[i].u));
work(q[i-1].v,q[i].v,LCA(q[i-1].v,q[i].v));
change(lca);
ans[q[i].id]=cnt;
if(num[q[i].a]&&num[q[i].b]&&q[i].a!=q[i].b) ans[q[i].id]--;
change(lca);
}
}
int main(){
freopen("data.in","r",stdin);//
scanf("%d%d",&n,&m);
t=sqrt(n);
for(int i=1;i<=n;i++) scanf("%d",&c[i]);
for(int i=1;i<=n;i++){
int u,v;
scanf("%d%d",&u,&v);
if(!u||!v) rt=u+v;
else add(u,v),add(v,u);
}
for(int i=1;i<=m;i++){
scanf("%d%d%d%d",&q[i].u,&q[i].v,&q[i].a,&q[i].b);
q[i].id=i;
}
tot=0;
int rm=dfs(rt);
for(int i=1;i<=rm;i++) pos[stk[tp--]]=tot;
sort(q+1,q+m+1);
lg[0]=-1;
for(int i=1;i<=n;i++) lg[i]=lg[i>>1]+1;
for(int i=1;i<=lg
;i++)
for(int j=1;j<=n;j++) anc[j][i]=anc[anc[j][i-1]][i-1];
q[0]=(no){rt,rt,0,0,0};
solve();
for(int i=1;i<=m;i++) printf("%d\n",ans[i]);
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: