您的位置:首页 > 其它

BZOJ 3224 普通平衡树 平衡树 ( Treap , SBT , Splay ,替罪羊树 ,非旋转 Treap )

2018-03-01 23:14 549 查看
题目描述:
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)

Input第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)
Output对于操作3,4,5,6每行输出一个数,表示对应答案
Sample Input101 1064654 11 3177211 4609291 6449851 841851 898516 819681 4927375 493598Sample Output10646584185492737Hint1.n的数据范围:n<=100000
2.每个数的数据范围:[-2e9,2e9]
这道题目,就是考察平衡树的基本操作:插入,删除,求排名,前驱后继 
Treap  是最好写的,复杂度也不错,不容易出错。#include <bits/stdc++.h>
using namespace std ;
int n , cnt , root , ans , opt , x ;
struct data {
int ch[2] , data , size , prior , times ; // ch 孩子 , data 数据,size 子树大小 , prior 优先级 , times 出现次数
void init( int _data ) {
data = _data ; size = times = 1 ; prior = rand() ; // 初始化函数
}
int cmp( int x ) const {
return x == data ? -1 : x < data ? 0 : 1 ; // 决定当前数据该往哪个子树走
}
} tr[100005] ;

void update( int u ) { // 更新 u 处的子树大小 ,次数 + 左子树 + 右子树
tr[u].size = tr[u].times + tr[tr[u].ch[0]].size + tr[tr[u].ch[1]].size ;
}

void Rotate( int &u , int d ) { // u 传引用 , d = 0 左转, d = 1 右转
int t = tr[u].ch[d^1] ;
tr[u].ch[d^1] = tr[t].ch[d] ;
tr[t].ch[d] = u ;
update( u ) , update( t ) ; // 更新 u , t 的大小 , u 先 t 后
u = t ;
}

void insert( int &u , int x ) {
if( !u ) {
u = ++cnt ; tr[u].init( x ) ; return ;
}
++tr[u].size ;
int d = tr[u].cmp( x ) ;
if( d < 0 ) ++tr[u].times ; // 如果有重复元素 , 更新 times 次数,不建立新节点
else {
insert( tr[u].ch[d] , x ) ;
if( tr[tr[u].ch[d]].prior < tr[u].prior ) // 小顶堆
Rotate( u , d^1 ) ; // 插左边就往右边转 , 插右边就往左边转
}
}

void erase( int &u , int x ) {
if( !u ) return ;
int d = tr[u].cmp( x ) ;
if( d < 0 ) {
if( tr[u].times > 1 ) { // 还有重复元素 , 删除重复元素就行了
--tr[u].times ; --tr[u].size ; return ;
}
if( tr[u].ch[0] * tr[u].ch[1] == 0 ) // 左右子树有一个是空的
u = tr[u].ch[0] + tr[u].ch[1] ;
else {
d = tr[tr[u].ch[0]].prior < tr[tr[u].ch[1]].prior ? 0 : 1 ;
Rotate( u , d^1 ) ;
erase( u , x ) ;
}
}
else --tr[u].size , erase( tr[u].ch[d] , x ) ; // 沿途更新 size
}

int Query_rank( int u , int x ) { // 查询 x 在数据中的排名
if( !u ) return 0 ;
int d = tr[u].cmp( x ) , l = tr[tr[u].ch[0]].size ; // 找到 x
if( d < 0 ) return l + 1 ; // 左子树 + 1 就是在当前树的排名
if( d ) return l + tr[u].times + Query_rank( tr[u].ch[1] , x ) ; // 当前根的左子树 + 根 + 右子树的排名,都比 x 小
return Query_rank( tr[u].ch[0] , x ) ;
}

int Query_kth( int u , int x ) { // 查询第 x 个数
if( !u ) return 0 ;
int l = tr[tr[u].ch[0]].size ; // l 是左子树有多少个比自己小的
if( x <= l )
return Query_kth( tr[u].ch[0] , x ) ;
if( x > l + tr[u].times )
return Query_kth( tr[u].ch[1] , x - l - tr[u].times ) ; // 减去前面的排名,在右子树继续找剩下的排名
return tr[u].data ;
}

void Get_pre( int u , int x ) { // 找前驱
if( !u ) return ;
if( tr[u].data < x ) ans = u , Get_pre( tr[u].ch[1] , x ) ;
else Get_pre( tr[u].ch[0] , x ) ;
}

void Get_nex( int u , int x ) { // 找后继
if( !u ) return ;
if( tr[u].data > x ) ans = u , Get_nex( tr[u].ch[0] , x ) ;
else Get_nex( tr[u].ch[1] , x ) ;
}

int main() {
scanf( "%d" , &n ) ;
while( n-- ) {
scanf( "%d %d" , &opt , &x ) ;
switch( opt ) {
case 1: insert( root , x ) ; break ;
case 2: erase( root , x ) ; break ;
case 3: printf( "%d\n" , Query_rank( root , x ) ) ; break ;
case 4: printf( "%d\n" , Query_kth( root , x ) ) ; break ;
case 5: ans = 0 ; Get_pre( root , x ) ; printf( "%d\n" , tr[ans].data ) ; break ;
case 6: ans = 0 ; Get_nex( root , x ) ; printf( "%d\n" , tr[ans].data ) ; break ;
}
}
return 0 ;
}SBT , size balanced tree , 翻译成 “大小平衡树” 吧 

 ,是二叉搜索树,高度平衡。基本操作复杂度 log(n) 
基本性质是,每棵子树的大小不小于其兄弟的子树大小。
除了自身建树,删除元素,维护性质,其他功能和二叉搜索树都一致。#include <bits/stdc++.h>
using namespace std ;
int n , cnt , root , ans , opt , x ;
struct data {
int ch[2] , data , size , times ;
void init( int _data ) {
data = _data ; size = times = 1 ;
}
int cmp( int x ) const {
return x == data ? -1 : x < data ? 0 : 1 ;
}
} tr[100005] ;

void update( int u ) {
tr[u].size = tr[u].times + tr[tr[u].ch[0]].size + tr[tr[u].ch[1]].size ;
}

void Rotate( int &u , int d ) {
int t = tr[u].ch[d^1] ;
tr[u].ch[d^1] = tr[t].ch[d] ;
tr[t].ch[d] = u ;
tr[t].size = tr[u].size ;
update( u ) ;
u = t ;
}

void maintain( int &u , int d ) {
if( !u ) return ;
update( u ) ; // 以插入左边为例 , d = 0
int l = tr[tr[tr[u].ch[d]].ch[d]].size ; // 左子树的左子树大小
int r = tr[tr[tr[u].ch[d]].ch[d^1]].size ; // 左子树的右子树大小
int cur = tr[tr[u].ch[d^1]].size ; // 当前右子树的大小,因为插入左边,就可能左边的比右边的大
if( cur < l ) Rotate( u , d^1 ) ; // u 的左子树的左子树 比 u 的右子树大
else if( cur < r ) // u 的左子树的右子树 比 u 的右子树大
Rotate( tr[u].ch[d] , d ) , Rotate( u , d^1 ) ;
else return ;
maintain( tr[u].ch[0] , 0 ) ; // 更新新的根的左子树,
maintain( tr[u].ch[1] , 1 ) ; // 更新新的根的右子树
maintain( u , 0 ) ;
maintain( u , 1 ) ; // 更新新的根,左右两边都试试
}

void insert( int &u , int x ) {
if( !u ) {
u = ++cnt ; tr[u].init( x ) ; return ;
}
++tr[u].size ;
int d = tr[u].cmp( x ) ;
if( d < 0 ) ++tr[u].times ;
else insert( tr[u].ch[d] , x ) , maintain( u , d ) ;
}

void erase( int &u , int x ) {
if( !u ) return ;
int d = tr[u].cmp( x ) ;
if( d < 0 ) {
if( tr[u].times > 1 ) {
--tr[u].times ; --tr[u].size ; return ;
}
if( tr[u].ch[0] * tr[u].ch[1] == 0 )
u = tr[u].ch[0] + tr[u].ch[1] ;
else {
d = tr[tr[u].ch[0]].size > tr[tr[u].ch[1]].size ; // 哪边 size 更大,删哪边
Rotate( u , d ) ; // 先转过来
erase( tr[u].ch[d] , x ) ;
maintain( u , d^1 ) ; // 在一边删除,相当于在另一边插入
}
}
else --tr[u].size , erase( tr[u].ch[d] , x ) , maintain( u , d^1 ) ; // 回溯的 maintain 保持性质
}

int Query_rank( int u , int x ) {
if( !u ) return 0 ;
int d = tr[u].cmp( x ) , l = tr[tr[u].ch[0]].size ;
if( d < 0 ) return l + 1 ;
if( d ) return l + tr[u].times + Query_rank( tr[u].ch[1] , x ) ;
return Query_rank( tr[u].ch[0] , x ) ;
}

int Query_kth( int u , int x ) {
if( !u ) return 0 ;
int l = tr[tr[u].ch[0]].size ;
if( x <= l )
return Query_kth( tr[u].ch[0] , x ) ;
if( x > l + tr[u].times )
return Query_kth( tr[u].ch[1] , x - l - tr[u].times ) ;
return tr[u].data ;
}

void Get_pre( int u , int x ) {
if( !u ) return ;
if( tr[u].data < x ) ans = u , Get_pre( tr[u].ch[1] , x ) ;
else Get_pre( tr[u].ch[0] , x ) ;
}

void Get_nex( int u , int x ) {
if( !u ) return ;
if( tr[u].data > x ) ans = u , Get_nex( tr[u].ch[0] , x ) ;
else Get_nex( tr[u].ch[1] , x ) ;
}

int main() {
scanf( "%d" , &n ) ;
while( n-- ) {
scanf( "%d %d" , &opt , &x ) ;
switch( opt ) {
case 1: insert( root , x ) ; break ;
case 2: erase( root , x ) ; break ;
case 3: printf( "%d\n" , Query_rank( root , x ) ) ; break ;
case 4: printf( "%d\n" , Query_kth( root , x ) ) ; break ;
case 5: ans = 0 ; Get_pre( root , x ) ; printf( "%d\n" , tr[ans].data ) ; break ;
case 6: ans = 0 ; Get_nex( root , x ) ; printf( "%d\n" , tr[ans].data ) ; break ;
}
}
return 0 ;
}

Splay 把插入的,找到的节点旋转到根。也是平衡树的一种,更难写,不好找错。#include <bits/stdc++.h>
using namespace std ;
int root , cnt , n , ans ;
struct Node {
int data , times , size , ch[2] , f ;
void init( int _data , int fa ) {
data = _data ; f = fa ; times = size = 1 ;
}
int cmp( int x ) const {
return x == data ? -1 : x < data ? 0 : 1 ;
}
} tr[100005] ;

int rson( int f , int u ) { // 判断 u 是不是 f 的右孩子
return tr[f].ch[1] == u ;
}

void update( int r ) {
tr[r].size = tr[r].times + tr[tr[r].ch[0]].size + tr[tr[r].ch[1]].size ;
}

void Rotate( int u ) { // 带父节点的旋转
int fa = tr[u].f , grand = tr[fa].f ;
int d = rson( fa , u ) ;
tr[fa].ch[d] = tr[u].ch[d^1] ;
if( tr[u].ch[d^1] )
tr[tr[u].ch[d^1]].f = fa ;
tr[u].ch[d^1] = fa ;
tr[fa].f = u ;
tr[u].f = grand ;
if( grand )
tr[grand].ch[rson( grand , fa )] = u ;
update( fa ) ;
update( u ) ;
}

void Splay( int u , int tar ) { // 把 u 旋转到 tar 的某个孩子 , tar 为 NULL 意味着旋转到根
while( tr[u].f != tar ) {
int fa = tr[u].f ;
int grand = tr[fa].f ;
if( grand == tar )
Rotate( u ) ;
else
if( rson( grand , fa ) ^ rson( fa , u ) ) // 一致的旋转
Rotate( u ) , Rotate( u ) ;
else
Rotate( fa ) , Rotate( u ) ;
}
if( !tar ) root = u ;
// update( u ) ;
}

void insert( int u , int x , int pre ) {
if( !u ) {
u = ++cnt ; tr[u].init( x , pre ) ;
if( pre ) tr[pre].ch[x > tr[pre].data] = u ;
Splay( u , 0 ) ; // 转到根
return ;
}
++tr[u].size ;
int d = tr[u].cmp( x ) ;
if( d < 0 ) ++tr[u].times , Splay( u , 0 ) ; // 转到根
else insert( tr[u].ch[d] , x , u ) ;
}

