fork download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3.  
  4. using i64 = int64_t;
  5.  
  6. const int MAXN = 2.1e5;
  7. const int MAXM = 2.1e5;
  8. int N, M;
  9. int A[MAXM];
  10. int B[MAXM];
  11. vector<int> adj[MAXN];
  12.  
  13. int par[MAXN][21];
  14. int depth[MAXN];
  15.  
  16. void dfs_par(int cur, int prv) {
  17. if (prv) {
  18. adj[cur].erase(find(adj[cur].begin(), adj[cur].end(), prv));
  19. }
  20. par[cur][0] = prv;
  21. for (int i = 0; par[cur][i]; i++) {
  22. par[cur][i+1] = par[par[cur][i]][i];
  23. }
  24.  
  25. depth[cur] = prv ? depth[prv] + 1 : 0;
  26. for (int nxt : adj[cur]) {
  27. dfs_par(nxt, cur);
  28. }
  29. }
  30.  
  31. int getAnc(int cur, int k) {
  32. assert(k >= 0);
  33. assert(depth[cur] >= k);
  34. while (k > 0) {
  35. int i = __builtin_ctz(k);
  36. cur = par[cur][i];
  37. k -= 1 << i;
  38. }
  39. return cur;
  40. }
  41.  
  42. int lca(int a, int b) {
  43. if (depth[a] > depth[b]) swap(a, b);
  44. b = getAnc(b, depth[b] - depth[a]);
  45. assert(depth[a] == depth[b]);
  46. if (a == b) return a;
  47. int i = 0;
  48. while (par[a][i] != par[b][i]) {
  49. i++;
  50. }
  51. while (i--) {
  52. if (par[a][i] != par[b][i]) {
  53. a = par[a][i], b = par[b][i];
  54. }
  55. }
  56. return par[a][0];
  57. }
  58.  
  59. int numPass[MAXN];
  60.  
  61. void dfs_sub(int cur) {
  62. for (int nxt : adj[cur]) {
  63. dfs_sub(nxt);
  64. numPass[cur] += numPass[nxt];
  65. }
  66. }
  67.  
  68. int bad[MAXN];
  69. map<pair<int, int>, int> badPairs;
  70.  
  71. i64 c2(i64 n) {
  72. return n * (n-1) / 2;
  73. }
  74.  
  75. int main() {
  76. ios::sync_with_stdio(0), cin.tie(0);
  77.  
  78. cin >> N >> M;
  79. for (int i = 0; i < M; i++) {
  80. cin >> A[i] >> B[i];
  81. }
  82.  
  83. for (int i = 0; i < N-1; i++) {
  84. adj[A[i]].push_back(B[i]);
  85. adj[B[i]].push_back(A[i]);
  86. }
  87. dfs_par(1, 0);
  88.  
  89. for (int i = N-1; i < M; i++) {
  90. int a = A[i], b = B[i];
  91. int c = lca(a, b);
  92. numPass[a] ++;
  93. numPass[b] ++;
  94. numPass[c] -= 2;
  95. }
  96. dfs_sub(1);
  97.  
  98. i64 ans = 0;
  99. for (int i = N-1; i < M; i++) {
  100. int a = A[i], b = B[i];
  101. int c = lca(a, b);
  102. if (depth[a] > depth[b]) swap(a, b);
  103. assert(a != b);
  104.  
  105. int pb = getAnc(b, depth[b] - depth[c] - 1);
  106. bad[pb] ++;
  107. ans += numPass[pb];
  108. ans --;
  109. if (a != c) {
  110. int pa = getAnc(a, depth[a] - depth[c] - 1);
  111. bad[pa] ++;
  112. badPairs[minmax(pa, pb)] ++;
  113. ans += numPass[pa];
  114. ans --;
  115. }
  116. }
  117. for (int i = 1; i <= N; i++) {
  118. ans -= c2(bad[i]);
  119. }
  120. for (auto it : badPairs) {
  121. ans -= c2(it.second);
  122. }
  123. cout << ans << '\n';
  124.  
  125. return 0;
  126. }
  127.  
Success #stdin #stdout 0s 8360KB
stdin
5 8
1 2
1 3
1 4
1 5
2 3
3 4
4 5
5 2
stdout
4