Codeforces 452E Three strings 后缀数组 + 并查集

来源:互联网 发布:uc聊天软件 编辑:程序博客网 时间:2024/05/17 00:00

题目大意:

就是现在给出三个总长度不超过3*10^5的字符串, 每个字符串只包含字母'a' ~ 'z', 现在对于每一个L, (1 <= L <= minLength(s1, s2, s3))也就是L不超过s1, s2, s3中最短长度, 求出存在多少个i, j, k使得s1[ i ~ i + L - 1] == s2[ j ~ j + L - 1] == s3[ k ~ k + L - 1], 结果对于10^9 + 7取模之后输出


大致思路:

首先不难想到后缀数组处理三个串拼接起来的总串, 记录每一个字符的来源, 也就是记录每个后缀的来源, 然后需要根据height数组从大到小来利用并查集标记区间进行计算, 注意两个区间合并的时候之后 (i, j, k)三者不来自同一个原来的区间才能算, 所以稍微容斥一下即可

由于事先对height数组排序了, 所以也不需要树状数组之类的来辅助更新答案, 直接利用排序好的height数组的单调性即可


之前想过一个从height由小到大切割区间进行分治dfs的方法, 然后利用树状数组辅助更新答案, 但是复杂度还是太高了...果然还是需要用并查集


代码如下:

Result  :  Accepted     Memory  :  27500 KB     Time  :  202 ms

/* * Author: Gatevin * Created Time:  2015/3/18 16:10:35 * File Name: Kotori_Itsuka.cpp */#include<iostream>#include<sstream>#include<fstream>#include<vector>#include<list>#include<deque>#include<queue>#include<stack>#include<map>#include<set>#include<bitset>#include<algorithm>#include<cstdio>#include<cstdlib>#include<cstring>#include<cctype>#include<cmath>#include<ctime>#include<iomanip>using namespace std;const double eps(1e-8);typedef long long lint;#define maxn 300010int wa[maxn], wb[maxn], wv[maxn], Ws[maxn];int cmp(int *r, int a, int b, int l){    return r[a] == r[b] && r[a + l] == r[b + l];}void da(int *r, int *sa, int n, int m){    int *x = wa, *y = wb, *t, i, j, p;    for(i = 0; i < m; i++) Ws[i] = 0;    for(i = 0; i < n; i++) Ws[x[i] = r[i]]++;    for(i = 1; i < m; i++) Ws[i] += Ws[i - 1];    for(i = n - 1; i >= 0; i--) sa[--Ws[x[i]]] = i;    for(j = 1, p = 1; p < n; j *= 2, m = p)    {        for(p = 0, i = n - j; i < n; i++) y[p++] = i;        for(i = 0; i < n; i++) if(sa[i] >= j) y[p++] = sa[i] - j;        for(i = 0; i < n; i++) wv[i] = x[y[i]];        for(i = 0; i < m; i++) Ws[i] = 0;        for(i = 0; i < n; i++) Ws[wv[i]]++;        for(i = 1; i < m; i++) Ws[i] += Ws[i - 1];        for(i = n - 1; i >= 0; i--) sa[--Ws[wv[i]]] = y[i];        for(t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n; i++)            x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p++;    }    return;}int rank[maxn], height[maxn];void calheight(int *r, int *sa, int n){    int i, j, k = 0;    for(i = 1; i <= n; i++) rank[sa[i]] = i;    for(i = 0; i < n; height[rank[i++]] = k)        for(k ? k-- : 0, j = sa[rank[i] - 1]; r[i + k] == r[j + k]; k++);    return;}int f[maxn];int find(int x){    return x == f[x] ? x : f[x] = find(f[x]);}bool cmp2(int a, int b){    return height[a] > height[b];}char in[maxn];int s[maxn], sa[maxn], p[maxn], belong[maxn], N;lint cnt[maxn][3], ans[maxn];const lint mod = 1e9 + 7;int main(){    int mlen = 1e9;    N = 0;    for(int i = 0; i < 3; i++)    {        scanf("%s", in);        int len = strlen(in);        mlen = min(len, mlen);        for(int j = 0; j < len; j++)        {            belong[N] = i;            s[N++] = in[j] - 'a' + 1;        }        belong[N] = -1;        s[N++] = 27 + i;    }    N--;    s[N] = 0;    da(s, sa, N + 1, 30);    calheight(s, sa, N);    for(int i = 0; i <= N; i++) p[i] = f[i] = i;    for(int i = 0; i <= N; i++)        if(belong[i] != -1) cnt[i][belong[i]]++;    sort(p + 1, p + N + 1, cmp2);    lint result = 0;    for(int i = 1; i <= N; i++)    {        if(i > 1 && height[p[i]] != height[p[i - 1]])            for(int j = height[p[i]] + 1; j <= height[p[i - 1]]; j++)                ans[j] = result;        int bl = find(sa[p[i]]), br = find(sa[p[i] - 1]);        result = (result - cnt[bl][0]*cnt[bl][1]*cnt[bl][2] % mod + mod) % mod;        result = (result - cnt[br][0]*cnt[br][1]*cnt[br][2] % mod + mod) % mod;        for(int j = 0; j < 3; j++)            cnt[bl][j] = (cnt[bl][j] + cnt[br][j]) % mod;        f[br] = bl;        result = (result + cnt[bl][0]*cnt[bl][1]*cnt[bl][2]) % mod;    }    for(int i = 1; i <= mlen; i++)        printf("%I64d ", ans[i]);    return 0;}



顺带祭奠一下以前写的TLE了的方法...

Result  :  Time Limit Exceeded on test 42

/* * Author: Gatevin * Created Time:  2015/3/12 22:32:12 * File Name: Kotori_Itsuka.cpp */#include<iostream>#include<sstream>#include<fstream>#include<vector>#include<list>#include<deque>#include<queue>#include<stack>#include<map>#include<set>#include<bitset>#include<algorithm>#include<cstdio>#include<cstdlib>#include<cstring>#include<cctype>#include<cmath>#include<ctime>#include<iomanip>using namespace std;const double eps(1e-8);typedef long long lint;const lint mod = 1000000007LL;#define maxn 300100int wa[maxn], wb[maxn], wv[maxn], Ws[maxn];int cmp(int *r, int a, int b, int l){    return r[a] == r[b] && r[a + l] == r[b + l];}void da(int *r, int *sa, int n, int m){    int *x = wa, *y = wb, *t, i, j, p;    for(i = 0; i < m; i++) Ws[i] = 0;    for(i = 0; i < n; i++) Ws[x[i] = r[i]]++;    for(i = 1; i < m; i++) Ws[i] += Ws[i - 1];    for(i = n - 1; i >= 0; i--) sa[--Ws[x[i]]] = i;    for(j = 1, p = 1; p < n; j *= 2, m = p)    {        for(p = 0, i = n - j; i < n; i++) y[p++] = i;        for(i = 0; i < n; i++) if(sa[i] >= j) y[p++] = sa[i] - j;        for(i = 0; i < n; i++) wv[i] = x[y[i]];        for(i = 0; i < m; i++) Ws[i] = 0;        for(i = 0; i < n; i++) Ws[wv[i]]++;        for(i = 1; i < m; i++) Ws[i] += Ws[i - 1];        for(i = n - 1; i >= 0; i--) sa[--Ws[wv[i]]] = y[i];        for(t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n; i++)            x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p++;    }    return;}int rank[maxn], height[maxn];void calheight(int *r, int *sa, int n){    int i, j, k = 0;    for(i = 1; i <= n; i++) rank[sa[i]] = i;    for(i = 0; i < n; height[rank[i++]] = k)        for(k ? k-- : 0, j = sa[rank[i] - 1]; r[i + k] == r[j + k]; k++);    return;}char in[maxn];int s[maxn], sa[maxn], belong[maxn], N;lint ans[maxn];lint C[maxn];int lowbit(int x){    return -x & x;}void add(int L, lint value){    while(L <= N)        C[L] = (C[L] + value) % mod, L += lowbit(L);    return;}void update(int L, int R, lint value)//区间更新[L, R] += value{    add(L, value), add(R + 1, (-value + mod) % mod);}lint query(int pos)//单点查询{    lint ret = 0;    while(pos)        ret = (ret + C[pos]) % mod, pos -= lowbit(pos);    return ret;}void dfs(int L, int R, int h){    int i = L;    while(i <= R)    {        while(i <= R && height[i] == h) i++;        if(i > R) break;        lint cnt[4]; memset(cnt, 0, sizeof(cnt));        int j = i;        cnt[belong[sa[j - 1]]]++;        int nexh = height[i];        while(j <= R && height[j] > h)            cnt[belong[sa[j]]]++, nexh = min(nexh, height[j]), j++;    //    for(int k = h + 1; k <= nexh; k++)    //        ans[k] = (ans[k] + cnt[1]*cnt[2]*cnt[3] % mod) % mod;        update(h + 1, nexh, cnt[1]*cnt[2]*cnt[3] % mod);        dfs(i, j - 1, nexh);        i = j;    }    return;}void solve(int mlen){    dfs(1, N, 0);    for(int i = 1; i <= mlen; i++)        printf("%I64d ", query(i));}int main(){    int minlen = 1e9;    N = 0;    for(int i = 1; i <= 3; i++)    {        scanf("%s", in);        int len = strlen(in);        minlen = min(minlen, len);        for(int j = 0; j < len; j++)        {            belong[N] = i;            s[N++] = in[j] - 'a' + 1;        }        belong[N] = -1;        s[N++] = 26 + i;    }    N--;    s[N] = 0;    da(s, sa, N + 1, 30);    calheight(s, sa, N);    solve(minlen);    return 0;}



0 0
原创粉丝点击