void Find( int u , int x ) {
if( !u ) return ;
while( tr[u].ch[x > tr[u].data] && tr[u].data != x )
u = tr[u].ch[x > tr[u].data] ;
Splay( u , 0 ) ; // 找得到就转到根,找不到就旋转第一个大于 x 的数为根
}

int Query_rank( int x ) { // 先转到根
Find( root , x ) ;
return tr[tr[root].ch[0]].size + 1 ; // 左子树大小 + 1
}

int Query_kth( int u , int x ) {
if( !u ) return 0 ;
int l = tr[tr[u].ch[0]].size ;
if( x <= l )
return Query_kth( tr[u].ch[0] , x ) ;
if( x > l + tr[u].times )
return Query_kth( tr[u].ch[1] , x - l - tr[u].times ) ;
return tr[u].data ;
}

void Get_pre( int u , int x ) {
if( !u ) return ;
if( tr[u].data < x ) ans = u , Get_pre( tr[u].ch[1] , x ) ;
else Get_pre( tr[u].ch[0] , x ) ;
}

void Get_nex( int u , int x ) {
if( !u ) return ;
if( tr[u].data > x ) ans = u , Get_nex( tr[u].ch[0] , x ) ;
else Get_nex( tr[u].ch[1] , x ) ;
}

void erase( int R , int x ) {
if( !R ) return ;
int d = tr[R].cmp( x ) ;
if( d != -1 )
--tr[R].size , erase( tr[R].ch[d] , x ) ; // 沿途更新 size
else {
Splay( R , 0 ) ; // R 是要删除的,先转到根
if( tr[R].times > 1 ) {
--tr[R].times , --tr[R].size ; return ;
}
if( tr[R].ch[0] == 0 ) { // R 只有右孩子
root = tr[R].ch[1] ;
if( root ) tr[root].f = 0 ; // R 的右孩子为根
}
else {
int p = tr[R].ch[0] ;
while( tr[p].ch[1] )
p = tr[p].ch[1] ; // 找 R 的前驱
Splay( p , R ) ; // 把 p 转到 R 的孩子上
root = p ; // p 为根,删除 R
tr[root].f = 0 ;
tr[p].ch[1] = tr[R].ch[1] ; // p 之前没有右孩子
if( tr[p].ch[1] )
tr[tr[p].ch[1]].f = p ;
}
// update( root ) ;
}
}

int main() {
scanf( "%d" , &n ) ;
int opt , x ;
while( n-- ) {
scanf( "%d%d" , &opt , &x ) ;
switch( opt ) {
case 1: insert( root , x , 0 ) ; break ;
case 2: erase( root , x ) ; break ;
case 3: printf( "%d\n" , Query_rank( x ) ) ; break ;
case 4: printf( "%d\n" , Query_kth( root , x ) ) ; break ;
case 5: ans = 0 , Get_pre( root , x ) , printf( "%d\n" , tr[ans].data ) ; break ;
case 6: ans = 0 , Get_nex( root , x ) , printf( "%d\n" , tr[ans].data ) ; break ;
}
}
return 0 ;
}以上的 Treap , SBT , Splay 复杂度都差不多。
当然,网上关于这道题目,有更好的解法——替罪羊树,好像 200 ms 左右就可以。我的 Treap . Splay , SBT 都是三四百ms 的样子。
不过我写的替罪羊树,没那么快,360ms , emm 。不过学习了“暴力”的优雅。#include <bits/stdc++.h>
using namespace std ;
const double alpla = 0.75 ; // 比例 , 子树超过这个比例 , 就重建
int n , cnt , root , ans , opt , x ;
struct data {
in
d4e1
t ch[2] , data , size , times ; // times 出现次数 , = 0 说明没有或者全部被删除了。
void init( int _data ) {
data = _data ; size = times = 1 ;
}
int cmp( int x ) const {
return x == data ? -1 : x < data ? 0 : 1 ;
}
} tr[100005] ;
vector<int> One ; // 储存重建的中序遍历的序列,从小到大

void update( int u ) {
int l = tr[u].ch[0] , r = tr[u].ch[1] ;
tr[u].size = tr[u].times + tr[l].size + tr[r].size ;
}

int is_bad( int u ) {
int l = tr[u].size * alpla < tr[tr[u].ch[0]].size ;
int r = tr[u].size * alpla < tr[tr[u].ch[1]].size ;
return l || r ; // 左右有一边偏大
}

