您的位置:首页 > 其它

[BZOJ3611][Heoi2014]大工程(虚树+树形dp)

2017-03-12 12:30 393 查看

题目描述

传送门

题解

令size(i)表示i子树里有多少个关键点

令sum(i)表示i子树中所有关键点到i的距离和

令Max(i)表示i子树中所有关键点到它的最长链,_Max(i)次长链,Min(i)最短链,_Min(i)次短链

这些都非常好维护,第二问和第三问也很好计算,用最和次拼一下就行了

对于第一问的话,在dp的时候维护一下当前size和sum的乘积就行了

将所有的关键点和它们的lca建出一棵虚树,边权为两点之间的距离

然后按照上面的dp就行了

dp的时候要格外注意子树的根是否是关键点以及儿子的个数

代码

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
using namespace std;
#define LL long long
#define N 1000005
#define sz 20

int n,q,k,dfs_clock,top;
int tot,point
,nxt[N*2],v[N*2],c[N*2];
int pt
,key
,flag
,stack
,h
,in
,out
,f
[sz+3],size
;
LL sum
,Max
,_Max
,Min
,_Min
,ans1,ans2,ans3;

const LL inf=1e18;

void add(int x,int y,int z)
{
++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=z;
}
void build(int x,int fa)
{
h[x]=h[fa]+1;in[x]=++dfs_clock;
for (int i=1;i<sz;++i) f[x][i]=f[f[x][i-1]][i-1];
for (int i=point[x];i;i=nxt[i])
if (v[i]!=fa)
{
f[v[i]][0]=x;
build(v[i],x);
}
out[x]=++dfs_clock;
}
int cmp(int a,int b)
{
return in[a]<in[b];
}
int lca(int x,int y)
{
if (h[x]<h[y]) swap(x,y);
int cha=h[x]-h[y];
for (int i=0;i<sz;++i)
if ((cha>>i)&1) x=f[x][i];
if (x==y) return x;
for (int i=sz-1;i>=0;--i)
if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
void treedp(int x)
{
size[x]=0;
sum[x]=0;
Max[x]=_Max[x]=0;
Min[x]=_Min[x]=inf;
if (key[x]==q) size[x]=1,Min[x]=0;
int cnt=0;
for (int i=point[x];i;i=nxt[i])
{
++cnt;
treedp(v[i]);
ans1+=(LL)size[v[i]]*sum[x]+(LL)size[x]*(sum[v[i]]+(LL)c[i]*(LL)size[v[i]]);
size[x]+=size[v[i]];
sum[x]+=sum[v[i]]+(LL)c[i]*(LL)size[v[i]];
if (Max[v[i]]+(LL)c[i]>Max[x])
{
_Max[x]=Max[x];
Max[x]=Max[v[i]]+(LL)c[i];
}
else _Max[x]=max(_Max[x],Max[v[i]]+(LL)c[i]);
if (Min[v[i]]+(LL)c[i]<Min[x])
{
_Min[x]=Min[x];
Min[x]=Min[v[i]]+(LL)c[i];
}
else _Min[x]=min(_Min[x],Min[v[i]]+(LL)c[i]);
}
if (key[x]==q||cnt>1)
{
ans2=min(ans2,Min[x]+_Min[x]);
ans3=max(ans3,Max[x]+_Max[x]);
}
point[x]=0;
}
int main()
{
scanf("%d",&n);
for (int i=1;i<n;++i)
{
int x,y;scanf("%d%d",&x,&y);
add(x,y,1),add(y,x,1);
}
build(1,0);
memset(point,0,sizeof(point));
scanf("%d",&q);
while (q)
{
scanf("%d",&k);
for (int i=1;i<=k;++i)
{
scanf("%d",&pt[i]);
key[pt[i]]=flag[pt[i]]=q;
}
sort(pt+1,pt+k+1,cmp);pt[0]=k;
for (int i=2;i<=k;++i)
{
int r=lca(pt[i-1],pt[i]);
if (flag[r]!=q)
{
flag[r]=q;
pt[++pt[0]]=r;
}
}
if (flag[1]!=q) flag[1]=q,pt[++pt[0]]=1;
sort(pt+1,pt+pt[0]+1,cmp);
tot=0;stack[top=1]=1;
for (int i=2;i<=pt[0];++i)
{
while (in[pt[i]]<in[stack[top]]||in[pt[i]]>out[stack[top]])
--top;
add(stack[top],pt[i],h[pt[i]]-h[stack[top]]);
stack[++top]=pt[i];
}
ans1=0;ans2=inf;ans3=0;
treedp(1);
printf("%lld %lld %lld\n",ans1,ans2,ans3);
--q;
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: