51nod 1601 完全图的最小生成树计数 Trie+kruskal

来源:互联网 发布:网络发信息到手机 编辑:程序博客网 时间:2024/06/06 01:42

题意:给定一个长度为n的数组a[1..n],有一幅完全图,满足(u,v)的边权为a[u] xor a[v]
求边权和最小的生成树,你需要输出边权和还有方案数对1e9+7取模的值。

由于边权是xor得到,容易想到用trie统计。。
按照当前最高位0/1将当前区间内的点分成两个部分s/t,那么答案肯定是s的最小生成树+t的最小生成树+s-t的最小边,s-t最小边用trie统计,最小生成树递归处理。
那么方案数的话就是每次那个连接两个块之间的最小边的数量,所以trie树统计一下节点个数就好。
字典树那个地方每次查询一个数,尽量使得当前位置相同就好,最后记得记录一下,每个位可能有多个点(多个数),会对方案造成贡献。
好像是经典套路?

#include<cstdio>#include<algorithm>#include<cstring>#define fo(i,a,b) for(int i=a;i<=b;i++)#define fd(i,a,b) for(int i=a;i>=b;i--)using namespace std;const int N=1e5+5;const int mo=1e9+7;const int inf=0x3f3f3f3f;int n,cnt,tot,a[N],s[N],t[N],fac[N];typedef long long ll;ll sum;struct node{    int cnt,next[2];}ch[N*31];inline void clear(){    fo(i,0,tot)        ch[i].next[0]=ch[i].next[1]=ch[i].cnt=0;    tot=0;}inline int pow(int a,int b){    int ret=1;    while (b)    {        if (b&1)ret=1ll*ret*a%mo;        a=1ll*a*a%mo;        b>>=1;    }    return ret;}inline void ins(int x){    int p=0;    fd(i,30,0)    {        int y=(x>>i)&1;        if (!ch[p].next[y])            ch[p].next[y]=++tot;        p=ch[p].next[y];    }    ch[p].cnt++;}inline pair<int,int> find(int x){    int p=0,ans=0,y;    fd(i,30,0)    {        y=(x>>i)&1;        if (ch[p].next[y])p=ch[p].next[y],ans|=y<<i;        else p=ch[p].next[y^1],ans|=(y^1)<<i;    }    return make_pair(ans^x,ch[p].cnt);}inline void solve(int l,int r,int dep){    if (l>=r)return;    if (dep<0)    {        if (r-l+1>=2)cnt=1ll*cnt*pow(r-l+1,r-l-1)%mo;        return;    }    int cnt1=0,cnt2=0;    fo(i,l,r)        if ((a[i]>>dep)&1)s[cnt1++]=a[i];        else t[cnt2++]=a[i];    fo(i,0,cnt1-1)a[l+i]=s[i];    fo(i,0,cnt2-1)a[l+cnt1+i]=t[i];    clear();    pair<int,int>tmp;    int ans=inf,tot=0;    fo(i,0,cnt2-1)ins(t[i]);    fo(i,0,cnt1-1)    {        tmp=find(s[i]);        if (tmp.first<ans)ans=tmp.first,tot=tmp.second;        else if (tmp.first==ans)            tot+=tmp.second;    }    if (sum!=inf&&tot)sum+=ans,cnt=1ll*tot*cnt%mo;    solve(l,l+cnt1-1,dep-1);    solve(l+cnt1,r,dep-1);}int main(){    scanf("%d",&n);    fac[0]=cnt=1;    fo(i,1,n)fac[i]=1ll*fac[i-1]*i%mo;    fo(i,1,n)scanf("%d",&a[i]);    solve(1,n,30);    printf("%lld\n%d\n",sum,cnt);}
阅读全文
0 0
原创粉丝点击