fork download
  1. from itertools import chain, groupby
  2. from random import random
  3.  
  4. upperBound = 512
  5.  
  6. #numbers = [n for n in range(0, upperBound) if random() < 0.5]
  7. numbers = [n for n in range(0, upperBound) if n % 3 != 0]
  8.  
  9. bits = ''.join([str(1 if n in numbers else 0) for n in range(0, upperBound)])
  10.  
  11. def flatten(l):
  12. return list(chain.from_iterable(l))
  13.  
  14. def encode(l):
  15. return flatten([[k, len(list(g))] for k, g in groupby(l)])
  16.  
  17. def decode(l):
  18. return flatten([[l[i * 2]] * l[i * 2 + 1] for i in range(int(len(l) / 2))])
  19.  
  20. def bwt(s):
  21.  
  22. assert "^" not in s, "Input string cannot contain '^'"
  23.  
  24. s += "^" # Add end of file marker
  25. table = sorted(s[i:] + s[:i] for i in range(len(s))) # Table of rotations of string
  26. last_column = [row[-1:] for row in table] # Last characters of each row
  27.  
  28. return "".join(last_column) # Convert list of characters into string
  29.  
  30. def ibwt(r, *args):
  31.  
  32. firstCol = "".join(sorted(r))
  33. count = [0]*256
  34. byteStart = [-1]*256
  35. output = [""] * len(r)
  36. shortcut = [None]*len(r)
  37.  
  38. #Generates shortcut lists
  39. for i in range(len(r)):
  40. shortcutIndex = ord(r[i])
  41. shortcut[i] = count[shortcutIndex]
  42. count[shortcutIndex] += 1
  43. shortcutIndex = ord(firstCol[i])
  44. if byteStart[shortcutIndex] == -1:
  45. byteStart[shortcutIndex] = i
  46.  
  47. localIndex = (r.index("^") if not args else args[0])
  48.  
  49. for i in range(len(r)):
  50. #takes the next index indicated by the transformation vector
  51. nextByte = r[localIndex]
  52. output [len(r)-i-1] = nextByte
  53. shortcutIndex = ord(nextByte)
  54. #assigns localIndex to the next index in the transformation vector
  55. localIndex = byteStart[shortcutIndex] + shortcut[localIndex]
  56.  
  57. return "".join(output).rstrip("^")
  58.  
  59. ppbits = [bits[i : i + 8] for i in range(0, len(bits), 8)]
  60. ppbits = [ppbits[i : i + 8] for i in range(0, len(ppbits), 8)]
  61.  
  62. print '\n'.join(map(lambda l : ' '.join(l), ppbits)), '\n'
  63.  
  64. compressed = encode(list(bwt(bits)))
  65. print compressed, '\n'
  66.  
  67. bits_ = ibwt(''.join(decode(compressed)))
  68. numbers_ = [n for n in range(len(bits_)) if bits_[n] == '1']
  69.  
  70. print bits == bits_ and numbers == numbers_
Success #stdin #stdout 0.02s 10056KB
stdin
Standard input is empty
stdout
01101101 10110110 11011011 01101101 10110110 11011011 01101101 10110110
11011011 01101101 10110110 11011011 01101101 10110110 11011011 01101101
10110110 11011011 01101101 10110110 11011011 01101101 10110110 11011011
01101101 10110110 11011011 01101101 10110110 11011011 01101101 10110110
11011011 01101101 10110110 11011011 01101101 10110110 11011011 01101101
10110110 11011011 01101101 10110110 11011011 01101101 10110110 11011011
01101101 10110110 11011011 01101101 10110110 11011011 01101101 10110110
11011011 01101101 10110110 11011011 01101101 10110110 11011011 01101101 

['^', 1, '1', 340, '0', 171, '1', 1] 

True