fork download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3. #define gc getchar_unlocked
  4. #define fo(i,n) for(i=0;i<n;i++)
  5. #define Fo(i,k,n) for(i=k;i<n;i++)
  6. #define ll long long
  7. #define si(x) scanf("%d",&x)
  8. #define sl(x) scanf("%I64d",&x)
  9. #define ss(s) scanf("%s",s)
  10. #define pb push_back
  11. #define mp make_pair
  12. #define F first
  13. #define S second
  14. #define all(x) x.begin(), x.end()
  15. #define clr(x) memset(x, 0, sizeof(x))
  16. #define sortall(x) sort(all(x))
  17. #define tr(it, a) for(auto it = a.begin(); it != a.end(); it++)
  18. #define PI 3.1415926535897932384626
  19. typedef pair<int, int> pii;
  20. typedef pair<ll, ll> pll;
  21. typedef vector<int> vi;
  22. typedef vector<ll> vl;
  23. typedef vector<pii> vpii;
  24. typedef vector<pll> vpll;
  25. typedef vector<vi> vvi;
  26. typedef vector<vl> vvl;
  27. int mod ;
  28. const int N = 3e5;
  29. vpll g[N];
  30. vi cur;
  31. int part;
  32. int gvis[N], lvl[N], stree[N];
  33. ll up[N], down[N], sz[N];
  34. int a[N];
  35. int L, R;
  36. void dfs(int u, int par){
  37. //add u to current tree
  38. if (par == 0) cur.clear();
  39. cur.pb(u);
  40. lvl[u] = 1+lvl[par];
  41. sz[u] = 1;
  42. for(pii it: g[u]){
  43. int v = it.F;
  44. if (gvis[v] or v == par) continue;
  45. dfs(v, u);
  46. sz[u] += sz[v];
  47. }
  48. }
  49. int centroid(int u, int par){
  50.  
  51. for(auto it: g[u]){
  52. int v = it.F;
  53. if ( v == par or gvis[v]) continue;
  54. if (2*sz[v] > (int)cur.size()) return centroid(v, u);
  55. }
  56. return u;
  57. }
  58. vi dp[N][2];
  59. void go(int u, int par, int cen){
  60. lvl[u] = 1+lvl[par];
  61. if (par == 0) up[u] = 0;
  62. if(par == cen){
  63. stree[u] = part++;
  64. // dr[u] = a[u] <= a[cen]; //0 for lo
  65. }
  66. else {
  67. stree[u] = stree[par];
  68. }
  69. //push u in hi(u)
  70. dp[u][1].pb(0);
  71.  
  72. int i, dr = 0;
  73. for(auto it: g[u]){
  74. int v = it.F, w = it.S;
  75. if (v == par or gvis[v]) continue;
  76. go(v, u, cen);
  77. dr = a[v] >= a[u];
  78. int ex = a[v] > a[u];
  79. for(int val: dp[v][0])
  80. dp[u][dr].pb(val + ex);
  81. for(int val: dp[v][1])
  82. dp[u][dr].pb(val);
  83. }
  84. sortall(dp[u][0]);
  85. sortall(dp[u][1]);
  86. // cout<<"At " << u <<endl;
  87. // for(int x: dp[u][0]) cout<<x<<" "; cout<<endl;
  88. // for(int x: dp[u][1]) cout<<x<<" "; cout<<endl;
  89. // cout<<"done\n";
  90. }
  91. map<ll, ll> cnt, st[N];
  92. bool f(ll val, ll x){
  93. return val >= x;
  94. }
  95. int query(vi &a, ll val, int lo = 0, int hi = -1){
  96. if(hi==-1)hi = a.size()-1;
  97. // cout<<lo<<" * "<<hi<<" "<<val<<" "<<a[0]<<endl;
  98. if (a.empty())return 0;
  99. if (val < 0)return 0;
  100. if (a[0] > val) return 0;
  101. if (lo == hi) return lo+1;
  102. if (lo+1 == hi){
  103. if(f(val, a[hi]))return hi+1;
  104. return lo+1;
  105. }
  106. int mid = (lo+hi)/2;
  107. if(f(val, a[mid])) return query(a, val, mid, hi);
  108. else return query(a, val, lo, mid-1);
  109. }
  110. ll solvefor(int cen){
  111. int i;
  112. //solve for centroid
  113. //init current parts of subtree
  114. part = 0;
  115.  
  116.  
  117. //make centroid as root
  118. //find levels
  119. for(int u: cur) dp[u][0].clear(), dp[u][1].clear();
  120. fo(i, part) st[i].clear();
  121.  
  122. stree[0] = N-1;
  123. lvl[0] = -1;
  124. go(cen, 0, cen);
  125.  
  126. //precalculate the maps
  127.  
  128. //traverse only the current nodes in present tree
  129. set<ll> val;
  130. val.clear();
  131. // for(int u: cur) cnt[up[u]]++, st[stree[u]][up[u]]++, val.insert(up[u]);
  132. ll pre = 0;
  133. //dp so that cnt[x] gives no of nodes with
  134. //distance to root <=x
  135. //calculate ans
  136. ll ans = 0;
  137. int u = cen;
  138. // cout<<"Find answer for centorids "<<u<<endl;
  139. // for(int val: dp[u][0]) cout<<val<<" ";cout<<endl;
  140. // for(int val: dp[u][1]) cout<<val<<" ";cout<<endl;
  141. for(auto it: g[u]){
  142. int v = it.F, w = it.S;
  143. if (gvis[v]) continue;
  144. if(a[v] < a[u]){
  145. // cout<<" v "<<v<<" : "<<dp[v][0].size()<<" "<<dp[v][1].size()<<endl;
  146. for(int val: dp[v][0]){
  147. int lo = L-val, hi = R-val;
  148. if(hi<0) break;
  149. // cout<<val<<": "<<lo<<" "<<hi<<" :: ";
  150. ans += query(dp[u][0], hi-1) - query(dp[u][0], lo-2);
  151. // cout<<ans<<" ";
  152. ans += query(dp[u][1], hi) - query(dp[u][1], lo-1);
  153. // cout<<ans<<" ";
  154. ans -= query(dp[v][0], hi-1) - query(dp[v][0], lo-2);
  155. // cout<<ans<<" ";
  156. ans -= query(dp[v][1], hi-1) - query(dp[v][1], lo-2);
  157. // cout<<ans<<endl;
  158. }
  159. for(int val: dp[v][1]){
  160. int lo = L-val, hi = R-val;
  161. if(hi<0) break;
  162. ans += query(dp[u][0], hi-1) - query(dp[u][0], lo-2);
  163. ans += query(dp[u][1], hi) - query(dp[u][1], lo-1);
  164. ans -= query(dp[v][0], hi-1) - query(dp[v][0], lo-2);
  165. ans -= query(dp[v][1], hi-1) - query(dp[v][1], lo-2);
  166. }
  167. // cout<<"cen "<<cen<<" "<<v<<" "<<ans<<"<"<<endl;
  168. }
  169. else{
  170. // cout<<" v "<<v<<" : "<<dp[v][0].size()<<" "<<dp[v][1].size()<<endl;
  171. for(int val: dp[v][0]){
  172. // cout<<val<<" ";
  173. int ch = 0;
  174. if(a[v]>a[cen])val++, ch = 1;
  175. // cout<<val<<" :: ";
  176. int lo = L-val, hi = R-val;
  177. if(hi<0) break;
  178. ans += query(dp[u][0], hi) - query(dp[u][0], lo-1);
  179. // cout<<ans<<" ";
  180. ans += query(dp[u][1], hi) - query(dp[u][1], lo-1);
  181. // cout<<ans<<" ";
  182. ans -= query(dp[v][0], hi-1-ch) - query(dp[v][0], lo-2-ch);
  183. // cout<<ans<<" ";
  184. ans -= query(dp[v][1], hi-1) - query(dp[v][1], lo-2);
  185. // cout<<ans<<" \n";
  186. }
  187. for(int val: dp[v][1]){
  188. int lo = L-val, hi = R-val;
  189. if(hi<0) break;
  190. ans += query(dp[u][0], hi) - query(dp[u][0], lo-1);
  191. ans += query(dp[u][1], hi) - query(dp[u][1], lo-1);
  192. ans -= query(dp[v][0], hi-1) - query(dp[v][0], lo-2);
  193. ans -= query(dp[v][1], hi-1) - query(dp[v][1], lo-2);
  194. }
  195. }
  196. }
  197. if(u==cen){
  198. ans += query(dp[u][0], R) - query(dp[u][0], L-1);
  199. ans += query(dp[u][1], R) - query(dp[u][1], L-1);
  200. }
  201. // cout<<ans<<endl;
  202. return ans/2;
  203. }
  204.  
  205. ll solve(int u){
  206. //dfs to calculate centroid
  207. dfs(u, 0);
  208. //find centroid of current tree
  209. int cen = centroid(u, 0);
  210. // cout<<cen<<" * \n";
  211. ll ans = 0;
  212. ans += solvefor(cen);
  213. //mark cen done in global visited
  214. gvis[cen] = 1;
  215. for(auto it: g[cen])
  216. if (!gvis[it.F])
  217. ans += solve(it.F);
  218. // Fo(i, 1, n+1) gvis[i] = 1;
  219. return ans;
  220. }
  221. int main()
  222. {
  223. ios_base::sync_with_stdio(false);
  224. cin.tie(NULL);
  225. ll i,n,k,j, u, v, w, q;
  226. int t;
  227. cin>>t;
  228. while(t--){
  229. cin>>n>>L>>R;
  230. fo(i, n) cin>>a[i+1];
  231. Fo(i, 1, n+1) g[i].clear(), gvis[i] = 0;
  232. fo(i, n-1){
  233. cin>>u>>v;
  234. g[u].pb({v, 1});
  235. g[v].pb({u, 1});
  236. }
  237.  
  238. lvl[0] = -1;
  239. cout<<solve(1)<<endl;
  240. }
  241.  
  242. return 0;
  243. }
  244.  
  245. int mpow(int base, int exp) {
  246. base %= mod;
  247. int result = 1;
  248. while (exp > 0) {
  249. if (exp & 1) result = ((ll)result * base) % mod;
  250. base = ((ll)base * base) % mod;
  251. exp >>= 1;
  252. }
  253. return result;
  254. }
Success #stdin #stdout 0s 32776KB
stdin
3
3
1 2
1 3 1
1 2
2 3
7
2 2
1 8 2 3 4 8 1
1 2
2 3
3 4
3 5
3 6
6 7
9
1 7
5 1 1 1 1 1 1 1 1
1 2
2 3
1 4
4 5
1 6
6 7
1 8
8 9
stdout
1
1
24