hdu 5318 The Goddess Of The Moon 矩阵快速幂+dp

来源:互联网 发布:淘宝禁止好评返现2016 编辑:程序博客网 时间:2024/04/30 17:38

这题意也有够坑的。。

题意:

给你n(<=50)个串,每个串长度<=10。如果a串的后缀和b串的前缀相等,并且长度>=2,则b串可以连在a串后面(注意,不用合并a,b串相同的位置)。

每个串的个数都是无穷个,现在让你选m(<=1e9)个串,问你有多少个不同的串(不是有多少个不同长度的串= =)。

11  111

111  11

*上面两个串并不相同,因此算两个。


思路:

注意字符串去重。还好m不会等于0。= =||

先暴力求出每个串能转移的位置。a[i][j]为1,第j个字符串能连在第i个字符串后面;反之为0,则不能。

定义dp[i][j]:选了i个串,最后以j串结尾的方案数。

则dp[i][j] += dp[i-1][k]*a[k][j]; (1<=k<= n)。


由于i太大,有1e9个,然后就是矩阵快速幂登场了。

首先res矩阵第一行保存以每个串结尾的方案数,一开始第一行全为1.,其他位置全为0。

然后a矩阵快速幂m-1次方(为什么要-1,因为res第一行一开始就设为1,因此看做已经选了一个串了)。

最后把res的第一行加起来便是要求的答案。

code:

#include <bits/stdc++.h>using namespace std;const int N = 55;const int MOD = 1000000007;typedef long long LL;int n, m;char ch[N][15];struct PP {    int tp[N][N];}a;PP mul(const PP &a, const PP &b) {    PP c;    memset(c.tp, 0, sizeof(c.tp));    for(int i = 1;i <= n; i++)        for(int j = 1;j <= n; j++)            for(int k = 1;k <= n; k++) {                c.tp[i][j] += (LL)a.tp[i][k]*b.tp[k][j]%MOD;                c.tp[i][j] %= MOD;            }    return c;}bool check(int x, int y) {    int lenx = strlen(ch[x]), leny = strlen(ch[y]);    if(lenx == 1 || leny == 1) return false;    for(int i = lenx-2;i >= 0; i--) {        int j = 0, ti = i;        while(ti < lenx && j < leny && ch[x][ti] == ch[y][j]) ti++, j++;        if(ti == lenx) return true;    }    return false;}            void solve() {    memset(a.tp, 0, sizeof(a.tp));    for(int i = 1;i <= n; i++)         for(int j = 1;j <= n; j++)            if(check(i, j)) a.tp[i][j] = 1;        int tm = m;    PP tmp = a, res;    memset(res.tp, 0, sizeof(res.tp));    for(int i = 1;i <= n; i++) res.tp[1][i] = 1;    tm--;    while(tm) {        if(tm&1) {            res = mul(res, tmp);        }        tmp = mul(tmp, tmp);        tm >>= 1;    }    int ans = 0;    for(int i = 1;i <= n; i++) ans = (ans+res.tp[1][i])%MOD;    printf("%d\n", ans);}    set <string> st;int main() {    int T;    scanf("%d", &T);    while(T--) {        st.clear();        scanf("%d%d", &n, &m);        for(int i = 1;i <= n; i++) {            scanf("%s", ch[i]);            st.insert(ch[i]);        }        int tn = 0;        for(auto &it:st) strcpy(ch[++tn], it.c_str());        n = tn;        solve();    }    return 0;}




0 0