您的位置:首页 > 移动开发

Appleman and Tree - CodeForces 461 B 树形dp

2014-08-27 13:16 579 查看
Appleman and Tree

time limit per test
2 seconds

memory limit per test
256 megabytes

input
standard input

output
standard output

Appleman has a tree with n vertices. Some of the vertices (at least one) are colored black and other vertices are colored white.

Consider a set consisting of k (0 ≤ k < n) edges
of Appleman's tree. If Appleman deletes these edges from the tree, then it will split into(k + 1) parts. Note, that each part will be a tree with colored
vertices.

Now Appleman wonders, what is the number of sets splitting the tree in such a way that each resulting part will have exactly one black vertex? Find this number modulo 1000000007 (109 + 7).

Input

The first line contains an integer n (2  ≤ n ≤ 105)
— the number of tree vertices.

The second line contains the description of the tree: n - 1 integers p0, p1, ..., pn - 2 (0 ≤ pi ≤ i).
Where pi means
that there is an edge connecting vertex (i + 1) of the tree and vertex pi.
Consider tree vertices are numbered from 0 to n - 1.

The third line contains the description of the colors of the vertices: n integers x0, x1, ..., xn - 1 (xi is
either 0 or 1). If xi is
equal to 1, vertex i is colored black. Otherwise, vertex i is
colored white.

Output

Output a single integer — the number of ways to split the tree modulo 1000000007 (109 + 7).

Sample test(s)

input
3
0 0
0 1 1


output
2


input
6
0 1 1 0 4
1 1 0 0 1 0


output
1


input
10
0 1 2 1 4 4 4 0 8
0 0 0 1 0 1 1 0 0 1


output
27


题意:将一个树切断一些边,使得每个子树都只有一个黑色的顶点,问有多少种切法。

思路:dp[u][0]代表这个顶点连接下面使得它所在的子树没有黑色顶点的情况,dp[u][1]代表这个顶点连接下面使得它所在的子树只有一个黑色顶点的情况。

AC代码如下:

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<vector>
using namespace std;
typedef long long ll;
void exgcd(ll a,ll b,ll& d,ll& x,ll& y)
{
	if(!b)d=a,x=1LL,y=0LL;
	else exgcd(b,a%b,d,y,x),y-=x*(a/b);
}
ll inv(ll a,ll m)
{
	ll d,x,y;
	exgcd(a,m,d,x,y);
	return d==1LL?(x+m)%m:-1LL;
}
vector<int> vc[100010];
ll dp[100010][2],MOD=1000000007;
int root,vis[100010],col[100010];
void dfs(int u)
{ int i,j,k,len=vc[u].size();
  if(col[u]==1)
  { dp[u][1]=1;
    for(i=0;i<len;i++)
    { dfs(vc[u][i]);
      if(col[vc[u][i]]==1)
       dp[u][1]*=dp[vc[u][i]][1];
      else
       dp[u][1]*=(dp[vc[u][i]][0]+dp[vc[u][i]][1]);
      dp[u][1]%=MOD;
    }
  }
  else if(col[u]==0)
  { dp[u][0]=1;
    for(i=0;i<len;i++)
    { dfs(vc[u][i]);
      if(col[vc[u][i]]==1)
       dp[u][0]*=dp[vc[u][i]][1];
      else
       dp[u][0]*=(dp[vc[u][i]][0]+dp[vc[u][i]][1]);
      dp[u][0]%=MOD;
    }
    for(i=0;i<len;i++)
    { if(col[vc[u][i]]==1)
       dp[u][1]+=dp[u][0]*inv(dp[vc[u][i]][1],MOD)%MOD*dp[vc[u][i]][1];
      else
       dp[u][1]+=dp[u][0]*inv(dp[vc[u][i]][0]+dp[vc[u][i]][1],MOD)%MOD*dp[vc[u][i]][1];
      dp[u][1]%=MOD;
    }
  }
  dp[u][1]%=MOD;
  dp[u][0]%=MOD;
}
void solve()
{ int T,t,n,m,i,j,k=0,u;
  scanf("%d",&n);
  for(i=1;i<n;i++)
  { scanf("%d",&u);
    vc[u].push_back(i);
  }
  for(i=0;i<n;i++)
  { scanf("%d",&col[i]);
    k+=col[i];
  }
  if(k==0)
  { printf("0\n");
    return;
  }
  if(k==1)
  { printf("1\n");
    return;
  }
  dfs(0);
  printf("%I64d\n",dp[0][1]);
}
int main()
{ solve();
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: