3838
3939from MDAnalysis .lib .distances import apply_PBC
4040import numpy .typing as npt
41- from typing import Optional , ClassVar
41+ from typing import Optional , ClassVar , Union , Any
4242
4343__all__ = ["PeriodicKDTree" ]
4444
4545
46- class PeriodicKDTree (object ):
46+ class AugmentedPKDTree (object ):
4747 """Wrapper around :class:`scipy.spatial.cKDTree`
4848
4949 Creates an object which can handle periodic as well as
@@ -193,9 +193,7 @@ def search(self, centers: npt.ArrayLike, radius: float) -> npt.NDArray:
193193 "Cutoff needs to be provided when working with PBC."
194194 )
195195 if self .cutoff < radius :
196- raise RuntimeError (
197- "Set cutoff greater or equal to the radius."
198- )
196+ raise RuntimeError ("Set cutoff greater or equal to the radius." )
199197 # Bring all query points to the central cell
200198 wrapped_centers = apply_PBC (centers , self .box )
201199 indices = list (self .ckdt .query_ball_point (wrapped_centers , radius ))
@@ -247,9 +245,7 @@ def search_pairs(self, radius: float) -> npt.NDArray:
247245 "Cutoff needs to be provided when working with PBC."
248246 )
249247 if self .cutoff < radius :
250- raise RuntimeError (
251- "Set cutoff greater or equal to the radius."
252- )
248+ raise RuntimeError ("Set cutoff greater or equal to the radius." )
253249
254250 pairs = np .array (list (self .ckdt .query_pairs (radius )), dtype = np .intp )
255251 if self .pbc :
@@ -311,9 +307,7 @@ class initialization
311307 "Cutoff needs to be provided when working with PBC."
312308 )
313309 if self .cutoff < radius :
314- raise RuntimeError (
315- "Set cutoff greater or equal to the radius."
316- )
310+ raise RuntimeError ("Set cutoff greater or equal to the radius." )
317311 # Bring all query points to the central cell
318312 wrapped_centers = apply_PBC (centers , self .box )
319313 other_tree = cKDTree (wrapped_centers , leafsize = self .leafsize )
@@ -336,3 +330,236 @@ class initialization
336330 if pairs .size > 0 :
337331 pairs = unique_rows (pairs )
338332 return pairs
333+
334+
335+ class PeriodicKDTree (object ):
336+
337+ def __init__ (
338+ self , box : Optional [npt .ArrayLike ] = None , leafsize : int = 10
339+ ) -> None :
340+ self .leafsize = leafsize
341+ self .dim = 3
342+ self .box = box
343+ self ._built = False
344+ self .cutoff = None
345+ self .mapping = None
346+
347+ _use_augmented = False
348+ if box is not None :
349+ box_array = np .asarray (box , dtype = np .float32 )
350+ if box_array .shape == (6 ,):
351+ if not np .allclose (box_array [3 :], 90.0 ):
352+ _use_augmented = True
353+ else :
354+ _use_augmented = True
355+
356+ self ._use_augmented = _use_augmented
357+
358+ if self ._use_augmented :
359+ self ._tree = AugmentedPKDTree (box = self .box , leafsize = leafsize )
360+ else :
361+ self ._tree = None
362+ if box is not None :
363+ self .box = np .asarray (box , dtype = np .float32 )
364+ self ._is_ortho = True
365+
366+ @property
367+ def pbc (self ):
368+ """Flag to indicate the presence of periodic boundaries.
369+
370+ - ``True`` if PBC are taken into account
371+ - ``False`` if no unitcell dimension is available.
372+
373+ This is a managed attribute and can only be read.
374+ """
375+ return self .box is not None
376+
377+ def set_coords (
378+ self , coords : npt .ArrayLike , cutoff : Optional [float ] = None
379+ ) -> None :
380+ """Constructs KDTree from the coordinates
381+
382+ Parameters
383+ ----------
384+ coords: array_like
385+ Coordinate array of shape ``(N, 3)`` for N atoms.
386+ cutoff: float
387+ Specified cutoff distance for searches.
388+ Required for periodic calculations.
389+ """
390+ if self ._use_augmented :
391+ assert self ._tree is not None
392+ self ._tree .set_coords (coords , cutoff )
393+ self ._built = True
394+ self .cutoff = cutoff
395+ else :
396+ coords = np .asarray (coords , dtype = np .float32 )
397+ self .cutoff = cutoff
398+
399+ if self .box is None :
400+ if cutoff is not None :
401+ raise RuntimeError (
402+ "Donot provide cutoff distance for non PBC aware calculations"
403+ )
404+ self .coords = coords
405+ self ._tree = cKDTree (self .coords , leafsize = self .leafsize )
406+ else :
407+ if cutoff is None :
408+ raise RuntimeError (
409+ "Provide a cutoff distance with tree.set_coords(...)"
410+ )
411+ self .coords = apply_PBC (coords , self .box )
412+ box_array = np .asarray (self .box , dtype = np .float32 )
413+ self ._tree = cKDTree (
414+ self .coords , leafsize = self .leafsize , boxsize = box_array [:3 ]
415+ )
416+
417+ self ._built = True
418+
419+ def search (self , centers : npt .ArrayLike , radius : float ) -> npt .NDArray :
420+ """Search all points within radius from centers and their periodic images.
421+
422+ Parameters
423+ ----------
424+ centers: array_like (N,3)
425+ coordinate array to search for neighbors
426+ radius: float
427+ maximum distance to search for neighbors.
428+ """
429+ if not self ._built :
430+ raise RuntimeError ("Unbuilt tree. Run tree.set_coords(...)" )
431+
432+ if self ._use_augmented :
433+ assert self ._tree is not None
434+ return self ._tree .search (centers , radius )
435+
436+ centers = np .asarray (centers , dtype = np .float32 )
437+ if centers .shape == (self .dim ,):
438+ centers = centers .reshape ((1 , self .dim ))
439+
440+ if self .pbc :
441+ if self .cutoff is None :
442+ raise ValueError (
443+ "Cutoff needs to be provided when working with PBC."
444+ )
445+ if self .cutoff < radius :
446+ raise RuntimeError ("Set cutoff greater or equal to the radius." )
447+ wrapped_centers = apply_PBC (centers , self .box )
448+ assert isinstance (self ._tree , cKDTree )
449+ indices = list (self ._tree .query_ball_point (wrapped_centers , radius ))
450+ else :
451+ assert isinstance (self ._tree , cKDTree )
452+ indices = list (self ._tree .query_ball_point (centers , radius ))
453+
454+ self ._indices = np .array (
455+ list (itertools .chain .from_iterable (indices )), dtype = np .intp
456+ )
457+
458+ if self ._indices .size > 0 :
459+ self ._indices = np .asarray (unique_int_1d (self ._indices ))
460+ return self ._indices
461+
462+ def get_indices (self ) -> npt .NDArray :
463+ """Return the neighbors from the last query.
464+
465+ Returns
466+ ------
467+ indices : NDArray
468+ neighbors for the last query points and search radius
469+ """
470+ if self ._use_augmented :
471+ assert self ._tree is not None
472+ return self ._tree .get_indices ()
473+ return self ._indices
474+
475+ def search_pairs (self , radius : float ) -> npt .NDArray :
476+ """Search all the pairs within a specified radius
477+
478+ Parameters
479+ ----------
480+ radius : float
481+ Maximum distance between pairs of coordinates
482+
483+ Returns
484+ -------
485+ pairs : array
486+ Indices of all the pairs which are within the specified radius
487+ """
488+ if not self ._built :
489+ raise RuntimeError ("Unbuilt Tree. Run tree.set_coords(...)" )
490+
491+ if self ._use_augmented :
492+ assert self ._tree is not None
493+ return self ._tree .search_pairs (radius )
494+
495+ if self .pbc :
496+ if self .cutoff is None :
497+ raise ValueError (
498+ "Cutoff needs to be provided when working with PBC."
499+ )
500+ if self .cutoff < radius :
501+ raise RuntimeError ("Set cutoff greater or equal to the radius." )
502+
503+ assert isinstance (self ._tree , cKDTree )
504+ pairs = np .array (list (self ._tree .query_pairs (radius )), dtype = np .intp )
505+
506+ if pairs .size > 0 :
507+ pairs = np .sort (pairs , axis = 1 )
508+ pairs = unique_rows (pairs )
509+ return pairs
510+
511+ def search_tree (self , centers : npt .ArrayLike , radius : float ) -> np .ndarray :
512+ """
513+ Searches all the pairs within `radius` between `centers`
514+ and ``coords``
515+
516+ ``coords`` are the already initialized coordinates in the tree
517+ during :meth:`set_coords`.
518+
519+ Parameters
520+ ----------
521+ centers: array_like (N,3)
522+ coordinate array to search for neighbors
523+ radius: float
524+ maximum distance to search for neighbors.
525+
526+ Returns
527+ -------
528+ pairs : array
529+ all the pairs between ``coords`` and ``centers``
530+ """
531+ if not self ._built :
532+ raise RuntimeError ("Unbuilt tree. Run tree.set_coords(...)" )
533+
534+ if self ._use_augmented :
535+ assert self ._tree is not None
536+ return self ._tree .search_tree (centers , radius )
537+
538+ centers = np .asarray (centers , dtype = np .float32 )
539+ if centers .shape == (self .dim ,):
540+ centers = centers .reshape ((1 , self .dim ))
541+
542+ if self .pbc :
543+ if self .cutoff is None :
544+ raise ValueError (
545+ "Cutoff needs to be provided when working with PBC."
546+ )
547+ if self .cutoff < radius :
548+ raise RuntimeError ("Set cutoff greater or equal to the radius." )
549+ wrapped_centers = apply_PBC (centers , self .box )
550+ box_array = np .asarray (self .box , dtype = np .float32 )
551+ other_tree = cKDTree (
552+ wrapped_centers , leafsize = self .leafsize , boxsize = box_array [:3 ]
553+ )
554+ else :
555+ other_tree = cKDTree (centers , leafsize = self .leafsize )
556+
557+ pairs_list = other_tree .query_ball_tree (self ._tree , radius )
558+ pairs = np .array (
559+ [[i , j ] for i , lst in enumerate (pairs_list ) for j in lst ],
560+ dtype = np .intp ,
561+ )
562+
563+ if pairs .size > 0 :
564+ pairs = unique_rows (pairs )
565+ return pairs
0 commit comments