poj 3321 Apple Tree(线段树)

来源:互联网 发布:mac b站直播 编辑:程序博客网 时间:2024/05/30 04:08

题目大意:

        有一棵树,有n个节点,开始时每个分叉点有1个苹果,现在有m个操作,Q x代表查询x节点及x节点的子节点的苹果总和,C x代表将节点x的状态翻转,有苹果变成没苹果,没苹果变成有苹果。

解题思路:

        从节点1开始,进行一次深搜,根据遍历时间为每个节点编号,进入一个新节点时时间加1,离开时,时间标为从上一个节点离开的时间。那么每个节点的进入时间和离开时间就组成了一个区间[l,r],且当前节点的有无苹果状态存在第l个点上,这个节点及子节点的苹果总和就是[l,r]上的苹果总和。所以问题就变成了一个线段树问题。

注意点:

        本题会卡vector,不建议使用。

代码:

#include <iostream>#include <stdio.h>#include <vector>using namespace std;struct range{    int l,r;};struct node{    int l,r;    int sum;};struct tNode{    int to,next;    tNode(){to = -1;next = -1;}};int n,m,nowPos = 0;;tNode son[100005];range a[100005];node tree[400005];void dfs(int x){    a[x].l = ++nowPos;//进入新节点时时间加1    if(son[x].to != -1){        int now = son[x].to;        while(now != -1){            dfs(now);            now = son[now].next;        }    }    a[x].r = nowPos;//离开时时间为从上一个节点离开的时间}void buildTree(int root,int l,int r){    tree[root].l = l;    tree[root].r = r;    tree[root].sum = r - l + 1;    if(l != r){        int mid = (l+r)>>1;        buildTree(2*root,l,mid);        buildTree(2*root+1,mid+1,r);    }}void updata(int root,int pos){    if(tree[root].l == pos && tree[root].r == pos){        tree[root].sum = 1 - tree[root].sum;        return;    }    int mid = (tree[root].l + tree[root].r)>>1;    if(pos <= mid){        updata(2*root,pos);    }    else{        updata(2*root+1,pos);    }    tree[root].sum = tree[2*root].sum + tree[2*root+1].sum;}int query(int root,int l,int r){    if(tree[root].l == l && tree[root].r == r){        return tree[root].sum;    }    int mid = (tree[root].l + tree[root].r)>>1;    if(r <= mid){        return query(2*root,l,r);    }    else if(l > mid){        return query(2*root+1,l,r);    }    else{        return query(2*root,l,mid) + query(2*root+1,mid+1,r);    }}int main(){    char s[5];    int x,y;    scanf("%d",&n);    for(int i = 0; i < n-1; i ++){        scanf("%d%d",&x,&y);        son[y].next = son[x].to;        son[x].to = y;    }    dfs(1);    buildTree(1,1,n);    scanf("%d",&m);    for(int i = 0; i < m; i ++){        scanf("%s %d",s,&x);        if(s[0] == 'C'){            updata(1,a[x].l);        }        else{            printf("%d\n",query(1,a[x].l,a[x].r));        }    }    return 0;}


0 0
原创粉丝点击