您的位置:首页 > 其它

BZOJ2733 [HNOI2012]永无乡 平衡树启发式合并

2014-05-31 20:30 323 查看
首先因为题目中涉及到查询第K小值,所以用平衡树来维护每个连通分支的信息。

那么加边这个操作怎么实现呢?其实就是将任意的两个平衡树合并。给我们的直观感受是把小的树合并到大的树里比较高效。

事实上,这样做的话,所有合并操作可以在O(nlog^2n)之内解决。

为什么呢?可以这样来分析。每个节点经过一次合并操作以后,它所在的树的大小至少要加倍,那么也就是说至多一个节点被合并操作影响logn次,每次合并后的插入操作要O(logn)时间,共有n个节点,就得到了O(nlog^2n)的时间复杂度。

吐槽一下数据……刚开始我没判断加边操作的两边是否已经在同一个连通分支内,就直接把树复制了一遍……竟然也AC了。下面的代码是改正以后的代码。

//BZOJ2733
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<string>
#include<cstring>
#include<vector>
#include<queue>
#include<ctime>
#include<cstdlib>
using namespace std;
const int MAXN=100010;
struct Treap_Node
{
int ch[2],key,dat,size,sub;
}Treap[MAXN<<5];
int p[MAXN],sz[MAXN],root[MAXN],ip[MAXN],n,m,in1,in2,q,tot;
char op[10];
int find(int x)
{
if(p[x]==x) return x;
p[x]=find(p[x]);
return p[x];
}
inline void uni(int i,int j)
{
sz[find(j)]+=sz[find(i)];
p[find(i)]=find(j);
}
inline int cmp(int x,int tar)
{
return (Treap[x].dat>tar)?0:1;
}
inline void maintain(int x)
{
Treap[x].size=Treap[Treap[x].ch[0]].size+1+Treap[Treap[x].ch[1]].size;
}
inline void rotate(int &x,int d)
{
int p=Treap[x].ch[d^1];
Treap[x].ch[d^1]=Treap[p].ch[d];
Treap[p].ch[d]=x;
maintain(x);
maintain(p);
x=p;
}
void ins(int &x,int tar,int s)
{
if(!x)
{
Treap[++tot].dat=tar,Treap[tot].sub=s,Treap[tot].key=rand();
Treap[tot].size=1,x=tot;
return;
}
int d=cmp(x,tar);
ins(Treap[x].ch[d],tar,s);
if(Treap[Treap[x].ch[d]].key>Treap[x].key) rotate(x,d^1);
maintain(x);
}
int getKth(int x,int k)
{
if(k<=Treap[Treap[x].ch[0]].size) return getKth(Treap[x].ch[0],k);
k-=Treap[Treap[x].ch[0]].size+1;
if(k<=0) return Treap[x].sub;
else return getKth(Treap[x].ch[1],k);
}
void mergeto(int x,int &y)
{
ins(y,Treap[x].dat,Treap[x].sub);
if(Treap[x].ch[0]) mergeto(Treap[x].ch[0],y);
if(Treap[x].ch[1]) mergeto(Treap[x].ch[1],y);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&ip[i]);
for(int i=1;i<=n;i++) p[i]=i,sz[i]=1;
for(int i=1;i<=m;i++)
{
scanf("%d%d",&in1,&in2);
if(find(in1)!=find(in2)) uni(in1,in2);
}
for(int i=1;i<=n;i++) ins(root[find(i)],ip[i],i);
scanf("%d",&q);
for(int i=1;i<=q;i++)
{
scanf("%s%d%d",op,&in1,&in2);
if(op[0]=='B'&&find(in1)!=find(in2))
{
int s1=sz[find(in1)],s2=sz[find(in2)];
if(s1>s2)
{
mergeto(root[find(in2)],root[find(in1)]);
uni(in2,in1);
}
else
{
mergeto(root[find(in1)],root[find(in2)]);
uni(in1,in2);
}
}
else if(op[0]=='Q')
{
if(sz[find(in1)]<in2) puts("-1");
else printf("%d\n",getKth(root[find(in1)],in2));
}
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: