fork(3) download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3.  
  4. #define MAX_N 150005
  5. #define MAX_M 100005
  6. #define LG_MAX_N 20
  7.  
  8. int N, M, time_counter, block_size, distinct;
  9. int wt[MAX_N], entry_time[MAX_N], exit_time[MAX_N], depth[MAX_N], freq[MAX_N];
  10. bool inrange[MAX_N];
  11. int time_log[MAX_N<<1];
  12. int anc[MAX_N][LG_MAX_N]; // ancestor table; anc[u][i] is the 2^i th ancestor of u
  13. vector<int> g[MAX_N];
  14.  
  15. struct query{
  16. int id, l, r, lca;
  17. }q[MAX_M];
  18. int ans[MAX_M];
  19.  
  20. void dfs(int u, int par){
  21. time_counter++;
  22. entry_time[u] = time_counter;
  23. time_log[time_counter] = u;
  24.  
  25. for(int i=1; i<LG_MAX_N; i++)
  26. anc[u][i] = anc[anc[u][i-1]][i-1];
  27.  
  28. for(int i=0; i<g[u].size(); i++){
  29. int v = g[u][i];
  30. if(v==par) continue;
  31. depth[v] = depth[u] + 1;
  32. anc[v][0] = u;
  33. dfs(v, u);
  34. }
  35.  
  36. time_counter++;
  37. exit_time[u] = time_counter;
  38. time_log[time_counter] = u;
  39. }
  40.  
  41. int lca(int u, int v){
  42. if(depth[u]>depth[v])
  43. swap(u, v);
  44.  
  45. for(int i=LG_MAX_N-1; i>=0; i--)
  46. if(depth[v]-depth[u]>=(1<<i))
  47. v = anc[v][i];
  48. if(u==v) return u;
  49.  
  50. for(int i=LG_MAX_N-1; i>=0; i--){
  51. if(anc[u][i]!=anc[v][i]){
  52. u = anc[u][i];
  53. v = anc[v][i];
  54. }
  55. }
  56. return anc[u][0];
  57. }
  58.  
  59. bool qcomp(query &q1, query &q2){
  60. int b1 = q1.l/block_size;
  61. int b2 = q2.l/block_size;
  62. return b1==b2 ? q1.r<q2.r : b1<b2;
  63. }
  64.  
  65. void add_rem(int i){
  66. int u = time_log[i];
  67. if(inrange[u]){
  68. freq[wt[u]]--;
  69. if(freq[wt[u]]==0)
  70. distinct--;
  71. }
  72. else{
  73. if(freq[wt[u]]==0)
  74. distinct++;
  75. freq[wt[u]]++;
  76. }
  77. inrange[u] ^= 1;
  78. }
  79.  
  80. void mo(){
  81.  
  82. block_size = (int)(sqrt(N<<1)+0.5);
  83. sort(q, q+M, qcomp);
  84.  
  85. int curr_l = 0, curr_r = -1;
  86. distinct = 0;
  87. for(int i=0; i<M; i++){
  88.  
  89. while(curr_l<q[i].l){
  90. add_rem(curr_l);
  91. curr_l++;
  92. }
  93. while(curr_l>q[i].l){
  94. curr_l--;
  95. add_rem(curr_l);
  96. }
  97.  
  98. while(curr_r<q[i].r){
  99. curr_r++;
  100. add_rem(curr_r);
  101. }
  102. while(curr_r>q[i].r){
  103. add_rem(curr_r);
  104. curr_r--;
  105. }
  106.  
  107. if(time_log[q[i].l]!=q[i].lca)
  108. add_rem(entry_time[q[i].lca]);
  109. ans[q[i].id] = distinct;
  110. if(time_log[q[i].l]!=q[i].lca)
  111. add_rem(entry_time[q[i].lca]);
  112. }
  113. }
  114.  
  115.  
  116. int main(){
  117.  
  118. scanf("%d%d", &N, &M);
  119. for(int i=1; i<=N; i++)
  120. scanf("%d", &wt[i]);
  121.  
  122. map<int, int> cc;
  123. int ccval = 1;
  124. for(int i=1; i<=N; i++){
  125. if(cc.count(wt[i])==0)
  126. cc[wt[i]] = ccval++;
  127. wt[i] = cc[wt[i]];
  128. }
  129.  
  130. for(int i=1; i<=N-1; i++){
  131. int u, v;
  132. scanf("%d%d", &u, &v);
  133. g[u].push_back(v);
  134. g[v].push_back(u);
  135. }
  136.  
  137. anc[1][0] = 1;
  138. time_counter = 0;
  139. dfs(1, -1);
  140.  
  141. for(int i=0; i<M; i++){
  142. int l, r;
  143. scanf("%d%d", &l, &r);
  144. if(entry_time[l]>entry_time[r])
  145. swap(l, r);
  146. q[i].id = i;
  147. q[i].lca = lca(l, r);
  148. if(q[i].lca==l){
  149. q[i].l = entry_time[l];
  150. q[i].r = entry_time[r];
  151. }
  152. else{
  153. q[i].l = exit_time[l];
  154. q[i].r = entry_time[r];
  155. }
  156. }
  157.  
  158. mo();
  159.  
  160. for(int i=0; i<M; i++)
  161. printf("%d\n", ans[i]);
  162.  
  163. return 0;
  164. }
Success #stdin #stdout 0s 37504KB
stdin
8 2
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5
7 8
stdout
4
4