您的位置:首页 > 其它

Splay伸展树模板总结

2015-09-04 16:20 483 查看
1.基本点操作

//Splay 基本操作 均摊复杂度O(lgN)

//POJ 1442
//基本点操作

// sp.init()	初始化

#include <iostream>
#include <algorithm>
#include <cmath>

using namespace std;

#define maxn 31000
#define min(a,b) ((a)<(b)?(a):(b))
const int oo = 0x3f3f3f3f;
struct Node{
int key, sz, cnt;
Node *ch[2], *pnt;//左右儿子和父亲
Node(){}
Node(int x, int y, int z){
key = x, sz = y, cnt = z;
}
void rs(){
sz = ch[0]->sz + ch[1]->sz + cnt;
}
}nil(0, 0, 0), *NIL = &nil;
struct Splay{//伸展树结构体类型
Node *root;
int ncnt;//计算key值不同的结点数,注意已经去重了
Node nod[maxn];
void init(){// 首先要初始化
root = NIL;
ncnt = 0;
}
void rotate(Node *x, bool d){//旋转操作,d为true表示右旋
Node *y = x->pnt;
y->ch[!d] = x->ch[d];
if (x->ch[d] != NIL)
x->ch[d]->pnt = y;
x->pnt = y->pnt;
if (y->pnt != NIL){
if (y == y->pnt->ch[d])
y->pnt->ch[d] = x;
else
y->pnt->ch[!d] = x;
}
x->ch[d] = y;
y->pnt = x;
y->rs();
x->rs();
}
void splay(Node *x, Node *target){//将x伸展到target的儿子位置处
Node *y;
while (x->pnt != target){
y = x->pnt;
if (x == y->ch[0]){
if (y->pnt != target && y == y->pnt->ch[0])
rotate(y, true);
rotate(x, true);
}
else{
if (y->pnt != target && y == y->pnt->ch[1])
rotate(y, false);
rotate(x, false);
}
}
if (target == NIL)
root = x;
}
/************************以上一般不用修改************************/
void insert(int key){//插入一个值
if (root == NIL){
ncnt = 0;
root = &nod[++ncnt];
root->ch[0] = root->ch[1] = root->pnt = NIL;
root->key = key;
root->sz = root->cnt = 1;
return;
}
Node *x = root, *y;
while (1){
x->sz++;
if (key == x->key){
x->cnt++;
x->rs();
y = x;
break;
}
else if (key < x->key){
if (x->ch[0] != NIL)
x = x->ch[0];
else{
x->ch[0] = &nod[++ncnt];
y = x->ch[0];
y->key = key;
y->sz = y->cnt = 1;
y->ch[0] = y->ch[1] = NIL;
y->pnt = x;
break;
}
}
else{
if (x->ch[1] != NIL)
x = x->ch[1];
else{
x->ch[1] = &nod[++ncnt];
y = x->ch[1];
y->key = key;
y->sz = y->cnt = 1;
y->ch[0] = y->ch[1] = NIL;
y->pnt = x;
break;
}
}
}
splay(y, NIL);
}
Node* search(int key){//查找一个值,返回指针
if (root == NIL)
return NIL;
Node *x = root, *y = NIL;
while (1){
if (key == x->key){
y = x;
break;
}
else if (key > x->key){
if (x->ch[1] != NIL)
x = x->ch[1];
else
break;
}
else{
if (x->ch[0] != NIL)
x = x->ch[0];
else
break;
}
}
splay(x, NIL);
return y;
}
Node* searchmin(Node *x){//查找最小值,返回指针
Node *y = x->pnt;
while (x->ch[0] != NIL){//遍历到最左的儿子就是最小值
x = x->ch[0];
}
splay(x, y);
return x;
}
Node* searchmax(Node *x){
Node *y = x->pnt;
while(x->ch[1] != NIL){
x = x->ch[1];
}
splay(x,y);
return x;
}
void del(int key){//删除一个值
if (root == NIL)
return;
Node *x = search(key), *y;
if (x == NIL)
return;
if (x->cnt > 1){
x->cnt--;
x->rs();
return;
}
else if (x->ch[0] == NIL && x->ch[1] == NIL){
init();
return;
}
else if (x->ch[0] == NIL){
root = x->ch[1];
x->ch[1]->pnt = NIL;
return;
}
else if (x->ch[1] == NIL){
root = x->ch[0];
x->ch[0]->pnt = NIL;
return;
}
y = searchmin(x->ch[1]);
y->pnt = NIL;
y->ch[0] = x->ch[0];
x->ch[0]->pnt = y;
y->rs();
root = y;
}
int rank(int key){//求结点高度
Node *x = search(key);
if (x == NIL)
return 0;
return x->ch[0]->sz + 1/* or x->cnt*/;
}
Node* findk(int kth){//查找第k小的值
if (root == NIL || kth > root->sz)
return NIL;
Node *x = root;
while (1){
if (x->ch[0]->sz +1 <= kth && kth <= x->ch[0]->sz + x->cnt)
break;
else if (kth <= x->ch[0]->sz)
x = x->ch[0];
else{
kth -= x->ch[0]->sz + x->cnt;
x = x->ch[1];
}
}
splay(x, NIL);
return x;
}

Node* proc(int key)
{
Node *x = search(key);
if(x==NIL || x->cnt == 1 && x->ch[0]==NIL)
{
return NIL;
}
if(x->cnt > 1)
{
return x;
}
else
{
return searchmax(x->ch[0]);
}
}

Node* succ(int key)
{
Node *x = search(key);
if(x==NIL || x->cnt == 1 && x->ch[1]==NIL)
{
return NIL;
}
if(x->cnt > 1)
{
return x;
}
else
{
return searchmin(x->ch[1]);
}
}

int size()
{
return root->sz;
}
}sp;
int num[30010];
int cmd[30010];
int main(){

ios::sync_with_stdio(false);
int N,M,x,i;
while(cin >> N >> M)
{
sp.init();
for(int i=0;i<N;i++)
{
cin >> num[i];
}
for(int i=1;i<=M;i++)
{
cin >> cmd[i];
}
int j=1;
Node* node;
for(int i=0;i<N;i++)
{
sp.insert(num[i]);
while(sp.size()==cmd[j])
{
node = sp.findk(j++);
cout << node->key << endl;
}
}

}

return 0;
}


