您的位置:首页 > 其它

【BZOJ3196】【Tyvj1730】二逼平衡树,第一次的树套树(线段树+splay)

2016-04-07 16:24 417 查看
传送门1

传送门2

写在前面:创造迄今最长的正常代码的记录

思路:个人感觉这个树套树就是对线段树的每个区间建一棵splay来维护,最初觉得这个方法会爆T爆M……(实际上真的可能会爆)。对于5个操作,我们有如下策略

对于操作1,我们比较容易想到,寻找k在[l,r]上的排名就是求[l,r]中比k小的数的数量+1,这等价于找出它在[l,mid]和[mid+1,r]上比他小的数的总数量+1,然后就可以线段树一层层套下去,再用Splay的rank函数查找了

对于操作2,这是一个比较麻烦的,因为它不能像1一样在区间中合并,但数据范围是[0,10^8],所以我们可以令l=0,r=10^8,二分查找mid的排名,最后得到正解

对于操作3,这是单点修改,所以直接一直放下去,并修改所在的个区间的splay(先del原数值再insert新数值)

对于操作4,5,显然答案也是可以在区间上合并的,前驱找出最大的,后继找出最小的,一层层下放,如果查询区间覆盖了当前线段树的节点区间,就直接调用Splay的前驱后继函数

注意:

1.记录每个splay的根节点并需要实时修改,推荐通过记录其下标来修改(代码中rt全部为当前splay的根在数组中的下标),毕竟取地址符什么看起来就很不舒服= =

2.废物利用,每次修改操作时记录下原数在splay中的下标,到时候插入的时候直接用就行了(记得初始化),防止下标加的过多导致RE(如果原数出现次数大于1则不能这么做,只能再开一个下标)

3.代码在BZOJ上测试通过,但截至发文时间,Tyvj服务器一直处于崩溃,无法评测,po主在cogs上评测T了两个点……开O2加inline也不管用……

