BZOJ 1036 [ZJOI2008]树的统计Count 树链剖分练手题

来源:互联网 发布:以爆字开头的网络语言 编辑:程序博客网 时间:2024/06/05 17:20

写过的第一道树剖题。之前看过几个课件,也没写过,这次直接YY出来的竟然能A。维护区间信息时没用线段树,用的Splay,只为练练手,然而就导致代码无下限的长,可是心酸。还好这次写了之后1A(233,明明输出优化没加负号Wa一次的),不然还真不敢调。

#include <cstdio>#include <cstring>#include <algorithm>#include <vector>using namespace std;struct node{    int va, fa, lc, rc, sz, su, mx;    void set0()    {va=fa=lc=rc=sz=su=0; mx=-999999;}};struct SplayTree{    node P[30005]; int root, cnt;    void maintain (int rt)    {        P[0].set0();        int lc = P[rt].lc, rc = P[rt].rc;        P[rt].sz = 1+P[lc].sz+P[rc].sz;        P[rt].su = P[rt].va+P[lc].su+P[rc].su;        P[rt].mx = max(P[rt].va, max(P[lc].mx,P[rc].mx));    }    void rotate (int rt)    {        int fa = P[rt].fa, gfa = P[fa].fa;        if (!gfa) root = rt;        if (P[fa].lc == rt)        {            P[fa].lc = P[rt].rc, P[rt].rc = fa;            P[P[fa].lc].fa = fa, P[fa].fa = rt;        }        else        {            P[fa].rc = P[rt].lc, P[rt].lc = fa;            P[P[fa].rc].fa = fa, P[fa].fa = rt;        }        if (P[gfa].lc == fa) P[gfa].lc = rt;        else P[gfa].rc = rt; P[rt].fa = gfa;        maintain(fa); maintain(rt);    }    int find (int rt, int k)    {        int rk = P[P[rt].lc].sz + 1;        if (rk == k) return rt;        if (rk > k) return find (P[rt].lc, k);        if (rk < k) return find (P[rt].rc, k-rk);    }    void splay (int rt)    {        while (rt != root)        {            int fa = P[rt].fa, gfa = P[fa].fa;            if (gfa&&((P[fa].lc==rt)==(P[gfa].lc==fa))) rotate(fa);            rotate (rt);        }    }    int split (int rt, int op)    {        splay (rt);        int res = op ? P[root].lc : P[root].rc;        if (op) P[root].lc = 0; else P[root].rc = 0;        P[res].fa = 0;        maintain(root);        return res;    }    void merge (int L, int R)    {        if (!L) {root=R; return;}        if (!R) {root=L; return;}        root = L;        splay (find(root,P[root].sz));        P[root].rc = R, P[R].fa = root;        maintain(root);    }    void insert (int va)    {        if (!root)        {            root = ++cnt;            P[1].va = P[1].mx = P[1].su = va;            P[1].sz = 1;            return ;        }        splay(P[root].sz);        P[root].rc = ++cnt;        P[cnt].va = P[cnt].mx = P[cnt].su = va;        P[cnt].sz = 1, P[cnt].fa = root;        maintain(root);    }    void update (int pos, int c)    {        int rt = find (root, pos);        splay (rt);        P[rt].va = c;        maintain(rt);    }    int qmax (int pl, int pr)    {        if (pl == 1)        {            int t = pr==P[root].sz ? pr : pr+1;            int rt = find (root, t);            splay (rt);            return t==pr ? P[root].mx : P[P[root].lc].mx;        }        int rtl = find (root, pl), rtr = find (root, pr);        int op1 = split(rtl, 1), op2 = split(rtr, 0);        int res = P[root].mx;        merge (root, op2);        merge (op1, root);        return res;    }    int qsum (int pl, int pr)    {        if (pl == 1)        {            int t = pr==P[root].sz ? pr : pr+1;            int rt = find (root, t);            splay (rt);            return t==pr ? P[root].su : P[P[root].lc].su;        }        int rtl = find (root, pl), rtr = find (root, pr);        int op1 = split(rtl, 1), op2 = split(rtr, 0);        int res = P[root].su;        merge (root, op2);        merge (op1, root);        return res;    }}Solve;void get (int &x){    char c = getchar(); bool neg = 0; x = 0;    while (c < '0' || c > '9')     {if (c == '-') neg = 1; c = getchar();}    while (c <= '9' && c >= '0') x = x*10+c-48, c = getchar();    if (neg) x = -x;}void put (int x){    char s[15]; int num = 0;     if (x < 0) {x = -x; putchar ('-');}    if (x == 0) putchar ('0');    while (x) s[++num] = (x%10)+48, x /= 10;    while (num) putchar (s[num--]);    putchar ('\n');}void get (char *s){    char c = getchar(); int num = 0;    while (c < 'A' || c > 'Z') c = getchar();    while ((c <= 'Z' && c >= 'A') || c == '-') s[num++] = c, c = getchar();}int n, q, e, cnt, h[30005], nx[60005], to[60005];int val[30005], fa[30005], son[30005], head[30005], size[30005], Lsize[30005], deep[30005], belong[30005];vector <int> clude[30005];void Dfs1 (int rt, int f){    fa[rt] = f;    size[rt] = 1;    deep[rt] = deep[f] + 1;    for (int i = h[rt]; i; i = nx[i])    if (f != to[i]){        Dfs1(to[i],rt);        size[rt] += size[to[i]];            if (size[to[i]] > size[son[rt]]) son[rt] = to[i];    }}void Dfs2 (int rt, int f){    if (son[f] == rt) belong[rt] = belong[f], head[rt] = head[f];    else belong[rt] = ++cnt, head[rt] = rt;    clude[belong[rt]].push_back(rt);    Lsize[belong[rt]]++;    for (int i = h[rt]; i; i = nx[i])    if (f != to[i]) Dfs2(to[i], rt);}void Initialize(){    get(n);    for (int i = 1; i < n; i++)    {        int a, b; get(a); get(b);        nx[++e] = h[a], h[a] = e, to[e] = b;        nx[++e] = h[b], h[b] = e, to[e] = a;    }    for (int i = 1; i <= n; i++) get(val[i]);    Dfs1(1, 0);    Dfs2(1, 0);    for (int i = 1; i <= cnt; i++)     {        Lsize[i] += Lsize[i-1];        for (int j = 0; j < clude[i].size(); j++) Solve.insert(val[clude[i][j]]);    }}int id (int t){    return deep[t]-deep[head[t]]+1+Lsize[belong[t]-1];}void Change (int rt, int c){    val[rt] = c;    Solve.update(id(rt), c);}void Getmax (int l, int r){    int res = -999999;    while (belong[l] != belong[r])    {        if (deep[head[l]] < deep[head[r]])        {            res = max (res, Solve.qmax(id(head[r]), id(r)));            r = fa[head[r]];        }        else        {            res = max (res, Solve.qmax(id(head[l]), id(l)));            l = fa[head[l]];        }    }    l = id(l), r = id(r);    if (r < l) swap (l, r);    res = max (res, Solve.qmax(l,r));    put (res);}void Getsum (int l, int r){    int res = 0;    while (belong[l] != belong[r])    {        if (deep[head[l]] < deep[head[r]])        {            res += Solve.qsum(id(head[r]), id(r));            r = fa[head[r]];        }        else        {            res += Solve.qsum(id(head[l]), id(l));            l = fa[head[l]];        }    }    l = id(l), r = id(r);    if (r < l) swap (l, r);    res += Solve.qsum(l,r);    put (res);}void Work(){    get(q);    while (q--)    {        char s[10]; get(s);        int a, b; get(a); get(b);        if (s[1] == 'H') Change(a, b);        if (s[1] == 'M') Getmax(a, b);        if (s[1] == 'S') Getsum(a, b);    }}int main (){    Initialize();    Work();    return 0;}
0 0
原创粉丝点击