线段树

来源:互联网 发布:自由光导航淘宝 编辑:程序博客网 时间:2024/06/03 17:11

HH神的线段树出神入化,所以跟着HH学习线段树。

风格:

maxn是题目给的最大区间,而节点数要开4倍,确切的说……

lson和rson辨别表示结点的左孩子和右孩子。

PushUp(int rt)是把当前结点的信息更新到父节点

PushDown(int rt)是把当前结点的信息更新给孩子结点。

rt表示当前子树的根(root),也就是当前所在的结点。


思想:

对于每个非叶节点所标示的结点 [a,b],其做孩子表示的区间是[a,(a+b)/2],其右孩子表示[(a+b)/2,b].

构造:



离散化和线段树:

题目:x轴上有若干个线段,求线段覆盖的总长度。

普通解法:设置坐标范围[min,max],初始化为0,然后每一段分别染色为1,最后统计1的个数,适用于线段数目少,区间范围小。

离散化的解法:离散化就是一一映射的关系,即将一个大坐标和小坐标进行一一映射,适用于线段数目少,区间范围大。

例如:[10000,22000],[30300,55000],[44000,60000],[55000,60000].

第一步:排序 10000 22000 30300 44000 55000 60000

第二部:编号 1        2        3         4       5         6

第三部:用编号来代替原数,即小数代大数 。

[10000,22000]~[1,2]

[30300,55000]~[3,5]

[44000,60000]~[4,6]

[55000,60000]~[5,6]

然后再用小数进行普通解法的步骤,最后代换回去。

线段树的解法:线段树通过建立线段,将原来染色O(n)的复杂度减小到 log(n),适用于线段数目多,区间范围小的情况。

离散化的线段树:适用于线段数目多,区间范围大的情况。


构造:

动态数据结构:

struct node{

 node* left;

 node* right;

……

}

静态全局数组模拟(完全二叉树):

struct node{

  int left;

  int right;

……

}Tree[MAXN]

例如:



线段树与点树:

线段树的每一个结点表示一个点,成为点树,比如说用于求第k小数的线段树。

点树结构体:

struct node{

int l, r;

int c;//用于存放次结点的值,默认为0

}T[3*MAXN];

创建:

创建顺序为先序遍历,即先构造根节点,再构造左孩子,再构造右孩子。

[cpp] view plaincopy
  1. void construct(int l, int r, int k){  
  2.     T[k].l = l;  
  3.     T[k].r = r;  
  4.     T[k].c = 0;  
  5.     if(l == r) return ;  
  6.     int m = (l + r) >> 1;  
  7.     construct(l, m, k << 1);  
  8.     construct(m + 1, r, (k << 1) + 1);  
  9.     return ;  
  10. }  
[cpp] view plaincopy
  1. void construct(int l, int r, int k){  
  2.     T[k].l = l;  
  3.     T[k].r = r;  
  4.     T[k].c = 0;  
  5.     if(l == r) return ;  
  6.     int m = (l + r) >> 1;  
  7.     construct(l, m, k << 1);  
  8.     construct(m + 1, r, (k << 1) + 1);  
  9.     return ;  
  10. }  
 


[A,B,C]:A表示左值,B表示右值,C表示在静态数组中的位置,由此可知,n个点的话大约共有2*n个结点,因此开3*n的结构体一定是够的。


更新值:

[cpp] view plaincopy
  1. void insert(int d, int k){  
  2.     //如果找到了就c值+1返回。   
  3.     if(T[k].l == T[k].r && d == T[k].l){  
  4.         T[k].c += 1;  
  5.         return ;  
  6.     }  
  7.     int m = (T[k].l + T[k].r) >> 1;  
  8.     if(d <= m) insert(d, k << 1);  
  9.     else insert(d, (k << 1) + 1);  
  10.     //更新每一个c,向上更新   
  11.     T[k].c = T[k << 1].c + T[(k << 1) + 1].c;  
  12. }  
