您的位置:首页 > 其它

[hackerrank]Counting on a tree

2016-11-14 17:33 323 查看

题目大意

给你一棵树,每个点有一个颜色

若干询问,每次询问两条树路径上,存在多少点对(i,j)满足以下条件:

1、i不等于j

2、i和j颜色相同

3、i在第一条树路径上,j在第二条树路径上

第一种算法

我们先将颜色离散化,然后统计每种颜色有多少个点。

先不考虑第一个限制,假设对于一种颜色c,我处理出了d数组d[x]表示x到根路径上有多少个颜色为c的,那么一个询问x->y和u->v答案为(d[x]+d[y]−d[lca(x,y)]−d[fa[lca(x,y)]])∗(d[a]+d[b]−d[lca(a,b)]−d[fa[lca(a,b)]])

因此考虑暴力枚举c,然后计算d,然后计算对询问的贡献。

第二种算法

考虑(d[x]+d[y]−d[lca(x,y)]−d[fa[lca(x,y)]])∗(d[a]+d[b]−d[lca(a,b)]−d[fa[lca(a,b)]])

可以拆成至多16项,每一项形如d[x]*d[y]。

意义是什么?那就是x和y到根路径上有多少等颜色点对。

考虑把询问挂在点上,然后进行一次dfs。

我们设a[k]表示k与当前搜索点到根路径上有多少等颜色点对。

假设dfs到x点,那么就枚举所有与x颜色相同的点y,显然y子树内的点a值都应该+1,子树修改可以使用线段树+dfs序解决。

退出x点时再枚举一次-1去除影响即可。

平衡结合

第一种算法好暴力,颜色种数太多就炸了!

第二种算法好暴力,同种颜色点数量太多就炸了!

思考平衡结合算法,设置阈值B(可取根号n),对于同种颜色点数量<=B的颜色做第二种算法,否则做第一种算法,易证复杂度为n根号n log n。

减重

考虑i可能等于j,事实上多算的答案一定是两条树路径路径交的长度。

可以考虑树剖,也可以考虑下面这个把所有情况几乎都讨论完的简单算法,详见代码

void getans3(int id){
int u=b[id][0],v=b[id][1],x=b[id][2],y=b[id][3],w=lca(u,v),z=lca(x,y);
lc[1]=lca(u,x);
lc[2]=lca(u,y);
lc[3]=lca(v,x);
lc[4]=lca(v,y);
sort(lc+1,lc+5);
if (lc[1]==lc[2]&&lc[1]==lc[3]&&lc[1]==lc[4]){
if (lca(u,v)==lc[1]&&lca(x,y)==lc[1]) ans[id]-=1;
return;
}
if (lc[1]==lc[2]&&lc[1]==lc[3]){
ans[id]-=(ll)getdis(lc[1],lc[4]);
return;
}
if (lc[4]==lc[2]&&lc[4]==lc[3]){
ans[id]-=(ll)getdis(lc[1],lc[4]);
return;
}
if (lc[1]==lc[2]&&lc[3]==lc[4]){
if (w==lc[1]&&z==lc[3]) ans[id]-=1;
else if (w==lc[3]&&z==lc[1]) ans[id]-=1;
return;
}
if (lc[1]==lc[2]){
ans[id]-=(ll)getdis(lc[3],lc[4]);
return;
}
if (lc[2]==lc[3]){
ans[id]-=(ll)getdis(lc[1],lc[4]);
return;
}
if (lc[3]==lc[4]){
ans[id]-=(ll)getdis(lc[1],lc[2]);
return;
}
}


最终整道题的代码

