fork download
  1. // I can't tell you what it really is,
  2. // I can only tell you what it feels like.
  3. #include "bits/stdc++.h"
  4. using namespace std;
  5.  
  6. #define int long long
  7. #define rep(i,a,n) for (int i = a; i <= n; ++i)
  8. #define pb push_back
  9. #define P pair < int, int >
  10. #define s second
  11. #define f first
  12. #define all(v) v.begin(), v.end()
  13. #define lb(v, x) lower_bound(all(v), x) - v.begin()
  14. #define up(v, x) upper_bound(all(v), x) - v.begin()
  15.  
  16. const int mod = 1e9 + 7;
  17. const int N = 2e5 + 5;
  18.  
  19. vector < int > v[N];
  20. int val[N], idx[N], ar[N], ok = 1;
  21. int sz[N];
  22.  
  23. void dfs(int x, int p) {
  24. ar[ok] = x;
  25. idx[x] = ok++;
  26. sz[x] = 1;
  27. for (int i : v[x]) {
  28. if (i != p) {
  29. dfs(i, x);
  30. sz[x] += sz[i];
  31. }
  32. }
  33. }
  34.  
  35.  
  36. int a[N];
  37. vector < int > tree[N << 2], sum[N<<2];
  38.  
  39. class merge_sort_tree{
  40. public:
  41. void build(int l, int r, int node){
  42. if(l == r){
  43. tree[node].pb(a[l]);
  44. return ;
  45. }
  46. int mid = l + r >> 1, lc = node + node, rc = 1 + lc;
  47. build(l, mid, lc); build(mid + 1, r, rc);
  48.  
  49. merge(all(tree[lc]), all(tree[rc]), back_inserter(tree[node]));
  50. }
  51. void done() {
  52. int kit = N<<2;
  53. for (int i = 0; i < kit; ++i) {
  54. sum[i].resize(tree[i].size());
  55. if (!tree[i].empty())
  56. sum[i][0] = tree[i][0];
  57. for (int j = 1; j < tree[i].size(); ++j) {
  58. sum[i][j] = sum[i][j-1] + tree[i][j];
  59. }
  60. }
  61. }
  62. int query(int l, int r, int ql, int qr, int val, int node){
  63. if(qr < l || r < ql)
  64. return 0;
  65. if(ql <= l and r <= qr){
  66. int L = lower_bound(all(tree[node]), val) - tree[node].begin();
  67. if (!L) {
  68. return 0;
  69. }
  70. return sum[node][L-1];
  71. }
  72. int mid = l + r >> 1, lc = node + node, rc = 1 + lc;
  73.  
  74. return (query(l, mid, ql, qr, val, lc)+ query(mid + 1, r, ql, qr, val, rc));
  75. }
  76. };
  77.  
  78. inline void solve() {
  79. int n, l, r;
  80. cin >> n;
  81. rep(i,2,n) {
  82. cin >> l >> r;
  83. v[l].pb(r);
  84. v[r].pb(l);
  85. }
  86. dfs(1, 1);
  87. rep(i,1,n) {
  88. cin >> val[i];
  89. a[idx[i]] = val[i];
  90. }
  91. merge_sort_tree obj;
  92. obj.build(1, n, 1);
  93. obj.done();
  94. int ans = 0;
  95. rep(i,1,n) {
  96. if (sz[ar[i]] == 1) {
  97. continue;
  98. }
  99. int l = i+1, r = i + sz[ar[i]] - 1;
  100. if (l <= r) {
  101. ans += obj.query(1, n, l, r, val[ar[i]], 1);
  102. }
  103. }
  104. cout << ans;
  105. }
  106. signed main() {
  107. ios_base::sync_with_stdio(0);
  108. cin.tie(NULL);
  109. cout.tie(NULL);
  110. int t = 1;
  111. solve();
  112. return 0;
  113. }
Success #stdin #stdout 0.02s 45752KB
stdin
6
6 1
3 6
5 6
4 3
2 3
8 7 6 7 10 8 
stdout
40