from itertools import chain, groupby
from random import random

upperBound = 512

#numbers = [n for n in range(0, upperBound) if random() < 0.5]
numbers = [n for n in range(0, upperBound) if n % 3 != 0]

bits = ''.join([str(1 if n in numbers else 0) for n in range(0, upperBound)])

def flatten(l):
    return list(chain.from_iterable(l))

def encode(l):
    return flatten([[k, len(list(g))] for k, g in groupby(l)])

def decode(l):
    return flatten([[l[i * 2]] * l[i * 2 + 1] for i in range(int(len(l) / 2))])

def bwt(s):

    assert "^" not in s, "Input string cannot contain '^'"

    s += "^"  # Add end of file marker
    table = sorted(s[i:] + s[:i] for i in range(len(s)))  # Table of rotations of string
    last_column = [row[-1:] for row in table]  # Last characters of each row

    return "".join(last_column)  # Convert list of characters into string

def ibwt(r, *args):

        firstCol = "".join(sorted(r))
        count = [0]*256
        byteStart = [-1]*256
        output = [""] * len(r)
        shortcut = [None]*len(r)

        #Generates shortcut lists
        for i in range(len(r)):
                shortcutIndex = ord(r[i])
                shortcut[i] = count[shortcutIndex]
                count[shortcutIndex] += 1
                shortcutIndex = ord(firstCol[i])
                if byteStart[shortcutIndex] == -1:
                        byteStart[shortcutIndex] = i

        localIndex = (r.index("^") if not args else args[0])

        for i in range(len(r)):
                #takes the next index indicated by the transformation vector
                nextByte = r[localIndex]
                output [len(r)-i-1] = nextByte
                shortcutIndex = ord(nextByte)
                #assigns localIndex to the next index in the transformation vector
                localIndex = byteStart[shortcutIndex] + shortcut[localIndex]

        return "".join(output).rstrip("^")

ppbits = [bits[i : i + 8] for i in range(0, len(bits), 8)]
ppbits = [ppbits[i : i + 8] for i in range(0, len(ppbits), 8)]

print '\n'.join(map(lambda l : ' '.join(l), ppbits)), '\n'

compressed = encode(list(bwt(bits)))
print compressed, '\n'

bits_ = ibwt(''.join(decode(compressed)))
numbers_ = [n for n in range(len(bits_)) if bits_[n] == '1']

print bits == bits_ and numbers == numbers_