fork download
  1. """intervals
  2.  
  3. Union, intersection, set difference and symmetric difference
  4. of possibly overlapping or touching integer intervals.
  5. Intervals are defined right-open. (1, 4) -> 1, 2, 3
  6.  
  7. e.g.
  8. union([(1, 4), (7, 9)], (3, 5)) -> [(1, 5), (7, 9)]
  9. intersection([(1, 4), (7, 9)], (3, 5)) -> [(3, 4)]
  10. set_difference([(1, 4), (7, 9)], (3, 5)) -> [(1, 3), (7, 9)]
  11. set_difference([(3, 5)], [(1, 4), (7, 9)]) -> [(4, 5)]
  12. symmetric_difference([(1, 4), (7, 9)], (3, 5)) -> [(1, 3), (4, 5), (7, 9)]
  13.  
  14. see: http://e...content-available-to-author-only...a.org/wiki/Set_theory#Basic_concepts_and_notation
  15. """
  16.  
  17. import copy
  18. from collections import namedtuple
  19. from itertools import accumulate, chain, islice, repeat
  20. from operator import itemgetter
  21. import unittest
  22.  
  23. class Intervals(object):
  24. """Holds a non overlapping list of intervals.
  25. One single interval is just a pair.
  26. Overlapping or touching intervals are automatically merged.
  27. """
  28.  
  29. def __init__(self, interval_list=()):
  30. """Raises a ValueError if the length of one of the
  31. intervals in the list is negative.
  32. """
  33. if any(begin > end for begin, end in interval_list):
  34. raise ValueError('Invalid interval')
  35. self._interval_list = _merge_interval_lists(
  36. interval_list, [], OP_UNION)
  37.  
  38. def __repr__(self):
  39. """Just write out all included intervals.
  40. """
  41. return 'Intervals ' + str(self._interval_list)
  42.  
  43. def get(self, copy_content=True):
  44. """Return the list of intervals.
  45. """
  46. return copy.copy(self._interval_list) if copy_content\
  47. else self._interval_list
  48.  
  49.  
  50. def union(a, b):
  51. """Merge a and b (union).
  52. """
  53. return Intervals(_merge_interval_lists(
  54. a.get(False), b.get(False), OP_UNION))
  55.  
  56. def intersections(a, b):
  57. """Intersects a and b.
  58. """
  59. return Intervals(_merge_interval_lists(
  60. a.get(False), b.get(False), merge_type=OP_INTERSECTIONS))
  61.  
  62. def set_difference(a, b):
  63. """Removes b from a.
  64. Set difference is not commutative.
  65. """
  66. return Intervals(_merge_interval_lists(
  67. a.get(False), b.get(False), merge_type=OP_SET_DIFFERENCE))
  68.  
  69. def symmetric_difference(a, b):
  70. """Symmetric difference of a and b.
  71. """
  72. return Intervals(_merge_interval_lists(
  73. a.get(False), b.get(False), merge_type=OP_SYMMETRIC_DIFFERENCE))
  74.  
  75.  
  76. def compose(func_1, func_2, unpack=False):
  77. """
  78. compose(func_1, func_2, unpack=False) -> function
  79.  
  80. The function returned by compose is a composition of func_1 and func_2.
  81. That is, compose(func_1, func_2)(5) == func_1(func_2(5)).
  82. """
  83. if not callable(func_1):
  84. raise TypeError("First argument to compose must be callable.")
  85. if not callable(func_2):
  86. raise TypeError("Second argument to compose must be callable.")
  87.  
  88. if unpack:
  89. def composition(*args, **kwargs):
  90. return func_1(*func_2(*args, **kwargs))
  91. else:
  92. def composition(*args, **kwargs):
  93. return func_1(func_2(*args, **kwargs))
  94. return composition
  95.  
  96.  
  97. # Helper to avoid repetitive lambdas.
  98. def check_for(cond):
  99. return lambda val: val == cond
  100.  
  101. # http://e...content-available-to-author-only...a.org/wiki/Identity_function
  102. identity = lambda x: x
  103.  
  104.  
  105. # Depending on the operation carried out,
  106. # the criteria for interval begins and ends in the result differ.
  107. # E.g:
  108. # a = | - - - - | | - - - |
  109. # b = | - - - - |
  110. # counts = 1 2 1 0 1 0 (union, intersection, sym diff)
  111. # counts = 1 0 -1 0 1 0 (set diff)
  112. # union = | - - - - - - | | - - - |
  113. # inter = | - - - - |
  114. # sym d = | - |
  115. # set d = | - | | - - - |
  116. #
  117. # One can see that union begins if the count changes from 0 to 1
  118. # and ends if the count changes from 1 to 0
  119. # An intersection begins at a change from 1 to 2 and ends with 2 to 1.
  120. # A symmetric difference begins at every change to one
  121. # and ends at every change away from one.
  122. # The conditions for the set difference are the same as for the union.
  123. set_op = namedtuple('set_op', ['transform', 'begin', 'end'])
  124.  
  125. OP_UNION = set_op(
  126. transform = identity,
  127. begin = check_for((0, 1)),
  128. end = check_for((1, 0)))
  129.  
  130. OP_INTERSECTIONS = set_op(
  131. transform = identity,
  132. begin = check_for((1, 2)),
  133. end = check_for((2, 1)))
  134.  
  135. OP_SYMMETRIC_DIFFERENCE = set_op(
  136. transform = identity,
  137. begin = lambda change: change[1] == 1,
  138. end = lambda change: change[1] != 1)
  139.  
  140. OP_SET_DIFFERENCE = set_op(
  141. # If we want to calculate the set difference
  142. # we invert the second interval list,
  143. # i.e. swap begin and end.
  144. transform = compose(tuple, reversed),
  145. begin = check_for((0, 1)),
  146. end = check_for((1, 0)))
  147.  
  148.  
  149. # class Intervals makes sure, by always building the union first,
  150. # that no invalid a's or b's are fed here.
  151. def _merge_interval_lists(a, b, merge_type):
  152. """Merges two lists of intervals in O(n*log(n)).
  153. Overlapping or touching intervals are simplified to one.
  154.  
  155. Arguments:
  156. a and b -- The interval lists to merge.
  157. merge_type -- Can be:
  158. OP_UNION,
  159. OP_INTERSECTIONS,
  160. OP_SYMMETRIC_DIFFERENCE, or
  161. OP_SET_DIFFERENCE.
  162.  
  163. Return the sorted result as a list.
  164. """
  165.  
  166. # Adjust the operand for the selected operator.
  167. b = map(merge_type.transform, b)
  168.  
  169. # Separately sort begins and ends
  170. # and pair them with the implied change
  171. # of the count of currently open intervals.
  172. # e.g. (1, 4), (7, 9), (3, 5) ->
  173. # begins = [(1, 1), (3, 1), (7, 1)]
  174. # ends = [(4, -1), (5, -1), (9, -1)]
  175. both = list(chain(a, b))
  176. begins = zip(sorted(map(itemgetter(0), both)),
  177. repeat(1))
  178. ends = zip(sorted(map(itemgetter(1), both)),
  179. repeat(-1))
  180.  
  181. # Sort begins and ends together.
  182. # If the value is the same, begins come before ends
  183. # to ensure touching intervals being merged to one.
  184. # In our example above this means:
  185. # edges = [(1, 4), (3, 1), (4, -1), (5, -1), (7, 1), (9, -1)]
  186. edges = sorted(chain(begins, ends), key=lambda x: (x[0], -x[1]))
  187.  
  188. # The number of opened intervals after each edge.
  189. counts = list(accumulate(map(itemgetter(1), edges)))
  190. # The changes of opened intervals at each edge.
  191. changes = zip(chain([0], counts), counts)
  192. # Just the x positions of the edges.
  193. xs = map(itemgetter(0), edges)
  194. xs_and_changes = list(zip(xs, changes))
  195.  
  196. # Now we filter out the begins and ends from the changes
  197. # and get their x positions.
  198. res_begins = map(itemgetter(0),
  199. starfilter(lambda x, change: merge_type.begin(change),
  200. xs_and_changes))
  201. res_ends = map(itemgetter(0),
  202. starfilter(lambda x, change: merge_type.end(change),
  203. xs_and_changes))
  204.  
  205. # The result is then just pairing up the sorted begins and ends.
  206. result = pairwise(sorted(chain(res_begins, res_ends)), False)
  207.  
  208. # No empty intervals in the result.
  209. def length_greater_than_zero(interval):
  210. return interval[0] < interval[1]
  211.  
  212. return list(filter(length_greater_than_zero, result))
  213.  
  214.  
  215. class TestIntervals(unittest.TestCase):
  216.  
  217. def test_ctor(self):
  218. # Check ctors sanity check.
  219. self.assertRaises(ValueError, Intervals, [(2, 4), (3, 1)])
  220.  
  221. def test_add_behind(self):
  222. # Check adding right of the last interval.
  223. intervals = Intervals([(0, 2)])
  224. intervals = union(intervals, Intervals([(3, 4)]))
  225. self.assertEqual(intervals.get(), [(0, 2), (3, 4)])
  226.  
  227. def test_add_in_front(self):
  228. # Check adding left to the first interval.
  229. intervals = Intervals([(3, 4)])
  230. intervals = union(intervals, Intervals([(1, 2)]))
  231. self.assertEqual(intervals.get(), [(1, 2), (3, 4)])
  232.  
  233. def test_add_in_between(self):
  234. # Check adding between two intervals.
  235. intervals = Intervals([(1, 2)])
  236. intervals = union(intervals, Intervals([(6, 9)]))
  237. intervals = union(intervals, Intervals([(3, 5)]))
  238. self.assertEqual(intervals.get(), [(1, 2), (3, 5), (6, 9)])
  239.  
  240. def test_add_touching(self):
  241. # Check adding a interval touching an existing one.
  242. intervals = Intervals([(1, 3)])
  243. intervals = union(intervals, Intervals([(3, 5)]))
  244. self.assertEqual(intervals.get(), [(1, 5)])
  245.  
  246. def test_add_overlapping(self):
  247. # Check adding a interval overlapping an existing one.
  248. intervals = Intervals([(1, 4)])
  249. intervals = union(intervals, Intervals([(3, 5)]))
  250. self.assertEqual(intervals.get(), [(1, 5)])
  251.  
  252. def test_add_overlapping_multiple(self):
  253. # Check adding a interval overlapping multiple existing ones.
  254. intervals = Intervals([(1, 4)])
  255. intervals = union(intervals, Intervals([(5, 7)]))
  256. intervals = union(intervals, Intervals([(8, 10)]))
  257. intervals = union(intervals, Intervals([(3, 9)]))
  258. self.assertEqual(intervals.get(), [(1, 10)])
  259.  
  260. def test_add_swallow(self):
  261. # Check adding a interval completely covering an existing one.
  262. intervals = Intervals([(2, 3)])
  263. intervals = union(intervals, Intervals([(1, 4)]))
  264. self.assertEqual(intervals.get(), [(1, 4)])
  265.  
  266. def test_sub(self):
  267. # Check removing an interval
  268. intervals = Intervals([(0, 3)])
  269. intervals = union(intervals, Intervals([(5, 7)]))
  270. intervals = set_difference(intervals, Intervals([(2, 6)]))
  271. self.assertEqual(intervals.get(), [(0, 2), (6, 7)])
  272.  
  273. def test_intersections(self):
  274. # Check adding right of the last interval.
  275. intervals = Intervals([(0, 3)])
  276. intervals = union(intervals, Intervals([(5, 7)]))
  277. intervals = intersections(intervals, Intervals([(2, 6)]))
  278. self.assertEqual(intervals.get(), [(2, 3), (5, 6)])
  279.  
  280. def test_symmetric_difference(self):
  281. # Check symmetric difference
  282. intervals = Intervals([(0, 3)])
  283. intervals = union(intervals, Intervals([(5, 7)]))
  284. intervals = symmetric_difference(intervals, Intervals([(2, 6)]))
  285. self.assertEqual(intervals.get(), [(0, 2), (3, 5), (6, 7)])
  286.  
  287.  
  288.  
  289. def tuple_wise(iterable, size, step):
  290. """Tuples up the elements of iterable.
  291.  
  292. Arguments:
  293. iterable -- source data
  294. size -- size of the destination tuples
  295. step -- step to do in iterable per destination tuple
  296.  
  297. tuple_wise(s, 3, 1): "s -> (s0,s1,s2), (s1,s2,s3), (s3,s4,s5), ...
  298. tuple_wise(s, 2, 4): "s -> (s0,s1), (s4,s5), (s8,s9), ...
  299. """
  300. return zip(
  301. *(islice(iterable, start, None, step)
  302. for start in range(size)))
  303.  
  304. def pairwise(iterable, overlapping):
  305. """Pairs up the elements of iterable.
  306. overlapping: "s -> (s0,s1), (s2,s3), (s4,s5), ...
  307. not overlapping: "s -> (s0,s1), (s1,s2), (s2,s3), ...
  308. """
  309. return tuple_wise(iterable, 2, 1 if overlapping else 2)
  310.  
  311. def starfilter(function, iterable):
  312. """starfilter <--> filter == starmap <--> map"""
  313. return (item for item in iterable if function(*item))
  314.  
  315.  
  316. if __name__ == '__main__':
  317. unittest.main(verbosity=2)
Success #stdin #stdout #stderr 0.22s 11384KB
stdin
Standard input is empty
stdout
Standard output is empty
stderr
test_add_behind (__main__.TestIntervals) ... ok
test_add_in_between (__main__.TestIntervals) ... ok
test_add_in_front (__main__.TestIntervals) ... ok
test_add_overlapping (__main__.TestIntervals) ... ok
test_add_overlapping_multiple (__main__.TestIntervals) ... ok
test_add_swallow (__main__.TestIntervals) ... ok
test_add_touching (__main__.TestIntervals) ... ok
test_ctor (__main__.TestIntervals) ... ok
test_intersections (__main__.TestIntervals) ... ok
test_sub (__main__.TestIntervals) ... ok
test_symmetric_difference (__main__.TestIntervals) ... ok

----------------------------------------------------------------------
Ran 11 tests in 0.004s

OK