bzoj 3129 [Sdoi2013]方程 数论 容斥

来源:互联网 发布:网络龙虎概率分析软件 编辑:程序博客网 时间:2024/06/06 00:56

对于>= 的条件从答案中直接减去。
对于<= 的条件容斥转成 >=
注意是正整数要转成非负,然后直接跑礼物就行啦。

#include <bits/stdc++.h>using namespace std;#define ll long long#define PA pair<int,int> int T,p,n,n1,n2,m;int ans,num;int pm[11],pn[11],a[11],v[11],phi[11],pt[11];int pre[11][11000];int qpow(int x,int y,int mod){    int ret=1;    while(y)    {        if(y&1)ret=(ll)ret*x%mod;        x=(ll)x*x%mod;y>>=1;    }    return ret;}PA get(int x,int tp){    if(x==0)return make_pair(1,0);    int r1=qpow(pre[tp][pt[tp]-1],x/pt[tp],pt[tp]);    r1=r1*pre[tp][x%pt[tp]]%pt[tp];    PA r2=get(x/pm[tp],tp);    return make_pair(r1*r2.first%pt[tp],x/pm[tp]+r2.second);}int ny(int x,int tp){return qpow(x,phi[tp]-1,pt[tp]);}void exgcd(int &x,int &y,int a,int b){    if(b==0){y=0;x=1;return;}    exgcd(y,x,b,a%b);y-=a/b*x;}int CRT(){    int ret=0,x,y;    for(int i=1;i<=num;i++)    {        exgcd(x,y,p/pt[i],pt[i]);        x=(x%pt[i]+pt[i])%pt[i];        ret=(ret+v[i]*x%pt[i]*(p/pt[i]))%p;    }    return ret;}int C(int x,int y){    for(int i=1;i<=num;i++)    {        PA r1=get(x,i),r2=get(y,i),r3=get(x-y,i);        int t1=r1.second-r2.second-r3.second;        if(t1>=pn[i])v[i]=0;        else v[i]=r1.first*ny(r2.first,i)%pt[i]*            ny(r3.first,i)%pt[i]*qpow(pm[i],t1,pt[i])%pt[i];    }    return CRT();}int main(){    scanf("%d%d",&T,&p);    for(int i=2,p1=p;i<=p1;i++)        if(p1%i==0)        {            pm[++num]=i;pt[num]=1;            while(p1%i==0)                p1/=i,pn[num]++,pt[num]*=i;            phi[num]=pt[num]-pt[num]/i;            pre[num][0]=1;            for(int j=1;j<pt[num];j++)            {                pre[num][j]=pre[num][j-1];                if(j%i)pre[num][j]=pre[num][j]*j%pt[num];            }        }    while(T--)    {        ans=0;        scanf("%d%d%d%d",&n,&n1,&n2,&m);        for(int i=1;i<=n1;i++)scanf("%d",&a[i]);        for(int i=1,x;i<=n2;i++)            scanf("%d",&x),m-=x;        m-=n-n2;        for(int i=0;i<1<<n1;i++)        {            int cnt=0,m1=m;            for(int j=0;j<n1;j++)                if(i>>j&1)cnt++,m1-=a[j+1];            if(m1<0)continue;            if(cnt&1)ans=(ans-C(n+m1-1,n-1)+p)%p;            else ans=(ans+C(n+m1-1,n-1))%p;        }        printf("%d\n",ans);    }    return 0;}
0 0