fork(2) download
  1. from bisect import bisect_left
  2. from collections import defaultdict
  3.  
  4. class Cache(defaultdict):
  5. def __init__(self, method):
  6. self.method = method
  7. def __missing__(self, key):
  8. return self.method(key)
  9.  
  10. def memoized(f):
  11. cache = Cache(lambda args: f(*args))
  12. def ret(*args):
  13. return cache[args]
  14. return ret
  15.  
  16. class MappedList(object):
  17. def __init__(self, method, input):
  18. self.method = memoized(method)
  19. self.input = input
  20. def __iter__(self):
  21. return iter((self.method(x) for x in self.input))
  22. def __len__(self):
  23. return len(self.input)
  24. def __getitem__(self, i):
  25. return self.method(i)
  26.  
  27. def find_closest(data, target, key = lambda x:x):
  28. s = sorted(data)
  29. evaluated = MappedList(key, s)
  30. index = bisect_left(evaluated, target)
  31. if index == 0:
  32. return data[0]
  33. if index == len(data):
  34. return data[index-1]
  35. if target - evaluated[index-1] <= evaluated[index] - target:
  36. return data[index-1]
  37. else:
  38. return data[index]
  39.  
  40. count = 0
  41.  
  42. def calc(x):
  43. global count
  44. count += 1
  45. return x
  46.  
  47. count = 0
  48.  
  49. data = [0, 2, 6, 10, 15, 17]
  50.  
  51. print(list(MappedList(lambda x: x+1, data)))
  52. count = 0
  53. result = find_closest(data, 10, calc)
  54. assert result == 10
  55. print(count)
  56. count = 0
  57. result = find_closest(data, 9, calc)
  58. assert result == 10
  59. print(count)
  60. count = 0
  61. result = find_closest(data, 7, calc)
  62. assert result == 6
  63. print(count)
  64. count = 0
  65. result = find_closest(data, -1, calc)
  66. assert result == 0
  67. print(count)
  68. count = 0
  69. result = find_closest(data, 0, calc)
  70. assert result == 0
  71. print(count)
  72. count = 0
  73. result = find_closest(data, 17, calc)
  74. assert result == 17
  75. print(count)
  76. count = 0
  77. result = find_closest(data, 19, calc)
  78. assert result == 17
  79. print(count)
  80.  
  81.  
  82.  
Success #stdin #stdout 0.04s 9488KB
stdin
Standard input is empty
stdout
[1, 3, 7, 11, 16, 18]
2
2
2
3
3
2
2