您的位置:首页 > 其它

[bzoj3611][Heoi2014]大工程

2016-02-12 20:06 369 查看
  看题目感觉应该就是传说中的虚树?

  然后跑去学了一发。。。自己YY了一下然后挂飞。。于是就只好抄模板了T_T

  建完虚树就是个树形dp。。。

  对于询问总和:每条边对答案的贡献是边权*一端的节点数*另一端的节点数。(这里的节点不包括建虚树时添上去的点)

  对于询问最小值最大值,每次计算出经过这个节点的最长||最短路径长度就好了。。

  大概这种题条件都有一个sigma(K)<=n之类的。。而且题目求的东西得符合区间加法。。。不然你把边合在一起也没用>_<

  链剖求lca果然快。。。速度能进前10.。。然而代码长度实在感人= =

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int maxn=1000023;
const int inf=1002333333;
struct zs{
int too,pre;
}e[maxn<<1];
struct zs1{
int too,pre,dis;
}e1[maxn<<2];
int tot,tot1,last[maxn],last1[maxn];
int sz[maxn],mn[maxn],mx[maxn];
int intree[maxn],poi[maxn],rt;
int dfn[maxn],bel[maxn],size[maxn],dep[maxn],fa[maxn],tim;
int st[maxn],top;
int ansmn,ansmx;
ll anssum;
int i,j,k,n,m,K,a,b,lca;

int ra;char rx;
inline int read(){
rx=getchar(),ra=0;
while(rx<'0'||rx>'9')rx=getchar();
while(rx>='0'&&rx<='9')ra*=10,ra+=rx-48,rx=getchar();return ra;
}

inline void insert(int a,int b){
e[++tot].too=b,e[tot].pre=last[a],last[a]=tot;
e[++tot].too=a,e[tot].pre=last[b],last[b]=tot;
}
inline void ins(int a,int b){
//  printf("   %d-->%d\n",a,b);
e1[++tot1].too=b,e1[tot1].dis=dep[b]-dep[a],e1[tot1].pre=last1[a],last1[a]=tot1;
}

void dfs1(int x){
size[x]=1;
for(int i=last[x];i;i=e[i].pre)
if(e[i].too!=fa[x])
fa[e[i].too]=x,
dfs1(e[i].too),
size[x]+=size[e[i].too];
}
void dfs2(int x,int chain){
bel[x]=chain,dfn[x]=++tim,dep[x]=dep[fa[x]]+1;int mxpos=0,i,to;
for(to=e[i=last[x]].too;i;to=e[i=e[i].pre].too)
if(to!=fa[x]&&size[to]>size[mxpos])mxpos=to;
if(!mxpos)return;
dfs2(mxpos,chain);
for(to=e[i=last[x]].too;i;to=e[i=e[i].pre].too)
if(to!=fa[x]&&to!=mxpos)dfs2(to,to);
}
inline int getlca(int a,int b){
if(dep[bel[a]]<dep[bel[b]])swap(a,b);
while(bel[a]!=bel[b]){
a=fa[bel[a]];
if(dep[bel[a]]<dep[bel[b]])swap(a,b);
}
return dep[a]<dep[b]?a:b;
}

bool cmp(int a,int b){return dfn[a]<dfn[b];}

inline int min(int a,int b){return a<b?a:b;}
inline int max(int a,int b){return a>b?a:b;}
void dp(int x){
register int i,to;
if(intree[x]==m)sz[x]=1,mn[x]=mx[x]=0;
else sz[x]=0,mn[x]=inf,mx[x]=-inf;
for(to=e1[i=last1[x]].too;i;to=e1[i=e1[i].pre].too){
dp(to),sz[x]+=sz[to],mn[to]+=e1[i].dis,mx[to]+=e1[i].dis;
anssum+=(ll)sz[to]*(K-sz[to])*e1[i].dis;
if(mn[x]+mn[to]<ansmn)ansmn=mn[x]+mn[to];
if(mx[to]+mx[x]>ansmx)ansmx=mx[x]+mx[to];
if(mn[to]<mn[x])mn[x]=mn[to];if(mx[to]>mx[x])mx[x]=mx[to];
}
last1[x]=0;
}

char s[23];int len;
inline void outll(ll x){
if(!x){putchar('0');return;}
for(len=0;x;s[++len]=x%10,x/=10);
while(len)putchar(s[len--]+48);
}
inline void outint(int x){
if(!x){putchar('0');return;}
for(len=0;x;s[++len]=x%10,x/=10);
while(len)putchar(s[len--]+48);
}

int main(){
register int i;
n=read();
for(i=1;i<n;i++)a=read(),b=read(),insert(a,b);
dfs1(1),dfs2(1,1);
//  for(i=1;i<=n;i++)printf("  %d  %d\n",dep[i],dfn[i]);
//  for(i=1;i<n;i++)for(j=i+1;j<=n;j++)printf("  %d&&%d %d\n",i,j,getlca(i,j));
for(m=read();m;m--){
tot1=0;
for(K=read(),i=1;i<=K;i++)intree[poi[i]=read()]=m;
sort(poi+1,poi+1+K,cmp);top=1,st[1]=poi[1];
//      for(i=1;i<=K;i++)printf("  %d\n",poi[i]);
for(i=2;i<=K;i++){
lca=getlca(poi[i],st[top]);
while(dfn[lca]<dfn[st[top]]&&top)
if(dfn[st[top-1]]<=dfn[lca]){
ins(lca,st[top--]);
if(st[top]!=lca)st[++top]=lca;
}else ins(st[top-1],st[top]),top--;
st[++top]=poi[i];
}
while(top>1)ins(st[top-1],st[top]),top--;
ansmn=inf,ansmx=anssum=0,rt=st[1],
dp(rt);
outll(anssum),putchar(' '),outint(ansmn),putchar(' '),outint(ansmx),putchar('\n');
}
return 0;
}


View Code
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: