AC自动机+DP小结 (一)

来源:互联网 发布:淘宝店铺产品全部下架 编辑:程序博客网 时间:2024/06/05 04:37

    好久没有更新博客了,最近真是懒到家了,南京赛前重点复习了下AC自动机+DP方面的题,写下来总结一下。

HDU 2457  DNA repair

http://acm.hdu.edu.cn/showproblem.php?pid=2457

题目大意:给你一个字符主串和很多病毒串,要求更改最少的字符使得没有一个病毒串是主串的子串。

 

思路:本人的第一道AC自动机DP,很初等的题目了,首先根据病毒串建立AC自动机,病毒串的尾节点设置一个标记,表示这个节点“危险”不能走,问题转化到求用主串在建立后的AC自动机上走len步(len为主串的长度)且不能经过“危险”节点需要更改的最小字符个数,设dp[i][j]表示长度为i的字符串到达j节点时所需改变的最小字符个数,然后根据AC自动机进行转移。建立失败指针的时候要注意的是如果一个节点的失败指针指向的节点是“危险的”,那么我们也要把该节点设为“危险”的,然后就直接状态转移即可。

 

#include <iostream>#include <string.h>#include <stdio.h>#include <algorithm>#include <queue>#define maxn 1010using namespace std;struct node{    int next[4];    int fail;    int flag;    void init()    {        memset(next,0,sizeof(next));        fail=0;        flag=0;    }}trie[maxn];int n,tot,inf;int dp[maxn][maxn];void init(){    tot=0;    trie[0].init();    memset(dp,1,sizeof(dp));    inf=dp[0][0];}int getnum(char x){    switch(x)    {        case 'A':return 0;        case 'C':return 1;        case 'G':return 2;        case 'T':return 3;    }    return 0;}void insert(char *str){    int p=0,index;    for(;*str!='\0';str++)    {        index=getnum(*str);        if(trie[p].next[index]==0)        {            trie[++tot].init();            trie[p].next[index]=tot;        }        p=trie[p].next[index];    }    trie[p].flag=1;}void build_fail(){    queue<int>Q;    int p,son,cur,i;    Q.push(0);    while(!Q.empty())    {        p=Q.front();        Q.pop();        for(i=0;i<4;i++)        {            if(trie[p].next[i]!=0)            {                son=trie[p].next[i];                cur=trie[p].fail;                if(p==0)                    trie[son].fail=0;                else                {                    while(cur&&trie[cur].next[i]==0)                        cur=trie[cur].fail;                    trie[son].fail=trie[cur].next[i];                    if(trie[trie[son].fail].flag)                    trie[son].flag=1;                }                Q.push(son);            }            else                {                  //  if(p==0)                   // trie[p].next[i]=0;                  //  else                    trie[p].next[i]=trie[trie[p].fail].next[i];                }        }    }}char str[maxn],tmp[25];void solve(){    dp[0][0]=0;    int i,j,k,len=strlen(str+1);    for(i=1;i<=len;i++)    {        for(j=0;j<=tot;j++)        {            if(dp[i-1][j]!=inf)            {                for(k=0;k<4;k++)                {                    int son=trie[j].next[k];                    if(trie[son].flag==0)                    {                        dp[i][son]=min(dp[i][son],dp[i-1][j]+(getnum(str[i])!=k));                    }                }            }        }    }    int ans=inf;    for(i=0;i<=tot;i++)    {        if(!trie[i].flag)        ans=min(ans,dp[len][i]);    }    if(ans==inf)    printf("-1\n");    else    printf("%d\n",ans);}int main(){    // freopen("dd.txt","r",stdin);     int T=0;     while(scanf("%d",&n)&&n)     {         printf("Case %d: ",++T);         init();         for(int i=1;i<=n;i++)         {             scanf("%s",tmp);             insert(tmp);         }         build_fail();         scanf("%s",str+1);         solve();     }    return 0;}


 

 hdu 2825

http://acm.hdu.edu.cn/showproblem.php?pid=2825

题目大意:给你一些密码片段字符串,让你求长度为n,且至少包含k个不同密码片段串的字符串的数量。

 

思路:因为密码串的数量不多,所以这里可以用状压解决,设dp[i][j][flag]表示长度为i且在节点j状态为flag的字符串有多少个,flag表示了这个状态包含密码串的状态,flag的第x位为1表示包含了第x个密码串,否则表示没有包含,我们构建密码串的AC自动机,然后在每个节点也维护一个flag表示该节点所表示字符串包含密码穿的状态,并且在构建失败指针的时候,当前节点的flag要与其失败指针所指结点的flag做一个并集.

即  trie[son].flag=trie[son].flag|trie[trie[son].fail].flag;之后在AC自动机上DP,转移,最后求至少包含k个1的状态flag,长度为len的串有多少个即为答案。

代码如下

 

#include <iostream>#include <string.h>#include <stdio.h>#include <algorithm>#include <queue>#define ll long long#define maxn 1010#define knum 26#define mod 20090717using namespace std;struct node{    int next[knum];    int fail;    int flag;    void init()    {        memset(next,0,sizeof(next));        fail=0;        flag=0;    }}trie[maxn];int n,m,k,tot;ll dp[30][110][1100];void init(){    tot=0;    trie[0].init();}void insert(char *str,int val){    int p=0,index;    for(;*str!='\0';str++)    {        index=*str-'a';        if(trie[p].next[index]==0)        {            trie[++tot].init();            trie[p].next[index]=tot;        }        p=trie[p].next[index];    }    trie[p].flag|=(1<<val);}void build_fail(){    queue<int>Q;    int p,son,cur,i;    Q.push(0);    while(!Q.empty())    {        p=Q.front();        Q.pop();        for(i=0;i<knum;i++)        {            if(trie[p].next[i]!=0)            {                son=trie[p].next[i];                cur=trie[p].fail;                if(p==0)                    trie[son].fail=0;                else                {                    while(cur&&trie[cur].next[i]==0)                    cur=trie[cur].fail;                    trie[son].fail=trie[cur].next[i];                }                trie[son].flag=trie[son].flag|trie[trie[son].fail].flag;                Q.push(son);            }            else            {                trie[p].next[i]=trie[trie[p].fail].next[i];            }        }    }}int check(int x){    int sum=0;    while(x)    {        sum++;        x-=x&(-x);    }    return sum;}void solve(){    int i,j,kk,limit=(1<<m);    for(i=0;i<=n;i++)    {        for(j=0;j<=tot;j++)        {            for(kk=0;kk<limit;kk++)            dp[i][j][kk]=0;        }    }    dp[0][0][0]=1;    for(i=0;i<n;i++)    {        for(j=0;j<=tot;j++)        {            for(kk=0;kk<limit;kk++)            {                if(dp[i][j][kk])                {                    //cout<<dp[i][j][k]<<" ";                    for(int l=0;l<knum;l++)                    {                        int son=trie[j].next[l];                        int tt=kk|trie[son].flag;                        dp[i+1][son][tt]=(dp[i+1][son][tt]+dp[i][j][kk])%mod;                    }                }            }        }    }    ll ans=0;    for(i=0;i<=tot;i++)    {        for(j=0;j<limit;j++)        {            if(check(j)>=k)            {                {                    ans=(ans+dp[n][i][j])%mod;                }            }        }    }    printf("%I64d\n",ans);}int main(){   // freopen("dd.txt","r",stdin);    int i;    while(scanf("%d%d%d",&n,&m,&k)&&(n||m||k))    {        init();        char tmp[15];        for(i=0;i<m;i++)        {            scanf("%s",tmp);            insert(tmp,i);        }        build_fail();        solve();    }    return 0;}


 

hdu 4057 Rescue the Rabbit

 

http://acm.hdu.edu.cn/showproblem.php?pid=4057

题目大意:给你一些基因片段,每个片段有一个权值,现在要你找到一个长度为l的基因,使得它的权值最大,基因的权值计算方法是,如果有一个基因片段是该基因的子串,则加上该基因片段的权值,但是每种基因片段的权值只计算一次。

 

思路:还是很容易看出是AC自动机加dp的,还是用状态压缩来表示字符串包含基因片段的状态,用dp[i][j][flag]来表示长度为i,在j节点状态为flag是否可以达(因为每个基因片段只计算一次,所以对于给定的flag其权值一定,所以只要判断flag是否可达即可)。因为i*j*flag太大内存不够,所以这里要用滚动数组来优化。类似于上一题,在每个节点维护一个状态flag,然后建立失败指针的时候与上一题一样的处理。然后在AC自动机上转移判断各个状态是否可达,最后对于每个可达的状态计算其相应权值取最大值即可。如果最大权值小于0,不要忘了输出 No Rabbit after 2012!

 

#include <iostream>#include <stdio.h>#include <string.h>#include <algorithm>#include <queue>#define maxn 1010#define inf 2100000000using namespace std;struct node{    int next[4];    int fail;    int flag;    void init()    {        memset(next,0,sizeof(next));        fail=0;        flag=0;    }}a[maxn];int n,len,tot;int weight[15];char keyword[110];int dp[2][maxn][1<<10];void init(){    tot=0;    a[0].init();    memset(dp,0,sizeof(dp));}int getnum(char x){    switch(x)    {        case 'A':return 0;        case 'C':return 1;        case 'G':return 2;        case 'T':return 3;    }    return 0;}void insert(char *str,int val){    int p=0,index;    for(;*str!='\0';str++)    {        index=getnum(*str);        if(a[p].next[index]==0)        {            a[++tot].init();            a[p].next[index]=tot;        }        p=a[p].next[index];    }    a[p].flag=a[p].flag|(1<<val);}void build_fail(){    queue<int>Q;    int p,son,cur,i;    Q.push(0);    while(!Q.empty())    {        p=Q.front();        Q.pop();        for(i=0;i<4;i++)        {            if(a[p].next[i]!=0)            {                son=a[p].next[i];                cur=a[p].fail;                if(p==0)                    a[son].fail=0;                else                {                    while(cur&&a[cur].next[i]==0)                        cur=a[cur].fail;                    a[son].fail=a[cur].next[i];                }                a[son].flag=a[son].flag|a[a[son].fail].flag;                Q.push(son);            }            else                a[p].next[i]=a[a[p].fail].next[i];        }    }}int getweight(int x){    int i,sum=0;    for(i=0;i<n;i++)    {        if(x&(1<<i))        sum+=weight[i];    }    return sum;}void solve(){    int i,j,k,l,son,ans,tmp;    dp[0][0][0]=1;    for(i=1;i<=len;i++)    {        memset(dp[i&1],0,sizeof(dp[i&1]));        for(j=0;j<=tot;j++)        {            for(l=0;l<(1<<n);l++)            {                if(dp[(i+1)&1][j][l]!=1)                continue;                for(k=0;k<4;k++)                {                    son=a[j].next[k];                    dp[i&1][son][l|a[son].flag]=1;                }            }        }    }    ans=-inf;    for(j=0;j<(1<<n);j++)    {        for(i=0;i<=tot;i++)        {            if(dp[len&1][i][j]==1)            {                tmp=getweight(j);                ans=max(ans,tmp);            }        }    }    if(ans<0)        printf("No Rabbit after 2012!\n");    else        printf("%d\n",ans);}int main(){    //freopen("dd.txt","r",stdin);    int i;    while(scanf("%d%d",&n,&len)!=EOF)    {        init();        for(i=0;i<n;i++)        {            scanf("%s%d",keyword,&weight[i]);            insert(keyword,i);        }        build_fail();        solve();    }    return 0;}

 

hdu 4758  Walk Through Squares

 

http://acm.hdu.edu.cn/showproblem.php?pid=4758

 

题目大意:给你两个串A和B,它们都只由R和L组成,问你含有n个R和m个L且既包含A也包含B的字符串有多少个。

 

思路:很明显的AC自动机+DP啊。。。为什么多校的时候就没想到呢。。。

状态很好设计,就是dp[i][j][k][flag]表示长度为i,在j节点且其中有k个R,状态为flag的字符串数量。和上两道题一样构造AC自动机后,在AC自动机上走,转移的时候分别讨论当前位是L或R往下走即可,没啥trick,然后内存开不够要用滚动数组处理,将长度那一维去掉就行。

代码仅供参考。。。

#include <iostream>#include <string.h>#include <algorithm>#include <queue>#include <stdio.h>#define maxn 2010#define knum 2#define inf 2100000000#define mod 1000000007#define ll long longusing namespace std;struct node{    int next[knum];    int fail;    int flag;    void init()    {        memset(next,0,sizeof(next));        fail=0;        flag=0;    }}trie[maxn];int n,tot,m;int getnum(char x){    if(x=='R')    return 0;    return 1;}void init(){    tot=0;    trie[0].init();}void insert(char *str,int val){    int p=0,index;    for(;*str!='\0';str++)    {        index=getnum(*str);        if(trie[p].next[index]==0)        {            trie[++tot].init();            trie[p].next[index]=tot;        }        p=trie[p].next[index];    }    trie[p].flag|=(1<<val);}void build_fail(){    queue<int>Q;    int p,son,cur,i;    Q.push(0);    while(!Q.empty())    {        p=Q.front();        Q.pop();        for(i=0;i<knum;i++)        {            if(trie[p].next[i]!=0)            {                son=trie[p].next[i];                cur=trie[p].fail;                if(p==0)                    trie[son].fail=0;                else                {                    while(cur&&trie[cur].next[i]==0)                        cur=trie[cur].fail;                    trie[son].fail=trie[cur].next[i];                }                trie[son].flag|=trie[trie[son].fail].flag;                Q.push(son);            }            else            {                trie[p].next[i]=trie[trie[p].fail].next[i];            }        }    }}ll dp[2][210][110][4];void solve(){    int i,j,k,l;    int len=n+m;    for(i=0;i<2;i++)     for(j=0;j<=tot;j++)      for(k=0;k<=m;k++)       for(l=0;l<4;l++)        dp[i][j][k][l]=0;        dp[0][0][0][0]=1;        for(i=0;i<len;i++)        {            int t1=i%2,t2=(i+1)%2;            for(j=0;j<=tot;j++)             for(k=0;k<=m;k++)              for(l=0;l<4;l++)               dp[t2][j][k][l]=0;            for(j=0;j<=tot;j++)            {                int mi=min(m,i);                for(k=0;k<=mi;k++)                {                    for(l=0;l<4;l++)                    {                        if(dp[t1][j][k][l])                        {                            for(int tt=0;tt<2;tt++)                            {                                int son=trie[j].next[tt];                                int kk=l|trie[son].flag;                                if(tt)                                {                                    dp[t2][son][k][kk]=(dp[t2][son][k][kk]+dp[t1][j][k][l])%mod;                                }                                else                                {                                    dp[t2][son][k+1][kk]=(dp[t2][son][k+1][kk]+dp[t1][j][k][l])%mod;                                }                            }                        }                    }                }            }        }        int tmp=i%2;        ll ans=0;        for(i=0;i<=tot;i++)        {            ans+=dp[tmp][i][m][3];        }        ans%=mod;        cout<<ans<<endl;}int main(){  //  freopen("dd.txt","r",stdin);    int ncase;    scanf("%d",&ncase);    while(ncase--)    {        init();        char tmp[110];        scanf("%d%d",&m,&n);        scanf("%s",tmp);        insert(tmp,0);        scanf("%s",tmp);        insert(tmp,1);        build_fail();        solve();    }    return 0;}