fork download
  1. from itertools import count, accumulate, chain, cycle
  2. from random import randrange, sample
  3.  
  4. empty = ("_empty", "_empty", 0)
  5. deleted = ("_deleted", "_deleted", 0)
  6.  
  7. def is_prime(x):
  8. for f in accumulate(chain((2, 1, 2, 2), cycle((4, 2, 4, 2, 4, 6, 2, 6)))):
  9. if f * f > x:
  10. return True
  11. if x % f == 0:
  12. return False
  13.  
  14. def next_prime(x):
  15. return next(p for p in count(x|1, 2) if is_prime(p))
  16.  
  17. def inc_linprob(hashval, M):
  18. """ increment for linear probing"""
  19. return 1
  20.  
  21. def inc_double_hashing(hashval, M):
  22. """ increment for double hashing - make sure it is not in range [1, M-1]"""
  23. return hashval % (M-1) + 1
  24.  
  25. class dictlh(object):
  26. """ Store (key, value, hash(key))
  27. """
  28. def __init__(self, mode=inc_linprob, load_factor=0.5):
  29. self.M = 997
  30. self.buckets = [empty] * self.M
  31. self.load_factor = load_factor
  32. self.max_load = load_factor * self.M
  33. self.n_elements = 0
  34. self.n_deleted = 0
  35. self._increment = mode
  36. self.searches = 0
  37.  
  38. def _find_insert_bucket(self, key):
  39. """ result: find the key OR find an empty bucket to insert """
  40. hashval = hash(key)
  41. b, M = self.buckets, self.M
  42. inc = self._increment(hashval, M)
  43. h = hashval % M
  44. bh = b[h]
  45. while bh != empty and bh[0] != key:
  46. h = (h + inc) % M
  47. bh = b[h]
  48. self.searches += 1
  49. return h
  50.  
  51. def _insert(self, key, value, hashval=None):
  52. if hashval is None:
  53. hashval = hash(key)
  54. bi = self._find_insert_bucket(key)
  55. isempty = self.buckets[bi] == empty
  56. self.buckets[bi] = (key, value)
  57. if isempty:
  58. self.n_elements += 1
  59. if self.n_elements > self.max_load:
  60. self._rehash()
  61.  
  62. def _find_key(self, key):
  63. bi = self._find_insert_bucket(key)
  64. if self.buckets[bi][0] == key:
  65. return bi
  66. raise KeyError("key not found")
  67.  
  68. def _lookup(self, key):
  69. h = self._find_key(key)
  70. return self.buckets[h][1]
  71.  
  72. def _rehash(self):
  73. load = self.n_elements - self.n_deleted
  74. self.M = next_prime(4 * load)
  75. self.max_load = self.load_factor * self.M
  76. self.n_elements = self.n_deleted = 0
  77. old_buckets = list(self.buckets)
  78. self.buckets = [empty] * self.M
  79. for item in old_buckets:
  80. if item not in (deleted, empty):
  81. self._insert(*item)
  82.  
  83. def _delete(self, key):
  84. h = self._find_key(key)
  85. self.buckets[h] = deleted
  86. self.n_deleted += 1
  87.  
  88. def __setitem__(self, key, value):
  89. self._insert(key, value)
  90.  
  91. def __getitem__(self, key):
  92. return self._lookup(key)
  93.  
  94. def __delitem__(self, key):
  95. self._delete(key)
  96.  
  97. def get(self, key, default=None):
  98. try:
  99. return self._lookup(key)
  100. except KeyError:
  101. return default
  102.  
  103. def __len__(self):
  104. return self.n_elements - self.n_deleted
  105.  
  106. def __contains__(self, key):
  107. bi = self._find_insert_bucket(key)
  108. return self.buckets[bi][0] == key
  109.  
  110. def __iter__(self):
  111. yield from (b[0] for b in self.buckets if b not in (empty, deleted))
  112.  
  113. def keys(self):
  114. """a copy of all keys"""
  115. return list((b[0] for b in self.buckets if b not in (empty, deleted)))
  116.  
  117. def values(self):
  118. """a copy of all values"""
  119. return list((b[1] for b in self.buckets if b not in (empty, deleted)))
  120.  
  121.  
  122. if __name__ == '__main__':
  123. from time import clock
  124. from glob import glob
  125. from itertools import starmap, repeat
  126.  
  127. def repeatfunc(func, times=None, *args):
  128. if times is None:
  129. return starmap(func, repeat(args))
  130. return starmap(func, repeat(args, times))
  131.  
  132. load_factor = 0.50
  133. # Tests
  134. for mode in (inc_linprob, inc_double_hashing):
  135. D = dictlh(mode=mode, load_factor=load_factor)
  136. PD = {}
  137. for k, v in zip("abcdefghijklmnopqrstuvwxyz", range(26)):
  138. D[k] = v
  139. PD[k] = v
  140.  
  141. assert D["v"] == PD["v"]
  142.  
  143. del D["v"]
  144. del PD["v"]
  145. try:
  146. D["v"]
  147. except KeyError:
  148. pass
  149.  
  150. for key in "abcdefghijklmnopqrstuwxyz":
  151. assert D[key] == PD[key]
  152.  
  153. D["a"] = 99
  154. PD["a"] = 99
  155. assert D["a"] == PD["a"]
  156.  
  157. def random_ints(N, start, stop):
  158. return list(repeatfunc(randrange, N, *(start, stop)))
  159.  
  160. N = 1000
  161. rand1, rand2 = random_ints(N, 0, 50), random_ints(N, 0, 1000000)
  162. for key, val in zip(rand1, rand2):
  163. D[key] = val
  164. PD[key] = val
  165. assert D[key] == PD[key]
  166.  
  167. for key in sample(rand1, N // 2):
  168. try:
  169. del D[key]
  170. except KeyError:
  171. pass
  172. try:
  173. del PD[key]
  174. except KeyError:
  175. pass
  176.  
  177.  
  178. for key, val in PD.items():
  179. assert D[key] == val, "{} {} {}".format(key, val, D[key])
  180.  
  181. print(D.searches)
  182.  
  183. print("tests OK")
  184.  
Success #stdin #stdout 0.03s 37400KB
stdin
Standard input is empty
stdout
14314
tests OK
2364
tests OK