您的位置:首页 > 其它

【BZOJ3224】【TYVJ1728】普通平衡树

2018-01-14 19:53 399 查看
【题目链接】

点击打开链接

【思路要点】

本题包含了平衡树最基本的操作。是任何学习平衡树都应当先做一遍的题。
笔者实现了四种平衡树,Splay、Treap、替罪羊树和非旋转式Treap(以及其可持久化)。

【代码】

Splay
/*Splay Tree Version*/
#include<bits/stdc++.h>
using namespace std;
#define MAXN	100005
template <typename T> void read(T &x) {
x = 0; int f = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
x *= f;
}
struct Splay {
int root, total;
int child[MAXN][2], father[MAXN];
int index[MAXN], size[MAXN], cnt[MAXN];
bool get(int x) {
return x == child[father[x]][1];
}
void update(int x) {
size[x] = cnt[x];
size[x] += size[child[x][0]];
size[x] += size[child[x][1]];
}
void rotate(int x) {
int f = father[x], g = father[f];
if (f == 0) return;
int tmp = get(x), tnp = get(f);
child[f][tmp] = child[x][tmp ^ 1];
if (child[x][tmp ^ 1]) father[child[x][tmp ^ 1]] = f;
child[x][tmp ^ 1] = f;
father[f] = x;
father[x] = g;
if (g) child[g][tnp] = x;
update(f);
update(x);
}
void splay(int x) {
for (int f = father[x]; (f = father[x]); rotate(x))
if (get(f) == get(x)) rotate(f);
else rotate(x);
root = x;
}
void insert(int x) {
if (root == 0) {
root = ++total;
index[root] = x;
cnt[root] = size[root] = 1;
return;
}
int now = root;
while (true) {
if (x == index[now]) {
cnt[now]++;
splay(now);
return;
}
bool tmp = index[now] < x;
if (child[now][tmp]) now = child[now][tmp];
else {
child[now][tmp] = ++total;
father[total] = now;
index[total] = x;
cnt[total] = size[total] = 1;
splay(total);
return;
}
}
}
int rank(int x) {
int now = root, ans = 1;
while (true) {
if (index[now] <= x) {
ans += size[child[now][0]];
if (index[now] == x) {
splay(now);
return ans;
}
ans += cnt[now];
now = child[now][1];
} else now = child[now][0];
}
}
int pre() {
int now = child[root][0];
while (child[now][1])
now = child[now][1];
return now;
}
int suc() {
int now = child[root][1];
while (child[now][0])
now = child[now][0];
return now;
}
void del(int x) {
rank(x);
if (cnt[root] >= 2) {
cnt[root]--;
size[root]--;
return;
}
if (child[root][0] == 0 && child[root][1] == 0) {
root = 0;
return;
}
if (child[root][0] == 0) {
root = child[root][1];
father[root] = 0;
return;
}
if (child[root][1] == 0) {
root = child[root][0];
father[root] = 0;
return;
}
splay(pre());
child[root][1] = child[child[root][1]][1];
father[child[root][1]] = root;
update(root);
}
int find(int x) {
int now = root;
while (true) {
if (x <= size[child[now][0]]) now = child[now][0];
else {
x -= size[child[now][0]];
if (x <= cnt[now]) {
splay(now);
return index[now];
}
x -= cnt[now];
now = child[now][1];
}
}
}
int pred(int x) {
insert(x);
int ans = index[pre()];
del(x);
return ans;
}
int succ(int x) {
insert(x);
int ans = index[suc()];
del(x);
return ans;
}
} T;
int main() {
int n; read(n);
for (int i = 1; i <= n; i++) {
int opt, value;
read(opt), read(value);
if (opt == 1) T.insert(value);
if (opt == 2) T.del(value);
if (opt == 3) printf("%d\n", T.rank(value));
if (opt == 4) printf("%d\n", T.find(value));
if (opt == 5) printf("%d\n", T.pred(value));
if (opt == 6) printf("%d\n", T.succ(value));
}
return 0;
}

Treap
/*Treap Version*/
#include<bits/stdc++.h>
using namespace std;
#define MAXN	100005
template <typename T> void read(T &x) {
x = 0; int f = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
x *= f;
}
struct Treap {
struct Node {
int index, weight, size, cnt;
int father, child[2];
} a[MAXN];
int root, size;
void update(int root) {
a[root].size = a[root].cnt;
if (a[root].child[0]) a[root].size += a[a[root].child[0]].size;
if (a[root].child[1]) a[root].size += a[a[root].child[1]].size;
}
int new_node(int value) {
size++;
a[size].index = value;
a[size].cnt = a[size].size = 1;
a[size].weight = rand();
return size;
}
bool get(int x) {
return x == a[a[x].father].child[1];
}
void rotate(int x) {
int f = a[x].father, g = a[f].father;
bool tmp = get(x), tnp = get(f);
a[f].child[tmp] = a[x].child[tmp ^ 1];
a[a[x].child[tmp ^ 1]].father = f;
a[f].father = x;
a[x].child[tmp ^ 1] = f;
a[x].father = g;
if (g) a[g].child[tnp] = x;
update(f); update(x);
if (f == root) root = x;
}
void pushup(int x) {
int f = a[x].father;
while (f && a[x].weight < a[f].weight) {
rotate(x);
f = a[x].father;
}
}
void insert(int value) {
if (root == 0) {
root = new_node(value);
return;
}
int now = root;
while (true) {
a[now].size++;
if (a[now].index == value) {
a[now].cnt++;
return;
}
bool tmp = value > a[now].index;
if (a[now].child[tmp]) now = a[now].child[tmp];
else {
a[now].child[tmp] = new_node(value);
a[size].father = now;
pushup(size);
return;
}
}
}
void del(int &root, int fa, int value) {
a[root].size--;
if (a[root].index == value) {
if (a[root].cnt >= 2) {
a[root].cnt--;
return;
}
if (a[root].child[0] == 0) {
root = a[root].child[1];
a[root].father = fa;
return;
}
if (a[root].child[1] == 0) {
root = a[root].child[0];
a[root].father = fa;
return;
}
if (a[a[root].child[0]].weight < a[a[root].child[1]].weight) {
int tmp = a[root].child[0];
rotate(tmp); a[tmp].size--;
del(a[tmp].child[1], tmp, value);
} else {
int tmp = a[root].child[1];
rotate(tmp); a[tmp].size--;
del(a[tmp].child[0], tmp, value);
}
} else if (value < a[root].index) del(a[root].child[0], root, value);
else del(a[root].child[1], root, value);
}
void del(int value) {
del(root, 0, value);
}
int rank(int value) {
int now = root, ans = 1;
while (true) {
if (value < a[now].index) now = a[now].child[0];
else {
ans += a[a[now].child[0]].size;
if (value == a[now].index) return ans;
ans += a[now].cnt;
now = a[now].child[1];
}
}
}
int find(int rank) {
int now = root;
while (true) {
if (rank <= a[a[now].child[0]].size) now = a[now].child[0];
else {
rank -= a[a[now].child[0]].size;
if (rank <= a[now].cnt) return a[now].index;
rank -= a[now].cnt;
now = a[now].child[1];
}
}
}
int Rank(int value) {
int now = root, ans = 0;
while (true) {
if (value < a[now].index) now = a[now].child[0];
else {
ans += a[a[now].child[0]].size;
ans += a[now].cnt;
if (value == a[now].index) return ans;
now = a[now].child[1];
}
}
}
int pred(int value) {
insert(value);
int ans = find(rank(value) - 1);
del(value);
return ans;
}
int succ(int value) {
insert(value);
int ans = find(Rank(value) + 1);
del(value);
return ans;
}
} T;
int main() {
int n; read(n);
for (int i = 1; i <= n; i++) {
int opt, x;
read(opt), read(x);
if (opt == 1) T.insert(x);
if (opt == 2) T.del(x);
if (opt == 3) printf("%d\n", T.rank(x));
if (opt == 4) printf("%d\n", T.find(x));
if (opt == 5) printf("%d\n", T.pred(x));
if (opt == 6) printf("%d\n", T.succ(x));
}
return 0;
}

替罪羊树
/*Scapegoat Tree Version*/
#include<bits/stdc++.h>
using namespace std;
#define MAXN	100005
#define ALPHA	0.75
template <typename T> void read(T &x) {
x = 0; int f = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
x *= f;
}
struct Scapegoat_Tree {
int root, tmp, top, pos, mem[MAXN];
int len, tindex[MAXN], tcnt[MAXN];
int child[MAXN][2], father[MAXN];
int index[MAXN], size[MAXN], cnt[MAXN], rsize[MAXN];
private:
int new_node() {
index[mem[top]] = 0;
size[mem[top]] = cnt[mem[top]] = 0;
rsize[mem[top]] = 0;
father[mem[top]] = 0;
child[mem[top]][0] = 0;
child[mem[top]][1] = 0;
return mem[top--];
}
void clear(int x) {
mem[++top] = x;
}
void update(int x) {
size[x] = cnt[x];
size[x] += size[child[x][0]];
size[x] += size[child[x][1]];
rsize[x] = 1;
rsize[x] += rsize[child[x][0]];
rsize[x] += rsize[child[x][1]];
}
void dfs(int x) {
if (child[x][0]) dfs(child[x][0]);
if (cnt[x]) {
len++; tindex[len] = index[x];
tcnt[len] = cnt[x];
}
if (child[x][1]) dfs(child[x][1]);
if (tmp != x) clear(x);
}
void rebuild(int x, int l, int r) {
if (l == r) {
rsize[x] = 1;
size[x] = cnt[x] = tcnt[l];
index[x] = tindex[l];
return;
}
int mid = (l + r) / 2;
rsize[x] = 1;
size[x] = cnt[x] = tcnt[mid];
index[x] = tindex[mid];
if (mid > l) {
child[x][0] = new_node();
father[child[x][0]] = x;
rebuild(child[x][0], l, mid - 1);
}
if (mid < r) {
child[x][1] = new_node();
father[child[x][1]] = x;
rebuild(child[x][1], mid + 1, r);
}
update(x);
}
void rebuild(int x) {
len = 0; tmp = x;
dfs(x);
child[x][0] = child[x][1] = 0;
rebuild(x, 1, len);
}
bool unbalance(int x) {
return max(rsize[child[x][0]], rsize[child[x][1]]) > rsize[x] * ALPHA + 1;
}
void insert(int root, int value) {
if (index[root] == value) {
cnt[root]++;
size[root]++;
return;
}
bool t = value > index[root];
if (child[root][t]) insert(child[root][t], value);
else {
int tmp = new_node();
child[root][t] = tmp;
father[tmp] = root;
index[tmp] = value;
size[tmp] = cnt[tmp] = 1;
rsize[tmp] = 1;
}
update(root);
if (unbalance(root)) pos = root;
}
void del(int root, int value) {
if (index[root] == value) {
cnt[root]--;
size[root]--;
return;
}
bool tmp = value > index[root];
del(child[root][tmp], value);
update(root);
if (unbalance(root)) pos = root;
}
public:
void init(int n) {
top = 1;
while (top <= n) {
mem[top] = n - top + 1;
top++;
}
top--;
}
void insert(int x) {
if (root == 0) {
root = new_node();
index[root] = x;
rsize[root] = 1;
size[root] = cnt[root] = 1;
return;
}
pos = 0;
insert(root, x);
if (pos) rebuild(pos);
}
void del(int x) {
pos = 0;
del(root, x);
if (pos) rebuild(pos);
}
int rank(int x) { /*MinimumRank*/
int now = root, ans = 1;
while (true) {
if (x < index[now]) now = child[now][0];
else {
ans += size[child[now][0]];
if (index[now] == x) return ans;
ans += cnt[now];
now = child[now][1];
}
}
}
int rbnk(int x) { /*MaximumRank*/
int now = root, ans = 0;
while (true) {
if (x < index[now]) now = child[now][0];
else {
ans += size[child[now][0]];
ans += cnt[now];
if (index[now] == x) return ans;
now = child[now][1];
}
}
}
int find(int x) {
int now = root;
while (true) {
if (size[child[now][0]] >= x) now = child[now][0];
else {
x -= size[child[now][0]];
if (x <= cnt[now]) return index[now];
x -= cnt[now];
now = child[now][1];
}
}
}
int pred(int x) {
insert(x);
int ans = find(rank(x) - 1);
del(x);
return ans;
}
int succ(int x) {
insert(x);
int ans = find(rbnk(x) + 1);
del(x);
return ans;
}
} T;
int main() {
int n; read(n);
T.init(n);
for (int i = 1; i <= n; i++) {
int opt, value;
read(opt), read(value);
if (opt == 1) T.insert(value);
if (opt == 2) T.del(value);
if (opt == 3) printf("%d\n", T.rank(value));
if (opt == 4) printf("%d\n", T.find(value));
if (opt == 5) printf("%d\n", T.pred(value));
if (opt == 6) printf("%d\n", T.succ(value));
}
return 0;
}

非旋转式Treap

/*Persistent Treap Version*/
#include<bits/stdc++.h>
using namespace std;
#define MAXN	100005
#define MAXP	3000005
template <typename T> void read(T &x) {
x = 0; int f = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
x *= f;
}
struct Persistent_Treap {
struct Node {
int lc, rc;
int weight, index, cnt, size;
} a[MAXP];
int root[MAXN];
int size, version;
int new_node(int value) {
size++;
a[size].cnt = 1;
a[size].size = 1;
a[size].index = value;
a[size].weight = rand();
return size;
}
int copy(int x) {
size++;
a[size].cnt = a[x].cnt;
a[size].index = a[x].index;
a[size].weight = a[x].weight;
return size;
}
void update(int x) {
a[x].size = a[x].cnt;
if (a[x].lc) a[x].size += a[a[x].lc].size;
if (a[x].rc) a[x].size += a[a[x].rc].size;
}
int rank(int root, int value) { /*Number of index that are less than value*/
int ans = 0;
while (root) {
if (value < a[root].index) root = a[root].lc;
else {
ans += a[a[root].lc].size;
if (value == a[root].index) return ans;
ans += a[root].cnt;
root = a[root].rc;
}
}
return ans;
}
int rbnk(int root, int value) { /*Number of index that are less than or equal with value*/
int ans = 0;
while (root) {
if (value < a[root].index) root = a[root].lc;
else {
ans += a[a[root].lc].size;
ans += a[root].cnt;
if (value == a[root].index) return ans;
root = a[root].rc;
}
}
return ans;
}
int rank(int value) {
return rank(root[version], value);
}
int rbnk(int value) {
return rbnk(root[version], value);
}
int merge(int x, int y) {
if (x == 0) return y;
if (y == 0) return x;
if (a[x].weight < a[y].weight) {
int tmp = copy(x);
a[tmp].lc = a[x].lc;
a[tmp].rc = merge(a[x].rc, y);
update(tmp);
return tmp;
} else {
int tmp = copy(y);
a[tmp].rc = a[y].rc;
a[tmp].lc = merge(x, a[y].lc);
update(tmp);
return tmp;
}
}
pair <int, int> split(int x, int cnt) {
if (cnt == 0) return make_pair(0, x);
if (a[x].size == cnt) return make_pair(x, 0);
if (cnt <= a[a[x].lc].size) {
pair <int, int> tmp = split(a[x].lc, cnt);
int tnp = copy(x);
a[tnp].rc = a[x].rc;
a[tnp].lc = tmp.second;
update(tnp);
tmp.second = tnp;
return tmp;
}
cnt -= a[a[x].lc].size;
if (cnt < a[x].cnt) {
int tmp = copy(x);
a[tmp].cnt = cnt;
a[tmp].lc = a[x].lc;
update(tmp);
int tnp = copy(x);
a[tnp].cnt = a[x].cnt - cnt;
a[tnp].rc = a[x].rc;
update(tnp);
return make_pair(tmp, tnp);
}
cnt -= a[x].cnt;
pair <int, int> tmp = split(a[x].rc, cnt);
int tnp = copy(x);
a[tnp].lc = a[x].lc;
a[tnp].rc = tmp.first;
update(tnp);
tmp.first = tnp;
return tmp;
}
void insert(int value) {
int cnt = rank(value);
int cmt = rbnk(value);
if (cnt == cmt) {
pair <int, int> tmp = split(root[version], cnt);
int tnp = new_node(value);
root[++version] = merge(merge(tmp.first, tnp), tmp.second);
} else {
pair <int, int> tmp = split(root[version], cnt);
pair <int, int> tnp = split(tmp.second, cmt - cnt);
int New = copy(tnp.first);
a[New].cnt++; update(New);
root[++version] = merge(merge(tmp.first, New), tnp.second);
}
}
void del(int value) {
int cnt = rank(value);
int cmt = rbnk(value);
pair <int, int> tmp = split(root[version], cnt);
pair <int, int> tnp = split(tmp.second, cmt - cnt);
int New = copy(tnp.first);
a[New].cnt--; update(New);
if (a[New].cnt == 0) root[++version] = merge(tmp.first, tnp.second);
else root[++version] = merge(merge(tmp.first, New), tnp.second);
}
int find(int root, int cnt) {
while (true) {
if (cnt <= a[a[root].lc].size) root = a[root].lc;
else {
cnt -= a[a[root].lc].size;
if (cnt <= a[root].cnt) return a[root].index;
cnt -= a[root].cnt;
root = a[root].rc;
}
}
}
int find(int cnt) {
return find(root[version], cnt);
}
int pred(int value) {
return find(rank(value));
}
int succ(int value) {
return find(rbnk(value) + 1);
}
} PT;
int main() {
int n; read(n);
for (int i = 1; i <= n; i++) {
int opt, x;
read(opt), read(x);
if (opt == 1) PT.insert(x);
if (opt == 2) PT.del(x);
if (opt == 3) printf("%d\n", PT.rank(x) + 1);
if (opt == 4) printf("%d\n", PT.find(x));
if (opt == 5) printf("%d\n", PT.pred(x));
if (opt == 6) printf("%d\n", PT.succ(x));
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: