树链剖分学习笔记

发布于 2018-06-18  395 次阅读


树链剖分是一种树路径信息维护算法。

把整棵树划分成许多条,使每个节点都在唯一的链上,对每一条链维护一棵线段树,把在树上的操作转移到线段树上。

ž将一棵树划分成若干条链,用数据结构去维护每条链,保证每个点在且仅在一条链上,通过数据结构维护这些链的信息,复杂度为$ O(logN) $

接下来我们以一道例题为例。

P3384-树链剖分

如题,已知一棵包含 $ N $ 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

  • 操作1: 格式: $1\ x\ y\ z $ 表示将树从x到y结点最短路径上所有节点的值都加上 $ z $
  • 操作2: 格式: $2\ x\ y $ 表示求树从x到y结点最短路径上所有节点的值之和
  • 操作3: 格式: $3\ x\ z $表示将以x为根节点的子树内所有节点值都加上 $ z $
  • 操作4: 格式: $4\ x\ $表示求以x为根节点的子树内所有节点值之和

有 $ M $ 次操作,其中 $N \leq 10^5,M \leq 10^5$

树链剖分

介绍

我们先明确一下概念:

  • 重儿子:在一个点 $ u $ 的儿子中,拥有最大的 $size$ 值的点 ,就叫做是点 $ u $ 的重儿子。
  • 轻儿子:点 $ u $ 的儿子中除了 $ u $ 的重儿子外的所有儿子。
  • 重边:点 $ u $ 于其重儿子的连边。
  • 轻边:点 $ u $ 与其轻儿子的连边。
  • 重链:仅由重边构成的路径

那么我们不难发现,树链剖分有一些保证复杂度的优秀的性质。

  • 性质1:如果 $ v $ 是 $ u $ 的轻儿子,那么 $ size_v \leq \frac{size_u}{2} $ 。
  • 性质2:任意点 $ u $ 到根的路径上轻边,重链条数都不大于 $ log_2n $ 。

定义

树链剖分需要定义三种结构体,树的节点,树的边,链。

代码定义如下:

struct Node{
    int dfn,size,depth; // dfn 表示 dfs 序,size 表示以该节点为根的子树的大小,depth 表示该节点的深度
    bool vis; // 是否在第一次 dfs 中被访问过
    Node *fa,*maxchild; // fa 表示该节点的父节点,maxchild 表示该节点的重儿子
    struct Edge *firstEdge;
    struct Chain *chain; // 该节点所在的链
}node[MAXN];

struct Edge{
    Node *s,*t;
    Edge *next;
}pool1[MAXN << 1],*frog1 = pool1;

Edge *New1(Node *s,Node *t){
    Edge *ret = ++frog1;
    ret->s = s;ret->t = t;
    ret->next = s->firstEdge;
    return ret;
}

inline void add(int u,int v){
    node[u].firstEdge = New1(&node[u],&node[v]);
    node[v].firstEdge = New1(&node[v],&node[u]);
}

struct Chain{
    Node *top; // 表示链的顶部节点
    // 这里还可以维护链的更多信息,例如长度
}pool2[MAXN],*frog2 = pool2;

Chain *New2(Node *top){
    Chain *ret = ++frog2;
    ret->top = top;
    return ret;
}

剖分过程

这个过程是划分轻重链的过程。

我们用两遍 dfs 来实现。

对于第一遍 dfs,求出每个节点的 $ size,depth,fa,maxchild $ ,对于第二遍 dfs,求出 $ dfn,chain $ (定义见上文)

剖分过程代码如下:

void dfs1(Node *v){ // 第一遍 dfs
    v->size = 1; // 初始化 size
    v->vis = true; // 标记已被访问
    for(Edge *e = v->firstEdge;e;e = e->next){
        if(!e->t->vis){
            e->t->fa = v;
            e->t->depth = v->depth + 1;
            dfs1(e->t);
            v->size += e->t->size; // 累加 size
            if(!v->maxchild || v->maxchild->size < e->t->size)
                v->maxchild = e->t; 
            // 更新重儿子
        }
    }
}

void dfs2(Node *v){
    static int ts = 0;
    // 这里的 static 可以理解为全局变量,不会被重复定义
    v->dfn = ++ts; // 获得 dfn 序列
    if(!v->fa || v->fa->maxchild != v)
        v->chain = New2(v);
    else v->chain = v->fa->chain;
    if(v->maxchild)
        dfs2(v->maxchild); // 优先遍历重儿子
    for(Edge *e = v->firstEdge;e;e = e->next){
        if(e->t->fa == v && e->t != v->maxchild)
            dfs2(e->t);
    }
}

inline void split(int root){
    node[root].depth = 1;
    dfs1(&node[root]);
    dfs2(&node[root]);
}

维护链上信息

我们使用线段树来维护信息,每个线段树节点的编号是树上节点所对应的 $ dfn $ 值。

线段树(或其他数据结构)的定义要随着题目的改变而改变。

SegmentTree *segt = SegmentTree::build(1,N)

查询&修改

对于树链剖分后的查询和修改操作,只需要考虑如何转换到线段树上的区间修改操作即可。

我们以操作 $ 1 $ 和 $ 2 $ 为例,设 $ u $ 的深度更深,我们先让 $ u $ 跳到 $ v $ 所在的链上,途中进行 修改/查询 ,最后对 $ u $ 和 $ v $ 剩下的区域进行 修改/查询 即可。

inline void modify1(int x,int y,LL delta){
    Node *u = &node[x],*v = &node[y]; 
    while(u->chain != v->chain){   // 调到同一个链上,
        if(u->chain->top->depth < v->chain->top->depth) std::swap(u,v);
        segt->modify(u->chain->top->dfn,u->dfn,delta);
        u = u->chain->top->fa;
    }
    if(u->depth > v->depth) std::swap(u,v);
    segt->modify(u->dfn,v->dfn,delta);
}

inline LL query1(int x,int y){
    Node *u = &node[x],*v = &node[y];
    LL ret = 0;
    while(u->chain != v->chain){
        if(u->chain->top->depth < v->chain->top->depth) std::swap(u,v);
        ret = (ret + segt->query(u->chain->top->dfn,u->dfn)) % ha;
        u = u->chain->top->fa;
    }
    if(u->depth > v->depth) std::swap(u,v);
    ret = (ret + segt->query(u->dfn,v->dfn)) % ha;
    return ret;
}

操作二留给读者作为思考题,要考虑修改的内容和 $ dfn $ 序的关系,进而才能确定与线段树修改的关系。

最近公共祖先

树剖求最近公共祖先,也是先让 $ u $ 跳到 $ v $ 所在的链上,然后返回 $ depth $ 小的点。

Node *lca(Node *u,Node *v){
    while(u->chain != v->chain){
        if(u->chain->top->depth < v->chain->top->depth) std::swap(u,v);
        u = u->chain->top->fa;
    }
    if(u->depth > v->depth) std::swap(u,v);
    return u;
}

全部代码

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

const int MAXN = 1000000 + 5;

#define LL long long
#define ULL unsigned long long
#define DEBUG(x) std::cerr << #x << '=' << x << std::endl

int N,M,R,ha;
LL dist[MAXN];

struct Node{
    int dfn,size,depth;
    bool vis;
    Node *fa,*maxchild;
    struct Edge *firstEdge;
    struct Chain *chain;
}node[MAXN];

struct Edge{
    Node *s,*t;
    Edge *next;
}pool1[MAXN << 1],*frog1 = pool1;

Edge *New1(Node *s,Node *t){
    Edge *ret = ++frog1;
    ret->s = s;ret->t = t;
    ret->next = s->firstEdge;
    return ret;
}

inline void add(int u,int v){
    node[u].firstEdge = New1(&node[u],&node[v]);
    node[v].firstEdge = New1(&node[v],&node[u]);
}

struct Chain{
    Node *top;
}pool2[MAXN],*frog2 = pool2;

Chain *New2(Node *top){
    Chain *ret = ++frog2;
    ret->top = top;
    return ret;
}

void dfs1(Node *v){
    v->size = 1;
    v->vis = true;
    for(Edge *e = v->firstEdge;e;e = e->next){
        if(!e->t->vis){
            e->t->fa = v;
            e->t->depth = v->depth + 1;
            dfs1(e->t);
            v->size += e->t->size;
            if(!v->maxchild || v->maxchild->size < e->t->size)
                v->maxchild = e->t;
        }
    }
}

void dfs2(Node *v){
    static int ts = 0;
    
    v->dfn = ++ts;
    if(!v->fa || v->fa->maxchild != v)
        v->chain = New2(v);
    else v->chain = v->fa->chain;
    if(v->maxchild)
        dfs2(v->maxchild);
    for(Edge *e = v->firstEdge;e;e = e->next){
        if(e->t->fa == v && e->t != v->maxchild)
            dfs2(e->t);
    }
}

inline void split(int root){
    node[root].depth = 1;
    dfs1(&node[root]);
    dfs2(&node[root]);
}

struct SegmentTree;
SegmentTree *New3(int ,int ,SegmentTree *,SegmentTree *);

struct SegmentTree {
    int l,r;
    LL sum,tag;
    SegmentTree *lc,*rc;
    
    static SegmentTree *build(int l,int r){
        int mid = (l + r) >> 1;
        return (l == r) ? New3(l,r,NULL,NULL) : New3(l,r,build(l,mid),build(mid + 1,r));
    }
    
    inline void pushup(){
        sum = (lc->sum + rc->sum) % ha;
    }
    
    inline void cover(LL delta){
        sum = (sum + ((r - l + 1) * delta) % ha) % ha;
        tag = (tag + delta) % ha;
    }
    
    inline void pushdown(){
        if(tag){
            lc->cover(tag);
            rc->cover(tag);
            tag = 0;
        }
    }
    
    void update(int pos,LL x){
        if(l == r){
            sum = x;
            return;
        }
        LL mid = (l + r) >> 1;
        if(pos <= mid) lc->update(pos,x);
        else rc->update(pos,x);
        pushup();
    }
    
    void modify(int left,int right,LL delta){
        if(left == l && right == r){
            cover(delta);
            return;
        }
        if(left > r || right < l) return;
        pushdown();
        int mid = (l + r) >> 1;
        if(right <= mid) lc->modify(left,right,delta);
        else if(left > mid) rc->modify(left,right,delta);
        else{
            lc->modify(left,mid,delta);
            rc->modify(mid + 1,right,delta);
        }
        pushup();
    }
    
    LL query(int left,int right){
        if(left == l && right == r) return sum%ha;
        if(left > r || right < l) return 0;
        pushdown();
        int mid = (l + r) >> 1;
        if(right <= mid) return lc->query(left,right);
        else if(left > mid) return rc->query(left,right);
        return (lc->query(left,mid) + rc->query(mid + 1,right))%ha;
    }
}pool3[MAXN << 2],*frog3 = pool3,*segt;

SegmentTree *New3(int l,int r,SegmentTree *lc,SegmentTree *rc){
    SegmentTree *ret = ++frog3;
    ret->l = l;ret->r = r;
    ret->lc = lc;ret->rc = rc;
    ret->sum = ret->tag = 0;
    return ret;
}

inline void update(int x,LL delta){
    Node *v = &node[x];
    segt->update(v->dfn,delta);
}

inline void modify1(int x,int y,LL delta){
    Node *u = &node[x],*v = &node[y];
    while(u->chain != v->chain){
        if(u->chain->top->depth < v->chain->top->depth) std::swap(u,v);
        segt->modify(u->chain->top->dfn,u->dfn,delta);
        u = u->chain->top->fa;
    }
    if(u->depth > v->depth) std::swap(u,v);
    segt->modify(u->dfn,v->dfn,delta);
}

inline LL query1(int x,int y){
    Node *u = &node[x],*v = &node[y];
    LL ret = 0;
    while(u->chain != v->chain){
        if(u->chain->top->depth < v->chain->top->depth) std::swap(u,v);
        ret = (ret + segt->query(u->chain->top->dfn,u->dfn)) % ha;
        u = u->chain->top->fa;
    }
    if(u->depth > v->depth) std::swap(u,v);
    ret = (ret + segt->query(u->dfn,v->dfn)) % ha;
    return ret;
}

inline void modify2(int x,LL delta){
    Node *v = &node[x];
    segt->modify(v->dfn,v->dfn + v->size - 1,delta);
}

inline LL query2(int x){
    Node *v = &node[x];
    return segt->query(v->dfn,v->dfn + v->size - 1) % ha;
}

int main(){
    scanf("%d%d%d%d",&N,&M,&R,&ha);
    segt = SegmentTree::build(1,N);
    for(int i = 1;i <= N;i++)
        scanf("%lld",dist + i);
    for(int u,v,i = 1;i < N;i++){
        scanf("%d%d",&u,&v);
        add(u,v);
    }
    split(R);
    for(int i = 1;i <= N;i++)
        update(i,dist[i]);
    int opt,x,y;LL delta;
    while(M--){
        scanf("%d",&opt);
        switch(opt){
            case 1:
                scanf("%d%d%lld",&x,&y,&delta);
                modify1(x,y,delta);
                break;
            case 2:
                scanf("%d%d",&x,&y);
                printf("%lld\n",query1(x,y)%ha);
                break;
            case 3:
                scanf("%d%lld",&x,&delta);
                modify2(x,delta);
                break;
            case 4:
                scanf("%d",&x);
                printf("%lld\n",query2(x)%ha);
                break;
        }
    }
    return 0;
}

思考题答案:我们考虑修改一个以 $ v $ 为根子树的时候,最小的 $ dfn $ 的节点是 $ v $,$ dfn $ 最大值就是 $ dfn_v + size_v - 1 $。


一个OIer。