#include <algorithm>
#include <array>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <iomanip>
#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
using namespace std;
class EA;
class RSA;
class RPG;
class SimpleRPG;
class RBG;
class SimpleRBG;
class Sieve;
class PT;
class MRBPT;
class HA;
class SHA1;
class BigMath;
class BigInt;
class BitField;
class ByteInputStream;
class ByteOutputStream;
class OctetStringByteInputStream;
class OctetStringByteOutputStream;
class StringByteInputStream;
class StringByteOutputStream;
class BitInputStream;
class BitInputStreamConnector;
template <class X> class PrimitiveBitInputStream;
template <class X> class Pool;
template <class X> class ExArray;
template <class X> class PrimitiveExArray;
template <class X> class ClassObjectExArray;
typedef ExArray<unsigned char> OctetString;
template <class X>
class ExArray {
public:
virtual X& operator [](const unsigned int& i) = 0;
virtual X& at(const unsigned int& i) = 0;
virtual X& last() = 0;
virtual unsigned int getSize() const = 0;
virtual bool isEmpty() const = 0;
virtual void push(const X& x) = 0;
virtual void pop() = 0;
virtual void clear() = 0;
virtual void resize(const unsigned int& sz) = 0;
};
template <class X>
class PrimitiveExArray : public ExArray<X> {
public:
PrimitiveExArray(const unsigned int& sz = 0);
PrimitiveExArray(const PrimitiveExArray& a);
virtual ~PrimitiveExArray();
PrimitiveExArray& operator =(const PrimitiveExArray& a);
virtual X& operator [](const unsigned int& i);
virtual X& at(const unsigned int& i);
virtual X& last();
virtual unsigned int getSize() const;
virtual bool isEmpty() const;
virtual void push(const X& x);
virtual void pop();
virtual void clear();
virtual void resize(const unsigned int& sz);
protected:
X* buf;
unsigned int buf_sz;
unsigned int sz;
};
template <class X>
class ClassObjectExArray : public ExArray<X> {
public:
ClassObjectExArray(const unsigned int& sz = 0);
ClassObjectExArray(const ClassObjectExArray& a);
virtual ~ClassObjectExArray();
ClassObjectExArray& operator =(const ClassObjectExArray& a);
virtual X& operator [](const unsigned int& i);
virtual X& at(const unsigned int& i);
virtual X& last();
virtual unsigned int getSize() const;
virtual bool isEmpty() const;
virtual void push(const X& x);
virtual void pop();
virtual void clear();
virtual void resize(const unsigned int& sz);
protected:
X* buf;
unsigned int buf_sz;
unsigned int sz;
};
template <class X>
class Pool {
public:
class Ptr {
public:
Ptr(Pool* pool);
Ptr(const Ptr& ptr);
virtual ~Ptr();
Ptr& operator =(const Ptr& ptr);
X& operator *() const;
X* operator ->() const;
X* get() const;
protected:
Pool* pool;
pair<X*, unsigned int*> x_cnt;
void release();
};
Pool();
virtual ~Pool();
protected:
PrimitiveExArray<pair<X*, unsigned int*> > xa;
Pool(const Pool& pool);
Pool& operator =(const Pool& pool);
pair<X*, unsigned int*> get();
void putBack(const pair<X*, unsigned int*>& x_cnt);
friend Ptr;
};
class BigInt {
public:
BigInt();
BigInt(const int& num);
BigInt(const string& str, const unsigned int& base = 10);
BigInt(const BitField& fld, const int& sign = 1);
BigInt(const BigInt& num);
BigInt operator -() const;
BigInt operator +(const BigInt& ade) const;
BigInt operator -(const BigInt& sub) const;
BigInt operator *(const BigInt& mer) const;
BigInt operator /(const BigInt& dsr) const;
BigInt operator %(const BigInt& nrm) const;
pair<BigInt, BigInt> divide(const BigInt& dsr) const;
BigInt operator <<(const size_t& num) const;
BigInt operator >>(const size_t& num) const;
BigInt& operator =(const BigInt& num);
BigInt& operator +=(const BigInt& ade);
BigInt& operator -=(const BigInt& sub);
BigInt& operator *=(const BigInt& mer);
BigInt& operator /=(const BigInt& dsr);
BigInt& operator %=(const BigInt& nrm);
BigInt& operator <<=(const unsigned int& num);
BigInt& operator >>=(const unsigned int& num);
BigInt& operator ++();
BigInt operator ++(int);
BigInt& operator --();
BigInt operator --(int);
bool operator ==(const BigInt& num) const;
bool operator <(const BigInt& num) const;
bool operator >(const BigInt& num) const;
bool operator !=(const BigInt& num) const;
bool operator <=(const BigInt& num) const;
bool operator >=(const BigInt& num) const;
BigInt abs() const;
bool isZero() const;
int getSign() const;
bool getLSBit() const;
unsigned char getByte(const unsigned int ls_pos) const;
unsigned char getLSByte() const;
unsigned long getLSLong() const;
void setInt(const int& num);
string toString(const unsigned int& base = 10) const;
unsigned int getLength() const;
protected:
class BFPtr : public Pool<BitField>::Ptr {
public:
BFPtr();
BFPtr(const unsigned long& num);
BFPtr(const BitField& fld);
protected:
static Pool<BitField> pool;
};
int sign;
BFPtr fld;
BigInt(const BFPtr& fld, const int& sign);
static BFPtr add(const BFPtr& age, const BFPtr& ade);
static BFPtr subtract(const BFPtr& min, const BFPtr& sub);
static BFPtr multiply(const BFPtr& mca, const BFPtr& mer);
static pair<BFPtr, BFPtr> divide(const BFPtr& dde, const BFPtr& dsr);
static unsigned int digitToNumber(const char& dig);
static char numberToDigit(const unsigned int& num);
};
class EA {
public:
virtual void encrypt(const shared_ptr<ByteInputStream>& msg_in, const shared_ptr<ByteOutputStream>& cip_out) = 0;
virtual void decrypt(const shared_ptr<ByteInputStream>& cip_in, const shared_ptr<ByteOutputStream>& msg_out) = 0;
};
class RSA : public EA {
public:
struct Key {
BigInt n;
BigInt e;
};
RSA(const Key& key);
virtual void encrypt(const shared_ptr<ByteInputStream>& msg_in, const shared_ptr<ByteOutputStream>& cip_out);
virtual void decrypt(const shared_ptr<ByteInputStream>& cip_in, const shared_ptr<ByteOutputStream>& msg_out);
static pair<Key, Key> generateKeys(const unsigned int& len);
static Key makeKey(const BigInt& n, const BigInt& e);
protected:
Key key;
unsigned int key_len;
shared_ptr<RBG> rbg;
RSA(const RSA& rsa);
RSA& operator =(const RSA& rsa);
shared_ptr<OctetString> encrypt(const shared_ptr<OctetString>& msg);
shared_ptr<OctetString> decrypt(const shared_ptr<OctetString>& cip);
};
class RPG {
public:
virtual BigInt generatePrime(const unsigned int& len) = 0;
protected:
static bool findSmallPrimeFactor(const BigInt& num);
};
class SimpleRPG : public RPG {
public:
SimpleRPG();
virtual BigInt generatePrime(const unsigned int& len);
protected:
shared_ptr<RBG> rbg;
shared_ptr<PT> pt;
SimpleRPG(const SimpleRPG& simple_rpg);
SimpleRPG& operator =(const SimpleRPG& simple_rpg);
};
class RBG {
public:
virtual bool generateBit() = 0;
};
class SimpleRBG : public RBG {
public:
SimpleRBG();
virtual bool generateBit();
protected:
shared_ptr<BitField> hash;
unsigned int pos;
SimpleRBG(const SimpleRBG& simple_rbg);
SimpleRBG& operator =(const SimpleRBG& simple_rbg);
};
class Sieve {
public:
unsigned int getNextPrime(const unsigned int& prev_num);
static Sieve instance;
protected:
shared_ptr<PrimitiveExArray<bool> > tab;
Sieve();
};
class PT {
public:
virtual bool isPrime(const BigInt& p) = 0;
};
class MRBPT : public PT {
public:
MRBPT();
virtual bool isPrime(const BigInt& p);
protected:
MRBPT(const MRBPT& mrbpt);
MRBPT& operator =(const MRBPT& mrbpt);
};
class HA {
public:
virtual shared_ptr<BitField> computeHash(const shared_ptr<BitInputStream>& msg) = 0;
};
class SHA1 : public HA {
public:
SHA1();
virtual shared_ptr<BitField> computeHash(const shared_ptr<BitInputStream>& msg);
protected:
typedef unsigned long (*FUNCTION)(const unsigned long& x, const unsigned long& y, const unsigned long& z);
SHA1(const SHA1& sha1);
SHA1& operator =(const SHA1& sha1);
static const FUNCTION FUNCTIONS[4];
static const unsigned long K[4];
static const unsigned long INITIAL_HASH_VALUES[5];
static unsigned long ch(const unsigned long& x, const unsigned long& y, const unsigned long& z);
static unsigned long parity(const unsigned long& x, const unsigned long& y, const unsigned long& z);
static unsigned long maj(const unsigned long& x, const unsigned long& y, const unsigned long& z);
};
class BigMath {
public:
static BigInt power(const BigInt& base, const BigInt& exp);
static BigInt power(const BigInt& base, const BigInt& exp, const BigInt& nrm);
static BigInt gcd(const BigInt& a, const BigInt& b);
static BigInt lcm(const BigInt& a, const BigInt& b);
static BigInt inverse(const BigInt& num, const BigInt& nrm);
protected:
BigMath();
};
class BitField {
public:
class LongInputStream {
public:
LongInputStream(const BitField* fld);
bool isEOF() const;
unsigned long getLong();
protected:
const BitField* fld;
unsigned int end;
unsigned int cur;
unsigned int low_len;
unsigned int high_len;
unsigned long low_mask;
unsigned long high_mask;
unsigned long rem;
unsigned long num;
};
class LongOutputStream {
public:
LongOutputStream(BitField* fld);
virtual ~LongOutputStream();
void putLong(const unsigned long& num);
void flush();
protected:
BitField* fld;
unsigned int cur;
unsigned int low_len;
unsigned int high_len;
unsigned long low_mask;
unsigned long high_mask;
unsigned long high_buf;
bool bufed;
};
class ReverseLongInputStream {
public:
ReverseLongInputStream(const BitField* fld);
bool isEOF() const;
unsigned long getLong();
protected:
const BitField* fld;
unsigned int end;
unsigned int cur;
unsigned int low_len;
unsigned int high_len;
unsigned long low_mask;
unsigned long high_mask;
unsigned long rem;
unsigned long num;
};
class ReverseLongOutputStream {
public:
ReverseLongOutputStream(BitField* fld);
virtual ~ReverseLongOutputStream();
void putLong(const unsigned long& num);
void flush();
protected:
BitField* fld;
unsigned int low_len;
unsigned int high_len;
unsigned long low_mask;
unsigned long high_mask;
unsigned long low_buf;
bool bufed;
};
BitField();
BitField(const unsigned long& num);
BitField(const BitField& fld);
virtual ~BitField();
BitField& operator =(const BitField& fld);
unsigned int getLength() const;
bool isEmpty() const;
bool isZero() const;
bool isOne() const;
bool isPowerOf2() const;
bool getBit(const unsigned int& pos) const;
bool getLSBit() const;
bool getMSBit() const;
unsigned char getByte(const unsigned int& ls_pos) const;
unsigned char getLSByte() const;
unsigned long getLong(const unsigned int& ls_pos) const;
unsigned long getLSLong() const;
void setBit(const unsigned int& pos, const bool& bit);
void pushMS(const bool& bit);
template <class X> void pushMS(const X& x);
void pushLS(const bool& bit);
template <class X> void pushLS(const X& x);
void setLong(const unsigned long& num);
void shift(const int& num);
void trim();
void clear();
int compare(const BitField& fld) const;
string toString() const;
BitField getSubfield(const unsigned int& ls_pos, const unsigned int& len) const;
protected:
unsigned char* buf;
unsigned int buf_sz;
int beg;
unsigned int len;
void extendIfNecessary(const int& len);
friend LongInputStream;
friend LongOutputStream;
friend ReverseLongInputStream;
friend ReverseLongOutputStream;
};
class ByteInputStream {
public:
virtual unsigned char getByte() = 0;
virtual bool isEOF() = 0;
};
class ByteOutputStream {
public:
virtual void putByte(const unsigned char& byte) = 0;
virtual void flush() = 0;
};
class OctetStringByteInputStream : public ByteInputStream {
public:
OctetStringByteInputStream(const shared_ptr<OctetString>& os);
virtual unsigned char getByte();
virtual bool isEOF();
protected:
shared_ptr<OctetString> os;
unsigned int i;
OctetStringByteInputStream(const OctetStringByteInputStream& os_in);
OctetStringByteInputStream& operator =(const OctetStringByteInputStream& os_in);
};
class OctetStringByteOutputStream : public ByteOutputStream {
public:
OctetStringByteOutputStream();
virtual void putByte(const unsigned char& byte);
virtual void flush();
shared_ptr<OctetString> getResult();
protected:
shared_ptr<OctetString> os;
OctetStringByteOutputStream(const OctetStringByteOutputStream& os_out);
OctetStringByteOutputStream& operator =(const OctetStringByteOutputStream& os_out);
};
class StringByteInputStream : public ByteInputStream {
public:
StringByteInputStream(const string& str);
virtual unsigned char getByte();
virtual bool isEOF();
protected:
string str;
size_t pos;
StringByteInputStream(const StringByteInputStream& str_in);
StringByteInputStream& operator =(const StringByteInputStream& str_in);
};
class StringByteOutputStream : public ByteOutputStream {
public:
StringByteOutputStream();
virtual void putByte(const unsigned char& byte);
virtual void flush();
string getResult();
protected:
string str;
StringByteOutputStream(const StringByteOutputStream& str_out);
StringByteOutputStream& operator =(const StringByteOutputStream& str_out);
};
class BitInputStream {
public:
virtual bool getBit() = 0;
virtual bool isEOF() = 0;
};
class BitInputStreamConnector : public BitInputStream {
public:
BitInputStreamConnector(const shared_ptr<vector<shared_ptr<BitInputStream> > >& ins);
virtual bool getBit();
virtual bool isEOF();
protected:
shared_ptr<vector<shared_ptr<BitInputStream> > > ins;
unsigned int i;
};
template <class X>
class PrimitiveBitInputStream : public BitInputStream {
public:
PrimitiveBitInputStream(const X& x);
virtual bool getBit();
virtual bool isEOF();
protected:
X x;
unsigned int pos;
PrimitiveBitInputStream(const PrimitiveBitInputStream& pri_in);
PrimitiveBitInputStream& operator =(const PrimitiveBitInputStream& pri_in);
};
template <class X> X rotateLeft(const X& x, const unsigned int& n, const unsigned int& w = (sizeof(X) << 3));
////////////////////////////////////////////////////////////////////////////////
RSA::RSA(const Key& key) : key(key), key_len((key.n.getLength() + 7) / 8), rbg(new SimpleRBG) {}
void RSA::encrypt(const shared_ptr<ByteInputStream>& msg_in, const shared_ptr<ByteOutputStream>& cip_out) {
unsigned int max_msg_len = this->key_len - 3 - 8;
shared_ptr<OctetString> msg(new PrimitiveExArray<unsigned char>);
while (!msg_in->isEOF()) {
msg->clear();
while (!msg_in->isEOF() && msg->getSize() < max_msg_len)
msg->push(msg_in->getByte());
shared_ptr<OctetString> cip = encrypt(msg);
for (unsigned int i = 0; i < cip->getSize(); i++)
cip_out->putByte(cip->at(i));
}
cip_out->flush();
}
void RSA::decrypt(const shared_ptr<ByteInputStream>& cip_in, const shared_ptr<ByteOutputStream>& msg_out) {
unsigned int max_cip_len = this->key_len;
shared_ptr<OctetString> cip(new PrimitiveExArray<unsigned char>);
while (!cip_in->isEOF()) {
cip->clear();
while (!cip_in->isEOF() && cip->getSize() < max_cip_len)
cip->push(cip_in->getByte());
shared_ptr<OctetString> msg = decrypt(cip);
for (unsigned int i = 0; i < msg->getSize(); i++)
msg_out->putByte(msg->at(i));
}
msg_out->flush();
}
pair<RSA::Key, RSA::Key> RSA::generateKeys(const unsigned int& len) {
if (len < 96) throw string("鍵の長さが96ビット未満。");
if (len % 8 != 0) throw string("鍵の長さが8の倍数ではない。");
shared_ptr<RPG> rpg(new SimpleRPG());
BigInt p = rpg->generatePrime(len / 2);
BigInt q = rpg->generatePrime(len / 2);
BigInt n = p * q;
BigInt e = rpg->generatePrime(len / 4);
BigInt d = BigMath::inverse(e, BigMath::lcm(p - 1, q - 1));
return make_pair(makeKey(n, e), makeKey(n, d));
}
RSA::Key RSA::makeKey(const BigInt& n, const BigInt& e) {
Key res;
res.n = n;
res.e = e;
return res;
}
shared_ptr<OctetString> RSA::encrypt(const shared_ptr<OctetString>& msg) {
BitField fld;
fld.pushLS<unsigned char>(0x00);
fld.pushLS<unsigned char>(0x02);
unsigned int pad_len = this->key_len - 3 - msg->getSize();
for (unsigned int i = 0; i < pad_len; i++) {
fld.pushLS(true);
for (unsigned int j = 0; j < 7; j++)
fld.pushLS(this->rbg->generateBit());
}
fld.pushLS<unsigned char>(0x00);
for (unsigned int i = 0; i < msg->getSize(); i++)
fld.pushLS<unsigned char>(msg->at(i));
BigInt msg_num(fld);
BigInt cip_num = BigMath::power(msg_num, this->key.e, this->key.n);
shared_ptr<OctetString> cip(new PrimitiveExArray<unsigned char>);
for (unsigned int i = 0; i < this->key_len; i++)
cip->push(cip_num.getByte((this->key_len - 1 - i) << 3));
return cip;
}
shared_ptr<OctetString> RSA::decrypt(const shared_ptr<OctetString>& cip) {
BitField fld;
for (unsigned int i = 0; i < cip->getSize(); i++)
fld.pushLS<unsigned char>(cip->at(i));
BigInt cip_num(fld);
BigInt msg_num = BigMath::power(cip_num, this->key.e, this->key.n);
shared_ptr<OctetString> msg(new PrimitiveExArray<unsigned char>);
bool pad_end = false;
for (unsigned int i = 0; i < this->key_len; i++) {
unsigned char byte = msg_num.getByte((this->key_len - 1 - i) << 3);
if (i < 2) {
if ((i == 0 && byte != 0x00) || (i == 1 && byte != 0x02))
throw string("復号化できない。");
}
else if (!pad_end) {
if (byte == 0x00) pad_end = true;
}
else msg->push(byte);
}
return msg;
}
bool RPG::findSmallPrimeFactor(const BigInt& num) {
bool res = false;
unsigned int end = (num.getLength() * num.getLength()) >> 8;
for (unsigned int prime = 3; prime < end; prime = Sieve::instance.getNextPrime(prime)) {
if ((num % prime).isZero()) {
res = true;
break;
}
}
return res;
}
SimpleRPG::SimpleRPG() : rbg(new SimpleRBG), pt(new MRBPT) {}
BigInt SimpleRPG::generatePrime(const unsigned int& len) {
if (len < 2) throw string("生成する素数の長さが2ビット未満。");
BigInt num;
bool found = false;
while (!found) {
BitField fld;
if (len == 2) fld.pushMS(this->rbg->generateBit());
else {
fld.pushMS(true);
for (unsigned int i = 0; i < len - 2; i++)
fld.pushMS(this->rbg->generateBit());
}
fld.pushMS(true);
for (num = BigInt(fld); num.getLength() == len; num += 2) {
if (!findSmallPrimeFactor(num) && this->pt->isPrime(num)) {
found = true;
break;
}
}
}
return num;
}
SimpleRBG::SimpleRBG() : pos(160) {
unsigned long long t = time(nullptr);
unsigned long long c = clock();
srand((t >> 32) ^ (t & 0xffffffff) ^ (c >> 32) ^ (c & 0xffffffff));
}
bool SimpleRBG::generateBit() {
if (this->pos >= 160) {
shared_ptr<HA> ha(new SHA1);
shared_ptr<vector<shared_ptr<BitInputStream> > > ins(new vector<shared_ptr<BitInputStream> >);
ins->push_back(shared_ptr<BitInputStream>(new PrimitiveBitInputStream<unsigned long long>(time(nullptr))));
ins->push_back(shared_ptr<BitInputStream>(new PrimitiveBitInputStream<unsigned long long>(clock())));
ins->push_back(shared_ptr<BitInputStream>(new PrimitiveBitInputStream<int>(rand())));
this->hash = ha->computeHash(shared_ptr<BitInputStream>(new BitInputStreamConnector(ins)));
this->pos = 0;
}
bool res = this->hash->getBit(this->pos);
this->pos++;
return res;
}
unsigned int Sieve::getNextPrime(const unsigned int& prev_num) {
unsigned int res;
for (;;) {
{
unsigned int i = prev_num + 1;
for (; i < this->tab->getSize(); i++) {
if (this->tab->at(i) == true) {
res = i;
break;
}
}
if (i < this->tab->getSize()) break;
}
unsigned int new_sz = this->tab->getSize() << 1;
if (new_sz == 0) new_sz = 1;
for (unsigned int i = this->tab->getSize(); i < new_sz; i++)
this->tab->push(true);
for (unsigned int i = 0; i < this->tab->getSize(); i++) {
if (this->tab->at(i) == true) {
for (unsigned int j = i + i; j < this->tab->getSize(); j += i)
this->tab->at(j) = false;
}
}
}
return res;
}
Sieve Sieve::instance;
Sieve::Sieve() : tab(new PrimitiveExArray<bool>) {
this->tab->push(false);
this->tab->push(false);
}
MRBPT::MRBPT() {}
bool MRBPT::isPrime(const BigInt& p) {
bool res = true;
if (p == 1) res = false;
else if (p == 2) res = true;
else if (p.getLSBit() == false) res = false;
else {
BigInt o = p - 1;
BigInt q = o;
BigInt k(1);
while (q.getLSBit() == false) {
q >>= 1;
k++;
}
BigInt cnt(3072 / p.getLength());
if (cnt == 0) cnt = 1;
BigInt a(2);
BigInt gap(p / cnt);
if (gap.isZero()) gap = 1;
for (; a < p; a += gap) {
BigInt b = BigMath::power(a, q, p);
if (b != 1 && b != o) {
BigInt i(1);
for (; i < k; i++) {
b = (b * b) % p;
if (b == o) break;
}
if (i == k) {
res = false;
break;
}
}
}
}
return res;
}
const SHA1::FUNCTION SHA1::FUNCTIONS[4] = {
ch,
parity,
maj,
parity,
};
const unsigned long SHA1::K[4] = {
0x5a827999,
0x6ed9eba1,
0x8f1bbcdc,
0xca62c1d6,
};
const unsigned long SHA1::INITIAL_HASH_VALUES[5] = {
0x67452301,
0xefcdab89,
0x98badcfe,
0x10325476,
0xc3d2e1f0,
};
SHA1::SHA1() {}
shared_ptr<BitField> SHA1::computeHash(const shared_ptr<BitInputStream>& msg) {
shared_ptr<array<unsigned long, 5> > hash(new array<unsigned long, 5>);
copy(INITIAL_HASH_VALUES, INITIAL_HASH_VALUES + 5, &hash->front());
shared_ptr<array<unsigned long, 16> > block(new array<unsigned long, 16>);
shared_ptr<array<unsigned long, 80> > sche(new array<unsigned long, 80>);
unsigned long long len = 0;
bool put_end = false;
bool put_len = false;
while (!put_len) {
block->fill(0);
unsigned int block_pos = 0;
for (; block_pos < 512 && !msg->isEOF(); block_pos++) {
block->at(block_pos >> 5) |= (msg->getBit() == true ? 1 : 0) << (32 - 1 - (block_pos & 0x1f));
len++;
}
if (block_pos < 512) {
if (!put_end) {
block->at(block_pos >> 5) |= 1 << (32 - 1 - (block_pos & 0x1f));
block_pos++;
put_end = true;
}
if (block_pos + 64 <= 512) {
block->at(14) = len >> 32;
block->at(15) = len & 0xffffffff;
put_len = true;
}
}
for (unsigned int i = 0; i < 80; i++) {
if (i < 16) sche->at(i) = block->at(i);
else sche->at(i) = rotateLeft(
sche->at(i - 3) ^
sche->at(i - 8) ^
sche->at(i - 14) ^
sche->at(i - 16), 1);
}
unsigned long a = hash->at(0);
unsigned long b = hash->at(1);
unsigned long c = hash->at(2);
unsigned long d = hash->at(3);
unsigned long e = hash->at(4);
for (unsigned int i = 0; i < 80; i++) {
unsigned long t =
rotateLeft(a, 5) +
FUNCTIONS[i / 20](b, c, d) +
e +
K[i / 20] +
sche->at(i);
e = d;
d = c;
c = rotateLeft(b, 30);
b = a;
a = t;
}
hash->at(0) += a;
hash->at(1) += b;
hash->at(2) += c;
hash->at(3) += d;
hash->at(4) += e;
}
shared_ptr<BitField> res(new BitField);
for (unsigned int i = 0; i < 5; i++) res->pushMS(hash->at(i));
return res;
}
unsigned long SHA1::ch(const unsigned long& x, const unsigned long& y, const unsigned long& z) {
return (x & y) ^ (~x & z);
}
unsigned long SHA1::parity(const unsigned long& x, const unsigned long& y, const unsigned long& z) {
return x ^ y ^ z;
}
unsigned long SHA1::maj(const unsigned long& x, const unsigned long& y, const unsigned long& z) {
return (x & y) ^ (x & z) ^ (y & z);
}
BigInt BigMath::power(const BigInt& base, const BigInt& exp) {
BigInt res;
if (base.isZero()) res = BigInt(0);
else if (exp.isZero()) res = BigInt(1);
else if (exp < 0) res = BigInt(1) / power(base, -exp);
else {
res.setInt(1);
BigInt coef = base;
for (BigInt exp2 = exp;;) {
if (exp2.getLSBit()) res *= coef;
exp2 >>= 1;
if (exp2.isZero()) break;
coef *= coef;
}
}
return res;
}
BigInt BigMath::power(const BigInt& base, const BigInt& exp, const BigInt& nrm) {
BigInt res;
if (base.isZero()) res = BigInt(0);
else if (exp.isZero()) res = BigInt(1);
else if (exp < 0) res = (BigInt(1) / power(base, -exp)) % nrm;
else {
res.setInt(1);
BigInt coef = base % nrm;
for (BigInt exp2 = exp;;) {
if (exp2.getLSBit()) res = (res * coef) % nrm;
exp2 >>= 1;
if (exp2.isZero()) break;
coef = (coef * coef) % nrm;
}
}
return res;
}
BigInt BigMath::gcd(const BigInt& a, const BigInt& b) {
BigInt c(a.abs()), d(b.abs()), r;
if (c < d) swap(c, d);
for (;;) {
r = c % d;
if (r == 0) break;
c = d;
d = r;
};
if (d.getSign() != a.getSign()) d = -d;
return d;
}
BigInt BigMath::lcm(const BigInt& a, const BigInt& b) {
return a * b / gcd(a, b);
}
BigInt BigMath::inverse(const BigInt& num, const BigInt& nrm) {
BigInt a(num.abs()), b(nrm.abs()), c, d(1), e(0);
pair<BigInt, BigInt> q_r;
if (a < b) {
swap(a, b);
swap(d, e);
}
for (;;) {
q_r = a.divide(b);
if (q_r.second == 0) break;
c = d - e * q_r.first;
d = e;
e = c;
a = b;
b = q_r.second;
};
if (c.getSign() != nrm.getSign()) c += nrm;
return c;
}
class MColPtr : public Pool<PrimitiveExArray<unsigned long> >::Ptr {
public:
MColPtr();
protected:
static Pool<PrimitiveExArray<unsigned long> > pool;
};
class MTabPtr : public Pool<ClassObjectExArray<MColPtr> >::Ptr {
public:
MTabPtr();
protected:
static Pool<ClassObjectExArray<MColPtr> > pool;
};
BigInt::BigInt() : sign(0) {}
BigInt::BigInt(const int& num) {
this->sign = num == 0 ? 0 : (num > 0 ? 1 : -1);
this->fld->setLong(num < 0 ? -num : num);
}
BigInt::BigInt(const string& str, const unsigned int& base) {
this->sign = 1;
this->fld->setLong(0);
BFPtr base_fld(base);
auto iter = str.begin();
if (*iter == '+') {
this->sign = 1;
iter++;
}
else if (*iter == '-') {
this->sign = -1;
iter++;
}
for (; iter != str.end(); iter++) {
if (iter != str.begin()) {
this->fld = multiply(this->fld, base_fld);
}
unsigned int num = digitToNumber(*iter);
if (num >= base) throw string("値が範囲を逸脱している。");
this->fld = add(this->fld, BFPtr(num));
}
if (this->fld->isZero()) this->sign = 0;
}
BigInt::BigInt(const BitField& fld, const int& sign) {
*this->fld = fld;
this->fld->trim();
this->sign = fld.isZero() ? 0 : sign;
}
BigInt::BigInt(const BigInt& num) : sign(num.sign), fld(*num.fld) {}
BigInt BigInt::operator -() const {
return BigInt(BFPtr(*this->fld), -this->sign);
}
BigInt BigInt::operator +(const BigInt& ade) const {
BigInt res;
if (this->sign == 0)
res = ade;
else if (ade.sign == 0)
res = *this;
else if (this->sign == ade.sign)
res = BigInt(add(this->fld, ade.fld), this->sign);
else {
int cmp_res = this->fld->compare(*ade.fld);
if (cmp_res == 0) res.setInt(0);
else {
int sign;
BFPtr gre_fld, less_fld;
if (cmp_res > 0) {
sign = this->sign;
gre_fld = this->fld;
less_fld = ade.fld;
}
else {
sign = ade.sign;
gre_fld = ade.fld;
less_fld = this->fld;
}
res = BigInt(subtract(gre_fld, less_fld), sign);
}
}
return res;
}
BigInt BigInt::operator -(const BigInt& sub) const {
return operator +(-sub);
}
BigInt BigInt::operator *(const BigInt& mer) const {
return BigInt(multiply(this->fld, mer.fld), this->sign * mer.sign);
}
BigInt BigInt::operator /(const BigInt& dsr) const {
return BigInt(divide(this->fld, dsr.fld).first, this->sign * dsr.sign);
}
BigInt BigInt::operator %(const BigInt& nrm) const {
return BigInt(divide(this->fld, nrm.fld).second, this->sign);
}
pair<BigInt, BigInt> BigInt::divide(const BigInt& dsr) const {
pair<BFPtr, BFPtr> quo_rem = divide(this->fld, dsr.fld);
return make_pair(BigInt(quo_rem.first, this->sign * dsr.sign), BigInt(quo_rem.second, this->sign));
}
BigInt BigInt::operator <<(const unsigned int& num) const {
BFPtr fld(*this->fld);
fld->shift(-num);
return BigInt(fld, this->sign);
}
BigInt BigInt::operator >>(const unsigned int& num) const {
BFPtr fld(*this->fld);
fld->shift(num);
return BigInt(fld, this->sign);
}
BigInt& BigInt::operator =(const BigInt& num) {
if (&num != this) {
this->sign = num.sign;
*this->fld = *num.fld;
}
return *this;
}
BigInt& BigInt::operator +=(const BigInt& add) {
return *this = *this + add;
}
BigInt& BigInt::operator -=(const BigInt& sub) {
return *this = *this - sub;
}
BigInt& BigInt::operator *=(const BigInt& mer) {
return *this = *this * mer;
}
BigInt& BigInt::operator /=(const BigInt& dsr) {
return *this = *this / dsr;
}
BigInt& BigInt::operator %=(const BigInt& nrm) {
return *this = *this % nrm;
}
BigInt& BigInt::operator <<=(const size_t& num) {
return *this = *this << num;
}
BigInt& BigInt::operator >>=(const size_t& num) {
return *this = *this >> num;
}
BigInt& BigInt::operator ++() {
return *this += 1;
}
BigInt BigInt::operator ++(int) {
BigInt res = *this;
*this += 1;
return res;
}
BigInt& BigInt::operator --() {
return *this -= 1;
}
BigInt BigInt::operator --(int) {
BigInt res = *this;
*this -= 1;
return res;
}
bool BigInt::operator ==(const BigInt& num) const {
if (&num == this) return true;
return this->sign == num.sign && this->fld->compare(*num.fld) == 0;
}
bool BigInt::operator <(const BigInt& num) const {
if (&num == this) return false;
int cmp_res = this->fld->compare(*num.fld);
return (this->sign >= 0 && num.sign >= 0 && cmp_res < 0) ||
(this->sign < 0 && num.sign >= 0) ||
(this->sign < 0 && num.sign < 0 && cmp_res > 0);
}
bool BigInt::operator >(const BigInt& num) const {
if (&num == this) return false;
int cmp_res = this->fld->compare(*num.fld);
return (this->sign >= 0 && num.sign >= 0 && cmp_res > 0) ||
(this->sign >= 0 && num.sign < 0) ||
(this->sign < 0 && num.sign < 0 && cmp_res < 0);
}
bool BigInt::operator !=(const BigInt& num) const {
return !operator ==(num);
}
bool BigInt::operator <=(const BigInt& num) const {
return !operator >(num);
}
bool BigInt::operator >=(const BigInt& num) const {
return !operator <(num);
}
BigInt BigInt::abs() const {
return BigInt(BFPtr(*this->fld), 1);
}
bool BigInt::isZero() const {
return this->fld->isZero();
}
int BigInt::getSign() const {
return this->sign;
}
bool BigInt::getLSBit() const {
return this->fld->getLSBit();
}
unsigned char BigInt::getByte(const unsigned int ls_pos) const {
return this->fld->getByte(ls_pos);
}
unsigned char BigInt::getLSByte() const {
return this->fld->getLSByte();
}
unsigned long BigInt::getLSLong() const {
return this->fld->getLSLong();
}
void BigInt::setInt(const int& num) {
this->sign = num == 0 ? 0 : (num > 0 ? 1 : -1);
this->fld->setLong(num < 0 ? -num : num);
}
string BigInt::toString(const unsigned int& base) const {
string res = "";
pair<BFPtr, BFPtr> quo_rem(BFPtr(*this->fld), BFPtr());
BFPtr base_fld;
base_fld->setLong(base);
while (!quo_rem.first->isZero()) {
quo_rem = divide(quo_rem.first, base_fld);
res += numberToDigit(quo_rem.second->getLSLong());
}
if (res.empty()) res += '0';
if (this->sign < 0) res += '-';
reverse(res.begin(), res.end());
size_t pos = 0;
for (; pos < res.length() - 1; pos++) if (res[pos] != '0') break;
return res.substr(pos, res.length() - pos);
}
unsigned int BigInt::getLength() const {
return this->fld->getLength();
}
BigInt::BFPtr::BFPtr() : Pool<BitField>::Ptr(&pool) {
this->x_cnt.first->clear();
}
BigInt::BFPtr::BFPtr(const unsigned long& num) : Pool<BitField>::Ptr(&pool) {
this->x_cnt.first->setLong(num);
}
BigInt::BFPtr::BFPtr(const BitField& fld) : Pool<BitField>::Ptr(&pool) {
*this->x_cnt.first = fld;
}
Pool<BitField> BigInt::BFPtr::pool;
BigInt::BigInt(const BFPtr& fld, const int& sign) {
this->fld = fld;
this->fld->trim();
this->sign = fld->isZero() ? 0 : sign;
}
BigInt::BFPtr BigInt::add(const BFPtr& age, const BFPtr& ade) {
BFPtr sum;
if (age->isZero()) *sum = *ade;
else if (ade->isZero()) *sum = *age;
else {
BitField::LongOutputStream sum_out(sum.get());
BitField::LongInputStream age_in(age.get());
BitField::LongInputStream ade_in(ade.get());
bool carry = false;
while (!age_in.isEOF() || !ade_in.isEOF()) {
unsigned long age_num = !age_in.isEOF() ? age_in.getLong() : 0;
unsigned long ade_num = !ade_in.isEOF() ? ade_in.getLong() : 0;
unsigned long sum_num = (age_num & ((1 << (32 - 1)) - 1)) + (ade_num & ((1 << (32 - 1)) - 1)) + (carry ? 1 : 0);
int ms_num = (age_num >> (32 - 1)) + (ade_num >> (32 - 1)) + (sum_num >> (32 - 1));
carry = ms_num > 1;
if (carry) ms_num -= 2;
sum_num = (sum_num & ((1 << (32 - 1)) - 1)) | (ms_num << (32 - 1));
sum_out.putLong(sum_num);
}
sum_out.flush();
if (carry) sum->pushMS(true);
}
if (sum->isEmpty()) sum->pushMS(false);
else sum->trim();
return sum;
}
BigInt::BFPtr BigInt::subtract(const BFPtr& min, const BFPtr& sub) {
BFPtr dif;
if (sub->isZero()) *dif = *min;
else {
BitField::LongOutputStream dif_out(dif.get());
BitField::LongInputStream min_in(min.get());
BitField::LongInputStream sub_in(sub.get());
bool borrow = false;
while (!min_in.isEOF() || !sub_in.isEOF()) {
unsigned int min_num = !min_in.isEOF() ? min_in.getLong() : 0;
unsigned int sub_num = !sub_in.isEOF() ? sub_in.getLong() : 0;
unsigned int dif_num = (min_num | (1 << (32 - 1))) - (sub_num & ((1 << (32 - 1)) - 1)) - (borrow ? 1 : 0);
int ms_num = (min_num >> (32 - 1)) - (sub_num >> (32 - 1)) - (1 - (dif_num >> (32 - 1)));
borrow = ms_num < 0;
if (borrow) ms_num += 2;
dif_num = (dif_num & ((1 << (32 - 1)) - 1)) | (ms_num << (32 - 1));
dif_out.putLong(dif_num);
}
dif_out.flush();
if (borrow) throw string("被減数が減数より小さい。");
}
if (dif->isEmpty()) dif->pushMS(false);
else dif->trim();
return dif;
}
BigInt::BFPtr BigInt::multiply(const BFPtr& mca, const BFPtr& mer) {
BFPtr pro;
if (mca->isZero() || mer->isZero()) pro->setLong(0);
else if (mca->isPowerOf2()) {
*pro = *mer;
pro->shift(-(mca->getLength() - 1));
}
else if (mer->isPowerOf2()) {
*pro = *mca;
pro->shift(-(mer->getLength() - 1));
}
else {
BitField::LongOutputStream pro_out(pro.get());
BitField::LongInputStream mer_in(mer.get());
MTabPtr tab;
for (unsigned int base_col = 0; !mer_in.isEOF(); base_col++) {
unsigned long mer_num = mer_in.getLong();
BitField::LongInputStream mca_in(mca.get());
for (unsigned int col = base_col; !mca_in.isEOF(); col++) {
unsigned long mca_num = mca_in.getLong();
unsigned long long pro_num = (unsigned long long)mer_num * mca_num;
unsigned long rem_num = pro_num & (((unsigned long long)1 << 32) - 1);
unsigned long carry_num = pro_num >> 32;
if (tab->getSize() <= col) tab->push(MColPtr());
tab->at(col)->push(rem_num);
if (carry_num != 0) {
if (tab->getSize() <= col + 1) tab->push(MColPtr());
tab->at(col + 1)->push(carry_num);
}
}
}
for (unsigned int col = 0; col < tab->getSize(); col++) {
unsigned long age_num = 0;
unsigned long carry_num = 0;
for (unsigned int i = 0; i < tab->at(col)->getSize(); i++) {
unsigned long ade_num = tab->at(col)->at(i);
unsigned long sum_num = (age_num & ((1 << (32 - 1)) - 1)) + (ade_num & ((1 << (32 - 1)) - 1));
int ms_num = (age_num >> (32 - 1)) + (ade_num >> (32 - 1)) + (sum_num >> (32 - 1));
if (ms_num > 1) {
carry_num++;
ms_num -= 2;
}
age_num = (sum_num & ((1 << (32 - 1)) - 1)) | (ms_num << (32 - 1));
}
pro_out.putLong(age_num);
if (carry_num != 0) {
if (tab->getSize() <= col + 1) tab->push(MColPtr());
tab->at(col + 1)->push(carry_num);
}
}
pro_out.flush();
}
if (pro->isEmpty()) pro->pushMS(false);
else pro->trim();
return pro;
}
pair<BigInt::BFPtr, BigInt::BFPtr> BigInt::divide(const BFPtr& dde, const BFPtr& dsr) {
if (dsr->isZero()) throw string("除数がゼロ。");
pair<BFPtr, BFPtr> quo_rem;
int cmp_res = dde->compare(*dsr);
if (cmp_res == 0) {
quo_rem.first->setLong(1);
quo_rem.second->setLong(0);
}
else if (cmp_res < 0) {
quo_rem.first->setLong(0);
*quo_rem.second = *dde;
}
else if (dsr->isPowerOf2()) {
*quo_rem.first = *dde;
quo_rem.first->shift(dsr->getLength() - 1);
*quo_rem.second = dde->getSubfield(0, dsr->getLength() - 1);
}
else {
BitField::ReverseLongOutputStream quo_out(quo_rem.first.get());
unsigned int k = (32 - (dsr->getLength() & (32 - 1))) & (32 - 1);
BFPtr k_dde(*dde);
k_dde->shift(-k);
BFPtr k_dsr(*dsr);
k_dsr->shift(-k);
BitField::ReverseLongInputStream dde_in(k_dde.get());
BitField::ReverseLongInputStream dsr_in(k_dsr.get());
unsigned long dsr_ms_num = dsr_in.getLong();
while (!dde_in.isEOF()) {
BitField::ReverseLongOutputStream rem_out(quo_rem.second.get());
rem_out.putLong(dde_in.getLong());
rem_out.flush();
quo_rem.second->trim();
int cmp_res = quo_rem.second->compare(*k_dsr);
if (cmp_res == 0) {
quo_out.putLong(1);
quo_rem.second->clear();
}
else if (cmp_res < 0) quo_out.putLong(0);
else {
BitField::ReverseLongInputStream rem_in(quo_rem.second.get());
unsigned long long rem_ms_num = rem_in.getLong();
if (rem_ms_num < dsr_ms_num)
rem_ms_num = (rem_ms_num << 32) | rem_in.getLong();
unsigned long long quo_num = min(rem_ms_num / dsr_ms_num, ((unsigned long long)1 << 32) - 1);
BFPtr pro = multiply(k_dsr, BFPtr(quo_num));
while (quo_rem.second->compare(*pro) < 0) {
quo_num--;
quo_rem.second = add(quo_rem.second, k_dsr);
}
quo_out.putLong(quo_num);
quo_rem.second = subtract(quo_rem.second, pro);
}
}
quo_out.flush();
quo_rem.second->shift(k);
}
if (quo_rem.first->isEmpty()) quo_rem.first->pushMS(false);
else quo_rem.first->trim();
if (quo_rem.second->isEmpty()) quo_rem.second->pushMS(false);
else quo_rem.second->trim();
return quo_rem;
}
unsigned int BigInt::digitToNumber(const char& dig) {
unsigned int res;
if (dig >= '0' && dig <= '9') res = dig - '0';
else if (dig >= 'A' && dig <= 'Z') res = 10 + dig - 'A';
else if (dig >= 'a' && dig <= 'z') res = 10 + dig - 'a';
else throw string("文字が数ではない。");
return res;
}
char BigInt::numberToDigit(const unsigned int& num) {
char res;
if (num >= 0 && num <= 9) res = '0' + num;
else if (num >= 10 && num <= 36) res = 'a' + (num - 10);
else throw string("数が範囲を逸脱している。");
return res;
}
MColPtr::MColPtr() : Pool<PrimitiveExArray<unsigned long> >::Ptr(&pool) {
this->x_cnt.first->clear();
}
Pool<PrimitiveExArray<unsigned long> > MColPtr::pool;
MTabPtr::MTabPtr() : Pool<ClassObjectExArray<MColPtr> >::Ptr(&pool) {
this->x_cnt.first->clear();
}
Pool<ClassObjectExArray<MColPtr> > MTabPtr::pool;
BitField::LongInputStream::LongInputStream(const BitField* fld) {
this->fld = fld;
this->rem = this->fld->len;
if (this->rem != 0) {
this->high_len = this->fld->beg & (32 - 1);
this->low_len = 32 - this->high_len;
this->high_mask = (1 << this->high_len) - 1;
this->low_mask = ~this->high_mask;
this->end = this->fld->buf_sz >> 2;
this->cur = this->fld->beg >> 5;
this->num = ((unsigned long*)this->fld->buf)[this->cur];
}
}
bool BitField::LongInputStream::isEOF() const {
return this->rem == 0;
}
unsigned long BitField::LongInputStream::getLong() {
unsigned int res = 0;
res |= (this->num & this->low_mask) >> this->high_len;
if (this->rem >= this->low_len) {
this->rem -= this->low_len;
if (this->rem != 0) {
this->cur++;
if (this->cur == this->end) this->cur = 0;
this->num = ((unsigned long*)this->fld->buf)[this->cur];
if (this->high_len != 0) {
res |= (this->num & this->high_mask) << this->low_len;
if (this->rem >= this->high_len)
this->rem -= this->high_len;
else {
res &= (1 << (this->low_len + this->rem)) - 1;
this->rem = 0;
}
}
}
}
else {
res &= (1 << this->rem) - 1;
this->rem = 0;
}
return res;
}
BitField::LongOutputStream::LongOutputStream(BitField* fld) {
this->fld = fld;
unsigned int fld_end = this->fld->beg + this->fld->len;
unsigned int buf_len = this->fld->buf_sz << 3;
unsigned int act_fld_end;
if (fld_end <= buf_len) act_fld_end = fld_end;
else act_fld_end = fld_end - buf_len;
this->high_len = act_fld_end & (32 - 1);
this->low_len = 32 - this->high_len;
this->high_mask = (1 << this->high_len) - 1;
this->low_mask = ~this->high_mask;
this->cur = act_fld_end >> 5;
unsigned int end = this->fld->buf_sz >> 2;
if (this->fld->len != 0 && this->cur != end)
this->high_buf = ((unsigned long*)this->fld->buf)[this->cur] & this->high_mask;
else this->high_buf = 0;
this->bufed = false;
}
BitField::LongOutputStream::~LongOutputStream() {
flush();
}
void BitField::LongOutputStream::putLong(const unsigned long& num) {
this->fld->extendIfNecessary(32);
unsigned int end = this->fld->buf_sz >> 2;
if (this->cur == end) this->cur = 0;
((unsigned long*)this->fld->buf)[this->cur] = this->high_buf | (num << this->high_len);
if (this->high_len != 0) {
this->high_buf = num >> this->low_len;
this->bufed = true;
}
this->cur++;
}
void BitField::LongOutputStream::flush() {
if (this->bufed) {
unsigned int fld_end = this->fld->beg + this->fld->len;
unsigned int buf_len = this->fld->buf_sz << 3;
unsigned int act_fld_end;
if (fld_end <= buf_len) act_fld_end = fld_end;
else act_fld_end = fld_end - buf_len;
unsigned int cur = act_fld_end >> 5;
if (cur == 0) cur = (this->fld->buf_sz >> 2) - 1;
else cur--;
((unsigned long*)this->fld->buf)[cur] =
this->high_buf | (((unsigned long*)this->fld->buf)[cur] & this->low_mask);
this->bufed = false;
}
}
BitField::ReverseLongInputStream::ReverseLongInputStream(const BitField* fld) {
this->fld = fld;
this->rem = this->fld->len;
if (this->rem != 0) {
this->end = this->fld->buf_sz >> 2;
this->high_len = this->fld->beg & (32 - 1);
this->low_len = 32 - this->high_len;
this->high_mask = (1 << this->high_len) - 1;
this->low_mask = ~this->high_mask;
unsigned int fld_end = this->fld->beg + this->fld->len;
unsigned int buf_len = this->fld->buf_sz << 3;
unsigned int act_fld_end;
if (fld_end <= buf_len) act_fld_end = fld_end;
else act_fld_end = fld_end - buf_len;
unsigned int last_len = act_fld_end & (32 - 1);
if (last_len <= this->high_len) this->cur = act_fld_end >> 5;
else this->cur = (act_fld_end + (32 - 1)) >> 5;
if (this->cur != this->end)
this->num = ((unsigned long*)this->fld->buf)[this->cur];
}
}
bool BitField::ReverseLongInputStream::isEOF() const {
return this->rem == 0;
}
unsigned long BitField::ReverseLongInputStream::getLong() {
unsigned int res = 0;
unsigned int len = this->rem & (32 - 1);
if (len == 0) len = 32;
if (len > this->low_len) {
res |= this->num & this->high_mask;
unsigned int high_len = len - this->low_len;
if (high_len < this->high_len)
res &= (1 << high_len) - 1;
res <<= this->low_len;
}
if (this->cur == 0) this->cur = this->end;
this->cur--;
this->num = ((unsigned long*)this->fld->buf)[this->cur];
res |= (this->num & this->low_mask) >> this->high_len;
if (len < this->low_len) res &= (1 << len) - 1;
this->rem -= len;
return res;
}
BitField::ReverseLongOutputStream::ReverseLongOutputStream(BitField* fld) {
this->fld = fld;
this->high_len = this->fld->beg & (32 - 1);
this->low_len = 32 - this->high_len;
this->high_mask = (1 << this->high_len) - 1;
this->low_mask = ~this->high_mask;
unsigned int cur = this->fld->beg >> 5;
unsigned int end = this->fld->buf_sz >> 2;
if (this->fld->len != 0 && cur != end)
this->low_buf = ((unsigned long*)this->fld->buf)[cur] & this->low_mask;
else this->low_buf = 0;
this->bufed = false;
}
BitField::ReverseLongOutputStream::~ReverseLongOutputStream() {
flush();
}
void BitField::ReverseLongOutputStream::putLong(const unsigned long& num) {
this->fld->extendIfNecessary(-32);
unsigned int cur = (this->fld->beg >> 5) + 1;
unsigned int end = this->fld->buf_sz >> 2;
if (cur == end) cur = 0;
if (this->high_len != 0) {
((unsigned long*)this->fld->buf)[cur] = (num >> this->low_len) | this->low_buf;
this->low_buf = num << this->high_len;
}
else {
((unsigned long*)this->fld->buf)[cur] = this->low_buf;
this->low_buf = num;
}
this->bufed = true;
}
void BitField::ReverseLongOutputStream::flush() {
if (this->bufed) {
unsigned int cur = this->fld->beg >> 5;
((unsigned long*)this->fld->buf)[cur] =
(((unsigned long*)this->fld->buf)[cur] & this->high_mask) | this->low_buf;
this->bufed = false;
}
}
BitField::BitField() : buf(nullptr), buf_sz(0), beg(0), len(0) {}
BitField::BitField(const unsigned long& num) : buf(nullptr), buf_sz(0), beg(0), len(0) {
extendIfNecessary(32);
*(unsigned long*)this->buf = num;
trim();
}
BitField::BitField(const BitField& fld) : buf(nullptr), buf_sz(0), beg(0), len(0) {
if (fld.len != 0) {
this->buf_sz = fld.buf_sz;
this->buf = new unsigned char[this->buf_sz];
memcpy(this->buf, fld.buf, fld.buf_sz);
this->beg = fld.beg;
this->len = fld.len;
}
}
BitField::~BitField() {
if (this->buf) delete[] this->buf;
}
BitField& BitField::operator =(const BitField& fld) {
if (&fld != this) {
this->len = 0;
if (fld.len != 0) {
this->buf_sz = fld.buf_sz;
this->buf = new unsigned char[this->buf_sz];
memcpy(this->buf, fld.buf, fld.buf_sz);
this->beg = fld.beg;
this->len = fld.len;
}
}
return *this;
}
unsigned int BitField::getLength() const {
return this->len;
}
bool BitField::isEmpty() const {
return this->len == 0;
}
bool BitField::isZero() const {
bool res = true;
for (unsigned int pos = 0; pos < this->len; pos++) {
if (getBit(pos) == true) {
res = false;
break;
}
}
return res;
}
bool BitField::isOne() const {
bool res = false;
if (this->len != 0 && getBit(0) == true) {
res = true;
for (unsigned int pos = 1; pos < this->len; pos++) {
if (getBit(pos) == true) {
res = false;
break;
}
}
}
return res;
}
bool BitField::isPowerOf2() const {
unsigned int cnt = 0;
for (unsigned int pos = 0; pos < this->len; pos++) {
if (getBit(pos) == true) {
cnt++;
if (cnt > 1) break;
}
}
return cnt == 1;
}
bool BitField::getBit(const unsigned int& pos) const {
unsigned int abs_pos = this->beg + pos;
unsigned int buf_len = this->buf_sz << 3;
if (abs_pos >= buf_len) abs_pos -= buf_len;
return ((this->buf[abs_pos >> 3] >> (abs_pos & (8 - 1))) & 1) == 1;
}
bool BitField::getLSBit() const {
return getBit(0);
}
bool BitField::getMSBit() const {
return getBit(this->len - 1);
}
unsigned char BitField::getByte(const unsigned int& ls_pos) const {
unsigned char res = 0;
for (unsigned int pos = 0; pos < 8 && ls_pos + pos < this->len; pos++)
res |= (getBit(ls_pos + pos) == true ? 1 : 0) << pos;
return res;
}
unsigned char BitField::getLSByte() const {
return getByte(0);
}
unsigned long BitField::getLong(const unsigned int& ls_pos) const {
unsigned long res = 0;
for (unsigned int pos = 0; pos < 32 && ls_pos + pos < this->len; pos++)
res |= (getBit(ls_pos + pos) == true ? 1 : 0) << pos;
return res;
}
unsigned long BitField::getLSLong() const {
return getLong(0);
}
void BitField::setBit(const unsigned int& pos, const bool& bit) {
unsigned int abs_pos = this->beg + pos;
unsigned int buf_len = this->buf_sz << 3;
if (abs_pos >= buf_len) abs_pos -= buf_len;
if (bit) this->buf[abs_pos >> 3] |= 1 << (abs_pos & (8 - 1));
else this->buf[abs_pos >> 3] &= ~(1 << (abs_pos & (8 - 1)));
}
void BitField::pushMS(const bool& bit) {
extendIfNecessary(1);
setBit(this->len - 1, bit);
}
template <class X> void BitField::pushMS(const X& x) {
for (unsigned int pos = 0; pos < (sizeof(X) << 3); pos++)
pushMS(((x >> pos) & 1) == 1);
}
void BitField::pushLS(const bool& bit) {
extendIfNecessary(-1);
setBit(0, bit);
}
template <class X> void BitField::pushLS(const X& x) {
for (unsigned int pos = 0; pos < (sizeof(X) << 3); pos++)
pushLS(((x >> ((sizeof(X) << 3) - 1 - pos)) & 1) == 1);
}
void BitField::setLong(const unsigned long& num) {
clear();
extendIfNecessary(32);
*(unsigned int*)this->buf = num;
trim();
}
void BitField::shift(const int& num) {
if (num < 0) {
extendIfNecessary(num);
for (unsigned int pos = 0; pos < -num; pos++)
setBit(pos, false);
}
else if (num > 0) {
unsigned int act_num;
if (num <= this->len) act_num = num;
else act_num = this->len;
this->beg += act_num;
unsigned int buf_len = this->buf_sz << 3;
if (this->beg >= buf_len) this->beg -= buf_len;
this->len -= act_num;
}
}
void BitField::trim() {
while (this->len > 1 && getMSBit() == false) this->len--;
}
void BitField::clear() {
this->beg = 0;
this->len = 0;
}
int BitField::compare(const BitField& fld) const {
int res = 0;
if (&fld != this) {
for (unsigned int next_pos = this->len < fld.len ? fld.len : this->len; next_pos > 0; next_pos--) {
unsigned int pos = next_pos - 1;
bool bit1 = pos < this->len ? getBit(pos) : false;
bool bit2 = pos < fld.len ? fld.getBit(pos) : false;
if (bit1 == false && bit2 == true) {
res = -1;
break;
}
else if (bit1 == true && bit2 == false) {
res = 1;
break;
}
}
}
return res;
}
string BitField::toString() const {
string res = "";
for (unsigned int pos = 0; pos < this->len; pos++)
res += getBit(pos) == true ? '1' : '0';
return res;
}
BitField BitField::getSubfield(const unsigned int& ls_pos, const unsigned int& len) const {
BitField res;
for (unsigned int pos = 0; pos < len; pos++)
res.pushMS(getBit(ls_pos + pos));
return res;
}
void BitField::extendIfNecessary(const int& len) {
unsigned int abs_len = len < 0 ? -len : len;
unsigned int need_sz = ((((this->len + abs_len + (8 - 1)) >> 3) + (4 - 1)) >> 2) << 2;
if (need_sz > this->buf_sz) {
unsigned int new_buf_sz = need_sz << 1;
unsigned char* new_buf = new unsigned char[new_buf_sz];
unsigned int new_beg = 0;
if (this->len != 0) {
unsigned int fld_beg = this->beg >> 3;
unsigned int fld_end = (this->beg + this->len + (8 - 1)) >> 3;
if (fld_end <= this->buf_sz) {
unsigned int fld_sz = fld_end - fld_beg;
memcpy(new_buf + fld_beg, this->buf + fld_beg, fld_sz);
new_beg = this->beg;
}
else {
unsigned int low_sz = this->buf_sz - fld_beg;
unsigned int new_low_beg = new_buf_sz - low_sz;
memcpy(new_buf + new_low_beg, this->buf + fld_beg, low_sz);
unsigned int high_sz = fld_end - this->buf_sz;
memcpy(new_buf, this->buf, high_sz);
unsigned int buf_len = this->buf_sz << 3;
unsigned int low_len = buf_len - this->beg;
unsigned int new_buf_len = new_buf_sz << 3;
new_beg = new_buf_len - low_len;
}
delete[] this->buf;
}
this->buf_sz = new_buf_sz;
this->buf = new_buf;
this->beg = new_beg;
}
if (len < 0) {
this->beg += len;
int buf_len = this->buf_sz << 3;
if (this->beg < 0) this->beg += buf_len;
}
this->len += abs_len;
}
OctetStringByteInputStream::OctetStringByteInputStream(const shared_ptr<OctetString>& os) : os(os), i(0) {}
unsigned char OctetStringByteInputStream::getByte() {
return this->os->at(this->i++);
}
bool OctetStringByteInputStream::isEOF() {
return this->i >= this->os->getSize();
}
OctetStringByteOutputStream::OctetStringByteOutputStream() : os(new PrimitiveExArray<unsigned char>) {}
void OctetStringByteOutputStream::putByte(const unsigned char& byte) {
this->os->push(byte);
}
void OctetStringByteOutputStream::flush() {}
shared_ptr<OctetString> OctetStringByteOutputStream::getResult() {
return this->os;
}
StringByteInputStream::StringByteInputStream(const string& str) : str(str), pos(0) {}
unsigned char StringByteInputStream::getByte() {
return this->str.at(this->pos++);
}
bool StringByteInputStream::isEOF() {
return this->pos >= this->str.length();
}
StringByteOutputStream::StringByteOutputStream() {}
void StringByteOutputStream::putByte(const unsigned char& byte) {
this->str += (char)byte;
}
void StringByteOutputStream::flush() {}
string StringByteOutputStream::getResult() {
return this->str;
}
BitInputStreamConnector::BitInputStreamConnector(const shared_ptr<vector<shared_ptr<BitInputStream> > >& ins) : ins(ins), i(0) {}
bool BitInputStreamConnector::getBit() {
bool res = this->ins->at(this->i)->getBit();
if (this->ins->at(this->i)->isEOF()) this->i++;
return res;
}
bool BitInputStreamConnector::isEOF() {
return this->i >= this->ins->size();
}
template <class X>
PrimitiveBitInputStream<X>::PrimitiveBitInputStream(const X& x) : x(x), pos(0) {}
template <class X>
bool PrimitiveBitInputStream<X>::getBit() {
bool res = ((this->x >> ((sizeof(X) << 3) - 1 - this->pos)) & 1) == 1;
this->pos++;
return res;
}
template <class X>
bool PrimitiveBitInputStream<X>::isEOF() {
return this->pos >= (sizeof(X) << 3);
}
template <class X>
Pool<X>::Ptr::Ptr(Pool* pool) : pool(pool) {
this->x_cnt = this->pool->get();
if (this->x_cnt.second) (*this->x_cnt.second)++;
}
template <class X>
Pool<X>::Ptr::Ptr(const Ptr& ptr) {
this->pool = ptr.pool;
this->x_cnt = ptr.x_cnt;
if (this->x_cnt.second) (*this->x_cnt.second)++;
}
template <class X>
Pool<X>::Ptr::~Ptr() {
release();
}
template <class X>
typename Pool<X>::Ptr& Pool<X>::Ptr::operator =(const Ptr& ptr) {
if (&ptr != this) {
release();
this->x_cnt = ptr.x_cnt;
if (this->x_cnt.second) (*this->x_cnt.second)++;
}
return *this;
}
template <class X>
X& Pool<X>::Ptr::operator *() const {
return *this->x_cnt.first;
}
template <class X>
X* Pool<X>::Ptr::operator ->() const {
return this->x_cnt.first;
}
template <class X>
X* Pool<X>::Ptr::get() const {
return this->x_cnt.first;
}
template <class X>
void Pool<X>::Ptr::release() {
if (this->x_cnt.second) {
(*this->x_cnt.second)--;
if (*this->x_cnt.second == 0)
pool->putBack(this->x_cnt);
}
}
template <class X>
Pool<X>::Pool() {}
template <class X>
Pool<X>::~Pool() {
for (unsigned int i = 0; i < this->xa.getSize(); i++) {
delete this->xa.at(i).first;
delete this->xa.at(i).second;
}
}
template <class X>
pair<X*, unsigned int*> Pool<X>::get() {
if (this->xa.isEmpty()) this->xa.push(make_pair(new X, new unsigned int(0)));
pair<X*, unsigned int*> res = this->xa.last();
this->xa.pop();
return res;
}
template <class X>
void Pool<X>::putBack(const pair<X*, unsigned int*>& x_cnt) {
this->xa.push(x_cnt);
}
template <class X>
PrimitiveExArray<X>::PrimitiveExArray(const unsigned int& sz) : buf(nullptr), buf_sz(0), sz(0) {
resize(sz);
}
template <class X>
PrimitiveExArray<X>::PrimitiveExArray(const PrimitiveExArray& a) : buf(nullptr), buf_sz(0), sz(0) {
if (a.sz != 0) {
if (this->buf) delete[] this->buf;
this->buf_sz = a.buf_sz;
this->buf = new X[this->buf_sz];
this->sz = a.sz;
memcpy(this->buf, a.buf, sizeof(X) * this->sz);
}
}
template <class X>
PrimitiveExArray<X>::~PrimitiveExArray() {
if (this->buf) delete[] this->buf;
}
template <class X>
PrimitiveExArray<X>& PrimitiveExArray<X>::operator =(const PrimitiveExArray& a) {
if (&a != this && a.sz != 0) {
if (this->buf) delete[] this->buf;
this->buf_sz = a.buf_sz;
this->buf = new X[this->buf_sz];
this->sz = a.sz;
memcpy(this->buf, a.buf, sizeof(X) * this->sz);
}
return *this;
}
template <class X>
X& PrimitiveExArray<X>::operator [](const unsigned int& i) {
return this->buf[i];
}
template <class X>
X& PrimitiveExArray<X>::at(const unsigned int& i) {
return this->buf[i];
}
template <class X>
X& PrimitiveExArray<X>::last() {
return this->buf[this->sz - 1];
}
template <class X>
unsigned int PrimitiveExArray<X>::getSize() const {
return this->sz;
}
template <class X>
bool PrimitiveExArray<X>::isEmpty() const {
return this->sz == 0;
}
template <class X>
void PrimitiveExArray<X>::push(const X& x) {
resize(this->sz + 1);
this->buf[this->sz - 1] = x;
}
template <class X>
void PrimitiveExArray<X>::pop() {
this->sz--;
}
template <class X>
void PrimitiveExArray<X>::clear() {
this->sz = 0;
}
template <class X>
void PrimitiveExArray<X>::resize(const unsigned int& sz) {
if (sz > this->buf_sz) {
unsigned int new_buf_sz = sz << 1;
X* new_buf = new X[new_buf_sz];
if (this->buf_sz != 0) {
memcpy(new_buf, this->buf, sizeof(X) * this->sz);
delete[] this->buf;
}
this->buf = new_buf;
this->buf_sz = new_buf_sz;
}
this->sz = sz;
}
template <class X>
ClassObjectExArray<X>::ClassObjectExArray(const unsigned int& sz) : buf(nullptr), buf_sz(0), sz(0) {
resize(sz);
}
template <class X>
ClassObjectExArray<X>::ClassObjectExArray(const ClassObjectExArray& a) : buf(nullptr), buf_sz(0), sz(0) {
if (a.sz != 0) {
if (this->buf) delete[] this->buf;
this->buf_sz = a.buf_sz;
this->buf = new X[this->buf_sz];
this->sz = a.sz;
copy(a.buf, a.buf + a.sz, this->buf);
}
}
template <class X>
ClassObjectExArray<X>& ClassObjectExArray<X>::operator =(const ClassObjectExArray& a) {
if (&a != this && a.sz != 0) {
if (this->buf) delete[] this->buf;
this->buf_sz = a.buf_sz;
this->buf = new X[this->buf_sz];
this->sz = a.sz;
copy(a.buf, a.buf + a.sz, this->buf);
}
return *this;
}
template <class X>
ClassObjectExArray<X>::~ClassObjectExArray() {
if (this->buf) delete[] this->buf;
}
template <class X>
X& ClassObjectExArray<X>::operator [](const unsigned int& i) {
return this->buf[i];
}
template <class X>
X& ClassObjectExArray<X>::at(const unsigned int& i) {
return this->buf[i];
}
template <class X>
X& ClassObjectExArray<X>::last() {
return this->buf[this->sz - 1];
}
template <class X>
unsigned int ClassObjectExArray<X>::getSize() const {
return this->sz;
}
template <class X>
bool ClassObjectExArray<X>::isEmpty() const {
return this->sz == 0;
}
template <class X>
void ClassObjectExArray<X>::push(const X& x) {
resize(this->sz + 1);
this->buf[this->sz - 1] = x;
}
template <class X>
void ClassObjectExArray<X>::pop() {
this->buf[this->sz - 1] = X();
this->sz--;
}
template <class X>
void ClassObjectExArray<X>::clear() {
for (unsigned int i = 0; i < this->sz; i++) this->buf[i] = X();
this->sz = 0;
}
template <class X>
void ClassObjectExArray<X>::resize(const unsigned int& sz) {
if (sz > this->buf_sz) {
unsigned int new_buf_sz = sz << 1;
X* new_buf = new X[new_buf_sz];
if (this->buf_sz != 0) {
copy(this->buf, this->buf + this->sz, new_buf);
delete[] this->buf;
}
this->buf = new_buf;
this->buf_sz = new_buf_sz;
}
else if (sz < this->sz) for (unsigned int i = sz; i < this->sz; i++) this->buf[i] = X();
this->sz = sz;
}
template <class X> X rotateLeft(const X& x, const unsigned int& n, const unsigned int& w) {
return (x << n) | (x >> (w - n));
}
////////////////////////////////////////////////////////////////////////////////
class Receiver;
class Sender;
class Receiver {
public:
string name;
Receiver(const string& name);
void createKeys(const unsigned int& len);
RSA::Key getPublicKey() const;
void receivedCipherText(const Sender& sen, const shared_ptr<OctetString>& cip_text);
void decryptReceivedCipherText() const;
protected:
pair<RSA::Key, RSA::Key> pub_pri_keys;
shared_ptr<OctetString> rec_cip_text;
};
class Sender {
public:
string name;
Sender(const string& name);
void getPublicKeyFrom(const Receiver& rec);
void encryptMessage(const string& msg);
void sendCipherTextTo(Receiver& rec) const;
protected:
RSA::Key rec_pub_key;
shared_ptr<OctetString> cip_text;
};
Receiver::Receiver(const string& name) : name(name) {}
void Receiver::createKeys(const unsigned int& len) {
this->pub_pri_keys = RSA::generateKeys(len);
cout << this->name << "「" << "鍵のペアを作ります" << "」" << endl;
cout << this->name << "「" << "公開鍵=(n=" << this->pub_pri_keys.first.n.toString(16) << ", e=" << this->pub_pri_keys.first.e.toString(16) << ")" << "」" << endl;
cout << this->name << "「" << "秘密鍵=(n=" << this->pub_pri_keys.second.n.toString(16) << ", e=" << this->pub_pri_keys.second.e.toString(16) << ")" << "」" << endl;
cout << endl;
}
void Receiver::receivedCipherText(const Sender& sen, const shared_ptr<OctetString>& cip_text) {
cout << this->name << "「" << sen.name << "から暗号文を受信しました" << "」" << endl;
this->rec_cip_text = cip_text;
cout << this->name << "「" << "暗号文=";
for (unsigned int i = 0; i < this->rec_cip_text->getSize(); i++) {
cout << hex << setw(2) << setfill('0') << (unsigned int)this->rec_cip_text->at(i);
}
cout << "」" << endl;
cout << endl;
}
void Receiver::decryptReceivedCipherText() const {
cout << this->name << "「" << "暗号文を復号化します" << "」" << endl;
cout << this->name << "「" << "使用する鍵=(n=" << this->pub_pri_keys.second.n.toString(16) << ", e=" << this->pub_pri_keys.second.e.toString(16) << ")" << "」" << endl;
shared_ptr<EA> de_rsa(new RSA(this->pub_pri_keys.second));
shared_ptr<ByteInputStream> cip_in(new OctetStringByteInputStream(this->rec_cip_text));
shared_ptr<StringByteOutputStream> msg_out(new StringByteOutputStream());
de_rsa->decrypt(cip_in, msg_out);
string msg = msg_out->getResult();
cout << this->name << "「" << "メッセージ=" << msg << "」" << endl;
cout << endl;
}
RSA::Key Receiver::getPublicKey() const {
return this->pub_pri_keys.first;
}
Sender::Sender(const string& name) : name(name) {}
void Sender::getPublicKeyFrom(const Receiver& rec) {
cout << this->name << "「" << rec.name << "の公開鍵を取得します" << "」" << endl;
this->rec_pub_key = rec.getPublicKey();
cout << endl;
}
void Sender::encryptMessage(const string& msg) {
cout << this->name << "「" << "メッセージを暗号化します" << "」" << endl;
cout << this->name << "「" << "メッセージ=" << msg << "」" << endl;
cout << this->name << "「" << "使用する鍵=(n=" << this->rec_pub_key.n.toString(16) << ", e=" << this->rec_pub_key.e.toString(16) << ")" << "」" << endl;
shared_ptr<EA> en_rsa(new RSA(this->rec_pub_key));
shared_ptr<ByteInputStream> msg_in(new StringByteInputStream(msg));
shared_ptr<OctetStringByteOutputStream> cip_out(new OctetStringByteOutputStream());
en_rsa->encrypt(msg_in, cip_out);
this->cip_text = cip_out->getResult();
cout << this->name << "「" << "暗号文=";
for (unsigned int i = 0; i < this->cip_text->getSize(); i++) {
cout << hex << setw(2) << setfill('0') << (unsigned int)this->cip_text->at(i);
}
cout << "」" << endl;
cout << endl;
}
void Sender::sendCipherTextTo(Receiver& rec) const {
cout << this->name << "「" << rec.name << "に暗号文を送信します" << "」" << endl;
cout << endl;
rec.receivedCipherText(*this, this->cip_text);
}
int main() {
try {
Receiver rec("太郎");
Sender sen("次郎");
rec.createKeys(512);
sen.getPublicKeyFrom(rec);
sen.encryptMessage("テスト");
sen.sendCipherTextTo(rec);
rec.decryptReceivedCipherText();
}
catch (const string& str) {
cerr << str << endl;
return 1;
}
return 0;
}