#include <bits/stdc++.h>
#include <ext/rope>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
//#pragma GCC optimize("Ofast")
//#pragma GCC optimize("unroll-loops")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#define ll long long
#define pb push_back
#define x first
#define ld long double
#define y second
#define mk(a,b) make_pair(a,b)
#define rr return 0
#define uid uniform_int_distribution
#define urd uniform_real_distribution
#define sqr(a) ((a)*(a))
#define all(a) a.begin(),a.end()

using namespace std;

//using namespace __gnu_cxx;
//using namespace __gnu_pbds;
//template<class value, class cmp = less<value> >
//using ordered_set = tree<value, null_type, cmp, rb_tree_tag, tree_order_statistics_node_update>;
//template<class key, class value, class cmp = less<key> >
//using ordered_map = tree<key, value, cmp, rb_tree_tag, tree_order_statistics_node_update>;
//
///// find_by_order()
///// order_of_key()
//
//mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
//inline int randll(int l = LONG_MIN, int r = LONG_MAX) {
//    return uid<int>(l, r)(rng);
//}

const int N = 1e5 + 11;
ll a[N];
inline ll solve(ll n, ll m) {
    ll pos = n - 2;
    ll ans = 0;
    while (pos > 0 && a[pos] == a[n - 1])
        --pos;
    while (m && a[n - 1]) {
//        cout << a[n - 1] << ' ' << pos << "\n---\n";
        if ((n - 1 - pos) * (a[n - 1] - a[pos]) < m) {
            m -= (n - 1 - pos) * (a[n - 1] - a[pos]);
            a[n - 1] = a[pos];
        }
        else {
            for (int i = n - 2; i > pos; i--)
                a[i] = a[n - 1];
            ll k = m / (n - 1 - pos);
            ll ost = m - k * (n - 1 - pos);
            for (int i = n - 1; i > pos; i--) {
                a[i] -= k;
                a[i] -= (ost > 0);
                --ost;
            }
            m = 0;
        }
        while (pos > 0 && a[pos] == a[n - 1])
            --pos;
    }
    for (int i = 0; i < n; i++) {
        ans += a[i] * a[i];
    }
    return ans;
}
main()
{
    ios::sync_with_stdio(0);
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    ll m, n;
    cin >> m >> n;
    a[0] = 0;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    ++n;
    sort(a, a + n);
    cout << solve(n, m);
}
