fork(1) download
  1. #include <bits/stdc++.h>
  2. #define vi vector<int>
  3. #define pb push_back
  4. #define fo(i, n) for(i=0; i<n; i++)
  5. using namespace std;
  6. const int N=40011,M=100011;
  7. int BLOCK;
  8. struct node{
  9. int l, r, i, sp, ans;
  10. }Q[M];
  11. int lvl[N], p[N][16], st[N], en[N], id[2*N], occ[N], ans, w[N];
  12. map<int, int> HASH;
  13. int cnt[N];
  14. vi g[N];
  15. int ti;
  16. bool f(node a, node b){
  17. if (a.l/BLOCK != b.l/BLOCK)
  18. return a.l < b.l;
  19. return a.r < b.r;
  20. }
  21. bool gg(node a, node b){
  22. return a.i < b.i;
  23. }
  24. void dfs(int u, int p){
  25.  
  26. st[u] = ++ti;
  27. id[ti] = u;
  28. int v;
  29. for(auto v: g[u]){
  30. if (v == p)
  31. continue;
  32. lvl[v] = lvl[u]+1;
  33. ::p[v][0] = u;
  34. dfs(v, u);
  35. }
  36. en[u] = ++ti;
  37. id[ti] = u;
  38. }
  39. int lca(int u, int v){
  40. int lg, i;
  41. for (lg = 0; (1<<lg) <= lvl[u]; lg++);
  42. lg--;
  43. for(i=lg; i>=0; i--)
  44. if ( lvl[u] - (1<<i) >= lvl[v])
  45. u = p[u][i];
  46. if (u == v)
  47. return u;
  48. for(i = lg; i >= 0; i--){
  49. if (p[u][i] != -1 && p[u][i] != p[v][i])
  50. u = p[u][i], v = p[v][i];
  51. }
  52. return p[u][0];
  53. }
  54. void add(int node){
  55. occ[node]++;
  56. cnt[w[node]]++;
  57. if (occ[node] == 2){
  58. cnt[w[node]] -= 2;
  59. if (cnt[w[node]] == 0)
  60. ans--;
  61. }
  62. else if (cnt[w[node]] == 1) ans++;
  63. }
  64. void del(int node){
  65. int wt = w[node];
  66. occ[node]--;
  67.  
  68. if (occ[node] == 1){
  69. cnt[wt]++;
  70. if (cnt[wt] == 1)
  71. ans++;
  72. return;
  73. }
  74. cnt[wt]--;
  75. if (cnt[wt] == 0) ans--;
  76. }
  77. int main() {
  78. ios_base::sync_with_stdio(false);
  79. int n, m, i, j, u, v;
  80. ans = ti = 0;
  81. cin>>n>>m;
  82. BLOCK = sqrt(n);
  83. int no = 0;
  84. HASH.clear();
  85. fo(i, n){
  86. cin>>w[i+1];
  87. if (HASH.find(w[i+1]) == HASH.end())
  88. HASH[w[i+1]] = ++no;
  89. w[i+1] = HASH[w[i+1]];
  90.  
  91. }
  92. fo(i, n-1){
  93. cin>>u>>v;
  94. g[u].pb(v);
  95. g[v].pb(u);
  96. }
  97. lvl[1] = 0;
  98. memset(cnt, 0, sizeof(cnt));
  99. memset(occ, 0, sizeof(occ));
  100. memset(p, -1, sizeof(p));
  101. dfs(1, 0);
  102. for(i=1; i<16; i++)
  103. for(j=1; j<=n; j++)
  104. if( p[j][i-1] != -1)
  105. p[j][i] = p[p[j][i-1]][i-1];
  106. fo(i, m){
  107. Q[i].i = i;
  108. Q[i].sp = -1;
  109. cin>>u>>v;
  110. if (lvl[u] < lvl[v])
  111. swap(u, v);
  112. int w = lca(u, v);
  113. if (w == v){
  114. Q[i].l = st[v];
  115. Q[i].r = st[u]+1;
  116. }else{
  117. if (st[v] > en[u]){
  118. Q[i].l = en[u];
  119. Q[i].r = st[v]+1;
  120. }
  121. else{
  122. Q[i].l = en[v];
  123. Q[i].r = st[u]+1;
  124. }
  125. // Special case: We have to consider 'w' separately.
  126. Q[i].sp = w;
  127. }
  128. }
  129. sort(Q, Q+m, f);
  130. int currL = 0, currR = 0, L, R;
  131. fo(i, m){
  132. L = Q[i].l, R = Q[i].r;
  133. while (currL < L){
  134.  
  135. del(id[currL]);
  136. currL++;
  137. }
  138. while (currL > L){
  139. add(id[currL-1]);
  140. currL--;
  141. }
  142. while (currR < R){
  143. add(id[currR]);
  144. currR++;
  145. }
  146. while (currR > R){
  147. del(id[currR-1]);
  148. currR--;
  149. }
  150. Q[i].ans = ans;
  151. if (Q[i].sp != -1){
  152. if (cnt[w[Q[i].sp]] == 0)
  153. Q[i].ans = ans+1;
  154. }
  155. }
  156. sort(Q, Q+m, gg);
  157. fo(i, m)
  158. cout<<Q[i].ans<<endl;
  159.  
  160. return 0;
  161. }
Success #stdin #stdout 0s 9592KB
stdin
8 2
105 2 9 3 8 5 7 7
1 2        
1 3
1 4
3 5
3 6
3 7
4 8
2 5
7 8
stdout
4
4