洛谷P3384【模板】树链剖分 (树链剖分)

来源:互联网 发布:炒股记账软件 编辑:程序博客网 时间:2024/05/22 18:55

题目

题目传送门


题解

树链剖分模板题,积累一下模板


代码

#include <iostream>#include <cstdio>#include <cstdio>#include <cstring>#include <algorithm>#define N 100005using namespace std;int n, m, Root, MOD, cur, head_p[N], Tim;int num[N], fa[N], top[N], dep[N], son[N], size[N], start[N], end[N], fw[N];struct Tadj{int next, obj;} Edg[N << 1];struct Tnode{int sum, add;} tree[N << 2];void Init(){    cur = -1;    Tim = 0;    memset(head_p, -1, sizeof(head_p));    memset(son, -1, sizeof(son));}void Insert(int a, int b){    cur ++;    Edg[cur].next = head_p[a];    Edg[cur].obj = b;    head_p[a] = cur;}void dfs1(int root, int dad){    dep[root] = dep[dad] + 1;    fa[root] = dad;    size[root] = 1;    for(int i = head_p[root]; ~ i; i = Edg[i].next){      int v = Edg[i].obj;      if(v == dad)  continue;      dfs1(v, root);      size[root] += size[v];      if(son[root] == -1 || size[v] > size[son[root]])  son[root] = v;    }}void dfs2(int root, int tp){    Tim ++;    start[root] = Tim;    fw[Tim] = root;    top[root] = tp;    if(~ son[root])  dfs2(son[root], tp);    for(int i = head_p[root]; ~ i; i = Edg[i].next){      int v = Edg[i].obj;      if(v == fa[root] || v == son[root])  continue;      dfs2(v, v);    }    end[root] = Tim;}void build(int root, int L, int R){    if(L == R){      tree[root].sum = num[fw[L]] % MOD;      return;    }    int mid = (L + R) >> 1, Lson = root << 1, Rson = root << 1 | 1;    build(Lson, L, mid);    build(Rson, mid+1, R);    tree[root].sum = (tree[Lson].sum + tree[Rson].sum) % MOD;}void down(int root, int L, int R){    if(tree[root].add == 0)  return;    int mid = (L + R) >> 1, Lson = root << 1, Rson = root << 1 | 1;    tree[Lson].add = (tree[Lson].add + tree[root].add) % MOD;    tree[Rson].add = (tree[Rson].add + tree[root].add) % MOD;    tree[Lson].sum = (tree[Lson].sum + (tree[root].add*(mid-L+1)%MOD)%MOD) % MOD;    tree[Rson].sum = (tree[Rson].sum + (tree[root].add*(R-mid)%MOD)%MOD) % MOD;    tree[root].add = 0;}void update(int root, int L, int R, int x, int y, int val){    if(x > R || y < L)  return;    if(x <= L && y >= R){      tree[root].sum = (tree[root].sum + ((R - L + 1)%MOD * val)%MOD) % MOD;      tree[root].add = (tree[root].add + val) % MOD;      return;    }    int mid = (L + R) >> 1, Lson = root << 1, Rson = root << 1 | 1;    down(root, L, R);    update(Lson, L, mid, x, y, val);    update(Rson, mid+1, R, x, y, val);    tree[root].sum = (tree[Lson].sum + tree[Rson].sum) % MOD;}int query(int root, int L, int R, int x, int y){    if(x > R || y < L)  return 0;    if(x <= L && y >= R)  return tree[root].sum;    int mid = (L + R) >> 1, Lson = root << 1, Rson = root << 1 | 1;    down(root, L, R);    int tmp1 = query(Lson, L, mid, x, y);    int tmp2 = query(Rson, mid+1, R, x, y);    return (tmp1 + tmp2) % MOD;}void work1(int x, int y, int val){    while(top[x] != top[y]){      if(dep[top[x]] > dep[top[y]])  swap(x, y);      update(1, 1, n, start[top[y]], start[y], val);      y = fa[top[y]];    }    if(dep[x] > dep[y])  swap(x, y);    update(1, 1, n, start[x], start[y], val);}int work2(int x, int y){    int ans = 0;    while(top[x] != top[y]){      if(dep[top[x]] > dep[top[y]])  swap(x, y);      ans = (ans + query(1, 1, n, start[top[y]], start[y])) % MOD;      y = fa[top[y]];    }    if(dep[x] > dep[y])  swap(x, y);    ans = (ans + query(1, 1, n, start[x], start[y])) % MOD;    return ans;}int main(){    scanf("%d%d%d%d", &n, &m, &Root, &MOD);    Init();    for(int i = 1; i <= n; i++)      scanf("%d", &num[i]);    int a, b;    for(int i = 1; i < n; i++){      scanf("%d%d", &a, &b);      Insert(a, b);      Insert(b, a);    }    dfs1(Root, 0);    dfs2(Root, Root);    build(1, 1, n);    int op, k;    for(int i = 1; i <= m; i++){      scanf("%d", &op);      if(op == 1){        scanf("%d%d%d", &a, &b, &k);        work1(a, b, k % MOD);      }      else if(op == 2){        scanf("%d%d", &a, &b);        printf("%d\n", work2(a, b));      }      else if(op == 3){        scanf("%d%d", &a, &b);        update(1, 1, n, start[a], end[a], b % MOD);      }      else{          scanf("%d", &a);        printf("%d\n", query(1, 1, n, start[a], end[a]));      }    }    return 0;}

Smile.

0 0
原创粉丝点击