bzoj3160 万径人踪灭 FFT+manacher

来源:互联网 发布:淘宝卡宾代购店铺 编辑:程序博客网 时间:2024/05/22 16:49

       一道挺不错的题目。关键是想到卷积(相信大神看到这就会做了,不对,大神还需要来看我的博客吗)

       首先我们可以求出所有的回文子序列,然后减去回文子串的数量,就可以得到答案了。回文子串的数量可以用manacher算法O(N)得到,那么就看怎么得到回文子序列了。

       不妨来看以一个点i为中心有多少回文子序列(以夹缝为中心的回文子序列同理)。我们发现关键是统计有多少个k,满足ch[i-k]=ch[i+k],那么以i为中心的点回文子序列的个数为2^k-1(自己不能作为回文子序列)。另外我们发现不同的字母得到的k是相互独立的。因此我们求出有多少ch[i-k]=ch[i+k]=a,以及有多少ch[i-k]=ch[i+k]=b,然后把两个k加起来就行了。

       那么考虑字母'a'带来的影响。可以设一个数组a[i],当ch[i]='a'时a[i]=1。那么对于点i就是统计Σa[i-k]*a[i+k],如果这样还不是很明显,那么我们令点i的答案为2i(i和i+1的夹缝为2i+1),可以看到b[2i]=Σa[i-k]*a[i+k]=Σ(j=0,i)a[j]*a[2i-j]!这就是一个卷积的形式即:b[i]=Σa[j]a[i-j],可以看到对于夹缝这个也是同样成立的!

       然后用FFT加速卷积计算即可。注意可以不需要求出'a'和'b'的两个卷积,而可以求出两个点值表达式后合并到一起,然后就只需要求一次插值即可。

AC代码如下:

#include<iostream>#include<cstdio>#include<cmath>#include<cstring>#define pi acos(-1.0)#define mod 1000000007#define N 300005using namespace std;int n,m,pos[N],f[N],bin[N]; char ch[N],s[N];struct cpx{ double r,i; }a[N],b[N];cpx operator +(cpx x,cpx y){ x.r+=y.r; x.i+=y.i; return x; }cpx operator -(cpx x,cpx y){ x.r-=y.r; x.i-=y.i; return x; }cpx operator *(cpx x,cpx y){cpx z; z.r=x.r*y.r-x.i*y.i; z.i=x.r*y.i+x.i*y.r; return z;}void dft(cpx *a,int p){int i,j,k,mid; cpx w,wn,u,v;for (k=2; k<=m; k<<=1){wn.r=cos(pi*2.0/k*p); wn.i=sin(pi*2.0/k*p); mid=k>>1;for (i=0; i<m; i+=k){w.r=1; w.i=0;for (j=i; j<i+mid; j++){u=a[j]; v=a[j+mid]*w;a[j]=u+v; a[j+mid]=u-v; w=w*wn;}}}if (p<0) for (i=0; i<m; i++) a[i].r/=m;}int main(){scanf("%s",ch+1); n=strlen(ch+1); int i,j,k,cnt=0;m=n<<1|1; for (i=1; i<m; i<<=1) cnt++; m=i;for (i=0; i<m; i++)for (k=i,j=cnt; j; j--,k>>=1) pos[i]=pos[i]<<1|(k&1);for (i=1; i<=n; i++)if (ch[i]=='a') a[pos[i]].r=1; else b[pos[i]].r=1;dft(a,1); dft(b,1);for (i=0; i<m; i++) b[i]=a[i]*a[i]+b[i]*b[i];for (i=0; i<m; i++) a[pos[i]]=b[i];dft(a,-1); int ans=0;bin[0]=1; for (i=1; i<=n; i++) bin[i]=(bin[i-1]<<1)%mod;for (i=0; i<m; i++)ans=(ans+bin[((int)(a[i].r+0.5)+1)>>1]-1)%mod;for (i=1; i<=n; i++){s[i<<1]=ch[i]; s[i<<1|1]='#';}int len=n<<1|1,mx=0;s[1]='#'; s[0]='$';  s[len+1]='@';for (i=2; i<len; i++){f[i]=(mx>i)?min(mx-i,f[(k<<1)-i]):1;while (s[i-f[i]]==s[i+f[i]]) f[i]++;if (i+f[i]>mx){ mx=i+f[i]; k=i; }ans=(ans-(f[i]>>1)+mod)%mod;}printf("%d\n",ans);return 0;}


by lych

2016.3.9 

0 0
原创粉丝点击