您的位置:首页 > 其它

BZOJ 1500|NOI 2005|维修数列|Splay

2017-12-21 20:46 465 查看
这题没啥好说的。。主要是用来贴模板 3.6s

尝试了抽象Splay这个类,不过卡在了内存池怎么写的问题上。。。感觉对C++理解还不够深刻。

#include <cstdio>
#include <stdexcept>
#include <algorithm>
using namespace std;
#define inf 1000000000
#define N 1000005

int read() {
int s = 0, f = 1; char ch = getchar();
for (; ch < '0' || ch > '9'; ch = getchar()) if (ch == '-') f = -1;
for (; '0' <= ch && ch <= '9'; ch = getchar()) s = s * 10 + ch - '0';
return s * f;
}

template<typename T>
class Splay {
public:
Splay(T a[], int n) {
init(a, n);
}

void cover(int startIndex, int length, T val) {
Node *x = subsequence(startIndex, length), *y = x->fa;
cover_impl(x, val);
update(y); update(y->fa);
}

void reverse(int startIndex, int length) {
Node *x = subsequence(startIndex, length), *y = x->fa;
if (!x->tag) {
reverse_impl(x);
update(y); update(y->fa);
}
}

void erase(int startIndex, int length) {
Node *x = subsequence(startIndex, length), *y = x->fa;
delete_recursively(x); y->left() = Node::null;
update(y); update(y->fa);
}

T getSum(int startIndex, int length) {
return subsequence(startIndex, length)->sum;
}

T getMaxSubsequence() {
return root->mx;
}

void insert(T a[], int startIndex, int length) {
Node *z = build(a, 0, length - 1);
Node *x = find(root, startIndex + 1);
Node *y = find(root, startIndex + 2);
splay(x, root);
splay(y, x->right());
z->fa = y; y->left() = z;
update(y); update(x);
}
private:

struct Node {
Node *c[2], *fa;
int size;
T sum, val, mx, lx, rx;
bool tag, rev;

Node(T v) {
c[0] = c[1] = fa = Node::null;
size = 1;
sum = val = mx = v;
if (v >= 0) lx = rx = v;
else lx = rx = 0;
tag = rev = 0;
}

Node *&left() { return c[0]; }
Node *&right() { return c[1]; }

void *operator new(size_t size) {
static Node *preserve = NULL;
static int cnt = 0;
Node *r;
if (free_pointer == null) {
// allocate memory lazily
if (preserve == NULL || cnt == 0) {
preserve = (Node *) malloc(size * N);
cnt = N;
}
r = &preserve[--cnt];
} else {
// reuse recycled memory
r = free_pointer;
free_pointer = free_pointer->fa;
}
return r;
}

void operator delete(void *p) {
if (p == NULL || p == null) return;
Node *q = static_cast<Node *>(p);
q->fa = free_pointer;
free_pointer = q;
}

static Node *null;
private:
Node() {
c[0] = c[1] = fa = this;
sum = size = val = lx = rx = tag = rev = 0;
mx = -inf;
}
static Node *free_pointer;
} *root;

void update(Node *x) {
if (x == Node::null) return;
Node *l = x->left(), *r = x->right();
x->sum = l->sum + r->sum + x->val;
x->size = l->size + r->size + 1;
x->mx = max(max(l->mx, r->mx), l->rx + x->val + r->lx);
x->lx = max(l->lx, l->sum + x->val + r->lx);
x->rx = max(r->rx, r->sum + x->val + l->rx);
}

void cover_impl(Node *c, T val) {
if (c == Node::null || c == NULL)
return;
c->tag = 1;
c->val = val;
c->sum = val * c->size;
if (val >= 0) {
c->lx = c->rx = c->mx = c->sum;
} else {
c->lx = c->rx = 0;
c->mx = val;
}
}

void reverse_impl(Node *x) {
if (x == Node::null || x == NULL)
return;
x->rev ^= 1;
swap(x->left(), x->right());
swap(x->lx, x->rx);
}

void pushdown(Node *x) {
Node *l = x->left(), *r = x->right();
if (x->tag) {
cover_impl(l, x->val);
cover_impl(r, x->val);
x->rev = x->tag = 0;
}
if (x->rev) {
x->rev ^= 1;
reverse_impl(x->left());
reverse_impl(x->right());
}
}

void rotate(Node *x, Node *&dest) {
Node *y = x->fa, *z = y->fa;
int l = y->right() == x, r = l ^ 1;
if (y == dest) dest = x;
else z->c[z->right() == y] = x;
x->c[r]->fa = y; y->fa = x; x->fa = z;
y->c[l] = x->c[r]; x->c[r] = y;
update(y); update(x);
}

void splay(Node *x, Node *&dest) {
while (x != dest) {
Node *y = x->fa, *z = y->fa;
if (y != dest) {
if ((y->left() == x) ^ (z->left() == y)) rotate(x, dest);
else rotate(y, dest);
}
rotate(x, dest);
}
}

Node *find(Node *x, int rank) {
pushdown(x);
if (x->left()->size + 1 == rank)
return x;
if (x->left()->size >= rank)
return find(x->left(), rank);
return find(x->right(), rank - x->left()->size - 1);
}

void delete_recursively(Node *x) {
if (x == Node::null) return;
delete_recursively(x->left());
delete_recursively(x->right());
delete x;
}

Node *&subsequence(int startIndex, int length) {
// for initialization, adding -oo forward and backward.
Node *x = find(root, startIndex), *y = find(root, startIndex + length + 1);
splay(x, root); splay(y, x->right());
return y->left();
}

void init(T a[], int n) {
for (int i = n; i; --i)
a[i] = a[i - 1];
a[0] = a[n + 1] = -inf;
root = build(a, 0, n + 1);
}

Node *build(T a[], int l, int r, Node *fa = Node::null) {
if (l > r) return Node::null;
int mid = (l + r) / 2;
Node *now = new Node(a[mid]);
now->fa = fa;
if (l != r) {
now->left() = build(a, l, mid - 1, now);
now->right() = build(a, mid + 1, r, now);
}
update(now);
return now;
}
};
template<typename T> typename Splay<T>::Node *Splay<T>::Node::null = new Node();
template<typename T> typename Splay<T>::Node *Splay<T>::Node::free_pointer = Node::null;

int main() {
static int a
;
int i, n, m, k, tot, val;
char ch[10];

n = read(); m = read();
for (i = 0; i < n; ++i) a[i] = read();
Splay<int> *splay = new Splay<int>(a, n);

while(m--) {
scanf("%s", ch);
if (ch[0] != 'M' || ch[2] != 'X') k = read(), tot = read();
if (ch[0] == 'I') {
for (i = 0; i < tot; ++i) a[i] = read();
splay->insert(a, k, tot);
}
if (ch[0] == 'D') splay->erase(k, tot);
if (ch[0] == 'M') {
if (ch[2] == 'X') printf("%d\n", splay->getMaxSubsequence());
else val = read(), splay->cover(k, tot, val);
}
if (ch[0] == 'R') splay->reverse(k, tot);
if (ch[0] == 'G') printf("%d\n", splay->getSum(k, tot));
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: