#!/usr/bin/env python3

import sys
import numpy as np
from random import randrange
from matplotlib import pyplot as plt
from matplotlib import image
from scipy.special import expit
import gzip
import struct

HIDDEN = 200

files = [
    "../t10k-images-idx3-ubyte.gz",
    "../train-images-idx3-ubyte.gz",
    "../t10k-labels-idx1-ubyte.gz",
    "../train-labels-idx1-ubyte.gz"
]

def train():

    with gzip.open(files[1],'rb') as fd:
        all = fd.read()

    eoh = 16
    header = struct.unpack('>iiii',all[:eoh])
    print("header " + str(header))

    imagesize = header[2]*header[3]
    #imagesize = 28*28  # 784 pixels
    print("imagesize " + str(imagesize) )


    #initializing
    pixno = 0
    startpos = eoh + imagesize * pixno
    endpos = startpos + imagesize
    trainImages = np.empty([60000, 28, 28])


    for i in range(0, header[1]):
        a1 = all[startpos:endpos]
        a2 = struct.unpack('784B', a1)
        a3 = np.array(a2).reshape(28,28)
        trainImages[i] = a3
        startpos = endpos + 1
        endpos = startpos + imagesize
       # print("count = =" + str(count))