您的位置:首页 > 其它

hdu5593/ZYB's Tree 树形dp

2015-12-06 16:47 302 查看

ZYB's Tree

Memory Limit: 131072/131072 K (Java/Others)

问题描述
ZYBZYB有一颗NN个节点的树,现在他希望你对于每一个点,求出离每个点距离不超过KK的点的个数.

两个点(x,y)(x,y)在树上的距离定义为两个点树上最短路径经过的边数,

为了节约读入和输出的时间,我们采用如下方式进行读入输出:

读入:读入两个数A,BA,B,令fa_ifa​i​​为节点ii的父亲,fa_1=0fa​1​​=0;fa_i=(A*i+B)\%(i-1)+1fa​i​​=(A∗i+B)%(i−1)+1 i \in [2,N]i∈[2,N] .

输出:输出时只需输出NN个点的答案的xorxor和即可。

输入描述
第一行一个整数TT表示数据组数。

接下来每组数据:

一行四个正整数N,K,A,BN,K,A,B.

最终数据中只有两组N \geq 100000N≥100000。

1 \leq T \leq 51≤T≤5,1 \leq N \leq 5000001≤N≤500000,1 \leq K \leq 101≤K≤10,1 \leq A,B \leq 10000001≤A,B≤1000000

输出描述
TT行每行一个整数表示答案.

输入样例
1
3 1 1 1

输出样例
3

题解:定义dp[i][j]为以i为根距离为j的点的个数

定义dp2[i][j] 在除去i的子树的点中,与点i距离为j的点的个数

在遍历图求出dp[][]后

对于fa,son

我们求dp2的转移方程就是

dp2[son][h]=dp[fa][h-1]-dp[son][h-2]+dp[fa][h-1];

//meek
#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <queue>
#include <map>
#include <set>
#include <stack>
#include <sstream>
#include <vector>
using namespace std ;
typedef long long ll;
#define mem(a) memset(a,0,sizeof(a))
#define pb push_back
#define fi first
#define se second

inline ll read()
{
ll x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-')f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
//****************************************

const int N=500000+100;
const ll inf = 1ll<<61;
const int mod= 1000000007;

int a,b,K,n;
int vis
;
vector<int >G
;
int nex;
int dp
[11],dp2
[11];
void dfs(int x) {
dp[x][0]=1;
for(int i=0;i<G[x].size();i++) {
dfs(G[x][i]);
for(int j=1;j<=K;j++) {
dp[x][j]+=dp[G[x][i]][j-1];
}
}

}
int main() {
int T;
scanf("%d",&T);
while(T--) {
mem(dp),mem(dp2);
scanf("%d%d%d%d",&n,&K,&a,&b);
for(int i=0;i<=N;i++) G[i].clear();
for(int i=2;i<=n;i++) {
ll fa=(a+b)%(i-1)+1;
G[fa].pb(i);
}int A=0,ans;
dfs(1);
for(int i=1;i<=n;i++) {
for(int j=0;j<G[i].size();j++) {
dp2[G[i][j]][1]=dp[i][0];
for(int h=2;h<=K;h++)
dp2[G[i][j]][h]=dp[i][h-1]-dp[G[i][j]][h-2]+dp2[i][h-1];
}
ans=0;
for(int j=0;j<=K;j++) {
ans+=dp[i][j]+dp2[i][j];
}
A^=ans;
}
printf("%d\n",A);
}
return 0;
}


代码
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: