UVALive 3710 Interconnect(记忆化搜索 + hash)

来源:互联网 发布:shodan入侵网络摄像头 编辑:程序博客网 时间:2024/06/06 08:47

题目大意:一个地方有n个城市,先开始有m条边,每个边连接连个不同的城市,双向边,可以有重边,现在要建边,每年加一条边,建每条边的概率都是一样的,问你能让这n个城市都连通的数学期望是多少?

思路:能相互连接起来的点把他们搞成一个个集合,然后根据先开始的m条边就有几个集合了,然后加边,那么这条边就有两种情况,一种是加了某两个集合之间的一条,一个是连接某个集合里的两个点,即没有用的。假设我们设加了某一条两个集合之间的边的概率是 p = a[ i ] * a[ j ]/nn,a[ i ] 和 a[ j ] 分别是i 集合和 j 集合的点数,nn = n*(n-1)/2,表示总的方案数,然后设 q 表示加了没用的边的概率 q = SIGMA(a[ i ] * (a[ i ] - 1) / 2),加了这条边之后,图形就变了,假设那个变了之后的图形的数学期望是 r ,那么原来图的数学期望就是 p*(1+r) + q*p*(1+r) + q*q*p*(1+r) + ...,然后就是求这个公式的极限了,结果是 p*(1/(1-q) + r)/(1-q)。因为是记忆化搜索,肯定要记录状态,就需要hash 。我用的是BDKR hash ,然后 hash 值一样时,再 for 一遍判断

自己写的时候,第一遍用的是map + set 写的(公式是看别人的。。囧),然后 LA 上过了,1.5s,然后 POJ 就是 TLE,看了下别人的代码,然后知道了这个 BDKR hash,果然这样比较快。。 = =   关于BDKR hash 感觉不是很懂,那个 seed 的取值,怎么定?网上一查全是模板一类的东西,希望各位路过的大牛能提点下。。

还要这道题用 trie 能不能做,我自己敲了个 , TLE 了,估计是写挫了。。 也希望大牛能指点一二。。

代码如下:

#include<cstdio>#include<cstring>#include<map>#include<vector>#include<algorithm>using namespace std;const int MOD = 10007;const int MAXN = 33;struct State{    int num;    int a[33];    double ans;    void order()    {        sort(a,a+num);    }    unsigned int get_hash()//BKDR hash    {        unsigned int hash = 0,seed = 131;//seed:   31 、131 、1313、 13131 、131313 etc..        for(int i = 0;i<num;i++)            hash = hash*seed + a[i];        return (hash&0x7fffffff)%MOD;    }    bool operator == (const State &tmp) const    {        if(tmp.num != num) return 0;        for(int i = 0;i<tmp.num;i++)        {            if(tmp.a[i]!=a[i]) return 0;        }        return 1;    }};vector < State > hh[MOD];double check(State &s){    int id = s.get_hash();    for(int i = 0 ;i<hh[id].size();i++)        if(hh[id][i] == s)            return hh[id][i].ans;    return -1.0;}int nn;double dfs(State s){    if(s.num == 1) return 0;    double x = check(s);    if(x != -1)    {        return x;    }    double q = 0;    for(int i = 0;i<s.num;i++)    {        q += s.a[i]*(s.a[i] - 1)/2;    }    q = q/nn;    State t;    double ans = 0;    for(int i = 0;i<s.num;i++)        for(int j = i+1;j<s.num;j++)        {            double p = s.a[i]*s.a[j]*1.0/nn;            t = s;            t.a[i] += t.a[j];            swap(t.a[j],t.a[s.num-1]);            t.num--;            t.order();            ans += p*(1/(1-q) + dfs(t))/(1-q);        }    int id = s.get_hash();    s.ans = ans;    hh[id].push_back(s);    return ans;}int fa[MAXN],size[MAXN];int find_fa(int x){    if(x == fa[x]) return x;    return fa[x] = find_fa(fa[x]);}int main(){    int n,m;    while(~scanf("%d%d",&n,&m))    {        nn = n*(n-1)/2;        for(int i = 1;i<=n;i++)        {            fa[i] = i;            size[i] = 1;        }        for(int i = 1;i<=m;i++)        {            int a,b;            scanf("%d%d",&a,&b);            int fx = find_fa(a);            int fy = find_fa(b);            if(fx == fy) continue;            fa[fx] = fy;            size[fy] += size[fx];            size[fx] = 0;        }        State beg;        beg.num = 0 ;        for(int i = 1;i<=n;i++)        {            if(size[i])            {                beg.a[beg.num++] = size[i];            }        }        beg.order();        printf("%.6f\n",dfs(beg));    }    return 0;}


set + map 代码如下:

#include<cstdio>#include<cstring>#include<set>#include<map>#include<algorithm>using namespace std;multiset<int> beg;map < multiset<int> , double > mm;int nn;double dfs(multiset<int> &s){    if(mm.find(s) != mm.end()) return mm[s];    double q = 0;    for(multiset<int> ::iterator it = s.begin();it!=s.end();it++)        q += (*it)*((*it)-1)/2;    q = q/nn;    multiset<int> t = s;    double ans = 0;    for(multiset<int> :: iterator it = s.begin();it!=s.end();it++)    {        multiset<int> :: iterator it2 = ++it;        it--;        for(;it2!=s.end();it2++)        {            double p = (*it)*(*it2);            p = p/nn;            //printf("it = %d,it2 = %d,p = %f,q = %f,level = %d\n",*it,*it2,p,q,level);            t = s;            int sum = (*it) + (*it2);            t.erase(t.lower_bound(*it));            t.erase(t.lower_bound(*it2));            t.insert(sum);            //for(multiset<int> :: iterator it3 = t.begin();it3!=t.end();it3++)                //printf("t = %d ",*it3);            //puts("");            ans += p*(1.0/(1-q) + dfs(t))/(1-q);        }    }    return mm[s] = ans;}int fa[33];int size[33];int find_fa(int x){    if(fa[x] == x) return x;    return fa[x] = find_fa(fa[x]);}int main(){    int n,m;    while(~scanf("%d%d",&n,&m))    {        nn = n*(n-1)/2;        for(int i = 1;i<=n;i++)        {            fa[i] = i;            size[i] = 1;        }        for(int i = 1;i<=m;i++)        {            int a,b;            scanf("%d%d",&a,&b);            int fx = find_fa(a);            int fy = find_fa(b);            if(fx == fy) continue;            fa[fx] = fy;            size[fy] += size[fx];            size[fx] = 0;        }        beg.clear();        beg.insert(n);        mm[beg] = 0;        beg.clear();        for(int i = 1;i<=n;i++)        {            if(size[i])                beg.insert(size[i]);        }        //for(multiset<int> :: iterator it = beg.begin();it!=beg.end();it++)            //printf("%d ",*it);        //puts("");        printf("%.6f\n",dfs(beg));    }    return 0;}