Splay 如何维护区间信息

发布于 2019-02-02  432 次阅读


如何维护区间

我们都知道平衡树是一种很神奇的东西,可以用于维护动态的序列总信息,但是如果想维护其中的某一个区间的信息,我们该怎么做好呢?
我们考虑换一种维护方式:平衡树的结构不是按照节点的值的大小来进行建树,而是以这个节点在序列中的位置为关键字来维护。也就是说对于平衡树中任意一个节点 $x$ ,左子树中的元素永远在它前面,右子树的元素永远在它后面。显然一个区间可以对应一个子树,所以我们要使用能提取子树的数据结构,比如 Splay 和 非旋 Treap。
比如我们考虑提取 $[l,r]$ 的区间信息:我们首先将 $l-1$ 节点旋转到根,Splay 变成了这样:

然后我们接着把 $r+1$ 旋转到 $l-1$ 的下方,也就变成了这样:

我们发现 $R+1$ 节点的左子树就是我们要的区间!所以我们直接对这个子树进行想要的操作即可,复杂度是旋转的复杂度即均摊$O(logn)$。
接下来我们来看一道例题。

「BZOJ1500」「NOI2005」维修数列

题目链接
题目大意:要求你维护一个数列,支持以下操作:

任何时刻数列中最多含有 $5\times 10^5$ 个数,数列中任何一个数字均在 $[-10^3, 10^3]$ 内。
插入的数字总数不超过 $4 \times 10^6$个,输入文件大小不超过 $\text{20M}$。
显然这就是用 Splay 维护区间信息的裸题,接下来我们对于每个操作都分析一遍。

Insert 操作

看起来暴力插入所有的数非常不可取。。。随便就能卡掉的样。
我们不妨考虑一下如果我们有了一个序列里的所有的树,想得到这个序列对应的 Splay 可以有什么方便的方法。显然我们可以直接暴力递归建树,这样不仅快 ($O(nlogn)$) 得到的树也十分平衡。具体建树过程就是每次选择中点作为当前的根,然后递归处理中点左边和右边。
我们不妨把输入的序列先搞个 Splay 出来,然后我们发现在 $x$ 后插入,如果我们按照上面的方法提取出区间 $[x,x+1]$,那么 $x+1$ 的左子树不久应该是我们新建出来的 Splay 吗。。。直接接上去就可以了。

Delete 操作

和 Insert 操作相同的思路:我们直接把这段区间提取出来,不过我们发现现在我们是要删除,所以直接断开这个区间代表的子树与父节点的连接就可以了。但是在这题中为了节省运行空间建议写一个内存回收池来回收这些被删除的节点,这个操作甚至比插入还简单。

MAKE-SAME 操作

考虑把这一段区间修改成同一个数。在线段树上区间操作我们有 lazy-tag 可用,在平衡树上也可以。我们提取出这个区间,然后在根节点打个区间覆盖的标记就行了。注意每次访问这个区间前标记必须要下放。

翻转

学过 LCT 的都会,没学过的参考一下文艺平衡树的做法就可以了。
我们发现建出的 Splay 同时保留了原来序列的性质:如果我们中序遍历这棵 Splay,那么得到的序列就是原序列!我们考虑区间翻转意味着什么:是不是就是对于这个区间里的每一个节点左右儿子都互换啊(因为互换后现在先遍历的是原来后遍历的,并且子树也都交换了)。所以我们同时维护一下翻转标记就可以了。
注意到下放标记的时候如果又有区间覆盖的标记又有翻转的标记,这时候我们可以忽略翻转标记(因为你都是一个数了翻转没有意义),这样能优化不少常数。

求和

我们对于 Splay 的每一个节点,维护它及其它子树中所有节点代表的原序列中的元素的和。查询的时候提取区间直接输出就可以了。

求最大子列

我们考虑一种基于分治的求最大子列方法:对于每个区间都维护从左边开始的最大子列,从右边开始的最大子列和该区间的最大子列,分别记为 $ls,rs,ans$。
对于 ls 的维护:我们考虑从左边开始的最大子列是否覆盖了整个左儿子代表的区间,rs 同理。
对于 ans 就有多种可能了:全都在左儿子区间的,全都在右儿子区间的,跨过中间点的。
pushup 的时候分别维护即可,可以 $O(1)$ 维护。
于是这题的代码也就可以轻松写出来啦:

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

#define fi first
#define lc (ch[x][0])
#define se second
#define U unsigned
#define rc (ch[x][1])
#define Re register
#define LL long long
#define MP std::make_pair
#define CLR(i,a) memset(i,a,sizeof(i))
#define FOR(i,a,b) for(Re int i = a;i <= b;++i)
#define ROF(i,a,b) for(Re int i = a;i >= b;--i)
#define SFOR(i,a,b,c) for(Re int i = a;i <= b;i+=c)
#define SROF(i,a,b,c) for(Re int i = a;i >= b;i-=c)
#define DEBUG(x) std::cerr << #x << '=' << x << std::endl

const int MAXN = 500000+5;

int f[MAXN],ch[MAXN][2],size[MAXN],val[MAXN],sum[MAXN],ls[MAXN],rs[MAXN],max[MAXN];
int rev[MAXN],tag[MAXN];

int a[MAXN],bin[MAXN],cnt,top,N,M;

inline void pushup(int x){
    size[x] = size[lc] + size[rc] + 1;
    sum[x] = sum[lc] + sum[rc] + val[x];
    max[x] = std::max(std::max(max[lc],max[rc]),rs[lc]+ls[rc]+val[x]);
    ls[x] = std::max(ls[lc],sum[lc]+val[x]+ls[rc]);
    rs[x] = std::max(rs[rc],sum[rc]+val[x]+rs[lc]);
}

inline void pushdown(int x){
    if(tag[x]){
        tag[x] = rev[x] = false;
        if(lc){
            tag[lc] = true;val[lc] = val[x];sum[lc] = val[x]*size[lc];
        }
        if(rc){
            tag[rc] = true;val[rc] = val[x];sum[rc] = val[x]*size[rc];
        }
        if(val[x] >= 0){
            if(lc) ls[lc] = rs[lc] = max[lc] = sum[lc];
            if(rc) ls[rc] = rs[rc] = max[rc] = sum[rc];
        }
        else{
            if(lc) ls[lc] = rs[lc] = 0,max[lc] = val[x];
            if(rc) ls[rc] = rs[rc] = 0,max[rc] = val[x];
        }
    }
    if(rev[x]){
        rev[x] = 0;rev[lc] ^= 1;rev[rc] ^= 1;
        std::swap(ls[lc],rs[lc]);std::swap(ls[rc],rs[rc]);
        std::swap(ch[lc][0],ch[lc][1]);std::swap(ch[rc][0],ch[rc][1]);
    }
}

inline void rotate(int x){
    int y = f[x],z = f[y],k = ch[y][1] == x,w = ch[x][!k];
    ch[z][ch[z][1] == y] = x;f[x] = z;
    ch[x][!k] = y;f[y] = x;
    ch[y][k] = w;f[w] = y;
    pushup(y);pushup(x);
}
int root;
inline void splay(int x,int v){
    int y,z;
    while((y = f[x]) != v){
        if((z = f[y]) != v) rotate((ch[z][1] == y)^(ch[y][1] == x) ? x : y);
        rotate(x);
    }
    if(!v) root = x;
}

int getkth(int rk){
    int x = root;
    while(233){
        pushdown(x);
        if(rk <= size[lc]) x = lc;
        else if(rk == size[lc]+1) return x;
        else rk -= size[lc]+1,x = rc;
    }
}

inline int build(int l,int r,int fa){
    if(l > r) return 0;
    int mid = (l + r) >> 1,x;
    x = top ? bin[top--] : ++cnt;
    f[x] = fa;val[x] = a[mid];
    rev[x] = tag[x] = 0;
    lc = build(l,mid-1,x);rc = build(mid+1,r,x);
    pushup(x);return x;
}

inline void insert(int l,int tot){
    int r = l+1;
    l = getkth(r);r = getkth(r+1);
    splay(l,0);splay(r,l);
    FOR(i,1,tot) scanf("%d",a+i);
    ch[r][0] = build(1,tot,r);
    N += tot;pushup(r);pushup(l);
}

inline void erase(int x){
    if(!x) return;
    bin[++top] = x;
    erase(lc);erase(rc);
}

inline void del(int l,int r){
    N -= r-l+1;
    l = getkth(l);r = getkth(r+2);
    splay(l,0);splay(r,l);
    erase(ch[r][0]);ch[r][0] = 0;
    pushup(r);pushup(l);
}

inline void modify(int l,int r,int v){
    l = getkth(l);r = getkth(r+2);
    splay(l,0);splay(r,l);
    int x = ch[r][0];
    val[x] = v;sum[x] = v*size[x];
    if(v <= 0) ls[x] = rs[x] = 0,max[x] = v;
    else ls[x] = rs[x] = max[x] = sum[x];
    tag[x] = 1;
    pushup(r);pushup(l);
}

inline void reverse(int l,int r){
    l = getkth(l);r = getkth(r+2);
    splay(l,0);splay(r,l);
    int x = ch[r][0];
    if(!tag[x]){
        rev[x] ^= 1;std::swap(lc,rc);std::swap(ls[x],rs[x]);
        pushup(r);pushup(l);
    }
}

inline int querysum(int l,int r){
    l = getkth(l);r = getkth(r+2);
    splay(l,0);splay(r,l);
    return sum[ch[r][0]];
}

inline int calc(){
    int l = getkth(1),r = getkth(N+2);
    splay(l,0);splay(r,l);
    return max[ch[r][0]];
}

int main(){
    scanf("%d%d",&N,&M);
    max[0] = a[1] = a[N+2] = INT_MIN;
    FOR(i,1,N) scanf("%d",a+i+1);
    root = build(1,N+2,0);
    while(M--){
        char opt[10];
        int x,y,z;
        scanf("%s",opt);
        if(opt[0] == 'I'){
            scanf("%d%d",&x,&y);
            insert(x,y);
        }
        if(opt[0] == 'D'){
            scanf("%d%d",&x,&y);
            del(x,x+y-1);
        }
        if(opt[0] == 'M' && opt[2] == 'K'){
            scanf("%d%d%d",&x,&y,&z);
            modify(x,x+y-1,z);
        }
        if(opt[0] == 'R'){
            scanf("%d%d",&x,&y);
            reverse(x,x+y-1);
        }
        if(opt[0] == 'G'){
            scanf("%d%d",&x,&y);
            printf("%d\n",querysum(x,x+y-1));
        }
        if(opt[0] == 'M' && opt[2] == 'X'){
            printf("%d\n",calc());
        }
    }
    return 0;
}

写在最后

至于我为什么不写听起来更简单的 fhq treap,因为我们机房里的人天天吹它,于是我就不想学了。。。可能以后接触到可持久化的时候才会学QAQ


一个OIer。