Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/scipy/spatial/kdtree.py : 9%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright Anne M. Archibald 2008
2# Released under the scipy license
3import numpy as np
4from heapq import heappush, heappop
5import scipy.sparse
7__all__ = ['minkowski_distance_p', 'minkowski_distance',
8 'distance_matrix',
9 'Rectangle', 'KDTree']
12def minkowski_distance_p(x, y, p=2):
13 """
14 Compute the pth power of the L**p distance between two arrays.
16 For efficiency, this function computes the L**p distance but does
17 not extract the pth root. If `p` is 1 or infinity, this is equal to
18 the actual L**p distance.
20 Parameters
21 ----------
22 x : (M, K) array_like
23 Input array.
24 y : (N, K) array_like
25 Input array.
26 p : float, 1 <= p <= infinity
27 Which Minkowski p-norm to use.
29 Examples
30 --------
31 >>> from scipy.spatial import minkowski_distance_p
32 >>> minkowski_distance_p([[0,0],[0,0]], [[1,1],[0,1]])
33 array([2, 1])
35 """
36 x = np.asarray(x)
37 y = np.asarray(y)
39 # Find smallest common datatype with float64 (return type of this function) - addresses #10262.
40 # Don't just cast to float64 for complex input case.
41 common_datatype = np.promote_types(np.promote_types(x.dtype, y.dtype), 'float64')
43 # Make sure x and y are NumPy arrays of correct datatype.
44 x = x.astype(common_datatype)
45 y = y.astype(common_datatype)
47 if p == np.inf:
48 return np.amax(np.abs(y-x), axis=-1)
49 elif p == 1:
50 return np.sum(np.abs(y-x), axis=-1)
51 else:
52 return np.sum(np.abs(y-x)**p, axis=-1)
55def minkowski_distance(x, y, p=2):
56 """
57 Compute the L**p distance between two arrays.
59 Parameters
60 ----------
61 x : (M, K) array_like
62 Input array.
63 y : (N, K) array_like
64 Input array.
65 p : float, 1 <= p <= infinity
66 Which Minkowski p-norm to use.
68 Examples
69 --------
70 >>> from scipy.spatial import minkowski_distance
71 >>> minkowski_distance([[0,0],[0,0]], [[1,1],[0,1]])
72 array([ 1.41421356, 1. ])
74 """
75 x = np.asarray(x)
76 y = np.asarray(y)
77 if p == np.inf or p == 1:
78 return minkowski_distance_p(x, y, p)
79 else:
80 return minkowski_distance_p(x, y, p)**(1./p)
83class Rectangle(object):
84 """Hyperrectangle class.
86 Represents a Cartesian product of intervals.
87 """
88 def __init__(self, maxes, mins):
89 """Construct a hyperrectangle."""
90 self.maxes = np.maximum(maxes,mins).astype(float)
91 self.mins = np.minimum(maxes,mins).astype(float)
92 self.m, = self.maxes.shape
94 def __repr__(self):
95 return "<Rectangle %s>" % list(zip(self.mins, self.maxes))
97 def volume(self):
98 """Total volume."""
99 return np.prod(self.maxes-self.mins)
101 def split(self, d, split):
102 """
103 Produce two hyperrectangles by splitting.
105 In general, if you need to compute maximum and minimum
106 distances to the children, it can be done more efficiently
107 by updating the maximum and minimum distances to the parent.
109 Parameters
110 ----------
111 d : int
112 Axis to split hyperrectangle along.
113 split : float
114 Position along axis `d` to split at.
116 """
117 mid = np.copy(self.maxes)
118 mid[d] = split
119 less = Rectangle(self.mins, mid)
120 mid = np.copy(self.mins)
121 mid[d] = split
122 greater = Rectangle(mid, self.maxes)
123 return less, greater
125 def min_distance_point(self, x, p=2.):
126 """
127 Return the minimum distance between input and points in the hyperrectangle.
129 Parameters
130 ----------
131 x : array_like
132 Input.
133 p : float, optional
134 Input.
136 """
137 return minkowski_distance(0, np.maximum(0,np.maximum(self.mins-x,x-self.maxes)),p)
139 def max_distance_point(self, x, p=2.):
140 """
141 Return the maximum distance between input and points in the hyperrectangle.
143 Parameters
144 ----------
145 x : array_like
146 Input array.
147 p : float, optional
148 Input.
150 """
151 return minkowski_distance(0, np.maximum(self.maxes-x,x-self.mins),p)
153 def min_distance_rectangle(self, other, p=2.):
154 """
155 Compute the minimum distance between points in the two hyperrectangles.
157 Parameters
158 ----------
159 other : hyperrectangle
160 Input.
161 p : float
162 Input.
164 """
165 return minkowski_distance(0, np.maximum(0,np.maximum(self.mins-other.maxes,other.mins-self.maxes)),p)
167 def max_distance_rectangle(self, other, p=2.):
168 """
169 Compute the maximum distance between points in the two hyperrectangles.
171 Parameters
172 ----------
173 other : hyperrectangle
174 Input.
175 p : float, optional
176 Input.
178 """
179 return minkowski_distance(0, np.maximum(self.maxes-other.mins,other.maxes-self.mins),p)
182class KDTree(object):
183 """
184 kd-tree for quick nearest-neighbor lookup
186 This class provides an index into a set of k-D points which
187 can be used to rapidly look up the nearest neighbors of any point.
189 Parameters
190 ----------
191 data : (N,K) array_like
192 The data points to be indexed. This array is not copied, and
193 so modifying this data will result in bogus results.
194 leafsize : int, optional
195 The number of points at which the algorithm switches over to
196 brute-force. Has to be positive.
198 Raises
199 ------
200 RuntimeError
201 The maximum recursion limit can be exceeded for large data
202 sets. If this happens, either increase the value for the `leafsize`
203 parameter or increase the recursion limit by::
205 >>> import sys
206 >>> sys.setrecursionlimit(10000)
208 See Also
209 --------
210 cKDTree : Implementation of `KDTree` in Cython
212 Notes
213 -----
214 The algorithm used is described in Maneewongvatana and Mount 1999.
215 The general idea is that the kd-tree is a binary tree, each of whose
216 nodes represents an axis-aligned hyperrectangle. Each node specifies
217 an axis and splits the set of points based on whether their coordinate
218 along that axis is greater than or less than a particular value.
220 During construction, the axis and splitting point are chosen by the
221 "sliding midpoint" rule, which ensures that the cells do not all
222 become long and thin.
224 The tree can be queried for the r closest neighbors of any given point
225 (optionally returning only those within some maximum distance of the
226 point). It can also be queried, with a substantial gain in efficiency,
227 for the r approximate closest neighbors.
229 For large dimensions (20 is already large) do not expect this to run
230 significantly faster than brute force. High-dimensional nearest-neighbor
231 queries are a substantial open problem in computer science.
233 The tree also supports all-neighbors queries, both with arrays of points
234 and with other kd-trees. These do use a reasonably efficient algorithm,
235 but the kd-tree is not necessarily the best data structure for this
236 sort of calculation.
238 """
239 def __init__(self, data, leafsize=10):
240 self.data = np.asarray(data)
241 self.n, self.m = np.shape(self.data)
242 self.leafsize = int(leafsize)
243 if self.leafsize < 1:
244 raise ValueError("leafsize must be at least 1")
245 self.maxes = np.amax(self.data,axis=0)
246 self.mins = np.amin(self.data,axis=0)
248 self.tree = self.__build(np.arange(self.n), self.maxes, self.mins)
250 class node(object):
251 def __lt__(self, other):
252 return id(self) < id(other)
254 def __gt__(self, other):
255 return id(self) > id(other)
257 def __le__(self, other):
258 return id(self) <= id(other)
260 def __ge__(self, other):
261 return id(self) >= id(other)
263 def __eq__(self, other):
264 return id(self) == id(other)
266 class leafnode(node):
267 def __init__(self, idx):
268 self.idx = idx
269 self.children = len(idx)
271 class innernode(node):
272 def __init__(self, split_dim, split, less, greater):
273 self.split_dim = split_dim
274 self.split = split
275 self.less = less
276 self.greater = greater
277 self.children = less.children+greater.children
279 def __build(self, idx, maxes, mins):
280 if len(idx) <= self.leafsize:
281 return KDTree.leafnode(idx)
282 else:
283 data = self.data[idx]
284 # maxes = np.amax(data,axis=0)
285 # mins = np.amin(data,axis=0)
286 d = np.argmax(maxes-mins)
287 maxval = maxes[d]
288 minval = mins[d]
289 if maxval == minval:
290 # all points are identical; warn user?
291 return KDTree.leafnode(idx)
292 data = data[:,d]
294 # sliding midpoint rule; see Maneewongvatana and Mount 1999
295 # for arguments that this is a good idea.
296 split = (maxval+minval)/2
297 less_idx = np.nonzero(data <= split)[0]
298 greater_idx = np.nonzero(data > split)[0]
299 if len(less_idx) == 0:
300 split = np.amin(data)
301 less_idx = np.nonzero(data <= split)[0]
302 greater_idx = np.nonzero(data > split)[0]
303 if len(greater_idx) == 0:
304 split = np.amax(data)
305 less_idx = np.nonzero(data < split)[0]
306 greater_idx = np.nonzero(data >= split)[0]
307 if len(less_idx) == 0:
308 # _still_ zero? all must have the same value
309 if not np.all(data == data[0]):
310 raise ValueError("Troublesome data array: %s" % data)
311 split = data[0]
312 less_idx = np.arange(len(data)-1)
313 greater_idx = np.array([len(data)-1])
315 lessmaxes = np.copy(maxes)
316 lessmaxes[d] = split
317 greatermins = np.copy(mins)
318 greatermins[d] = split
319 return KDTree.innernode(d, split,
320 self.__build(idx[less_idx],lessmaxes,mins),
321 self.__build(idx[greater_idx],maxes,greatermins))
323 def __query(self, x, k=1, eps=0, p=2, distance_upper_bound=np.inf):
325 side_distances = np.maximum(0,np.maximum(x-self.maxes,self.mins-x))
326 if p != np.inf:
327 side_distances **= p
328 min_distance = np.sum(side_distances)
329 else:
330 min_distance = np.amax(side_distances)
332 # priority queue for chasing nodes
333 # entries are:
334 # minimum distance between the cell and the target
335 # distances between the nearest side of the cell and the target
336 # the head node of the cell
337 q = [(min_distance,
338 tuple(side_distances),
339 self.tree)]
340 # priority queue for the nearest neighbors
341 # furthest known neighbor first
342 # entries are (-distance**p, i)
343 neighbors = []
345 if eps == 0:
346 epsfac = 1
347 elif p == np.inf:
348 epsfac = 1/(1+eps)
349 else:
350 epsfac = 1/(1+eps)**p
352 if p != np.inf and distance_upper_bound != np.inf:
353 distance_upper_bound = distance_upper_bound**p
355 while q:
356 min_distance, side_distances, node = heappop(q)
357 if isinstance(node, KDTree.leafnode):
358 # brute-force
359 data = self.data[node.idx]
360 ds = minkowski_distance_p(data,x[np.newaxis,:],p)
361 for i in range(len(ds)):
362 if ds[i] < distance_upper_bound:
363 if len(neighbors) == k:
364 heappop(neighbors)
365 heappush(neighbors, (-ds[i], node.idx[i]))
366 if len(neighbors) == k:
367 distance_upper_bound = -neighbors[0][0]
368 else:
369 # we don't push cells that are too far onto the queue at all,
370 # but since the distance_upper_bound decreases, we might get
371 # here even if the cell's too far
372 if min_distance > distance_upper_bound*epsfac:
373 # since this is the nearest cell, we're done, bail out
374 break
375 # compute minimum distances to the children and push them on
376 if x[node.split_dim] < node.split:
377 near, far = node.less, node.greater
378 else:
379 near, far = node.greater, node.less
381 # near child is at the same distance as the current node
382 heappush(q,(min_distance, side_distances, near))
384 # far child is further by an amount depending only
385 # on the split value
386 sd = list(side_distances)
387 if p == np.inf:
388 min_distance = max(min_distance, abs(node.split-x[node.split_dim]))
389 elif p == 1:
390 sd[node.split_dim] = np.abs(node.split-x[node.split_dim])
391 min_distance = min_distance - side_distances[node.split_dim] + sd[node.split_dim]
392 else:
393 sd[node.split_dim] = np.abs(node.split-x[node.split_dim])**p
394 min_distance = min_distance - side_distances[node.split_dim] + sd[node.split_dim]
396 # far child might be too far, if so, don't bother pushing it
397 if min_distance <= distance_upper_bound*epsfac:
398 heappush(q,(min_distance, tuple(sd), far))
400 if p == np.inf:
401 return sorted([(-d,i) for (d,i) in neighbors])
402 else:
403 return sorted([((-d)**(1./p),i) for (d,i) in neighbors])
405 def query(self, x, k=1, eps=0, p=2, distance_upper_bound=np.inf):
406 """
407 Query the kd-tree for nearest neighbors
409 Parameters
410 ----------
411 x : array_like, last dimension self.m
412 An array of points to query.
413 k : int, optional
414 The number of nearest neighbors to return.
415 eps : nonnegative float, optional
416 Return approximate nearest neighbors; the kth returned value
417 is guaranteed to be no further than (1+eps) times the
418 distance to the real kth nearest neighbor.
419 p : float, 1<=p<=infinity, optional
420 Which Minkowski p-norm to use.
421 1 is the sum-of-absolute-values "Manhattan" distance
422 2 is the usual Euclidean distance
423 infinity is the maximum-coordinate-difference distance
424 distance_upper_bound : nonnegative float, optional
425 Return only neighbors within this distance. This is used to prune
426 tree searches, so if you are doing a series of nearest-neighbor
427 queries, it may help to supply the distance to the nearest neighbor
428 of the most recent point.
430 Returns
431 -------
432 d : float or array of floats
433 The distances to the nearest neighbors.
434 If x has shape tuple+(self.m,), then d has shape tuple if
435 k is one, or tuple+(k,) if k is larger than one. Missing
436 neighbors (e.g. when k > n or distance_upper_bound is
437 given) are indicated with infinite distances. If k is None,
438 then d is an object array of shape tuple, containing lists
439 of distances. In either case the hits are sorted by distance
440 (nearest first).
441 i : integer or array of integers
442 The locations of the neighbors in self.data. i is the same
443 shape as d.
445 Examples
446 --------
447 >>> from scipy import spatial
448 >>> x, y = np.mgrid[0:5, 2:8]
449 >>> tree = spatial.KDTree(list(zip(x.ravel(), y.ravel())))
450 >>> tree.data
451 array([[0, 2],
452 [0, 3],
453 [0, 4],
454 [0, 5],
455 [0, 6],
456 [0, 7],
457 [1, 2],
458 [1, 3],
459 [1, 4],
460 [1, 5],
461 [1, 6],
462 [1, 7],
463 [2, 2],
464 [2, 3],
465 [2, 4],
466 [2, 5],
467 [2, 6],
468 [2, 7],
469 [3, 2],
470 [3, 3],
471 [3, 4],
472 [3, 5],
473 [3, 6],
474 [3, 7],
475 [4, 2],
476 [4, 3],
477 [4, 4],
478 [4, 5],
479 [4, 6],
480 [4, 7]])
481 >>> pts = np.array([[0, 0], [2.1, 2.9]])
482 >>> tree.query(pts)
483 (array([ 2. , 0.14142136]), array([ 0, 13]))
484 >>> tree.query(pts[0])
485 (2.0, 0)
487 """
488 x = np.asarray(x)
489 if np.shape(x)[-1] != self.m:
490 raise ValueError("x must consist of vectors of length %d but has shape %s" % (self.m, np.shape(x)))
491 if p < 1:
492 raise ValueError("Only p-norms with 1<=p<=infinity permitted")
493 retshape = np.shape(x)[:-1]
494 if retshape != ():
495 if k is None:
496 dd = np.empty(retshape,dtype=object)
497 ii = np.empty(retshape,dtype=object)
498 elif k > 1:
499 dd = np.empty(retshape+(k,),dtype=float)
500 dd.fill(np.inf)
501 ii = np.empty(retshape+(k,),dtype=int)
502 ii.fill(self.n)
503 elif k == 1:
504 dd = np.empty(retshape,dtype=float)
505 dd.fill(np.inf)
506 ii = np.empty(retshape,dtype=int)
507 ii.fill(self.n)
508 else:
509 raise ValueError("Requested %s nearest neighbors; acceptable numbers are integers greater than or equal to one, or None")
510 for c in np.ndindex(retshape):
511 hits = self.__query(x[c], k=k, eps=eps, p=p, distance_upper_bound=distance_upper_bound)
512 if k is None:
513 dd[c] = [d for (d,i) in hits]
514 ii[c] = [i for (d,i) in hits]
515 elif k > 1:
516 for j in range(len(hits)):
517 dd[c+(j,)], ii[c+(j,)] = hits[j]
518 elif k == 1:
519 if len(hits) > 0:
520 dd[c], ii[c] = hits[0]
521 else:
522 dd[c] = np.inf
523 ii[c] = self.n
524 return dd, ii
525 else:
526 hits = self.__query(x, k=k, eps=eps, p=p, distance_upper_bound=distance_upper_bound)
527 if k is None:
528 return [d for (d,i) in hits], [i for (d,i) in hits]
529 elif k == 1:
530 if len(hits) > 0:
531 return hits[0]
532 else:
533 return np.inf, self.n
534 elif k > 1:
535 dd = np.empty(k,dtype=float)
536 dd.fill(np.inf)
537 ii = np.empty(k,dtype=int)
538 ii.fill(self.n)
539 for j in range(len(hits)):
540 dd[j], ii[j] = hits[j]
541 return dd, ii
542 else:
543 raise ValueError("Requested %s nearest neighbors; acceptable numbers are integers greater than or equal to one, or None")
545 def __query_ball_point(self, x, r, p=2., eps=0):
546 R = Rectangle(self.maxes, self.mins)
548 def traverse_checking(node, rect):
549 if rect.min_distance_point(x, p) > r / (1. + eps):
550 return []
551 elif rect.max_distance_point(x, p) < r * (1. + eps):
552 return traverse_no_checking(node)
553 elif isinstance(node, KDTree.leafnode):
554 d = self.data[node.idx]
555 return node.idx[minkowski_distance(d, x, p) <= r].tolist()
556 else:
557 less, greater = rect.split(node.split_dim, node.split)
558 return traverse_checking(node.less, less) + \
559 traverse_checking(node.greater, greater)
561 def traverse_no_checking(node):
562 if isinstance(node, KDTree.leafnode):
563 return node.idx.tolist()
564 else:
565 return traverse_no_checking(node.less) + \
566 traverse_no_checking(node.greater)
568 return traverse_checking(self.tree, R)
570 def query_ball_point(self, x, r, p=2., eps=0):
571 """Find all points within distance r of point(s) x.
573 Parameters
574 ----------
575 x : array_like, shape tuple + (self.m,)
576 The point or points to search for neighbors of.
577 r : positive float
578 The radius of points to return.
579 p : float, optional
580 Which Minkowski p-norm to use. Should be in the range [1, inf].
581 eps : nonnegative float, optional
582 Approximate search. Branches of the tree are not explored if their
583 nearest points are further than ``r / (1 + eps)``, and branches are
584 added in bulk if their furthest points are nearer than
585 ``r * (1 + eps)``.
587 Returns
588 -------
589 results : list or array of lists
590 If `x` is a single point, returns a list of the indices of the
591 neighbors of `x`. If `x` is an array of points, returns an object
592 array of shape tuple containing lists of neighbors.
594 Notes
595 -----
596 If you have many points whose neighbors you want to find, you may save
597 substantial amounts of time by putting them in a KDTree and using
598 query_ball_tree.
600 Examples
601 --------
602 >>> from scipy import spatial
603 >>> x, y = np.mgrid[0:5, 0:5]
604 >>> points = np.c_[x.ravel(), y.ravel()]
605 >>> tree = spatial.KDTree(points)
606 >>> tree.query_ball_point([2, 0], 1)
607 [5, 10, 11, 15]
609 Query multiple points and plot the results:
611 >>> import matplotlib.pyplot as plt
612 >>> points = np.asarray(points)
613 >>> plt.plot(points[:,0], points[:,1], '.')
614 >>> for results in tree.query_ball_point(([2, 0], [3, 3]), 1):
615 ... nearby_points = points[results]
616 ... plt.plot(nearby_points[:,0], nearby_points[:,1], 'o')
617 >>> plt.margins(0.1, 0.1)
618 >>> plt.show()
620 """
621 x = np.asarray(x)
622 if x.shape[-1] != self.m:
623 raise ValueError("Searching for a %d-dimensional point in a "
624 "%d-dimensional KDTree" % (x.shape[-1], self.m))
625 if len(x.shape) == 1:
626 return self.__query_ball_point(x, r, p, eps)
627 else:
628 retshape = x.shape[:-1]
629 result = np.empty(retshape, dtype=object)
630 for c in np.ndindex(retshape):
631 result[c] = self.__query_ball_point(x[c], r, p=p, eps=eps)
632 return result
634 def query_ball_tree(self, other, r, p=2., eps=0):
635 """Find all pairs of points whose distance is at most r
637 Parameters
638 ----------
639 other : KDTree instance
640 The tree containing points to search against.
641 r : float
642 The maximum distance, has to be positive.
643 p : float, optional
644 Which Minkowski norm to use. `p` has to meet the condition
645 ``1 <= p <= infinity``.
646 eps : float, optional
647 Approximate search. Branches of the tree are not explored
648 if their nearest points are further than ``r/(1+eps)``, and
649 branches are added in bulk if their furthest points are nearer
650 than ``r * (1+eps)``. `eps` has to be non-negative.
652 Returns
653 -------
654 results : list of lists
655 For each element ``self.data[i]`` of this tree, ``results[i]`` is a
656 list of the indices of its neighbors in ``other.data``.
658 """
659 results = [[] for i in range(self.n)]
661 def traverse_checking(node1, rect1, node2, rect2):
662 if rect1.min_distance_rectangle(rect2, p) > r/(1.+eps):
663 return
664 elif rect1.max_distance_rectangle(rect2, p) < r*(1.+eps):
665 traverse_no_checking(node1, node2)
666 elif isinstance(node1, KDTree.leafnode):
667 if isinstance(node2, KDTree.leafnode):
668 d = other.data[node2.idx]
669 for i in node1.idx:
670 results[i] += node2.idx[minkowski_distance(d,self.data[i],p) <= r].tolist()
671 else:
672 less, greater = rect2.split(node2.split_dim, node2.split)
673 traverse_checking(node1,rect1,node2.less,less)
674 traverse_checking(node1,rect1,node2.greater,greater)
675 elif isinstance(node2, KDTree.leafnode):
676 less, greater = rect1.split(node1.split_dim, node1.split)
677 traverse_checking(node1.less,less,node2,rect2)
678 traverse_checking(node1.greater,greater,node2,rect2)
679 else:
680 less1, greater1 = rect1.split(node1.split_dim, node1.split)
681 less2, greater2 = rect2.split(node2.split_dim, node2.split)
682 traverse_checking(node1.less,less1,node2.less,less2)
683 traverse_checking(node1.less,less1,node2.greater,greater2)
684 traverse_checking(node1.greater,greater1,node2.less,less2)
685 traverse_checking(node1.greater,greater1,node2.greater,greater2)
687 def traverse_no_checking(node1, node2):
688 if isinstance(node1, KDTree.leafnode):
689 if isinstance(node2, KDTree.leafnode):
690 for i in node1.idx:
691 results[i] += node2.idx.tolist()
692 else:
693 traverse_no_checking(node1, node2.less)
694 traverse_no_checking(node1, node2.greater)
695 else:
696 traverse_no_checking(node1.less, node2)
697 traverse_no_checking(node1.greater, node2)
699 traverse_checking(self.tree, Rectangle(self.maxes, self.mins),
700 other.tree, Rectangle(other.maxes, other.mins))
701 return results
703 def query_pairs(self, r, p=2., eps=0):
704 """
705 Find all pairs of points within a distance.
707 Parameters
708 ----------
709 r : positive float
710 The maximum distance.
711 p : float, optional
712 Which Minkowski norm to use. `p` has to meet the condition
713 ``1 <= p <= infinity``.
714 eps : float, optional
715 Approximate search. Branches of the tree are not explored
716 if their nearest points are further than ``r/(1+eps)``, and
717 branches are added in bulk if their furthest points are nearer
718 than ``r * (1+eps)``. `eps` has to be non-negative.
720 Returns
721 -------
722 results : set
723 Set of pairs ``(i,j)``, with ``i < j``, for which the corresponding
724 positions are close.
726 """
727 results = set()
729 def traverse_checking(node1, rect1, node2, rect2):
730 if rect1.min_distance_rectangle(rect2, p) > r/(1.+eps):
731 return
732 elif rect1.max_distance_rectangle(rect2, p) < r*(1.+eps):
733 traverse_no_checking(node1, node2)
734 elif isinstance(node1, KDTree.leafnode):
735 if isinstance(node2, KDTree.leafnode):
736 # Special care to avoid duplicate pairs
737 if id(node1) == id(node2):
738 d = self.data[node2.idx]
739 for i in node1.idx:
740 for j in node2.idx[minkowski_distance(d,self.data[i],p) <= r]:
741 if i < j:
742 results.add((i,j))
743 else:
744 d = self.data[node2.idx]
745 for i in node1.idx:
746 for j in node2.idx[minkowski_distance(d,self.data[i],p) <= r]:
747 if i < j:
748 results.add((i,j))
749 elif j < i:
750 results.add((j,i))
751 else:
752 less, greater = rect2.split(node2.split_dim, node2.split)
753 traverse_checking(node1,rect1,node2.less,less)
754 traverse_checking(node1,rect1,node2.greater,greater)
755 elif isinstance(node2, KDTree.leafnode):
756 less, greater = rect1.split(node1.split_dim, node1.split)
757 traverse_checking(node1.less,less,node2,rect2)
758 traverse_checking(node1.greater,greater,node2,rect2)
759 else:
760 less1, greater1 = rect1.split(node1.split_dim, node1.split)
761 less2, greater2 = rect2.split(node2.split_dim, node2.split)
762 traverse_checking(node1.less,less1,node2.less,less2)
763 traverse_checking(node1.less,less1,node2.greater,greater2)
765 # Avoid traversing (node1.less, node2.greater) and
766 # (node1.greater, node2.less) (it's the same node pair twice
767 # over, which is the source of the complication in the
768 # original KDTree.query_pairs)
769 if id(node1) != id(node2):
770 traverse_checking(node1.greater,greater1,node2.less,less2)
772 traverse_checking(node1.greater,greater1,node2.greater,greater2)
774 def traverse_no_checking(node1, node2):
775 if isinstance(node1, KDTree.leafnode):
776 if isinstance(node2, KDTree.leafnode):
777 # Special care to avoid duplicate pairs
778 if id(node1) == id(node2):
779 for i in node1.idx:
780 for j in node2.idx:
781 if i < j:
782 results.add((i,j))
783 else:
784 for i in node1.idx:
785 for j in node2.idx:
786 if i < j:
787 results.add((i,j))
788 elif j < i:
789 results.add((j,i))
790 else:
791 traverse_no_checking(node1, node2.less)
792 traverse_no_checking(node1, node2.greater)
793 else:
794 # Avoid traversing (node1.less, node2.greater) and
795 # (node1.greater, node2.less) (it's the same node pair twice
796 # over, which is the source of the complication in the
797 # original KDTree.query_pairs)
798 if id(node1) == id(node2):
799 traverse_no_checking(node1.less, node2.less)
800 traverse_no_checking(node1.less, node2.greater)
801 traverse_no_checking(node1.greater, node2.greater)
802 else:
803 traverse_no_checking(node1.less, node2)
804 traverse_no_checking(node1.greater, node2)
806 traverse_checking(self.tree, Rectangle(self.maxes, self.mins),
807 self.tree, Rectangle(self.maxes, self.mins))
808 return results
810 def count_neighbors(self, other, r, p=2.):
811 """
812 Count how many nearby pairs can be formed.
814 Count the number of pairs (x1,x2) can be formed, with x1 drawn
815 from self and x2 drawn from ``other``, and where
816 ``distance(x1, x2, p) <= r``.
817 This is the "two-point correlation" described in Gray and Moore 2000,
818 "N-body problems in statistical learning", and the code here is based
819 on their algorithm.
821 Parameters
822 ----------
823 other : KDTree instance
824 The other tree to draw points from.
825 r : float or one-dimensional array of floats
826 The radius to produce a count for. Multiple radii are searched with
827 a single tree traversal.
828 p : float, 1<=p<=infinity, optional
829 Which Minkowski p-norm to use
831 Returns
832 -------
833 result : int or 1-D array of ints
834 The number of pairs. Note that this is internally stored in a numpy
835 int, and so may overflow if very large (2e9).
837 """
838 def traverse(node1, rect1, node2, rect2, idx):
839 min_r = rect1.min_distance_rectangle(rect2,p)
840 max_r = rect1.max_distance_rectangle(rect2,p)
841 c_greater = r[idx] > max_r
842 result[idx[c_greater]] += node1.children*node2.children
843 idx = idx[(min_r <= r[idx]) & (r[idx] <= max_r)]
844 if len(idx) == 0:
845 return
847 if isinstance(node1,KDTree.leafnode):
848 if isinstance(node2,KDTree.leafnode):
849 ds = minkowski_distance(self.data[node1.idx][:,np.newaxis,:],
850 other.data[node2.idx][np.newaxis,:,:],
851 p).ravel()
852 ds.sort()
853 result[idx] += np.searchsorted(ds,r[idx],side='right')
854 else:
855 less, greater = rect2.split(node2.split_dim, node2.split)
856 traverse(node1, rect1, node2.less, less, idx)
857 traverse(node1, rect1, node2.greater, greater, idx)
858 else:
859 if isinstance(node2,KDTree.leafnode):
860 less, greater = rect1.split(node1.split_dim, node1.split)
861 traverse(node1.less, less, node2, rect2, idx)
862 traverse(node1.greater, greater, node2, rect2, idx)
863 else:
864 less1, greater1 = rect1.split(node1.split_dim, node1.split)
865 less2, greater2 = rect2.split(node2.split_dim, node2.split)
866 traverse(node1.less,less1,node2.less,less2,idx)
867 traverse(node1.less,less1,node2.greater,greater2,idx)
868 traverse(node1.greater,greater1,node2.less,less2,idx)
869 traverse(node1.greater,greater1,node2.greater,greater2,idx)
871 R1 = Rectangle(self.maxes, self.mins)
872 R2 = Rectangle(other.maxes, other.mins)
873 if np.shape(r) == ():
874 r = np.array([r])
875 result = np.zeros(1,dtype=int)
876 traverse(self.tree, R1, other.tree, R2, np.arange(1))
877 return result[0]
878 elif len(np.shape(r)) == 1:
879 r = np.asarray(r)
880 n, = r.shape
881 result = np.zeros(n,dtype=int)
882 traverse(self.tree, R1, other.tree, R2, np.arange(n))
883 return result
884 else:
885 raise ValueError("r must be either a single value or a one-dimensional array of values")
887 def sparse_distance_matrix(self, other, max_distance, p=2.):
888 """
889 Compute a sparse distance matrix
891 Computes a distance matrix between two KDTrees, leaving as zero
892 any distance greater than max_distance.
894 Parameters
895 ----------
896 other : KDTree
898 max_distance : positive float
900 p : float, optional
902 Returns
903 -------
904 result : dok_matrix
905 Sparse matrix representing the results in "dictionary of keys" format.
907 """
908 result = scipy.sparse.dok_matrix((self.n,other.n))
910 def traverse(node1, rect1, node2, rect2):
911 if rect1.min_distance_rectangle(rect2, p) > max_distance:
912 return
913 elif isinstance(node1, KDTree.leafnode):
914 if isinstance(node2, KDTree.leafnode):
915 for i in node1.idx:
916 for j in node2.idx:
917 d = minkowski_distance(self.data[i],other.data[j],p)
918 if d <= max_distance:
919 result[i,j] = d
920 else:
921 less, greater = rect2.split(node2.split_dim, node2.split)
922 traverse(node1,rect1,node2.less,less)
923 traverse(node1,rect1,node2.greater,greater)
924 elif isinstance(node2, KDTree.leafnode):
925 less, greater = rect1.split(node1.split_dim, node1.split)
926 traverse(node1.less,less,node2,rect2)
927 traverse(node1.greater,greater,node2,rect2)
928 else:
929 less1, greater1 = rect1.split(node1.split_dim, node1.split)
930 less2, greater2 = rect2.split(node2.split_dim, node2.split)
931 traverse(node1.less,less1,node2.less,less2)
932 traverse(node1.less,less1,node2.greater,greater2)
933 traverse(node1.greater,greater1,node2.less,less2)
934 traverse(node1.greater,greater1,node2.greater,greater2)
935 traverse(self.tree, Rectangle(self.maxes, self.mins),
936 other.tree, Rectangle(other.maxes, other.mins))
938 return result
941def distance_matrix(x, y, p=2, threshold=1000000):
942 """
943 Compute the distance matrix.
945 Returns the matrix of all pair-wise distances.
947 Parameters
948 ----------
949 x : (M, K) array_like
950 Matrix of M vectors in K dimensions.
951 y : (N, K) array_like
952 Matrix of N vectors in K dimensions.
953 p : float, 1 <= p <= infinity
954 Which Minkowski p-norm to use.
955 threshold : positive int
956 If ``M * N * K`` > `threshold`, algorithm uses a Python loop instead
957 of large temporary arrays.
959 Returns
960 -------
961 result : (M, N) ndarray
962 Matrix containing the distance from every vector in `x` to every vector
963 in `y`.
965 Examples
966 --------
967 >>> from scipy.spatial import distance_matrix
968 >>> distance_matrix([[0,0],[0,1]], [[1,0],[1,1]])
969 array([[ 1. , 1.41421356],
970 [ 1.41421356, 1. ]])
972 """
974 x = np.asarray(x)
975 m, k = x.shape
976 y = np.asarray(y)
977 n, kk = y.shape
979 if k != kk:
980 raise ValueError("x contains %d-dimensional vectors but y contains %d-dimensional vectors" % (k, kk))
982 if m*n*k <= threshold:
983 return minkowski_distance(x[:,np.newaxis,:],y[np.newaxis,:,:],p)
984 else:
985 result = np.empty((m,n),dtype=float) # FIXME: figure out the best dtype
986 if m < n:
987 for i in range(m):
988 result[i,:] = minkowski_distance(x[i],y,p)
989 else:
990 for j in range(n):
991 result[:,j] = minkowski_distance(x,y[j],p)
992 return result