您的位置:首页 > 其它

zoj 3649 lca与倍增dp

2015-10-30 11:47 253 查看
参考:http://www.xuebuyuan.com/609502.html

先说题意:

给出一幅图,求最大生成树,并在这棵树上进行查询操作:给出两个结点编号x和y,求从x到y的路径上,由每个结点的权值构成的序列中的极差大小——要求,被减数要在减数的后面,即形成序列{a1,a2…aj …ak…an},求ak-aj (k>=j)的最大值。

求路径,显然用到lca。

太孤陋寡闻,才知道原来倍增dp能用来求LCA。

用p[u][i]表示结点u的第1<< i 个祖先结点,则有递推如下:

for(int i=0;i<POW;i++) p[u][i]=p[p[u][i-1]][i-1]。

在对图dfs的时候即完成递推。

要想求两个结点的lca,首先使得两结点高度相同,若二者的父亲结点不同,则一直向上查找。dep数组表示结点的深度。

int LCA(int a,int b){

if(dep[a]>dep[b]) swap(a,b);

if(dep[a]<dep[b]){

//这一部分使得dep[a]==dep[b]

int tmp=dep[b]-dep[a];

for(int i=0;i<POW;i++) if(tmp&(1<<i))

//这里从POW-1到0来遍历也是一样的

b=p[b][i];

}

if(a!=b){

for(int i=POW-1;i>=0;i--) if(p[a][i]!=p[b][i])

a=p[a][i],b=p[b][i];

a=p[a][0],b=p[b][0];

}

return a;

}

如此即返回结点的lca。

用倍增遍历的思路:

因为一段路被二进制分成了一截一截,或者说路径长度被用二进制表示了出来。而两个结点的深度差即为“路径长度”,所以只要tmp&(1<<i),则表示这是“路径”的其中一个结点,以此类推,从而得到两个深度相同的结点。

有了这个基础之后,用相同的方式构建——

mx数组,mx[u][i]表示从u到其第1<<i个祖先结点路径上的最大值

mn数组,mn[u][i]表示从u到其第1<<i个祖先结点路径上的最小值

dp数组,dp[u][i],表示从u到其第1<<i个祖先结点路径上的最大差值

dp2数组,dp2[u][i],表示从其第1<<i个祖先结点到u路径上的最大差值

构建好后是查询部分。给出结点x和y,获得lca。

则路径被分成两段—— x->lca->y。则有三种可能性:

x到lca上的最大差值;lca到y上的最大差值;x到y上的最大差值(即lca到y的最大值减去x到lca的最小值)。比较一下即可。