[cpp] view plaincopy
  1. void insert(int d, int k){  
  2.     //如果找到了就c值+1返回。  
  3.     if(T[k].l == T[k].r && d == T[k].l){  
  4.         T[k].c += 1;  
  5.         return ;  
  6.     }  
  7.     int m = (T[k].l + T[k].r) >> 1;  
  8.     if(d <= m) insert(d, k << 1);  
  9.     else insert(d, (k << 1) + 1);  
  10.     //更新每一个c,向上更新  
  11.     T[k].c = T[k << 1].c + T[(k << 1) + 1].c;  
  12. }  

查找值:

[cpp] view plaincopy
  1. //k表示树根,d表示要查找的值   
  2. void search(int d, int k, int& ans)  
  3. {  
  4.     if(T[k].l == T[k].r){  
  5.         ans = T[k].l;  
  6.         ans = T[k].l;  
  7.     }  
  8.     int m = (T[k].l + T[k].r) >> 1;  
  9.     //不懂   
  10.     if(d > T[(k << 1)].c) search(d - T[k << 1].c, (k << 1) + 1, ans);  
  11.     else search(d, k << 1, ans);  
  12. }  
[cpp] view plaincopy
  1. //k表示树根,d表示要查找的值  
  2. void search(int d, int k, int& ans)  
  3. {  
  4.     if(T[k].l == T[k].r){  
  5.         ans = T[k].l;  
  6.         ans = T[k].l;  
  7.     }  
  8.     int m = (T[k].l + T[k].r) >> 1;  
  9.     //不懂  
  10.     if(d > T[(k << 1)].c) search(d - T[k << 1].c, (k << 1) + 1, ans);  
  11.     else search(d, k << 1, ans);  
  12. }  

search函数的用法不太懂。

例题解:

(待更新)


四类题型:

1.单点更新   只更新叶子结点,然后把信息用PushUp(int r)这个函数更新上来。

hdu1166:敌兵布阵

线段树功能:update:单点替换 query:区间最值




poj2828

树状数组:

[cpp] view plaincopy
  1. #include <iostream>   
  2. #include <cstdio>   
  3. #include <string>   
  4. #include <cstring>   
  5. using namespace std;  
  6.   
  7. typedef pair<intint> PII;  
  8.   
  9. const int maxn = 200000;  
  10.   
  11. int C[maxn + 100];  
  12. int B[maxn + 100];  
  13. int n;  
  14. PII arr[maxn + 100];  
  15.   
  16. int lowbit(int k) { return k & (-k); }  
  17.   
  18. void init() {  
  19.     for(int i = 1; i <= n; i++) C[i] = lowbit(i);  
  20.     memset(B, -1, n + 10);  
  21. }  
  22.   
  23. void update(int i) {  
  24.     while(i <= n) {  
  25.         C[i]--;  
  26.         i += lowbit(i);  
  27.     }  
  28. }  
  29.   
  30.   
  31. int query(int i) {  
  32.     int ret = 0;  
  33.     while(i > 0) {  
  34.         ret += C[i];  
  35.         i -= lowbit(i);  
  36.     }  
  37.     return ret;  
  38. }  
  39.   
  40. void debug() {  
  41.     for(int i = 1; i <= n; i++) cout << i << " " << query(i) << endl;  
  42. }  
  43.   
  44.   
  45. void fun(int a, int v) {  
  46.     int l = 1, r = n;  
  47.     while(l < r) {  
  48.         int m = (l + r) >> 1;  
  49.         if(query(m) >= a) r = m;  
  50.         else l = m + 1;  
  51.     }  
  52.     //cout << "here  " << l << endl;   
  53.     update(l);  
  54.     //cout << "here2 " << endl;   
  55.     //debug();   
  56.     B[l] = v;  
  57.     //return l;   
  58. }  
  59.   
  60.   
  61.   
  62.   
  63. int main() {  
  64.     while(~scanf("%d", &n)) {  
  65.         init();  
  66.         int a, b;  
  67.         for(int i = 1; i <= n; i++) {  
  68.             scanf("%d%d", &a, &b);  
  69.             a++;  
  70.             arr[i].first = a;  
  71.             arr[i].second = b;  
  72.         }  
  73.         for(int i = n; i > 0; i--) fun(arr[i].first, arr[i].second);  
  74.         //debug2();   
  75.         //bool flag = false;   
  76.         for(int i = 1; i <= n; i++) {  
  77.             i == 1 ? printf("%d", B[i]) : printf(" %d", B[i]);  
  78.             //if(B[i] != -1 && !flag) { printf("%d", B[i]); flag = true; }  
  79.             //else if(B[i] != -1) printf(" %d", B[i]);  
  80.         }  
  81.         puts("");  
  82.     }  
  83.     return 0;  
  84. }  
[cpp] view plaincopy
  1. #include <iostream>  
  2. #include <cstdio>  
  3. #include <string>  
  4. #include <cstring>  
  5. using namespace std;  
  6.   
  7. typedef pair<intint> PII;  
  8.   
  9. const int maxn = 200000;  
  10.   
  11. int C[maxn + 100];  
  12. int B[maxn + 100];  
  13. int n;  
  14. PII arr[maxn + 100];  
  15.   
  16. int lowbit(int k) { return k & (-k); }  
  17.   
  18. void init() {  
  19.     for(int i = 1; i <= n; i++) C[i] = lowbit(i);  
  20.     memset(B, -1, n + 10);  
  21. }  
  22.   
  23. void update(int i) {  
  24.     while(i <= n) {  
  25.         C[i]--;  
  26.         i += lowbit(i);  
  27.     }  
  28. }  
  29.   
  30.   
  31. int query(int i) {  
  32.     int ret = 0;  
  33.     while(i > 0) {  
  34.         ret += C[i];  
  35.         i -= lowbit(i);  
  36.     }  
  37.     return ret;  
  38. }  
  39.   
  40. void debug() {  
  41.     for(int i = 1; i <= n; i++) cout << i << " " << query(i) << endl;  
  42. }  
  43.   
  44.   
  45. void fun(int a, int v) {  
  46.     int l = 1, r = n;  
  47.     while(l < r) {  
  48.         int m = (l + r) >> 1;  
  49.         if(query(m) >= a) r = m;  
  50.         else l = m + 1;  
  51.     }  
  52.     //cout << "here  " << l << endl;  
  53.     update(l);  
  54.     //cout << "here2 " << endl;  
  55.     //debug();  
  56.     B[l] = v;  
  57.     //return l;  
  58. }  
  59.   
  60.   
  61.   
  62.   
  63. int main() {  
  64.     while(~scanf("%d", &n)) {  
  65.         init();  
  66.         int a, b;  
  67.         for(int i = 1; i <= n; i++) {  
  68.             scanf("%d%d", &a, &b);  
  69.             a++;  
  70.             arr[i].first = a;  
  71.             arr[i].second = b;  
  72.         }  
  73.         for(int i = n; i > 0; i--) fun(arr[i].first, arr[i].second);  
  74.         //debug2();  
  75.         //bool flag = false;  
  76.         for(int i = 1; i <= n; i++) {  
  77.             i == 1 ? printf("%d", B[i]) : printf(" %d", B[i]);  
  78.             //if(B[i] != -1 && !flag) { printf("%d", B[i]); flag = true; }  
  79.             //else if(B[i] != -1) printf(" %d", B[i]);  
  80.         }  
  81.         puts("");  
  82.     }  
  83.     return 0;  
  84. }  

poj-3468

[cpp] view plaincopy
  1. #include <cstdio>   
  2. #include <cstring>   
  3. #include <iostream>   
  4. using namespace std;  
  5.   
  6. #define lson l, m, rt<<1   
  7. #define rson m+1, r, rt<<1|1  
  8.   
  9. typedef long long LL;  
  10.   
  11. const int maxn = 111111;  
  12.   
  13. LL col[maxn<<2];  
  14. LL sum[maxn<<2];  
  15.   
  16. void PushUp(LL rt) {  
  17.     sum[rt] = sum[rt<<1] + sum[rt<<1|1];  
  18. }  
  19.   
  20. //pushdown的作用是如果此点可以更新。   
  21. //也就是更新到下一层   
  22. //如果是底层,那么是不用pushdown的。   
  23. void PushDown(LL rt, LL m) {  
  24.     if(col[rt]) {  
  25.         //col[rt<<1] = col[rt<<1|1] = col[rt];  
  26.         col[rt<<1] += col[rt];  
  27.         col[rt<<1|1] += col[rt];  
  28.         sum[rt<<1] += col[rt] * (m - (m>>1));  
  29.         sum[rt<<1|1] += col[rt] * (m>>1);  
  30.         col[rt] = 0;  
  31.     }  
  32. }  
  33.   
  34. void build(LL l, LL r, LL rt) {  
  35.     col[rt] = 0;  
  36.     //cout << l << " " << r << endl;   
  37.     if(l == r) {  
  38.         scanf("%I64d", &sum[rt]);  
  39.         //cout << rt << " " << sum[rt] << endl;  
  40.         return ;  
  41.     }  
  42.     int m = (l + r) >> 1;  
  43.     build(lson);  
  44.     build(rson);  
  45.     PushUp(rt);  
  46. }  
  47.   
  48. LL query(LL L, LL R, LL l, LL r, LL rt) {  
  49.     LL ret = 0;  
  50.     if(L <= l && r <= R) {  
  51.         //if(col[rt]) return sum[rt] + (r - l + 1) * col[rt];  
  52.         return sum[rt];  
  53.     }  
  54.     PushDown(rt, r - l + 1);  
  55.     int m = (l + r) >> 1;  
  56.     if(L <= m) ret += query(L, R, lson);  
  57.     if(R > m) ret += query(L, R, rson);  
  58.     return ret;  
  59. }  
  60.   
  61. void update(LL L, LL R, LL c, LL l, LL r, LL rt) {  
  62.     if(L <= l && r <= R) {  
  63.         sum[rt] += c * (r - l + 1);  
  64.         col[rt] += c;//子节点没有更新   
  65.         return ;  
  66.     }  
  67.     PushDown(rt, r - l + 1);  
  68.     int m = (l + r) >> 1;  
  69.     if(L <= m) update(L, R, c, lson);  
  70.     if(R > m) update(L, R, c, rson);  
  71.     PushUp(rt);  
  72. }  
  73.   
  74. void debug(int n) {  
  75.     for(int i = 1; i <= (n*3); i++) {  
  76.         cout << i << " ";  
  77.     }  
  78.     cout << endl;  
  79.     for(int i = 1; i <= (n*3); i++) {  
  80.         cout << col[i] << " ";  
  81.     }  
  82.     cout << endl << endl;  
  83.     for(int i = 1; i <= (n*3); i++) {  
  84.         cout << i << " ";  
  85.     }  
  86.     cout << endl;  
  87.     for(int i = 1; i <= (n*3); i++) {  
  88.         cout << sum[i] << " ";  
  89.     }  
  90.     cout << endl;  
  91. }  
  92.   
  93. int main() {  
  94.     LL N, Q;  
  95.     while(~scanf("%I64d%I64d", &N, &Q)) {  
  96.         //cout << "N = " << N << endl;   
  97.         memset(sum, 0, sizeof(sum));  
  98.         memset(col, 0, sizeof(col));  
  99.         build(1, N, 1);  
  100.         //debug(N);   
  101.         for(int i = 0; i < Q; i++) {  
  102.             char ch[3];  
  103.             LL a, b, c;  
  104.             scanf("%s", ch);  
  105.             if(ch[0] == 'Q') {  
  106.                 scanf("%I64d%I64d", &a, &b);  
  107.                 printf("%I64d\n", query(a, b, 1, N, 1));  
  108.             }  
  109.             else {  
  110.                 scanf("%I64d%I64d%I64d", &a, &b, &c);  
  111.                 update(a, b, c, 1, N, 1);  
  112.             }  
  113.             //debug(N);   
  114.         }  
  115.     }  
  116.     return 0;  
  117. }  
[cpp] view plaincopy
  1. #include <cstdio>  
  2. #include <cstring>  
  3. #include <iostream>  
  4. using namespace std;  
  5.   
  6. #define lson l, m, rt<<1  
  7. #define rson m+1, r, rt<<1|1  
  8.   
  9. typedef long long LL;  
  10.   
  11. const int maxn = 111111;  
  12.   
  13. LL col[maxn<<2];  
  14. LL sum[maxn<<2];  
  15.   
  16. void PushUp(LL rt) {  
  17.     sum[rt] = sum[rt<<1] + sum[rt<<1|1];  
  18. }  
  19.   
  20. //pushdown的作用是如果此点可以更新。  
  21. //也就是更新到下一层  
  22. //如果是底层,那么是不用pushdown的。  
  23. void PushDown(LL rt, LL m) {  
  24.     if(col[rt]) {  
  25.         //col[rt<<1] = col[rt<<1|1] = col[rt];  
  26.         col[rt<<1] += col[rt];  
  27.         col[rt<<1|1] += col[rt];  
  28.         sum[rt<<1] += col[rt] * (m - (m>>1));  
  29.         sum[rt<<1|1] += col[rt] * (m>>1);  
  30.         col[rt] = 0;  
  31.     }  
  32. }  
  33.   
  34. void build(LL l, LL r, LL rt) {  
  35.     col[rt] = 0;  
  36.     //cout << l << " " << r << endl;  
  37.     if(l == r) {  
  38.         scanf("%I64d", &sum[rt]);  
  39.         //cout << rt << " " << sum[rt] << endl;  
  40.         return ;  
  41.     }  
  42.     int m = (l + r) >> 1;  
  43.     build(lson);  
  44.     build(rson);  
  45.     PushUp(rt);  
  46. }  
  47.   
  48. LL query(LL L, LL R, LL l, LL r, LL rt) {  
  49.     LL ret = 0;  
  50.     if(L <= l && r <= R) {  
  51.         //if(col[rt]) return sum[rt] + (r - l + 1) * col[rt];  
  52.         return sum[rt];  
  53.     }  
  54.     PushDown(rt, r - l + 1);  
  55.     int m = (l + r) >> 1;  
  56.     if(L <= m) ret += query(L, R, lson);  
  57.     if(R > m) ret += query(L, R, rson);  
  58.     return ret;  
  59. }  
  60.   
  61. void update(LL L, LL R, LL c, LL l, LL r, LL rt) {  
  62.     if(L <= l && r <= R) {  
  63.         sum[rt] += c * (r - l + 1);  
  64.         col[rt] += c;//子节点没有更新  
  65.         return ;  
  66.     }  
  67.     PushDown(rt, r - l + 1);  
  68.     int m = (l + r) >> 1;  
  69.     if(L <= m) update(L, R, c, lson);  
  70.     if(R > m) update(L, R, c, rson);  
  71.     PushUp(rt);  
  72. }  
  73.   
  74. void debug(int n) {  
  75.     for(int i = 1; i <= (n*3); i++) {  
  76.         cout << i << " ";  
  77.     }  
  78.     cout << endl;  
  79.     for(int i = 1; i <= (n*3); i++) {  
  80.         cout << col[i] << " ";  
  81.     }  
  82.     cout << endl << endl;  
  83.     for(int i = 1; i <= (n*3); i++) {  
  84.         cout << i << " ";  
  85.     }  
  86.     cout << endl;  
  87.     for(int i = 1; i <= (n*3); i++) {  
  88.         cout << sum[i] << " ";  
  89.     }  
  90.     cout << endl;  
  91. }  
  92.   
  93. int main() {  
  94.     LL N, Q;  
  95.     while(~scanf("%I64d%I64d", &N, &Q)) {  
  96.         //cout << "N = " << N << endl;  
  97.         memset(sum, 0, sizeof(sum));  
  98.         memset(col, 0, sizeof(col));  
  99.         build(1, N, 1);  
  100.         //debug(N);  
  101.         for(int i = 0; i < Q; i++) {  
  102.             char ch[3];  
  103.             LL a, b, c;  
  104.             scanf("%s", ch);  
  105.             if(ch[0] == 'Q') {  
  106.                 scanf("%I64d%I64d", &a, &b);  
  107.                 printf("%I64d\n", query(a, b, 1, N, 1));  
  108.             }  
  109.             else {  
  110.                 scanf("%I64d%I64d%I64d", &a, &b, &c);  
  111.                 update(a, b, c, 1, N, 1);  
  112.             }  
  113.             //debug(N);  
  114.         }  
  115.     }  
  116.     return 0;  
  117. }  
0 0
原创粉丝点击