Splay 学习笔记 · 续

发布于 2018-10-28  311 次阅读


之前我写的 平衡树学习笔记 大体介绍了一下两种平衡树。这里详细的介绍 Splay 。
主要是从代码层面详细介绍,默认读者已经掌握了基本原理。

Splay 代码详解

结构体定义

struct Node{
    Node *ch[2];
    int val,cnt,size;
}

$ch$ 指针指向这个节点的左右儿子,$val$ 表示该节点表示的元素的值,$cnt$ 表示该元素出现了几次,$size$ 表示以该节点为子数的值

一些结构体函数

int cmp(int v){ // 在寻找元素 v 的时候判断在左子树还是右子树:0 左子树,1 右子树,-1 当前节点
    if(v == val) return -1;
    return v < val ? 0 : 1;
}

int cmpk(int k){ // 在寻找第 k 大元素的时候判断在左子树还是右子树:0 左子树,1 右子树,-1 当前节点
    if(k <= ch[0]->size) return 0;
    if(k <= ch[0]->size+cnt) return -1;
    return 1;
}

void pushup(){ // 维护 size
    size = cnt + ch[0]->size + ch[1]->size;
}

Rotate 操作

void rotate(Node *&v,int d){ // d: 0 左旋 1 右旋
    Node *t = v->ch[d^1];
    v->ch[d^1] = t->ch[d];
    t->ch[d] = v;
    v->pushup();t->pushup();
    v = t;
}

这里我们使用异或来把左旋和右旋合并成一个代码,如果对该过程有疑问清参考我之前的平衡树文章。

Splay 操作

void splay(Node *&v,int val){ // 查询 val 并伸展为父节点
    int d = v->cmp(val);
    if(d != -1 && v->ch[d] != nil){
        int d2 = v->ch[d]->cmp(val);
        if(d2 != -1 && v->ch[d]->ch[d2] != nil){ // 如果目标节点在儿子的儿子以后了就递归伸展上来
            splay(v->ch[d]->ch[d2],val);
            if(d == d2) // 一字型旋转
                rotate(v,d2^1),rotate(v,d^1);
            else rotate(v->ch[d],d2^1),rotate(v,d^1); // 工字型旋转
        }
        else rotate(v,d^1); // 在儿子的儿子节点上直接单旋上来
    }
}
// splayk 的大部分操作同理
void splayk(Node *&v,int k){ // 查询排名 k 并伸展为父节点
    int d = v->cmpk(k);
    if(d == 1) k -= v->ch[0]->size + v->cnt; // 如果在右子树里就计算在右子树的排名,类似于权值线段树查询思想qaq
    if(d != -1){
        int d2 = v->ch[d]->cmpk(k);
        int k2 = (d2 == 1) ? k - (v->ch[d]->ch[0]->size + v->ch[d]->cnt) : k; // 找到在第二个子树里的排名
        if(d2 != -1){
            splayk(v->ch[d]->ch[d2],k2);
            if(d == d2)
                rotate(v,d2^1),rotate(v,d^1);
            else rotate(v->ch[d],d2^1),rotate(v,d^1);
        }
        else rotate(v,d^1);
    }
}

这里注意的是每个指针都要引用和判断一下空指针。

Split 操作和 Merge 操作

Node *split(Node *&v,int val){ // 分成两棵树,返回的树中值大于val,v树中的值小于等于val
    if(v == nil) return nil;
    splay(v,val); // 将这个元素伸展到根
    Node *t1,*t2;
    if(v->val <= val) // 按照该元素划分成两个树
        t1 = v;t2 = v->ch[1];v->ch[1] = nil;
    else
        t1 = v->ch[0];t2 = v;v->ch[0] = nil;
    v->pushup(); // 更新size
    v = t1;
    return t2;
}

void merge(Node *&t1,Node *&t2){ // t2合并到t1上默认t2较大
    if(t1 == nil) std::swap(t1,t2);
    splay(t1,INT_MAX); // 都放到左子树上
    t1->ch[1] = t2;t2 = nil; // 将较大的 t2 放在 t1 右子树上
    t1->pushup();
}

Split 操作就是首先将 $val$ 节点旋转到根然后分为左子树和右子树。
Merge 直接合并,比线段树合并简单 $qwq$。

Insert 操作和 Erase 操作

void insert(Node *&v,int val){
    Node *t2 = split(v,val);
    if(v->val == val) v->cnt++;
    else{
        Node *cur = New(val);
        merge(v,cur);
    }
    merge(v,t2);
}

void erase(Node *&v,int val){
    Node *t2 = split(v,val);
    if(v->val == val && !--v->cnt){
        v = v->ch[0];
    }
    merge(v,t2);
}

Insert 操作原理:如果已经存在就将这个元素旋转到根并且更新 $cnt$,否则按照 Split 出来的两棵子树中间合并上一个新节点。
Erase 操作原理:将元素旋转到根。如果删除一次后仍存在则更新cnt,否则将其左右子树合并并且删除该节点(可选)。

其他查询操作

int pre(Node *&v,int val){ // 查询 val 的前驱
    splay(v,val);
    if(v->val >= val){
        if(v->ch[0] == nil) return INT_MIN;
        splay(v->ch[0],INT_MAX);
        return v->ch[0]->val;
    }
    return v->val;
}


int succ(Node *&v,int val){ // 查询 val 的后继
    splay(v,val);
    if(v->val <= val){
        if(v->ch[1] == nil) return INT_MAX;
        splay(v->ch[1],INT_MIN);
        return v->ch[1]->val;
    }
    return v->val;
}

int get(Node *&v,int value){ // 数value的排名
    splay(v,value);
    return v->ch[0]->size+1;
}

int getk(Node *&v,int rank){ // 排名为rank的数
    splayk(v,rank);
    return v->val;
}

Pre 操作:等价于将序列取反后进行 upper_bound。这里只需要将 $val $ 伸展到根并且判断一下根节点是不是 $val$ 就可以了。
Succ 操作:等价于 upper_bound。和 Pre 操作同理,就是判断有一些小细节不一样。
GetRank 操作:查询一个数的排名,将 $val$ 旋转到根,这个数的排名就是小于它的数的数量 $+1$,即该节点左子树的 $size$ 加上 $1$。
GetKth 操作:查询排名为 $k$ 的数,用另一种伸展方式将排名为 $k$ 的数伸展到根即可。

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

#define Re register
#define LL long long
#define U unsigned
#define FOR(i,a,b) for(Re int i = a;i <= b;++i)
#define ROF(i,a,b) for(Re int i = a;i >= b;--i)
#define SFOR(i,a,b,c) for(Re int i = a;i <= b;i+=c)
#define SROF(i,a,b,c) for(Re int i = a;i >= b;i-=c)
#define CLR(i,a) memset(i,a,sizeof(i))
#define BR printf("--------------------\n")
#define DEBUG(x) std::cerr << #x << '=' << x << std::endl

const int MAXN = 100000+5;

int N,root,tot,fa[MAXN],ch[MAXN][2],size[MAXN],cnt[MAXN],val[MAXN];

inline void pushup(int x){
    size[x] = size[ch[x][0]]+size[ch[x][1]]+cnt[x];
}

inline bool get(int x){ // 是父节点的哪里
    return x == ch[fa[x]][1];
}

inline void rotate(int x){
    int y = fa[x],z = fa[y],k = get(x);
    ch[z][get(y)] = x;fa[x] = z;
    ch[y][k] = ch[x][k^1];fa[ch[x][k^1]] = y;
    ch[x][k^1] = y;fa[y] = x;
    pushup(y);pushup(x);
}

inline void splay(int x,int v){
    int y;
    while((y = fa[x]) != v){
        int y = fa[x];
        if(fa[y] != v) rotate(get(x) == get(y) ? y : x);
        rotate(x);
    }
    if(!v) root = x;
}

void find(int x){
    if(!root) return;
    int v = root;
    while(x != val[v] && ch[v][x>val[v]]) v = ch[v][x>val[v]];
    splay(v,0);
}

int pre(int x){
    find(x);
    if(val[root] < x) return root;
    int v = ch[root][0];
    while(ch[v][1]) v = ch[v][1];
    return v;
}

int suf(int x){
    find(x);
    if(val[root] > x) return root;
    int v = ch[root][1];
    while(ch[v][0]) v = ch[v][0];
    return v;
}

int rank(int x){
    find(x);
    return size[ch[root][0]];
}

int kth(int x){
    ++x;int v = root;
    while(true){
        if(x > size[ch[v][0]]+cnt[v]) x -= size[ch[v][0]]+cnt[v],v = ch[v][1];
        else if(x <= size[ch[v][0]]) v = ch[v][0];
        else return val[v];
    }
}

void insert(int x){
    int v = root,t = 0;
    while(x != val[v] && v) t = v,v = ch[v][x>val[v]];
    if(v) ++cnt[v];
    else{
        v = ++tot;
        if(t) ch[t][x>val[t]] = v;
        ch[v][0] = ch[v][1] = 0;fa[v] = t;val[v] = x;size[v] = cnt[v] = 1;
    }
    splay(v,0);
}

