splay 的普通平衡树功能

来源:互联网 发布:淘宝信用度查询 编辑:程序博客网 时间:2024/05/18 01:51

普通平衡树的功能主要有   

 插入 删除 一个数

 找 前驱  后继  第k大 小

求大于等于或小于等于某个数的个数(可以求逆序数)

确定一个数的排名

我把所有的功能整合在了一起了,代码如下

#include<cstdio>#include<cstdlib>const int inf  = ~0u>>2;#define L ch[x][0]#define R ch[x][1]#define KT (ch[ ch[rt][1] ][0])const int maxn = 500010;int lim;struct SplayTree {int sz[maxn];int ch[maxn][2];int pre[maxn];int rt,top;inline void up(int x){sz[x]  = cnt[x]  + sz[ L ] + sz[ R ];}inline void Rotate(int x,int f){int y=pre[x];ch[y][!f] = ch[x][f];pre[ ch[x][f] ] = y;pre[x] = pre[y];if(pre[x]) ch[ pre[y] ][ ch[pre[y]][1] == y ] =x;ch[x][f] = y;pre[y] = x;up(y);}inline void Splay(int x,int goal){//将x旋转到goal的下面while(pre[x] != goal){if(pre[pre[x]] == goal) Rotate(x , ch[pre[x]][0] == x);else   {int y=pre[x],z=pre[y];int f = (ch[z][0]==y);if(ch[y][f] == x) Rotate(x,!f),Rotate(x,f);else Rotate(y,f),Rotate(x,f);}}up(x);if(goal==0) rt=x;}inline void RTO(int k,int goal){//将第k位数旋转到goal的下面int x=rt;while(sz[ L ] != k-1) {if(k < sz[ L ]+1) x=L;else {k-=(sz[ L ]+1);x = R;}}Splay(x,goal);}inline void vist(int x){if(x){printf("结点%2d : 左儿子  %2d   右儿子  %2d   val:%2d sz=%d  cnt:%d\n",x,L,R,val[x],sz[x],cnt[x]);vist(L);vist(R);}}void debug() {puts("");vist(rt);puts("");}inline void Newnode(int &x,int c,int f){x=++top;L = R = 0;pre[x] = f;sz[x]=1; cnt[x]=1;val[x] = c;}inline void init(){ch[0][0]=ch[0][1]=pre[0]=sz[0]=0;rt=top=0; cnt[0]=0;}inline void Insert(int &x,int key,int f){if(!x) {Newnode(x,key,f);Splay(x,0);//注意插入完成后splayreturn ;}if(key==val[x]){cnt[x]++;sz[x]++;Splay(x,0);//注意插入完成后splayreturn ;}else if(key<val[x]) {Insert(L,key,x);} else {Insert(R,key,x);}up(x);}void Del_root(){//删除根节点int t=rt;if(ch[rt][1]) {rt=ch[rt][1];RTO(1,0);ch[rt][0]=ch[t][0];if(ch[rt][0]) pre[ch[rt][0]]=rt;}else rt=ch[rt][0];pre[rt]=0;up(rt);}void findpre(int x,int key,int &ans){//找前驱节点if(!x)  return ;if(val[x] <= key){ans=x;findpre(R,key,ans);} elsefindpre(L,key,ans);}void findsucc(int x,int key,int &ans){//找后继节点if(!x) return ;if(val[x]>=key) {ans=x;findsucc(L,key,ans);} elsefindsucc(R,key,ans);}inline int find_kth(int x,int k){ //第k小的数if(k<sz[L]+1) {return find_kth(L,k);}else if(k > sz[ L ] + cnt[x] ) return find_kth(R,k-sz[L]-cnt[x]);else{ Splay(x,0);return val[x];}}int find(int x,int key){if(!x) return 0;else if(key < val[x])  return find(L,key);else if(key > val[x])  return find(R,key);else return x;}int getmin(int x){while(L) x=L;    return val[x];}int getmax(int x){while(R) x=R;   return val[x];}//确定key的排名int getrank(int x,int key,int cur){//cur:当前已知比要求元素(key)小的数的个数if(key == val[x])  return sz[L] + cur + 1;else if(key < val[x])getrank(L,key,cur);else getrank(R,key,cur+sz[L]+cnt[rt]);}int get_lt(int x,int key){//小于key的数的个数 lt:less than if(!x) return 0;if(val[x]>=key) return get_lt(L,key);return cnt[x]+sz[L]+get_lt(R,key);}int get_mt(int x,int key){//大于key的数的个数 mt:more thanif(!x) return 0;if(val[x]<=key) return get_mt(R,key) ;return cnt[x]+sz[R]+get_mt(L,key);}void del(int &x,int f){//删除小于lim的所有的数所在的节点if(!x) return ;if(val[x]>=lim){del(L,x);} else {x=R; pre[x]=f;if(f==0)  rt=x;del(x,f);}if(x)  up(x);}inline void update(){del(rt,0);}int get_mt(int key) {return get_mt(rt,key);}int get_lt(int key) {return get_lt(rt,key);}void insert(int key) {Insert(rt,key,0);    }void Delete(int key) {int node=find(rt,key);Splay(node,0);cnt[rt]--;if(!cnt[rt])Del_root();}int kth(int k) {return find_kth(rt,k);}int cnt[maxn];int val[maxn];int lim;}spt;


原创粉丝点击