void DFS( int u ) {
if( !u ) return ;
DFS( tr[u].ch[0] ) ;
if( tr[u].times ) One.push_back( u ) ; // 如果还存在,没被全部删掉
DFS( tr[u].ch[1] ) ;
}

int Divide( int l , int r ) {
if( l >= r ) return 0 ;
int mid = ( l + r ) >> 1 ;
int u = One[mid] ; // 以中间为根,这样就可以达到高度平衡。
tr[u].ch[0] = Divide( l , mid ) ;
tr[u].ch[1] = Divide( mid+1 , r ) ;
update( u ) ; // 建好左右两边,更新根的大小
return u ;
}

void Re_build( int &u ) {
One.clear() ;
DFS( u ) ; // 获得从小到大的有效序列
u = Divide( 0 , (int)One.size() ) ; // 根据 One 序列从 Mid 递归建立
}

void insert( int &u , int x ) {
if( !u ) {
u = ++cnt ; tr[u].init( x ) ; return ;
}
++tr[u].size ;
int d = tr[u].cmp( x ) ;
if( d < 0 ) ++tr[u].times ; // 早就出现过了
else {
insert( tr[u].ch[d] , x ) ;
if( is_bad( u ) ) // 检验是否严重失衡
Re_build( u ) ;
}
}

void erase( int &u , int x ) {
if( !u ) return ;
int d = tr[u].cmp( x ) ;
if( d < 0 ) {
if( tr[u].times > 1 ) {
--tr[u].times ; --tr[u].size ; return ;
}
if( tr[u].times == 0 ) return ; // 目标全部被删除了
tr[u].times = 0 ; // 把目标标记,正要删除最后一次
Re_build( u ) ; // 重建以 u 为根 的树( 本来不想写的,不写就 WA , 至今没找到错误, 还是写上吧)
}
else --tr[u].size , erase( tr[u].ch[d] , x ) ;
}

int Query_rank( int u , int x ) {
if( !u ) return 0 ;
int d = tr[u].cmp( x ) , l = tr[tr[u].ch[0]].size ;
if( d < 0 ) return l + 1 ;
if( d ) return l + tr[u].times + Query_rank( tr[u].ch[1] , x ) ;
return Query_rank( tr[u].ch[0] , x ) ;
}

int Query_kth( int u , int x ) {
if( !u ) return 0 ;
int l = tr[tr[u].ch[0]].size ;
if( x <= l )
return Query_kth( tr[u].ch[0] , x ) ;
if( x > l + tr[u].times )
return Query_kth( tr[u].ch[1] , x - l - tr[u].times ) ;
return tr[u].data ;
}

void Get_pre( int u , int x ) {
if( !u ) return ;
if( tr[u].data < x ) ans = u , Get_pre( tr[u].ch[1] , x ) ;
else Get_pre( tr[u].ch[0] , x ) ;
}

void Get_nex( int u , int x ) {
if( !u ) return ;
if( tr[u].data > x ) ans = u , Get_nex( tr[u].ch[0] , x ) ;
else Get_nex( tr[u].ch[1] , x ) ;
}

int main() {
scanf( "%d" , &n ) ;
while( n-- ) {
scanf( "%d %d" , &opt , &x ) ;
switch( opt ) {
case 1: insert( root , x ) ; break ;
case 2: erase( root , x ) ; break ;
case 3: printf( "%d\n" , Query_rank( root , x ) ) ; break ;
case 4: printf( "%d\n" , Query_kth( root , x ) ) ; break ;
case 5: ans = 0 ; Get_pre( root , x ) ; printf( "%d\n" , tr[ans].data ) ; break ;
case 6: ans = 0 ; Get_nex( root , x ) ; printf( "%d\n" , tr[ans].data ) ; break ;
}
}
return 0 ;
}
听说还有非旋转的 treap , 学了一下,感觉和左偏树好像,这里是把小于和大于两部分拆开来,然后在中间插入,删除,或者在左边找前驱,在右边找后继。插入删除和一般的平衡树很不一样...... 代码很短,很好理解。就是 split 拆分的操作,画棵树比划比划就明白了。
参考这篇博客  非旋转 Treap#include <bits/stdc++.h>
using namespace std ;
int n , opt , x ;

