#!/usr/bin/env python3
#
#  pi.py - Calculate Pi
#
import sys
import time
import math
import gmpy2
from gmpy2 import mpfr
from gmpy2 import mpz

#
# Global Variables
#
count   = 0
total   = 0
grad    = 0
step    = 0

#
# Show Progress
#
def progress_init(max):
        global count, total, grad, step

        total = max
        count = 0
        step = int(total / 1000)
        grad = int(step / 2)

def progress():
        global count, total, grad, step

        if (count > grad):
                grad += step
                g = int(math.floor(72.5*count/total+0.5))
                p = int(math.floor(1000.5*count/total+0.5))
                msg = "H" * g + "-" * (72-g) + " " + str(p/10) + "%\r"
                if (grad > total):
                        msg += "\n"

                print(msg, sep="", end="", flush=True)

#
# Write digit string
#
def write_string(digit_string):
        fd = open("pi-py.txt", mode="w")

        fd.write(" pi = ")
        fd.write(digit_string[0])
        fd.write(".")

        for c in range(1, len(digit_string), 50):
                if (c != 1):
                        fd.write("\t")

                fd.write(digit_string[c:c+50])

                if ((c % 1000) == 951):
                        fd.write(" << ")
                        fd.write(str(c+49))
                        fd.write("\r\n")
                elif ((c % 500) == 451):
                        fd.write(" <\r\n")
                else:
                        fd.write("\r\n")

        # Final new-line
        fd.write("\r\n")

        fd.close()

#
# Recursive funcion.
#
def s(a, b, max):
        global count

        m = math.ceil((a + b) / 2)

        if (b - a == 1):
                if (a == 0):
                        r = 120         # 6!
                        q = mpz(640320**3)
                        p = gmpy2.sub( gmpy2.mul(q, 13591409),
                                gmpy2.mul(r, 13591409+545140134) )
                else:
                        r = mpz(8 * (a*6+1) * (a*6+3) * (a*6+5))
                        q = mpz((b*640320)**3)
                        if ((b%2) == 0):
                                p = gmpy2.mul(mpz(13591409 + b*545140134), r)
                        else:
                                p = gmpy2.mul(mpz(-13591409 - b*545140134), r)
        else:
                p1, q1, r1 = s(a, m, max)
                p2, q2, r2 = s(m, b, max)

                # Merge
                p = gmpy2.add( gmpy2.mul(p1, q2), gmpy2.mul(p2, r1) )
                q = gmpy2.mul(q1, q2)

                if (b != max):
                        r = gmpy2.mul(r1, r2)
                else:
                        r = 0

        count += 1
        progress()

        return p, q, r

#
# Calculate e
#
def calc_pi(digits):
        global total

        d = digits+1
        n_terms = math.ceil(d*math.log(10)/(3*math.log(53360)))
        precision = math.ceil(d * math.log2(10)) + 4
        print("d = ", d, ", n = ", n_terms, ", precision = ", precision, sep="")

        print("gmpy2 version:", gmpy2.version())
        print("MP version:", gmpy2.mp_version())
        print("MPFR version:", gmpy2.mpfr_version())

        max_precision = gmpy2.get_max_precision()
        print("max_precision =", max_precision)
        max_emax = gmpy2.get_emax_max()
        print("max_emax =", max_emax)

        if (max_precision < precision):
                print("Error! Max precision is too small! Program terminated.")
                return

        gmpy2.get_context().precision = precision
        gmpy2.get_context().emax = max_emax
        print("Real precision = ", gmpy2.get_context().precision)
        progress_init(n_terms * 2 - 1)                          # Initialize progress bar

        # Recursion
        start_time = time.monotonic_ns()
        p, q, r = s(0, n_terms, n_terms)
        end_time = time.monotonic_ns()
        elapsed = (end_time - start_time) / 1000000000
        print("Recursion:", elapsed, "seconds.")

        start_time = time.monotonic_ns()
        q = gmpy2.mul(q, 426880)
        end_time = time.monotonic_ns()
        elapsed = (end_time - start_time) / 1000000000
        print("Multiply by 426880:", elapsed, "seconds.")

        start_time = time.monotonic_ns()
        pf = mpfr(p)
        qf = mpfr(q)
        ef = gmpy2.div(qf, pf)
        end_time = time.monotonic_ns()
        elapsed = (end_time - start_time) / 1000000000
        print("Grand Division:", elapsed, "seconds.")

        start_time = time.monotonic_ns()
        ef = gmpy2.mul(ef, gmpy2.sqrt(10005))
        end_time = time.monotonic_ns()
        elapsed = (end_time - start_time) / 1000000000
        print("Multiply by sqrt(10005):", elapsed, "seconds.")

        # Convert to decimal digits
        start_time = time.monotonic_ns()
        estr, exp, prec = mpfr.digits(ef)
        estr = estr[0:d]
        end_time = time.monotonic_ns()
        elapsed = (end_time - start_time) / 1000000000
        print("Convert to decimal digits:", elapsed, "seconds.")

        # Write to file
        start_time = time.monotonic_ns()
        write_string(estr)
        end_time = time.monotonic_ns()
        elapsed = (end_time - start_time) / 1000000000
        print("Write to file:", elapsed, "seconds.")

#
#  main program
#
if __name__ == '__main__':
        argc = len(sys.argv)
        if (argc >= 2):
                digits = int(sys.argv[1])
        else:
                digits = 100000

        calc_pi(digits)

# End of pi.py
