POJ_2778 DNA Sequence AC自动机+dp

来源:互联网 发布:星际淘宝主 编辑:程序博客网 时间:2024/05/20 13:41

http://poj.org/problem?id=2778

题意:

给你M个最多只有10个字符的字符串,问长度为N的不含这些字符串的字符串的

个数有多少个。N<=2000000000,M<=10

思路:

字符串匹配的问题,用AC自动机是最好的选择,因为N的范围很大, 直觉告诉

我们要用二分矩阵乘法。接着就是列状态转移方程,用F(i , j )表示 i 个字符

的字符串,最后一个串的状态为j时候的种数,状态转移方程就可以表示为:

 F(i ,j )= sum{ F( i-1 , j' ) },其中j'到j一个合法的转移。这样我们就可以转

换为矩阵相乘。具体的注释见代码吧。

代码:

#include<stdio.h>#include<string.h>#include<queue>const __int64 Mod = 100000 ;int N ,M, Root ,cnt ;char ch[20] ;struct Node{    int fail ;                      //AC自动机的失败指针    bool danger ;                   //标记以该结点为结尾的字符串是否合法    int next[4] ;                   //next数组    void init(){                    //初始化各个变量        fail = -1 ;        danger =  0 ;        memset(next , -1 ,sizeof(next));    }}p[110] ;std::queue<int> que ;__int64 g[110][110] ;               //g[i][j]表示由状态i转移到状态j的种数inline int get(char c){    switch(c){        case 'A' :  return 0 ;        case 'C' :  return 1 ;        case 'T' :  return 2 ;        case 'G' :  return 3 ;        default  :  return -1 ;    }}void build_trie(char *ch){    int len = strlen(ch) ;    int idx ,loc = Root;    for(int i=0;i<len;i++){        idx = get(ch[i]) ;        if( p[loc].next[idx] == -1 ){            ++cnt ; p[cnt].init() ;            p[loc].next[idx] = cnt ;        }        loc = p[loc].next[idx] ;        if( p[loc].danger ) return ;        //如果当前字符串是前面某个字符串的前缀,则该字符串的接下去的状态都不合法    }    p[loc].danger = 1 ;             //最后的状态标记为不合法}void build_ac_automation(){    int loc = Root ;    p[loc].fail = -1 ;    while(!que.empty()) que.pop() ;    que.push(loc) ;    while(!que.empty()){        int u = que.front() ; que.pop() ;        for(int i=0;i<4;i++){            if( p[u].next[i] == -1 )    continue ;            int v = p[u].next[i] ;            if( u==Root ){                p[v].fail = Root ;            }            else{                int temp = p[u].fail ;                while( temp!= -1 ){                    if( p[temp].next[i] != -1){                        p[v].fail = p[temp].next[i] ;                        if( p[ p[temp].next[i] ].danger ){      //和普通的建失败指针唯一不同的地方,请注意。                            p[v].danger = 1 ;                        }                        break ;                    }                    temp = p[temp].fail ;                }                if( temp == -1 )                    p[v].fail = Root ;            }            que.push(v) ;        }    }}void cal(){    memset(g, 0,sizeof(g));    for(int i=0;i<=cnt;i++){        if( p[i].danger )   continue;        for(int j=0;j<4;j++){            if( p[i].next[j]!=-1 && p[ p[i].next[j] ].danger==0){                g[i][ p[i].next[j] ] ++ ;            }            else if( p[i].next[j] == -1 ){                int temp = p[i].fail ;                while( temp != -1 ){                    if( p[temp].next[j] != -1 ){                        break ;                    }                    temp = p[temp].fail ;                }                if(temp == -1)                    g[i][Root] ++ ;                else{                    if( p[ p[temp].next[j] ].danger == 0 )                        g[i][ p[temp].next[j] ] ++ ;                }            }        }    }}__int64 res[110][110] ;void calc(__int64 a[110][110],__int64 b[110][110]){    __int64 c[110][110] ;    for(int i=0;i<=cnt;i++){        for(int j=0;j<=cnt;j++){            c[i][j] = 0 ;            for(int k=0;k<=cnt;k++){                c[i][j] = ( c[i][j] + a[i][k] * b[k][j] % Mod ) % Mod ;            }        }    }    for(int i=0;i<=cnt;i++)        for(int j=0;j<=cnt;j++)            a[i][j] = c[i][j] ;}void solve(int k){    while(k){        if( k & 1 ){            calc(res, g);        }        calc(g,g) ;        k >>= 1 ;    }}int main(){    Root = 0 ;    while(scanf("%d%d",&M,&N) == 2){        p[Root].init() ;    cnt = 0 ;        for(int i=1;i<=M;i++){            scanf("%s",ch);            build_trie(ch) ;             //构建字典树        }        build_ac_automation() ;        cal() ;        memset(res,0,sizeof(res));        for(int i=0;i<110;i++)  res[i][i]= 1 ;        solve(N) ;        __int64 ans = 0 ;        for(int i=0;i<=cnt;i++){            ans += res[0][i] ;            if( ans > Mod ) ans %= Mod ;        }        printf("%I64d\n",ans);    }    return 0 ;}

原创粉丝点击