保存一份splay模板,慢慢学习。

来源:互联网 发布:闭口粉刺怎么去除知乎 编辑:程序博客网 时间:2024/05/21 06:30
    #include<cstdio>      #include<cstdlib>      const int inf  = ~0u>>2;      #define L ch[x][0]      #define R ch[x][1]      #define KT (ch[ ch[rt][1] ][0])      const int maxn = 500010;      int lim;      struct SplayTree {          int sz[maxn];          int ch[maxn][2];          int pre[maxn];          int rt,top;          inline void up(int x){              sz[x]  = cnt[x]  + sz[ L ] + sz[ R ];          }          inline void Rotate(int x,int f){              int y=pre[x];              ch[y][!f] = ch[x][f];              pre[ ch[x][f] ] = y;              pre[x] = pre[y];              if(pre[x]) ch[ pre[y] ][ ch[pre[y]][1] == y ] =x;              ch[x][f] = y;              pre[y] = x;              up(y);          }          inline void Splay(int x,int goal){//将x旋转到goal的下面              while(pre[x] != goal){                  if(pre[pre[x]] == goal) Rotate(x , ch[pre[x]][0] == x);                  else   {                      int y=pre[x],z=pre[y];                      int f = (ch[z][0]==y);                      if(ch[y][f] == x) Rotate(x,!f),Rotate(x,f);                      else Rotate(y,f),Rotate(x,f);                  }              }              up(x);              if(goal==0) rt=x;          }          inline void RTO(int k,int goal){//将第k位数旋转到goal的下面              int x=rt;              while(sz[ L ] != k-1) {                  if(k < sz[ L ]+1) x=L;                  else {                      k-=(sz[ L ]+1);                      x = R;                  }              }              Splay(x,goal);          }          inline void vist(int x){              if(x){                  printf("结点%2d : 左儿子  %2d   右儿子  %2d   val:%2d sz=%d  cnt:%d\n",x,L,R,val[x],sz[x],cnt[x]);                  vist(L);                  vist(R);              }          }          void debug() {              puts("");              vist(rt);              puts("");          }          inline void Newnode(int &x,int c,int f){              x=++top;              L = R = 0;              pre[x] = f;              sz[x]=1; cnt[x]=1;              val[x] = c;          }          inline void init(){              ch[0][0]=ch[0][1]=pre[0]=sz[0]=0;              rt=top=0; cnt[0]=0;          }          inline void Insert(int &x,int key,int f){              if(!x) {                  Newnode(x,key,f);                  Splay(x,0);//注意插入完成后splay                  return ;              }              if(key==val[x]){                  cnt[x]++;                  sz[x]++;                  Splay(x,0);//注意插入完成后splay                  return ;              }else if(key<val[x]) {                  Insert(L,key,x);              } else {                  Insert(R,key,x);              }              up(x);          }          void Del_root(){//删除根节点              int t=rt;              if(ch[rt][1]) {                  rt=ch[rt][1];                  RTO(1,0);                  ch[rt][0]=ch[t][0];                  if(ch[rt][0]) pre[ch[rt][0]]=rt;              }              else rt=ch[rt][0];              pre[rt]=0;              up(rt);          }          void findpre(int x,int key,int &ans){//找前驱节点              if(!x)  return ;              if(val[x] <= key){                  ans=x;                  findpre(R,key,ans);              } else                  findpre(L,key,ans);          }          void findsucc(int x,int key,int &ans){//找后继节点              if(!x) return ;              if(val[x]>=key) {                  ans=x;                  findsucc(L,key,ans);              } else                  findsucc(R,key,ans);          }          inline int find_kth(int x,int k){ //第k小的数              if(k<sz[L]+1) {                  return find_kth(L,k);              }else if(k > sz[ L ] + cnt[x] )                   return find_kth(R,k-sz[L]-cnt[x]);              else{                   Splay(x,0);                  return val[x];              }          }          int find(int x,int key){              if(!x) return 0;              else if(key < val[x])  return find(L,key);              else if(key > val[x])  return find(R,key);              else return x;          }          int getmin(int x){              while(L) x=L;    return val[x];          }          int getmax(int x){              while(R) x=R;   return val[x];          }          //确定key的排名          int getrank(int x,int key,int cur){//cur:当前已知比要求元素(key)小的数的个数              if(key == val[x])                    return sz[L] + cur + 1;              else if(key < val[x])                  getrank(L,key,cur);              else                   getrank(R,key,cur+sz[L]+cnt[rt]);          }          int get_lt(int x,int key){//小于key的数的个数 lt:less than               if(!x) return 0;              if(val[x]>=key) return get_lt(L,key);              return cnt[x]+sz[L]+get_lt(R,key);          }          int get_mt(int x,int key){//大于key的数的个数 mt:more than              if(!x) return 0;              if(val[x]<=key) return get_mt(R,key) ;              return cnt[x]+sz[R]+get_mt(L,key);          }          void del(int &x,int f){//删除小于lim的所有的数所在的节点              if(!x) return ;              if(val[x]>=lim){                  del(L,x);              } else {                  x=R;                   pre[x]=f;                  if(f==0)  rt=x;                  del(x,f);              }              if(x)  up(x);          }          inline void update(){              del(rt,0);          }          int get_mt(int key) {              return get_mt(rt,key);          }          int get_lt(int key) {              return get_lt(rt,key);          }          void insert(int key) {              Insert(rt,key,0);          }          void Delete(int key) {              int node=find(rt,key);              Splay(node,0);              cnt[rt]--;              if(!cnt[rt])Del_root();          }          int kth(int k) {              return find_kth(rt,k);          }          int cnt[maxn];          int val[maxn];          int lim;      }spt;  


0 0
原创粉丝点击