您的位置:首页 > 其它

[HDOJ 4894] Mart Master [树形DP]

2014-08-04 19:12 302 查看
给定一棵树,把树上的节点分给三个人,然后去掉连在属于不同人的节点之间的边,这样每个人都有若干个森林。每个人都有一个能量值,为他所有的奇数个点的树的个数减去偶数个点的树的个数,和0的最大值。问这三个人的能量值的和的期望。为了得到一个整数,我们求期望与3^n的乘积(n为点的个数),模10^9的值。

数据范围:点数不超过300。

首先由于问题是对称的,我们可以只计算一个人的能量值的期望。然后乘以3即为结果。

随便将一个点提为根。然后我们定义状态dp[i][j][k],i表示在以i为根的子树上,k表示未和0取最大值之前的能量值(即他所有的奇数个点的树的个数减去偶数个点的树的个数)为k,j=0表示不选该点,j=1表示选该点且该点所在的树的点的个数为奇数,j=2表示该点所在的树的个数为偶数,dp[i][j][k]表示在i,j,k的条件下的方案数。

复杂度分析:在树上进行分组背包,若每个节点的组内个数为以该节点为根的子树的规模大小,则总复杂度为n^2而不是n^3。

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

const int mod=1000000007;

struct Node {
int fe,min,max;
long long dp[3][610];
bool visited;
};
struct Edge {
int t,ne;
};

Node a[301];
Edge b[610];
long long tmpdp[3][610];
int n,bp;

void putedge(int x,int y) {
b[bp].t=y;
b[bp].ne=a[x].fe;
a[x].fe=bp++;
}

void getans(Node &x) {
int j,k,l;
x.visited=true;
x.dp[0][300]=1;
x.dp[1][301]=1;
x.min=300;
x.max=301;
for (j=x.fe;j!=-1;j=b[j].ne) {
if (!a[b[j].t].visited) {
Node &y=a[b[j].t];
getans(y);
memcpy(tmpdp,x.dp,sizeof(tmpdp));
memset(x.dp,0,sizeof(tmpdp));
for (k=x.min;k<=x.max;k++) {
for (l=y.min;l<=y.max;l++) {
x.dp[0][k+l-300]+=tmpdp[0][k]*y.dp[0][l]%mod;
x.dp[0][k+l-300]+=tmpdp[0][k]*y.dp[1][l]%mod;
x.dp[0][k+l-300]+=tmpdp[0][k]*y.dp[2][l]%mod;
x.dp[1][k+l-300]+=tmpdp[1][k]*y.dp[0][l]%mod;
x.dp[2][k+l-303]+=tmpdp[1][k]*y.dp[1][l]%mod;
x.dp[1][k+l-299]+=tmpdp[1][k]*y.dp[2][l]%mod;
x.dp[2][k+l-300]+=tmpdp[2][k]*y.dp[0][l]%mod;
x.dp[1][k+l-299]+=tmpdp[2][k]*y.dp[1][l]%mod;
x.dp[2][k+l-299]+=tmpdp[2][k]*y.dp[2][l]%mod;
}
}
for (k=0;k<610&&x.dp[0][k]==0&&x.dp[1][k]==0&&x.dp[2][k]==0;k++);
x.min=k;
//printf("%d\n",x.min);
for (k=609;k>=0&&x.dp[0][k]==0&&x.dp[1][k]==0&&x.dp[2][k]==0;k--);
x.max=k;
//printf("%d\n",x.max);
for (k=x.min;k<=x.max;k++) {
x.dp[0][k]%=mod;
x.dp[1][k]%=mod;
x.dp[2][k]%=mod;
//printf("%lld %lld %lld\n",x.dp[0][k],x.dp[1][k],x.dp[2][k]);
}
}
}
//printf("Node %d:\n",(int)(&x-a));
for (k=x.min;k<=x.max;k++) {
x.dp[0][k]<<=1;
if (x.dp[0][k]>=mod) x.dp[0][k]-=mod;
//printf("%d %d\n",k,x.dp[0][k]);
}
}

int main() {
int x,y,i;
while (scanf("%d",&n)!=EOF) {
bp=0;
for (i=1;i<=n;i++) {
memset(a[i].dp,0,sizeof(tmpdp));
a[i].visited=false;
a[i].fe=-1;
}
for (i=1;i<n;i++) {
scanf("%d%d",&x,&y);
putedge(x,y);
putedge(y,x);
}
getans(a[1]);
int ans=0;
for (i=300;i<610;i++) {
ans=(ans+(i-300)*(a[1].dp[0][i]+a[1].dp[1][i]+a[1].dp[2][i]))%mod;
}
ans=(long long)ans*3%mod;
printf("%d\n",ans);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: