HDU 3341 Lost's revenge(AC自动机+DP+变进制优化)

来源:互联网 发布:淘宝用不了怎么回事 编辑:程序博客网 时间:2024/06/05 23:39
题目是给一个DNA重新排列使其包含最多的数论基因。

字符最长是40
只需要记录ACGT出现的次数。
如果使用5维数组,显然超内存了。
假设ACGT的总数分别为num[0],num[1],num[2],num[3]
那么对于ACGT的数量分别为ABCD的状态可以记录为:
A*(num[1]+1)*(num[2]+1)*(num[3]+1) + B*(num[2]+1)*(num[3]+1)+ C*(num[3]+1) +D
这样的状态最大就是11*11*11*11
复杂度也可以承受了。
字符串可能有重复的,用int来记录数量

忘了dp【j】【i】小于0直接continue的情况,wa了好多。。
#include <set>#include <map>#include <stack>#include <queue>#include <deque>#include <cmath>#include <vector>#include <string>#include <cstdio>#include <cstdlib>#include <cstring>#include <iostream>#include <algorithm>using namespace std;#define L(i) i<<1#define R(i) i<<1|1#define INF  0x3f3f3f3f#define pi acos(-1.0)#define eps 1e-9#define maxn 60010#define MOD 1000000007int n,m;int dp[510][20010];int bit[13],num[13];struct Trie{    int next[510][5],fail[510],en[510];    int root,L;    void init()    {        L = 0;        root = newnode();    }    int newnode()    {        for(int i = 0; i < 4; i++)            next[L][i] = -1;        en[L++] = 0;        return L-1;    }    int getch(char ch)    {        if(ch == 'A')            return 0;        else if(ch == 'C')            return 1;        else if(ch == 'G')            return 2;        return 3;    }    void Insert(char buf[])    {        int now = root;        int len = strlen(buf);        for(int i = 0; i < len; i++)        {            if(next[now][getch(buf[i])] == -1)                next[now][getch(buf[i])] = newnode();            now = next[now][getch(buf[i])];        }        en[now]++;    }    void build()    {        queue<int> Q;        fail[root] = root;        for(int i = 0; i < 4; i++)        {            if(next[root][i] == -1)                next[root][i] = root;            else            {                fail[next[root][i]] = root;                Q.push(next[root][i]);            }        }        while(!Q.empty())        {            int now = Q.front();            Q.pop();            en[now] += en[fail[now]];            for(int i = 0; i < 4; i++)            {                if(next[now][i] == -1)                    next[now][i] = next[fail[now]][i];                else                {                    fail[next[now][i]] = next[fail[now]][i];                    Q.push(next[now][i]);                }            }        }    }    int ok(int x,int &a,int &b,int &c,int &d)    {        a = x/bit[0];        if(a > num[0])            return 0;        x -= a*bit[0];        b = x/bit[1];        if(b > num[1])            return 0;        x -= b*bit[1];        c = x/bit[2];        if(c > num[2])            return 0;        x -= c*bit[2];        d = x/bit[3];        if(d > num[3] || x-d*bit[3] != 0)            return 0;        return 1;    }    int solve()    {        char s[1010];        scanf("%s",s);        int len = strlen(s);        memset(num,0,sizeof(num));        for(int i = 0; i < len; i++)            num[getch(s[i])]++;        bit[0] = (num[1]+1)*(num[2]+1)*(num[3]+1);        bit[1] = (num[2]+1)*(num[3]+1);        bit[2] = num[3]+1;        bit[3] = 1;        int ss = num[0]*bit[0] + num[1]*bit[1] + num[2]*bit[2] + num[3]*bit[3];        memset(dp,-1,sizeof(dp));        dp[root][0] = 0;        for(int i = 0; i <= ss; i++)        {            int a,b,c,d;            if(ok(i,a,b,c,d))                for(int j = 0; j < L; j++)                {                    if(dp[j][i] < 0)                        continue;                    if(a < num[0])                        dp[next[j][0]][i+bit[0]] = max(dp[next[j][0]][i+bit[0]],dp[j][i]+en[next[j][0]]);                    if(b < num[1])                        dp[next[j][1]][i+bit[1]] = max(dp[next[j][1]][i+bit[1]],dp[j][i]+en[next[j][1]]);                    if(c < num[2])                        dp[next[j][2]][i+bit[2]] = max(dp[next[j][2]][i+bit[2]],dp[j][i]+en[next[j][2]]);                    if(d < num[3])                        dp[next[j][3]][i+bit[3]] = max(dp[next[j][3]][i+bit[3]],dp[j][i]+en[next[j][3]]);                }        }        int ans = 0;        for(int i = 0; i < L; i++)            ans = max(ans,dp[i][ss]);        return ans;    }} ac;int main(){    int t,C = 1;    //scanf("%d",&t);    while(scanf("%d",&n) && n)    {        ac.init();        char s[1010];        for(int i = 0; i < n; i++)        {            scanf("%s",s);            ac.Insert(s);        }        ac.build();        printf("Case %d: %d\n",C++,ac.solve());    }    return 0;}

大牛代码:
#include <set>#include <map>#include <stack>#include <queue>#include <deque>#include <cmath>#include <vector>#include <string>#include <cstdio>#include <cstdlib>#include <cstring>#include <iostream>#include <algorithm>using namespace std;#define L(i) i<<1#define R(i) i<<1|1#define INF  0x3f3f3f3f#define pi acos(-1.0)#define eps 1e-9#define maxn 60010#define MOD 1000000007int n,m;int dp[510][20010];int bit[13],num[13];struct Trie{    int next[510][5],fail[510],en[510];    int root,L;    void init()    {        L = 0;        root = newnode();    }    int newnode()    {        for(int i = 0; i < 4; i++)            next[L][i] = -1;        en[L++] = 0;        return L-1;    }    int getch(char ch)    {        if(ch == 'A')            return 0;        else if(ch == 'C')            return 1;        else if(ch == 'G')            return 2;        return 3;    }    void Insert(char buf[])    {        int now = root;        int len = strlen(buf);        for(int i = 0; i < len; i++)        {            if(next[now][getch(buf[i])] == -1)                next[now][getch(buf[i])] = newnode();            now = next[now][getch(buf[i])];        }        en[now]++;    }    void build()    {        queue<int> Q;        fail[root] = root;        for(int i = 0; i < 4; i++)        {            if(next[root][i] == -1)                next[root][i] = root;            else            {                fail[next[root][i]] = root;                Q.push(next[root][i]);            }        }        while(!Q.empty())        {            int now = Q.front();            Q.pop();            en[now] += en[fail[now]];            for(int i = 0; i < 4; i++)            {                if(next[now][i] == -1)                    next[now][i] = next[fail[now]][i];                else                {                    fail[next[now][i]] = next[fail[now]][i];                    Q.push(next[now][i]);                }            }        }    }    int solve()    {        char s[1010];        scanf("%s",s);        int len = strlen(s);        memset(num,0,sizeof(num));        for(int i = 0; i < len; i++)            num[getch(s[i])]++;        bit[0] = (num[1]+1)*(num[2]+1)*(num[3]+1);        bit[1] = (num[2]+1)*(num[3]+1);        bit[2] = num[3]+1;        bit[3] = 1;        memset(dp,-1,sizeof(dp));        dp[root][0] = 0;        for(int A = 0; A <= num[0]; A++)            for(int B = 0; B <= num[1]; B++)                for(int C = 0; C <= num[2]; C++)                    for(int D = 0; D <= num[3]; D++)                    {                        int s = A*bit[0] + B*bit[1] + C*bit[2] + D*bit[3];                        for(int i = 0; i < L; i++)                            if(dp[i][s] >= 0)                            {                                for(int k = 0; k < 4; k++)                                {                                    if(k == 0 && A == num[0])                                        continue;                                    if(k == 1 && B == num[1])                                        continue;                                    if(k == 2 && C == num[2])                                        continue;                                    if(k == 3 && D == num[3])                                        continue;                                    dp[next[i][k]][s+bit[k]] = max(dp[next[i][k]][s+bit[k]],dp[i][s]+en[next[i][k]]);                                }                            }                    }        int ans = 0;        int ss = num[0]*bit[0] + num[1]*bit[1] + num[2]*bit[2] + num[3]*bit[3];        for(int i = 0; i < L; i++)            ans = max(ans,dp[i][ss]);        return ans;    }} ac;int main(){    int t,C = 1;    //scanf("%d",&t);    while(scanf("%d",&n) && n)    {        ac.init();        char s[1010];        for(int i = 0; i < n; i++)        {            scanf("%s",s);            ac.Insert(s);        }        ac.build();        printf("Case %d: %d\n",C++,ac.solve());    }    return 0;}


0 0