fork download
  1. """intervals
  2.  
  3. Union, intersection, set difference and symmetric difference
  4. of possibly overlapping or touching integer intervals.
  5. Intervals are defined right-open. (1, 4) -> 1, 2, 3
  6.  
  7. e.g.
  8. union([(1, 4), (7, 9)], (3, 5)) -> [(1, 5), (7, 9)]
  9. intersection([(1, 4), (7, 9)], (3, 5)) -> [(3, 4)]
  10. set_difference([(1, 4), (7, 9)], (3, 5)) -> [(1, 3), (7, 9)]
  11. set_difference([(3, 5)], [(1, 4), (7, 9)]) -> [(4, 5)]
  12. symmetric_difference([(1, 4), (7, 9)], (3, 5)) -> [(1, 3), (4, 5), (7, 9)]
  13.  
  14. see: http://e...content-available-to-author-only...a.org/wiki/Set_theory#Basic_concepts_and_notation
  15. """
  16.  
  17. import copy
  18. from itertools import accumulate, chain, islice, repeat
  19. from operator import itemgetter
  20. import unittest
  21.  
  22. class Intervals(object):
  23. """Holds a non overlapping list of intervals.
  24. One single interval is just a pair.
  25. Overlapping or touching intervals are automatically merged.
  26. """
  27.  
  28. def __init__(self, interval_list=()):
  29. """Raises a ValueError if the length of one of the
  30. intervals in the list is negative.
  31. """
  32. if any(begin > end for begin, end in interval_list):
  33. raise ValueError('Invalid interval')
  34. self._interval_list = _merge_interval_lists(
  35. interval_list, [])
  36.  
  37. def __repr__(self):
  38. """Just write out all included intervals.
  39. """
  40. return 'Intervals ' + str(self._interval_list)
  41.  
  42. def get(self, copy_content=True):
  43. """Return the list of intervals.
  44. """
  45. return copy.copy(self._interval_list) if copy_content\
  46. else self._interval_list
  47.  
  48.  
  49. def union(a, b):
  50. """Merge a and b (union).
  51. """
  52. return Intervals(_merge_interval_lists(
  53. a.get(False), b.get(False)))
  54.  
  55. def intersections(a, b):
  56. """Intersects a and b.
  57. """
  58. return Intervals(_merge_interval_lists(
  59. a.get(False), b.get(False), merge_type='intersections'))
  60.  
  61. def set_difference(a, b):
  62. """Removes b from a.
  63. Set difference is not commutative.
  64. """
  65. return Intervals(_merge_interval_lists(
  66. a.get(False), b.get(False), merge_type='set difference'))
  67.  
  68. def symmetric_difference(a, b):
  69. """Symmetric difference of a and b.
  70. """
  71. return Intervals(_merge_interval_lists(
  72. a.get(False), b.get(False), merge_type='symmetric difference'))
  73.  
  74.  
  75. # class Intervals makes sure, by always building the union first,
  76. # that no invalid a's or b's are fed here.
  77. def _merge_interval_lists(a, b, merge_type='union'):
  78. """Merges two lists of intervals in O(n*log(n)).
  79. Overlapping or touching intervals are simplified to one.
  80.  
  81. Arguments:
  82. a and b -- The interval lists to merge.
  83. merge_type -- Can be:
  84. 'union',
  85. 'intersections',
  86. 'symmetric difference', or
  87. 'set difference'.
  88.  
  89. Return the sorted result as a list.
  90. """
  91.  
  92. # If we want to calculate the set difference
  93. # we invert the second interval list,
  94. # i.e. swap begin and end.
  95. if merge_type == 'set difference':
  96. b = map(lambda p: (p[1], p[0]), b)
  97.  
  98. # Separately sort begins and ends
  99. # and pair them with the implied change
  100. # of the count of currently open intervals.
  101. # e.g. (1, 4), (7, 9), (3, 5) ->
  102. # begins = [(1, 1), (3, 1), (7, 1)]
  103. # ends = [(4, -1), (5, -1), (9, -1)]
  104. both = list(chain(a, b))
  105. begins = zip(sorted(map(itemgetter(0), both)),
  106. repeat(1))
  107. ends = zip(sorted(map(itemgetter(1), both)),
  108. repeat(-1))
  109.  
  110. # Sort begins and ends together.
  111. # If the value is the same, begins come before ends
  112. # to ensure touching intervals being merged to one.
  113. # In our example above this means:
  114. # edges = [(1, 4), (3, 1), (4, -1), (5, -1), (7, 1), (9, -1)]
  115. edges = sorted(chain(begins, ends), key=lambda x: (x[0], -x[1]))
  116.  
  117. # Depending on the operation carried out,
  118. # the criteria for interval begins and ends in the result differ.
  119. # E.g:
  120. # a = | - - - - | | - - - |
  121. # b = | - - - - |
  122. # counts = 1 2 1 0 1 0 (union, intersection, sym diff)
  123. # counts = 1 0 -1 0 1 0 (set diff)
  124. # union = | - - - - - - | | - - - |
  125. # inter = | - - - - |
  126. # sym d = | - |
  127. # set d = | - | | - - - |
  128. #
  129. # One can see that union begins if the count changes from 0 to 1
  130. # and ends if the count changes from 1 to 0
  131. # An intersection begins at a change from 1 to 2 and ends with 2 to 1.
  132. # A symmetric difference begins at every change to one
  133. # and ends at every change away from one.
  134. # The conditions for the set difference are the same as for the union.
  135. check_begin = {'union': lambda change: change == (0, 1),
  136. 'intersections': lambda change: change == (1, 2),
  137. 'symmetric difference': lambda change: change[1] == 1,
  138. 'set difference': lambda change: change == (0, 1)
  139. }[merge_type]
  140.  
  141. check_end = {'union': lambda change: change == (1, 0),
  142. 'intersections': lambda change: change == (2, 1),
  143. 'symmetric difference': lambda change: change[1] != 1,
  144. 'set difference': lambda change: change == (1, 0)
  145. }[merge_type]
  146.  
  147. # The number of opened intervals after each edge.
  148. counts = list(accumulate(map(itemgetter(1), edges)))
  149. # The changes of opened intervals at each edge.
  150. changes = zip(chain([0], counts), counts)
  151. # Just the x positions of the edges.
  152. xs = map(itemgetter(0), edges)
  153. xs_and_changes = list(zip(xs, changes))
  154.  
  155. # Now we filter out the begins and ends from the changes
  156. # and get their x positions.
  157. res_begins = map(itemgetter(0),
  158. starfilter(lambda x, change: check_begin(change),
  159. xs_and_changes))
  160. res_ends = map(itemgetter(0),
  161. starfilter(lambda x, change: check_end(change),
  162. xs_and_changes))
  163.  
  164. # The result is then just pairing up the sorted begins and ends.
  165. result = pairwise(sorted(chain(res_begins, res_ends)), False)
  166.  
  167. # No empty intervals in the result.
  168. def length_greater_than_zero(interval):
  169. return interval[0] < interval[1]
  170. return list(filter(length_greater_than_zero, result))
  171.  
  172.  
  173. class TestIntervals(unittest.TestCase):
  174.  
  175. def test_ctor(self):
  176. # Check ctors sanity check.
  177. self.assertRaises(ValueError, Intervals, [(2, 4), (3, 1)])
  178.  
  179. def test_add_behind(self):
  180. # Check adding right of the last interval.
  181. intervals = Intervals([(0, 2)])
  182. intervals = union(intervals, Intervals([(3, 4)]))
  183. self.assertEqual(intervals.get(), [(0, 2), (3, 4)])
  184.  
  185. def test_add_in_front(self):
  186. # Check adding left to the first interval.
  187. intervals = Intervals([(3, 4)])
  188. intervals = union(intervals, Intervals([(1, 2)]))
  189. self.assertEqual(intervals.get(), [(1, 2), (3, 4)])
  190.  
  191. def test_add_in_between(self):
  192. # Check adding between two intervals.
  193. intervals = Intervals([(1, 2)])
  194. intervals = union(intervals, Intervals([(6, 9)]))
  195. intervals = union(intervals, Intervals([(3, 5)]))
  196. self.assertEqual(intervals.get(), [(1, 2), (3, 5), (6, 9)])
  197.  
  198. def test_add_touching(self):
  199. # Check adding a interval touching an existing one.
  200. intervals = Intervals([(1, 3)])
  201. intervals = union(intervals, Intervals([(3, 5)]))
  202. self.assertEqual(intervals.get(), [(1, 5)])
  203.  
  204. def test_add_overlapping(self):
  205. # Check adding a interval overlapping an existing one.
  206. intervals = Intervals([(1, 4)])
  207. intervals = union(intervals, Intervals([(3, 5)]))
  208. self.assertEqual(intervals.get(), [(1, 5)])
  209.  
  210. def test_add_overlapping_multiple(self):
  211. # Check adding a interval overlapping multiple existing ones.
  212. intervals = Intervals([(1, 4)])
  213. intervals = union(intervals, Intervals([(5, 7)]))
  214. intervals = union(intervals, Intervals([(8, 10)]))
  215. intervals = union(intervals, Intervals([(3, 9)]))
  216. self.assertEqual(intervals.get(), [(1, 10)])
  217.  
  218. def test_add_swallow(self):
  219. # Check adding a interval completely covering an existing one.
  220. intervals = Intervals([(2, 3)])
  221. intervals = union(intervals, Intervals([(1, 4)]))
  222. self.assertEqual(intervals.get(), [(1, 4)])
  223.  
  224. def test_sub(self):
  225. # Check removing an interval
  226. intervals = Intervals([(0, 3)])
  227. intervals = union(intervals, Intervals([(5, 7)]))
  228. intervals = set_difference(intervals, Intervals([(2, 6)]))
  229. self.assertEqual(intervals.get(), [(0, 2), (6, 7)])
  230.  
  231. def test_intersections(self):
  232. # Check adding right of the last interval.
  233. intervals = Intervals([(0, 3)])
  234. intervals = union(intervals, Intervals([(5, 7)]))
  235. intervals = intersections(intervals, Intervals([(2, 6)]))
  236. self.assertEqual(intervals.get(), [(2, 3), (5, 6)])
  237.  
  238. def test_symmetric_difference(self):
  239. # Check symmetric difference
  240. intervals = Intervals([(0, 3)])
  241. intervals = union(intervals, Intervals([(5, 7)]))
  242. intervals = symmetric_difference(intervals, Intervals([(2, 6)]))
  243. self.assertEqual(intervals.get(), [(0, 2), (3, 5), (6, 7)])
  244.  
  245.  
  246. def tuple_wise(iterable, size, step):
  247. """Tuples up the elements of iterable.
  248.  
  249. Arguments:
  250. iterable -- source data
  251. size -- size of the destination tuples
  252. step -- step to do in iterable per destination tuple
  253.  
  254. tuple_wise(s, 3, 1): "s -> (s0,s1,s2), (s1,s2,s3), (s3,s4,s5), ...
  255. tuple_wise(s, 2, 4): "s -> (s0,s1), (s4,s5), (s8,s9), ...
  256. """
  257. return zip(
  258. *(islice(iterable, start, None, step)
  259. for start in range(size)))
  260.  
  261. def pairwise(iterable, overlapping):
  262. """Pairs up the elements of iterable.
  263. overlapping: "s -> (s0,s1), (s2,s3), (s4,s5), ...
  264. not overlapping: "s -> (s0,s1), (s1,s2), (s2,s3), ...
  265. """
  266. return tuple_wise(iterable, 2, 1 if overlapping else 2)
  267.  
  268. def starfilter(function, iterable):
  269. """starfilter <--> filter == starmap <--> map"""
  270. return (item for item in iterable if function(*item))
  271.  
  272. if __name__ == '__main__':
  273. unittest.main(verbosity=2)
Success #stdin #stdout #stderr 0.14s 11336KB
stdin
Standard input is empty
stdout
Standard output is empty
stderr
test_add_behind (__main__.TestIntervals) ... ok
test_add_in_between (__main__.TestIntervals) ... ok
test_add_in_front (__main__.TestIntervals) ... ok
test_add_overlapping (__main__.TestIntervals) ... ok
test_add_overlapping_multiple (__main__.TestIntervals) ... ok
test_add_swallow (__main__.TestIntervals) ... ok
test_add_touching (__main__.TestIntervals) ... ok
test_ctor (__main__.TestIntervals) ... ok
test_intersections (__main__.TestIntervals) ... ok
test_sub (__main__.TestIntervals) ... ok
test_symmetric_difference (__main__.TestIntervals) ... ok

----------------------------------------------------------------------
Ran 11 tests in 0.017s

OK