您的位置:首页 > 其它

【XSY1536】【BZOJ3522】【BZOJ4543】【POI2014】Hotel 树形DP 长链剖分 启发式合并

2017-08-12 12:47 465 查看

题目大意

​  给你一棵树,求有多少个组点满足x≠y,x≠z,y≠z,distx,y=distx,z=disty,z

​  1≤n≤100000

题解

​  问题转换为有多少个组点满足disti,x=disti,y=disti,z

​  我们考虑树形DP

​  fi,j=以i为根的子树中与i的距离为j的节点数

​  gi,j=以i为根的子树外选择一个点s满足s到i的距离为j,能新增的的方案数

​  若v是u的重儿子,则:fu,j+=fv,j−1,gu,j+=gv,j+1,这样就可以由u的重儿子转移到u

​  否则:gu,j+=gv,j+1+fv,j−1×fu,j,fu,j+=fv,j−1

​  答案为∑fx,j×gy,j,其中x是y的兄弟

​  可以用长链剖分辅助转移

​  时间复杂度:O(n)

​  gjs大爷的长链剖分讲解

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
struct list
{
int v[200010];
int t[200010];
int h[100010];
int n;
void clear()
{
n=0;
memset(h,0,sizeof h);
}
void add(int x,int y)
{
n++;
v
=y;
t
=h[x];
h[x]=n;
}
};
list l;
ll ans;
ll f[100010];
ll g[200010];
int d[100010];
int bg[100010];
int ed[100010];
int ch[100010];
int t[100010];
int w[100010];
int ti;
void dfs(int x,int fa)
{
d[x]=1;
ch[x]=0;
int i;
for(i=l.h[x];i;i=l.t[i])
if(l.v[i]!=fa)
{
dfs(l.v[i],x);
if(d[l.v[i]]+1>d[x])
{
d[x]=d[l.v[i]]+1;
ch[x]=l.v[i];
}
}
}
void dfs2(int x,int fa,int top)
{
t[x]=top;
w[x]=++ti;
if(x==top)
bg[top]=ti;
ed[top]=ti;
if(ch[x])
dfs2(ch[x],x,top);
int i;
for(i=l.h[x];i;i=l.t[i])
if(l.v[i]!=ch[x]&&l.v[i]!=fa)
dfs2(l.v[i],x,l.v[i]);
}
ll& getf(int x,int y)
{
return f[w[x]+y];
}
ll& getg(int x,int y)
{
return g[2*(w[t[x]]-1)+2*d[t[x]]-d[x]+1-y];
}
void solve(int x,int fa)
{
if(ch[x])
solve(ch[x],x);
int i,j;
for(i=l.h[x];i;i=l.t[i])
if(l.v[i]!=fa&&l.v[i]!=ch[x])
{
int v=l.v[i];
solve(v,x);
for(j=0;j<d[v];j++)
ans+=getf(v,j)*getg(x,j+1);
for(j=1;j<d[v];j++)
ans+=getg(v,j)*getf(x,j-1);
for(j=0;j<d[v];j++)
getg(x,j+1)+=getf(v,j)*getf(x,j+1);
for(j=1;j<d[v];j++)
getg(x,j-1)+=getg(v,j);
for(j=0;j<d[v];j++)
getf(x,j+1)+=getf(v,j);
}
ans+=getg(x,0);
getf(x,0)++;
}
int main()
{
int n;
scanf("%d",&n);
l.clear();
memset(bg,0,sizeof bg);
memset(ed,0,sizeof ed);
memset(f,0,sizeof f);
memset(g,0,sizeof g);
memset(d,0,sizeof d);
memset(ch,0,sizeof ch);
memset(t,0,sizeof t);
memset(w,0,sizeof w);
ans=0;
ti=0;
int i,x,y;
for(i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
l.add(x,y);
l.add(y,x);
}
dfs(1,0);
dfs2(1,0,1);
solve(1,0);
printf("%lld\n",ans);
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: