fork(1) download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3. using ll = long long;
  4. const int LOG = 20;
  5.  
  6. struct BIT {
  7. int n; vector<int> b;
  8. void init(int _n){ n=_n; b.assign(n+1,0); }
  9. void add(int i,int v){ for(; i<=n; i+=i&-i) b[i]+=v; }
  10. int sum(int i){ int r=0; for(; i>0; i-=i&-i) r+=b[i]; return r; }
  11. };
  12.  
  13. int main(){
  14. ios::sync_with_stdio(false);
  15. cin.tie(nullptr);
  16. int n,q;
  17. if(!(cin>>n>>q)) return 0;
  18. vector<int> val(n+1);
  19. for(int i=1;i<=n;i++) cin>>val[i];
  20. vector<vector<int>> g(n+1);
  21. for(int i=1;i<n;i++){
  22. int u,v; cin>>u>>v;
  23. g[u].push_back(v); g[v].push_back(u);
  24. }
  25. vector<array<int,LOG>> up(n+1);
  26. vector<int> depth(n+1,0);
  27. vector<ll> S(n+1,0);
  28. function<void(int,int)> dfs0 = [&](int u,int p){
  29. up[u][0]=p;
  30. for(int j=1;j<LOG;j++) up[u][j]=up[ up[u][j-1] ][j-1];
  31. for(int v:g[u]) if(v!=p){
  32. depth[v]=depth[u]+1;
  33. S[v]=S[u]+val[v];
  34. dfs0(v,u);
  35. }
  36. };
  37. S[1]=val[1];
  38. dfs0(1,0);
  39. auto lca = [&](int a,int b){
  40. if(depth[a]<depth[b]) swap(a,b);
  41. int k=depth[a]-depth[b];
  42. for(int j=LOG-1;j>=0;j--) if(k>>j&1) a=up[a][j];
  43. if(a==b) return a;
  44. for(int j=LOG-1;j>=0;j--) if(up[a][j]!=up[b][j]){
  45. a=up[a][j]; b=up[b][j];
  46. }
  47. return up[a][0];
  48. };
  49. vector<vector<tuple<int,int,int>>> gu(n+1);
  50. vector<ll> needv(q+1);
  51. vector<int> U(q+1),V(q+1),W(q+1);
  52. vector<ll> vals;
  53. vals.reserve(n+q+5);
  54. for(int i=1;i<=n;i++) vals.push_back(S[i]);
  55. for(int i=1;i<=q;i++){
  56. int u,v; cin>>u>>v;
  57. U[i]=u; V[i]=v;
  58. int w = lca(u,v);
  59. W[i]=w;
  60. needv[i]=S[u]-2*S[w];
  61. vals.push_back(needv[i]);
  62. }
  63. sort(vals.begin(), vals.end());
  64. vals.erase(unique(vals.begin(), vals.end()), vals.end());
  65. int M = (int)vals.size();
  66. vector<int> Spos(n+1);
  67. for(int i=1;i<=n;i++) Spos[i]= (int)(lower_bound(vals.begin(), vals.end(), S[i]) - vals.begin()) + 1;
  68. for(int i=1;i<=q;i++){
  69. int idxNeed = (int)(upper_bound(vals.begin(), vals.end(), needv[i]) - vals.begin());
  70. gu[U[i]].push_back({i, idxNeed, +1});
  71. gu[V[i]].push_back({i, idxNeed, +1});
  72. gu[W[i]].push_back({i, idxNeed, -2});
  73. }
  74. vector<ll> ans(q+1,0);
  75. for(int i=1;i<=q;i++){
  76. if(S[ W[i] ] > needv[i]) ans[i]++;
  77. }
  78. BIT bit; bit.init(M);
  79. function<void(int,int)> dfs = [&](int u,int p){
  80. bit.add(Spos[u],1);
  81. for(auto &t: gu[u]){
  82. int id, idxNeed, sign;
  83. tie(id, idxNeed, sign) = t;
  84. int greater = bit.sum(M) - bit.sum(idxNeed);
  85. ans[id] += 1LL*sign*greater;
  86. }
  87. for(int v: g[u]) if(v!=p) dfs(v,u);
  88. bit.add(Spos[u],-1);
  89. };
  90. dfs(1,0);
  91. for(int i=1;i<=q;i++) cout<<ans[i]<<"\n";
  92. return 0;
  93. }
  94.  
Success #stdin #stdout 0.01s 5280KB
stdin
8 5
-1 1 1 1 -1 1 1  -1
1 5
5 6
3 6
4 5
4 7
4 8
1 2
3 8
2 2
1 7
2 7
6 4
stdout
0
0
0
0
0