• Source
    1. import random
    2. import time
    3.  
    4.  
    5. class KdTree:
    6. def __init__(self, points):
    7. if not points:
    8. return None
    9. self.dimension = len(points[0])
    10. self.root = self.build(points)
    11.  
    12. def build(self, points, depth=0):
    13. if not points:
    14. return None
    15. axis = depth % self.dimension
    16. points.sort(key=lambda x: x[axis])
    17. median = len(points) // 2
    18. return {
    19. 'value': points[median],
    20. 'left': self.build(points[0:median], depth + 1),
    21. 'right': self.build(points[median + 1:], depth + 1)
    22. }
    23.  
    24. def search(self, ld, ru):
    25. return self._search(self.root, 0, ld, ru)
    26.  
    27. def _search(self, node, d, ld, ru):
    28. res = 0
    29. p = node['value']
    30. if ld[0] <= p[0] < ru[0] and ld[1] <= p[1] < ru[1]:
    31. res += 1
    32. axis = d % self.dimension
    33. if node['left'] and ld[axis] <= p[axis]:
    34. res += self._search(node['left'], d + 1, ld, ru)
    35. if node['right'] and p[axis] < ru[axis]:
    36. res += self._search(node['right'], d + 1, ld, ru)
    37. return res
    38.  
    39.  
    40. def bruteforce(points, ld, ru):
    41. d = len(points[0])
    42. cnt = 0
    43. for point in points:
    44. if all(ld[i] <= point[i] < ru[i] for i in range(d)):
    45. cnt += 1
    46. return cnt
    47.  
    48.  
    49. if __name__ == '__main__':
    50. d = 8
    51. points = [list(random.randint(1, 1_000_000) for j in range(d)) for i in range(500_000)]
    52.  
    53. ld = [random.randint(2000, 3000) for i in range(d)]
    54. ru = [random.randint(13000, 15000) for i in range(d)]
    55.  
    56. t0 = time.time()
    57.  
    58. kdtree = KdTree(points)
    59. t1 = time.time()
    60.  
    61. ret_kd = kdtree.search(ld, ru)
    62. t2 = time.time()
    63.  
    64. ret_b = bruteforce(points, ld, ru)
    65. t3 = time.time()
    66.  
    67. assert ret_kd == ret_b
    68.  
    69. print(t1 - t0, 'build')
    70. print(t2 - t1, 'kd')
    71. print(t3 - t2, 'bruteforce')
    72.