void del(int x){
    int last = pre(x),next = suf(x);
    splay(last,0);splay(next,last);
    int v = ch[next][0];
    if(cnt[v] > 1) --cnt[v],splay(v,0);
    else ch[next][0] = 0;
}

int main(){
    insert(INT_MIN);insert(INT_MAX);
    scanf("%d",&N);
    while(N--){
        int opt,x;scanf("%d%d",&opt,&x);
        if(opt == 1) insert(x);
        if(opt == 2) del(x);
        if(opt == 3) printf("%d\n",rank(x));
        if(opt == 4) printf("%d\n",kth(x));
        if(opt == 5) printf("%d\n",val[pre(x)]);
        if(opt == 6) printf("%d\n",val[suf(x)]);
    }
    return 0;
}

终于写完了 $qaq$ 其实 Splay 也不是那么的难....并且还能适用于一些其他的数据结构例如 LCT
那我们来看另一道用 Splay 维护区间信息的题目吧 $qwq$.

「BZOJ3223」文艺平衡树

题目链接

题目描述

您需要写一种数据结构(可不参考题目标题),来维护一个有序数列,其中需要提供以下操作:
- 翻转一个区间,
例如原有序序列是5 4 3 2 1,翻转区间是 $[2,4]$ 的话,结果是5 2 3 4 1
设序列长度为 $n$,有 $n \leq 10^5$。

题解

其实区间操作也不是那么难 $qwq$
首先我们维护这些数的位置。如果把它们扔进 Splay 之后我们发现查询第 $i$ 个数实际上就是查询 Splay 里排名第 $i$ 的数。
区间翻转相当于交换了左右子树。
所以维护一个翻转标记然后每次访问节点前交换左右子树并且下方标记就好啦 $qwq$

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

#define fi first
#define lc (ch[x][0])
#define se second
#define U unsigned
#define rc (ch[x][1])
#define Re register
#define LL long long
#define MP std::make_pair
#define CLR(i,a) memset(i,a,sizeof(i))
#define FOR(i,a,b) for(Re int i = a;i <= b;++i)
#define ROF(i,a,b) for(Re int i = a;i >= b;--i)
#define SFOR(i,a,b,c) for(Re int i = a;i <= b;i+=c)
#define SROF(i,a,b,c) for(Re int i = a;i >= b;i-=c)
#define DEBUG(x) std::cerr << #x << '=' << x << std::endl

const int MAXN = 100000+5;
int ch[MAXN][2],f[MAXN],tag[MAXN],val[MAXN],size[MAXN];

inline void pushup(int x){
    size[x] = size[lc] + size[rc] + 1;
}

inline void pushdown(int x){
    if(tag[x]){
        std::swap(lc,rc);
        if(lc) tag[lc] ^= 1;
        if(rc) tag[rc] ^= 1;
        tag[x] = 0;
    }
}

inline void rotate(int x){
    int y = f[x],z = f[y],k = ch[y][1] == x,w = ch[x][!k];
    ch[z][ch[z][1] == y] = x;f[x] = z;
    ch[x][!k] = y;f[y] = x;
    ch[y][k] = w;f[w] = y;
    pushup(y);pushup(x);
}
int root;
inline void splay(int x,int v){
    int y,z;
    while((y = f[x]) != v){
        if((z = f[y]) != v) rotate((ch[y][1] == x) ^ (ch[z][1] == y) ? x : y);
        rotate(x);
    }
    if(!v) root = x;
}
int cnt = 0;
inline void insert(int v){
    int x = root,fa = 0;
    while(x) fa = x,x = v > val[x] ? rc : lc;
    x = ++cnt;if(fa) ch[fa][v>val[fa]] = x;
    size[x] = 1;val[x] = v;f[x] = fa;lc = rc = 0;
    splay(x,0);
}

inline int getkth(int rak){
    int x = root;
    while(233){
        pushdown(x);
        if(size[lc] >= rak) x = lc;
        else if(size[lc]+1 == rak) return x;
        else rak -= size[lc]+1,x = rc;
    }
}

inline void change(int l,int r){
    l = getkth(l);r = getkth(r+2);
    splay(l,0);splay(r,l);
    tag[ch[ch[root][1]][0]] ^= 1;
}
int N,M;
inline void dfs(int x){
    pushdown(x);
    if(lc) dfs(lc);
    if(val[x] > 1 && val[x] < N+2) printf("%d ",val[x]-1);
    if(rc) dfs(rc);
}

int main(){
    scanf("%d%d",&N,&M);
    FOR(i,1,N+2) insert(i);
    FOR(i,1,M){
        int l,r;scanf("%d%d",&l,&r);
        change(l,r);
    }
    dfs(root);
    return 0;
}

一个OIer。