【清华集训2017模拟12.10】回文串

来源:互联网 发布:生存模型 软件 编辑:程序博客网 时间:2024/05/12 09:07

Description

NYG 很喜欢研究回文串问题,有一天他想到了这样一个问题:
给出一个字符串 S,现在有 4 种操作:
• addl c :在当前字符串的左端加入字符 c;
• addr c :在当前字符串的右端加入字符 c;
• transl l 1 r 1 l 2 r 2 :取出 S 的两个子串 S[l 1 …r 1 ],S[l 2 …r 2 ],现在 NYG想把前一个字符串变换为后一个字符串,每次操作他可以在前一个字符串的左端插入或删除一个字符,保证 NYG 会使用尽量少的步数进行操作,你需要输出整个操作的优美度。
• transr l 1 r 1 l 2 r 2 :取出 S 的两个子串 S[l 1 …r 1 ],S[l 2 …r 2 ],现在 NYG想把前一个字符串变换为后一个字符串,每次操作他可以在前一个字符串的右端插入或删除一个字符,保证 NYG 会使用尽量少的步数进行操作,你需要输出整个操作的优美度。
设字符串 S 长为 n,且从 1 开始标号,那么 nyg 这样定义一次变换的优美度 p:
定义 S[i…j] 是好的当且仅当 S[i…j] 是回文的而且在变换中出现过.
这里写图片描述
例如 S = abaabac,现在执行 transr 1 5 4 7,其变换过程为:
abaab -> abaa -> aba -> abac
注意上述四个串均视为在此次变换中出现,其中 aba 为回文串,且S[1…3] = S[4…6] = aba,故此次变换的优美度为 6。
由于 NYG 还要忙着出题,这个任务就交给你了。
n,q<=1e5

Solution

先离线把回文树建出来,然后可以发现这个操作相当于求回文树中一段路径的len[x]*cnt[x]之和。
用LCT/树链剖分维护即可,注意LCA有时候不能计入答案。
感觉写的比纯数据结构还长,还恶心QwQ
代码特丑~~

Code

