[JSOI2015][JZOJ4061]字符串树

来源:互联网 发布:mysql 字段原子累加 编辑:程序博客网 时间:2024/05/01 18:34

题目大意

一棵有n个节点的树,每条边有一个长度l不大于10的字符串。有q个询问,形如(x,y,s)的询问,查询点x到点y的路径上,前缀为s的边的数量。
1n,q100000


题目分析

唉,太水了太水了。
对整棵树建可持续化字典树,每个节点上树的副本是以其父亲节点为根的字典树。
查询时用两个点都直接走一遍,再用最近公共祖先的两边去减掉重复的即可。
时间复杂度O((n+q)l),空间复杂度O(nl)
具体细节详见代码实现。


代码实现

#include <iostream>#include <cstring>#include <cstdio>#include <cctype>#include <cmath>using namespace std;int read(){    int x=0,f=1;    char ch=getchar();    while (!isdigit(ch))    {        if (ch=='-')            f=-1;        ch=getchar();    }    while (isdigit(ch))    {        x=x*10+ch-'0';        ch=getchar();    }    return x*f;}const int N=100005;const int M=(N-1)<<1;const int E=N<<1;const int LGE=17;const int L=10;const int S=N*L+N;struct TRIE{    int tov[S][26],size[S];    int root[N+1];    int tot;    int newnode()    {        ++tot;        for (int i=0;i<26;i++)            tov[tot][i]=0;        size[tot]=0;        return tot;    }    int insert(char str[],int rt0)    {        int len=strlen(str);        int rt=newnode(),ret=rt;        for (int c=0;c<26;c++)            tov[rt][c]=tov[rt0][c];        size[rt]=size[rt0]+1;        for (int i=0;i<len;i++)        {            tov[rt][str[i]-'a']=newnode();            rt=tov[rt][str[i]-'a'];            rt0=tov[rt0][str[i]-'a'];            for (int c=0;c<26;c++)                tov[rt][c]=tov[rt0][c];            size[rt]=size[rt0]+1;        }        return ret;    }    int query(char str[],int rt)    {        int len=strlen(str);        for (int i=0;i<len;i++)            rt=tov[rt][str[i]-'a'];        return size[rt];    }}trie;struct TREE{    int last[N+1],tov[M+1],next[M+1],pos[N+1],high[N+1],fa[N+1],root[N+1];    int rmq[E+1][LGE+1];    char st[M+1][L+5];    int euler[E+1];    int tot,e,lge;    void insert(int x,int y,char s[L+5])    {        tov[++tot]=y;        strcpy(st[tot],s);        next[tot]=last[x];        last[x]=tot;    }    void calc(int x)    {        int i=last[x],y;        euler[++e]=x;        pos[x]=e;        while (i)        {            y=tov[i];            if (y!=fa[x])            {                fa[y]=x;                high[y]=high[x]+1;                root[y]=trie.insert(st[i],root[x]);                calc(y);                euler[++e]=x;            }            i=next[i];        }    }    void build()    {        lge=(int)trunc(log(e)/log(2));        for (int i=1;i<=e;i++)            rmq[i][0]=euler[i];        for (int j=1;j<=lge;j++)            for (int i=1;i<=e-(1<<j)+1;i++)                if (high[rmq[i][j-1]]<high[rmq[i+(1<<j-1)][j-1]])                    rmq[i][j]=rmq[i][j-1];                else                    rmq[i][j]=rmq[i+(1<<j-1)][j-1];    }    int RMQ(int l,int r)    {        int lgr=(int)trunc(log(r-l+1)/log(2));        if (high[rmq[l][lgr]]<high[rmq[r-(1<<lgr)+1][lgr]])            return rmq[l][lgr];        else            return rmq[r-(1<<lgr)+1][lgr];    }    int lca(int x,int y)    {        x=pos[x],y=pos[y];        if (x>y)            x^=y^=x^=y;        return RMQ(x,y);    }}t;char r[L+5];int n,q;int main(){    freopen("strings.in","r",stdin);    freopen("strings.out","w",stdout);    n=read();    for (int i=1,x,y;i<n;i++)    {        x=read(),y=read();        scanf("%s",r);        t.insert(x,y,r);        t.insert(y,x,r);    }    t.root[1]=0;    t.calc(1);    t.build();    q=read();    for (int i=1,x,y,z;i<=q;i++)    {        x=read(),y=read();        scanf("%s",r);        z=t.lca(x,y);        int ans=trie.query(r,t.root[x])+trie.query(r,t.root[y]);        ans-=trie.query(r,t.root[z])*2;        printf("%d\n",ans);    }    fclose(stdin);    fclose(stdout);}
0 0