fork download
  1. import sys
  2.  
  3. sys.setrecursionlimit(2000000)
  4.  
  5. MOD = 998244353
  6. MAXN = 500005
  7.  
  8. fact = [1] * MAXN
  9. invFact = [1] * MAXN
  10.  
  11. def power(a, b):
  12. res = 1
  13. a %= MOD
  14. while b > 0:
  15. if b % 2 == 1:
  16. res = (res * a) % MOD
  17. a = (a * a) % MOD
  18. b //= 2
  19. return res
  20.  
  21. def modInverse(n):
  22. return power(n, MOD - 2)
  23.  
  24. def precompute():
  25. for i in range(1, MAXN):
  26. fact[i] = (fact[i - 1] * i) % MOD
  27. invFact[MAXN - 1] = modInverse(fact[MAXN - 1])
  28. for i in range(MAXN - 2, -1, -1):
  29. invFact[i] = (invFact[i + 1] * (i + 1)) % MOD
  30.  
  31. def nCr(n, r):
  32. if r < 0 or r > n:
  33. return 0
  34. num = fact[n]
  35. den = (invFact[r] * invFact[n - r]) % MOD
  36. return (num * den) % MOD
  37.  
  38. def solve():
  39. n, c = map(int, sys.stdin.readline().split())
  40.  
  41. adj = [[] for _ in range(n + 1)]
  42. for _ in range(n - 1):
  43. u, v = map(int, sys.stdin.readline().split())
  44. adj[u].append(v)
  45. adj[v].append(u)
  46.  
  47. if c >= n:
  48. print(0)
  49. return
  50.  
  51. dep = [0] * (n + 1)
  52. visited = [False] * (n + 1)
  53. q = [(1, 0)]
  54. visited[1] = True
  55.  
  56. head = 0
  57. while head < len(q):
  58. u, d = q[head]
  59. head += 1
  60. dep[u] = d
  61. for v in adj[u]:
  62. if not visited[v]:
  63. visited[v] = True
  64. q.append((v, d + 1))
  65.  
  66. M = n - 1
  67.  
  68. W_c = 0
  69. for i in range(2, n + 1):
  70. W_c = (W_c + nCr(dep[i] - 1, c - 1)) % MOD
  71. N_c = (W_c * fact[c]) % MOD
  72. N_c = (N_c * fact[M - c]) % MOD
  73.  
  74. N_c_plus_1 = 0
  75. if c + 1 <= M:
  76. W_c_plus_1 = 0
  77. for i in range(2, n + 1):
  78. W_c_plus_1 = (W_c_plus_1 + nCr(dep[i] - 1, c)) % MOD
  79. N_c_plus_1 = (W_c_plus_1 * fact[c + 1]) % MOD
  80. N_c_plus_1 = (N_c_plus_1 * fact[M - c - 1]) % MOD
  81.  
  82. ans = (N_c - N_c_plus_1 + MOD) % MOD
  83. print(ans)
  84.  
  85. precompute()
  86. line = sys.stdin.readline()
  87. if line:
  88. t_cases = int(line)
  89. for _ in range(t_cases):
  90. solve()
Success #stdin #stdout 0.34s 52044KB
stdin
3
2 1
1 2
5 1
1 2
1 3
1 4
1 5
5 3
1 2
1 3
1 4
1 5
stdout
1
24
0