[JZOJ4841] 平衡的子集

来源:互联网 发布:帖木儿帝国知乎 编辑:程序博客网 时间:2024/04/29 06:22

Description

夏令营有N个人,每个人的力气为M(i)。请大家从这N个人中选出若干人,如果这些人可以分成两组且两组力气之和完全相等,则称为一个合法的选法,问有多少种合法的选法?(注意是选法而不是分法)
40%的数据满足:1<=M(i)<=1000;
对于100%的数据满足:2<=N<=20,1<=M(i)<=100000000

Solution

设选的人中分为两个集合A,B
那么每个人有三种状态,不选,A集,B

我们知道iAa[i]=jBa[j]

然后移到左边来
iAa[i]jBa[j]=0

设每个人的三种状态分别为0,1,1,也就是系数

然后就搜索,得数和状态丢到哈希表里,最后再做一次统计答案。

然而这样会T,因为我们计算了大量状态是有重叠部分的,这一部分状态的计算非常浪费。

因此我们只需要先搜左边10个,再搜右边10个判断即可

注意有一个优化,就是有可能出现状态和得数都相等的情况

例如2,2,2
1,-1,1和1,1,-1的得数都为2,因此这一部分暴力判断一下是否在该位置中存在,又可以删去大量冗余状态。

我之前不判重就被卡了1s~

Code

#include <cstdio>#include <cstdlib>#include <algorithm>#include <iostream>#include <cstring>#include <cmath>#include <vector>#define fo(i,a,b) for(int i=a;i<=b;i++)#define fod(i,a,b) for(int i=a;i>=b;i--)#define N 22#define M 1100000#define mo 10000007#define INF 10000000000#define LL long longusing namespace std;int sum[M],n,hnum[mo],fs[mo],n1;LL a[N];struct node{    int p,nt,lt;    LL s;}h[1<<21];void dfs(int k,int s,LL sm){    if(k>n/2) return;    fo(i,-1,1)    {        if(i!=0)         {            s+=1<<k-1;            sm+=i*a[k];            if(sm==INF) sum[s]=1;            int j=fs[sm%mo],bz=1;            while(j>0)            {                if(h[j].s==sm&&h[j].p==s)                 {                    bz=0;                    break;                }                j=h[j].nt;            }            if(bz)            {                h[++n1].p=s;                h[n1].s=sm;                h[n1].nt=fs[sm%mo];                h[fs[sm%mo]].lt=n1;                fs[sm%mo]=n1;                hnum[sm%mo]++;                dfs(k+1,s,sm);            }            s-=1<<k-1;            sm-=i*a[k];        }        else dfs(k+1,s,sm);    }}void get(int k,int s,LL sm){    if(k>n) return;    fo(i,-1,1)    {        if(i!=0)        {            s+=1<<k-1;            sm+=i*a[k];            if(sm==INF) sum[s]=1;             LL pt=sm%mo;            int j=fs[pt];            while(j>0){if(h[j].s==sm) sum[h[j].p+s]=1;j=h[j].nt;}            get(k+1,s,sm);            s-=1<<k-1;            sm-=i*a[k];        }        else get(k+1,s,sm);    }}int main(){    cin>>n;    fo(i,1,n) scanf("%lld",&a[i]);    dfs(1,0,INF);    get(n/2+1,0,INF);    int ans=0;    fo(i,1,(1<<n)-1) ans+=sum[i];    cout<<ans;}
1 0
原创粉丝点击