# Đếm số đường đi không hướng độ dài k trên cây heap (n nodes).
# n: int (<=1e18), k: nonnegative int
# Trả về: số đường (int)
def count_pairs(n: int, k: int) -> int:
if k == 0:
return n # mỗi đỉnh là một đường độ dài 0
# helper: số nodes ở "độ sâu h" trong subtree gốc tại node c (c >= 1)
# tức là số x in [c*2^h, c*2^h + 2^h -1] ∩ [1,n]
def depth_count(c: int, h: int) -> int:
if c <= 0:
return 0
start = c << h
if start > n:
return 0
end = start + (1 << h) - 1
if end <= n:
return 1 << h
return n - start + 1
# helper: số chẵn trong đoạn [L,R] (L,R ints, L<=R)
def count_even(L: int, R: int) -> int:
if R < L:
return 0
first = L if (L % 2 == 0) else (L + 1)
if first > R:
return 0
return ((R - first) // 2) + 1
# Precompute for h in [0..k-1]:
# M_h = floor(n / 2^h)
# partial_size_h = n - M_h * 2^h + 1 (if M_h >= 1) else 0
M = []
partial = []
pow2 = [1 << h for h in range(k)] # safe since k small in practice
for h in range(k):
twoh = 1 << h
Mh = n // twoh
M.append(Mh)
if Mh >= 1:
partial.append(n - Mh * twoh + 1)
else:
partial.append(0)
# sum S1 = sum_a (depth(left, k-1) + depth(right, k-1))
# left children are even indices c=2a, right are odd c=2a+1
# we compute sum over even c of depth_count(c,k-1) and sum over odd similarly
def sum_by_parity(h: int):
Mh = M[h]
if Mh == 0:
return (0, 0)
full = 1 << h
# evens among full part (1..Mh-1)
even_full_cnt = (Mh - 1) // 2
odd_full_cnt = (Mh - 1) - even_full_cnt
even_sum = even_full_cnt * full + (partial[h] if (Mh % 2 == 0) else 0)
odd_sum = odd_full_cnt * full + (partial[h] if (Mh % 2 == 1) else 0)
return (even_sum, odd_sum)
even_km1, odd_km1 = sum_by_parity(k-1)
S1 = even_km1 + odd_km1
# S2 = sum_{l=1..k-1} sum_{a} depth(left, l-1) * depth(right, k-l-1)
# = sum_{l=1..k-1} sum_{even c} f(c) * g(c+1), where f(c)=depth_count(c, l-1), g(c+1)=depth_count(c+1, k-l-1)
S2 = 0
for l in range(1, k):
h1 = l - 1
h2 = k - l - 1
M1 = M[h1]
M2 = M[h2]
if M1 == 0 or M2 == 0:
continue
# valid c range (even c) must satisfy c <= M1 and c+1 <= M2 => c <= min(M1, M2-1)
C = min(M1, M2 - 1)
if C < 2:
continue
# range [2..C] with even c only
total_even_positions = count_even(2, C)
full_prod = (1 << h1) * (1 << h2)
# exclude special positions where f or g is partial:
specials = {}
# position c = M1 (if in range) => f(c) is partial (otherwise full)
if 2 <= M1 <= C and (M1 % 2 == 0):
specials[M1] = ('f_partial', partial[h1])
# position c = M2 - 1 (if in range) => g(c+1) is partial
c2 = M2 - 1
if 2 <= c2 <= C and (c2 % 2 == 0):
if c2 in specials:
# both partial
specials[c2] = ('both_partial', (partial[h1], partial[h2]))
else:
specials[c2] = ('g_partial', partial[h2])
# count how many even positions are 'regular full' (neither special)
num_special_positions = len(specials)
regular_count = total_even_positions - num_special_positions
if regular_count > 0:
S2 += regular_count * full_prod
# add contributions of special positions
for cpos, info in specials.items():
if info[0] == 'f_partial':
fval = info[1]
gval = (1 << h2) if (cpos + 1 != M2) else partial[h2] # but cpos+1==M2 can't happen here because we marked only M1; safe to compute directly:
# safer compute explicitly:
gval = depth_count(cpos + 1, h2)
S2 += fval * gval
elif info[0] == 'g_partial':
gval = info[1]
fval = depth_count(cpos, h1)
S2 += fval * gval
else: # both_partial
p1, p2 = info[1]
S2 += p1 * p2
total = S1 + S2
return total