这题真心涨姿势。代码如下:

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=3e4+10,M=N<<1,POW=16,inf=21e8;
int mx
[POW],mn
[POW],p
[POW],dp
[POW],dp2
[POW];
int head
,nxt[M],to[M],cnt,val
,vis
,dep
;
int n,m,q,fa
;
struct Edge{
int u,v,w;
bool operator < (const Edge e) const{
return w>e.w;
}
}E[M];
void ini(int n){
memset(head,-1,sizeof(head));
cnt=0;
memset(vis,0,sizeof(vis));
fill(p[0],p[n+1],0);
fill(mx[0],mx[n+1],-inf);
fill(mn[0],mn[n+1],inf);
fill(dp[0],dp[n+1],-inf);
fill(dp2[0],dp2[n+1],-inf);
dep[0]=0;
}
int find_(int x){
return x==fa[x]?x:fa[x]=find_(fa[x]);
}
void addedge(int u,int v){
to[cnt]=v;
nxt[cnt]=head[u];
head[u]=cnt++;
}
int Kruskal(){
for(int i=0;i<=n;i++) fa[i]=i;
sort(E,E+m);
int sum=0;
for(int i=0;i<m;i++){
int a=find_(E[i].u),b=find_(E[i].v);
if(a!=b){
fa[a]=b;
addedge(E[i].u,E[i].v);
addedge(E[i].v,E[i].u);
sum+=E[i].w;
}
}
return sum;
}
void dfs(int u,int f){
dep[u]=dep[f]+1;
vis[u]=1;
for(int i=head[u];~i;i=nxt[i]) if(!vis[to[i]]){
int v=to[i];
p[v][0]=u;
mx[v][0]=max(val[u],val[v]);
mn[v][0]=min(val[u],val[v]);
dp[v][0]=val[u]-val[v];
dp2[v][0]=val[v]-val[u];
for(int j=1;j<POW;j++){
p[v][j]=p[p[v][j-1]][j-1];
mx[v][j]=max(mx[v][j-1],mx[p[v][j-1]][j-1]);
mn[v][j]=min(mn[v][j-1],mn[p[v][j-1]][j-1]);

dp[v][j]=max(dp[v][j-1],dp[p[v][j-1]][j-1]);
dp[v][j]=max(dp[v][j],mx[p[v][j-1]][j-1]-mn[v][j-1]);

dp2[v][j]=max(dp2[v][j-1],dp2[p[v][j-1]][j-1]);
dp2[v][j]=max(dp2[v][j],mx[v][j-1]-mn[p[v][j-1]][j-1]);
}
dfs(v,u);
}
}
int LCA(int a,int b){
//第一次看到这样的LCA,holy high
//有点不明觉厉
if(dep[a]>dep[b]) swap(a,b);
if(dep[a]<dep[b]){
//这一部分使得dep[a]==dep[b]
int tmp=dep[b]-dep[a];
for(int i=POW-1;i>=0;i--) if(tmp&(1<<i))
b=p[b][i];
}
if(a!=b){
//如果高度相等,而a!=b
for(int i=POW-1;i>=0;i--) if(p[a][i]!=p[b][i])
a=p[a][i],b=p[b][i];
a=p[a][0],b=p[b][0];
}
return a;
}
int getmax(int x,int lca){
int ans=0,tmp=dep[x]-dep[lca];
for(int i=POW-1;i>=0;i--) if(tmp&(1<<i)){
ans=max(ans,mx[x][i]);
x=p[x][i];
}
return ans;
}
int getmin(int x,int lca){
int ans=inf,tmp=dep[x]-dep[lca];
for(int  i=POW-1;i>=0;i--) if(tmp&(1<<i)){
ans=min(ans,mn[x][i]);
x=p[x][i];
}
return ans;
}
int getleft(int x,int lca){
int ans=0,minn=inf;
int tmp=dep[x]-dep[lca];
for(int i=POW-1;i>=0;i--) if(tmp&(1<<i)){
ans=max(ans,dp[x][i]);
ans=max(ans,mx[x][i]-minn);
minn=min(minn,mn[x][i]);
x=p[x][i];
}
return ans;
}
int getright(int x,int lca){
int ans=0,maxx=0;
int tmp=dep[x]-dep[lca];
for(int i=POW-1;i>=0;i--) if(tmp&(1<<i)){
ans=max(ans,dp2[x][i]);
ans=max(ans,maxx-mn[x][i]);
maxx=max(maxx,mx[x][i]);
x=p[x][i];
}
return ans;
}
int main(){
freopen("in.txt","r",stdin);
while(~scanf("%d",&n)){
for(int i=1;i<=n;i++)
scanf("%d",&val[i]);
ini(n);
scanf("%d",&m);
for(int i=0;i<m;i++)
scanf("%d%d%d",&E[i].u,&E[i].v,&E[i].w);
printf("%d\n",Kruskal());
dfs(1,0);
scanf("%d",&q);
int x,y;
while(q--){
scanf("%d%d",&x,&y);
int lca=LCA(x,y);
int ans=getmax(y,lca)-getmin(x,lca);
ans=max(ans,getleft(x,lca));
ans=max(ans,getright(y,lca));
printf("%d\n",ans);
}
}
return 0;
}


View Code
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: