您的位置:首页 > 其它

BZOJ - 3611 大工程 【虚树+LCA[二分法]+树形DP】

2017-08-28 21:22 375 查看
HYSBZ - 3611传送门

3611: [Heoi2014]大工程

Time Limit: 60 Sec  Memory Limit:
512 MB
Submit: 1634  Solved: 694

[Submit][Status][Discuss]

Description

国家有一个大工程,要给一个非常大的交通网络里建一些新的通道。 
我们这个国家位置非常特殊,可以看成是一个单位边权的树,城市位于顶点上。 
在 2 个国家 a,b 之间建一条新通道需要的代价为树上 a,b 的最短路径。
 现在国家有很多个计划,每个计划都是这样,我们选中了 k 个点,然后在它们两两之间 新建 C(k,2)条 新通道。
现在对于每个计划,我们想知道:
 1.这些新通道的代价和
 2.这些新通道中代价最小的是多少 
3.这些新通道中代价最大的是多少

Input

第一行 n 表示点数。

 接下来 n-1 行,每行两个数 a,b 表示 a 和 b 之间有一条边。
点从 1 开始标号。 接下来一行 q 表示计划数。
对每个计划有 2 行,第一行 k 表示这个计划选中了几个点。
 第二行用空格隔开的 k 个互不相同的数表示选了哪 k 个点。

Output

输出 q 行,每行三个数分别表示代价和,最小代价,最大代价。

题意:见上文。

(下一篇博客会介绍虚树和LCA)

虚树:是根据一棵树上必要的点加上他们的LCA而建立的一棵树。能够节省进行树上操作的时间。

LCA:最近公共祖先,这里是基于二分法,实现的,详见挑程P328.

树形DP(基于虚树):维护一个最小次小最大次大值,这样最大距离最小距离就求出来啦。

然后用sum存对于x子树(包含x),包含的关键点数。对于x到他的儿子y,他们之间的距离要乘上sum[y]*(总关键点-sum[y]).(想一想为什么)

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int INF=0x3f3f3f3f;
const int N=1001000;
int head[N*2],fa[22]
,dep
,id
,state
,TAT,h
;
int n,m,cnt,num;
LL ans,ans1,ans2,max1
,max2
,min1
,min2
,k,sum
,sig
;
struct edge
{
int to,nex;
}e[N*2];
bool cmp(int x,int y)
{
return id[x]<id[y];
}
void add(int from,int to)
{
num++;
e[num].to=to;
e[num].nex=head[from];
head[from]=num;
}
void dfs(int x,int ffa,int d)
{
//cout<<x<<endl;
dep[x]=d;id[x]=cnt++;fa[0][x]=ffa;
for(int i=head[x];i!=-1;i=e[i].nex)
{
int y=e[i].to;
if(y==ffa)continue;
dfs(y,x,d+1);
}
head[x]=-1;
}
int LCA(int u,int v)
{
if(dep[u]<dep[v])swap(u,v);
for(int i=0;i<=m;i++)
{
if((dep[u]-dep[v])&(1<<i))
u=fa[i][u];
}
if(u==v)return u;
for(int i=m;i>=0;i--)
{
if(fa[i][u]!=fa[i][v])
{
u=fa[i][u];
v=fa[i][v];
}
}
return fa[0][u];
}
void init()
{
cnt=1;
dfs(1,-1,1);
for(int i=0;i+1<=m;i++)
{
for(int j=1;j<=n;j++)
{
if(fa[i][j]<0)fa[i+1][j]=-1;
else fa[i+1][j]=fa[i][fa[i][j]];
}
}
}
void bt()
{
sort(h+1,h+k+1,cmp);
int cut=1;
state[cut]=1;
for(int i=1;i<=k;i++)
{
if(h[i]==1)continue;
int lca=LCA(h[i],state[cut]);
if(lca==state[cut])state[++cut]=h[i];
else
{
while(true)
{
int x=state[cut],xx=state[--cut];
if(xx==lca)
{
add(xx,x);
break;
}
if(dep[xx]<dep[lca])
{
add(lca,x);
state[++cut]=lca;
break;
}
add(xx,x);
}
state[++cut]=h[i];
}
}
for(int i=1;i<cut;i++)
add(state[i],state[i+1]);
}
void dfs1(int x)
{

if(sig[x]==TAT)
{ max1[x]=max2[x]=0;
sum[x]=1;
min1[x]=0;
min2[x]=INF;
}
else
{
max1[x]=max2[x]=-INF;
sum[x]=0;
min1[x]=min2[x]=INF;
}
for(int i=head[x];i!=-1;i=e[i].nex)
{
int y=e[i].to;
dfs1(y);
sum[x]+=sum[y];
int tem=dep[y]-dep[x]+min1[y];
if(tem<min1[x])
{
min2[x]=min1[x];
min1[x]=tem;
}
else if(tem<min2[x])
min2[x]=tem;
tem=dep[y]-dep[x]+max1[y];
if(tem>max1[x])
{
max2[x]=max1[x];
max1[x]=tem;
}
else if(tem>max2[x])
max2[x]=tem;
ans=ans+(dep[y]-dep[x])*sum[y]*(k-sum[y]);
ans1=min(ans1,min1[x]+min2[x]);
ans2=max(ans2,max1[x]+max2[x]);
}
head[x]=-1;
}
int main()
{
int x,y;
while(scanf("%d",&n)!=EOF)
{
m=0;
while(n>=(1<<m))m++;
num=0;
memset(head,-1,sizeof(head));
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
init();
memset(sig,-1,sizeof(sig));
scanf("%d",&TAT);
while(TAT--)
{
scanf("%d",&k);
num=0;

for(int i=1;i<=k;i++)
{
scanf("%d",&h[i]);
sig[h[i]]=TAT;
}
bt();
ans=0,ans1=INF,ans2=0;
dfs1(1);
printf("%lld %lld %lld\n",ans,ans1,ans2);
}
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: