树上启发式合并学习笔记

发布于 2019-01-29  322 次阅读


首先什么是启发式算法?启发式算法是基于人类的经验和直观感觉,对一些算法的优化。
最简单的就是并查集的按秩合并,每次把小的合并到大的上面,这样找父亲的复杂度就小了很多。
树上启发式合并(dsu on tree)是一种我也不知道为什么叫做这个名字的奇怪算法。
特点是可以在 $O(nlogn)$ 的时间内完成对无修改的子树的统计。
我们考虑这样一个问题:给树上的每一个节点一个颜色,询问每个节点上的子树有多少个颜色为 $k$ 的节点。
暴力十分简单,枚举一下就可以了,复杂度是 $O(n^2)$。
我们来考虑优化一下暴力,首先对于这个树重链剖分一下。
考虑一下每个节点的答案是由其子树得来的,所以可以用桶来合并这些答案。
我们开一个全局的桶,每次统计 $x$ 这个节点时,首先暴力递归统计轻儿子的答案,注意每次退出轻儿子的时候要清空贡献。然后统计重儿子,此时不用清空重儿子的贡献,最后把轻儿子再统计一遍就能得到 $x$ 节点的桶了。
发现重链剖分将一棵树分割为不超过 $logn$ 条重链,所以每一个节点最多向上合并 $logn$ 次,时间复杂度是 $O(nlogn)$ 的。
接下来我们来看一些例题深入的了解一下这种做法。
CF600E
题目大意:$n$ 个点的有根树,以 $1$ 为根,每个点有一种颜色。我们称一种颜色占领了一个子树当且仅当没有其他颜色在这个子树中出现得比它多。求占领每个子树的所有颜色之和。
这个题目就是开个全局桶统计一下每个颜色出现的次数,用上面的做法直接做就可以了。

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

#define Re register
#define LL long long
#define U unsigned
#define FOR(i,a,b) for(Re int i = a;i <= b;++i)
#define ROF(i,a,b) for(Re int i = a;i >= b;--i)
#define SFOR(i,a,b,c) for(Re int i = a;i <= b;i+=c)
#define SROF(i,a,b,c) for(Re int i = a;i >= b;i-=c)
#define CLR(i,a) memset(i,a,sizeof(i))
#define BR printf("--------------------\n")
#define DEBUG(x) std::cerr << #x << '=' << x << std::endl
#define int LL
const int MAXN = 100000+5;

int head[MAXN],cnt,size[MAXN],son[MAXN],dep[MAXN],fa[MAXN];
int ans[MAXN],val[MAXN],c[MAXN],maxv,sum;
int N;

struct Edge{
    int to,next;
}e[MAXN<<1];

inline void add(int u,int v){
    e[++cnt] = (Edge){v,head[u]};head[u] = cnt;
}

void dfs1(int v){
    size[v] = 1;
    for(int i = head[v];i;i = e[i].next){
        if(e[i].to == fa[v]) continue;
        dep[e[i].to] = dep[v]+1;fa[e[i].to] = v;
        dfs1(e[i].to);size[v] += size[e[i].to];
        son[v] = (size[son[v]] < size[e[i].to]) ? e[i].to : son[v];
    }
}

bool vis[MAXN];

void change(int v,int fa,int k){
    c[val[v]] += k;
    if(k > 0 && c[val[v]] >= maxv){
        if(c[val[v]] > maxv) sum = 0,maxv = c[val[v]];
        sum += val[v];
    }
    for(int i = head[v];i;i = e[i].next){
        if(vis[e[i].to] || e[i].to == fa) continue;
        change(e[i].to,v,k);
    }
}

void dfs2(int v,int fa=0,bool used=false){
    for(int i = head[v];i;i = e[i].next){
        if(e[i].to == fa || e[i].to == son[v]) continue;
        dfs2(e[i].to,v);
    }
    if(son[v]) dfs2(son[v],v,1),vis[son[v]] = true;
    change(v,fa,1);ans[v] = sum;
    if(son[v]) vis[son[v]] = false;
    if(!used) change(v,fa,-1),maxv = sum = 0;
}

signed main(){
    scanf("%I64d",&N);
    FOR(i,1,N) scanf("%I64d",val+i);
    FOR(i,1,N-1){
        int u,v;scanf("%I64d%I64d",&u,&v);
        add(u,v);add(v,u);
    }
    dfs1(1);dfs2(1);
    FOR(i,1,N) printf("%I64d ",ans[i]);puts("");
    return 0;
}

