import random
import time
class KdTree:
def __init__(self, points):
if not points:
return None
self.dimension = len(points[0])
self.root = self.build(points)
def build(self, points, depth=0):
if not points:
return None
axis = depth % self.dimension
points.sort(key=lambda x: x[axis])
median = len(points) // 2
return {
'value': points[median],
'left': self.build(points[0:median], depth + 1),
'right': self.build(points[median + 1:], depth + 1)
}
def search(self, ld, ru):
return self._search(self.root, 0, ld, ru)
def _search(self, node, d, ld, ru):
res = 0
p = node['value']
if ld[0] <= p[0] < ru[0] and ld[1] <= p[1] < ru[1]:
res += 1
axis = d % self.dimension
if node['left'] and ld[axis] <= p[axis]:
res += self._search(node['left'], d + 1, ld, ru)
if node['right'] and p[axis] < ru[axis]:
res += self._search(node['right'], d + 1, ld, ru)
return res
def bruteforce(points, ld, ru):
d = len(points[0])
cnt = 0
for point in points:
if all(ld[i] <= point[i] < ru[i] for i in range(d)):
cnt += 1
return cnt
if __name__ == '__main__':
d = 8
points = [list(random.randint(1, 1_000_000) for j in range(d)) for i in range(500_000)]
ld = [random.randint(2000, 3000) for i in range(d)]
ru = [random.randint(13000, 15000) for i in range(d)]
t0 = time.time()
kdtree = KdTree(points)
t1 = time.time()
ret_kd = kdtree.search(ld, ru)
t2 = time.time()
ret_b = bruteforce(points, ld, ru)
t3 = time.time()
assert ret_kd == ret_b
print(t1 - t0, 'build')
print(t2 - t1, 'kd')
print(t3 - t2, 'bruteforce')