划分树模板

来源:互联网 发布:编程主要学什么 编辑:程序博客网 时间:2024/04/29 17:11


#include<cstdio>#include<cmath>#include<cstring>#include<algorithm>using namespace std;#define MAXN 100001#define L(u) (u<<1)#define R(u) (u<<1|1)#define MID(l, r) ((l+r)>>1)struct SegTree{    int l, r;} node[MAXN*4];int sortA[MAXN];int toLeft[20][MAXN]; //toLeft[i]表示[node[u].l, i]区域里有多少个数分到左子树中int val[20][MAXN];void build(int u, int l, int r, int h){node[u].l = l;node[u].r = r;if(node[u].l == node[u].r)return;int mid = MID(l, r);int eqlToLeft = mid - l + 1;  //eqlToLeft统计区间[l,r]中与中位数相等且分到左子树中的数的个数for(int i = l; i <= r ; i++)if(val[h][i] < sortA[mid])eqlToLeft--;  //先假设左子树中的(mid-l+1)个数都等于中位数,然后把实际上小于中位数的减去int lpos = l;int rpos = mid + 1;int cnt = 0;  //统计已经进入左子树的数个数(对于所有等于中位数的数)for(int i = l ; i <= r ; i++)    {if(i == l)toLeft[h][i] = 0;elsetoLeft[h][i] = toLeft[h][i-1];if(val[h][i] < sortA[mid]){toLeft[h][i]++;val[h+1][lpos++] = val[h][i];}        else if(val[h][i] > sortA[mid])val[h+1][rpos++] = val[h][i];else        {            //对于等于中位数的数,一部分分到左子树,一部分分到右子树if(cnt < eqlToLeft)            {cnt++;toLeft[h][i]++;val[h+1][lpos++] = val[h][i];}            elseval[h+1][rpos++] = val[h][i];}}build(L(u), l, mid, h + 1);build(R(u), mid + 1, r, h + 1);}int query(int u, int l, int r, int h, int k){if(l == r) return val[h][l];int cnt1;  //cnt1表示[node[u].l, l-1]有多少个数分到左子树中int cnt2;  //cnt2表示[l,r]有多少个数分到当前区间的左子树中//[node[u].l, l-1] + [l,r] = [node[u].l, r]if(l == node[u].l)    {        cnt1 = 0;cnt2 = toLeft[h][r];}else    {        cnt1 = toLeft[h][l-1];cnt2 = toLeft[h][r] - toLeft[h][l-1];}if(cnt2 >= k) //[l,r]区间上有多于k个分到左边,显然去左子树找第k个    {        //计算出新的映射区间,注意:划分树上保证下标的顺序不变int newl = node[u].l + cnt1; //[node[u].l, l-1]int newr = node[u].l + cnt1 + cnt2 - 1; //[l,r]return query(L(u), newl, newr, h + 1, k);}else    {int mid = MID(node[u].l, node[u].r);int cnt3 = l - node[u].l - cnt1;  //cnt3记录node[u].l, l-1]有多少个分到右子树中int cnt4 = r - l + 1 - cnt2;  //cnt4记录[l,r]有多少个分到右子树中int newl = mid + cnt3 + 1;int newr = mid + cnt3 + cnt4;return query(R(u), newl, newr, h+1, k - cnt2);}}int main(){    int n, m;    while(scanf("%d%d",&n,&m) != EOF)    {        for(int i = 1; i <= n; i++)        {            scanf("%d",&val[0][i]);            sortA[i] = val[0][i];        }        sort(sortA + 1, sortA + 1 + n);        build(1, 1, n, 0);        int l, r, k;        while(m--)        {            scanf("%d%d%d",&l,&r,&k);            int ret = query(1, l, r, 0, k);            printf("%d\n",ret);        }    }}


原创粉丝点击