#include<bits/stdc++.h>
#define pd(i) (i>='0'&&i<='9')
using namespace std;
int n,m,tot;
int num[50003],roots[2000003];
struct Splay
{
int fa,ch[2],siz,data,occ;
}a[2000003];
int in()
{
int t=0,f=1;
char ch=getchar();
while (!pd(ch))
{
if (ch=='-') f=-1;
ch=getchar();
}
while (pd(ch)) t=(t<<3)+(t<<1)+ch-'0',ch=getchar();
return f*t;
}
void ct(int x)
{
a[x].siz=a[a[x].ch[0]].siz+a[a[x].ch[1]].siz+a[x].occ;
}
void made(int x,int id)
{
a[id].data=x,
a[id].occ=a[id].siz=1,
a[id].ch[0]=a[id].ch[1]=a[id].fa=0;
}
void rorate(int now,bool mk)
{
int pa=a[now].fa;
a[a[now].ch[mk]].fa=pa;
a[pa].ch[!mk]=a[now].ch[mk];
a[now].fa=a[pa].fa;
if (a[pa].fa)
{
if (a[a[pa].fa].ch[0]==pa) a[a[pa].fa].ch[0]=now;
else a[a[pa].fa].ch[1]=now;
}
a[now].ch[mk]=pa;
a[pa].fa=now;
ct(pa);ct(now);
}
void splay(int rt,int now,int goal)
{
int pa;
while (a[now].fa!=goal)
{
pa=a[now].fa;
if (a[pa].fa==goal)
{
if (a[pa].ch[0]==now) rorate(now,1);
else rorate(now,0);
}
else if (a[a[pa].fa].ch[0]==pa)
{
if (a[pa].ch[0]==now) rorate(pa,1);
else rorate(now,0);
rorate(now,1);
}
else
{
if (a[pa].ch[1]==now) rorate(pa,0);
else rorate(now,1);
rorate(now,0);
}
}
if (!goal) roots[rt]=now;
}
void insert(int rt,int x,int id)
{
if (!roots[rt]) {made(x,id);roots[rt]=id;return;}
int now=roots[rt];
while (now)
{
if (a[now].data==x) {a[now].occ++;a[now].siz++;splay(rt,now,0);return;}
if (a[now].data>x)
{
if (!a[now].ch[0]) {made(x,id);a[now].ch[0]=id;a[id].fa=now;break;}
else now=a[now].ch[0];
}
else
{
if (!a[now].ch[1]) {made(x,id);a[now].ch[1]=id;a[id].fa=now;break;}
else now=a[now].ch[1];
}
}
splay(rt,id,0);
}
int find(int root,int x)
{
int now=root;
while (now)
{
if (a[now].data==x) return now;
if (a[now].data>x) now=a[now].ch[0];
else now=a[now].ch[1];
}
}
int findmax(int now)
{
while (a[now].ch[1]) now=a[now].ch[1];
return now;
}
int find_next_min(int rt,int x)
{
int now=roots[rt],t=0,ans=-0x7fffffff;
while (now)
{
if (a[now].data<x)
{
if (ans<a[now].data)ans=a[now].data,t=now;
now=a[now].ch[1];
}
else now=a[now].ch[0];
}
return ans;
}
int find_next_max(int rt,int x)
{
int now=roots[rt],t=0,ans=0x7fffffff;
while (now)
{
if (a[now].data>x)
{
if (ans>a[now].data) ans=a[now].data,t=now;
now=a[now].ch[0];
}
else now=a[now].ch[1];
}
return ans;
}
void replace(int rt,int x,int k)
{
int now=find(roots[rt],x);
splay(rt,now,0);
if (a[now].occ>1) {a[now].occ--;a[now].siz--;}
else if (a[now].siz==1) roots[rt]=0;
else if (!a[now].ch[0])
{
roots[rt]=a[now].ch[1];
a[a[now].ch[1]].fa=0;
}
else if (!a[now].ch[1])
{
roots[rt]=a[now].ch[0];
a[a[now].ch[0]].fa=0;
}
else
{
splay(rt,findmax(a[now].ch[0]),now);
a[a[now].ch[0]].ch[1]=a[now].ch[1];
a[a[now].ch[1]].fa=a[now].ch[0];
a[a[now].ch[0]].fa=0;
roots[rt]=a[now].ch[0];
ct(a[now].ch[0]);
}
if (!a[now].occ)insert(rt,k,now);
else insert(rt,k,++tot);
}
int find_rank(int rt,int x)//这里的findrank实际上是在splay里找比x小的数的数量
{
int now=roots[rt],ans=0;
while (now)
{
if (a[now].data>x) now=a[now].ch[0];
else if (a[now].data<x)
ans+=(a[now].occ+a[a[now].ch[0]].siz),
now=a[now].ch[1];
else {ans+=a[a[now].ch[0]].siz;break;}
}
return ans;
}
void build(int now,int begin,int end)
{
for (int i=begin;i<=end;i++) insert(now,num[i],++tot);
if (begin==end) return;
int mid=(begin+end)>>1;
build(now<<1,begin,mid);
build(now<<1|1,mid+1,end);
}
int solve1(int now,int begin,int end,int l,int r,int k)
{
if (l<=begin&&end<=r) return find_rank(now,k);
int mid=(begin+end)>>1,rank=0;
if (mid>=l) rank+=solve1(now<<1,begin,mid,l,r,k);
if (mid<r) rank+=solve1(now<<1|1,mid+1,end,l,r,k);
return rank;
}
int solve2(int l,int r,int k)
{
int begin=0,end=1e8+1,mid;
while (begin<end)
{
mid=(begin+end)>>1;
if (solve1(1,1,n,l,r,mid)<k)
begin=mid+1;
else end=mid;
}
return begin-1;
}
void solve3(int now,int begin,int end,int pos,int k)
{
replace(now,num[pos],k);
if (begin==end) {num[pos]=k;return;}
int mid=(begin+end)>>1;
if (mid>=pos) solve3(now<<1,begin,mid,pos,k);
else solve3(now<<1|1,mid+1,end,pos,k);
}
int solve4(int now,int begin,int end,int l,int r,int k)
{
if (l<=begin&&end<=r) return find_next_min(now,k);
int mid=(begin+end)>>1,ans=-0x7fffffff;
if (mid>=l) ans=max(ans,solve4(now<<1,begin,mid,l,r,k));
if (mid<r) ans=max(ans,solve4(now<<1|1,mid+1,end,l,r,k));
return ans;
}
int solve5(int now,int begin,int end,int l,int r,int k)
{
if (l<=begin&&end<=r) return find_next_max(now,k);
int mid=(begin+end)>>1,ans=0x7fffffff;
if (mid>=l) ans=min(ans,solve5(now<<1,begin,mid,l,r,k));
if (mid<r) ans=min(ans,solve5(now<<1|1,mid+1,end,l,r,k));
return ans;
}
main()
{
n=in();m=in();
int opt,x,y,k;
for (int i=1;i<=n;i++) num[i]=in();
build(1,1,n);
while (m--)
{
opt=in();
if (opt!=3)x=in(),y=in(),k=in();
else x=in(),y=in();
if (opt==1) printf("%d\n",solve1(1,1,n,x,y,k)+1);
else if (opt==2) printf("%d\n",solve2(x,y,k));
else if (opt==3) solve3(1,1,n,x,y);
else if (opt==4) printf("%d\n",solve4(1,1,n,x,y,k));
else printf("%d\n",solve5(1,1,n,x,y,k));
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: