【题目描述】
给你一棵数,以及这棵树上边的距离.问有多少对点它们两者间的距离小于等于$K$

【输入格式】
第一行 一个整数$n$ 表示树上有多少个点
接下来$n-1$行 每行三个整数 表示一条无向边的两端和权值
最后一行一个整数$k$

【输出格式】
一个整数 表示有多少对点之间的距离小于等于$K$

题解

点分治模板题

什么情况下可以用点分治?
当询问形如这种形式:树上所有路径中最…的/所有路径有多少条满足…
就可以使用点分治将暴力枚举两点的那个$O(n^2)$优化到$O(n\log n)$

假如此题只需要你求**经过点$1$**的距离小于等于$K$的路径有多少条
我们可以这么做:把所有点(包括$1$)到$1$的距离全部存进一个数组$q$里,然后计算$q$中有多少对数字两两加起来小于等于$k$这个是可以通过双指针$O(n)$解决的
这么做显然是正确的
当然是有问题的了!看下面这个例子:

$q$数组里有$0,2,5$,那么$0+2,0+5,2+5$都小于等于$k$,求得的答案有三个,但是实际上合法的经过$1$的路径只有$1$到$2$,$1$到$3$两条
我们注意到$2+5$是不合法的 你不能这么走2->1->3 所以只有(不在$1$的同一个儿子的子树中)的两个点才能相加配对
怎么办?也很简单 按上面的方法处理完后 再分别处理$1$的每一个儿子 把儿子的子树里的所有点进行配对 这些配对都是不合法的 将答案减去这些配对的贡献即可 这个显然也是$O(n)$的 就可以求出正确的答案了 至此我们已经把此题的答案统计方法讲完了 下面讲淀粉质

点分治的思想是这样的:

看这个分叉的菊花图 显然它的重心是$1$
我们先用上面的方法 计算出经过点$1$的路径的贡献
然后递归进$1$的每个儿子的子树 对于每个子树找到它的重心$x$然后如法炮制计算该子树内经过$x$的路径的贡献 注意是该子树内 也就是说 点$1$已经和我们无关了
然后再继续往下递归
这里有一个示意图

先统计$1$的贡献 这是第一层 然后找到子树${2,3,4}$的重心 这里也就是$2$然后递归进入$2$继续统计 这就是第二层
由于树的重心的性质:它每个儿子的子树大小都不会超过$\frac{n}{2}$所以这里最多有$\log n$层 再加上上面统计答案的方法是$O(n)$的 此题复杂度$O(n\log n)$

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;

inline int read() {
int x = 0, f = 1; char ch = getchar();
for (; ch > '9' || ch < '0'; ch = getchar()) if (ch == '-') f = -1;
for (; ch <= '9' && ch >= '0'; ch = getchar()) x = (x << 3) + (x << 1) + (ch ^ '0');
return x * f;
}

const int inf = 0x7fffffff;
int n, k, rt, nowsz, mx;
int head[40005], pre[80005], to[80005], val[80005], sz;
int tot[40005], dis[40005];
ll ans;
bool vis[40005];

void init() {
memset(head, 0, sizeof(head));
memset(vis, 0, sizeof(vis));
sz = 0; ans = 0;
}

inline void addedge(int u, int v, int w) {
pre[++sz] = head[u]; head[u] = sz; to[sz] = v; val[sz] = w;
pre[++sz] = head[v]; head[v] = sz; to[sz] = u; val[sz] = w;
}

void getrt(int x, int fa) { //找到当前子树的重心
tot[x] = 1;
int nowmx = -inf;
for (int i = head[x]; i; i = pre[i]) {
int y = to[i];
if (y == fa || vis[y]) continue;
getrt(y, x);
tot[x] += tot[y];
nowmx = max(nowmx, tot[y]);
}
nowmx = max(nowmx, nowsz - tot[x]);
if (nowmx < mx) {
mx = nowmx, rt = x;
}
}

int l, r, q[40005];

void getdis(int x, int fa) { //统计答案部分:找到当前子树中每个点到重心的距离
q[++r] = dis[x];
for (int i = head[x]; i; i = pre[i]) {
int y = to[i];
if (y == fa || vis[y]) continue;
dis[y] = dis[x] + val[i];
getdis(y, x);
}
}

ll calc(int x, int d) { //统计答案的函数
l = 1, r = 0;
dis[x] = d;
getdis(x, 0); //计算当前子树中每个点到重心的距离
sort(q + 1, q + r + 1); //双指针算出多少对数字满足q[i]+q[j]<=k
ll ret = 0;
while (l < r) {
if (q[l] + q[r] <= k) {
ret += r - l, l++;
} else r--;
}
return ret;
}

void divide(int x) { //点分治函数
ans += calc(x, 0); //这是在统计x的答案
vis[x] = 1;
for (int i = head[x]; i; i = pre[i]) {
int y = to[i];
if (vis[y]) continue;
ans -= calc(y, val[i]); //这是在减去x的答案中不合法的部分
mx = inf;
nowsz = tot[y];
getrt(y, 0); //这是在找y的子树的重心
divide(rt); //递归进入y
}
}

int main() {
n = read();
init();
for (int i = 1; i < n; i++) {
int u = read(), v = read(), w = read();
addedge(u, v, w);
}
k = read();
mx = inf;
nowsz = n;
getrt(1, 0);
divide(rt);
printf("%lld\n", ans);
return 0;
}

评论