#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
using namespace std;
class UBigInt;
extern bool primalityTestMillerRabin(const UBigInt& p, const unsigned long& cnt = 100);
class UBigInt {
public:
UBigInt(const bool& bit = false);
UBigInt(const shared_ptr<vector<bool> >& bits);
UBigInt(const unsigned long& val);
UBigInt(const UBigInt& num);
UBigInt operator +(const UBigInt& ade) const;
UBigInt operator -(const UBigInt& sub) const;
UBigInt operator *(const UBigInt& mer) const;
UBigInt operator /(const UBigInt& dsr) const;
UBigInt operator %(const UBigInt& nrm) const;
UBigInt operator <<(const size_t& wid) const;
UBigInt operator >>(const size_t& wid) const;
UBigInt& operator =(const UBigInt& num);
UBigInt& operator +=(const UBigInt& ade);
UBigInt& operator -=(const UBigInt& sub);
UBigInt& operator *=(const UBigInt& mer);
UBigInt& operator /=(const UBigInt& dsr);
UBigInt& operator %=(const UBigInt& nrm);
UBigInt& operator <<=(const size_t& wid);
UBigInt& operator >>=(const size_t& wid);
UBigInt& operator ++();
UBigInt operator ++(int);
UBigInt& operator --();
UBigInt operator --(int);
bool operator ==(const UBigInt& num) const;
bool operator <(const UBigInt& num) const;
bool operator >(const UBigInt& num) const;
bool operator !=(const UBigInt& num) const;
bool operator <=(const UBigInt& num) const;
bool operator >=(const UBigInt& num) const;
UBigInt power(const UBigInt& exp);
UBigInt modPower(const UBigInt& exp, const UBigInt& nrm);
void powerAssign(const UBigInt& exp);
void modPowerAssign(const UBigInt& exp, const UBigInt& nrm);
bool isZero() const;
bool getLowestBit() const;
unsigned long getLowestDWord() const;
string toString(const unsigned int& base = 10) const;
static UBigInt fromString(const string& str, const unsigned int& base = 10);
protected:
shared_ptr<vector<bool> > bits;
static shared_ptr<vector<bool> > makeBits(const unsigned long& val);
static shared_ptr<vector<bool> > add(const shared_ptr<vector<bool> >& age, const shared_ptr<vector<bool> >& ade);
static shared_ptr<vector<bool> > subtract(const shared_ptr<vector<bool> >& min, const shared_ptr<vector<bool> >& sub);
static shared_ptr<vector<bool> > multiply(const shared_ptr<vector<bool> >& mca, const shared_ptr<vector<bool> >& mer);
static shared_ptr<vector<bool> > divide(const shared_ptr<vector<bool> >& dde, const shared_ptr<vector<bool> >& dsr, shared_ptr<vector<bool> >* rem = NULL);
static shared_ptr<vector<bool> > leftShift(const shared_ptr<vector<bool> >& bits, const size_t& wid);
static shared_ptr<vector<bool> > rightShift(const shared_ptr<vector<bool> >& bits, const size_t& wid);
static bool less(const shared_ptr<vector<bool> >& left, const shared_ptr<vector<bool> >& right);
static bool equal(const shared_ptr<vector<bool> >& left, const shared_ptr<vector<bool> >& right);
static bool isZero(const shared_ptr<vector<bool> >& bits);
static unsigned long getLowestDWord(const shared_ptr<vector<bool> >& bits);
static unsigned int charToFigure(const char& ch);
static char figureToChar(const unsigned int& fig);
static shared_ptr<vector<bool> > trim(const shared_ptr<vector<bool> >& bits);
};
#define ASSERT(pred) assert(#pred, pred);
extern void assert(const char* pred_str, const bool& pred);
bool primalityTestMillerRabin(const UBigInt& p, const unsigned long& cnt) {
bool res;
if (p == 1UL) res = false;
else if (p == 2UL) res = true;
else if (p.getLowestBit() == false) res = false;
else {
UBigInt q = p - 1UL;
UBigInt k(true);
while (q.getLowestBit() == false) {
q >>= 1;
k++;
}
UBigInt a(2UL);
UBigInt gap(((p - 1UL) - 2UL) / cnt);
if (gap.isZero()) gap++;
for (; a < p; a += gap) {
UBigInt b = a.modPower(q, p);
if (b != 1UL && b != p - 1UL) {
UBigInt i(1UL);
for (; i < k; i++) {
b = (b * b) % p;
if (b == 1UL || b == p - 1UL) break;
}
if (i == k) {
res = false;
break;
}
}
}
if (a >= p) res = true;
}
return res;
}
UBigInt::UBigInt(const bool& bit) : bits(new vector<bool>(1, bit)) {}
UBigInt::UBigInt(const shared_ptr<vector<bool> >& bits) : bits(bits) {}
UBigInt::UBigInt(const unsigned long& val) : bits(makeBits(val)) {}
UBigInt::UBigInt(const UBigInt& num) : bits(new vector<bool>(*num.bits)) {}
UBigInt UBigInt::operator +(const UBigInt& ade) const {
return UBigInt(add(this->bits, ade.bits));
}
UBigInt UBigInt::operator -(const UBigInt& sub) const {
return UBigInt(subtract(this->bits, sub.bits));
}
UBigInt UBigInt::operator *(const UBigInt& mer) const {
return UBigInt(multiply(this->bits, mer.bits));
}
UBigInt UBigInt::operator /(const UBigInt& dsr) const {
return UBigInt(divide(this->bits, dsr.bits));
}
UBigInt UBigInt::operator %(const UBigInt& nrm) const {
shared_ptr<vector<bool> > rem;
divide(this->bits, nrm.bits, &rem);
return UBigInt(rem);
}
UBigInt UBigInt::operator <<(const size_t& wid) const {
return UBigInt(leftShift(this->bits, wid));
}
UBigInt UBigInt::operator >>(const size_t& wid) const {
return UBigInt(rightShift(this->bits, wid));
}
UBigInt& UBigInt::operator =(const UBigInt& num) {
this->bits = shared_ptr<vector<bool> >(new vector<bool>(*num.bits));
return *this;
}
UBigInt& UBigInt::operator +=(const UBigInt& add) {
return *this = *this + add;
}
UBigInt& UBigInt::operator -=(const UBigInt& sub) {
return *this = *this - sub;
}
UBigInt& UBigInt::operator *=(const UBigInt& mer) {
return *this = *this * mer;
}
UBigInt& UBigInt::operator /=(const UBigInt& dsr) {
return *this = *this / dsr;
}
UBigInt& UBigInt::operator %=(const UBigInt& nrm) {
return *this = *this % nrm;
}
UBigInt& UBigInt::operator <<=(const size_t& wid) {
return *this = *this << wid;
}
UBigInt& UBigInt::operator >>=(const size_t& wid) {
return *this = *this >> wid;
}
UBigInt& UBigInt::operator ++() {
return *this += UBigInt(true);
}
UBigInt UBigInt::operator ++(int) {
UBigInt res = *this;
*this += UBigInt(true);
return res;
}
UBigInt& UBigInt::operator --() {
return *this -= UBigInt(true);
}
UBigInt UBigInt::operator --(int) {
UBigInt res = *this;
*this -= UBigInt(true);
return res;
}
bool UBigInt::operator ==(const UBigInt& num) const {
return equal(this->bits, num.bits);
}
bool UBigInt::operator <(const UBigInt& num) const {
return less(this->bits, num.bits);
}
bool UBigInt::operator >(const UBigInt& num) const {
return !equal(this->bits, num.bits) && !less(this->bits, num.bits);
}
bool UBigInt::operator !=(const UBigInt& num) const {
return !equal(this->bits, num.bits);
}
bool UBigInt::operator <=(const UBigInt& num) const {
return equal(this->bits, num.bits) || less(this->bits, num.bits);
}
bool UBigInt::operator >=(const UBigInt& num) const {
return !less(this->bits, num.bits);
}
UBigInt UBigInt::power(const UBigInt& exp) {
if (isZero()) return UBigInt(false);
UBigInt res(true);
UBigInt coef = *this;
for (UBigInt exp2 = exp;;) {
if (exp2.getLowestBit()) {
res *= coef;
}
exp2 >>= 1;
if (exp2.isZero()) break;
coef *= coef;
}
return res;
}
UBigInt UBigInt::modPower(const UBigInt& exp, const UBigInt& nrm) {
if (nrm.isZero()) throw string("法がゼロ。");
if (isZero()) return UBigInt(false);
UBigInt res(true);
UBigInt coef = *this % nrm;
for (UBigInt exp2 = exp;;) {
if (exp2.getLowestBit()) {
res = (res * coef) % nrm;
}
exp2 >>= 1;
if (exp2.isZero()) break;
coef = (coef * coef) % nrm;
}
return res;
}
void UBigInt::powerAssign(const UBigInt& exp) {
*this = power(exp);
}
void UBigInt::modPowerAssign(const UBigInt& exp, const UBigInt& nrm) {
*this = modPower(exp, nrm);
}
bool UBigInt::isZero() const {
return isZero(this->bits);
}
bool UBigInt::getLowestBit() const {
return this->bits->back();
}
unsigned long UBigInt::getLowestDWord() const {
return getLowestDWord(this->bits);
}
string UBigInt::toString(const unsigned int& base) const {
string res;
shared_ptr<vector<bool> > wrk = this->bits;
shared_ptr<vector<bool> > base_bits(makeBits(base));
while (!isZero(wrk)) {
shared_ptr<vector<bool> > rem;
wrk = divide(wrk, base_bits, &rem);
res += figureToChar(getLowestDWord(rem));
}
if (res.empty()) res += '0';
reverse(res.begin(), res.end());
return res;
}
UBigInt UBigInt::fromString(const string& str, const unsigned int& base) {
shared_ptr<vector<bool> > bits(new vector<bool>(1, false));
shared_ptr<vector<bool> > base_bits(makeBits(base));
for (string::const_iterator iter = str.begin(); iter != str.end(); iter++) {
if (iter != str.begin()) bits = multiply(bits, base_bits);
unsigned int fig = charToFigure(*iter);
if (fig >= base) throw string("値が範囲を逸脱している。");
bits = add(bits, makeBits(fig));
}
return UBigInt(bits);
}
shared_ptr<vector<bool> > UBigInt::makeBits(const unsigned long& val) {
shared_ptr<vector<bool> > res(new vector<bool>);
for (int i = 31; i >= 0; i--) {
if ((val >> i) & 1) res->push_back(true);
else if (!res->empty()) res->push_back(false);
}
if (res->empty()) res->push_back(false);
return res;
}
shared_ptr<vector<bool> > UBigInt::add(const shared_ptr<vector<bool> >& age, const shared_ptr<vector<bool> >& ade) {
shared_ptr<vector<bool> > sum(new vector<bool>);
vector<bool>::const_reverse_iterator age_iter = age->rbegin();
vector<bool>::const_reverse_iterator ade_iter = ade->rbegin();
bool carry = false;
while (age_iter != age->rend() || ade_iter != ade->rend()) {
int age_val = age_iter != age->rend() ? int(*age_iter) : 0;
int ade_val = ade_iter != ade->rend() ? int(*ade_iter) : 0;
int sum_val = age_val + ade_val + int(carry);
carry = sum_val > 1;
if (carry) sum_val -= 2;
sum->push_back(sum_val == 1);
if (age_iter != age->rend()) age_iter++;
if (ade_iter != ade->rend()) ade_iter++;
}
if (carry) sum->push_back(true);
if (sum->empty()) sum->push_back(false);
reverse(sum->begin(), sum->end());
return sum;
}
shared_ptr<vector<bool> > UBigInt::subtract(const shared_ptr<vector<bool> >& min, const shared_ptr<vector<bool> >& sub) {
shared_ptr<vector<bool> > dif(new vector<bool>);
vector<bool>::const_reverse_iterator min_iter = min->rbegin();
vector<bool>::const_reverse_iterator sub_iter = sub->rbegin();
bool borrow = false;
while (min_iter != min->rend() || sub_iter != sub->rend()) {
int min_val = min_iter != min->rend() ? int(*min_iter) : 0;
int sub_val = sub_iter != sub->rend() ? int(*sub_iter) : 0;
int dif_val = min_val - sub_val - int(borrow);
borrow = dif_val < 0;
if (borrow) dif_val += 2;
dif->push_back(dif_val == 1);
if (min_iter != min->rend()) min_iter++;
if (sub_iter != sub->rend()) sub_iter++;
}
if (borrow) throw string("被減数が減数より小さい。");
reverse(dif->begin(), dif->end());
return trim(dif);
}
shared_ptr<vector<bool> > UBigInt::multiply(const shared_ptr<vector<bool> >& mca, const shared_ptr<vector<bool> >& mer) {
shared_ptr<vector<bool> > pro(new vector<bool>(1, false));
vector<bool>::const_reverse_iterator mer_iter = mer->rbegin();
size_t wid = 0;
while (mer_iter != mer->rend()) {
if (*mer_iter) pro = add(pro, leftShift(mca, wid));
mer_iter++;
wid++;
}
return pro;
}
shared_ptr<vector<bool> > UBigInt::divide(const shared_ptr<vector<bool> >& dde, const shared_ptr<vector<bool> >& dsr, shared_ptr<vector<bool> >* rem) {
if (isZero(dsr)) throw string("除数がゼロ。");
shared_ptr<vector<bool> > quo(new vector<bool>);
shared_ptr<vector<bool> > fld(new vector<bool>);
for (vector<bool>::const_iterator dde_iter = dde->begin(); dde_iter != dde->end(); dde_iter++) {
fld->push_back(*dde_iter);
fld = trim(fld);
if (!less(fld, dsr)) {
quo->push_back(true);
fld = subtract(fld, dsr);
}
else if (!quo->empty()) quo->push_back(false);
}
if (quo->empty()) quo->push_back(false);
if (rem) *rem = fld;
return quo;
}
shared_ptr<vector<bool> > UBigInt::leftShift(const shared_ptr<vector<bool> >& bits, const size_t& wid) {
shared_ptr<vector<bool> > res(new vector<bool>(*bits));
for (int i = 0; i < wid; i++) res->push_back(false);
return res;
}
shared_ptr<vector<bool> > UBigInt::rightShift(const shared_ptr<vector<bool> >& bits, const size_t& wid) {
shared_ptr<vector<bool> > res;
if (wid < bits->size()) res = shared_ptr<vector<bool> >(new vector<bool>(bits->begin(), bits->end() - wid));
else res = shared_ptr<vector<bool> >(new vector<bool>(1, false));
return res;
}
bool UBigInt::equal(const shared_ptr<vector<bool> >& left, const shared_ptr<vector<bool> >& right) {
bool res;
if (left->size() != right->size()) res = false;
else {
vector<bool>::const_iterator left_iter = left->begin();
vector<bool>::const_iterator right_iter = right->begin();
for (;;) {
if (left_iter == left->end()) {
res = true;
break;
}
else if (!*left_iter && *right_iter) {
res = false;
break;
}
else if (*left_iter && !*right_iter) {
res = false;
break;
}
left_iter++;
right_iter++;
}
}
return res;
}
bool UBigInt::less(const shared_ptr<vector<bool> >& left, const shared_ptr<vector<bool> >& right) {
bool res;
if (left->size() < right->size()) res = true;
else if (left->size() > right->size()) res = false;
else {
vector<bool>::const_iterator left_iter = left->begin();
vector<bool>::const_iterator right_iter = right->begin();
for (;;) {
if (left_iter == left->end()) {
res = false;
break;
}
else if (!*left_iter && *right_iter) {
res = true;
break;
}
else if (*left_iter && !*right_iter) {
res = false;
break;
}
left_iter++;
right_iter++;
}
}
return res;
}
bool UBigInt::isZero(const shared_ptr<vector<bool> >& bits) {
return bits->size() == 1 && bits->front() == false;
}
unsigned long UBigInt::getLowestDWord(const shared_ptr<vector<bool> >& bits) {
unsigned long res = 0;
vector<bool>::const_iterator iter = bits->begin();
for (int i = 0; i < 32; i++) {
if (iter == bits->end()) break;
if (i != 0) res <<= 1;
res |= int(*iter);
iter++;
}
return res;
}
unsigned int UBigInt::charToFigure(const char& ch) {
unsigned int res;
if (ch >= '0' && ch <= '9') res = ch - '0';
else if (ch >= 'A' && ch <= 'Z') res = 10 + ch - 'A';
else if (ch >= 'a' && ch <= 'z') res = 10 + ch - 'a';
else throw string("文字が数ではない。");
return res;
}
char UBigInt::figureToChar(const unsigned int& fig) {
char res;
if (fig >= 0 && fig <= 9) res = '0' + fig;
else if (fig >= 10 && fig <= 36) res = 'A' + (fig - 10);
else throw string("値が範囲を逸脱している。");
return res;
}
shared_ptr<vector<bool> > UBigInt::trim(const shared_ptr<vector<bool> >& bits) {
size_t pos = 0;
while (pos != bits->size() && bits->at(pos) == false) pos++;
if (pos == bits->size()) pos--;
return shared_ptr<vector<bool> >(new vector<bool>(bits->begin() + pos, bits->end()));
}
void assert(const char* pred_str, const bool& pred) {
if (pred) {
cout << "アサート成功: " << pred_str << endl;
}
else {
cerr << "アサート失敗: " << pred_str << endl;
exit(1);
}
}
int main() {
try {
ASSERT(primalityTestMillerRabin(UBigInt::fromString("0")) == false)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("1")) == false)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("2")) == true)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("3")) == true)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("4")) == false)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("5")) == true)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("6")) == false)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("7")) == true)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("8")) == false)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("9")) == false)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("87")) == false)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("89")) == true)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("329")) == false)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("331")) == true)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("561")) == false)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("563")) == true)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("807")) == false)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("809")) == true)
ASSERT(primalityTestMillerRabin(UBigInt::fromString("2251")) == true)
// ASSERT(primalityTestMillerRabin(UBigInt::fromString("59561")) == true)
// ASSERT(primalityTestMillerRabin(UBigInt::fromString("181243")) == true)
// ASSERT(primalityTestMillerRabin(UBigInt::fromString("14414443")) == true)
}
catch (const string& msg) {
cerr << msg << endl;
return 1;
}
return 0;
}