您的位置:首页 > 其它

51nod 1766 树上的最远点对——线段树

2017-10-10 21:54 344 查看
n个点被n-1条边连接成了一颗树,给出a~b和c~d两个区间,表示点的标号请你求出两个区间内各选一点之间的最大距离,即你需要求出max{dis(i,j) |a<=i<=b,c<=j<=d}
(PS 建议使用读入优化) Input
第一行一个数字 n n<=100000。
第二行到第n行每行三个数字描述路的情况, x,y,z (1<=x,y<=n,1<=z<=10000)表示x和y之间有一条长度为z的路。
第n+1行一个数字m,表示询问次数 m<=100000。
接下来m行,每行四个数a,b,c,d。
Output
共m行,表示每次询问的最远距离
Input示例
5
1 2 1
2 3 2
1 4 3
4 5 4
1
2 3 4 5
Output示例
10

————————————————————————————

这道题可以证明两个区间并起来的最远点对 一定是两个区间单独最远点对中的四个点

然后我们就可以利用线段树来维护辣

#include<cstdio>
#include<cstring>
#include<algorithm>
using std::swap;
const int M=2e5+7,inf=0x3f3f3f3f;
int read(){
int ans=0,f=1,c=getchar();
while(c<'0'||c>'9'){if(c=='-') f=-1; c=getchar();}
while(c>='0'&&c<='9'){ans=ans*10+(c-'0'); c=getchar();}
return ans*f;
}
int max(int x,int y){return x>y?x:y;}
int n,m;
int first[M],cnt;
struct node{int to,next,w;}e[2*M];
void ins(int a,int b,int w){e[++cnt]=(node){b,first[a],w}; first[a]=cnt;}
void insert(int a,int b,int w){ins(a,b,w); ins(b,a,w);}
int sz[M],son[M],dep[M],fa[M],top[M],id[M],idp=1,dis[M];
void f1(int x){
sz[x]=1;
for(int i=first[x];i;i=e[i].next){
int now=e[i].to;
if(now==fa[x]) continue;
fa[now]=x;
dep[now]=dep[x]+1;
dis[now]=dis[x]+e[i].w;
f1(now); sz[x]+=sz[now];
if(sz[now]>sz[son[x]]) son[x]=now;
}
}
void f2(int x,int tp){
top[x]=tp; id[x]=idp++;
if(son[x]) f2(son[x],tp);
for(int i=first[x];i;i=e[i].next){
int now=e[i].to;
if(now!=fa[x]&&now!=son[x]) f2(now,now);
}
}
int cntq;
struct pos{int mx,p1,p2;}tr[2*M+1007];
int find(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
return x;
}
int calc(int x,int y){
int lca=find(x,y);
return dis[x]+dis[y]-2*dis[lca];
}
void up(int x,int ls,int rs){
int k;
tr[x].mx=tr[ls].mx; tr[x].p1=tr[ls].p1; tr[x].p2=tr[ls].p2;
if(tr[rs].mx>=tr[x].mx) tr[x].mx=tr[rs].mx,tr[x].p1=tr[rs].p1,tr[x].p2=tr[rs].p2;
int x1=tr[ls].p1,y1=tr[ls].p2,x2=tr[rs].p1,y2=tr[rs].p2;
if(x1!=-1){
if(x2!=-1&&(k=calc(x1,x2))>=tr[x].mx) tr[x].mx=k,tr[x].p1=x1,tr[x].p2=x2;
if(y2!=-1&&(k=calc(x1,y2))>=tr[x].mx) tr[x].mx=k,tr[x].p1=x1,tr[x].p2=y2;
}
if(y1!=-1){
if(x2!=-1&&(k=calc(y1,x2))>=tr[x].mx) tr[x].mx=k,tr[x].p1=x2,tr[x].p2=y1;
if(y2!=-1&&(k=calc(y1,y2))>=tr[x].mx) tr[x].mx=k,tr[x].p1=y1,tr[x].p2=y2;
}
}
void build(int x,int l,int r){
if(l==r){
tr[x].p1=l; tr[x].p2=l;
tr[x].mx=0; return ;
}
int mid=(l+r)>>1;
build(x<<1,l,mid);
build(x<<1^1,mid+1,r);
up(x,x<<1,x<<1^1);
}
int L,R;
int push_ans(int x,int l,int r){
if(L<=l&&r<=R) return x;
int mid=(l+r)>>1,ly=++cntq;
tr[ly]=(pos){0,-1,-1};
int s1=0,s2=0;
if(L<=mid) s1=push_ans(x<<1,l,mid);
if(R>mid) s2=push_ans(x<<1^1,mid+1,r);
up(ly,s1,s2);
return ly;
}
int ans,ly,a,b,c,d,s1,s2;
int main(){
int x,y,w;
n=read(); tr[0].p1=tr[0].p2=-1; tr[0].mx=-inf;
for(int i=1;i<n;i++) x=read(),y=read(),w=read(),insert(x,y,w);
f1(1); f2(1,1); build(1,1,n);
m=read();
for(int i=1;i<=m;i++){
a=read(); b=read();
c=read(); d=read();
cntq=2*M;
L=a; R=b; int s1=push_ans(1,1,n);
L=c; R=d; int s2=push_ans(1,1,n);
int x1=tr[s1].p1,y1=tr[s1].p2; //printf("[%d %d]\n",x1,y1);
int x2=tr[s2].p1,y2=tr[s2].p2; //printf("[%d %d]\n",x2,y2);
ans=max(max(calc(x1,y2),calc(x1,x2)),max(calc(y1,y2),calc(y1,x2)));
printf("%d\n",ans);
}
return 0;

}
View Code

 

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: