您的位置:首页 > 其它

HDU 6035 Colorful Tree (树形DP,dfs)

2017-07-28 14:20 363 查看
Problem Description

There is a tree with n nodes,
each of which has a type of color represented by an integer, where the color of node i is ci.

The path between each two different nodes is unique, of which we define the value as the number of different colors appearing in it.

Calculate the sum of values of all paths on the tree that has n(n−1)2 paths
in total.

 

Input

The input contains multiple test cases.

For each test case, the first line contains one positive integers n,
indicating the number of node. (2≤n≤200000)

Next line contains n integers
where the i-th
integer represents ci,
the color of node i. (1≤ci≤n)

Each of the next n−1 lines
contains two positive integers x,y (1≤x,y≤n,x≠y),
meaning an edge between node x and
node y.

It is guaranteed that these edges form a tree.

 

Output

For each test case, output "Case #x: y"
in one line (without quotes), where x indicates
the case number starting from 1 and y denotes
the answer of corresponding case.

 

Sample Input

3
1 2 1
1 2
2 3
6
1 2 1 3 2 1
1 2
1 3
2 4
2 5
3 6

 

Sample Output

Case #1: 6
Case #2: 29

详见http://blog.csdn.net/Bahuia/article/details/76141574,里面写的很详细

代码:

#include <bits/stdc++.h>
#define mem(p,k) memset(p,k,sizeof(p));
#define rep(a,b,c) for(int a=b;a<c;a++)
#define pb push_back
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define inf 0x6fffffff
#define ll long long
using namespace std;
const int N=2e5+10;
int n;
int color
,vis
;
ll siz
,sum
,ans;
vector<int> vec
;
ll dfs(int u,int pre){
siz[u]=1;
ll last,b,all=0;
int cnt=vec[u].size();
for(int i=0;i<cnt;i++){
int v=vec[u][i];
if(v==pre)continue;
last=sum[color[u]];
siz[u]+=dfs(v,u);
b=sum[color[u]]-last;
ans+=(siz[v]-b)*(siz[v]-b-1)/2;
//cout<<siz[v]<<" "<<b<<"===";
//cout<<ans<<endl;
all+=siz[v]-b;
}
sum[color[u]]+=all+1;
//cout<<color[u]<<" "<<sum[color[u]]<<endl;
return siz[u];
}
int main()
{
int T=1;
while(~scanf("%d",&n)){
mem(vis,0);
mem(sum,0);
ll p=0;
for(int i=1;i<=n;i++){
scanf("%d",color+i);
if(!vis[color[i]])vis[color[i]]=1,p++;
vec[i].clear();
}
for(int i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
vec[x].pb(y);
vec[y].pb(x);
}
if(p==1){
printf("Case #%d: %lld\n",T++,(ll)(n-1)*n/2);
continue;
}
ans=0;
dfs(1,-1);
for(int i=1;i<=n;i++){
if(vis[i]){

ans+=(ll)(n-sum[i])*(n-sum[i]-1)/2;
}
}
printf("Case #%d: %lld\n",T++,p*(ll)(n-1)*n/2-ans);
}

return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: