二进制枚举子集与容斥

来源:互联网 发布:360网络电视直播 编辑:程序博客网 时间:2024/06/05 10:47

一般的3^p的统计子集答案的方法
时间复杂度O(3^p),空间上需要两个数组。 (p为位数)

void sumup(){       for(int i=1;i<=top;i++)    {        for(int s=(i-1)&i;s>0;s=(s-1)&i)         {cal(i,s); fix();}      }}

用容斥的方式统计该集合的子集答案对该集合的影响
时间复杂度O(p* 2^p)),空间只需要一个数组。(p为位数)

void sumup(){       for(int i=0;i<P;i++)     {        int S=top^(1<<i);        for(int ss=S;ss>0;ss=(ss-1)&S)        {            cal(ss|(1<<i),ss);            fix();        }        cal(1<<i,0);        fix();    }}

非常重要的就是利用已经求出的值。

以今天kor一题的代码为例,

题意是给出n个数ai(n<=1e5,ai<=2^20),选出k个数(k<=n),使他们或起来的值为r(r<=2^20)的方案数有多少。
原本的cnt[x]表示值为x的数的个数,

经过sumup( ) 统计0~(1<< P)-1 这些集合中是这个数或1其子集的数有多少,
cnt[x]变为 值为x或x的子集的数的个数。
再通过排列数 cnt[x]=comb(cnt[x],k),cnt[x]变为,选k个数或起来是x或x的子集的方案数。
sumdown( )是个容斥,cnt[x]最终变为选k个数或起来是的方案数。

代码:

#include<cstdio>#include<iostream>#include<algorithm>#include<cstring>using namespace std;const int P=20;const int N=100010;const int mod=1000000000+7;const int top=(1<<P)-1;int T,n,k,r,cnt[1<<P],fac[N],inv[N];int modpow(int a,int b) {    int ans=1; int base=a;    for(;b;b>>=1)    {        if(b&1) ans=(1LL*ans*base)%mod;        base=(1LL*base*base)%mod;    }       return ans;}void init(){    fac[1]=inv[1]=fac[0]=inv[0]=1;    for(int i=2;i<=100000;i++)    {        fac[i]=(1LL*fac[i-1]*i)%mod;        inv[i]=modpow(fac[i],mod-2)%mod;    }}void fix(int &x){    while(x>=mod) x-=mod;    while(x<0) x+=mod;}void sumup(){       for(int i=0;i<P;i++)    {        int S=top^(1<<i);        for(int ss=S;ss>0;ss=(ss-1)&S)        {            cnt[ss|(1<<i)]=cnt[ss|(1<<i)]+cnt[ss];            fix(cnt[ss|(1<<i)]);        }        cnt[1<<i]=cnt[1<<i]+cnt[0];        fix(cnt[1<<i]);    }}void sumdown(){    for(int i=0;i<P;i++)    {        int S=top^(1<<i);        for(int ss=S;ss>0;ss=(ss-1)&S)        {            cnt[ss|(1<<i)]=cnt[ss|(1<<i)]-cnt[ss];            fix(cnt[ss|(1<<i)]);        }        cnt[1<<i]=cnt[1<<i]-cnt[0];        fix(cnt[1<<i]);    }}int comb(int a,int b){    if(a<b) return 0;    int iv=(1LL*inv[a-b]*inv[b])%mod;    int ans=(1LL*fac[a]*iv)%mod;    return ans;}int main(){    freopen("kor.in","r",stdin);    freopen("kor.out","w",stdout);    init();    scanf("%d",&T);    while(T--)    {        memset(cnt,0,sizeof(cnt));        scanf("%d%d%d",&n,&k,&r);        for(int i=1;i<=n;i++)        {            int x;            scanf("%d",&x);            cnt[x]++;             }        sumup();        for(int i=0;i<=top;i++)        cnt[i]=comb(cnt[i],k);        sumdown();        printf("%d\n",cnt[r]);    }    return 0;}
原创粉丝点击