[JZOJ5077]树的难题

来源:互联网 发布:越南 共产 知乎 编辑:程序博客网 时间:2024/04/29 22:38

题目大意

给定一棵n个点的无根树,树上每一条边都有颜色。一共m种颜色,编号从1m。第i种颜色权值为ci
对于树上的一条简单路径,路径上经过的所有边按照顺序组成一个颜色序列,序列可以划分成若干个相同颜色段。定义路径权值为颜色序列上每一个同颜色段的颜色权值之和。
你要计算出边数在[l,r]之内的所有简单路径中,路径权值的最大值。

1mn2×105,1lrn,|ci|104


题目分析

点分治,对于一个分治重心,把子树按照到重心边的颜色排序,然后同种颜色同种颜色转移,用线段树记录长度为某个值的最大答案,开两棵来进行不同颜色和同颜色的转移。在切换颜色时使用线段树合并将第二棵并到第一棵上可以卡常。
时间复杂度O(nlog2n)
然而,如果我们同颜色按深度排排序,不同颜色也排排序,姿势好一点的话用单调队列可以做到O(nlogn)


代码实现

蒟蒻在模拟赛时只想到了log2做法。

#include <algorithm>#include <iostream>#include <climits>#include <cstdio>#include <cctype>using namespace std;int read(){    int x=0,f=1;    char ch=getchar();    while (!isdigit(ch)) f=ch=='-'?-1:f,ch=getchar();    while (isdigit(ch)) x=x*10+ch-'0',ch=getchar();    return x*f;}const int INF=INT_MAX;const int N=200050;const int M=200050;const int E=N<<1;const int LGN=18;const int S=N*LGN*2;struct segment_tree{    int v[S],son[S][2];    int tot;    void init(){tot=0,v[0]=-INF;}    int newnode(){return v[++tot]=-INF,son[tot][0]=son[tot][1]=0,tot;}    void insert(int &rt,int x,int y,int l,int r)    {        if (!rt) rt=newnode();        v[rt]=max(v[rt],y);        if (l==r) return;        int mid=l+r>>1;        if (x<=mid) insert(son[rt][0],x,y,l,mid);        else insert(son[rt][1],x,y,mid+1,r);    }    int query(int rt,int st,int en,int l,int r)    {        if (!rt) return -INF;        if (st==l&&en==r) return v[rt];        int mid=l+r>>1;        if (en<=mid) return query(son[rt][0],st,en,l,mid);        else if (mid+1<=st) return query(son[rt][1],st,en,mid+1,r);        else return max(query(son[rt][0],st,mid,l,mid),query(son[rt][1],mid+1,en,mid+1,r));    }    int merge(int rt1,int rt2)    {        if (!(rt1&&rt2)) return rt1^rt2;        v[rt1]=max(v[rt1],v[rt2]);        son[rt1][0]=merge(son[rt1][0],son[rt2][0]),son[rt1][1]=merge(son[rt1][1],son[rt2][1]);        return rt1;    }}t;struct data{    int x,c;    bool operator<(data const d)const{return c<d.c;}}son[N];int last[N],fa[N],size[N],que[N];int tov[E],nxt[E],col[E];bool vis[N];int val[M];int n,m,tot,head,tail,ans,L,R,root1,root2,dif;void insert(int x,int y,int z){tov[++tot]=y,nxt[tot]=last[x],col[tot]=z,last[x]=tot;}int core(int og){    int rets=n,ret,x,y,i,tmp;    for (head=0,fa[que[tail=1]=og]=0;head<tail;)        for (size[x=que[++head]]=1,i=last[x];i;i=nxt[i])            if ((y=tov[i])!=fa[x]&&!vis[y]) fa[que[++tail]=y]=x;    for (head=tail;head>1;--head) size[fa[que[head]]]+=size[que[head]];    for (head=1;head<=tail;++head)    {        for (tmp=size[og]-size[x=que[head]],i=last[x];i;i=nxt[i])            if ((y=tov[i])!=fa[x]&&!vis[y]) tmp=max(tmp,size[y]);        if (tmp<rets) ret=x,rets=tmp;    }    return ret;}void calc(int x,int len,int lst,int cur){    if (len>R) return;    if (L<=len) ans=max(ans,cur);    if (len<R)    {        int ret=t.query(root1,max(L-len,1),R-len,1,n);        if (ret!=-INF) ans=max(ans,ret+cur);        ret=t.query(root2,max(L-len,1),R-len,1,n);        if (ret!=-INF) ans=max(ans,ret+cur-dif);    }    for (int i=last[x],y;i;i=nxt[i])        if ((y=tov[i])!=fa[x]&&!vis[y])            fa[y]=x,calc(y,len+1,col[i],cur+(lst!=col[i])*val[col[i]]);}void change(int x,int len,int lst,int cur){    if (len>R) return;    t.insert(root2,len,cur,1,n);    for (int i=last[x],y;i;i=nxt[i])        if ((y=tov[i])!=fa[x]&&!vis[y])            change(y,len+1,col[i],cur+(lst!=col[i])*val[col[i]]);}void solve(int x){    int c=core(x),cnt=0;    for (int i=last[c],y;i;i=nxt[i])        if (!vis[y=tov[i]]) son[++cnt].x=y,son[cnt].c=col[i];    sort(son+1,son+1+cnt),root1=root2=0,t.init();    for (int j=1;j<=cnt;++j)    {        if (son[j].c!=son[j-1].c) root1=t.merge(root1,root2),root2=0,dif=val[son[j].c];        fa[son[j].x]=c,calc(son[j].x,1,son[j].c,val[son[j].c]),change(son[j].x,1,son[j].c,val[son[j].c]);    }    vis[c]=1;    for (int i=last[c],y;i;i=nxt[i])        if (!vis[y=tov[i]]) solve(y);}int main(){    freopen("journey.in","r",stdin),freopen("journey.out","w",stdout);    n=read(),m=read(),L=read(),R=read();    for (int i=1;i<=m;++i) val[i]=read();    for (int i=1,x,y,z;i<n;++i) x=read(),y=read(),z=read(),insert(x,y,z),insert(y,x,z);    ans=-INF,solve(1),printf("%d\n",ans);    fclose(stdin),fclose(stdout);    return 0;}
0 0
原创粉丝点击