class Split_Treap {
private:
int root , cnt ;
struct Node {
int data , size , ch[2] ;
void init( int _data ) {
data = _data ; size = 1 ; ch[0] = ch[1] = 0 ;
}
} tr[100005] ;
public:
Split_Treap() {
root = cnt = 0 ;
memset( tr , 0 , sizeof( tr ) ) ;
}
~Split_Treap() {}

void update( int u ) {
tr[u].size = 1 + tr[tr[u].ch[0]].size + tr[tr[u].ch[1]].size ;
}

void split( int u , int x , int &l , int &r ) {
if( !u ) { l = r = 0 ; return ; }
if( tr[u].data <= x )
l = u , split( tr[l].ch[1] , x , tr[l].ch[1] , r ) ; // 找出 <= x 的部分, 后面找到的肯定插在右子树上
else
r = u , split( tr[r].ch[0] , x , l , tr[r].ch[0] ) ; // 找出 > x 的部分
update( u ) ;
}

int merge( int l , int r ) {
if( l * r == 0 ) return l + r ;
if( rand() % 2 ) {
tr[l].ch[1] = merge( tr[l].ch[1] , r ) ; // 一半的概率会归并到另一边
update( l ) ; return l ;
} else {
tr[r].ch[0] = merge( l , tr[r].ch[0] ) ;
update( r ) ; return r ;
}
}

void insert( int x ) {
int l = 0 , r = 0 ;
split( root , x , l , r ) ; // 先把原有的二叉树分出 <= x 的部分, 新的 x 合并 l , r
tr[++cnt].init( x ) ;
root = merge( merge( l , cnt ) , r ) ;
}

void erase( int x ) {
if( !root ) return ;
int l = 0 , r = 0 , mid = 0 ;
split( root , x-1 , l , r ) ; // 分离出 <= x-1 的部分
split( r , x , mid , r ) ; // 在 > x-1 的部分中, 分离出 <= x 的部分 mid, 和 > x 的部分 r
root = merge( l , merge( merge( tr[mid].ch[0] , tr[mid].ch[1] ) , r ) ) ;
}
// 把 mid 的子树合并, 然后再合并 l , r
void rank( int x ) {
int l = 0 , r = 0 ;
split( root , x-1 , l , r ) ; // 分离出 <= x-1 的部分
printf( "%d\n" , tr[l].size + 1 ) ; // <= x-1 的部分的数 + 1 就是 x 的排名
root = merge( l , r ) ;
}

void Kth( int x ) {
int p = root ;
while( p ) {
int l = tr[tr[p].ch[0]].size ;
if( l + 1 == x )
break ;
if( l >= x ) p = tr[p].ch[0] ;
else x -= l + 1 , p = tr[p].ch[1] ;
}
printf( !p ? "0\n" : "%d\n" , tr[p].data ) ;
}

int Get( int u , int d ) {
if( !u ) return 0 ;
while( tr[u].ch[d] ) u = tr[u].ch[d] ;
return tr[u].data ;
}

void pre( int x ) {
int l , r ;
split( root , x-1 , l , r ) ; // 先分离出 <= x-1 的部分, 在这里找 x 的前驱
printf( "%d\n" , Get( l , 1 ) ) ; // 不用 split 也可以, 就按照二叉树找前驱的方法也行
root = merge( l , r ) ;
}

void nex( int x ) {
int l , r ;
split( root , x , l , r ) ; // 先分离出 > x 的部分, 在这里找后继
printf( "%d\n" , Get( r , 0 ) ) ;
root = merge( l , r ) ;
}
} ;

