HDU 6153 A Secret KMP

来源:互联网 发布:淘宝发布宝贝主图尺寸 编辑:程序博客网 时间:2024/06/05 09:30

简略题意:给出两个长度为1e6的S,T串,问T串的每个后缀在S中作为子串出现的次数记为Si, 长度为Li,问len1i=0SiTi的答案。

首先将两串都反转,从而变成前缀的处理关系。
对串2建立fail数组,用串1去进行匹配,再在fail树上累加一下后缀和即可。
复杂度是O(n)的。

#include <bits/stdc++.h>using namespace std;typedef long long LL;const LL mod = 1e9+7;const LL maxn = 1100000;LL t;char s1[maxn], s2[maxn];LL fail[maxn], sum[maxn];int main() {    scanf("%lld", &t);    while(t--) {        memset(fail, 0, sizeof fail);        memset(sum, 0, sizeof sum);        scanf("%s%s", s1, s2);        LL len1 = strlen(s1), len2 = strlen(s2);        reverse(s1, s1 + len1);        reverse(s2, s2 + len2);        LL j = 0, k = -1;        fail[0] = -1;        while(j < len2) {            if(k == -1 || s2[j] == s2[k])                fail[++j] = ++k;            else                k = fail[k];        }        j = 0;        for(LL i = 0; i < len1; i++) {            while(j > 0 && s1[i] != s2[j])                j = fail[j];            if(s2[j] == s1[i])                j++;            sum[j]++;            if(j == len2) j = fail[j];        }        LL ans = 0;        for(int i = len2; i >= 1; i--)            sum[fail[i]] += sum[i];        for(int i = 1; i <= len2; i++)            ans += (LL)i*sum[i], ans %= mod;        printf("%lld\n", ans);    }    return 0;}

这里提供一个解法为O(nlogn)的,但是被卡掉的算法。
其实我们不反转串,将两个串拼接在一起之后可以发现,对每个后缀进行一次在rank数组上的二分,即可得到答案,不过有可能有T串匹配自己的部分,因此每次需要建立两次后缀数组,答案即为(S+T)的部分减去T的部分。
本地跑了几组极限数据对拍都是对的,不过时限的确紧了。
代码姑且也放出一份…

#include <bits/stdc++.h>using namespace std;typedef long long LL;LL ans = 0;const LL mod = 1e9+7;struct SuffixArray{    static const int maxn = 1e6+7;    int s[maxn], n, m;    int sa[maxn], rank[maxn], height[maxn];    int t[maxn], t2[maxn], c[maxn];    int MIN[maxn][30];    void build_sa(){        s[n++] = 0;        int *x = t, *y = t2;        for(int i=0; i<m; i++) c[i] = 0;        for(int i=0; i<n; i++) c[x[i]=s[i]]++;        for(int i=1; i<m; i++) c[i]+=c[i-1];        for(int i=n-1; i>=0; i--) sa[--c[x[i]]] = i;        for(int k=1; k<=n; k<<=1){            int p = 0;            for(int i=n-k; i<n; i++) y[p++] = i;            for(int i=0; i<n; i++) if(sa[i]>=k) y[p++] = sa[i]-k;            for(int i=0; i<m; i++) c[i] = 0;            for(int i=0; i<n; i++) c[x[y[i]]]++;            for(int i=1; i<m; i++) c[i]+=c[i-1];            for(int i=n-1; i>=0; i--) sa[--c[x[y[i]]]] = y[i];            swap(x, y);            p = 1; x[sa[0]] = 0;            for(int i=1; i<n; i++)                x[sa[i]] = y[sa[i-1]]==y[sa[i]] && y[sa[i-1]+k]==y[sa[i]+k]?p-1:p++;            if(p >= n) break;            m=p;        }        n--;    }    void getheight(){        int k = 0;        for(int i=n; i>=1; i--) rank[sa[i]] = i;        for(int i=0; i<n; i++){            if(k) k --;            int j = sa[rank[i]-1];            while(s[i+k] == s[j+k]) k++;            height[rank[i]] = k;        }    }    void RMQ_init(){        for(int i=1; i<=n; i++) MIN[i][0] = height[i];        for(int j=1; (1<<j)<=n; j++){            for(int i=1; i+(1<<j)<=n; i++){                MIN[i][j] = min(MIN[i][j-1], MIN[i+(1<<(j-1))][j-1]);            }        }    }    int RMQ(int L, int R){        int k = 0;        while(1<<(k+1) <= R-L+1) k++;        return min(MIN[L][k], MIN[R-(1<<k)+1][k]);    }    int LCP(int i, int j){        return RMQ(i+1, j);    }    void init(string str){        m = 130;        n = str.size();        for(int i=0; i<n; i++)            s[i] = (int)str[i];        build_sa();        getheight();        RMQ_init();    }    void solve(int st) {        int ed = n;        int len = 1;        for(int i = n-1; i >= st; i--, len++) {            int l = 1, r = 2*n;            int now = rank[i];            while(l < r) {                int m = (l + r) >> 1;                if(now + m > n - 1) {                    r = m;                    continue;                }                if(LCP(now, now+m) >= len) l = m + 1;                else r = m;            }            ans += len*(l)%mod;            ans %= mod;        }    }}suffix;int main() {    int t;    scanf("%d", &t);    while(t--) {        ans = 0;        string s1, s2;        s1 = char('z'+1) + s1;        cin >> s1 >> s2;        s2 = char('z'+1) + s2;        int len = s1.size();        s1 += s2;        suffix.init(s1);        suffix.solve(len+1);        LL tmp = ans;        ans = 0;        suffix.init(s2);        suffix.solve(1);        tmp -= ans;        tmp %= mod;        tmp += mod;        tmp %= mod;        printf("%lld\n", tmp);    }    return 0;}