您的位置:首页 > 其它

HDU 4747 Mex (线段树)

2015-05-02 10:05 387 查看
本题关键是想到如何转换成区间求和操作。

用mex[i]表示第1到第i个数这个区间的mex值,这样对mex[1]~mex
求和,就是所有包含第一个数的区间的mex值和。

包含第一个数的都求完了,可以把第一个删掉。这样mex[i]就变成了第2到第i个数这区间的mex值,同样对2到N这个区间的mex[i]求和,然后一直一个一个删除到最后一个。

删除a[i]会影响哪些区间:

假设i位置后下一次出现a[i]的位置是j,影响的区间就是i~j-1,因此需要修改的区间的结尾为j-1.

j之前的mex[i]如果小于a[i]就修改成a[i],由于mex数组是单调递增的,求出第一个小于a[i]的pos,后面的mex都大于a[i],因此修改区间的开头为pos。

最终需要修改的区间就是pos~j-1

区间修改,区间求和,求第一个小于a[i]的位置,这三个操作都可以用线段树完成。

代码:

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#include <algorithm>
#define maxn 200005
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
#define LL long long
#include <map>
int Next[maxn];
int a[maxn];
int N;
map <int,int> pre;
int Max[maxn*4];
LL sum[maxn*4];
int lazy[maxn*4];
int mex[maxn];
void pushdown(int rt,int l,int r){
	if(lazy[rt]!=-1){
		int v=lazy[rt];
		lazy[rt<<1]=lazy[rt<<1|1]=v;
		Max[rt<<1]=Max[rt<<1|1]=v;
		int mid=(l+r)>>1;
		sum[rt<<1]=(mid-l+1)*v;
		sum[rt<<1|1]=(r-mid)*v;
		lazy[rt]=-1;
	}
}

void pushup(int rt){
	Max[rt]=max(Max[rt<<1],Max[rt<<1|1]);
	sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}

void build(int l,int r,int rt){
	lazy[rt]=-1;
	if(l==r){
		Max[rt]=sum[rt]=mex[l];
		return ;
	}

	int mid=(l+r)/2;
	build(lson);
	build(rson);
	pushup(rt);
}

int GetPos(int l,int r,int rt,int v){
	if(l==r) return l;
	pushdown(rt,l,r);
	int mid=(l+r)>>1;
	if(Max[rt<<1]>v) return GetPos(lson,v);
	else return GetPos(rson,v);
}

void update(int l,int r,int rt,int L,int R,int v){
	if(l>=L&&r<=R){
		sum[rt]=v*(r-l+1);
		Max[rt]=v;
		lazy[rt]=v;
		return ;
	}
	pushdown(rt,l,r);
	int mid=(l+r)/2;
	if(mid>=L) update(lson,L,R,v);
	if(mid<R) update(rson,L,R,v);
	pushup(rt);
}

LL query(int l,int r,int rt,int L,int R){
	if(l>=L&&r<=R) return sum[rt];
	pushdown(rt,l,r);
	LL res=0;
	int mid=(l+r)/2;
	if(mid>=L) res+=query(lson,L,R);
	if(mid<R) res+=query(rson,L,R);
	return res;
}

int main(){
	while(~scanf("%d",&N)){
		if(!N) break;
		pre.clear();
		memset(Next,-1,sizeof(Next));
		int Min=0;
		for(int i=1;i<=N;i++){
			scanf("%d",&a[i]);
			int cur=a[i];
			if(!pre.count(cur)) pre[cur]=i;
			else{
				Next[pre[cur]]=i-1;
				pre[cur]=i;
			}
			while(pre.count(Min)) Min++;
			mex[i]=Min;
		}
		for(int i=1;i<=N;i++){
			if(Next[i]==-1) Next[i]=N;
		}
		build(1,N,1);
		LL res=0;
		for(int i=1;i<=N;i++){
			res+=query(1,N,1,i,N);
			if(Max[1]<a[i]) continue;
			int pos=GetPos(1,N,1,a[i]);
			pos=max(pos,i);
			update(1,N,1,pos,Next[i],a[i]);
		}
		printf("%I64d\n",res);
	}
	return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: