线段树

来源:互联网 发布:丑陋的中国人 知乎 编辑:程序博客网 时间:2024/06/05 04:08

线段树英文叫做segment tree。最近研究了下,发现非常有用,面试中考的也比较多。那什么样的题目可以使用线段树呢?它具有以下几个特点,当遇到这样的题目时,可以考虑用线段树。

  • 求一组区间值
  • 原始数据会发生变化

什么是线段树?

线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为O(logN)。而未优化的空间复杂度为2N,因此有时需要离散化让空间压缩。
上面是百度百科对线段树的说明,因为它是一种二叉搜索树,所以它更删查改一个数据的时间复杂度是O(lgn)。

构建线段树

比如对于含有n个元素的数据,构建成线段树,左分支范围是[0,n/2],右分支范围是[n/2+1,n-1]。
那对于一个含有4个元素的数组,构建完之后样子如下:

              [0,  3]             /        \      [0,  1]           [2, 3]      /     \           /     \   [0, 0]  [1, 1]     [2, 2]  [3, 3]

下面是C++的实现:

class SegmentTreeNode {public:    int start, end;    SegmentTreeNode *left, *right;    SegmentTreeNode(int start, int end) {        this->start = start, this->end = end;        this->left = this->right = NULL;    }}class Solution {public:    /**     *@param start, end: Denote an segment / interval     *@return: The root of Segment Tree     */    SegmentTreeNode * build(int start, int end) {        if(start > end){            return nullptr;        }        if(start == end){            return new SegmentTreeNode(start, end);        }        auto mid = start + ((end - start)>>1);        auto pNode = new SegmentTreeNode(start, end);        pNode->left = build(start, mid);        pNode->right = build(mid+1, end);        return pNode;    }};
更改某个元素

假设线段树节点里保存的是这个range内的最大值,那当修改这个数组中某个元素,如何更新这个线段树,从上往下找到这个位置,修改这个元素,然后更新查找路径上的相关节点。

/** * Definition of SegmentTreeNode: * class SegmentTreeNode { * public: *     int start, end, max; *     SegmentTreeNode *left, *right; *     SegmentTreeNode(int start, int end, int max) { *         this->start = start; *         this->end = end; *         this->max = max; *         this->left = this->right = NULL; *     } * } */class Solution {public:    /**     *@param root, index, value: The root of segment tree and      *@ change the node's value with [index, index] to the new given value     *@return: void     */    void modify(SegmentTreeNode *root, int index, int value) {        updateMax(root, index, value);    }    int updateMax(SegmentTreeNode *root, int index, int value){        if(!root || index < root->start || index > root->end){            return -1;//error        }        if(index == root->start && index == root->end){            root->max = value;            return value;        } else {            auto mid = root->left->end;            if(index <= mid){                auto leftMax = updateMax(root->left, index, value);                root->max = max(leftMax, root->right->max);            }else{                auto rightMax = updateMax(root->right, index, value);                root->max = max(root->left->max, rightMax);            }            return root->max;        }            }};
线段树查找

假设线段树中节点存放的是区间最大值,如何求区间内的最大值?
其实本质上就是找到这个区间上所有覆盖的区间,然后求最大值。在求覆盖区间的时候就可以做最大值的计算。

/** * Definition of SegmentTreeNode: * class SegmentTreeNode { * public: *     int start, end, max; *     SegmentTreeNode *left, *right; *     SegmentTreeNode(int start, int end, int max) { *         this->start = start; *         this->end = end; *         this->max = max; *         this->left = this->right = NULL; *     } * } */class Solution {public:    /**     *@param root, start, end: The root of segment tree and     *                         an segment / interval     *@return: The maximum number in the interval [start, end]     */    int query(SegmentTreeNode *root, int start, int end) {        if(start > end || !root) {            return -1;//error        }        if (root->start == start && root->end == end){            return root->max;        }        auto mid = root->left->end;        if(end <= mid){            return query(root->left, start, end);        } else if(start > mid){            return query(root->right, start, end);        } else {            return max(query(root->left, start, mid), query(root->right, mid+1, end));        }    }};

常见题型

求区间和

当求区间和,第一印象是用preSum方法,但是这种方法只适用于这个数据不发生修改的情况下,如果发生修改,preSum数组就需要更新,而更新这个数组的复杂度就是O(n),这是不可接受的。采用线段树就很好的解决这个问题,因为修改一个线段树的复杂度是lgn。

class CNode{    public:    int start;    int end;    long long sum;    CNode* pLeft;    CNode* pRight;    CNode(int _start, int _end, int _sum):start(_start), end(_end), sum(_sum), pLeft(nullptr), pRight(nullptr){    }};class Solution {public:    /* you may need to use some attributes here */    /**     * @param A: An integer vector     */    Solution(vector<int> A) {        if(A.empty()){            m_pRoot = nullptr;        } else{            m_pRoot = build(A, 0, A.size()-1);        }    }    /**     * @param start, end: Indices     * @return: The sum from start to end     */    long long query(int start, int end) {        return querySum(m_pRoot, start, end);    }    /**     * @param index, value: modify A[index] to value.     */    void modify(int index, int value) {        updateValue(m_pRoot, index, value);    }    private:        CNode* m_pRoot;    private:        CNode* build(vector<int>&A, int start, int end){            if(start == end){                return new CNode(start, end, A[start]);            }            auto mid = start + ((end - start)>>1);            auto pCur = new CNode(start, end, 0);            pCur->pLeft = build(A, start, mid);            pCur->pRight = build(A, mid+1, end);            pCur->sum = pCur->pLeft->sum + pCur->pRight->sum;            return pCur;        }        long long querySum(CNode* pRoot, int start, int end){            if(!pRoot               || start > end               || start > pRoot->end               || end < pRoot->start) {                   return 0;            }            if(start <= pRoot->start && end >= pRoot->end) {                return pRoot->sum;            }            if(pRoot->start == pRoot->end){                return 0;            }            auto mid = pRoot->pLeft->end;            if(end <= mid){                return querySum(pRoot->pLeft, start, end);            } else if(start > mid){                return querySum(pRoot->pRight, start, end);            } else {                return querySum(pRoot->pLeft, start, mid) + querySum(pRoot->pRight, mid+1, end);            }        }        int updateValue(CNode* pRoot, int index, int value){            if(!pRoot               || index < pRoot->start               || index > pRoot->end){                   return 0;               }            if(pRoot->start == pRoot->end && pRoot->start == index){                auto diff = value - pRoot->sum;                pRoot->sum += diff;                return diff;            }            auto mid = pRoot->pLeft->end;            int diff = 0;            if(index <= mid){                diff = updateValue(pRoot->pLeft, index, value);            } else {                diff = updateValue(pRoot->pRight, index, value);            }            pRoot->sum += diff;            return diff;        }};

现在区间节点存放的是sum,其实也可以是max或者min,或者符合区间值。

求小于自身数的个数

对于线段树的题目一般是数组,求这个数组的某个range的值。但是作为range的不光可以是下标,还有可能是数值范围。

给定一个数组,其中数组元素范围是0~10000(这个信息非常重要),对于每个元素A[i],计算数组i元素之前小于A[i]的个数[1,2,7,8,5], 返回 [0,1,2,3,2]

初步看可以采用扫描的方法,计算到A[i]是遍历0~i-1中小于A[i]的元素个数,这样复杂度为O(n^2)。
因为数组元素有范围,一般想到可以用一个数组来表示全部元素data,而小于某个元素的个数,则是求0~data[A[i]]元素个数,因为它是求这个元素之前所有元素中小于A[i]元素个数,每次增加一个元素时相当于往这个数组中插入一个元素,考虑用线段树,线段节点里保存这个range中数的个数。

class CNode {    public:    int low;    int high;    int count;    CNode* pLeft;    CNode* pRight;    CNode(int _low, int _high, int _count):low(_low), high(_high), count(_count), pLeft(nullptr), pRight(nullptr){    }};class Solution {public:   /**     * @param A: An integer array     * @return: Count the number of element before this element 'ai' is     *          smaller than it and return count number array     */    vector<int> countOfSmallerNumberII(vector<int> &A) {        vector<int> res;        if(A.empty()){            return res;        }        auto pRoot = buildSegmentTree(0, 10000);        for(auto value:A){            if(value < pRoot->low || value > pRoot->high){                res.push_back(0);            } else{                res.push_back(lowerCount(pRoot, value));                updateSegmentTree(pRoot, value);            }        }        return res;    }    CNode* buildSegmentTree(int low, int high){        auto pCur = new CNode(low, high, 0);        if(low == high){            return pCur;        }        auto mid = low + ((high-low)>>1);        pCur->pLeft = buildSegmentTree(low, mid);        pCur->pRight = buildSegmentTree(mid+1, high);        return pCur;    }    int updateSegmentTree(CNode* pCur, int value) {        if(!pCur || value < pCur->low || value > pCur->high){            return 0;        }        if(pCur->low == pCur->high){            if(pCur->low == value){                pCur->count++;                return 1;            } else {                return 0;            }        } else{            auto mid = pCur->low + ((pCur->high-pCur->low)>>1);            int diff = 0;            if(value <= mid){                diff = updateSegmentTree(pCur->pLeft, value);            } else {                diff = updateSegmentTree(pCur->pRight, value);            }            pCur->count += diff;            return diff;        }    }    int lowerCount(CNode* pCur, int value){        if(!pCur           || pCur->low > value           || pCur->count == 0){            return 0;        }        if(pCur->high < value){            return pCur->count;        }        auto mid = pCur->low + ((pCur->high - pCur->low)>>1);        if(value <= mid){            return lowerCount(pCur->pLeft, value);        } else{            return lowerCount(pCur->pLeft, value) + lowerCount(pCur->pRight, value);        }    }};

总结

线段树是二分查找树的一种,它用来解决区间问题,当某个问题可以转化为几个区间问题,而且元素的变化只会影响某几个区间,不是所有区间,则可以考虑用线段树。常见情况是求数组区间的值。

1 0