【bzoj2738】 矩阵乘法

来源:互联网 发布:linux 中文方格 编辑:程序博客网 时间:2024/06/05 07:35

Description

  给你一个N*N的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第K小数。

Input

 
  第一行两个数N,Q,表示矩阵大小和询问组数;
  接下来N行N列一共N*N个数,表示这个矩阵;
  再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。

Output

  对于每组询问输出第K小的数。

Sample Input

2 2
2 1
3 4
1 2 1 2 1
1 1 2 2 3

Sample Output

1
3

HINT

  矩阵中数字是109以内的非负整数;

  20%的数据:N<=100,Q<=1000;

  40%的数据:N<=300,Q<=10000;

  60%的数据:N<=400,Q<=30000;

  100%的数据:N<=500,Q<=60000。


Solve

整体二分,二分答案,二位树状数组求区间数的个数。

#include<algorithm>#include<iostream>#include<cstdio>#define lowbit(x) (x&(-x))using namespace std;int sum[505][505],ss,n,m,cnt,xh[60005],ans[60005],step,tmp[60005],maxn;bool mark[60005];struct node{    int val,x,y;    friend bool operator < (node i,node j){        return i.val<j.val;    }}t[250005];struct orz{int x1,x2,y1,y2,k;}q[60005];inline void add(int x,int y,int key){    for (int i=x;i<=n;i+=lowbit(i))        for (int j=y;j<=n;j+=lowbit(j))            sum[i][j]+=key;}inline int get_sum(int x,int y){    ss=0;    for (int i=x;i;i-=lowbit(i))        for (int j=y;j;j-=lowbit(j))            ss+=sum[i][j];    return ss;}inline int query(orz i){    return get_sum(i.x2,i.y2)+get_sum(i.x1-1,i.y1-1)-get_sum(i.x1-1,i.y2)-get_sum(i.x2,i.y1-1);}inline void solve(int l,int r,int L,int R){    if (l>r || L==R)return;    int mid=(L+R)>>1,len=0,ll[2];    for (;step<cnt && t[step+1].val<=mid;++step)add(t[step+1].x,t[step+1].y,1);    for (;step>0 && t[step].val>mid;--step)add(t[step].x,t[step].y,-1);    for (int i=l;i<=r;++i)        if (query(q[xh[i]])>=q[xh[i]].k)            ans[xh[i]]=mid,mark[i]=++len;        else mark[i]=0;    ll[1]=l;ll[0]=l+len;    for (int i=l;i<=r;++i)tmp[ll[mark[i]]++]=xh[i];    for (int i=l;i<=r;++i)xh[i]=tmp[i];    solve(l,ll[1]-1,L,mid);solve(ll[1],r,mid+1,R);}int main (){    scanf ("%d%d",&n,&m);    for (int i=1;i<=n;++i)        for (int j=1;j<=n;++j){            scanf ("%d",&t[++cnt].val);            t[cnt].x=i;t[cnt].y=j;            maxn=max(maxn,t[cnt].val);        }    sort(t+1,t+cnt+1);    for (int i=1;i<=m;++i)        scanf ("%d%d%d%d%d",&q[(xh[i]=i)].x1,&q[i].y1,&q[i].x2,&q[i].y2,&q[i].k);    solve(1,m,0,maxn+1);    for (int i=1;i<=m;++i)printf ("%d\n",ans[i]);    return 0;}