您的位置:首页 > 其它

poj 2750 经典线段树

2011-10-29 19:57 232 查看
题意:给出一数组,数组首尾是可以相接的,要求求出最大连续序列值,并且不可以包括所有元素值。

遇到求这种最大连续序列值,看了网上大牛的思想才知道一般分为两种情况:1、不同时包括两端点的情况,这时直接求整个序列中的最大连续序列值;2、同时包括两端点元素,这时求整个序列中的最小连续序列值,然后用总和减去它,就是所要求的值。

具体怎么求法呢?

假设我们将整个序列分成两个连续的序列a,b;与整个序列设成A。假如我们知道a,b序列各个的从左向右最大连续序列值lmax,从右往左的最大连续序列值rmax,和从左向右最小连续序列值lmin,从右往左的最小连续序列值rmin,和每个序列的最大连续序列值nmax和最小连续序列值nmin,和最大元素值max,和最小元素值min;

那么,就有:

A.nmax=max(a.nmax,b.max,a.rmax+b.lmax,);

A.lmax=max(a.sum+b.lmax,a.max);

A.rmax=max(b.sum+a.rmax,b.rmax);

A.lmin=min(a.sum+b.lmin,a.lmin);

A.rmin=min(b.sum+a.rmin,b.rmin);

A.min=min(a.min,b.min) ;

A.max=max(a.max,b.max) ;

现在的A.nmax并不是最终结果,因为还没考虑最大连续序列值存在两端的情况,

A.nmax=max(A.nmax,A.sum-(a.rmin+b.lmin))

这是不是最终的答案呢? 不是,因为没有考虑全为正数和全为负数的情况,所以这种情况下

A.nmax=A.namx-A.min 或 A.nmax=A.max ;

#include<iostream>
#include<cstdio>
using namespace std;
#define Max(a,b) a>b ? a:b
#define Min(a,b) a<b ? a:b
#define MAX_INT 100000
struct node
{
int sum;
int nmax,nmin;
int min,max;
int lmax,lmin;
int rmax,rmin;
int left;
int right;
};
node interval[10*MAX_INT];
int data[MAX_INT];
int init(int k,int i)
{
interval[i].left=interval[i].right=k;
interval[i].sum=data[k];
interval[i].nmax=interval[i].nmin=data[k];
interval[i].max=interval[i].min=data[k];
interval[i].lmax=interval[i].rmax=data[k];
interval[i].lmin=interval[i].rmin=data[k];
return 0;
}
int modify(int i)
{
int k=i<<1;
interval[i].left=interval[k].left;
interval[i].right=interval[k+1].right;

interval[i].sum=interval[k].sum+interval[k+1].sum;

interval[i].lmax = Max(interval[k].sum+interval[k+1].lmax , interval[k].lmax);
interval[i].rmax= Max(interval[k+1].sum+interval[k].rmax , interval[k+1].rmax);

interval[i].lmin= Min(interval[k].sum+interval[k+1].lmin,interval[k].lmin);
interval[i].rmin= Min(interval[k+1].sum+interval[k].rmin,interval[k+1].rmin);

interval[i].nmax= Max(interval[k].nmax , interval[k+1].nmax);
interval[i].nmax= Max(interval[i].nmax , interval[k].rmax+interval[k+1].lmax);

interval[i].nmin= Min(interval[k].nmin,interval[k+1].nmin);
interval[i].nmin= Min(interval[i].nmin,interval[k].rmin+interval[k+1].lmin);

interval[i].min=Min(interval[k].min , interval[k+1].min);
interval[i].max=Max(interval[k].max , interval[k+1].max);
return 0;
}

int create(int left,int right,int i)
{
int mid;
if(left==right)
{
init(left,i);
return 0;
}
mid=(left+right)>>1;
create(left,mid,i<<1);
create(mid+1,right,(i<<1)+1);
modify(i);
return 0;
}
int update(int root,int k,int w)
{
int i,mid;
i=root;
while(interval[i].left!=interval[i].right)
{
mid=(interval[i].left+interval[i].right)>>1;
if(mid>=k)
i=i<<1;
else
i=(i<<1)+1;
}
interval[i].lmax=interval[i].rmax=w;
interval[i].lmin=interval[i].rmin=w;
interval[i].nmax=interval[i].nmin=w;
interval[i].sum=w;
interval[i].min=interval[i].max=w;
while(i!=root)
{
i=i>>1;   modify(i);
}
if(interval[root].nmax<interval[root].sum-interval[root].nmin)
interval[root].nmax= interval[root].sum-interval[root].nmin;
if(interval[root].nmax==interval[root].sum)
return interval[root].sum-interval[root].min;
if(interval[root].nmin==interval[root].sum)
return interval[root].max;
return interval[root].nmax;
}
int main()
{
int i,k,m,n,w;
while(scanf("%d",&n)!=EOF)
{
for(i=1;i<=n;i++)
scanf("%d",&data[i]);
create(1,n,1);
scanf("%d",&m);
for(i=0;i<m;i++)
{
scanf("%d%d",&k,&w);
printf("%d\n",update(1,k,w));
}
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: