fork download
  1. #include <bits/stdc++.h>
  2.  
  3. using namespace std;
  4.  
  5. typedef long long ll;
  6. typedef pair<int, int> ii;
  7.  
  8. const int INF = 1e9;
  9. const ll LINF = 1e18;
  10.  
  11. const int MOD = 998244353;
  12.  
  13. const int N = 3e5 + 5;
  14.  
  15. // Gọi dp[i] = Đáp án khi xét các phần tử từ a[1] đến a[i]
  16. // dp[i] = tổng{dp[j] * (max{a[j + 1..i]} - min{a[j + 1..i]})} (với j < i)
  17. // = tổng{dp[j] * max{a[j + 1..i]}} - tổng{dp[j] * min{a[j + 1..i]}}
  18. // = A - B
  19. // Ở đây ta sẽ bàn về cách tính A, cách tính B hoàn toàn tương tự
  20. // Khi xét đến i, để tính A ta xét 2 trường hợp:
  21. // Trường hợp 1: a[i] đóng vai trò là max
  22. // - Gọi l[i] = phần tử gần nhất bên trái > a[i]
  23. // - a[i] sẽ đóng vai trò là max trong đoạn [j + 1, i] với l[i] <= j < i
  24. // => A = tổng{dp[j]} * a[i] với j thuộc đoạn [l[i], i - 1]
  25. // => Tối ưu bằng prefix sum
  26. // Trường hợp 2: a[i] không đóng vai trò là max
  27. // => A = sum_max[l[i]]
  28. // sum_max[i] là tổng A ở trường hợp 1 khi xét dãy ..., l[l[i]], l[i], i (đọc code để hiểu rõ hơn sum_max[])
  29. void add(int& a, int b) {
  30. a += b;
  31. if (a >= MOD) a -= MOD;
  32. if (a < 0) a += MOD;
  33. }
  34.  
  35. int n;
  36. int a[N];
  37.  
  38. int l[2][N]; // l[0][i] = Phần tử gần nhất bên trái > a[i]
  39. // l[1][i] = Phần tử gần nhất bên trái < a[i]
  40.  
  41. int dp[N];
  42. int pref[N];
  43. int sum_max[N];
  44. int sum_min[N];
  45.  
  46. int getSum(int l, int r) {
  47. int ans = pref[r];
  48. if (l > 0) add(ans, -pref[l - 1]);
  49. return ans;
  50. }
  51.  
  52. int main() {
  53. ios::sync_with_stdio(false);
  54. cin.tie(nullptr);
  55. cin >> n;
  56. for (int i = 1; i <= n; i++) cin >> a[i];
  57.  
  58. vector<int> st;
  59. for (int i = 1; i <= n; i++) {
  60. while (!st.empty() && a[st.back()] <= a[i]) st.pop_back();
  61. l[0][i] = st.empty() ? 0 : st.back();
  62. st.push_back(i);
  63. }
  64.  
  65. st.clear();
  66. for (int i = 1; i <= n; i++) {
  67. while (!st.empty() && a[st.back()] >= a[i]) st.pop_back();
  68. l[1][i] = st.empty() ? 0 : st.back();
  69. st.push_back(i);
  70. }
  71.  
  72. dp[0] = 1;
  73. pref[0] = 1;
  74. sum_max[0] = sum_min[0] = 0;
  75.  
  76. for (int i = 1; i <= n; i++) {
  77. // Tính A
  78. int j = l[0][i];
  79. int cur = 1ll * getSum(j, i - 1) * a[i] % MOD;
  80. add(dp[i], cur);
  81. sum_max[i] = cur;
  82. if (j > 0) {
  83. add(dp[i], sum_max[j]);
  84. add(sum_max[i], sum_max[j]);
  85. }
  86.  
  87. // Tính B
  88. j = l[1][i];
  89. cur = 1ll * getSum(j, i - 1) * a[i] % MOD;
  90. add(dp[i], -cur);
  91. sum_min[i] = cur;
  92. if (j > 0) {
  93. add(dp[i], -sum_min[j]);
  94. add(sum_min[i], sum_min[j]);
  95. }
  96.  
  97. // Cập nhật prefix sum
  98. pref[i] = pref[i - 1];
  99. add(pref[i], dp[i]);
  100. }
  101.  
  102. cout << dp[n] << '\n';
  103. }
Success #stdin #stdout 0.01s 9744KB
stdin
3
1 2 3
stdout
2