CF741D
给一棵树,每个节点的权值是'a'到'v'的字母,每次询问要求在一个子树找一条路径,使该路径包含的字符排序后成为回文串。
考虑重新排列后是回文,也就是说至多只有一个字母出现了奇数次,所以我们可以用二进制来处理这个事情。
首先来考虑暴力做法:对于每一个节点进行 dfs ,每到一个节点就强行枚举所有字母找到和它异或后结果为 $1$ 的个数 $<1$ 的路径,再取最大值,这样是 $O(n^2logn)$,可以用 dsu on tree 优化到 $O(nlogn)$。

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

#define fi first
#define lc (x<<1)
#define se second
#define U unsigned
#define rc (x<<1|1)
#define Re register
#define LL long long
#define MP std::make_pair
#define CLR(i,a) memset(i,a,sizeof(i))
#define FOR(i,a,b) for(Re int i = a;i <= b;++i)
#define ROF(i,a,b) for(Re int i = a;i >= b;--i)
#define SFOR(i,a,b,c) for(Re int i = a;i <= b;i+=c)
#define SROF(i,a,b,c) for(Re int i = a;i >= b;i-=c)
#define DEBUG(x) std::cerr << #x << '=' << x << std::endl
#define int LL
const int MAXN = 500000+5;

struct Edge{
    int to,next;
}e[MAXN<<1];

int head[MAXN<<2],son[MAXN],size[MAXN],val[MAXN],dep[MAXN],fa[MAXN],a[MAXN],cnt,N;

inline void add(int u,int v){
    e[++cnt] = (Edge){v,head[u]};head[u] = cnt;
}

void dfs1(int v){
    size[v] = 1;if(v != 1) val[v] = val[fa[v]]^(1<<a[v]);
    for(int i = head[v];i;i = e[i].next){
        //if(e[i].to == fa[v]) continue;
        dep[e[i].to] = dep[v] + 1;fa[e[i].to] = v;
        dfs1(e[i].to);size[v] += size[e[i].to];
        son[v] = size[son[v]] < size[e[i].to] ? e[i].to : son[v];
    }
}

int maxv,f[MAXN<<5];

inline void calc(int v,int lca){
    int now = val[v];
    maxv = std::max(maxv,f[now]+dep[v]-2dep[lca]);
    if(!(val[v]^val[lca])) maxv = std::max(maxv,dep[v]-dep[lca]);
    FOR(i,0,21){
        now = (1<<i)^val[v];
        maxv = std::max(maxv,f[now]+dep[v]-2dep[lca]);
        if((val[v]^val[lca]) == (1<<i)) maxv = std::max(maxv,dep[v]-dep[lca]);
    }
    for(int i = head[v];i;i = e[i].next) calc(e[i].to,lca);
}

inline void change(int v,int k){
    if(k) f[val[v]] = std::max(f[val[v]],dep[v]);
    else f[val[v]] = INT_MIN;
    for(int i = head[v];i;i = e[i].next) change(e[i].to,k);
}
int ans[MAXN];
void dfs2(int v,int k){
    //DEBUG(v);
    for(int i = head[v];i;i = e[i].next){
        if(e[i].to == son[v]) continue;
        dfs2(e[i].to,0);
    }
    if(son[v]) dfs2(son[v],1);
    maxv = 0;int now = val[v];
    maxv = std::max(maxv,f[now]-dep[v]);
    FOR(i,0,21){
        now = (1<<i)^val[v];
        maxv = std::max(maxv,f[now]-dep[v]);
    }
    for(int i = head[v];i;i = e[i].next){
        if(e[i].to == son[v]) continue;
        calc(e[i].to,v);change(e[i].to,1);
    }
    ans[v] = maxv;
    if(!k){
        for(int i = head[v];i;i = e[i].next) change(e[i].to,0);
        f[val[v]] = INT_MIN;
    }else f[val[v]] = std::max(f[val[v]],dep[v]);
}

void erase(int v){
    for(int i = head[v];i;i = e[i].next){
        erase(e[i].to);
        ans[v] = std::max(ans[v],ans[e[i].to]);
    }
}

char str[20];

signed main(){
    scanf("%I64d",&N);
    FOR(i,2,N){
        int u;scanf("%I64d%s",&u,str+1);
        add(u,i);a[i] = str[1]-'a';
    }
    dep[1] = 1;dfs1(1);
    CLR(f,128);dfs2(1,0);
    erase(1);
    FOR(i,1,N) printf("%I64d ",ans[i]);puts("");
    return 0;
}

一个OIer。