from math import log, exp, sqrt, floor, ceil, log2
import itertools as IT
import random
from collections import Counter

MOD = 2**31

def rint(m):
    return random.randrange(0, m)

def ansic(seed=None):
    M = 2**31
    if seed is None:
        seed = rint(0, M)
    while True:
        seed = (seed * 1103515245 + 12345) % M
        yield seed >> 16

def randu(seed=None):
    if seed is None:
        seed = rint(MOD)
    while True:
        seed = (seed * 65539) % MOD
        yield seed

def park_miller(seed=None):
    if seed is None:
        seed = rint(MOD)
    while True:
        seed = (seed * 16807) % (MOD-1)
        yield seed

def java(seed=None):
    F = 2**48
    if seed is None:
        seed = rint(F)
    while True:
        seed = (seed * 25214903917 + 11) % (F-1)
        yield seed >> 16

def prng_to_bits(rng, nbits):
    while True:
        b = bin(next(rng))[2:]
        yield from (int(bi) for bi in reversed(b))
        yield from IT.repeat(0, nbits - len(b))

def create_bits(rng, nbits, m):
    def bits():
        yield from prng_to_bits(rng(rint(m)), nbits)
    bits.__name__ = rng.__name__ + "_bits"
    return bits

park_miller_bits = create_bits(park_miller, 31, MOD)
randu_bits = create_bits(randu, 31, MOD)
ansic_bits = create_bits(ansic, 15, MOD)
java_bits = create_bits(java, 32, 2**48)

def python_bits():
    while True:
        yield random.getrandbits(1)

def count_bits(n):
    return bin(n).count("1")

def ones_test(prng, n=200):
    ones  = sum(count_bits(next(prng)) for _ in range(n))
    expected = n / 2
    stdev = sqrt(n / 4)
    lo, hi = round(expected - 2 *stdev), round(expected + 2 * stdev)
    if not lo < ones < hi:
        # print("Ones not {} < {} < {} - {}".format(lo, ones, hi, prng.__name__))
        return 0
    return 1

def count_bit_runs(prng, n=1000):
    return Counter(len(list(g)) for
                   k, g in IT.groupby(IT.islice(prng, n)) if k == 1)

def runs_test(prng, nbits=100000):
    runs = count_bit_runs(prng, nbits)
    longest_run = max(runs.keys())
    eruns, esdev =  -log(nbits/2)/log(0.5), 1 / log(2)
    rlo, rhi = ceil(eruns - 2 * esdev), floor(eruns + 2 * esdev)
    if not rlo <= longest_run <= rhi:
        # print("Run not {} < {} < {} - {}".format(rlo, longest_run, rhi, prng.__name__))
        return 0
    return 1

def tests(prng, nbits=100000):
    s1 = ones_test(prng, n=nbits)
    s2 = runs_test(prng, nbits)
    return s1 + s2

N = 10000
ncases = 50
for rng in (park_miller_bits, randu_bits, java_bits, python_bits, ansic_bits):
    score = 0
    for i in range(ncases):
        score += tests(rng(), N)
    print("{:16s}  {}/{}".format(rng.__name__, score, 2*ncases))
