fork download
  1. #include <bits/stdc++.h>
  2.  
  3. using namespace std;
  4.  
  5. typedef long long ll;
  6. typedef pair<int, int> ii;
  7.  
  8. const int INF = 1e9;
  9. const ll LINF = 1e18;
  10.  
  11. const int MOD = 998244353;
  12.  
  13. void add(int& a, int b) {
  14. a += b;
  15. if (a >= MOD) a -= MOD;
  16. }
  17.  
  18. int a[3];
  19.  
  20. vector<int> digit;
  21.  
  22. vector<int> getDigit(ll N) {
  23. vector<int> ans;
  24. for (; N > 0; N /= 2) ans.push_back(N % 2);
  25. return ans;
  26. }
  27.  
  28. int memo[60][2][2][2][2][2][2][10][10][10];
  29.  
  30. // l(i) là điều kiện larger của số thứ i
  31. // s(i) là điều kiện smaller của số thứ i
  32. // r(i) là số dư của phần prefix của số thứ i khi chia cho a(i)
  33. int dp(int idx, int l0, int l1, int l2, int s0, int s1, int s2, int r0, int r1, int r2) {
  34. if (idx == -1) {
  35. return (r0 == 0 && r1 == 0 && r2 == 0);
  36. }
  37.  
  38. int& ans = memo[idx][l0][l1][l2][s0][s1][s2][r0][r1][r2];
  39. if (ans != -1) return ans;
  40.  
  41. ans = 0;
  42. int digit_one = (idx == 0) ? 1 : 0;
  43. int min_digit_x0 = (l0) ? 0 : digit_one, max_digit_x0 = (s0) ? 1 : digit[idx];
  44. int min_digit_x1 = (l1) ? 0 : digit_one, max_digit_x1 = (s1) ? 1 : digit[idx];
  45. int min_digit_x2 = (l2) ? 0 : digit_one, max_digit_x2 = (s2) ? 1 : digit[idx];
  46.  
  47. for (int i = min_digit_x0; i <= max_digit_x0; i++) {
  48. for (int j = min_digit_x1; j <= max_digit_x1; j++) {
  49. for (int k = min_digit_x2; k <= max_digit_x2; k++) {
  50. if ((i ^ j ^ k) != 0) continue;
  51. add(ans, dp(idx - 1, l0 | (i > digit_one), l1 | (j > digit_one), l2 | (k > digit_one),
  52. s0 | (i < digit[idx]), s1 | (j < digit[idx]), s2 | (k < digit[idx]),
  53. (r0 * 2 + i) % a[0], (r1 * 2 + j) % a[1], (r2 * 2 + k) % a[2]));
  54. }
  55. }
  56. }
  57.  
  58. return ans;
  59. }
  60.  
  61. int solve(ll N) {
  62. digit = getDigit(N);
  63. memset(memo, -1, sizeof memo);
  64. return dp(digit.size() - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0);
  65. }
  66.  
  67. int main() {
  68. ios::sync_with_stdio(false);
  69. cin.tie(nullptr);
  70. ll N;
  71. cin >> N;
  72. for (int i = 0; i < 3; i++) cin >> a[i];
  73. int ans = solve(N);
  74. cout << ans << '\n';
  75. }
Success #stdin #stdout 0.01s 18596KB
stdin
13 2 3 5
stdout
4