目标
给你一棵二叉树的根节点 root ,二叉树中节点的值 互不相同 。另给你一个整数 start 。在第 0 分钟,感染 将会从值为 start 的节点开始爆发。
每分钟,如果节点满足以下全部条件,就会被感染:
- 节点此前还没有感染。
- 节点与一个已感染节点相邻。
返回感染整棵树需要的分钟数。
示例 1:
输入:root = [1,5,3,null,4,10,6,9,2], start = 3
输出:4
解释:节点按以下过程被感染:
- 第 0 分钟:节点 3
- 第 1 分钟:节点 1、10、6
- 第 2 分钟:节点5
- 第 3 分钟:节点 4
- 第 4 分钟:节点 9 和 2
感染整棵树需要 4 分钟,所以返回 4 。
示例 2:
输入:root = [1], start = 1
输出:0
解释:第 0 分钟,树中唯一一个节点处于感染状态,返回 0 。
说明:
- 树中节点的数目在范围 [1, 10^5] 内
- 1 <= Node.val <= 10^5
- 每个节点的值 互不相同
- 树中必定存在值为 start 的节点
思路
从树中任意节点开始,每过一分钟感染会向周围扩散,问感染整棵树需要多久。
首先我们要找到感染开始的节点。从这个节点出发,向左右子树点以及父节点扩散。可以将树转换为以感染节点为起点的有向无环连通图,这样问题被转换为求起点到图中任意节点的最长路径。
如果不想建图可以考虑扩散的具体路径,刚开始很难把各种情况都考虑到。我们需要计算以开始节点为根的子树高度 h(start)
,并依次比较开始节点到祖先节点路径长度加上祖先另一子树高度的最大值,即max(d(start) - d(ancestor) + h(anotherAncestorSubtree))
,再取二者的最大值即可。特别需要注意的是,不能使用子树高度之差来计算祖先与开始节点的路径长度。例如,E是开始节点,E到B的路径长度为d(E) - d(B) = 2 - 1 = 1
,而如果使用子树高度相减的话就得到了h(B) - h(E) = 3 - 0 = 3
。
A
/ \
B C
/ \
D E
|
F
|
G
在具体实现的时候如何判断祖先节点的哪个子树包含开始节点困扰了我半天。刚开始我选择了一个标志位,分别在左右子树递归结束的时候检测该标志,发现找到之后立即重置该标志,这样父节点就知道了是左子树还是右子树包含开始节点。但问题是再向上返回的时候就无法判断了。
可以考虑返回二维数组,也有网友的题解使用返回值的符号来标识是否找到开始节点。
代码
/**
* @date 2024-04-24 8:56
*/
public class AmountOfTime2385 {
int startToParentToLeaf = 0;
int startToLeaf = 0;
int cnt = 0;
public int amountOfTime(TreeNode root, int start) {
dfs(root, start);
return Math.max(startToLeaf, startToParentToLeaf);
}
/**
* 返回子树深度
*/
public int[] dfs(TreeNode root, int start) {
if (root == null) {
return new int[]{0, 0};
}
int[] l = dfs(root.left, start);
int[] r = dfs(root.right, start);
boolean lfind = l[1] == 1;
boolean rfind = r[1] == 1;
int max = Math.max(r[0], l[0]);
if (lfind || rfind) {
startToParentToLeaf = Math.max(startToParentToLeaf, l[0] + r[0]);
// 这里的返回值不是max,而是祖先节点到开始节点的路径长度
return new int[]{(lfind ? l[0] : r[0]) + 1, 1};
}
if (root.val == start) {
startToLeaf = max;
// 这里直接返回1,不加max
// 视为将开始节点的左右子树删掉,后面回溯时直接相加左右子树高度即可
return new int[]{1, 1};
}
return new int[]{max + 1, 0};
}
/**
* 返回深度
*/
public int[] dfs_v1(TreeNode root, int start, int depth) {
if (root == null) {
return new int[]{depth - 1, 0};
}
int[] l = dfs_v1(root.left, start, depth + 1);
int[] r = dfs_v1(root.right, start, depth + 1);
boolean lfind = l[1] == 1;
boolean rfind = r[1] == 1;
int max = Math.max(r[0], l[0]);
if (lfind) {
cnt++;
startToParentToLeaf = Math.max(r[0] - depth + cnt, startToParentToLeaf);
}
if (rfind) {
cnt++;
startToParentToLeaf = Math.max(l[0] - depth + cnt, startToParentToLeaf);
}
if (root.val == start) {
startToLeaf = max - depth;
return new int[]{max, 1};
}
return new int[]{max, l[1] + r[1]};
}
}