fork download
  1. # SSCCE version of the core function
  2. def nodes_in_range(src, cell, maxDist):
  3. srcX, srcY, srcZ = src.x, src.y, src.z
  4. maxDistSq = maxDist ** 2
  5. for node in cell:
  6. distSq = (node.x - srcX) ** 2
  7. if distSq > maxDistSq: continue
  8. distSq += (node.y - srcY) ** 2
  9. if distSq > maxDistSq: continue
  10. distSq += (node.z - srcZ) ** 2
  11. if distSq <= maxDistSq:
  12. yield node, distSq ** 0.5 # fast sqrt
  13.  
  14. from collections import namedtuple
  15. class Node(namedtuple('Node', ('ID', 'x', 'y', 'z'))):
  16. # actual class has assorted other properties
  17. pass
  18.  
  19. cell = [
  20. Node(1, 0, 0, 0),
  21. Node(2, -2, -3, 4),
  22. Node(3, .1, .2, .3),
  23. Node(4, 2.3, -3.3, -4.5),
  24. Node(5, -2.5, 4.5, 5),
  25. Node(6, 4, 3., 2.),
  26. Node(7, -2.46, 2.46, -2.47),
  27. Node(8, 2.45, -2.46, -2.47),
  28. Node(9, .5, .5, .1),
  29. Node(10, 5, 6, 7),
  30. # In practice, cells have upto 600 entries
  31. ]
  32.  
  33. if __name__ == "__main__":
  34. for node, dist in nodes_in_range(cell[0], cell, 4.2):
  35. print("{:3n} {:5.2f}".format(node.ID, dist))
  36.  
  37. import numpy
  38. import numpy.linalg
  39. contarry = numpy.ascontiguousarray
  40. float32 = numpy.float32
  41.  
  42. # The "np_cell" has two arrays: one is the list of nodes and the
  43. # second is a vectorizable array of their positions.
  44. # np_cell[N][1] == numpy array position of np_cell[N][0]
  45.  
  46. def make_np_cell(cell):
  47. return (
  48. tuple(cell),
  49. contarry([contarry((node.x, node.y, node.z), float32) for node in cell]),
  50. )
  51.  
  52. # This version fails because norm returns a single value.
  53. def np_nodes_in_range1(srcPos, np_cell, maxDist):
  54. distances = numpy.linalg.norm(np_cell[1] - srcPos)
  55.  
  56. for (node, dist) in zip(np_cell[0], distances):
  57. if dist <= maxDist:
  58. yield node, dist
  59.  
  60. # This version fails because
  61. def np_nodes_in_range2(srcPos, np_cell, maxDist):
  62. # this will fail because the distances are wrong
  63. distances = numpy.linalg.norm(np_cell[1] - srcPos, ord=1, axis=1)
  64. for (node, dist) in zip(np_cell[0], distances):
  65. if dist <= maxDist:
  66. yield node, dist
  67.  
  68. # This version doesn't vectorize and so performs poorly
  69. def np_nodes_in_range3(srcPos, np_cell, maxDist):
  70. norm = numpy.linalg.norm
  71. for (node, pos) in zip(np_cell[0], np_cell[1]):
  72. dist = norm(srcPos - pos)
  73. if dist <= maxDist:
  74. yield node, dist
  75.  
  76. def np_nodes_in_range4(srcPos, np_cell, maxDist):
  77. # this will fail because the distances are wrong
  78. nodes = np_cell[0]
  79. distances = numpy.linalg.norm(np_cell[1] - srcPos, ord=1, axis=1)
  80. for idx, dist in enumerate(distances):
  81. if dist <= maxDist:
  82. yield nodes[idx], dist
  83.  
  84. if __name__ == "__main__":
  85. np_cell = make_np_cell(cell)
  86. srcPos = np_cell[1][0] # Position column [1], first node [0]
  87. print("v1 - fails because it gets a single distance")
  88. try:
  89. for node, dist in np_nodes_in_range1(srcPos, np_cell, float32(4.2)):
  90. print("{:3n} {:5.2f}".format(node.ID, dist))
  91. except TypeError:
  92. print("distances was a single value")
  93.  
  94. print("v2 - gets the wrong distance values")
  95. for node, dist in np_nodes_in_range2(srcPos, np_cell, float32(4.2)):
  96. print("{:3n} {:5.2f}".format(node.ID, dist))
  97.  
  98. print("v3 - slower")
  99. for node, dist in np_nodes_in_range3(srcPos, np_cell, float32(4.2)):
  100. print("{:3n} {:5.2f}".format(node.ID, dist))
  101.  
  102. print("v4 - v2 using enumerate")
  103. for node, dist in np_nodes_in_range4(srcPos, np_cell, float32(4.2)):
  104. print("{:3n} {:5.2f}".format(node.ID, dist))
  105.  
Success #stdin #stdout 0.21s 26216KB
stdin
Standard input is empty
stdout
  1  0.00
  3  0.37
  9  0.71
v1 - fails because it gets a single distance
distances was a single value
v2 - gets the wrong distance values
  1  0.00
  3  0.60
  9  1.10
v3 - slower
  1  0.00
  3  0.37
  9  0.71
v4 - v2 using enumerate
  1  0.00
  3  0.60
  9  1.10