fork download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3.  
  4. #define ll long long
  5. #define all(c) ((c).begin()), ((c).end())
  6. #define sz(x) ((int)(x).size())
  7.  
  8. #ifdef LOCAL
  9. #include <print.h>
  10. #else
  11. #define trace(...)
  12. #define endl "\n" // remove in interactive
  13. #endif
  14. int power(int n, int p){
  15. n /= p;
  16. int ret = 0;
  17. while(n){
  18. ret += n;
  19. n /= p;
  20. }
  21. return ret;
  22. }
  23.  
  24. int main(){
  25. ios_base::sync_with_stdio(false);
  26. cin.tie(NULL); // Remove in interactive problems
  27. int m; ll N;
  28. cin >> N >> m;
  29. m = min((ll)m, N);
  30. vector<bool> isprime(m+1, true);
  31.  
  32. for(int p = 2; p * p <= m; p++) if(isprime[p]){
  33. for(int q = p * p; q <= m; q += p) isprime[q] = false;
  34. }
  35. vector<ll> Sum(m + 1);
  36. vector<int> tot;
  37. vector<ll> primes;
  38. for(int p = 2; p <= m; p++){
  39. Sum[p] = Sum[p - 1];
  40. if(p % 3 == 2 && isprime[p]){
  41. primes.push_back(p);
  42. tot.push_back(power(m, p));
  43. Sum[p] += p;
  44. }
  45. }
  46. int nump = primes.size();
  47. const int mod = 998244353;
  48. const ll mod2 = mod * 8LL * mod;
  49.  
  50. function<ll(ll)> get = [&](ll n){
  51. ll ret = 2; // y = 1
  52. function<void(ll, int, int,int)> recurse = [&](ll x, int i, int t, int e){
  53. int p = x == 1 ? 1 : primes[i];
  54. if(e > 0 && e < tot[i]) ret += t * x * p; // y = x * p
  55. ret += ((3 - t) * x % mod) * ((Sum[min((ll)m, n / x)] - Sum[p]) % mod); // y = x * q, q > p
  56. if(ret >= mod2) ret -= mod2;
  57. ll mx = n / x;
  58. if(primes[i] * primes[i] > mx) return;
  59. if(e < tot[i]) recurse(x * primes[i], i, e == 0 ? 3 - t: t, e + 1);
  60. for(int j = i + 1; j < nump && primes[j] * primes[j] <= mx; j++){
  61. recurse(x * primes[j], j, 3 - t, 1);
  62. }
  63. };
  64. recurse(1, 0, 2, 0);
  65. return ret % mod;
  66. };
  67. ll ans = 0;
  68. int maxp2 = power(m, 2);
  69. for(ll x = 2, i = 1; x <= N && i <= maxp2; x *= 2, i++) ans += x / 2;
  70. int maxp = power(m, 3);
  71.  
  72. for(ll x = 3, i = 1; x <= N && i <= maxp; x *= 3, i++){
  73. ans += (get(N / x)) * ( (x / 3) % mod);
  74. ans %= mod;
  75. }
  76. cout << ans << endl;
  77. }
Success #stdin #stdout 0s 5300KB
stdin
Standard input is empty
stdout
0