Skip to content

Commit 8abf284

Browse files
fixing2
1 parent 9fb8b0a commit 8abf284

3 files changed

Lines changed: 310 additions & 26 deletions

File tree

benchmarks/benchmarks/neighbors.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import numpy as np
2+
from MDAnalysis.lib.pkdtree import PeriodicKDTree
3+
from MDAnalysis.lib.distances import capped_distance
4+
from scipy.spatial import cKDTree
5+
6+
7+
class NeighborsBench:
8+
"""Benchmarks for neighbor searching functions."""
9+
10+
params = ([100, 1000, 10000, 100000], [20, 30, 36, 42, 48, 50, 60])
11+
param_names = ["number_of_atoms", "cutoff"]
12+
13+
def setup(self, number_of_atoms, cutoff):
14+
"""Setup called before each benchmark with each parameter combination."""
15+
self.box = np.array(
16+
[170.0, 70.0, 120.0, 90.0, 90.0, 90.0], dtype=np.float32
17+
)
18+
self.positions = (
19+
np.random.rand(number_of_atoms, 3) * self.box[:3]
20+
).astype(np.float32)
21+
self.centre = (self.box[:3] / 2.0).reshape(1, 3)
22+
self.cutoff = cutoff
23+
24+
self.scipy_tree = cKDTree(self.positions, boxsize=self.box[:3])
25+
self.mda_tree = PeriodicKDTree(box=self.box)
26+
self.mda_tree.set_coords(self.positions, cutoff=self.cutoff)
27+
28+
def time_mda_tree_search(self, number_of_atoms, cutoff):
29+
"""Benchmark just the search operation on pre-built tree."""
30+
self.mda_tree.search(self.centre, self.cutoff)
31+
32+
def time_scipy_tree_query(self, number_of_atoms, cutoff):
33+
"""Benchmark just the query operation on pre-built tree."""
34+
self.scipy_tree.query_ball_point(self.centre, self.cutoff)
35+
36+
def time_mda_PKDtree_with_setup(self, number_of_atoms, cutoff):
37+
"""Benchmark tree construction + search."""
38+
tree = PeriodicKDTree(box=self.box)
39+
tree.set_coords(self.positions, cutoff=self.cutoff)
40+
tree.search(self.centre, self.cutoff)
41+
42+
def time_scipy_cKDTree_with_setup(self, number_of_atoms, cutoff):
43+
"""Benchmark tree construction + query."""
44+
tree = cKDTree(self.positions, boxsize=self.box[:3])
45+
tree.query_ball_point(self.centre, self.cutoff)
46+
47+
def time_capped_distance_array(self, number_of_atoms, cutoff):
48+
"""Benchmark capped distance calculation."""
49+
capped_distance(
50+
self.centre, self.positions, max_cutoff=self.cutoff, box=self.box
51+
)

package/MDAnalysis/lib/pkdtree.py

Lines changed: 238 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@
3838

3939
from MDAnalysis.lib.distances import apply_PBC
4040
import numpy.typing as npt
41-
from typing import Optional, ClassVar
41+
from typing import Optional, ClassVar, Union, Any
4242

4343
__all__ = ["PeriodicKDTree"]
4444

4545

46-
class PeriodicKDTree(object):
46+
class AugmentedPKDTree(object):
4747
"""Wrapper around :class:`scipy.spatial.cKDTree`
4848
4949
Creates an object which can handle periodic as well as
@@ -193,9 +193,7 @@ def search(self, centers: npt.ArrayLike, radius: float) -> npt.NDArray:
193193
"Cutoff needs to be provided when working with PBC."
194194
)
195195
if self.cutoff < radius:
196-
raise RuntimeError(
197-
"Set cutoff greater or equal to the radius."
198-
)
196+
raise RuntimeError("Set cutoff greater or equal to the radius.")
199197
# Bring all query points to the central cell
200198
wrapped_centers = apply_PBC(centers, self.box)
201199
indices = list(self.ckdt.query_ball_point(wrapped_centers, radius))
@@ -247,9 +245,7 @@ def search_pairs(self, radius: float) -> npt.NDArray:
247245
"Cutoff needs to be provided when working with PBC."
248246
)
249247
if self.cutoff < radius:
250-
raise RuntimeError(
251-
"Set cutoff greater or equal to the radius."
252-
)
248+
raise RuntimeError("Set cutoff greater or equal to the radius.")
253249

254250
pairs = np.array(list(self.ckdt.query_pairs(radius)), dtype=np.intp)
255251
if self.pbc:
@@ -311,9 +307,7 @@ class initialization
311307
"Cutoff needs to be provided when working with PBC."
312308
)
313309
if self.cutoff < radius:
314-
raise RuntimeError(
315-
"Set cutoff greater or equal to the radius."
316-
)
310+
raise RuntimeError("Set cutoff greater or equal to the radius.")
317311
# Bring all query points to the central cell
318312
wrapped_centers = apply_PBC(centers, self.box)
319313
other_tree = cKDTree(wrapped_centers, leafsize=self.leafsize)
@@ -336,3 +330,236 @@ class initialization
336330
if pairs.size > 0:
337331
pairs = unique_rows(pairs)
338332
return pairs
333+
334+
335+
class PeriodicKDTree(object):
336+
337+
def __init__(
338+
self, box: Optional[npt.ArrayLike] = None, leafsize: int = 10
339+
) -> None:
340+
self.leafsize = leafsize
341+
self.dim = 3
342+
self.box = box
343+
self._built = False
344+
self.cutoff = None
345+
self.mapping = None
346+
347+
_use_augmented = False
348+
if box is not None:
349+
box_array = np.asarray(box, dtype=np.float32)
350+
if box_array.shape == (6,):
351+
if not np.allclose(box_array[3:], 90.0):
352+
_use_augmented = True
353+
else:
354+
_use_augmented = True
355+
356+
self._use_augmented = _use_augmented
357+
358+
if self._use_augmented:
359+
self._tree = AugmentedPKDTree(box=self.box, leafsize=leafsize)
360+
else:
361+
self._tree = None
362+
if box is not None:
363+
self.box = np.asarray(box, dtype=np.float32)
364+
self._is_ortho = True
365+
366+
@property
367+
def pbc(self):
368+
"""Flag to indicate the presence of periodic boundaries.
369+
370+
- ``True`` if PBC are taken into account
371+
- ``False`` if no unitcell dimension is available.
372+
373+
This is a managed attribute and can only be read.
374+
"""
375+
return self.box is not None
376+
377+
def set_coords(
378+
self, coords: npt.ArrayLike, cutoff: Optional[float] = None
379+
) -> None:
380+
"""Constructs KDTree from the coordinates
381+
382+
Parameters
383+
----------
384+
coords: array_like
385+
Coordinate array of shape ``(N, 3)`` for N atoms.
386+
cutoff: float
387+
Specified cutoff distance for searches.
388+
Required for periodic calculations.
389+
"""
390+
if self._use_augmented:
391+
assert self._tree is not None
392+
self._tree.set_coords(coords, cutoff)
393+
self._built = True
394+
self.cutoff = cutoff
395+
else:
396+
coords = np.asarray(coords, dtype=np.float32)
397+
self.cutoff = cutoff
398+
399+
if self.box is None:
400+
if cutoff is not None:
401+
raise RuntimeError(
402+
"Donot provide cutoff distance for non PBC aware calculations"
403+
)
404+
self.coords = coords
405+
self._tree = cKDTree(self.coords, leafsize=self.leafsize)
406+
else:
407+
if cutoff is None:
408+
raise RuntimeError(
409+
"Provide a cutoff distance with tree.set_coords(...)"
410+
)
411+
self.coords = apply_PBC(coords, self.box)
412+
box_array = np.asarray(self.box, dtype=np.float32)
413+
self._tree = cKDTree(
414+
self.coords, leafsize=self.leafsize, boxsize=box_array[:3]
415+
)
416+
417+
self._built = True
418+
419+
def search(self, centers: npt.ArrayLike, radius: float) -> npt.NDArray:
420+
"""Search all points within radius from centers and their periodic images.
421+
422+
Parameters
423+
----------
424+
centers: array_like (N,3)
425+
coordinate array to search for neighbors
426+
radius: float
427+
maximum distance to search for neighbors.
428+
"""
429+
if not self._built:
430+
raise RuntimeError("Unbuilt tree. Run tree.set_coords(...)")
431+
432+
if self._use_augmented:
433+
assert self._tree is not None
434+
return self._tree.search(centers, radius)
435+
436+
centers = np.asarray(centers, dtype=np.float32)
437+
if centers.shape == (self.dim,):
438+
centers = centers.reshape((1, self.dim))
439+
440+
if self.pbc:
441+
if self.cutoff is None:
442+
raise ValueError(
443+
"Cutoff needs to be provided when working with PBC."
444+
)
445+
if self.cutoff < radius:
446+
raise RuntimeError("Set cutoff greater or equal to the radius.")
447+
wrapped_centers = apply_PBC(centers, self.box)
448+
assert isinstance(self._tree, cKDTree)
449+
indices = list(self._tree.query_ball_point(wrapped_centers, radius))
450+
else:
451+
assert isinstance(self._tree, cKDTree)
452+
indices = list(self._tree.query_ball_point(centers, radius))
453+
454+
self._indices = np.array(
455+
list(itertools.chain.from_iterable(indices)), dtype=np.intp
456+
)
457+
458+
if self._indices.size > 0:
459+
self._indices = np.asarray(unique_int_1d(self._indices))
460+
return self._indices
461+
462+
def get_indices(self) -> npt.NDArray:
463+
"""Return the neighbors from the last query.
464+
465+
Returns
466+
------
467+
indices : NDArray
468+
neighbors for the last query points and search radius
469+
"""
470+
if self._use_augmented:
471+
assert self._tree is not None
472+
return self._tree.get_indices()
473+
return self._indices
474+
475+
def search_pairs(self, radius: float) -> npt.NDArray:
476+
"""Search all the pairs within a specified radius
477+
478+
Parameters
479+
----------
480+
radius : float
481+
Maximum distance between pairs of coordinates
482+
483+
Returns
484+
-------
485+
pairs : array
486+
Indices of all the pairs which are within the specified radius
487+
"""
488+
if not self._built:
489+
raise RuntimeError("Unbuilt Tree. Run tree.set_coords(...)")
490+
491+
if self._use_augmented:
492+
assert self._tree is not None
493+
return self._tree.search_pairs(radius)
494+
495+
if self.pbc:
496+
if self.cutoff is None:
497+
raise ValueError(
498+
"Cutoff needs to be provided when working with PBC."
499+
)
500+
if self.cutoff < radius:
501+
raise RuntimeError("Set cutoff greater or equal to the radius.")
502+
503+
assert isinstance(self._tree, cKDTree)
504+
pairs = np.array(list(self._tree.query_pairs(radius)), dtype=np.intp)
505+
506+
if pairs.size > 0:
507+
pairs = np.sort(pairs, axis=1)
508+
pairs = unique_rows(pairs)
509+
return pairs
510+
511+
def search_tree(self, centers: npt.ArrayLike, radius: float) -> np.ndarray:
512+
"""
513+
Searches all the pairs within `radius` between `centers`
514+
and ``coords``
515+
516+
``coords`` are the already initialized coordinates in the tree
517+
during :meth:`set_coords`.
518+
519+
Parameters
520+
----------
521+
centers: array_like (N,3)
522+
coordinate array to search for neighbors
523+
radius: float
524+
maximum distance to search for neighbors.
525+
526+
Returns
527+
-------
528+
pairs : array
529+
all the pairs between ``coords`` and ``centers``
530+
"""
531+
if not self._built:
532+
raise RuntimeError("Unbuilt tree. Run tree.set_coords(...)")
533+
534+
if self._use_augmented:
535+
assert self._tree is not None
536+
return self._tree.search_tree(centers, radius)
537+
538+
centers = np.asarray(centers, dtype=np.float32)
539+
if centers.shape == (self.dim,):
540+
centers = centers.reshape((1, self.dim))
541+
542+
if self.pbc:
543+
if self.cutoff is None:
544+
raise ValueError(
545+
"Cutoff needs to be provided when working with PBC."
546+
)
547+
if self.cutoff < radius:
548+
raise RuntimeError("Set cutoff greater or equal to the radius.")
549+
wrapped_centers = apply_PBC(centers, self.box)
550+
box_array = np.asarray(self.box, dtype=np.float32)
551+
other_tree = cKDTree(
552+
wrapped_centers, leafsize=self.leafsize, boxsize=box_array[:3]
553+
)
554+
else:
555+
other_tree = cKDTree(centers, leafsize=self.leafsize)
556+
557+
pairs_list = other_tree.query_ball_tree(self._tree, radius)
558+
pairs = np.array(
559+
[[i, j] for i, lst in enumerate(pairs_list) for j in lst],
560+
dtype=np.intp,
561+
)
562+
563+
if pairs.size > 0:
564+
pairs = unique_rows(pairs)
565+
return pairs

0 commit comments

Comments
 (0)