您的位置:首页 > 其它

文章标题 HDU 5977 : Garden of Eden (树分治)

2017-10-22 19:29 399 查看
参考自:http://blog.csdn.net/bahuia/article/details/53070036

代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <string>
#include <queue>
#include <set>
#include <map>
#include <algorithm>
#include <math.h>
#include <vector>
using namespace std;
typedef long long ll;

const int inf = 0x3f3f3f3f;
const int mod=1e9+7;
const int maxn=5e4+10;

int n,k,a[maxn];
vector<int>sta;
ll ha[1100],ans;

struct Edge{
int to,nex;
}edge[maxn*2];
int tot,head[maxn],vis[maxn];
void init(){
tot=0;
memset (head,-1,sizeof (head));
memset (vis,0,sizeof (vis));
}
void addedge(int u,int v){
edge[tot]=Edge{v,head[u]};
head[u]=tot++;
}

int sz[maxn],maxv[maxn],rt,Max,root;

void dfs_size(int u,int pre){//求出每个子树的大小,以及每个节点的最大儿子
sz[u]=1;
maxv[u]=0;
for (int i=head[u];i!=-1;i=edge[i].nex){
int v=edge[i].to;
if (v==pre||vis[v])continue;
dfs_size(v,u);
sz[u]+=sz[v];
maxv[u]=max(maxv[u],sz[v]);
}
}

void dfs_root(int r,int u,int pre){//找出以u为根的子树的重心
maxv[u]=max(maxv[u],sz[r]-maxv[u]);
if (Max>maxv[u]){
Max=maxv[u];root=u;
}
for (int i=head[u];i!=-1;i=edge[i].nex){
int v=edge[i].to;
if (v==pre||vis[v])continue;
dfs_root(r,v,u);
}
}

void dfs_sta(int u,int fa,int s){
sta.push_back(s);
for (int i=head[u];i!=-1;i=edge[i].nex){
int v=edge[i].to;
if (v==fa||vis[v])continue;
dfs_sta(v,u,s|(1<<a[v]));
}
}

ll solve(int u,int s){//计算当前子树中合法的点对数
ll res=0;
sta.clear();
dfs_sta(u,-1,s);
memset (ha,0,sizeof (ha));
for (int i=0;i<sta.size();i++)ha[sta[i]]++;
for (int i=0;i<sta.size();i++){
ha[sta[i]]--;
res+=ha[(1<<k)-1];
for (int s0=sta[i];s0;s0=(s0-1)&sta[i]){
res+=ha[((1<<k)-1)^s0];
}
ha[sta[i]]++;
}
return res;
}

void dfs(int u,int num){//总的dfs求解
Max=n;
dfs_size(u,-1);
dfs_root(u,u,-1);
int rt=root;//一定要注意这样里的root是全局变量,在递归之后可能改变,需要提前保存下来。
vis[rt]=1;
ans+=solve(rt,1<<(a[rt]));
for (int i=head[rt];i!=-1;i=edge[i].nex){
int v=edge[i].to;
if (vis[v])continue;
ans-=solve(v,(1<<a[v])|(1<<a[rt]));
}
for (int i=head[rt];i!=-1;i=edge[i].nex){
int v=edge[i].to;
if (vis[v])continue;
dfs(v,sz[v]);
}
}

int main()
{
while (scanf ("%d%d",&n,&k)!=EOF){
init();
for (int i=1;i<=n;i++){
scanf ("%d",&a[i]);
a[i]--;
}
int u,v;
for (int i=0;i<n-1;i++){
scanf ("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
if (k==1){
printf ("%lld\n",(ll)n*n);
continue;
}
ans=0;
dfs(1,n);
printf ("%lld\n",ans);
}
return 0;
}
/*
3 2
1 2 2
1 2
1 3
*/
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: