您的位置:首页 > 其它

hdu 6035 Colorful Tree(树形DP)

2017-07-26 17:35 369 查看


Colorful Tree

Time Limit: 6000/3000 MS (Java/Others)    Memory Limit: 131072/131072 K (Java/Others)

Total Submission(s): 1194    Accepted Submission(s): 477


Problem Description

There is a tree with n nodes,
each of which has a type of color represented by an integer, where the color of node i is ci.

The path between each two different nodes is unique, of which we define the value as the number of different colors appearing in it.

Calculate the sum of values of all paths on the tree that has n(n−1)2 paths
in total.

 

Input

The input contains multiple test cases.

For each test case, the first line contains one positive integers n,
indicating the number of node. (2≤n≤200000)

Next line contains n integers
where the i-th
integer represents ci,
the color of node i. (1≤ci≤n)

Each of the next n−1 lines
contains two positive integers x,y (1≤x,y≤n,x≠y),
meaning an edge between node x and
node y.

It is guaranteed that these edges form a tree.

 

Output

For each test case, output "Case #x: y"
in one line (without quotes), where x indicates
the case number starting from 1 and y denotes
the answer of corresponding case.

 

Sample Input

3
1 2 1
1 2
2 3
6
1 2 1 3 2 1
1 2
1 3
2 4
2 5
3 6

 

Sample Output

Case #1: 6
Case #2: 29

解题报告

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<vector>
#include<cmath>
//#include <bits/stdc++.h>
using namespace std;
const int N = 200000+7;
typedef long long LL;
int vis
, f
, c
;
LL size1
, col
;
vector<int>p
;
LL cnt;
void dfs(int u,int f)
{
size1[u]=1;
LL tmp=0;
if(col[c[u]]!=0)
{
tmp=col[c[u]];
col[c[u]]=0;
}
for(int i=0;i<p[u].size();i++)
{
int v=p[u][i];
if(v==f) continue;
dfs(v,u);
size1[u]+=size1[v];
cnt+=(LL)(size1[v]-col[c[u]])*(size1[v]-col[c[u]]-1)/2;
col[c[u]]=0;
}
col[c[u]]=tmp+size1[u];
return ;
}

int main()
{
LL n;
int ncase=1;
while(scanf("%lld", &n)!=EOF)
{
LL color=0;
memset(vis,0,sizeof(vis));
memset(col,0,sizeof(col));
memset(c,0,sizeof(c));
for(int i=1;i<=n;i++)
{
scanf("%d", &c[i]);
if(vis[c[i]]==0) color++;
vis[c[i]]=1;
p[i].clear();
}
for(int i=1;i<n;i++)
{
int x, y;
scanf("%d %d", &x, &y);
p[x].push_back(y),p[y].push_back(x);
}
cnt=0;
dfs(1,-1);
for(int i=1;i<=n;i++)
{
if(!vis[i]) continue;
cnt+=(LL)(n-col[i])*(n-col[i]-1)/2;
}
LL x=(LL)(n*(n-1))/2*color-cnt;
printf("Case #%d: %lld\n",ncase++,x);
}
return 0;
}

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: