您的位置:首页 > 其它

【BZOJ2738】矩阵乘法 整体二分

2017-05-22 16:40 399 查看

【BZOJ2738】矩阵乘法

Description

  给你一个N*N的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第K小数。

Input

  第一行两个数N,Q,表示矩阵大小和询问组数;
  接下来N行N列一共N*N个数,表示这个矩阵;
  再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。

Output

  对于每组询问输出第K小的数。

Sample Input

2 2

2 1

3 4

1 2 1 2 1

1 1 2 2 3

Sample Output

1

3

HINT

  矩阵中数字是109以内的非负整数;
  20%的数据:N<=100,Q<=1000;
  40%的数据:N<=300,Q<=10000;
  60%的数据:N<=400,Q<=30000;
  100%的数据:N<=500,Q<=60000。

题解:根据整体二分的思想,我们将所有数排序,然后二分。我们将[1,mid]中的所有数扔到二维树状数组中去,然后看一看那些矩阵中的元素个数≥K。我们将满足条件的放在左边,不满足的放在右边,然后继续递归下去,直至出解。

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
int n,m,n2,tot,now;
struct node
{
int x,y,val;
}v[500*510];
int q1[60010],q2[60010],q3[60010],q4[60010],qk[60010],ans[60010];
int s[510][510],p[60010],q[60010],sum[60010];
int rd()
{
int ret=0,f=1;	char gc=getchar();
while(gc<'0'||gc>'9')	{if(gc=='-')f=-f;	gc=getchar();}
while(gc>='0'&&gc<='9')	ret=ret*10+gc-'0',gc=getchar();
return ret*f;
}
bool cmp(node a,node b)
{
return a.val<b.val;
}
void updata(int x,int y,int val)
{
int i,j;
for(i=x;i<=n;i+=i&-i)
for(j=y;j<=n;j+=j&-j)
s[i][j]+=val;
}
int query(int x,int y)
{
int ret=0,i,j;
for(i=x;i;i-=i&-i)
for(j=y;j;j-=j&-j)
ret+=s[i][j];
return ret;
}
void solve(int l,int r,int L,int R)
{
if(l>r)	return ;
if(L==R)
{
for(int i=l;i<=r;i++)	ans[p[i]]=v[L].val;
return ;
}
int MID=L+R>>1,i,mid=l-1;
while(now<MID)	now++,updata(v[now].x,v[now].y,1);
while(now>MID)	updata(v[now].x,v[now].y,-1),now--;
for(i=l;i<=r;i++)
{
sum[p[i]]=query(q1[p[i]]-1,q2[p[i]]-1)+query(q3[p[i]],q4[p[i]])-query(q1[p[i]]-1,q4[p[i]])-query(q3[p[i]],q2[p[i]]-1);
if(sum[p[i]]>=qk[p[i]])	mid++;
}
int l1=l,l2=mid+1;
for(i=l;i<=r;i++)
{
if(sum[p[i]]>=qk[p[i]])	q[l1++]=p[i];
else	q[l2++]=p[i];
}
for(i=l;i<=r;i++)	p[i]=q[i];
solve(l,mid,L,MID),solve(mid+1,r,MID+1,R);
}
int main()
{
n=rd(),m=rd();
int i,j;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
v[++n2].val=rd(),v[n2].x=i,v[n2].y=j;
sort(v+1,v+n2+1,cmp);
for(i=1;i<=m;i++)	q1[i]=rd(),q2[i]=rd(),q3[i]=rd(),q4[i]=rd(),qk[i]=rd(),p[i]=i;
solve(1,m,1,n2);
for(i=1;i<=m;i++)	printf("%d\n",ans[i]);
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: