3252: 攻略 dfs序+线段树

来源:互联网 发布:人力资源分析软件 编辑:程序博客网 时间:2024/05/22 04:47

首先维护一个根到底路径的前缀和,选某个点代表选了此点到根的路径。
那么每次选择了一个点x,这个点所在子树内的每个点都要减少vx的收益,那么对于子树的区间减可以用dfs序来放到连续的一段并用线段树来实现,记录每个区间的最大值及最大值来自哪个点。每个点删除后打一个标记,并在线段树中赋值为inf,代表不可再次选到。
由于每个点只会被删除一次,而删除一次的复杂度为O(logn),所以总的复杂度为O(nlogn)

#include<iostream>#include<cstdio>#define ll long long #define inf 1e18using namespace std;int n,k,cnt,dfn;ll ans;ll v[200005],mx[800005],tag[800005];bool vis[200005];int pos[200005],in[200005],out[200005],fa[200005],head[200005],list[200005],next[200005],l[800005],r[800005],from[800005];inline ll read(){    ll a=0,f=1; char c=getchar();    while (c<'0'||c>'9') {if (c=='-') f=-1; c=getchar();}    while (c>='0'&&c<='9') {a=a*10+c-'0'; c=getchar();}    return a*f;}inline void insert(int x,int y){    next[++cnt]=head[x];    head[x]=cnt;    list[cnt]=y;}void dfs(int x){    in[x]=++dfn; pos[dfn]=x;    for (int i=head[x];i;i=next[i])     {        v[list[i]]+=v[x];        fa[list[i]]=x;        dfs(list[i]);    }    out[x]=dfn;}inline void update(int k){    mx[k]=from[k]=0;    if (mx[k<<1]<0&&mx[k<<1|1]<0) return;    mx[k]=max(mx[k<<1],mx[k<<1|1]);    from[k]=mx[k<<1]>mx[k<<1|1]?from[k<<1]:from[k<<1|1];}inline void pushdown(int k){    if (!tag[k]) return;    mx[k<<1]+=tag[k]; mx[k<<1|1]+=tag[k];    tag[k<<1]+=tag[k]; tag[k<<1|1]+=tag[k];    tag[k]=0;}void build(int k,int x,int y){    l[k]=x; r[k]=y;    if (l[k]==r[k]) {mx[k]=v[pos[l[k]]]; from[k]=pos[l[k]]; return;}    int mid=l[k]+r[k]>>1;    build(k<<1,x,mid); build(k<<1|1,mid+1,y);    update(k);}void change(int k,int x,int y,ll val){    if (x>y) return;    if (l[k]==x&&r[k]==y)     {        mx[k]+=val;        tag[k]+=val;        return;    }    pushdown(k);    int mid=l[k]+r[k]>>1;    if (y<=mid) change(k<<1,x,y,val);     else if (x>mid) change(k<<1|1,x,y,val);    else change(k<<1,x,mid,val),change(k<<1|1,mid+1,y,val);    update(k);}int main(){    n=read(); k=read();    for (int i=1;i<=n;i++) v[i]=read();    for (int i=1;i<n;i++)    {        int u=read(),v=read();        insert(u,v);    }    dfs(1);    build(1,1,n);    for (int i=1;i<=k;i++)    {        if (mx[1]==0) break;        ans+=mx[1];         for (int i=from[1];i&&!vis[i];i=fa[i])            vis[i]=1,change(1,in[i],in[i],-inf),change(1,in[i]+1,out[i],v[fa[i]]-v[i]);    }    cout << ans << endl;    return 0;}
0 0
原创粉丝点击