ZOJ3494 BCD Code (AC自动机+数位DP)

来源:互联网 发布:java 线程执行顺序 编辑:程序博客网 时间:2024/04/30 15:20

用AC自动机构造出病毒串的trie图,然后设状态dp[i][j]表示长度为i且位于j节点时的符合要求的数的数量,然后按照普通数位DP做即可

递推式数位DP统计[1,x]内符合条件的数只需要考虑三种情况:

1,位数比x短的数

2,位数和x一样,但是某一位比x小的数

3,x本身是否符合条件

0比较特殊,在一般的数位DP中需要特殊处理,一般以特判为主

AC自动机的话,即trie树+fail树,理解了fail边的构造,则写出AC自动机就不是什么困难的事了

#include <stdio.h>#include <string.h>#include <algorithm>#include <queue>using namespace std;struct AC {        struct Node {                Node *ch[2],*fail;                bool ed,vis;        }memo[2010],*root;        int tot;        void New_node(Node *&o) {                o = &memo[tot++];                o->ch[0] = o->ch[1] = NULL;                o->fail = root;                o->ed = o->vis = 0;        }        void init() {                tot = 0;                New_node(root);        }        void ins(char *s) {                Node *p = root;                for ( ; *s; s ++) {                        int c = *s-'0';                        if (p->ch[c]==NULL) New_node(p->ch[c]);                        p = p->ch[c];                }                p->ed = 1;        }        void dfs(Node *p) {                p->vis = true;                if (p==root) return ;                if (p->fail->vis==false) dfs(p->fail);                p->ed |= p->fail->ed;        }        void build() {                queue<Node*> que;                for (int i = 0; i < 2; i ++)                        if (root->ch[i]!=NULL) que.push(root->ch[i]);                while (!que.empty()) {                        Node *f = que.front(); que.pop();                        for (int i = 0; i < 2; i ++) {                                if (f->ch[i]!=NULL) {                                        Node *p = f->fail;                                        for ( ; p->ch[i]==NULL && p!=root; p = p->fail);                                        f->ch[i]->fail = p->ch[i]==NULL ? root : p->ch[i];                                        que.push(f->ch[i]);                                }                        }                }                for (int i = 0; i < tot; i ++) if (!memo[i].vis) dfs(memo+i);        }        int match(int v,char *s) {                Node *p = &memo[v];                if (p->ed) return -1;                for ( ; *s; s ++) {                        int c = *s-'0';                        while (p->ch[c]==NULL && p!=root) p = p->fail;                        if (p->ch[c]!=NULL) p = p->ch[c];                        if (p->ed) return -1;                }                return p-memo;        }}ac;typedef long long lld;const int MOD = (int)1e9+9;char s[222],snum[10][6] = {"0000","0001","0010","0011","0100","0101","0110","0111","1000","1001"};lld dp[222][2010];int g[2010][10];void add(lld &a,lld b) { a += b; if (a>=MOD) a -= MOD; }void init() {        for (int i = 0; i < 222; i ++)                for (int j = 0; j < ac.tot; j ++)                        dp[i][j] = 0;        for (int i = 0; i < ac.tot; i ++)                for (int j = 0; j < 10; j ++)                        g[i][j] = ac.match(i,snum[j]);        for (int i = 0; i < ac.tot; i ++)                dp[1][i] = 1;        for (int i = 1; i < 222-1; i ++)                for (int j = 0; j < ac.tot; j ++)                        for (int k = 0; k < 10; k ++) if (g[j][k]!=-1)                                add(dp[i+1][j],dp[i][g[j][k]]);}lld calc(bool mark) {        lld ret = 0;        int p = 0,len = strlen(s+1);        for (int i = 1; i < len-i+1; i ++) swap(s[i],s[len-i+1]);        for (int i = len; i >= 1; i --) {                if (p==-1) break;                for (int j = s[i]-'0'-1; j > 0; j --) {                        int v = g[p][j];                        if (v==-1) continue;                        add(ret,dp[i][v]);                }                if (s[i]-'0'>0 && g[p][0]!=-1 && i!=len)                         add(ret,dp[i][g[p][0]]);                p = g[p][s[i]-'0'];        }        add(ret,(p!=-1)*mark);        for (int i = len-1; i >= 1; i --)                 for (int j = 1; j < 10; j ++) if (g[0][j]!=-1)                        add(ret,dp[i][g[0][j]]);        return ret;}int main() {        int cas;        scanf("%d",&cas);        while (cas--) {                int n;                scanf("%d",&n);                ac.init();                for (int i = 0; i < n; i ++) {                        scanf("%s",s);                        ac.ins(s);                }                ac.build();                init();                lld ans = 0;                scanf("%s",s+1);                ans -= calc(0);                scanf("%s",s+1);                (((ans += calc(1)) %= MOD) += MOD) %= MOD;                printf("%lld\n",ans);        }        return 0;}


0 0
原创粉丝点击