您的位置:首页 > 其它

【BZOJ3110】K大数查询(ZJOI2013)-整体二分+线段树

2017-05-07 18:15 525 查看
测试地址:K大数查询

做法:这题需要用到整体二分和线段树(这题也可以用树套树做,然而复杂度就很恶心了)。

这一题由于一个位置可以有多个数,所以看上去束手无策,然而这一题并不强制在线,所以我们自然想到整体二分。

因为一个区间内比一个数大的数单调,所以这个性质是可二分的,所以函数solve(s,t,l,r)的作用就是处理操作区间[s,t]内的所有询问,处理过程就是统计每一个询问的区间中有多少大于等于mid的数,然后根据这个结果将操作分为两个部分,然后递归处理即可。统计可以用线段树区间修改来完成。总复杂度O(Nlog2N)。

以下是本人代码(91分WA,原因待探究):

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define inf 1000000000
#define ll long long
using namespace std;
int n,m,ans[50010],tmp[50010],qcnt=0;
ll p[200010]={0},seg[200010]={0},Min=inf,Max=-inf;
struct query
{
int op,a,b,id;
ll cur,c;
}q[50010],a1[50010],a2[50010];

void add(int no,int l,int r,int s,int t,int val)
{
if (l>=s&&r<=t)
{
p[no]+=val;
seg[no]+=(r-l+1)*val;
return;
}
int mid=(l+r)>>1;
if (p[no]!=0)
{
p[no<<1]+=p[no],p[no<<1|1]+=p[no];
seg[no<<1]+=(mid-l+1)*p[no],seg[no<<1|1]+=(r-mid)*p[no];
p[no]=0;
}
if (s<=mid) add(no<<1,l,mid,s,t,val);
if (t>mid) add(no<<1|1,mid+1,r,s,t,val);
seg[no]=seg[no<<1]+seg[no<<1|1];
}

ll query(int no,int l,int r,int s,int t)
{
if (l>=s&&r<=t) return seg[no];
int mid=(l+r)>>1;ll tot=0;
if (p[no]!=0)
{
p[no<<1]+=p[no],p[no<<1|1]+=p[no];
seg[no<<1]+=(mid-l+1)*p[no],seg[no<<1|1]+=(r-mid)*p[no];
p[no]=0;
}
if (s<=mid) tot+=query(no<<1,l,mid,s,t);
if (t>mid) tot+=query(no<<1|1,mid+1,r,s,t);
return tot;
}

void solve(int s,int t,int l,int r)
{
if (s>t||l>r) return;
if (l==r)
{
for(int i=s;i<=t;i++)
if (q[i].op==2) ans[q[i].id]=l;
return;
}

int mid=(l+r)>>1;
mid++;
for(int i=s;i<=t;i++)
{
if (q[i].op==1&&q[i].c>=mid) add(1,1,n,q[i].a,q[i].b,1);
if (q[i].op==2) tmp[q[i].id]=query(1,1,n,q[i].a,q[i].b);
}
for(int i=s;i<=t;i++)
if (q[i].op==1&&q[i].c>=mid) add(1,1,n,q[i].a,q[i].b,-1);

int n1=0,n2=0;
for(int i=s;i<=t;i++)
{
if (q[i].op==2)
{
if (q[i].cur+tmp[q[i].id]>=q[i].c) a2[++n2]=q[i];
else
{
q[i].cur+=tmp[q[i].id];
a1[++n1]=q[i];
}
}
else
{
if (q[i].c<mid) a1[++n1]=q[i];
else a2[++n2]=q[i];
}
}

for(int i=1;i<=n1;i++) q[s+i-1]=a1[i];
for(int i=1;i<=n2;i++) q[s+n1+i-1]=a2[i];
solve(s,s+n1-1,l,mid-1);
solve(s+n1,t,mid,r);
}

int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=m;i++)
{
scanf("%d%d%d%lld",&q[i].op,&q[i].a,&q[i].b,&q[i].c);
if (q[i].op==2) q[i].cur=0,q[i].id=++qcnt;
else Min=min(Min,q[i].c),Max=max(Max,q[i].c);
}

solve(1,m,Min,Max);

for(int i=1;i<=qcnt;i++)
printf("%d\n",ans[i]);
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: