您的位置:首页 > 其它

luogu #3391 【模板】文艺平衡树(splay)

2017-11-30 20:06 267 查看
题目背景

这是一道经典的Splay模板题——文艺平衡树。

题目描述

您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:翻转一个区间,例如原有序序列是5 4 3 2 1,翻转区间是[2,4]的话,结果是5 2 3 4 1

输入输出格式

输入格式:

第一行为n,m n表示初始序列有n个数,这个序列依次是(1,2,⋯n−1,n),m表示翻转操作次数

接下来m行每行两个数[l,r]数据保证 1≤l≤r≤n

输出格式:

输出一行n个数字,表示原始序列经过m次变换后的结果

模板题。

这道题与以往splay题的不同之处在于,这道题的BST(二叉搜索树)维护的是点的编号,也就是说,用BST来维护点的编号,splay就能实现线段树的所有功能。

(前两道题都可以用map+lower_bound水过去)

这里讲一下如何实现区间翻转操作。

首先我们需要一个find_kth(找第k大节点的树上的点的编号)函数,在rotate的时候维护一个size就可以了,比较好写,脑补一下就好。

我们翻转区间l,r的时候,用find_kth找到l,r对应的节点,然后splay(l,0),splay(r,l),将l节点的右儿子的左儿子的翻转标记异或上1即可。



(很明显,编号小于r而大于l的节点都在x节点及其子树中)

大概就是这样了。

#include<cstdio>
#include<algorithm>
#include<cstring>
#define maxn 100050
using namespace std;
struct node
{
int f,w,s[2],sz,lz;
void init (int x,int fa) {f=s[0]=s[1]=lz=0,sz=1,w=x,f=fa;}
}a[maxn*2];
int root,dfn,n,m,l,r;
inline void pushup(int x) {a[x].sz=a[a[x].s[0]].sz+a[a[x].s[1]].sz+1;}
inline void pushdown(int x) {if (a[x].lz) a[a[x].s[0]].lz^=1,a[a[x].s[1]].lz^=1,swap(a[x].s[0],a[x].s[1]),a[x].lz=0;}
inline void rotate(int x,int k)
{
int y=a[x].f,z=a[y].f;
a[y].s[!k]=a[x].s[k];
if (a[y].s[!k]) a[a[y].s[!k]].f=y;
a[x].f=z;
if (z) a[z].s[y==a[z].s[1]]=x;
a[y].f=x,a[x].s[k]=y;
pushup(y),pushup(x);
}
inline void splay(int x,int g)
{
while (a[x].f!=g)
{
int y=a[x].f,z=a[y].f;
if (z==g) {rotate(x,a[y].s[0]==x);continue;}
if (y==a[z].s[0])
{
if (x==a[y].s[0]) rotate(y,1),rotate(x,1);
else rotate(x,0),rotate(x,1);
}
else
{
if (x==a[y].s[1]) rotate(y,0),rotate(x,0);
else rotate(x,1),rotate(x,0);
}
}
if (!g) root=x;
}
inline void insert(int x)
{
int u=root,fa=0;
while (u) fa=u,u=a[u].s[a[u].w<x];
u=++dfn;
if (fa) a[fa].s[a[u].w<x]=u;
a[u].init(x,fa);
splay(u,0);
}
inline int kth(int x)
{
int u=root;
while (1)
{
pushdown(u);
if (a[a[u].s[0]].sz>=x) u=a[u].s[0];
else if (a[a[u].s[0]].sz+1==x) return u;
else x-=a[a[u].s[0]].sz+1,u=a[u].s[1];
}
}
inline void rever(int l,int r)
{
l=kth(l),r=kth(r+2);
splay(l,0),splay(r,l);
a[a[a[root].s[1]].s[0]].lz^=1;
}
inline void print(int x)
{
pushdown(x);
if (a[x].s[0]) print(a[x].s[0]);
if(a[x].w!=1&&a[x].w!=n+2) printf("%d ",a[x].w-1);
if (a[x].s[1]) print(a[x].s[1]);
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n+2;i++) insert(i);
while (m--) scanf("%d%d",&l,&r),rever(l,r);
print(root);
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: