您的位置:首页 > 其它

hdu 4747(区间更新)

2015-08-31 23:36 253 查看
题意:一个长度为n的序列,然后求任意左右区间l,r中没有出现过的最小的数字的和。

题解:如果固定区间的左端点得到的所有区间的解是从左到右发现是一个递增序列,用线段树维护当前固定左端点的区间的解的和和最大值,然后更新下一个左端点a[i+1]的区间,那么a[i]就要删除,发现以a[i+1]为左端点的区间所有mex值要把之前第一个mex大于a[i]的位置到下一个a[i]的位置所有值设置为a[i]。

举个例子: 3 2 1 0 2 3 1

左端点是a[1] = 3 , mex: 0 0 0 4 4 4 4

左端点是a[2] = 2 , mex: 0 0 0 3 3 4 4



可以发现当左端点是a[2] = 2时,[2,3]的mex都是0< a[1] = 3,然后mex[4] = 4 > a[1] = 3的位置4到 a[6] = a[1] = 3的位置6的解都是a[1],所以这里直接区间更新。

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
#define ll long long
using namespace std;
const int N = 200005;
int n, a
, f
, pre
, maxx[N << 2], flag[N << 2];
ll sum[N << 2];
map<int, int> mp;

void pushup(int k) {
sum[k] = sum[k * 2] + sum[k * 2 + 1];
maxx[k] = max(maxx[k * 2], maxx[k * 2 + 1]);
}

void pushdown(int k, int left, int right) {
if (flag[k] >= 0) {
int mid = (left + right) / 2;
flag[k * 2] = flag[k * 2 + 1] = flag[k];
sum[k * 2] = (mid - left + 1) * flag[k];
sum[k * 2 + 1] = (right - mid) * flag[k];
maxx[k * 2] = maxx[k * 2 + 1] = flag[k];
flag[k] = -1;
}
}

void build(int k, int left, int right) {
flag[k] = -1;
if (left == right) {
sum[k] = maxx[k] = f[left];
return;
}
int mid = (left + right) / 2;
build(k * 2, left, mid);
build(k * 2 + 1, mid + 1, right);
pushup(k);
}

void modify(int k, int left, int right, int l, int r, ll x) {
if (l <= left && right <= r) {
flag[k] = x;
sum[k] = (right - left + 1) * x;
maxx[k] = x;
return;
}
pushdown(k, left, right);
int mid = (left + right) / 2;
if (r <= mid)
modify(k * 2, left, mid, l, r, x);
else if (l > mid)
modify(k * 2 + 1, mid + 1, right, l, r, x);
else {
modify(k * 2, left, mid, l, mid, x);
modify(k * 2 + 1, mid + 1, right, mid + 1, r, x);
}
pushup(k);
}

int query(int k, int left, int right, int x) {
if (left == right)
return left;
pushdown(k, left, right);
int mid = (left + right) / 2;
if (maxx[k * 2] > x)
return query(k * 2, left, mid, x);
return query(k * 2 + 1, mid + 1, right, x);
}

int main() {
while (scanf("%d", &n) == 1 && n) {
mp.clear();
int temp = 0;
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
mp[a[i]] = 1;
while (mp.find(temp) != mp.end())
temp++;
f[i] = temp;
}
mp.clear();
for (int i = n; i >= 1; i--) {
if (mp.find(a[i]) == mp.end())
pre[i] = n + 1;
else pre[i] = mp[a[i]];
mp[a[i]] = i;
}
build(1, 1, n);
ll res = 0;
for (int i = 1; i <= n; i++) {
res += sum[1];
if (maxx[1] > a[i]) {
int l = query(1, 1, n, a[i]), r = pre[i] - 1;
if (l <= r)
modify(1, 1, n, l, r, a[i]);
}
modify(1, 1, n, i, i, 0);
}
printf("%lld\n", res);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: