您的位置:首页 > 其它

spoj_cot2 Count on a tree II(树上莫队+离散化)

2018-03-18 21:24 218 查看
http://www.elijahqi.win/archives/376

You are given a tree with N nodes. The tree nodes are numbered from 1 to N. Each node has an integer weight.

We will ask you to perform the following operation:

u v : ask for how many different integers that represent the weight of nodes there are on the path from u to v.

Input

In the first line there are two integers N and M. (N <= 40000, M <= 100000)

In the second line there are N integers. The i-th integer denotes the weight of the i-th node.

In the next N-1 lines, each line contains two integers u v, which describes an edge (u, v).

In the next M lines, each line contains two integers u v, which means an operation asking for how many different integers that represent the weight of nodes there are on the path from u to v.

Output

For each operation, print its result.

Example

Input:

8 2

105 2 9 3 8 5 7 7

1 2

1 3

1 4

3 5

3 6

3 7

4 8

2 5

7 8

Output:

4

4

找错加调试加学习整整7个小时啊 后来无意把N 多加一个0就过了,我也不知道这个数据是要闹哪样啊。

树上莫队仍然分块,只不过根据dfs序分块

首先将树分块,然后以所属块的编号为第一关键字,以dfs序为第二关键字对询问排序,下面只需要考虑如何由(u,v)链->(u’,v’)链了

令S(u,v)表示u~v的点的集合

S(u,v)=S(root,u) xor S(root,v) xor lca(u,v)

令 T(u,v)=S(root,u) xor S(root,v)

考虑T(u,v)->T(u,v’)

T(u,v) xor T(u,v’)=S(root,v) xor S(root,v’)

T(u,v’)=T(u,v) xor S(root,v) xor S(root,v’)

T(u,v’)=T(u,v) xor T(v,v’)

对lca单独考虑即可

之前没有学习过倍增的lca 洛谷的Lca模板还是用lct+o2优化卡过

倍增lca:首先两个不同深度的节点要先通过倍增给他们上升到同一深度,然后再采取倍增的方法,注意一定从大到小如2^3–>2^2

如何倍增,倍增使用了一个fa[哪个节点][2的i次方]

深搜时:要特殊判断,避免搜索到自己的父节点 dfs返回当前子树搜索到多少个节点,如果超过sqrt(n)那么,给这些节点分在一块内

开一个栈存储这些还没有编号的节点,最后深搜完毕,把剩下没有编号的统一再编号

处理询问的时候,我们按照已经排好序的询问来做,先处理第一个询问

剩下的询问都是在第一个询问的基础上扩展,可以由前面的公式推导得到,针对存在性取反的时候,类似以前做过的莫队题目,如果这个节点出现过,就去看这个节点的权值

最后 提醒本题需要离散化,存边的结构体开两倍

#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<map>
#define N 550000
#define M 110000
using namespace std;
inline int read(){
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9') {ch=getchar();}
while (ch<='9'&&ch>='0') {x=x*10+ch-'0';ch=getchar();}
return x*f;
}
inline void swap(int &x,int &y){
int t=x;x=y;y=t;
}
struct node{
int y,next;
}data[N<<1];
struct node1{
int l,r,id;
}q[M];
int n,m,n1,c
,c1
,dfn
,block_num,top,f
,ans
,num,Log
,low
,fa
[20],h
,bl
;
map<int,int> mm;
//int mm
;
int stack
,ans1;bool visit
;
int dfs(int x){
dfn[x]=++num;int size=0;
for (int i=1;i<=Log[low[x]];++i) fa[x][i]=fa[fa[x][i-1]][i-1];
for (int i=h[x];i;i=data[i].next){
int y=data[i].y;
if (fa[x][0]==y) continue;
fa[y][0]=x;low[y]=low[x]+1;
size+=dfs(y);
if (size>=n1){
block_num++;
for (int i=1;i<=size;++i){
bl[stack[top--]]=block_num;
}
size=0;
}
}
stack[++top]=x;
return size+1;
}
inline bool cmp(node1 a,node1 b){
return bl[a.l]==bl[b.l]?dfn[a.r]<dfn[b.r]:bl[a.l]<bl[b.l];
}
inline int lca(int x,int y){
if (low[x]<low[y]) swap(x,y);
int dis=low[x]-low[y];
for (int i=0;i<=Log[dis];++i) if (dis&(1<<i)) x=fa[x][i];
if (x==y) return x;
for (int i=Log
;i>=0;--i){
if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
}
return fa[x][0];
}
inline void reserve(int x){ //针对存在性取反  并且统计答案
if (visit[x]){
f[c[x]]--;if (!f[c[x]]) ans1--;visit[x]=false;
}else{
f[c[x]]++;if (f[c[x]]==1) ans1++;visit[x]=true;
}
//visit[x]^=1;
}
inline void solve(int x,int y){
while (x!=y) if (low[x]<low[y]) reserve(y),y=fa[y][0];else reserve(x),x=fa[x][0];
}
int main(){
freopen("10707.in","r",stdin);
//freopen("10707.out","w",stdout);
n=read();m=read();n1=sqrt(n);
for (int i=1;i<=n;++i) c[i]=read(),c1[i]=c[i];
sort(c1+1,c1+n+1);
//  for (int i=1;i<=n;++i) printf("%d ",c[i]);
int tmp=std::unique(c1+1,c1+n+1)-c1-1;
for (int i=1;i<=tmp;++i) mm[c1[i]]=i;
for (int i=1;i<=n;++i) c[i]=mm[c[i]];
memset(h,0,sizeof(h));Log[0]=-1;
int tmp1,tmp2,num=0;for (int i=1;i<=n;++i) Log[i]=Log[i>>1]+1;
for (int i=1;i<n;++i) {
tmp1=read();tmp2=read();
data[++num].y=tmp2;data[num].next=h[tmp1];h[tmp1]=num;
data[++num].y=tmp1;data[num].next=h[tmp2];h[tmp2]=num;
}
num=0;dfs(1);block_num++;
while (top) bl[stack[top--]]=block_num;
//for (int i=1;i<=n;++i) printf("%d ",bl[i]);
for (int i=1;i<=m;++i){q[i].l=read();q[i].r=read();q[i].id=i;if (bl[q[i].l]>bl[q[i].r]) swap(q[i].l,q[i].r);}
sort(q+1,q+m+1,cmp);
//for (int i=1;i<=m;++i) printf("%d %d\n",q[i].l,q[i].r);
/*  for (int i=1;i<=n;++i){
for (int j=1;j<=4;++j) printf("%d ",fa[i][j]);
printf("\n");
}*/
tmp=lca(q[1].l,q[1].r);
//  printf("%d ",tmp);
solve(q[1].l,q[1].r);
ans[q[1].id]=ans1+!(f[c[tmp]]);
for (int i=2;i<=m;++i){
solve(q[i-1].l,q[i].l);solve(q[i-1].r,q[i].r);
tmp=lca(q[i].l,q[i].r);
ans[q[i].id]=ans1+!(f[c[tmp]]);
}
for (int i=1;i<=m;++i) printf("%d\n",ans[i]);
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: