3262: 陌上花开 树状数组套平衡树/CDQ分治

来源:互联网 发布:命运 知乎 编辑:程序博客网 时间:2024/06/04 18:54

就是解决一个三维偏序问题,可以用树套树来解决。
第一维排序,第二维放到树状数组里,每个树状数组下建平衡树维护第三维。
我们按第一维排序后从小到大加入,每次到一个点时就把树状数组中小于等于第二维的扫一遍,再找平衡树里小于等于第三维的,这样就可以了。


我们还可以用强大的CDQ分治来解决。
同样是第一维排序,然后第二维分治,第三维树状数组。
定义solve(l,r)为求出按照第一维排序后的区间[l,r]对答案的贡献。
那么每次我们首先递归调用solve(l,mid)solve(mid+1,r),然后考虑区间[l,mid]对区间[mid+1,r]的影响,我们回溯的时候按照第二维进行归并排序,这样保证调用完两个子过程时,区间[l,mid]和区间[mid+1,r]中第二维是有序的,而区间[l,mid]中的任何一点比区间[mid+1,r]中的任何一点的第一维也小,这样保证了前两维都有序。我们就用第三维维护树状数组,每次扫到区间[mid+1,r]中的一点p2时,就把区间[l,mid]中第二维比p2第二维小的p1的第三维加入树状数组,然后再统计当中比p2第三维小的答案。就线性求出了区间[l,mid]对区间[mid+1,r]的贡献。

树状数组套平衡树

#include<iostream>#include<cstdio>#include<algorithm>#define M 5000005#define N 100005#define lowbit(i) (i&(-i))using namespace std;int n,k,cnt,tmp;int root[N<<1],ans[N],sum[N];struct node {int a,b,c;} a[N];int size[M],same[M],ls[M],rs[M],rnd[M],val[M];inline int read(){    int a=0,f=1; char c=getchar();    while (c<'0'||c>'9') {if (c=='-') f=-1; c=getchar();}    while (c>='0'&&c<='9') {a=a*10+c-'0'; c=getchar();}    return a*f;}inline bool cmp(node a,node b){    return (a.a==b.a&&a.b==b.b)?a.c<b.c:(a.a==b.a?a.b<b.b:a.a<b.a);}inline void pushup(int k){    size[k]=size[ls[k]]+size[rs[k]]+same[k];}inline void lturn(int &k){    int t=rs[k]; rs[k]=ls[t]; ls[t]=k; pushup(k); pushup(t); k=t;}inline void rturn(int &k){    int t=ls[k]; ls[k]=rs[t]; rs[t]=k; pushup(k); pushup(t); k=t;}void insert(int &k,int v){    if (!k)     {        k=++cnt;        size[k]=same[k]=1;        rnd[k]=rand();        val[k]=v;        return;    }    size[k]++;    if (v==val[k]) same[k]++;    else if (v<val[k])    {        insert(ls[k],v);        if (rnd[ls[k]]<rnd[k]) rturn(k);    }    else    {        insert(rs[k],v);        if (rnd[rs[k]]<rnd[k]) lturn(k);    }}void getrank(int k,int v){    if (!k) return;    if (val[k]==v) {tmp+=size[ls[k]]+same[k]; return;}    else if (v<val[k]) getrank(ls[k],v);    else tmp+=size[ls[k]]+same[k],getrank(rs[k],v);}inline void query(int x,int val){    for (int i=x;i;i-=lowbit(i))        getrank(root[i],val);}inline void add(int x,int val){    for (int i=x;i<=k;i+=lowbit(i))        insert(root[i],val);}int main(){    n=read(); k=read();    for (int i=1;i<=n;i++)        a[i].a=read(),a[i].b=read(),a[i].c=read();    sort(a+1,a+n+1,cmp);    for (int i=1;i<=n;i++)    {        if (i<n&&a[i].a==a[i+1].a&&a[i].b==a[i+1].b&&a[i].c==a[i+1].c) sum[i+1]=sum[i]+1;        else        {            tmp=0;            query(a[i].b,a[i].c);            ans[tmp]+=sum[i]+1;        }        add(a[i].b,a[i].c);    }    for (int i=0;i<n;i++) printf("%d\n",ans[i]);    return 0;}

CDQ分治:

#include<iostream>#include<cstdio>#include<algorithm>#define N 200005#define lowbit(i) (i&(-i))using namespace std;int n,k,cnt;struct node {int a,b,c,ans,total;} a[100005],p[100005],np[100005];int tree[200005],ans[100005];inline int read(){    int a=0,f=1; char c=getchar();    while (c<'0'||c>'9') {if (c=='-') f=-1; c=getchar();}    while (c>='0'&&c<='9') {a=a*10+c-'0'; c=getchar();}    return a*f;}inline bool cmp1(node a,node b){    return (a.a==b.a&&a.b==b.b)?a.c<b.c:(a.a==b.a?a.b<b.b:a.a<b.a);}inline bool cmp0(node a,node b){    return a.b==b.b?a.c<b.c:a.b<b.b;}inline void add(int x,int val){    for (int i=x;i<=k;i+=lowbit(i)) tree[i]+=val;}inline int query(int x){    int tmp=0;    for (int i=x;i;i-=lowbit(i)) tmp+=tree[i];    return tmp;}void solve(int l,int r){    if (l==r) return;    int mid=l+r>>1;    solve(l,mid); solve(mid+1,r);    int p1=l,p2=mid+1;    while (p2<=r)    {        while (p1<=mid&&p[p1].b<=p[p2].b)        {            add(p[p1].c,p[p1].total);            p1++;        }        p[p2].ans+=query(p[p2].c);        p2++;    }    for (int j=l;j<p1;j++) add(p[j].c,-p[j].total);    p1=l; p2=mid+1;    for (int i=l;i<=r;i++)        if ((cmp0(p[p1],p[p2])||p2>r)&&p1<=mid) np[i]=p[p1++]; else np[i]=p[p2++];    for (int i=l;i<=r;i++) p[i]=np[i];}int main(){    n=read(); k=read();    for (int i=1;i<=n;i++)        a[i].a=read(),a[i].b=read(),a[i].c=read();    sort(a+1,a+n+1,cmp1);    int num=0;    for (int i=1;i<=n;i++)    {        num++;        if (a[i].a!=a[i+1].a||a[i].b!=a[i+1].b||a[i].c!=a[i+1].c)        {            p[++cnt]=a[i];            p[cnt].total=num;            num=0;        }    }    solve(1,cnt);    for (int i=1;i<=cnt;i++) ans[p[i].ans+p[i].total-1]+=p[i].total;    for (int i=0;i<n;i++) printf("%d\n",ans[i]);    return 0;}
0 1