fork download
  1. #include <bits/stdc++.h>
  2.  
  3. using namespace std;
  4.  
  5. typedef long long ll;
  6. typedef pair<int, int> ii;
  7.  
  8. const int INF = 1e9;
  9. const ll LINF = 1e18;
  10.  
  11. const int N = 1e6 + 5;
  12.  
  13. struct Edge {
  14. int u, v, w;
  15. bool operator<(const Edge& other) const {
  16. return w < other.w;
  17. }
  18. };
  19.  
  20. int n;
  21. int w[N];
  22. vector<Edge> edges;
  23.  
  24. struct dsu {
  25. int n;
  26. vector<int> p, sz;
  27.  
  28. dsu(int n): n(n) {
  29. p.resize(n);
  30. sz.resize(n);
  31. for (int u = 1; u < n; u++) {
  32. p[u] = u;
  33. sz[u] = 1;
  34. }
  35. }
  36.  
  37. int findSet(int u) {
  38. if (p[u] == u) return u;
  39. return p[u] = findSet(p[u]);
  40. }
  41.  
  42. ll unionSet(int u, int v) {
  43. u = findSet(u);
  44. v = findSet(v);
  45. if (u == v) return 0;
  46. if (sz[u] < sz[v]) swap(u, v);
  47. ll cnt = 1ll * sz[u] * sz[v];
  48. p[v] = u;
  49. sz[u] += sz[v];
  50. return cnt;
  51. }
  52. };
  53.  
  54. // tổng max - tổng min
  55. // Có thể chuyển từ trọng số của đỉnh về trọng số của cạnh cho dễ đếm
  56. // Cách đếm hoàn toàn tương tự với bài G. Path Queries
  57.  
  58. ll sumMax() {
  59. vector<Edge> cur;
  60. for (Edge e : edges) {
  61. int cur_w = max(w[e.u], w[e.v]);
  62. cur.push_back({e.u, e.v, cur_w});
  63. }
  64.  
  65. sort(cur.begin(), cur.end());
  66.  
  67. ll ans = 0;
  68. dsu DSU(n + 1);
  69.  
  70. for (Edge e : cur) {
  71. ans += 1ll * e.w * DSU.unionSet(e.u, e.v);
  72. }
  73.  
  74. return ans;
  75. }
  76.  
  77. ll sumMin() {
  78. vector<Edge> cur;
  79. for (Edge e : edges) {
  80. int cur_w = min(w[e.u], w[e.v]);
  81. cur.push_back({e.u, e.v, cur_w});
  82. }
  83.  
  84. sort(cur.rbegin(), cur.rend());
  85.  
  86. ll ans = 0;
  87. dsu DSU(n + 1);
  88.  
  89. for (Edge e : cur) {
  90. ans += 1ll * e.w * DSU.unionSet(e.u, e.v);
  91. }
  92.  
  93. return ans;
  94. }
  95.  
  96. int main() {
  97. ios::sync_with_stdio(false);
  98. cin.tie(nullptr);
  99. cin >> n;
  100. for (int u = 1; u <= n; u++) cin >> w[u];
  101.  
  102. for (int i = 0; i < n - 1; i++) {
  103. int u, v;
  104. cin >> u >> v;
  105. edges.push_back({u, v, 0});
  106. }
  107.  
  108. ll ans = sumMax() - sumMin();
  109.  
  110. cout << ans << '\n';
  111. }
  112.  
Success #stdin #stdout 0.01s 5284KB
stdin
4
1 1 2 3
1 2
1 3
1 4
stdout
8