[AC自动机+数位DP] ZOJ3494 BCD Code

来源:互联网 发布:磁贴数据库已损坏 编辑:程序博客网 时间:2024/04/30 15:05

ZOJ3494

题意:先理解BCD编码,不是普通的二进制,设一个n位的整数x= An*10^n-1 + An-1*10^n-2 +...+ A1*10^0,那么x的BCD编码为An的二进制拼接上An-1的二进制拼接上An-2...直到A1。 所以 127的BCD编码为 0001 0010 0111。一个n位的整数,其BCD编码有n*4位。

然后题目有一个只包含01串的集合S,要求[A, B]范围内所有的BCD编码中不包含S中任意一个串的数的个数。

例如S={ "00"}, A=1, B=10, 那么答案只有"0101", "0110", "0111", 也就是十进制的5, 6, 7,答案是3个。

1 < A <= B < 1e200

解法:先对S建立AC自动机,由于S是01串集合,每次转移只能加1位二进制位,但是[A,B]都是十进制的,加一个十进制位相当于转移了4次(4个2进制位),所以再预处理出所有转移4次的结果,这样就知道了所有DP状态的转移,然后就可以愉快地数位DP啦,AC自动机在KB菊苣博客学的写法,这题里面自动机的含义就很明显了。


#include<bits/stdc++.h>#define ll long long intusing namespace std;const int mod = 1e9+9;char s[205], s1[205];ll dp[205][2005][2];int bcd[2005][10];int num[205];struct Trie{    int son[2005][2], fail[2005], end[2005];    int root, alloc;    int newnode(){        son[alloc][0] = son[alloc][1] = -1;        end[alloc] = 0;        return alloc++;    }    void init(){        alloc = 0;        root = newnode();    }    void insert(char *s){        int now = root;        for(int i = 0; s[i]; ++i){            if(son[now][s[i]-'0'] == -1) son[now][s[i]-'0'] = newnode();            now = son[now][s[i]-'0'];        }        end[now] = 1;    }    void build(){        queue<int>q;        fail[root] = root;        for(int i = 0; i < 2; ++i){            if(son[root][i] == -1) son[root][i] = root;            else{                fail[son[root][i]] = root;                q.push(son[root][i]);            }        }        while(!q.empty()){            int now = q.front(); q.pop();            if(end[fail[now]]) end[now] = 1;            for(int i = 0; i < 2; ++i){                if(son[now][i] == -1) son[now][i] = son[fail[now]][i];                else{                    fail[son[now][i]] = son[fail[now]][i];                    q.push(son[now][i]);                }            }        }    }    int add(int now, int k){        if(end[now]) return -1; //当前状态不合法        for(int i = 3; i >= 0; --i){ // 4次转移            if(end[son[now][(k>>i)&1]]) return -1; // end为true表示包含S串, -1表示转移后状态不合法            now = son[now][(k>>i)&1];        }        return now;    }    void getbcd(){ // 预处理所有状态转移4次后的状态        memset(bcd, -1, sizeof(bcd));        for(int i = 0; i < alloc; ++i){            for(int j = 0; j < 10; ++j){                bcd[i][j] = add(i,j); // bcd[i][j]表示状态i加上了十进制j之后的状态            }        }    }    ll dfs(int pos, int st, int zr, int f){ // zr控制前导0 st为当前状态        if(pos < 0) return 1;        if(!f && dp[pos][st][zr] != -1) return dp[pos][st][zr];        ll ans = 0;        if(zr) ans += dfs(pos-1, st, 1, f && num[pos] == 0); //由于前导0相当于没有,不能去匹配,单独考虑加0的情况        else if(bcd[st][0] != -1) ans += dfs(pos-1, bcd[st][0], 0, f && num[pos] == 0);        if(ans >= mod) ans -= mod;        int end = f? num[pos] : 9;        for(int i = 1; i <= end; ++i){            if(bcd[st][i] != -1){                ans += dfs(pos-1, bcd[st][i], 0, f && num[pos] == i); // 已经预处理了所有状态的转移情况 bcd[st][i] 就表示st状态加上十进制位i之后的状态                if(ans >= mod) ans -= mod;            }        }        if(!f) dp[pos][st][zr] = ans;        return ans;    }    void rv(int *n, char *s){ //翻转数位,从高位到低位DP,顺便转成int类型        int len = strlen(s), cnt = 0;        for(int i = len-1; i >= 0; --i){            n[cnt++] = s[i] -'0';        }    }    ll solve(char *s){        int len = strlen(s);        rv(num, s);        return dfs(len-1, 0, 1, 1);    }    int query(char *s){ // 由于AB范围太大,又不想写大数减法,利用AC自动机的多串匹配直接查询A是否合法        int now = root;        for(int i = 0; s[i]; ++i){            for(int j = 3; j >= 0; --j){ // 1个十进制位匹配4个2进制位                now = son[now][((s[i]-'0')>>j)&1];                int tmp = now;                while(tmp != root){                    if(end[tmp]) return 0;                    tmp = fail[tmp];                }            }        }        return 1;    }};Trie ac;int main(){    int T;    scanf("%d", &T);    while(T--){        ac.init();        int n;        scanf("%d", &n);        for(int i = 0; i < n; ++i){            scanf("%s", s);            ac.insert(s);        }        ac.build();        ac.getbcd();        scanf("%s%s", s, s1);        memset(dp, -1, sizeof(dp));        ll ans1 = ac.solve(s1);        ll ans2 = ac.solve(s);        int ans3 = ac.query(s);        printf("%lld\n", (ans1-ans2+ans3+mod)%mod);    }}


0 0