#include<cstdio>
#include<algorithm>
#include<cmath>
#define fo(i,a,b) for(i=a;i<=b;i++)
using namespace std;
typedef long long ll;
const int maxn=100000+10,maxq=50000+10,B=400;
ll ans[maxq];
int h[maxn],go[maxn*2],nxt[maxn*2],cnt[maxn],d[maxn],dep[maxn],a[maxn],gjx[maxn];
int ask[maxq*16][4],b[maxq][4],fa[maxn][25],size[maxn],dfn[maxn],zjy[maxn],cg[maxn];
//ask:0~1 point 2 which ask 3 xi shu
int ad[maxn*4];
int h2[maxn],g2[maxq*32],n2[maxq*32];
int h3[maxn],g3[maxn],n3[maxn];
int lc[5];
bool bz[maxn];
int i,j,k,l,t,n,m,q,tot,top,euler,num,now;
int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
void add(int x,int y){
go[++tot]=y;
nxt[tot]=h[x];
h[x]=tot;
}
void add2(int x,int y){
g2[++tot]=y;
n2[tot]=h2[x];
h2[x]=tot;
}
void add3(int x,int y){
g3[++tot]=y;
n3[tot]=h3[x];
h3[x]=tot;
}
void dfs(int x,int y){
dfn[x]=++top;
dep[x]=dep[y]+1;
fa[x][0]=y;
int t=h[x];
size[x]=1;
while (t){
if (go[t]!=y){
dfs(go[t],x);
size[x]+=size[go[t]];
}
t=nxt[t];
}
}
int lca(int x,int y){
if (dep[x]<dep[y]) swap(x,y);
if (dep[x]!=dep[y]){
int j=zjy[dep[x]];
while (j>=0){
if (dep[fa[x][j]]>=dep[y]) x=fa[x][j];
j--;
}
}
if (x==y) return x;
int j=zjy[dep[x]];
while (j>=0){
if (fa[x][j]!=fa[y][j]){
x=fa[x][j];
y=fa[y][j];
}
j--;
}
return fa[x][0];
}
void cr(int x,int y,int f){
ask[++num][0]=x;
ask[num][1]=y;
ask[num][2]=i;
ask[num][3]=f;
add2(x,num);
if (x!=y) add2(y,num);
}
void work(int u,int v,int x,int y){
int w=lca(u,v),z=lca(x,y);
cr(u,x,1);
cr(u,y,1);
cr(u,z,-1);
if (fa[z][0]) cr(u,fa[z][0],-1);
cr(v,x,1);
cr(v,y,1);
cr(v,z,-1);
if (fa[z][0]) cr(v,fa[z][0],-1);
cr(w,x,-1);
cr(w,y,-1);
cr(w,z,1);
if (fa[z][0]) cr(w,fa[z][0],1);
if (fa[w][0]){
cr(fa[w][0],x,-1);
cr(fa[w][0],y,-1);
cr(fa[w][0],z,1);
if (fa[z][0]) cr(fa[w][0],fa[z][0],1);
}
}
void dg(int x,int y){
d[x]=d[y]+(a[x]==now);
int t=h[x];
while (t){
if (go[t]!=y) dg(go[t],x);
t=nxt[t];
}
}
void getans1(int id){
int u=b[id][0],v=b[id][1],x=b[id][2],y=b[id][3];
ans[id]+=(ll)(d[u]+d[v]-2*d[lca(u,v)])*(d[x]+d[y]-2*d[lca(x,y)]);
}
void mark(int p,int v){
ad[p]+=v;
}
void down(int p){
if (ad[p]){
mark(p*2,ad[p]);
mark(p*2+1,ad[p]);
ad[p]=0;
}
}
void change(int p,int l,int r,int a,int b,int v){
if (l==a&&r==b){
mark(p,v);
return;
}
down(p);
int mid=(l+r)/2;
if (b<=mid) change(p*2,l,mid,a,b,v);
else if (a>mid) change(p*2+1,mid+1,r,a,b,v);
else{
change(p*2,l,mid,a,mid,v);
change(p*2+1,mid+1,r,mid+1,b,v);
}
}
int query(int p,int l,int r,int a){
if (l==r) return ad[p];
down(p);
int mid=(l+r)/2;
if (a<=mid) return query(p*2,l,mid,a);else return query(p*2+1,mid+1,r,a);
}
void getans2(int x,int z,int id,int v,int f){
//if (id!=5) return;
ans[id]+=(ll)v*f;
}
void solve(int x,int y){
//bz[x]=1;
int t,l,z;
if (cnt[a[x]]<=B){
t=h3[a[x]];
while (t){
change(1,1,n,dfn[g3[t]],dfn[g3[t]]+size[g3[t]]-1,1);
t=n3[t];
}
}
t=h[x];
while (t){
if (go[t]!=y) solve(go[t],x);
t=nxt[t];
}
t=h2[x];
while (t){
if (ask[g2[t]][0]==x) z=ask[g2[t]][1];else z=ask[g2[t]][0];
if (!bz[z]){
l=query(1,1,n,dfn[z]);
getans2(x,z,ask[g2[t]][2],l,ask[g2[t]][3]);
}
t=n2[t];
}
bz[x]=1;
if (cnt[a[x]]<=B){
t=h3[a[x]];
while (t){
change(1,1,n,dfn[g3[t]],dfn[g3[t]]+size[g3[t]]-1,-1);
t=n3[t];
}
}
}
int getdis(int x,int y){
return dep[x]+dep[y]-2*dep[lca(x,y)]+1;
}
void getans3(int id){ int u=b[id][0],v=b[id][1],x=b[id][2],y=b[id][3],w=lca(u,v),z=lca(x,y); lc[1]=lca(u,x); lc[2]=lca(u,y); lc[3]=lca(v,x); lc[4]=lca(v,y); sort(lc+1,lc+5); if (lc[1]==lc[2]&&lc[1]==lc[3]&&lc[1]==lc[4]){ if (lca(u,v)==lc[1]&&lca(x,y)==lc[1]) ans[id]-=1; return; } if (lc[1]==lc[2]&&lc[1]==lc[3]){ ans[id]-=(ll)getdis(lc[1],lc[4]); return; } if (lc[4]==lc[2]&&lc[4]==lc[3]){ ans[id]-=(ll)getdis(lc[1],lc[4]); return; } if (lc[1]==lc[2]&&lc[3]==lc[4]){ if (w==lc[1]&&z==lc[3]) ans[id]-=1; else if (w==lc[3]&&z==lc[1]) ans[id]-=1; return; } if (lc[1]==lc[2]){ ans[id]-=(ll)getdis(lc[3],lc[4]); return; } if (lc[2]==lc[3]){ ans[id]-=(ll)getdis(lc[1],lc[4]); return; } if (lc[3]==lc[4]){ ans[id]-=(ll)getdis(lc[1],lc[2]); return; } }
int main(){
freopen("data.in","r",stdin);freopen("data.out","w",stdout);
n=read();q=read();
fo(i,1,n) gjx[i]=a[i]=read();
sort(gjx+1,gjx+n+1);
l=unique(gjx+1,gjx+n+1)-gjx-1;
fo(i,1,n) a[i]=lower_bound(gjx+1,gjx+l+1,a[i])-gjx;
fo(i,1,n-1){
j=read();k=read();
add(j,k);add(k,j);
}
dfs(1,0);
fo(i,1,n) zjy[i]=floor(log(i)/log(2));
fo(j,1,zjy
)
fo(i,1,n)
fa[i][j]=fa[fa[i][j-1]][j-1];
fo(i,1,n) cnt[a[i]]++;
top=0;
fo(i,1,n)
if (cnt[i]>B) cg[++top]=i;
tot=0;
fo(i,1,n) add3(a[i],i);
tot=0;
fo(i,1,q){
b[i][0]=read();b[i][1]=read();b[i][2]=read();b[i][3]=read();
work(b[i][0],b[i][1],b[i][2],b[i][3]);
}
fo(i,1,top){
now=cg[i];
dg(1,0);
fo(j,1,q) getans1(j);
}
solve(1,0);
fo(i,1,q) getans3(i);
fo(i,1,q) printf("%lld\n",ans[i]);
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: