ZOJ Problem Set - 3661 Palindromic Substring

来源:互联网 发布:查士丁尼瘟疫 知乎 编辑:程序博客网 时间:2024/05/01 11:50

The 2012 ACM-ICPC Asia Changchun Regional Contest-G

题目大意就是求被hash出来的第k小的回文串。


首先有一个结论是一个长度为n的串,它的不相同的回文串的个数不超过n。那么可以试图暴力求出每一种回文串的个数。然后排个序就可以算出第k大的。


然后用之前写过的manacher算法求以每个位置为中心的回文串。如果[l,r]是一个回文串,我们则记录下来,然后判断[l+1,r-1]这一种回文串之前有没有找过,如果找过了,那么显然[l+2,r-2]等等的子回文串都找到过了,就不要找了。


这样能在找出所有的回文串用hash的话,可以认为是线性的,然后我是用了map,并且了解决冲突,用了两个hash数组。然后我们统计每种回文串。

若[l,r]是回文串,则[l-1,r+1]也是回文串,我们可以认为[l-1,r-1]是[l,r]的儿子,然后我们如果算出父亲回文串的方案,则可以把父亲回文串的方案数累加到儿子回文串上(这里指的回文串都是一种回文串)。比如已经统计了bacab方案为8,则可以给aca这种回文串的个数加8。这样的话我们可以按回文串的长度依次统计每种回文串出现的次数,这样这个问题就解决了。由于用了map,我写的程序非常慢。但是勉强过了~。

#include <iostream>#include <cstdio>#include <cstdlib>#include <cstring>#include <map>#include <set>#include <algorithm>#define MAXN 400020#define MOD 777777777LL#define Mod 1000000007LLusing namespace std;typedef long long LL;LL k;map<pair<int,int> ,int> mp;struct     node{      int l,r;      int value,cnt;      pair<int,int> valuePair;}pal[MAXN];char s[MAXN],ch[MAXN];int len[MAXN],R1[MAXN],R2[MAXN],v[50],n,m,Test;LL sum[MAXN],Pow1[MAXN],Pow2[MAXN],hash1[MAXN],hash2[MAXN];bool       cmp_len(const node &a,const node &b){           return ((a.r-a.l)>(b.r-b.l));}bool       cmp_value(const node&a,const node &b){           return (a.value<b.value);}void  manacher(char s[], int len[], int n){      for (int i=0;i< n;++i) len[i]=0;      int id=0,mx=1;      for (int i=0,j=0;i<n;++i)      {          len[i]=min(mx-i,len[2*id-i]);          for (j=len[i]; (i-j>=0 && i+j<n) && (s[i-j]==s[i+j]); j++);            len[i]=j;            if (j+i>mx) mx=j+i,id=i;        }}void  make_hash(){      sum[0]=0;      for (int i=0;i<n;++i)           sum[i+1]=(sum[i]*26+v[s[i]-'a'])%MOD;}int   getValue(int L,int R){      LL ans = sum[R]-sum[L-1]*Pow1[R-L+1]%MOD;      if(ans<0) ans+=MOD;      return int(ans);}pair<int,int>   getHash(int L,int R){      LL ans1=hash1[R+1]-hash1[L]*Pow1[R-L+1]%MOD;      LL ans2=hash2[R+1]-hash2[L]*Pow2[R-L+1]%MOD;      if (ans1<0) ans1+=Mod;      if (ans2<0) ans2+=Mod;      return make_pair(int(ans1),int(ans2));}int main(){    Pow1[0]=Pow2[0]=1;    for (int i=1;i<=MAXN;++i){        Pow1[i]=Pow1[i-1]*26%MOD;        Pow2[i]=Pow2[i-1]*27%MOD;    }         cin>>Test;        while (Test--){                            cin>>n>>m;              scanf("%s",s);              n = strlen(s);              memset(len,0,sizeof(len));              int L=0;              hash1[0]=hash2[0]=0;              for (int i=0;i<n;++i){                  ch[L++]='#';                  ch[L++]=s[i];                  hash1[i+1]=(hash1[i]*26+s[i]-'a'+1)%MOD;                  hash2[i+1]=(hash2[i]*27+s[i]-'a'+1)%MOD;              }              ch[L++]='#';              manacher(ch,len,L);              memset(R1,0,sizeof(R1));              memset(R2,0,sizeof(R2));              for (int i=0;i<L;++i)              {                  if (i%2) R1[i/2]=max(R1[i/2],len[i]/2); else                           R2[i/2-1]=max(R2[i/2-1],len[i]/2);              }              mp.clear();              int nPs=0;              for (int i=0;i<n;++i){                  bool flag= true;                  while (i-R1[i]+1>=0 && i+R1[i]-1<n && R1[i]>=1)                   {                        pair<int,int> hash_value = getHash(i-R1[i]+1,i+R1[i]-1);                        if (mp.find(hash_value)!=mp.end()){                           if(flag)pal[mp[hash_value]].cnt++;                            break;                        }                        mp[hash_value]=nPs;                        pal[nPs].l=i-R1[i]+1;                        pal[nPs].r=i+R1[i]-1;                        pal[nPs].valuePair = hash_value;                        pal[nPs++].cnt=flag;                         flag=false;                        R1[i]--;                  }                      flag=true;                  while (i-R2[i]+1>=0 && i+R2[i]<n && R2[i]>=1)                   {                        pair<int,int> hash_value = getHash(i-R2[i]+1,i+R2[i]);                        if (mp.find(hash_value)!=mp.end()){                           if (flag)pal[mp[hash_value]].cnt++;                            break;                        }                        mp[hash_value]=nPs;                        pal[nPs].l=i-R2[i]+1;                        pal[nPs].r=i+R2[i];                        pal[nPs].valuePair = hash_value;                        pal[nPs++].cnt=flag;                         flag=false;                        R2[i]--;                  }                     }              sort(pal,pal+nPs,cmp_len);                            for (int i=0;i<nPs;++i)              mp[pal[i].valuePair]=i;                            for (int i=0;i<nPs;++i)              {                  if(pal[i].l+1<=pal[i].r-1)                  {                     pair<int,int> hash_value = getHash(pal[i].l+1,pal[i].r-1);                     pal[mp[hash_value]].cnt+=pal[i].cnt;                  }              }              for (int i=0;i<m;++i)              {                  cin>>k;                  for (int j=0;j<26;++j) scanf("%d",&v[j]);                  make_hash();                  for (int j=0;j<nPs;++j){                      pal[j].value = getValue(pal[j].l+1,(pal[j].l+pal[j].r)/2+1);                      //cout<<pal[j].l<<" "<<pal[j].r<<" "<<pal[j].value<<" "<<pal[j].cnt<<endl;                      }                                    sort(pal,pal+nPs,cmp_value);                  for (int j=0;j<nPs;++j){                      k-=pal[j].cnt;                      if (k<=0) {cout<< pal[j].value<<endl; break;}                  }                     }              cout<<endl;                      }     return 0;}