点分治是什么:
首先选取一个点将无根树转为有根树,再递归处理每一棵以根节点的儿子为根的子树。
这样得到的分治结构有何用处?
通俗地讲,对任一个点,树的任一条路径要么经过这个点,要么不经过这个点,在这个点的子树里面。这对于树路径的计数等问题能够更容易的解决。
我们先详细介绍如何点分治,再来继续考虑使用。
考虑到我们每一次选择分治中心的时候,我们希望选出的这个根分出的子树的大小都尽量的小,我们可以选择使用这棵树的重心。
大致算法流程:
- 求出当前树的重心,钦定重心为根
- 统计当前树的答案
- 对与当前树相邻的,未成为任何一棵树的重心的节点执行第一步。
注意到树的重心有一个性质:树的重心每棵子树的大小一定小于等于 $\frac{n}{2}$,也就说明了每一次分治过程都会减少至少一半的问题规模,所以时间复杂度是优秀的一个 $log$。
接下来我们看一道简单的经典例题:
POJ1741
给出一棵边权树,询问有多少对点的距离小于等于 $k$。
$2 \leq n \leq 10000,k \leq 2^{31}$。
我们考虑每一次点分治的时候对于分治重心求答案,以 $v$ 为根,统计所有经过 $v$ 的路径。
我们可以对于每一个点,求出到根的路径长度,设为 $dis$,这样我们就是求满足 $dis_x + dis_y \leq k$ ,且 $x$ 和 $y$ 不在以 $v$ 为根的同一个子树内的 $(x,y)$ 的方案数(同一个子树两点路径不一定经过根)。
我们首先不考虑子树限制,问题简化为给你一个 $dis$ 数组,求里面和小于等于 $k$ 的数对个数。我们可以对其排个序,然后双指针扫一遍统计答案。那我们首先不注意子树限制求出来一个答案,在单独减去每个子树对答案的贡献就可以了。
时间复杂度是 $O(nlog^2n)$ 的。
#include <algorithm>
#include <iostream>
#include <cstring>
#include <climits>
#include <cstdio>
#include <vector>
#include <cstdlib>
#include <ctime>
#include <cmath>
#include <queue>
#include <stack>
#include <map>
#include <set>
#define fi first
#define lc (x<<1)
#define se second
#define U unsigned
#define rc (x<<1|1)
#define Re register
#define LL long long
#define MP std::make_pair
#define CLR(i,a) memset(i,a,sizeof(i))
#define FOR(i,a,b) for(Re int i = a;i <= b;++i)
#define ROF(i,a,b) for(Re int i = a;i >= b;--i)
#define SFOR(i,a,b,c) for(Re int i = a;i <= b;i+=c)
#define SROF(i,a,b,c) for(Re int i = a;i >= b;i-=c)
#define DEBUG(x) std::cerr << #x << '=' << x << std::endl
const int MAXN = 40000+5;
struct Edge{
int to,w,next;
}e[MAXN<<1];
int head[MAXN],dis[MAXN],size[MAXN],max[MAXN],cnt,root,sum;
bool vis[MAXN];
int N,K,st[MAXN];
inline void add(int u,int v,int w){
e[++cnt] = (Edge){v,w,head[u]};head[u] = cnt;
}
void getroot(int v,int fa){
size[v] = 1;max[v] = 0;
for(int to,i = head[v];i;i = e[i].next){
if((to = e[i].to) == fa || vis[to]) continue;
getroot(to,v);size[v] += size[to];
max[v] = std::max(max[v],size[to]);
}
max[v] = std::max(max[v],sum-max[v]);
root = max[root] > max[v] ? v : root;
}
void getdis(int v,int fa){
st[++st[0]] = dis[v];
for(int to,i = head[v];i;i = e[i].next){
if((to = e[i].to) == fa || vis[to]) continue;
dis[to] = dis[v] + e[i].w;getdis(to,v);
}
}
int calc(int v,int val){
int res = 0;
st[0] = 0;dis[v] = val;
getdis(v,0);std::sort(st+1,st+st[0]+1);
for(int l = 1,r = st[0];l < r;){
if(st[l] + st[r] <= K) res += r-l++;
else --r;
}
return res;
}
int ans = 0;
void divide(int v){
ans += calc(v,0);vis[v] = true;
for(int to,i = head[v];i;i = e[i].next){
if(vis[(to = e[i].to)]) continue;
ans -= calc(to,e[i].w);
max[root = 0] = sum = size[to];
getroot(to,0);divide(root);
}
}
int main(){
scanf("%d",&N);
FOR(i,1,N-1){
int u,v,w;scanf("%d%d%d",&u,&v,&w);
add(u,v,w);add(v,u,w);
}
scanf("%d",&K);
max[root = 0] = sum = N;getroot(1,0);
divide(root);
printf("%dn",ans);
return 0;
}