fork(4) download
  1. #include <vector>
  2. #include <cassert>
  3.  
  4. // ------------ Modulo Class ------------ //
  5. template<unsigned mod>
  6. class modulo {
  7. private:
  8. unsigned x;
  9. public:
  10. modulo() : x(0) {};
  11. modulo(unsigned x_) : x(x_) {};
  12. operator unsigned() { return x; }
  13. modulo operator==(const modulo& m) const { return x == m.x; }
  14. modulo operator!=(const modulo& m) const { return x != m.x; }
  15. modulo& operator+=(const modulo& m) { x = (x + m.x >= mod ? x + m.x - mod : x + m.x); return *this; }
  16. modulo& operator-=(const modulo& m) { x = (x < m.x ? x - m.x + mod : x - m.x); return *this; }
  17. modulo& operator*=(const modulo& m) { x = 1ULL * x * m.x % mod; return *this; }
  18. modulo operator+(const modulo& m) const { return modulo(*this) += m; }
  19. modulo operator-(const modulo& m) const { return modulo(*this) -= m; }
  20. modulo operator*(const modulo& m) const { return modulo(*this) *= m; }
  21. };
  22.  
  23. // ------------ Matrix Functions ------------ //
  24. typedef std::vector<modulo<998244353> > matrix_base;
  25. typedef std::vector<matrix_base> matrix;
  26. matrix mul(const matrix& a, const matrix& b) {
  27. assert(a[0].size() == b.size());
  28. matrix ret(a.size(), matrix_base(b[0].size(), 0));
  29. for (int i = 0; i < a.size(); i++) {
  30. for (int j = 0; j < b[0].size(); j++) {
  31. for (int k = 0; k < b.size(); k++) ret[i][j] += a[i][k] * b[k][j];
  32. }
  33. }
  34. return ret;
  35. }
  36. matrix unit(int n) {
  37. matrix ret(n, matrix_base(n, 0));
  38. for (int i = 0; i < n; i++) ret[i][i] = 1;
  39. return ret;
  40. }
  41. matrix power(const matrix& a, long long b) {
  42. assert(a.size() == a[0].size());
  43. matrix f = a, ret = unit(a.size());
  44. while (b) {
  45. if (b & 1) ret = mul(ret, f);
  46. f = mul(f, f);
  47. b >>= 1;
  48. }
  49. return ret;
  50. }
  51.  
  52. // ------------ Modpower Algorithm ------------ //
  53. inline int modpow(int a, int b, int m) {
  54. int ret = 1;
  55. while (b) {
  56. if (b & 1) ret = 1LL * ret * a % m;
  57. a = 1LL * a * a % m;
  58. b >>= 1;
  59. }
  60. return ret;
  61. }
  62.  
  63. // ------------ Number Theoretic Transform ------------ //
  64. inline static std::vector<int> FastModuloTransform(std::vector<int> v, int base, int root) {
  65. int n = v.size();
  66. for (int i = 0, j = 1; j < n - 1; j++) {
  67. for (int k = n >> 1; k >(i ^= k); k >>= 1);
  68. if (i < j) std::swap(v[i], v[j]);
  69. }
  70. for (int b = 1; b <= n / 2; b *= 2) {
  71. int x = modpow(root, (base - 1) / (b << 1), base);
  72. for (int i = 0; i < n; i += (b << 1)) {
  73. int p = 1;
  74. for (int j = i; j < i + b; j++) {
  75. int t1 = v[j], t2 = 1LL * v[j + b] * p % base;
  76. v[j] = t1 + t2; v[j] = (v[j] < base ? v[j] : v[j] - base);
  77. v[j + b] = t1 - t2 + base; v[j + b] = (v[j + b] < base ? v[j + b] : v[j + b] - base);
  78. p = 1LL * p * x % base;
  79. }
  80. }
  81. }
  82. return v;
  83. }
  84. inline static std::vector<int> FastConvolutionMod(std::vector<int> v1, std::vector<int> v2, int mod, int tr) {
  85. int n = v1.size() * 2; // v1 and v2 must be the same size!!
  86. v1.resize(n);
  87. v2.resize(n);
  88. v1 = FastModuloTransform(v1, mod, tr);
  89. v2 = FastModuloTransform(v2, mod, tr);
  90. for (int i = 0; i < n; i++) v1[i] = 1LL * v1[i] * v2[i] % mod;
  91. v1 = FastModuloTransform(v1, mod, modpow(tr, mod - 2, mod));
  92. int t = modpow(n, mod - 2, mod);
  93. for (int i = 0; i < n; i++) v1[i] = 1LL * v1[i] * t % mod;
  94. return v1;
  95. }
  96.  
  97. // ------------ Lagrange Interpolation ------------ //
  98. std::vector<int> lagrange_interpolation(std::vector<int> &v, int m) {
  99. int n = v.size() - 1;
  100. std::vector<int> inv(n + 2); inv[1] = 1;
  101. for (int i = 2; i <= n; i++) inv[i] = 1LL * inv[m % i] * (m - m / i) % m;
  102. std::vector<int> ret(n + 1);
  103. int q = 1;
  104. for (int i = 1; i <= n; i++) q = 1LL * q * inv[i] % m;
  105. if (n % 2 == 1) q = (m - q) % m;
  106. for (int i = 0; i <= n; i++) {
  107. ret[i] = 1LL * v[i] * q % m;
  108. q = 1LL * q * (m - n + i) % m * inv[i + 1] % m;
  109. }
  110. return ret;
  111. }
  112. int lagrange_function(int x, std::vector<int> &v, int m) {
  113. int n = v.size() - 1;
  114. int mul = 1;
  115. for (int i = 0; i <= n; i++) mul = 1LL * mul * (x - i + m) % m;
  116. int ret = 0;
  117. for (int i = 0; i <= n; i++) ret = (ret + 1LL * v[i] * modpow(x - i + m, m - 2, m)) % m;
  118. return 1LL * ret * mul % m;
  119. }
  120.  
  121. // ------------ Fibonacci Number ------------ //
  122. int nth_fibonacci(long long x) {
  123. matrix e(2, matrix_base(2));
  124. e[0][0] = e[0][1] = e[1][0] = 1;
  125. matrix p(2, matrix_base(1));
  126. p[0][0] = 1; p[1][0] = 0;
  127. return mul(power(e, x), p)[1][0];
  128. }
  129.  
  130. #include <iostream>
  131. using namespace std;
  132. const int mod = 998244353;
  133. int fact[1000009], inv[1000009], factinv[1000009];
  134. int combination(int a, int b) {
  135. if (a < 0 || b < 0 || a < b) return 0;
  136. return 1LL * fact[a] * factinv[b] % mod * factinv[a - b] % mod;
  137. }
  138. int n; long long m;
  139. int main() {
  140. cin >> n >> m; n--; m--;
  141. if (n == 0) {
  142. cout << nth_fibonacci(m + 1) << endl;
  143. }
  144. else {
  145. fact[0] = 1;
  146. for (int i = 1; i < 2 * n; i++) fact[i] = 1LL * fact[i - 1] * i % mod;
  147. inv[1] = 1;
  148. for (int i = 2; i < 2 * n; i++) inv[i] = 1LL * inv[mod % i] * (mod - mod / i) % mod;
  149. factinv[0] = 1;
  150. for (int i = 1; i < 2 * n; i++) factinv[i] = 1LL * factinv[i - 1] * inv[i] % mod;
  151. vector<int> a(3 * n); a[0] = a[1] = 1;
  152. for (int i = 2; i < 3 * n; i++) a[i] = (a[i - 1] + a[i - 2]) % mod;
  153. int e = 1; while (e < n) e <<= 1;
  154. vector<int> b(e), bt(a.begin(), a.begin() + e);
  155. for (int i = 0; i < n; i++) b[i] = combination((n - 1) + i, i);
  156. vector<int> res = FastConvolutionMod(b, bt, mod, 3);
  157. if (m < n) {
  158. cout << res[m] << endl;
  159. }
  160. else {
  161. vector<int> diff(n);
  162. for (int i = 0; i < n; i++) diff[i] = (a[2 * n + i] - res[i] + mod) % mod;
  163. vector<int> poly = lagrange_interpolation(diff, mod);
  164. int f = nth_fibonacci(m + 2 * n + 1);
  165. int ret = lagrange_function(m % mod, poly, mod);
  166. cout << (f - ret + mod) % mod << endl;
  167. }
  168. }
  169. return 0;
  170. }
Success #stdin #stdout 0s 15200KB
stdin
3 20
stdout
46345