import math
import sys
import time
sys.setrecursionlimit(500000)
class TreeNode:
def __init__(self, left=None, right=None):
self.left = left
self.right = right
class Solution:
def isBalanced(self, root):
if not root:
return True
left = self.height(root.left)
right = self.height(root.right)
if abs(left - right) > 1:
return False
return self.isBalanced(root.left) and self.isBalanced(root.right)
def height(self, node):
if not node:
return 0
return 1 + max(self.height(node.left), self.height(node.right))
def build_linked_list(n):
head = None
for _ in range(n):
head = TreeNode(left=head)
return head
def build_perfect_tree(n):
if n <= 0:
return None
left_n = (n - 1) // 2
right_n = n - 1 - left_n
return TreeNode(left=build_perfect_tree(left_n), right=build_perfect_tree(right_n))
def build_worst_case_tree(k):
"""
k is the length of the spine.
"""
if k <= 0:
return None
root = TreeNode(left=build_worst_case_tree(k - 1))
if k > 1:
root.right = build_linked_list(k - 1)
return root
N = 200_000
# Calculate the spine length k needed to get roughly N total nodes
# k*(k+1)/2 = N => k^2 + k - 2N = 0 => k = (-1 +/- sqrt(1+8N))/2
# We take the positive root and round down to get an integer k.
# That's k = int((-1 + math.sqrt(1 + 8 * N)) / 2)
k = int((-1 + math.sqrt(1 + 8 * N)) / 2)
trees = {
"Linked List": build_linked_list(N),
"Perfect Tree": build_perfect_tree(N),
"Worst-Case Tree": build_worst_case_tree(k)
}
def count_nodes(node):
if not node:
return 0
return 1 + count_nodes(node.left) + count_nodes(node.right)
print("Node counts:")
for name, tree in trees.items():
print(f"{name: <16} : {count_nodes(tree)} nodes")
solver = Solution()
print(f"\nTime it takes:")
for name, tree in trees.items():
start = time.perf_counter()
solver.isBalanced(tree)
end = time.perf_counter()
print(f"{name: <16} : {end - start:.5f} seconds")