hdu 5221 Occupation (树剖+线段树)

来源:互联网 发布:淘宝店铺监控多久解除 编辑:程序博客网 时间:2024/06/06 16:00

题意:

一棵树每个点都有权值,有三种操作

1 从x节点走到y节点,并将路径中的点的权值都取出来

2 将x节点的权值减去

3 将以x为根节点的子树的所有节点的值取出来。


每次操作后查询一次现在取出来的值为多少。


解题思路:


入门树剖+线段树题吧

第一次写树剖,写错了一个地方wa了好几发。

 while(ty!=tx)    {        if(deep[tx]>deep[ty]) //比较深度的应该是tx和ty不是x和y        {           update(1, 1, n, tid[top[x]], tid[x], 1);                 x=fa[top[x]],  tx=top[x];        }        else        {            update(1, 1, n, tid[top[y]], tid[y], 1);            y=fa[top[y]],  ty=top[y];        }    }

dfs序可以把树上的节点编号,而树剖可以让一条链上的节点编号都是连续的,然后当我们需要两个点路径的权值的时候,我们就可以通过让两个点所在的链不断上升,直到他们在同一条链上,由于树剖的性质,这个过程最多只需要logn次,这是时间上的优化。

第一种操作就是用树剖实现了,通过树剖找到我们需要更新的连续的编号区间, 然后用线段树维护。具体的我就不讲了,看下卿学姐的视频就清楚了。

第二种操作就是普通的线段树单点更新

第三种的话,由于一颗子树上的dfs序也一定是连续的,所以维护也简单了。


代码:

#include <bits/stdc++.h>#define ps push_back#define lson o<<1#define rson o<<1|1#define LL long long using namespace std;const int maxn=1e5+5;struct p{    LL x;    int lazy;    LL  sum;    void init()    {        x=0;         lazy=-1;    }}t[maxn<<4];int re[maxn];int son[maxn];int fa[maxn];int top[maxn];int deep[maxn];int tid[maxn];int siz[maxn];vector<int>edg[maxn];LL val[maxn];int cnt, n;void dfs1(int x, int f){    int i, j;    fa[x]=f;    son[x]=-1;    siz[x]=1;    for(i=0; i<(int)edg[x].size(); i++)    {        if(edg[x][i]!=f)        {            dfs1(edg[x][i], x);            siz[x]+=siz[edg[x][i]];            if(son[x]==-1||siz[son[x]]<siz[edg[x][i]])            {                son[x]=edg[x][i];            }        }     }    return;}void dfs2(int x,  int TOP, int de){    re[cnt]=x, tid[x]=cnt++, top[x]=TOP, deep[x]=de;       int i;    if(son[x]!=-1)    {        dfs2(son[x], TOP, de+1);    }    for(i=0; i<(int)edg[x].size(); i++)    {        if(edg[x][i]!=fa[x] && edg[x][i]!=son[x])        {            dfs2(edg[x][i], edg[x][i], de+1);               }    }    return;}void update(int o, int l, int r, int ll, int rr, int x){    if(ll<=l && r<=rr)    {        if(x==0)t[o].x=0;        else t[o].x=t[o].sum,t[o].lazy=x;                return;    }    if(t[o].lazy!=-1)    {        t[lson].x=t[lson].sum;        t[rson].x=t[rson].sum;        t[lson].lazy=t[rson].lazy=t[o].lazy;        t[o].lazy=-1;     }        int mid=(l+r)>>1;        if(ll<=mid)update(lson, l, mid, ll, rr, x);    if(rr>mid)update(rson, mid+1, r, ll, rr, x);    t[o].x=t[lson].x+t[rson].x; //   printf("%d %d %d\n", l, r, t[o].x);    return;}void UPD(int x, int y){    int ty=top[y], tx=top[x];    while(ty!=tx)    {//        printf("%d %d %d %d\n", x, tx, y, ty);        if(deep[tx]>deep[ty])        {        //    printf("l,r %d %d %d\n", top[x], tid[top[x]], tid[x]);           update(1, 1, n, tid[top[x]], tid[x], 1);                 x=fa[top[x]],  tx=top[x];        }        else        {            update(1, 1, n, tid[top[y]], tid[y], 1);            y=fa[top[y]],  ty=top[y];        }    }      //  printf("%d %d\n", x, y);        if(deep[x]<deep[y])        {           update(1, 1, n, tid[x], tid[y], 1);              }        else        {     //       printf("l,r %d %d\n", tid[y], tid[x]);            update(1, 1, n, tid[y], tid[x], 1);        }        return;}void build(int o, int l, int r){    t[o].init();    if(l==r)    {        t[o].sum=val[re[l]];        return;    }    int mid=(l+r)>>1;    build(lson, l, mid);    build(rson, mid+1, r);    t[o].sum=t[lson].sum+t[rson].sum;    return;}int main(){    int m, i, j;    cin>>m;    while(m--)    {        scanf("%d", &n);        for(i=1; i<=n; i++)        {            scanf("%lld", &val[i]);            edg[i].clear();        }        int x, y;        for(i=1; i<n; i++)        {            scanf("%d%d", &x, &y);            edg[x].ps(y);            edg[y].ps(x);        }        cnt=1;        dfs1(1, 0);        dfs2(1, 1, 1);//        for(i=1; i<=n; i++)printf("%d %d\n", son[i], tid[i]);printf("\n");        build(1, 1, n);        int q, op;        scanf("%d", &q);                while(q--)        {            scanf("%d", &op);            if(op==1)            {                scanf("%d%d", &x, &y);                UPD(x, y);                        printf("%lld\n", t[1].x);            }              else if(op==2)            {                scanf("%d", &x);                update(1, 1, n, tid[x], tid[x], 0);                printf("%lld\n", t[1].x);            }            else             {                scanf("%d", &x);                update(1, 1, n, tid[x], tid[x]+siz[x]-1, 1);                printf("%lld\n", t[1].x);            }        }    }    }


0 0
原创粉丝点击