【JZOJ 3872】圣诞树

来源:互联网 发布:海洋cms怎么更换模板 编辑:程序博客网 时间:2024/04/28 21:36

Description

圣诞节到了,小可可送给小薰一棵圣诞树。这棵圣诞树很奇怪,它是一棵多叉树,有n个点,n-1条边。它的每个结点都有一个权值。小可可和小薰想用这棵树玩一个游戏。
定义(s,e)为树上从s到e的简单路径,我们可以记下在这条路径上经过的结点,定义这个结点序列为S(s,e)。
我们按照如下方法定义这个序列S(s,e)的权值G(S(s,e)):假设这个序列中结点的权值为Z0,Z1,…,Z(L-1),其中L为序列的长度,我们定义G(S(s,e))=Z0 × k^0 + Z1 × k^1 + … + Z(L-1) × k^(L-1)。
如果路径(s,e)满足G(S(s,e)) ≡ x (mod y) ,那么这条路径属于小可可,否则这条路径属于小薰。小可可和小薰很显然不希望这个游戏变得那么简单。如果路径(p1,p2)和(p2,p3)都属于小薰,那么路径(p1,p3)也属于他 或 如果路径(p1,p2)和(p2,p3)都属于小可可,那么路径(p1,p3)也属于小可可。然而这个性质并不总是正确的。所以小薰想知道到底有多少三元组(p1,p2,p3)满足这个性质。

对于100%的数据,1 ≤ n ≤ 10^5,2 ≤ y ≤ 10^9,1 ≤ k ≤ y,0 ≤ x < y。

Analysis

对于图中的路径,如果G(s(i,j))mod y=x,则构造边权为1的边(i,j),否则边权为0
那么现在求的是所有满足(i,j) (j,k) (i,k) 权值相等的(i,j,k)三元组个数
定义in0[i]表示进入i 的边中权值为 0 的个数。类似地定义 in1[i],out0[i],out1[i]

p=ni=12out0[i]out1[i]+2in0[i]in1[i]+out0[i]in1[i]+out1[i]in0[i]
则我们知道三条边权值不全相同三元组个数被计算了两遍
ans=n3p/2
则问题变成快速求出in,out,因为in0[i]+in1[i]=n,所以只需求出in1和out1就行了
剩下的思路就是用点分治,求满足等式的点对可以通过哈希记录查询实现。具体细节请读者自行思考

ps:点分治常犯错误:在判断b[i].S!=b[i1].S,且i循环到num+1时,记得清空b[num+1].S

Code

#include<cstdio>#include<algorithm>#define fo(i,a,b) for(ll i=a;i<=b;i++)#define efo(i,v) for(int i=last[v];i;i=next[i])using namespace std;typedef long long ll;const int N=100010,M=N*2;const ll hx=2000000;ll n,K,X,mo,_k[N],ny[N],a[N],in[N],out[N],hf[hx][2],hg[hx][2];ll num,rt,tot,to[M],next[M],last[N],size[N];bool bz[N];struct node{    ll v,f,g,S,l;}b[N];bool cmp(node a,node b){    return a.S<b.S;}void link(int u,int v){    to[++tot]=v,next[tot]=last[u],last[u]=tot;}ll qmi(ll x,ll n){    ll t=1;    for(;n;n>>=1)    {        if(n&1) t=t*x%mo;        x=x*x%mo;    }    return t;}void getnum(int v,int fr){    num++;    efo(i,v)    {        int u=to[i];        if(u==fr || bz[u]) continue;        getnum(u,v);    }}void getrt(int v,int fr){    size[v]=1;    efo(i,v)    {        int u=to[i];        if(u==fr || bz[u]) continue;        getrt(u,v);        size[v]+=size[u];    }    if(size[v]>num-size[v] && !rt) rt=v;}void dfs(int v,int fr,ll d,ll f,ll g,ll S){    b[++num].v=v,b[num].S=S,b[num].l=d,b[num].f=f,b[num].g=g;     efo(i,v)    {        int u=to[i];        if(u==fr || bz[u]) continue;        dfs(u,v,d+1,(f+a[u]*_k[d+1])%mo,(g*K+a[u])%mo,S);    }}ll hashf(ll x){    ll pos=x%hx;    while(hf[pos][0] && hf[pos][0]!=x) pos=(pos+1)%hx;    return pos;}ll hashg(ll x){    ll pos=x%hx;    while(hg[pos][0] && hg[pos][0]!=x) pos=(pos+1)%hx;    return pos;}void divide(int v,int fr){    num=0;getnum(v,fr);    rt=0;getrt(v,fr);    if(a[rt]%mo==X) in[rt]++,out[rt]++;    num=0;    efo(i,rt)    {        int u=to[i];        if(u==fr || bz[u]) continue;        dfs(u,rt,1,a[u]*K%mo,(a[u]+a[rt]*K)%mo,u);    }    fo(i,1,num)    {        if(b[i].g==X) in[rt]++,out[b[i].v]++;        if((a[rt]+b[i].f)%mo==X) out[rt]++,in[b[i].v]++;        b[i].g=(X-b[i].g+mo)%mo*ny[b[i].l]%mo;    }    sort(b+1,b+num+1,cmp);    fo(i,1,num)    {        ll pos=hashf(b[i].f);        hf[pos][0]=b[i].f,hf[pos][1]++;        pos=hashg(b[i].g);        hg[pos][0]=b[i].g,hg[pos][1]++;    }    ll st=1;    b[num+1].S=0;    fo(i,1,num+1)        if(i>1 && b[i].S!=b[i-1].S)        {            fo(j,st,i-1)            {                ll pos=hashf(b[j].f);hf[pos][1]--;                pos=hashg(b[j].g);hg[pos][1]--;            }            fo(j,st,i-1)            {                ll pos=hashf(b[j].g);                out[b[j].v]+=hf[pos][1];                pos=hashg(b[j].f);                in[b[j].v]+=hg[pos][1];            }            fo(j,st,i-1)            {                ll pos=hashf(b[j].f);hf[pos][1]++;                pos=hashg(b[j].g);hg[pos][1]++;            }            st=i;        }    fo(i,1,num)    {        ll pos=hashf(b[i].f);        hf[pos][0]=hf[pos][1]=0;        pos=hashg(b[i].g);        hg[pos][0]=hg[pos][1]=0;    }    bz[rt]=1;    efo(i,rt)    {        int u=to[i];        if(u==fr || bz[u]) continue;        divide(u,rt);    }}int main(){    int u,v;    scanf("%lld %d %lld %lld",&n,&mo,&K,&X);    _k[0]=1;    fo(i,1,n) _k[i]=_k[i-1]*K%mo;    fo(i,0,n) ny[i]=qmi(_k[i],mo-2);    fo(i,1,n) scanf("%lld",&a[i]),a[i]%=mo;    fo(i,1,n-1)    {        scanf("%d %d",&u,&v);        link(u,v),link(v,u);    }    divide(1,0);    ll sum=0;    fo(i,1,n) sum+=2*in[i]*(n-in[i])+in[i]*(n-out[i])+out[i]*(n-in[i])+2*out[i]*(n-out[i]);    printf("%lld\n",n*n*n-sum/2);    return 0;}
0 0
原创粉丝点击