POJ 3415Common Substrings 后缀数组 + 线段树 + dfs

来源:互联网 发布:北京java培训哪家好 编辑:程序博客网 时间:2024/06/05 09:03

题意:给出两个字符串A,B和一个数字K。

计算S = {(i,j,k) | k >= K ,A(i,k) == B(j,k)}这个集合的元素个数。

思路: 首先后缀数组处理出 high[],然后对high[]建立线段树,线段树的每个节点记录两个信息:对应区间的最小值和最小值的位置(下标),继而对线段树进行dfs。

dfs的过程都写在注释里了。

void dfs(int l,int r,LL &anw,int len,int Low){    if(l > r)        return ;            N tmp = Query(1,1,len,l,r);//询问此段区间内的最小值及其位置。    if(tmp.Min >= Low)    {        //此段区间内分属A,B串的个数相乘然后与可取长度的个数相乘。        anw += (ans1[r]-ans1[l-2])*(ans2[r]-ans2[l-2])*(tmp.Min-Low+1);        dfs(l,tmp.site-1,anw,len,tmp.Min+1);//从最小值处分开继续dfs,[l,r]内的tmp.Min均以计算,故加一。        dfs(tmp.site+1,r,anw,len,tmp.Min+1);        return ;    }    dfs(l,tmp.site-1,anw,len,Low);    dfs(tmp.site+1,r,anw,len,Low);}

下面为全部代码

#include <algorithm>#include <iostream>#include <cstring>#include <cstdlib>#include <cstdio>#include <queue>#include <cmath>#include <stack>#include <map>#include <ctime>#include <iomanip>#pragma comment(linker, "/STACK:1024000000");#define EPS (1e-6)#define LL long long#define ULL unsigned long long#define _LL __int64#define INF 0x3f3f3f3f#define Mod 1000000007using namespace std;const int MAXN = 200510;char s[MAXN];int Rank[2*MAXN],sa[2*MAXN],tr[2*MAXN],high[MAXN];struct EDGE{    int v,next;}edge[2*MAXN];int tail[MAXN],Top;inline void Link(int u,int v){    edge[Top].v = v;    edge[Top].next = -1;    edge[tail[u]].next = Top;    tail[u] = Top++;}void Get_SA(char *s,int n,int m){    memset(Rank,0,sizeof(Rank));    memset(sa,0,sizeof(sa));    int i,j,k,ans,site;    for(i = max(n,m);i >= 0; --i)        tail[i] = i,edge[i].next = -1;    Top = max(n,m)+1;    for(i = 1; i <= n; ++i)        Link(s[i]-'A',i);    ans = 1,site = 1;    for(i = 0; i <= m; ++i)    {        for(j = edge[i].next; j != -1; j = edge[j].next)            sa[site++] = edge[j].v,Rank[edge[j].v] = ans;        if(edge[i].next != -1)            ans++;        tail[i] = i,edge[i].next = -1;    }    for(k = 1;k <= n; k <<= 1)    {        Top = n+1;        for(i = 1;i <= n; ++i)            Link(Rank[sa[i]+k],sa[i]);        site = 1;        for(i = 0;i <= n; ++i)        {            for(j = edge[i].next;j != -1; j = edge[j].next)                sa[site++] = edge[j].v;            tail[i] = i,edge[i].next = -1;        }        Top = n+1;        for(i = 1;i <= n; ++i)            Link(Rank[sa[i]],sa[i]);        site = 1;        for(i = 1;i <= n; ++i)        {            for(j = edge[i].next;j != -1; j = edge[j].next)                sa[site++] = edge[j].v;            tail[i] = i,edge[i].next = -1;        }        for(tr[sa[1]] = 1,i = 2,ans = 1;i <= n; ++i)        {            if(Rank[sa[i]] != Rank[sa[i-1]] || Rank[sa[i]+k] != Rank[sa[i-1]+k])                ans++;            tr[sa[i]] = ans;        }        for(i = 1;i <= n; ++i)            Rank[i] = tr[i];        if(ans >= n)            break;    }    for(i = 1,k = 1;i <= n; ++i)    {        if(k) k--;        if(Rank[i] == 1) {k = 0;high[1] = n-sa[1]+1;continue;}        j = sa[Rank[i]-1];        while(i+k <= n && j+k <= n && s[i+k] == s[j+k])            k++;        high[Rank[i]] = k;    }////    for(i = 1;i <= n; ++i)//        printf("i = %2d SA = %2d Rank = %2d high = %2d\n",i,sa[i],Rank[i],high[i]);    //以上为Rank,SA,HIGH的构造过程}LL ans1[MAXN],ans2[MAXN];struct N{    int site,Min;}st[4*MAXN];void Init(int site,int l,int r){    if(l == r)    {        st[site].Min = high[l],st[site].site = l;        return ;    }    int mid = (l+r)>>1;    Init(site<<1,l,mid);    Init(site<<1|1,mid+1,r);    st[site] = st[site<<1].Min < st[site<<1|1].Min ? st[site<<1] : st[site<<1|1];}N Query(int site,int L,int R,int l,int r){    if(L == l && R == r)        return st[site];    int mid = (L+R)>>1;    if(r <= mid)        return Query(site<<1,L,mid,l,r);    if(mid < l)        return Query(site<<1|1,mid+1,R,l,r);    N t1 = Query(site<<1,L,mid,l,mid);    N t2 = Query(site<<1|1,mid+1,R,mid+1,r);    if(t1.Min < t2.Min)        return t1;    return t2;}//l,r为左右区间端点。anw为最终答案。len为high[]的size,Low为此次dfs中符合要求的最小值,初始时为输入的K。void dfs(int l,int r,LL &anw,int len,int Low){    if(l > r)        return ;            N tmp = Query(1,1,len,l,r);//询问此段区间内的最小值及其位置。    if(tmp.Min >= Low)    {        //此段区间内分属A,B串的个数相乘然后与可取长度的个数相乘。        anw += (ans1[r]-ans1[l-2])*(ans2[r]-ans2[l-2])*(tmp.Min-Low+1);        dfs(l,tmp.site-1,anw,len,tmp.Min+1);//从最小值处分开继续dfs,[l,r]内的tmp.Min均以计算,故加一。        dfs(tmp.site+1,r,anw,len,tmp.Min+1);        return ;    }    dfs(l,tmp.site-1,anw,len,Low);    dfs(tmp.site+1,r,anw,len,Low);}int main(){    int n,k,len,i;    while(scanf("%d",&k) && k)    {        scanf("%s",s+1);        n = strlen(s+1);        scanf("%s",s+n+2);        s[n+1] = 'z'+1;        Get_SA(s,len = strlen(s+1),200);        for(ans1[0] = 0,ans2[0] = 0, i = 1;i <= len; ++i)        {            ans1[i] = ans1[i-1],ans2[i] = ans2[i-1];            if(sa[i] <= n) ans1[i]++;            if(sa[i] > n+1) ans2[i]++;        }        Init(1,1,len);        LL anw = 0;        dfs(2,len,anw,len,k);        printf("%I64d\n",anw);    }    return 0;}


0 0