#include <cstdio>#include <cstring>#include <algorithm>#define fo(i,a,b) for(int i=a;i<=b;i++)#define fd(i,a,b) for(int i=a;i>=b;i--)#define rep(i,a) for(int i=lst[a];i;i=nxt[i])using namespace std;typedef long long ll;int read() {    char ch;    for(ch=getchar();ch<'0'||ch>'9';ch=getchar());    int x=ch-'0';    for(ch=getchar();ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';    return x;}void write(ll x) {    if (!x) {puts("0");return;}    char ch[20];int tot=0;    for(;x;x/=10) ch[++tot]=x%10+'0';    fd(i,tot,1) putchar(ch[i]);    puts("");}const int N=4*1e5+5;int n,q,tot,pl,pr,len,S[N],x[N],c[N],now[N],l1[N],r1[N],l2[N],r2[N],posl[N],posr[N];bool ask[N],bz[N];struct palindrome_tree{    int fail[N],to[N][27],len[N],cnt[N],lst;    void build() {        len[0]=0;fail[0]=1;        len[tot=1]=-1;    }    int get(int n,int x) {        while (S[n-len[x]-1]!=S[n]) x=fail[x];        return x;    }    int add(int n,int x) {        int now=get(n,lst);        if (!to[now][x]) {            len[++tot]=len[now]+2;            fail[tot]=to[get(n,fail[now])][x];            to[now][x]=tot;        }        lst=to[now][x];        cnt[lst]++;        return lst;    }    void count() {fd(i,tot,1) cnt[fail[i]]+=cnt[i];}}tr;char st[10];void init() {    n=read();q=read();    int L=n+1,R=2*n;    fo(i,L,R) x[i]=read();    fo(i,1,q) {        scanf("%s",st+1);        if (st[1]=='a') {            bz[i]=(st[4]=='r');            c[i]=read();            if (bz[i]) {x[++R]=c[i];now[i]=R;}            else {x[--L]=c[i];now[i]=L;}        } else {            ask[i]=1;            bz[i]=(st[6]=='r');            l1[i]=read();r1[i]=read();            l2[i]=read();r2[i]=read();        }    }    len=R-L+1;    fo(i,1,len) S[i]=x[i+L-1]+1;    int tmp=0;    fd(i,q,1)         if (ask[i]) l1[i]+=tmp,r1[i]+=tmp,l2[i]+=tmp,r2[i]+=tmp;        else tmp+=!bz[i];    fo(i,1,q) if (!ask[i]) now[i]-=L-1;    pl=n-L+2;pr=pl+n-1;}int t[N<<1],nxt[N<<1],lst[N],l;void add(int x,int y) {    t[++l]=y;nxt[l]=lst[x];lst[x]=l;}int fa[N][17],dep[N],dfn[N],w[N],top[N],size[N],son[N],tmp;void dfs(int x,int y) {    fa[x][0]=y;dep[x]=dep[y]+1;    fo(j,1,16) fa[x][j]=fa[fa[x][j-1]][j-1];    size[x]=1;int k=0;    rep(i,x)         if (t[i]!=y) {            dfs(t[i],x);            size[x]+=size[t[i]];            if (size[t[i]]>k) k=size[t[i]],son[x]=t[i];        }}void make(int x,int y) {    top[x]=y;dfn[++tmp]=x;w[x]=tmp;    if (!son[x]) return;    make(son[x],y);    rep(i,x) if (t[i]!=fa[x][0]&&t[i]!=son[x]) make(t[i],t[i]);}void prepare() {    tr.build();    fo(i,1,len) posr[i]=tr.add(i,S[i]);    tr.lst=0;    fo(i,1,len/2) swap(S[i],S[len-i+1]);    fo(i,1,len) posl[i]=tr.add(i,S[i]);    fo(i,1,len/2) swap(S[i],S[len-i+1]),swap(posl[i],posl[len-i+1]);    tr.len[1]=0;    fo(i,2,tot) add(tr.fail[i]?tr.fail[i]:1,i);    dfs(1,0);make(1,1);}int lca(int x,int y) {    if (dep[x]<dep[y]) swap(x,y);    fd(j,16,0) if (dep[fa[x][j]]>dep[y]) x=fa[x][j];    if (dep[x]!=dep[y]) x=fa[x][0];    fd(j,16,0) if (fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j];    return (x==y)?x:fa[x][0];}int find(int x,int L) {    fd(j,16,0) if (tr.len[fa[x][j]]>L) x=fa[x][j];    if (tr.len[x]>L) x=fa[x][0];    return x;}ll Tree[N<<2],Sum[N<<2];int lazy[N<<2];void build(int v,int l,int r) {    if (l==r) {Sum[v]=tr.len[dfn[l]];return;}    int mid=l+r>>1;    build(v<<1,l,mid);build(v<<1|1,mid+1,r);    Sum[v]=Sum[v<<1]+Sum[v<<1|1];}void back(int v,int z) {    Tree[v]+=(ll)Sum[v]*z;    lazy[v]+=z;}void down(int v) {    if (lazy[v]) {        back(v<<1,lazy[v]);        back(v<<1|1,lazy[v]);        lazy[v]=0;    }}void modify(int v,int l,int r,int x,int y) {    if (l==x&&r==y) {back(v,1);return;}    int mid=l+r>>1;down(v);    if (y<=mid) modify(v<<1,l,mid,x,y);    else if (x>mid) modify(v<<1|1,mid+1,r,x,y);    else modify(v<<1,l,mid,x,mid),modify(v<<1|1,mid+1,r,mid+1,y);    Tree[v]=Tree[v<<1]+Tree[v<<1|1];}ll get_sum(int v,int l,int r,int x,int y) {    if (l==x&&r==y) return Tree[v];    int mid=l+r>>1;down(v);    if (y<=mid) return get_sum(v<<1,l,mid,x,y);    else if (x>mid) return get_sum(v<<1|1,mid+1,r,x,y);    else return get_sum(v<<1,l,mid,x,mid)+get_sum(v<<1|1,mid+1,r,mid+1,y);}void ins(int x) {    int f=top[x];    while (f) {        modify(1,1,tmp,w[f],w[x]);        x=fa[f][0];f=top[x];    }}ll query(int x,int y) {    ll ans=0;    int f1=top[x],f2=top[y];    while (f1!=f2) {        if (dep[f1]<dep[f2]) swap(x,y),swap(f1,f2);        ans+=get_sum(1,1,tmp,w[f1],w[x]);        x=fa[f1][0];f1=top[x];    }    if (dep[x]>dep[y]) swap(x,y);    ans+=get_sum(1,1,tmp,w[x],w[y]);    return ans;}void solve() {    build(1,1,tmp);    fo(i,pl,pr) {        int now=find(posr[i],i-pl+1);        ins(now);    }    fo(i,1,q)         if (ask[i]) {            ll ans=0;            if (bz[i]) {                int x=find(posl[l1[i]],r1[i]-l1[i]+1);                int y=find(posl[l2[i]],r2[i]-l2[i]+1);                int z=lca(x,y);                ans=query(x,y);                if (l1[i]+tr.len[z]<=r1[i]&&l2[i]+tr.len[z]<=r2[i])                     if (S[l1[i]+tr.len[z]]==S[l2[i]+tr.len[z]])                         ans-=(ll)get_sum(1,1,tmp,w[z],w[z]);            } else {                int x=find(posr[r1[i]],r1[i]-l1[i]+1);                int y=find(posr[r2[i]],r2[i]-l2[i]+1);                int z=lca(x,y);                ans=query(x,y);                if (r1[i]-tr.len[z]>=l1[i]&&r2[i]-tr.len[z]>=l2[i])                     if (S[r1[i]-tr.len[z]]==S[r2[i]-tr.len[z]])                         ans-=(ll)get_sum(1,1,tmp,w[z],w[z]);            }            write(ans);        } else {            if (bz[i]) {                pr++;                int now=find(posr[pr],pr-pl+1);                ins(now);            } else {                pl--;                int now=find(posl[pl],pr-pl+1);                ins(now);            }        }}int main() {    freopen("string.in","r",stdin);    freopen("string.out","w",stdout);    init();    prepare();    solve();    return 0;}