2.基本区间操作

/*
_http://acm.hdu.edu.cn/showproblem.php?pid=3487_
1.区间切割 区间翻转
原序列:	1、2、3 ……n
rangeCut(a,b,c)	将[ath,bth]剪切掉 然后黏贴到新序列的第cth后面
rangeFlip(a,b)	翻转[ath,bth]

_http://acm.hdu.edu.cn/showproblem.php?pid=1754_
2.单点更新 区间询问最值
query(a,b)	询问[a,b]最大值
update(a,c)	更新学生a的成绩为c

_http://poj.org/problem?id=3468_
3.区间更新 区间求和
querySum(l,r)	询问[l,r]的和
updateInterval(l,r,c)	区间[l,r]的值+c

*/
#include<iostream>
#include<cstdio>
#include<cstring>
#define LL long long
using namespace std;
const int MAXN = 333333;
#define m_set(ptr,v,type,size) memset(ptr,v,sizeof(type) * size)
#define loop(begin,end) for(int i=begin;i<end;i++)
#define debug	puts("here!")
class SplayTree
{
#define l(x) (ch[x][0])
#define r(x) (ch[x][1])
#define mid(x,y)	((x+y)>>1)
public:
int ch[MAXN][2],pre[MAXN];
int sz[MAXN],val[MAXN],rev[MAXN],a[MAXN];
int root,tot;

/**************/

int mx[MAXN];

/**************/

LL sum[MAXN],add[MAXN];

/**************/

void init()
{
m_set(ch,0,int,MAXN*2);
m_set(pre,0,int,MAXN);
m_set(sz,0,int,MAXN);
m_set(val,0,int,MAXN);
m_set(rev,0,int,MAXN);

/*************/
m_set(mx,-1,int,MAXN);
/*************/
root = tot = 0;
}

void read(int n)		//0 1 2 3 .... n 0
{
a[1] = a[n+2] = -1;
//			loop(2,n+2){
//				a[i] = i - 1;
//			}

/*************/
for(int i=2;i<=n+1;i++)
{
scanf("%d",&a[i]);
}
/*************/
}

void push_up(int rt)
{
sz[rt] = sz[l(rt)] + sz[r(rt)] + 1;
/*************/
mx[rt] = max(max(mx[l(rt)],mx[r(rt)]),val[rt]);
/*************/
sum[rt] = sum[l(rt)] + sum[r(rt)] + val[rt];
}

void push_down(int rt)
{
//			if(rt&&rev[rt])
//			{
//				swap(l(rt),r(rt));
//				if(l(rt))	rev[l(rt)] ^= 1;
//				if(r(rt))	rev[r(rt)] ^= 1;
//				rev[rt] = 0;
//			}

if(add[rt])
{
if(l(rt))
{
val[l(rt)] += add[rt];
add[l(rt)] += add[rt];
sum[l(rt)] += add[rt] * sz[l(rt)];
}
if(r(rt))
{
val[r(rt)] += add[rt];
add[r(rt)] += add[rt];
sum[r(rt)] += add[rt] * sz[r(rt)];
}
add[rt] = 0;
}
}

void swap(int &x,int &y)
{
int tmp = x;
x = y;
y = tmp;
}

void rotate(int x,int f)	//将x旋转到其直接父节点 右旋 左旋
{
int y  = pre[x];
push_down(x);
push_down(y);
ch[y][!f] = ch[x][f];
if(ch[y][!f])	pre[ch[y][!f]] = y;
push_up(y);
if(pre[y])	ch[pre[y]][r(pre[y])==y]= x;
pre[x] = pre[y];
ch[x][f] = y;
pre[y] = x;
}

void splay(int x,int goal)	//将x旋转到以goal节点为父节点
{
push_down(x);
while(pre[x]!=goal)
{
int y = pre[x] , z = pre[pre[x]];
if(z==goal)
{
rotate(x,l(y)==x);
}
else
{
int f = (l(z) == y);
if(ch[y][!f] == x)
{
rotate(y,f);
rotate(x,f);
}
else
{
rotate(x,!f);
rotate(x,f);
}
}
}
push_up(x);
if(goal==0)	root = x;
}

void rotateTo(int k,int goal)
{
int x = root;
while(true)
{
push_down(x);
int tmp = sz[l(x)] + 1;
if(k==tmp)	break;
else if(k<tmp)	x = l(x);
else
{
k -= tmp;
x = r(x);
}
}
splay(x,goal);
}

void buildTree(int l,int r,int &rt,int f)
{
if(l>r)	return;
int m = mid(l,r);
rt = ++tot;
val[rt] = a[m];
pre[rt] = f;
buildTree(l,m-1,l(rt),rt);
buildTree(m+1,r,r(rt),rt);
push_up(rt);
}

void rangeCut(int l,int r,int goal)
{
rotateTo(l-1,0);
rotateTo(r+1,root);
push_down(r(root));
int x = l(r(root));	//x子树代表[l,r]区间
//截取
l(r(root)) = 0;
pre[x] = 0;
push_up(r(root));
push_up(root);
rotateTo(goal,0);
rotateTo(goal+1,root);
push_down(r(root));
l(r(root)) = x;
pre[x] = r(root);
push_up(r(root));
push_up(root);
}

void rangeFlip(int l,int r)
{
rotateTo(l-1,0);
rotateTo(r+1,root);
push_down(r(root));
int x = l(r(root));
rev[x] ^= 1;
}

void dfs(int rt,int &size)
{
if(!rt)	return;
push_down(rt);
dfs(l(rt),size);
a[size++] = val[rt];
dfs(r(rt),size);
}

/********************/
int query(int l,int r)
{
rotateTo(l-1,0);
rotateTo(r+1,root);
return mx[l(r(root))];
}

void update(int pos,int c)
{
rotateTo(pos,0);
val[root] = c;
push_up(root);
}

LL querySum(int l,int r)
{
rotateTo(l-1,0);
rotateTo(r+1,root);
return sum[l(r(root))];
}

void updateInterval(int l,int r,int c)
{
rotateTo(l-1,0);
rotateTo(r+1,root);
val[l(r(root))] += c;
sum[l(r(root))] += c * sz[l(r(root))];
add[l(r(root))] += c;
}

}spt;