int main() {
Split_Treap One ;
scanf( "%d" , &n ) ;
while( n-- ) {
scanf( "%d %d" , &opt , &x ) ;
switch( opt ) {
case 1: One.insert( x ) ; break ;
case 2: One.erase( x ) ; break ;
case 3: One.rank( x ) ; break ;
case 4: One.Kth( x ) ; break ;
case 5: One.pre( x ) ; break ;
case 6: One.nex( x ) ; break ;
}
}
return 0 ;
}
我“脑洞小开” , 感觉 Treap 直接写 rand 函数可能也行,几分之一( 我选的 1 / 4 )的概率会往插入的反方向旋转,调整一下,  其实大概 1 / 3 的概率会旋转,因为每次插入一个元素,都会比较孩子和自己的优先级,大小情况也就三种情况,孩子优先级高;自己优先级高;优先级一样——1 / 3 。删除元素,也是用 rand 函数来控制,偶数就往左转, 奇数就往右转,其实 Treap 删除元素的旋转,是看左右子树的优先级的,优先级是随机的,肯定会转一次,所以,1 / 2 的概率 。其实和 Treap 是差不多的,都是随机。测试的结果和 Treap ,SBT 差不多。
至于复杂度正确性,有待考证,哈哈,不过,倒是很简单。#include <bits/stdc++.h>
using namespace std ;
int n , cnt , root , ans , opt , x ;
struct data {
int ch[2] , data , size , times ;
void init( int _data ) {
data = _data ; size = times = 1 ;
}
int cmp( int x ) const {
return x == data ? -1 : x < data ? 0 : 1 ;
}
} tr[100005] ;

void update( int u ) {
tr[u].size = tr[u].times + tr[tr[u].ch[0]].size + tr[tr[u].ch[1]].size ;
}

void Rotate( int &u , int d ) {
int t = tr[u].ch[d^1] ;
tr[u].ch[d^1] = tr[t].ch[d] ;
tr[t].ch[d] = u ;
update( u ) ;
update( t ) ;
u = t ;
}

void insert( int &u , int x ) {
if( !u ) {
u = ++cnt ; tr[u].init( x ) ; return ;
}
++tr[u].size ;
int d = tr[u].cmp( x ) ;
if( d < 0 ) ++tr[u].times ;
else {
insert( tr[u].ch[d] , x ) ;
if( rand() % 4 == 3 ) // 1 / 4 的概率会旋转
Rotate( u , d^1 ) ;
}
}

void erase( int &u , int x ) {
if( !u ) return ;
int d = tr[u].cmp( x ) ;
if( d < 0 ) {
if( tr[u].times > 1 ) {
--tr[u].times ; --tr[u].size ; return ;
}
if( tr[u].ch[0] * tr[u].ch[1] == 0 )
u = tr[u].ch[0] + tr[u].ch[1] ;
else {
d = rand() % 2 ; // 偶数右转,奇数左转
Rotate( u , d^1 ) ;
erase( u , x ) ;
}
}
else --tr[u].size , erase( tr[u].ch[d] , x ) ;
}

int Query_rank( int u , int x ) {
if( !u ) return 0 ;
int d = tr[u].cmp( x ) , l = tr[tr[u].ch[0]].size ;
if( d < 0 ) return l + 1 ;
if( d ) return l + tr[u].times + Query_rank( tr[u].ch[1] , x ) ;
return Query_rank( tr[u].ch[0] , x ) ;
}

int Query_kth( int u , int x ) {
if( !u ) return 0 ;
int l = tr[tr[u].ch[0]].size ;
if( x <= l )
return Query_kth( tr[u].ch[0] , x ) ;
if( x > l + tr[u].times )
return Query_kth( tr[u].ch[1] , x - l - tr[u].times ) ;
return tr[u].data ;
}

void Get_pre( int u , int x ) {
if( !u ) return ;
if( tr[u].data < x ) ans = u , Get_pre( tr[u].ch[1] , x ) ;
else Get_pre( tr[u].ch[0] , x ) ;
}

void Get_nex( int u , int x ) {
if( !u ) return ;
if( tr[u].data > x ) ans = u , Get_nex( tr[u].ch[0] , x ) ;
else Get_nex( tr[u].ch[1] , x ) ;
}

int main() {
// freopen( "BZOJ 3224.txt" , "r" , stdin ) ;
scanf( "%d" , &n ) ;
while( n-- ) {
scanf( "%d %d" , &opt , &x ) ;
switch( opt ) {
case 1: insert( root , x ) ; break ;
case 2: erase( root , x ) ; break ;
case 3: printf( "%d\n" , Query_rank( root , x ) ) ; break ;
case 4: printf( "%d\n" , Query_kth( root , x ) ) ; break ;
case 5: ans = 0 ; Get_pre( root , x ) ; printf( "%d\n" , tr[ans].data ) ; break ;
case 6: ans = 0 ; Get_nex( root , x ) ; printf( "%d\n" , tr[ans].data ) ; break ;
}
}
return 0 ;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: