您的位置:首页 > 其它

线段树典型例题--poj2777

2012-07-23 17:37 253 查看
这到题我认为网上有些人的算法是不对的。

void solve(int l,int r,int root) //询问
{

if(tree[root].col>=0) //如果父节点有单一的颜色,就直接更新,不需要找到子节点更新
{
flag[tree[root].col]=1;//统计哪些颜色出现过
return;
}
if(tree[root].left==tree[root].right) return;
int mid=(tree[root].left+tree[root].right)>>1;
if(l>mid) solve(l,r,(root<<1)+1);
else if(r<=mid) solve(l,r,root<<1);
else
{
solve(l,mid,root<<1);
solve(mid+1,r,(root<<1)+1);
}
}


这是某些人线段树中询问的算法,仔细想想,这个算法是会退化到线性复杂度的。只要相邻的点颜色都不一样即可。

我出了一组数据:

#include <cstdio>
#include <iostream>
using namespace std;

int main()
{
freopen("in","w",stdout);
int i;
printf("%d %d %d\n",100000,30,100000);
for (i=1;i<=50000;i++)
printf("C %d %d %d\n",i,i,i%30+1);
for (i=1;i<=50000;i++)
printf("P %d %d\n",1,100000);
}
上述程序产生的数据会让错误程序超时(大约两分钟才能跑出结果)

那么真正的算法究竟如何??

应当使用延迟标记。同时可以使用二进制记录节点中的颜色情况。

具体维护,我想显而易见。

我说那么多的目的,希望大家在刷题的时候不要被poj弱数据所迷惑,要严格要求自己,才会进步。同时也提醒自己。

【代码】

#include <iostream>
#include <cstring>
#include <string>
#include <cstdio>
#include <algorithm>
using namespace std;

const int N=300000;

int col
;
int n,m,t,sum;
bool pp
,ans[33];

void down(int i)
{
if (!pp[i]) return;
col[i*2]=col[i*2+1]=col[i];
pp[i*2]=pp[i*2+1]=true;
pp[i]=false;
}

void update(int i)
{
col[i]=col[i*2]|col[i*2+1];
}

void ins(int i,int l,int r,int x,int y,int k)
{
if (x<=l && y>=r)
{
col[i]=1<<(k-1);
pp[i]=true;
return;
}
down(i);
int mid=(l+r)/2;
if (x<=mid) ins(i*2,l,mid,x,y,k);
if (y>mid) ins(i*2+1,mid+1,r,x,y,k);
update(i);
}

void find(int i,int l,int r,int x,int y)
{
if (x<=l && y>=r)
{
int tmp=col[i];
for (int p=1;p<=t;p++)
{
ans[p]|=tmp&1;
tmp>>=1;
if (tmp==0) break;
}
return;
}
down(i);
int mid=(l+r)/2;
if (x<=mid) find(i*2,l,mid,x,y);
if (y>mid) find(i*2+1,mid+1,r,x,y);
update(i);
}

void build(int i,int l,int r)
{
col[i]=1;
if (l==r) return;
int mid=(l+r)/2;
build(i*2,l,mid);
build(i*2+1,mid+1,r);
}

int main()
{
int i,u,v,c;
char ch;

freopen("in","r",stdin);
scanf("%d%d%d\n",&n,&t,&m);
build(1,1,n);
while (m--)
{
scanf("%c",&ch);
if (ch=='P')
{
scanf("%d%d\n",&u,&v);
if (u>v) swap(u,v);
memset(ans,0,sizeof(ans));
find(1,1,n,u,v);
sum=0;
for (i=1;i<=t;i++)
sum+=ans[i];
printf("%d\n",sum);
}
else
{
scanf("%d%d%d\n",&u,&v,&c);
if (u>v) swap(u,v);
ins(1,1,n,u,v,c);
}
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: