AC自动机算法简介(洛谷P3808)

来源:互联网 发布:2017php应聘要求 编辑:程序博客网 时间:2024/06/05 04:27

算法用途

一个常见的例子就是给出n个单词,再给出一段包含m个字符的文章,让你找出有多少个单词在文章里出现过。如果直接跑n遍KMP的话时间复杂度会较高(O(nm)),这时AC自动机算法便应运而生。

算法先决条件

要学会AC自动机,首先需要掌握Trie树和KMP算法。

KMP算法入门 Trie树入门

算法思想

把所有的模板串存到Trie中,预处理出每个节点的失配指针(和KMP中的next数组很像),在查询时实现快速跳转。(其实就是在Trie上KMP)

算法要点

以洛谷P3808为例,完整代码请见模板1

构造Trie树

这个就是Trie树的插入操作,根本没有任何区别啊!
根据题目所求可以适当修改所需记录的量。

void nsrt(char s[]){    int node=0,len=strlen(s);//一定要先算好s的长度,不然会T掉的(别问我怎么知道的)    for (int i=0;i<len;i++){        int x=s[i]-'a';        if (!a[node][x])//如果没有这个儿子            a[node][x]=++k;//给这个儿子编号        node=a[node][x];//继续循环    }    f[node]++;//该字符串以这个字符为结尾}

失配指针

在处理失配指针时,要使用队列扩展(想想为什么)。其他具体见代码注释

void asknxt(){    queue<int>que;//队列    //先把root的子节点加入队列中    for (int i=0;i<26;i++)        if (a[0][i]){            que.push(a[0][i]);            nxt[a[0][i]]=0;        }    while (!que.empty()){        int x=que.front();        que.pop();        int node=nxt[x];//node为x的失配指针所指向的节点        //枚举x的所有子节点        for (int i=0;i<26;i++)            if (a[x][i]){//如果存在该节点                que.push(a[x][i]);//加入队列                nxt[a[x][i]]=a[node][i];//把该节点的失配指针指向其父节点的失配指针指向的节点的相同子节点            }            else//如果不存在该节点                a[x][i]=a[node][i];//直接把x的失配指针指向的节点的这个儿子拿过来当自己的儿子(很重要的一步,查询 时会很方便)         }}

查询

把文本串扫一遍,与Trie树进行匹配,记录有多少字符串出现,具体见代码注释。

int srch(char s[]){    int sum=0,node=0,len=strlen(s);//这里也要算好    for (int i=0;i<len;i++){        node=a[node][s[i]-'a'];//更新node        for (int j=node;j&&f[j]!=-1;j=nxt[j]){//一直到失配指针指向root且以该节点为结尾的字符串并未被查找过,如果当前节点已经找过,就说明之后的肯定也被找过            sum+=f[j];            f[j]=-1;//避免重复计算        }    }    return sum;}

模板1

洛谷P3808

#include<cstdio>#include<cstring>#include<algorithm>#include<queue>#define MAXN 1000000using namespace std;char s[MAXN+5];int a[MAXN+5][26],nxt[MAXN+5];int f[MAXN+5];int n,k=0;void nsrt(char s[]){    int node=0,len=strlen(s);    for (int i=0;i<len;i++){        int x=s[i]-'a';        if (!a[node][x])            a[node][x]=++k;        node=a[node][x];    }    f[node]++;}void asknxt(){    queue<int>que;    for (int i=0;i<26;i++)        if (a[0][i]){            que.push(a[0][i]);            nxt[a[0][i]]=0;        }    while (!que.empty()){        int x=que.front();        que.pop();        int node=nxt[x];        for (int i=0;i<26;i++)            if (a[x][i]){                que.push(a[x][i]);                nxt[a[x][i]]=a[node][i];            }            else                a[x][i]=a[node][i];            }}int srch(char s[]){    int sum=0,node=0,len=strlen(s);    for (int i=0;i<len;i++){        node=a[node][s[i]-'a'];        for (int j=node;j&&f[j]!=-1;j=nxt[j]){            sum+=f[j];            f[j]=-1;        }    }    return sum;}int main(){    scanf("%d",&n);    for (int i=1;i<=n;i++){        scanf("%s",s);        nsrt(s);    }    asknxt();    scanf("%s",s);    int ans=srch(s);    printf("%d\n",ans);    return 0;}

模板2

洛谷P3796

#include<cstdio>#include<cstring>#include<algorithm>#include<queue>#define MAXN 7000#define MAXM 1500#define MAXL 1000000using namespace std;int px[MAXL+5];char s[MAXL+5],ms[MAXM+5][MAXN+5];int a[MAXL+5][26],nxt[MAXL+5];int f[MAXL+5];int n,k=0;int num[MAXM+5];void nsrt(int num,char s[]){    int node=0,len=strlen(s);    for (int i=0;i<len;i++){        int x=s[i]-'a';        if (!a[node][x])            a[node][x]=++k;        node=a[node][x];    }    f[node]++;    px[node]=num;}void asknxt(){    queue<int>que;    for (int i=0;i<26;i++)        if (a[0][i]){            que.push(a[0][i]);            nxt[a[0][i]]=0;        }    while (!que.empty()){        int x=que.front();        que.pop();        int node=nxt[x];        for (int i=0;i<26;i++)            if (a[x][i]){                que.push(a[x][i]);                nxt[a[x][i]]=a[node][i];            }            else                a[x][i]=a[node][i];            }}void srch(char s[]){    int node=0,len=strlen(s);    for (int i=0;i<len;i++){        node=a[node][s[i]-'a'];        for (int j=node;j;j=nxt[j])            if (f[j])                num[px[j]]++;    }}int main(){    while (~scanf("%d",&n)&&n){        memset(nxt,0,sizeof(nxt));        memset(a,0,sizeof(a));        memset(ms,' ',sizeof(ms));        memset(f,0,sizeof(f));        memset(num,0,sizeof(num));        memset(px,0,sizeof(px));        for (int i=1;i<=n;i++){            scanf("%s",ms[i]);            nsrt(i,ms[i]);        }        scanf("%s",s);        asknxt();        srch(s);        int ma=0,nm[MAXM+5]={0},node=0;        for (int i=1;i<=n;i++)            ma=max(ma,num[i]);        for (int i=1;i<=n;i++)            if (num[i]==ma)                nm[++node]=i;        printf("%d\n",ma);        for (int i=1;i<=node;i++)            printf("%s\n",ms[nm[i]]);    }    return 0;}

还没懂的小伙伴可以多看几遍,或者参考这位大佬的blog

原创粉丝点击