您的位置:首页 > 其它

ZOJ 3863 Paths on the Tree 树分治

2015-04-17 12:44 459 查看
题目链接:点击打开链接

题意:

给定n个点的树。 常量k

问:对于一对路径,如果公共点<=k则为合法。

问有多少个合法的路径。

{1-3, 2-4} 和 {2-4,1-3} 视为2个不同的路径对。

1-3, 3-1视为相同路径。

思路:

首先来得到一个O(n^3)的算法:

把问题转成=> 总方案数 - 公共点>k个的路径对数

显然公共点是连续的,所以公共点会组成一条路径,我们设为 x-y,则枚举x和y,就能得到公共的部分(当然要保证x-y的公共点数>k)

那么现在的问题是 以公共路径为x-y 的路径对有多少条。



x有很多子树: x1, x2, x3 ···xi 图中为(1, 3, 3) 设sumx = x_1 + x_2 + ··+ x_i ( 这里sumx = 7

y有很多子树: y1, y2, y3···yi 图中为(1, 3, 1) 设sumy = y_1 + y_2 + ··+ y_i ( 这里sumy = 5

在x子树中选2个点排列的方案数 ans_x = (sum_x - x_i) * x_i (for any i) + (sum_x-1)

(为何加上sum_x-1, 因为不同子树间的方案已经计算过2次,但一个点是x,另一点是子树节点的方案只计算了一次, 所以+ x_1 + x_2 +···+x_i = sum_x-1)

这样就能求出公共路径一端是x,选择2个点的方法数。

化简一下ansx = sumx * (sumx-1) - xi*xi + (sumx-1);

我们设 fang = xi*xi;

则ansx = sumx*(sumx-1) - fang + (sumx-1);

进一步:

我们若要求出删除一个子树w后选2个点的方法数也就能简单地得到:

ansx' = (sumx - xi - w) * xi + (sumx-w-1) { i!=w } = (sumx-w) * (sumx-w-1) - (fang - w*w) + sumx-w-1;

剩下就是树分治。计算公共路径经过重心的方法数。

sum[cur][j] 表示对于当前枚举的重心的子树 ,子树中公共路径端点距离重心的距离恰好为 j 的个数。 相当于上述中的公共路径端点为X时,X端的方法数(即sumx)

sum[old][j] 表示以前枚举的重心的子树,子树中公共路径端点距离重心的距离>= j 的个数。同理相当于上述中公共路径端点为Y时

注意:

1、公共路径外的部分(即X子树中选的2个点,这两个点可以任意)可以经过重心,不能经过重心的只有公共部分的路径。

2、注意在找重心时算出的树的最大深度并不是 重心的最大深度。所以深度要持续更新。

3、清空“后缀和”要多清一点,因为第二条的原因。

done..

/*
by:http://blog.csdn.net/acmmmm
*/
#include <stdio.h>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <stack>
#include <time.h>
#include <queue>
template <class T>
inline bool rd(T &ret) {
	char c; int sgn;
	if (c = getchar(), c == EOF) return 0;
	while (c != '-' && (c<'0' || c>'9')) c = getchar();
	sgn = (c == '-') ? -1 : 1;
	ret = (c == '-') ? 0 : (c - '0');
	while (c = getchar(), c >= '0'&&c <= '9') ret = ret * 10 + (c - '0');
	ret *= sgn;
	return 1;
}
template <class T>
inline void pt(T x) {
	if (x <0) {
		putchar('-');
		x = -x;
	}
	if (x>9) pt(x / 10);
	putchar(x % 10 + '0');
}
using namespace std;
typedef unsigned long long ll;
const int N = 100005;

struct Edge{
	int from, to, nex;
}edge[N << 1];
int head
, edgenum;
void add(int u, int v){ Edge E = { u, v, head[u] }; edge[edgenum] = E; head[u] = edgenum++; }
int size
, parent
;
void dfs_init(int u, int fa){
	size[u] = 1; parent[u] = fa;
	for (int i = head[u]; ~i; i = edge[i].nex){
		int v = edge[i].to; if (v == fa)continue;
		dfs_init(v, u);
		size[u] += size[v];
	}
}
int n, k, maxdep;

int dp
, num
;//num[i]表示 以i为根的树 节点数  
//树重心的定义:dp[i]表示 将i点删去后 最大联通块的点数  
int root;
bool vis
;
int siz;//** 表示当前 计算的树的节点数   
int G
, top;
void getroot(int u, int fa, int deep){//找树的重心  
	dp[u] = 0; num[u] = 1;
	maxdep = max(maxdep, deep);
	for (int i = head[u]; ~i; i = edge[i].nex){
		int v = edge[i].to; if (v == fa || vis[v])continue;
		getroot(v, u, deep + 1);
		num[u] += num[v];
		dp[u] = max(dp[u], num[v]);
	}
	dp[u] = max(dp[u], siz - num[u]);
	if (dp[u] < dp[root])root = u;
}

ll ans, sum[2]
, w
;
int dep
;
ll Siz(int u, int v){
	if (v == parent[u])return size[u];
	else return n - size[v] ;
}
void dfs(int u, int fa, int deep){
	dep[u] = deep; maxdep = max(maxdep, deep);
	w[u] = Siz(u, fa) * (Siz(u, fa) - 1);
	num[u] = 1;
	G[top++] = u;
	for (int i = head[u]; ~i; i = edge[i].nex){
		int v = edge[i].to; if (v == fa)continue;
		w[u] -= Siz(v, u) * Siz(v, u);
		if (vis[v])continue;
		dfs(v, u, deep + 1);
		num[u] += num[v];
	}
	w[u] += Siz(u, fa);
}
void work(int u){
	siz = num[u];
	root = maxdep = 0;
	getroot(u, u, 0);
	if (maxdep * 2 < k)return;
	int old = 1, cur = 0;
	fill(sum[cur], sum[cur] + maxdep + 10, 0);
	sum[cur][0] = 1;
	ll all = n, fang = 0;
	for (int i = head[root]; ~i; i = edge[i].nex){
		int v = edge[i].to;
		fang += Siz(v, root) * Siz(v, root);
	}

	for (int i = head[root], j; ~i; i = edge[i].nex){
		int V = edge[i].to; if (vis[V])continue;
		top = 0;
		dfs(V, root, 1);
		swap(old, cur);
		fill(sum[cur], sum[cur] + maxdep + 10, 0);
		
		for (j = 0; j < top; j++) sum[cur][dep[G[j]]] += w[G[j]];
		for (j = 0; j <= maxdep; j++)
		{
			if (k-j <= maxdep)
			ans += sum[cur][j] * sum[old][max(0, k-j)];
		}
		for (j = maxdep-1; j >= 0; j--) sum[cur][j] += sum[cur][j + 1];
		if (k <= maxdep)
			ans += sum[cur][k] * (all - Siz(V, root) - 1 + (all - Siz(V, root)) * (all - Siz(V, root) - 1) - (fang - Siz(V, root)*Siz(V, root)));
		for (j = maxdep; j >= 0; j--)sum[cur][j] += sum[old][j];
	}
	vis[root] = true;
	for (int i = head[root]; ~i; i = edge[i].nex)
	if (false == vis[edge[i].to]) work(edge[i].to);
}

int main(){
	dp[0] = N;
	int T; rd(T);
	while (T--){
		rd(n); rd(k);
		memset(head, -1, sizeof head); edgenum = 0;
		for (int i = 1, u, v; i < n; i++){
			rd(u); rd(v); add(u, v); add(v, u);
		}
		dfs_init(1, 1);		
		ans = 0;
		num[1] = n;
		memset(vis, 0, sizeof vis);
		work(1);
		ll all = (ll)n*(n + 1) / 2;
		cout << (all * all - ans) << endl;
	}
	return 0;
}
/*
991

6 3
1 2
2 3
2 4
1 5
5 6

4 1
1 2
2 3
3 4

6 1
1 2
1 3
1 4
2 5
5 6

6 1
1 2
1 3
1 4
3 5
3 6

5 1
1 2
1 3
3 4
4 5

6 2
1 2
2 3
3 4
3 5
4 6

3 2
1 2
1 3

5 1
1 2
1 3
1 4
2 5

*/
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: