您的位置:首页 > 编程语言 > C语言/C++

【模板】最近公共祖先(LCA)

2017-08-16 17:28 465 查看
题自洛谷

题目描述

如题,给定一棵有根多叉树,请求出指定两个点直接最近的公共祖先。

输入输出格式

输入格式:

第一行包含三个正整数N、M、S,分别表示树的结点个数、询问的个数和树根结点的序号。

接下来N-1行每行包含两个正整数x、y,表示x结点和y结点之间有一条直接连接的边(数据保证可以构成树)。

接下来M行每行包含两个正整数a、b,表示询问a结点和b结点的最近公共祖先。

输出格式:

输出包含M行,每行包含一个正整数,依次为每一个询问的结果。

输入输出样例

输入样例#1:

5 5 4

3 1

2 4

5 1

1 4

2 4

3 2

3 5

1 2

4 5

输出样例#1:

4

4

1

4

4

说明

时空限制:1000ms,128M

数据规模:

对于30%的数据:N<=10,M<=10

对于70%的数据:N<=10000,M<=10000

对于100%的数据:N<=500000,M<=500000

样例说明:

该树结构如下:



第一次询问:2、4的最近公共祖先,故为4。

第二次询问:3、2的最近公共祖先,故为4。

第三次询问:3、5的最近公共祖先,故为1。

第四次询问:1、2的最近公共祖先,故为4。

第五次询问:4、5的最近公共祖先,故为4。

故输出依次为4、4、1、4、4

LCA的两种常见算法就是树上倍增和tarjan

两者的区别就是前者是在线算法后者是在线算法

(一个直接出答案一个是最后一起出答案)

这里都提供一下吧,但是树上倍增的代码不一定对(洛谷上只有30分)

看了下别人的题解说是卡常,但是懒得去改了



tarjan

#include<cstdio>
using namespace std;
const int M=2000005,N=1000005; //数据要多开一倍
int fa
,head
,qhead
,vis
;
struct edge{
int to,next;
}edge
;
struct qedge{
int to,next,ans;
}q[M];
int read()
{
int sum=0;
char ch=getchar();
while(ch>'9'||ch<'0') ch=getchar();
while(ch<='9'&&ch>='0')
{
sum=sum*10+ch-48;
ch=getchar();
}
return sum;
}
int cnt;
void add(int x,int y)
{
edge[++cnt].to=y;
edge[cnt].next=head[x];
head[x]=cnt;
}
void aadd(int x,int y,int z)
{
q[z].to=y;
q[z].next=qhead[x];
qhead[x]=z;
}
int find(int x)
{
if(fa[x]!=x) fa[x]=find(fa[x]);
return fa[x];
}
void unity(int x,int y)
{
x=find(x),y=find(y);
fa[y]=x;
}
void tarjan(int x)
{
vis[x]=1;
for(int i=head[x];i;i=edge[i].next)
{
int to=edge[i].to;
if(!vis[to])
{
tarjan(to);
unity(x,to);
}
}
for(int i=qhead[x];i;i=q[i].next)
{
if(vis[q[i].to]==2)
{
q[i].ans=find(q[i].to);
if(i%2)q[i+1].ans=q[i].ans;
else q[i-1].ans=q[i].ans;
}
}
vis[x]=2;
}
int main()
{
int n=read(),m=read(),s=read();
for(int i=1;i<n;i++)
{
int x=read(),y=read();
add(y,x),add(x,y);
fa[i]=i;
}
fa
=n;
for(int i=1;i<=m;i++)
{
int x=read(),y=read();
aadd(x,y,i*2-1);
aadd(y,x,i*2);
}
tarjan(s);
for(int i=1;i<=n;i++)
printf("%d\n",q[i*2].ans);
return 0;
}


倍增

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
struct node{
int to,next;
}e[1000005];
int lca[30][600005],deep[600005],h[1000005];
int n,m,t,root;
void add(int u,int v)
{
e[++t].to=v;
e[t].next=h[u];
h[u]=t;
}
int read()
{
int sum=0;
char ch=getchar();
while(ch>'9'||ch<'0') ch=getchar();
while(ch<='9'&&ch>='0')
{
sum=sum*10+ch-48;
ch=getchar();
}
return sum;
}
void dfs(int n,int f,int d)
{
lca[0]
=f;
deep
=d;
for(int i=h
;i!=-1;i=e[i].next)
{
if(e[i].to!=f)
dfs(e[i].to,n,d+1);
}
}
void init()
{
dfs(root,-1,0);
for(int k=0;k<t-1;k++)
{
for(int i=1;i<=n;i++)
{
if(lca[k][i]<0) lca[k+1][i]=-1;
else lca[k+1][i]=lca[k][lca[k][i]];
}
}
}
int LCA(int x,int y)
{
if(deep[x]>deep[y]) swap(x,y);
for(int i=0;i<t;i++)
{
if(((deep[y]-deep[x])>>i)&1) y=lca[i][y];
}
if(x==y) return x;
for(int i=t;i>=0;i--)
{
if(lca[i][x]!=lca[i][y])
{
x=lca[i][x];
y=lca[i][y];
}
}

return lca[0][x];
}
int main()
{
int x,y;
memset(h,-1,sizeof(h));
n=read(),m=read(),root=read();
t=int(log10(n)/log10(2))+1;

for(int i=1;i<n;i++)
{
x=read();y=read();
add(x,y);add(y,x);
}
init();
for (int i=1;i<=m;i++)
{
x=read();y=read();
printf("%d\n",LCA(x,y));
}
return 0;
}


更新 诈尸(倍增 v2.0)

还是要用BFS初始化好一些啊

毕竟时间上BFS还是保守一点的,这次倍增也能过洛谷所有点了

下面贴代码

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
struct node{
int to,next;
}e[1000005];
int lca[500005][25],deep[600005],h[500005],q[500005];
bool vis[500005];
int n,m,t,root;
void add(int u,int v)
{
e[++t].to=v;
e[t].next=h[u];
h[u]=t;
}
int read()
{
int sum=0;
char ch=getchar();
while(ch>'9'||ch<'0') ch=getchar();
while(ch<='9'&&ch>='0')
{
sum=sum*10+ch-48;
ch=getchar();
}
return sum;
}
void bfs()
{
int hd=0,tl=1;
q[1]=root;
vis[root]=1;
deep[root]=1;
lca[root][0]=1;
while(hd<tl)
{
int u=q[++hd];
for(int i=1;i<=20;i++)
lca[u][i]=lca[lca[u][i-1]][i-1];
for(int i=h[u];i!=0;i=e[i].next)
{
int v=e[i].to;
if(vis[v]==0)
{
vis[v]=1;
q[++tl]=v;
deep[v]=deep[u]+1;
lca[v][0]=u;
}
}
}
}
int LCA(int x,int y)
{
if(deep[x]<deep[y]) swap(x,y);
for(int i=20;i>=0;i--)
if(deep[y]+(1<<i)<=deep[x])
x=lca[x][i];

if(x==y) return x;
for(int i=20;i>=0;i--)
if(lca[x][i]!=lca[y][i])
{
x=lca[x][i];
y=lca[y][i];
}
return lca[x][0];
}
int main()
{
int x,y;
n=read(),m=read(),root=read();
for(int i=1;i<n;i++)
{
x=read();y=read();
add(x,y);add(y,x);
}
bfs();
for (int i=1;i<=m;i++)
{
x=read();y=read();
printf("%d\n",LCA(x,y));
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  c++