您的位置:首页 > 其它

Poj 1741——treap的启发式合并

2017-03-10 18:47 253 查看

Tree

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001).

Define dist(u,v)=The min distance between node u and v.

Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.

Write a program that will count how many pairs which are valid for a given tree.

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.

The last test case is followed by two zeros.

Output

For each test case output the answer on a single line.

题目大意

给定一棵树,求出满足距离不大于k的点对的个数。

解题思想

最开始的想法就是做n次DFS然后累计答案,但是显然超时。

于是乎,我们想到如果能维护一个数据结构能够快速的知道小于k-disy的x的个数(x,y是枚举的点对,disx表示x到x,y最近祖先的距离,要满足disx+disy<=k,就是disx<=k-disy),马上就想到treap,treap中维护的数值显然是到最近公共祖先的距离,但是每次进行合并时又涉及到+w,比较麻烦,所以维护到根的距离比较方便(累计答案的时候注意加两倍的最近祖先到根的距离)。合并其实很简单,把小树往大树里放就可以了,启发式合并其实就是暴力的一个一个放。也许有人会问为什么不会超时,很明显,最坏的情况是两个一样大的树进行合并,所以对于每个节点最多就合并log(n)次,询问也是log(n)的,所以总复杂度是O(nlog^2(n))。操作时要注意先询问答案后合并,否则会算重。

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=10005,maxm=20005;
struct jz{
int x,s,w,ran,l,r;
}a[maxn*14];
int lnk[maxn],nxt[maxm],son[maxm],w[maxm],ro[maxn],tot,n,K,m,ans;
bool vis[maxn];
void add(int x,int y,int z){nxt[++tot]=lnk[x];lnk[x]=tot;son[tot]=y;w[tot]=z;}
void Putdata(int k){a[k].s=a[a[k].l].s+a[a[k].r].s+a[k].w;}
void rturn(int &k){
int t=a[k].l;a[k].l=a[t].r;a[t].r=k;
a[t].s=a[k].s;Putdata(k);k=t;
}
void lturn(int &k){
int t=a[k].r;a[k].r=a[t].l;a[t].l=k;
a[t].s=a[k].s;Putdata(k);k=t;
}
void Insert(int &k,int x){
if (k==0){k=++m;a[k].s=a[k].w=1;a[k].ran=rand();a[k].x=x;return;}
a[k].s++;
if (x==a[k].x) a[k].w++;else
if (x<a[k].x){
Insert(a[k].l,x);
if (a[a[k].l].ran<a[k].ran) rturn(k);
}else{
Insert(a[k].r,x);
if (a[a[k].r].ran<a[k].ran) lturn(k);
}
}
void Join(int &k1,int k2){
if (k2==0) return;
for (int i=1;i<=a[k2].w;i++) Insert(k1,a[k2].x);
Join(k1,a[k2].l);Join(k1,a[k2].r);
}
int Asksum(int k,int x){
if (k==0) return 0;
if (x==a[k].x) return a[k].w+a[a[k].l].s;else
if (x<a[k].x) return Asksum(a[k].l,x);else
return a[k].w+a[a[k].l].s+Asksum(a[k].r,x);
}
int Count(int k1,int k2,int x){
if (k2==0) return 0;
return a[k2].w*Asksum(k1,x-a[k2].x)+Count(k1,a[k2].l,x)+Count(k1,a[k2].r,x);
}
void DFS(int x,int dep){
vis[x]=1;
for (int j=lnk[x];j;j=nxt[j])if (!vis[son[j]]){
DFS(son[j],dep+w[j]);
if (a[ro[x]].s<a[ro[son[j]]].s) swap(ro[son[j]],ro[x]);
ans+=Count(ro[x],ro[son[j]],K+2*dep);Join(ro[x],ro[son[j]]);
}
ans+=Asksum(ro[x],K+dep);
Insert(ro[x],dep);
}
int main(){
freopen("exam.in","r",stdin);
freopen("exam.out","w",stdout);
while (1){
scanf("%d%d",&n,&K);
if (n==0&&K==0) return 0;
memset(a,0,sizeof(a));
memset(ro,0,sizeof(ro));
memset(vis,0,sizeof(vis));
memset(lnk,0,sizeof(lnk));
tot=ans=m=0;int x,y,z;
for (int i=1;i<n;i++){scanf("%d%d%d",&x,&y,&z);add(x,y,z);add(y,x,z);}
DFS(1,0);
printf("%d\n",ans);
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: