【JZOJ5055】【GDOI2017模拟二试4.12】树上路径

来源:互联网 发布:ps如何把图片矩阵排列 编辑:程序博客网 时间:2024/05/17 02:51

Description

给定一颗n个结点的无根树,树上的每个点有一个非负整数点权,定义一条路径的价值为路径上的点权和-路径的点权最大值。
给定参数p,我们想知道,有多少不同的树上简单路径,满足它的价值恰好是p的倍数。
注意:单点算作一个路径;u ≠ v时,(u,v)和(v,u)只算一次。

Data Constraint

对所有测试点,我们有:
n≤10^5,p≤10^7,val_i≤10^9
这里写图片描述

Solution

这是道树分治的题。我们找出重心的位置,每次从重心往四周遍历,找出每条到重心的路径的点权和%p和路径的点权最大值,然后将路径按点权最大值从小大大排序,用个桶维护当前的路径的点权和,每次在桶中查找路径的点权和-路径的点权最大值的数量。由于可能会算重,所以要先重心的每颗子树自己先搞一下,减去重复。

Code

#include<iostream>#include<cmath>#include<cstring>#include<cstdio>#include<algorithm>using namespace std;const int maxn=2e5+5,maxn1=1e7+5;struct code{    int mx,sum;}b[maxn];int first[maxn],last[maxn],next[maxn],a[maxn],size[maxn],mx[maxn];int n,i,t,j,k,l,m,x,y,z,num,p,ans,ln,s,cnt[maxn1][2],bz[maxn],fa[maxn];void lian(int x,int y){    last[++num]=y;next[num]=first[x];first[x]=num;}bool cmp(code x,code y){    return x.mx<y.mx;}void dg1(int x,int y){    int t,p=num;size[x]=1;mx[x]=0;    for (t=first[x];t;t=next[t]){        if (last[t]==y || bz[last[t]])continue;        b[++num].sum=(b[p].sum+a[last[t]])%m;        b[num].mx=max(a[last[t]],b[p].mx);        dg1(last[t],x);size[x]+=size[last[t]];mx[x]=max(mx[x],size[last[t]]);    }}int find(int x,int y){    int t,k;mx[x]=max(mx[x],p-size[x]);    if (mx[x]*2<=p||p==1) return x;    for(t=first[x];t;t=next[t]){        if (last[t]==y || bz[last[t]]) continue;        k=find(last[t],x);        if (k) return k;    }    return 0;}void dg(int x){    int t,k;    bz[x]=1;num=0;    for (t=first[x];t;t=next[t]){        if (bz[last[t]]) continue;        k=num+1;        b[++num].mx=max(a[x],a[last[t]]);b[num].sum=(a[x]+a[last[t]])%m;        dg1(last[t],0);        sort(b+k,b+num+1,cmp);        for (i=k;i<=num;i++){            k=((b[i].sum-b[i].mx)%m+m)%m;            if (k) k=m-k;            if (cnt[k][0]==last[t]) ans-=cnt[k][1];            l=((b[i].sum-a[x])%m+m)%m;            if (cnt[l][0]!=last[t]) cnt[l][0]=last[t],cnt[l][1]=0;            cnt[l][1]++;        }    }    sort(b+1,b+num+1,cmp);    for (i=1;i<=num;i++){        k=((b[i].sum-b[i].mx)%m+m)%m;        if (k) k=m-k;else ans++;        if (cnt[k][0]==x) ans+=cnt[k][1];        l=((b[i].sum-a[x])%m+m)%m;        if (cnt[l][0]!=x) cnt[l][0]=x,cnt[l][1]=0;        cnt[l][1]++;    }    for (t=first[x];t;t=next[t]){        if (bz[last[t]])continue;p=size[last[t]];        k=find(last[t],x);        dg(k);    }}int main(){    freopen("path.in","r",stdin);freopen("path.out","w",stdout);    scanf("%d%d",&n,&m);    for (i=1;i<n;i++)        scanf("%d%d",&x,&y),lian(x,y),lian(y,x);    for (i=1;i<=n;i++)        scanf("%d",&a[i]);num=0;    dg1(1,0);p=n;    k=find(1,0);    dg(k);    ans+=n;    printf("%d\n",ans);}
1 0