BZOJ3224 普通平衡树

来源:互联网 发布:jquery 1.2.6.min.js 编辑:程序博客网 时间:2024/04/27 14:33

Description

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)

Input

第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)

Output

对于操作3,4,5,6每行输出一个数,表示对应答案

Sample Input

10

1 106465

4 1

1 317721

1 460929

1 644985

1 84185

1 89851

6 81968

1 492737

5 493598

Sample Output

106465

84185

492737

HINT

1.n的数据范围:n<=100000

2.每个数的数据范围:[-1e7,1e7]

平衡树裸题,当做SBT写法练手了。最终成品很糟糕,有时间再完善吧。
SBT的删除操作要注意,final_erase的节点数目更改要处理好,因为这个WA了几次,具体参见代码。
自己的代码:

#include<cstdio> struct node {     node *s[2];     int v,n;     node(int v):v(v),n(1)     {s[0]=s[1]=0;}     int cmp(int a)     {return a==v?-1:a<v?0:1;} }*sbt; void rot(node* &o,bool k) {     node *x=o->s[1^k];     o->s[1^k]=x->s[k];     x->s[k]=o;     int b=x->n;     x->n=o->n;     o->n+=(o->s[1^k]?o->s[1^k]->n:0)-b;     o=x; } void mt(node* &x,bool k) {     int a=x->s[k]?x->s[k]->n:0,         b=x->s[1^k]?x->s[1^k]->s[1^k]?x->s[1^k]->s[1^k]->n:0:0,         c=x->s[1^k]?x->s[1^k]->s[k]?x->s[1^k]->s[k]->n:0:0;         if(a<b) rot(x,k);         else if(a<c) {rot(x->s[1^k],1^k);rot(x,k);}         else return;         if(x->s[1^k]) mt(x->s[1^k],k);         if(x->s[k]) mt(x->s[k],1^k);         mt(x,k);         mt(x,1^k); } int v,p,r; void insert(node* &x) {     if(!x)     {         x=new node(v);         return;     }     x->n++;     int k=x->cmp(v);     if(k==-1) return;     insert(x->s[k]);     mt(x,1^k); } inline void insert(int vx) {     v=vx;     insert(sbt); } void final_erase(node* &x) {     if(!x->s[1])     {         v=x->v;         node *y=x->s[0];         r=x->n-(y?y->n:0);        delete x;         x=y;         return;     }      final_erase(x->s[1]);    x->n-=r;    mt(x,1); } void erase(node* &x) {     int k=x->cmp(v);     if(k==-1)     {         int a=x->n-(x->s[0]?x->s[0]->n:0)-(x->s[1]?x->s[1]->n:0);         if(a>1)         {             x->n--;             return;         }         if(!x->s[0]) k=1;         if(!x->s[1]) k=0;         if(k!=-1)         {             node *y=x;             x=x->s[k];             delete y;         }         else        {             final_erase(x->s[0]);             x->n--;             x->v=v;         }         return;     }     x->n--;     erase(x->s[k]);     mt(x,k); } inline void erase(int vx) {     v=vx;     erase(sbt); } int rank(int vx) {     int ans=0,k;     node *x=sbt;     while((k=x->cmp(vx))!=-1)     {         if(k==1) ans+=x->n-x->s[1]->n;         x=x->s[k];     }     return ans+(x->s[0]?x->s[0]->n+1:1); } int kth(int vx) {     node *x=sbt;     int a;     while(1)     {         a=x->s[0]?x->s[0]->n:0;         if(vx<=a) {x=x->s[0];continue;}         a=x->n-(x->s[1]?x->s[1]->n:0);         if(vx>a) {vx-=a;x=x->s[1];continue;}         return x->v;     } } void pre(node* x) {     int k=x->cmp(v);     if(k==1) {p=x->v;if(x->s[1]) pre(x->s[1]);}     else if(x->s[0]) pre(x->s[0]); } inline int pre(int vx) {     v=vx;     pre(sbt);     return p; } void nex(node* x) {     int k=x->cmp(v);     if(!k) {p=x->v;if(x->s[0]) nex(x->s[0]);}     else if(x->s[1]) nex(x->s[1]); } inline int nex(int vx) {     v=vx;     nex(sbt);     return p; } int main() {     int m;     scanf("%d",&m);     int r;     for(int i=0;i<m;i++)     {         scanf("%d",&r);         switch(r)         {             case 1:                 {                     scanf("%d",&r);                     insert(r);                     break;                 }             case 2:                 {                     scanf("%d",&r);                     erase(r);                     break;                 }             case 3:                 {                     scanf("%d",&r);                     printf("%d\n",rank(r));                     break;                 }             case 4:                 {                     scanf("%d",&r);                     printf("%d\n",kth(r));                     break;                 }             case 5:                 {                     scanf("%d",&r);                     printf("%d\n",pre(r));                     break;                 }             default:                 {                     scanf("%d",&r);                     printf("%d\n",nex(r));                 }         }     }     return 0; }
0 0