#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define all(c) ((c).begin()), ((c).end())
#define sz(x) ((int)(x).size())

#ifdef LOCAL
#include <print.h>
#else
#define trace(...)
#define endl "\n" // remove in interactive
#endif
int power(int n, int p){
    n /= p;
    int ret = 0;
    while(n){
        ret += n;
        n /= p;
    }
    return ret;
}

int main(){
    ios_base::sync_with_stdio(false); 
    cin.tie(NULL); // Remove in interactive problems
    int m; ll N;
    cin >> N >> m;
    m = min((ll)m, N);
    vector<bool> isprime(m+1, true);

    for(int p = 2; p * p <= m; p++) if(isprime[p]){
        for(int q = p * p; q <= m; q += p) isprime[q] = false;
    }
    vector<ll> Sum(m + 1);
    vector<int> tot;
    vector<ll> primes;
    for(int p = 2; p <= m; p++){
        Sum[p] = Sum[p - 1];
        if(p % 3 == 2 && isprime[p]){
            primes.push_back(p);
            tot.push_back(power(m, p));
            Sum[p] += p;
        }
    }
    int nump = primes.size();
    const int mod = 998244353;
    const ll mod2 = mod * 8LL * mod;

    function<ll(ll)> get = [&](ll n){
        ll ret = 2; // y = 1
        function<void(ll, int, int,int)> recurse = [&](ll x, int i, int t, int e){
            int p = x == 1 ? 1 : primes[i];
            if(e > 0 && e < tot[i]) ret += t * x * p; // y = x * p
            ret += ((3 - t) * x % mod) * ((Sum[min((ll)m, n / x)] - Sum[p]) % mod); // y = x * q, q > p
            if(ret >= mod2) ret -= mod2;
            ll mx = n / x;
            if(primes[i] *  primes[i] > mx) return;
            if(e < tot[i]) recurse(x * primes[i], i, e == 0 ? 3 - t: t, e + 1);
            for(int j = i + 1; j < nump && primes[j] * primes[j] <= mx; j++){
                recurse(x * primes[j], j, 3 - t, 1);
            }
        };
        recurse(1, 0, 2, 0);
        return ret % mod;
    };
    ll ans = 0;
    int maxp2 = power(m, 2);
    for(ll x = 2, i = 1; x <= N && i <= maxp2; x *= 2, i++) ans += x / 2;
    int maxp = power(m, 3);

    for(ll x = 3, i = 1; x <= N && i <= maxp; x *= 3, i++){
        ans += (get(N / x)) * ( (x / 3) % mod);
        ans %= mod;
    }
    cout << ans << endl;
}