//int main()
//{
//	int n,m;
//	while(~scanf("%d%d",&n,&m) && (n>=0 || m>=0))
//	{
//		spt.init();
//		spt.read(n);
//		spt.buildTree(1,n+2,spt.root,0);
//		char op[5];
//		int a,b,c;
//		while(m--)
//		{
//			scanf("%s%d%d",op,&a,&b);
//			if(op[0]=='C')
//			{
//				scanf("%d",&c);
//				spt.rangeCut(a+1,b+1,c+1);	//因为哨兵0 多加1
//			}
//			else
//			{
//				spt.rangeFlip(a+1,b+1);
//			}
//		}
//		n = 0;
//		spt.dfs(spt.root,n);	//中序遍历即为序列结果
//		loop(1,n-1)
//		{
//			if(i!=1)	printf(" ");
//			printf("%d",spt.a[i]);
//		}
//		printf("\n");
//	}
//	return 0;
//}

//int main()
//{
//	int n,m;
//	while(scanf("%d%d",&n,&m) == 2)
//	{
//		spt.init();
//		spt.read(n);
//		spt.buildTree(1,n+2,spt.root,0);
//		char op[2];
//		int a,b;
//		while(m--)
//		{
//			scanf("%s%d%d",op,&a,&b);
//			if(op[0]=='Q')
//			{
//				printf("%d\n",spt.query(a+1,b+1));
//			}
//			else
//			{
//				spt.update(a+1,b);
//			}
//		}
//	}
//	return 0;
//}

int main()
{
int n,m;
while(~scanf("%d%d",&n,&m))
{
spt.init();
spt.read(n);
spt.buildTree(1,n+2,spt.root,0);
char op[2];
int a,b,c;
while(m--)
{
scanf("%s%d%d",op,&a,&b);
if(op[0] == 'Q')
{
printf("%I64d\n",spt.querySum(a+1,b+1));
}
else
{
scanf("%d",&c);
spt.updateInterval(a+1,b+1,c);
}
}
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: