LA 3942 Remember the Word——DP + 字典树

来源:互联网 发布:淘宝网手动碎 菜机 编辑:程序博客网 时间:2024/05/16 04:37

蓝书例题,提示状态转移方程为dp【i】 = sum{dp【i + len(s)】},其中s是S【i……n-1】的前缀

然而一开始并不明白这个状态转移方程,看了很多博客还是云里雾里,最后在纸上手动演算了一下才明白,推荐不明白的朋友根据下面的方法在草纸上演算一下。

不考虑字典树,我们像蓝书介绍的那样定义dp【i】为【从第i个字符到最后一个字符的子串】的方案数,明显dp【i】是和原问题规模相同的子问题,现在我们来解决这个子问题:

对于dp【i】代表的字符串s【i~n-1】(字符串最后一位是n-1),可以看做【前缀+更小的子问题】,它能由更小的子问题转移的条件是前缀在字典中存在(这里不明白的话演算一下下面的例子),所以我们要得到子问题dp【i】的解,就要枚举它的所有前缀,从中找出合法的前缀,然后统计前缀后面的【更小的子问题】, 求和,dp【i】的解就是统计出来的和,这就是“dp【i】 = sum{dp【i + len(s)】},其中s是S【i……n-1】的前缀”的由来(注意s是dp【i】代表的子串的所有前缀)

枚举前缀的复杂的为O(n),算上递推时的复杂度,一共O(n^2)(这还没算判断前缀是否合法的复杂度),一定会超时,现在就要考虑字典树优化,对于dp【i】代表的子串,我们在字典树中寻找该子串路径上val为1的节点,说白了就是这个子串包含的前缀,每到一个这样的节点(设为x),就让dp【i】 += dp【x+1】,最终结果为dp【0】

理解了上述过程后,不难想到边界为dp【n】 = 1;

例子:

对于样例abcd,字典a、b、ab、cd,从a开始,合法前缀为a、ab,所以dp【0】 = dp【1】 + dp【2】,产生了dp【1】和dp【2】两个子问题;

先看dp【1】,它代表的字符串为bcd,合法前缀为b,所以dp【1】 = dp【2】, 产生了一个子问题dp【2】;

然后看dp【2】,合法前缀为cd,所以dp【2】 = dp【4】,dp【4】是我们设置的边界,值为1,所以dp【2】 = 1;

最终dp【0】 = 2

#include <cstdio>#include <cstring>#include <iostream>#include <algorithm>#define SIZE 26using namespace std;const int mod = 20071027;const int maxn = 5 * 1e5 + 10;char S[maxn], P[maxn];int flag = 0, n, m, total, dp[maxn];struct Trie {    int val, child[SIZE];    void init() {        val = 0;        memset(child, 0, sizeof(child));    }}trie[SIZE * maxn];void init() {    total = 0;    dp[n] = 1;    trie[0].init();}void update(char *str) {    int len = strlen(str), root = 0;    for (int i = 0; i < len; i++) {        int t = str[i] - 'a', pos = trie[root].child[t];        if (!pos) {            pos = ++total;            trie[pos].init();            trie[root].child[t] = pos;        }        root = pos;    }    trie[root].val = 1;}int query(int L, int R) {    int root = 0, sum = 0;    for (int i = L; i <= R; i++) {        int t = S[i] - 'a', pos = trie[root].child[t];        if (!pos) break;        if (trie[pos].val) {            sum = (sum + dp[i + 1]) % mod;        }        root = pos;    }    return sum;}int main() {    while (~scanf("%s", S)) {        n = strlen(S);        init();        scanf("%d", &m);        while (m--) {            scanf("%s", P);            update(P);        }        for (int i = n - 1; i >= 0; i--) {            dp[i] = query(i, n - 1);        }        printf("Case %d: %d\n", ++flag, dp[0]);    }    return 0;}


原创粉丝点击