您的位置:首页 > 其它

【BZOJ】3991 [SDOI2015]寻宝游戏 树形DP+虚树+set

2017-12-24 19:17 351 查看
题目传送门

其实这题并没有真正的用到虚树,只是用到了虚树的思想。

首先考虑暴力树形DP,时间复杂度还是O(n×m),必须要优化。

然后我们把思路转移到虚树上,发现问题转化为改变一个节点是否为关键点,答案就是虚树上所有边权*2。

我们考虑一个节点加入虚树产生的贡献,就是DFS序中和当前节点相邻的节点的路径长度*2,删除同理。那么我们每次维护改变的节点的贡献即可。

可以用一个set来维护当前虚树中的节点,然后就变成了求前驱后继之类的操作。

注意:树根不是1,但也不能每次减当前树根的深度,所以我们放在最后让树链的并减去树根的深度。

p.s.十年OI一场空,不开long long见祖宗……记得开long long啊……

附上AC代码:

#include <cstdio>
#include <cctype>
#include <algorithm>
#include <set>
using namespace std;

typedef long long ll;
const int N=1e5+10;
struct side{
ll to,w,nt;
}s[N<<1];
int n,m,h
,num,a
,t;
int d
,f
,sz
,hs
,top
,wz
,size,rl
,x,y,w;
bool b
;
set <int> u;
set <int> :: iterator it;
ll ans,tmp,dis
;

inline char nc(void){
static char ch[100010],*p1=ch,*p2=ch;
return p1==p2&&(p2=(p1=ch)+fread(ch,1,100010,stdin),p1==p2)?EOF:*p1++;
}

inline void read(int &a){
static char c=nc();int f=1;
for (;!isdigit(c);c=nc()) if (c=='-') f=-1;
for (a=0;isdigit(c);a=(a<<3)+(a<<1)+c-'0',c=nc());
return (void)(a*=f);
}

inline void add(int x,int y,ll w){
s[++num]=(side){y,w,h[x]},h[x]=num;
s[++num]=(side){x,w,h[y]},h[y]=num;
}

inline void so1(int x,int fa){
d[x]=d[f[x]=fa]+1,sz[x]=1;
for (int i=h[x]; i; i=s[i].nt)
if (s[i].to!=fa){
dis[s[i].to]=dis[x]+s[i].w,so1(s[i].to,x),sz[x]+=sz[s[i].to];
if (sz[s[i].to]>sz[hs[x]]) hs[x]=s[i].to;
}
return;
}

inline void so2(int x,int fa){
top[x]=fa,wz[x]=++size,rl[size]=x;
if (hs[x]) so2(hs[x],fa);
for (int i=h[x]; i; i=s[i].nt)
if (s[i].to!=f[x]&&s[i].to!=hs[x]) so2(s[i].to,s[i].to);
return;
}

inline int lca(int x,int y){
for (int fx=top[x],fy=top[y]; fx!=fy; x=f[fx],fx=top[x])
if (d[fx]<d[fy]) swap(fx,fy),swap(x,y);
return d[x]<d[y]?x:y;
}

inline ll calc(int x,int y){return dis[x]+dis[y]-(dis[lca(x,y)]<<1);}

int main(void){
read(n),read(m);
for (int i=1; i<n; ++i) read(x),read(y),read(w),add(x,y,1ll*w);
so1(1,0),so2(1,1),u.insert(-1e9),u.insert(1e9);
while (m--){
read(x);
if (a[x]) u.erase(wz[x]),t=-1;
else u.insert(wz[x]),t=1;
a[x]^=1,it=u.upper_bound(wz[x]);
int r=*it,l=*(--it); if (l>=wz[x]) l=*(--it);
if (l!=-1e9) ans+=calc(rl[l],x)*t; if (r!=1e9) ans+=calc(rl[r],x)*t;
if (l!=-1e9&&r!=1e9) ans-=calc(rl[l],rl[r])*t;
tmp=(u.size()>3)?calc(rl[*u.upper_bound(-1e9)],rl[*--u.lower_bound(1e9)]):0;
printf("%lld\n",ans+tmp);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: