您的位置:首页 > 其它

【POJ 2777 】 线段树之成段更新+位运算

2012-12-18 20:56 435 查看
题目链接:http://poj.org/problem?id=2777

题目大意: 给你一段区间[1,L] , 给定初始所有节点颜色为1,有下面两种操作:

1、 “C A B D” 将区间[A,B]改变颜色为D。

2、“P A B” 计算区间[A,B]有多少种不同的颜色。

解题思路:

最土的可能100000个节点*100000次运算,暴力肯定TLE。解决这一类题目一般会想到线段树或者树状数组,这里用线段树。

这里每次改变都从根节点到叶子节点都同步更新的话肯定TLE,这里又要用到区间更新。

区间更新: 区间更新指的是当要改变某个区间[tl,tr]的颜色值,当往下递归到这个区间[l,r]包含在[tl,tr]里面就停止,将要改变的颜色值赋给这个区间,并做一次标记。这个标记的作用防止下次不会再更新这里了,我们查询的时候就可以以标记为判断要不要询问这里的值,而如果下次更新又再次来到这里,就会把先前更新的值给覆盖掉,变为这次更新的值,而与之同步的是上次标记的值会继续往下传,往下传完之后此节点的标记清空。往下传的节点会继续往下传,一直传到此次和上次更新的区间不一样就停止,道理同上。

本题还用到了或位运算。我们用1的个数表示区间不同颜色的种类,则tree[u]=tree[2*u] | tree[2*u+1];

搞错了一下位运算,纠结了良久。 1<<val 往哪边进行位运算1就在哪边。

线段树好题。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cstring>
using namespace std;

const int maxn=100005;
int flag[4*maxn];
int tree[4*maxn];

void down(int u)   ///往下传标记
{
if(flag[u])
{
flag[2*u]=flag[u];
flag[2*u+1]=flag[u];
tree[2*u]=tree[u];
tree[2*u+1]=tree[u];
flag[u]=0;
}
}

void build(int u, int l, int r)  ///建树
{
flag[u]=0;
tree[u]=(1<<1);
if(l==r)
return ;
int mid=(l+r)>>1;
build(2*u,l,mid);
build(2*u+1,mid+1,r);
}

void update(int u, int l, int r, int tl, int tr, int val)
{
if(tl<=l&&r<=tr)
{
flag[u]=val;
tree[u]=(1<<val);  ///这里搞错了位运算,写成了(val<<1),纠结了良久
return ;
}
down(u);
int mid=(l+r)>>1;
if(tr<=mid)  update(2*u,l,mid,tl,tr,val);
else if(tl>mid) update(2*u+1,mid+1,r,tl,tr,val);
else
{
update(2*u,l,mid,tl,tr,val);
update(2*u+1,mid+1,r,tl,tr,val);
}
tree[u]=tree[2*u]|tree[2*u+1];
}

int getsum(int u, int l, int r, int tl, int tr)
{
if(tl<=l&&r<=tr)
{
return  tree[u];
}
down(u);
int mid=(l+r)>>1;
if(tr<=mid)  return getsum(2*u,l,mid,tl,tr);
else if(tl>mid) return getsum(2*u+1,mid+1,r,tl,tr);
else
{
int t1=getsum(2*u,l,mid,tl,tr);
int t2=getsum(2*u+1,mid+1,r,tl,tr);
return t1|t2;
}
}

int fx(int x)
{
int ans=0;
while(x)
{
if(x&1) ans++;
x>>=1;
}
return ans;
}

int main()
{
int  n, t, Q;
while(~scanf("%d%d%d",&n,&t,&Q))
{
build(1,1,n);
char ch[3];
while(Q--)
{
scanf("%s",ch);
int l, r, val;
if(ch[0]=='C')
{
scanf("%d%d%d",&l,&r,&val);
if(l>r) swap(l,r);
update(1,1,n,l,r,val);
}
else
{
scanf("%d%d",&l,&r);
if(l>r) swap(l,r);
int ans=getsum(1,1,n,l,r);
printf("%d\n",fx(ans));
}
}
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: