您的位置:首页 > 其它

[BZOJ3611] [Heoi2014]大工程(DP + 虚树)

2018-01-08 19:08 423 查看

传送门

 

$dp[i][0]$表示节点i到子树中的所有点的距离之和

$dp[i][1]$表示节点i到子树中最近距离的点的距离

$dp[i][2]$表示节点i到子树中最远距离的点的距离

建好虚树后dp即可。

因为对于虚树掌握的还不是很熟,有些细节还是要注意。

虚树中可能会加入一些lca节点,这些节点在dp的时候是不应该统计的。

对于本题来说,别忘记考虑某一节点不同子树中点对的组合。

 

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 2000010
#define LL long long

using namespace std;

LL ans1, ans2, dp
[3];
int n, cnt, rp, m, top, T;
int head
, to
, nex
, val
, dis
, size
, dfn
, deep
, f
[21], q
, s
, flag
;

inline int read()
{
int x = 0, f = 1;
char ch = getchar();
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = -1;
for(; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + ch - '0';
return x * f;
}

inline void add(int x, int y)
{
to[cnt] = y;
nex[cnt] = head[x];
head[x] = cnt++;
}

inline void dfs1(int u)
{
int i, v;
dfn[u] = ++rp;
deep[u] = deep[f[u][0]] + 1;
for(i = 0; f[u][i]; i++) f[u][i + 1] = f[f[u][i]][i];
for(i = head[u]; ~i; i = nex[i])
{
v = to[i];
if(!dfn[v])
{
f[v][0] = u;
dis[v] = dis[u] + 1;
dfs1(v);
}
}
head[u] = -1;
}

inline int calc_lca(int x, int y)
{
int i;
if(deep[x] < deep[y]) swap(x, y);
for(i = 20; i >= 0; i--)
if(deep[f[x][i]] >= deep[y]) x = f[x][i];
if(x == y) return x;
for(i = 20; i >= 0; i--)
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}

inline bool cmp(int x, int y)
{
return dfn[x] < dfn[y];
}

inline void dfs2(int u)
{
int i, v;
size[u] = flag[u];
dp[u][1] = 1e9, dp[u][0] = dp[u][2] = 0;
for(i = head[u]; ~i; i = nex[i])
{
v = to[i];
dfs2(v);
size[u] += size[v];
ans1 = min(ans1, dp[u][1] + dp[v][1] + dis[v] - dis[u]);
ans2 = max(ans2, dp[u][2] + dp[v][2] + dis[v] - dis[u]);
dp[u][0] += dp[v][0] + 1ll * size[v] * (m - size[v]) * (dis[v] - dis[u]);
dp[u][1] = min(dp[u][1], dis[v] - dis[u] + dp[v][1]);
dp[u][2] = max(dp[u][2], dis[v] - dis[u] + dp[v][2]);
}
if(flag[u])
{
ans1 = min(ans1, dp[u][1]);
ans2 = max(ans2, dp[u][2]);
dp[u][1] = 0;
}
head[u] = -1;
}

inline void solve()
{
int i, lca;
m = read();
top = cnt = 0;
for(i = 1; i <= m; i++) q[i] = read(), flag[q[i]] = 1;
sort(q + 1, q + m + 1, cmp);
for(i = 1; i <= m; i++)
{
if(!top)
{
s[++top] = q[i];
continue;
}
lca = calc_lca(s[top], q[i]);
while(dfn[lca] < dfn[s[top]])
{
if(dfn[lca] >= dfn[s[top - 1]])
{
add(lca, s[top]);
if(s[--top] != lca) s[++top] = lca;
break;
}
add(s[top - 1], s[top]), top--;
}
s[++top] = q[i];
}
while(top > 1) add(s[top - 1], s[top]), top--;
ans2 = 0;
ans1 = 1ll * 1e9 * 1e9;
dfs2(s[1]);
printf("%lld %lld %lld\n", dp[s[1]][0], ans1, ans2);
for(i = 1; i <= m; i++) flag[q[i]] = 0;
}

int main()
{
int i, x, y;
n = read();
memset(head, -1, sizeof(head));
for(i = 1; i < n; i++)
{
x = read();
y = read();
add(x, y);
add(y, x);
}
dfs1(1);
T = read();
while(T--) solve();
return 0;
}

  

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: