hdu - 2222 - Keywords Search(AC自动机)

来源:互联网 发布:淘宝评价有什么好处 编辑:程序博客网 时间:2024/05/01 03:55

题意:给出N个由小写字母组成的关键词,再给一个描述,问有多少个关键词在这个描述中出现(N <= 10000, 描述的长度 <= 1000000, 关键词的长度 <= 50)。

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=2222

——>>AC自动机第二题。

题目没说明描述是否都是小写字母,实践中以假设描述为小写字母可以AC。

注意:关键词有重复!(因此而WA无数)

仿RJ《训练指南》:

可能有多个last指向同一个单词结点,但这个单词结点只能被计数1次。

#include <cstdio>#include <cstring>#include <queue>using namespace std;const int maxn = 1000000 + 10;const int maxw = 50 + 10;const int maxnode = 500000 + 10;char qus[maxn], kw[maxw];int ch[maxnode][26], val[maxnode], f[maxnode], last[maxnode], num[maxnode];struct AC{    int sz;    int cnt;    AC(){        sz = 1;        cnt = 0;        memset(ch[0], 0, sizeof(ch[0]));    }    int idx(char c){        return c - 'a';    }    void insert(char *s){        int len = strlen(s), i, u = 0;        for(i = 0; i < len; i++){            int c = idx(s[i]);            if(!ch[u][c]){                memset(ch[sz], 0, sizeof(ch[sz]));                val[sz] = num[sz] = 0;                ch[u][c] = sz++;            }            u = ch[u][c];        }        num[u]++;        val[u] = 1;    }    void getFail(){        queue<int> qu;        f[0] = last[0] = 0;        for(int c = 0; c < 26; c++){            int u = ch[0][c];            if(u){                f[u] = last[u] = 0;                qu.push(u);            }        }        while(!qu.empty()){            int r = qu.front(); qu.pop();            for(int c = 0; c < 26; c++){                int u = ch[r][c];                if(!u) continue;                qu.push(u);                int v = f[r];                while(v && !ch[v][c]) v = f[v];                f[u] = ch[v][c];                last[u] = val[f[u]] ? f[u] : last[f[u]];            }        }    }    void dfs(int u){        if(u){            cnt += num[u];            num[u] = 0;            dfs(last[u]);        }    }    void find(char *T){        getFail();        int len = strlen(T), i, j = 0;        for(i = 0; i < len; i++){            int c = idx(T[i]);            while(j && !ch[j][c]) j = f[j];            j = ch[j][c];            if(val[j]) dfs(j);            else if(last[j]) dfs(last[j]);        }    }    void solve(){        printf("%d\n", cnt);    }};int main(){    int T, N;    scanf("%d", &T);    while(T--){        AC ac;        scanf("%d", &N);        while(N--){            scanf("%s", kw);            ac.insert(kw);        }        scanf("%s", qus);        ac.find(qus);        ac.solve();    }    return 0;}

改用指针:

(内存很紧,不要随便创建新结点。)

#include <cstdio>#include <cstring>#include <queue>using namespace std;const int maxn = 1000000 + 10;const int maxw = 50 + 10;char qus[maxn], kw[maxw];struct node{    int cnt;    node *f;    node *next[26];    node(){        cnt = 0;        memset(next, 0, sizeof(next));        f = NULL;    }};struct AC{    node *root;    int ret;    AC(){        root = new node;        ret = 0;    }    int idx(char c){        return c - 'a';    }    void insert(char *s){        node *p = root;        int len = strlen(s), i;        for(i = 0; i < len; i++){            int c = idx(s[i]);            if(!p->next[c]) p->next[c] = new node;            p = p->next[c];        }        p->cnt++;    }    void getFail(){        queue<node*> qu;        root->f = NULL;     //不要用自己,否则下面死循环        qu.push(root);        while(!qu.empty()){            node *r = qu.front(); qu.pop();            for(int c = 0; c < 26; c++) if(r->next[c]){                node *u = r->next[c];                qu.push(u);                node *v = r->f;                while(v && !v->next[c]) v = v->f;                if(v) u->f = v->next[c];                else u->f = root;            }        }    }    void find(char *T){        getFail();        int len = strlen(T), i;        node *j = root;        for(i = 0; i < len; i++){            int c = idx(T[i]);            while(j && !j->next[c]) j = j->f;            if(j) j = j->next[c];            else j = root;            node *p = j;            while(p != root && p->cnt != -1){                ret += p->cnt;                p->cnt = -1;                p = p->f;            }        }    }    void solve(){        printf("%d\n", ret);    }};int main(){    int T, N;    scanf("%d", &T);    while(T--){        AC ac;        scanf("%d", &N);        while(N--){            scanf("%s", kw);            ac.insert(kw);        }        scanf("%s", qus);        ac.find(qus);        ac.solve();    }    return 0;}


原创粉丝点击