可持久化线段树

来源:互联网 发布:语音对话机器人app源码 编辑:程序博客网 时间:2024/06/05 06:20

可持久化线段树

简介

可持久化数据结构又称函数式数据结构,其思路来自于函数式编程。在函数式编程中,变量的值是不允许改变的,因而每一次插入元素都必须创建一个新的版本。

设想一棵二叉树:

        [1]    [2]     [3] [4]  [5] [6]  

现在为了插入一个新节点,我们必须新建一棵树

        (1)    (2)     (3) (4)  (5) (6)  (7)

不难发现,很多元素被重复使用了。如果将重复的元素合并,就得到这样一棵树:

        [1]    --->  (1)    [2]     [3]   [2]   (3) [4]  [5] [6]         [6]  (7)

新建的元素其实只有O(h),如果是一棵平衡树或线段树,新建元素就是O(lgn)

应用

可持久化线段树是解决区间问题的锐利武器。考虑第i棵和第j棵线段树Ti,Tj,如果他们的对应元素相减得到一棵新树TjTi,这棵树其实就是区间 [i+1,j] 所对应的线段树。

例如vijos1459车展一题。用反证法不难证明题目中要求的即是

i=lr|ximid|

其中,mid为区间 x[l,r] 的中位数。

由于涉及了区间中位数,可以考虑使用树套树实现。但树套树代码复杂度较高且不宜于调试,可以考虑用可持久化线段树代替。

将输入的 xi 按顺序建立一棵可持久化线段树,分别维护sumnum_sum,第一个为区间内元素的和,第二个为区间内元素出现的次数。利用 TrTl1 得到区间 [l,r] 内的线段树来计算。

Code

// 可持久化线段树// 维护两个值#include <bits/stdc++.h>using namespace std;#define maxn 1005struct node {    int l, r, lc, rc;    long long sum;    int num_sum;    node(){l = r = lc = rc = sum = num_sum = 0; }}tree[15*maxn];int root[200005], top = 0;int n, m;inline long long read(){    long long a = 0; int c;    do c = getchar(); while(!isdigit(c));    while (isdigit(c)) {        a = a*10 + c - '0';        c = getchar();    }    return a;}int sorted[1005]; // 离散化int dat[1005]; // 原始数据inline void update(int i) {    tree[i].sum = tree[tree[i].lc].sum + tree[tree[i].rc].sum;    tree[i].num_sum = tree[tree[i].lc].num_sum + tree[tree[i].rc].num_sum;}inline int new_node(int l, int r) {    tree[++top].l = l;    tree[top].r = r;    return top;}void build(int &nd, int l, int r) {    if (l > r) return;    if (l == r) {nd = new_node(l, r);return;}    int mid = (l+r)>>1;    nd = new_node(l, r);    build(tree[nd].lc, l, mid);    build(tree[nd].rc, mid+1, r);}void insert(int pre, int &now, int k, long long dat) {    if (tree[pre].l == tree[pre].r) {        now = new_node(k, k);        tree[now].sum = dat;        tree[now].num_sum = 1;    } else {        now = new_node(tree[pre].l, tree[pre].r);        tree[now] = tree[pre];        if (k <= tree[tree[pre].lc].r) insert(tree[pre].lc, tree[now].lc, k, dat);        else insert(tree[pre].rc, tree[now].rc, k, dat);        update(now);    }}// 查找区间和(sum)long long get_sum(int pre, int now, int l, int r){    if (l > r || !pre || !now) return 0;    if (l == tree[pre].l && r == tree[now].r) return tree[now].sum - tree[pre].sum;    return get_sum(tree[pre].lc, tree[now].lc, l, min(r, tree[tree[pre].lc].r)) +           get_sum(tree[pre].rc, tree[now].rc, max(tree[tree[pre].rc].l, l), r);}// 区间内数字个数的和int get_num_sum(int pre, int now, int l, int r){    if (l > r || !pre || !now) return 0;    if (l == tree[pre].l && r == tree[now].r) return tree[now].num_sum - tree[pre].num_sum;    return get_num_sum(tree[pre].lc, tree[now].lc, l, min(r, tree[tree[pre].lc].r)) +           get_num_sum(tree[pre].rc, tree[now].rc, max(tree[tree[pre].rc].l, l), r);}int find_mid(int pre, int now, int k) // 查找中位数(第k个数)的位置{    if (tree[now].l == tree[now].r) return tree[now].l;    if (tree[tree[now].lc].num_sum - tree[tree[pre].lc].num_sum >= k)        return find_mid(tree[pre].lc, tree[now].lc, k);    else        return find_mid(tree[pre].rc, tree[now].rc, k-(tree[tree[now].lc].num_sum - tree[tree[pre].lc].num_sum));}// 查询区间long long query(int l, int r) {    int pos = find_mid(root[l-1], root[r], ((l+r)>>1)-l+1);    long long lft = get_sum(root[l-1], root[r], 1, pos);int ln = get_num_sum(root[l-1], root[r], 1, pos);    long long rgt = get_sum(root[l-1], root[r], pos+1, n);int rn = get_num_sum(root[l-1], root[r], pos+1, n);    return rgt - rn*sorted[pos] + ln*sorted[pos] - lft;}void dfs(int rt, int tab = 0) {    if (rt) {        for (size_t i = 0; i < tab; i++) putchar(' ');        cout << tree[rt].l << "->" << tree[rt].r << " " << tree[rt].sum << " " << tree[rt].num_sum << endl;        dfs(tree[rt].lc, tab+2);        dfs(tree[rt].rc, tab+2);    }}int main(){    n = read(); m = read();    build(root[0], 1, n);    long long a, l, r;    for (size_t i = 1; i <= n; i++)        sorted[i] = dat[i] = read();    sort(sorted+1, sorted+n+1);    for (size_t i = 1; i <= n; i++) {        insert(root[i-1], root[i], lower_bound(sorted+1, sorted+n+1, dat[i])-sorted, dat[i]);    }    long long ans = 0;    for (size_t i = 1; i <= m; i++) {        l = read(); r = read();        ans += query(l, r);    }    cout << ans << endl;    return 0;}
0 0
原创粉丝点击