您的位置:首页 > 其它

hdu4747 Mex (线段树 好题)

2015-09-07 18:45 204 查看
(这次的代码写得非常的丑,因为敲代码时的环境非常的乱,大家可以不用看了。。。。。。)

题目大意:

给一个数字串。

然后定义一个函数Mex({A}) = A中没有的最小的非负整数。

即若A = {0、1、3},则Mex(A) = 2

然后要求数字串的所有连续子串的Mex值之和。

总结:

这一题我想了很久。这一题的解放真是太有意思了,处于我思维的一个盲点。

我一开始想过处理出Mex = 1、2、3、.......、n-1、n 的子串数,然后统计答案。

然后我从另外一个方向想:

依次去统计

[n, n]

[n-1, n-1] , [n-1, n]

[n-1, n-2] , [n-2, n-1] , [n-2, n]

......

[1 , 1] , [1 , 2] , [1 , 3] , [1 , 4] , [1 , 5] , ..... , [1, n-1] , [1 , n]

然后[i, x] 到[i-1, x] ,用什么数据结构去维护。

从这个方向去想也一直没有结果,后面看了网上的题解,才发现原来还能够这么做。

做法是先用O(N)的时间计算出[1, 1] , [1 , 2] , [1 , 3] , ...... , [1 , n-1] , [1 , n] 的值

然后在其基础上,用O(logn)的时间复杂度计算出

[2, 2] , [2, 3] , [2 , 4] , ....... , [2 , n-1] , [2 , n]的值。

这样转换n次,就能算出所有连续子串的合,并且时间复杂度是O(N*logN)

具体做法:

先用O(n)的时间计算出[1, 1] , [1, 2] , [1, 3] , [1, 4], ..... , [1, n-1] , [1, n]的Mex值。

具体代码:

value = 0;
for (i = 1; i <= n; i ++)
{
if (a[i] <= n)
flag[a[i]] = 1;
while (flag[value] == 1)
value ++;
h[i] = value;
}


h[i]表示的就是[1, i]的Mex值

由于Mex的性质,h[]的值是单调不下降序列。

然后考虑将h[i]从表示[1, i]的Mex值转换到[2, i]的Mex值。

考虑所有的连续子串减去a[1]以后,对h[]有什么影响。

若[1, i]去掉a[1]以后,[2, i]内仍然含有等于a[1]的值,则h[i]的值不发生改变。

若[1, i]去掉a[1]以后,[2, i]内不含有等于a[1]的值,则h[i]的值不能够大于a[i],即若此时h[i] > a[i] , 令h[i] = a[i]

这样就能将h[i]从表示[1, i]的Mex值转换到[2, i]的Mex值,此时的h[i]仍然是单调不下降的。

直接转换的时间复杂度是O(n)的,但由于h[]是一个单调不下降序列,所以可以使用线段树,使转换在O(logn)的时间内完成。

这个转换需要完成两个操作。

操作一,找出满足h[k] > a[i] 的最小的k

这个可以在线段树中加了一个max标记,表示这个区间内最大的h[]值,利用这个标记,可以在线段树上在O(logn)的时间内找出k值

int find(int x, int y, int value, int t)	//在[x, y]中查询a[k] > value的k
{
int ans1, ans2;
if (x > lt[t].y || y < lt[t].x)
return n + 1;
if (lt[t].max <= value)
return n + 1;
if (lt[t].flag == 1 && lt[t].x != lt[t].y)
{
lt[t*2].get_property(lt[t]);
lt[t*2+1].get_property(lt[t]);
}
if (x <= lt[t].x && lt[t].y <= y)
{
if (lt[t].x == lt[t].y)
return lt[t].x;
else if (lt[t*2].max > value)
return find(x, y, value, t*2);
else return find(x, y, value, t*2+1);
}
else if (lt[t].x != lt[t].y)
{
ans1 = find(x, y, value, t*2);
ans2 = find(x, y, value, t*2+1);
return min(ans1, ans2);
}
}


操作二:在操作一的基础上,若k < a[i].next (a[i].next表示下一个等于a[i]的数字的下标),则将[k, a[i].next-1]上的h[i]值赋值为a[i]

这个因为每个的h[i]值只会下降不会上升,所以可以利用flag标记(表示整个区间是否等于同一个值)来进行线段树节点值传递。

这样就可以在O(logn)的时间内更新线段树的的max,sum,value,flag。(最大值、总和、当flag==1时,整个区间中的h[]值,整段区间相同标志)

我的更新操作代码:

void updata(int x, int y, int value, int t)
{
if (x > lt[t].y || y < lt[t].x)
return;
if (x <= lt[t].x && lt[t].y <= y)
{
lt[t].value = value;
lt[t].flag = 1;
lt[t].sum = 1ll * lt[t].value * (lt[t].y - lt[t].x + 1);
lt[t].max = value;
return;
}
if (lt[t].x != lt[t].y)
{
if (lt[t].flag == 1)
{
lt[t*2].get_property(lt[t]);
lt[t*2+1].get_property(lt[t]);
lt[t].flag = 0;
}
updata(x, y, value, t*2);
updata(x, y, value, t*2+1);
lt[t].max = max(lt[t*2].max, lt[t*2+1].max);
lt[t].sum = lt[t*2].sum + lt[t*2+1].sum;
}
}


最后贴上我全部的代码:

#include <iostream>
#include <stdio.h>
#include <stdlib.h>
using namespace std;

const int MAXN = 2e5 + 100;

class node
{
public:
int x, y, value, flag, max;
long long sum;
void get_property(node &b)
{
value = b.value;
flag = b.flag;
max = b.max;
sum = 1ll * (y - x + 1) * value;
}
};
node lt[MAXN*5];
int a[MAXN], flag[MAXN], h[MAXN];
int p_next[MAXN], head[MAXN];
int n;

void build(int x, int y, int t)
{
lt[t].x = x;
lt[t].y = y;
lt[t].value = 0;
lt[t].flag = 0;
lt[t].sum = 0;
lt[t].max = 0;
if (x != y)
{
int mid = (x + y) / 2;
build(x, mid, t*2);
build(mid+1, y, t*2+1);
}
}
void updata(int x, int y, int value, int t) { if (x > lt[t].y || y < lt[t].x) return; if (x <= lt[t].x && lt[t].y <= y) { lt[t].value = value; lt[t].flag = 1; lt[t].sum = 1ll * lt[t].value * (lt[t].y - lt[t].x + 1); lt[t].max = value; return; } if (lt[t].x != lt[t].y) { if (lt[t].flag == 1) { lt[t*2].get_property(lt[t]); lt[t*2+1].get_property(lt[t]); lt[t].flag = 0; } updata(x, y, value, t*2); updata(x, y, value, t*2+1); lt[t].max = max(lt[t*2].max, lt[t*2+1].max); lt[t].sum = lt[t*2].sum + lt[t*2+1].sum; } }
int find(int x, int y, int value, int t) //在[x, y]中查询a[k] > value的k { int ans1, ans2; if (x > lt[t].y || y < lt[t].x) return n + 1; if (lt[t].max <= value) return n + 1; if (lt[t].flag == 1 && lt[t].x != lt[t].y) { lt[t*2].get_property(lt[t]); lt[t*2+1].get_property(lt[t]); } if (x <= lt[t].x && lt[t].y <= y) { if (lt[t].x == lt[t].y) return lt[t].x; else if (lt[t*2].max > value) return find(x, y, value, t*2); else return find(x, y, value, t*2+1); } else if (lt[t].x != lt[t].y) { ans1 = find(x, y, value, t*2); ans2 = find(x, y, value, t*2+1); return min(ans1, ans2); } }
void init()
{
int i;
for (i = 0; i <= n; i ++)
{
flag[i] = 0;
head[i] = n + 1;
}
}
int main()
{
int i, pos_x, pos_y, value;
long long ans;
while (scanf("%d",&n))
{
if (n == 0)
break;
init();
for (i = 1; i <= n; i ++)
{
scanf("%d", &a[i]);
if (a[i] > n)
a[i] = n + 10;
}
value = 0;
for (i = 1; i <= n; i ++)
{
if (a[i] <= n)
flag[a[i]] = 1;
while (flag[value] == 1)
value ++;
h[i] = value;
}
for (i = n; i >= 1; i --)
if (a[i] <= n)
{
p_next[i] = head[a[i]];
head[a[i]] = i;
}
build(1, n, 1);
for (i = 1; i <= n; i ++)
updata(i, i, h[i], 1);
ans = 0;
for (i = 1; i <= n; i ++)
{
ans += lt[1].sum;
if (a[i] <= n)
{
pos_y = p_next[i];
pos_x = find(i, pos_y, a[i], 1);
if (pos_x <= pos_y)
updata(pos_x, pos_y-1, a[i], 1);
}
}
printf("%lld\n", ans);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: