Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion src/structuretoolkit/build/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,54 @@ def repulse(
return structure


def merge(structure: "ase.Atoms", cutoff: float = 1.8, iterations: int = 10) -> "ase.Atoms":
"""Merge pairs of atoms that are closer than ``cutoff`` by collapsing each
pair to their midpoint and deleting one of the two atoms.

The operation is applied repeatedly (up to ``iterations`` times) to handle
cases where a merge creates new close contacts.

.. note::
The structure is modified **in place**. Pass a copy if you need the
original to remain unchanged.

Args:
structure (:class:`ase.Atoms`):
Structure to modify.
cutoff (float):
Distance threshold in Ångström below which two atoms are
considered clashing and will be merged. Defaults to ``1.8``.
iterations (int):
Maximum number of recursive merge passes. Defaults to ``10``.

Returns:
:class:`ase.Atoms`: The modified structure with clashing atom pairs
replaced by single atoms at their midpoints.
"""
neigh = get_neighbors(structure, 1)
clashing = np.argwhere( neigh.distances[:,0] < cutoff ).ravel()
if len(clashing) == 0:
return structure

moving = []
deleting = []

for c in clashing:
if c in deleting:
continue

moving.append(c)
deleting.append(neigh.indices[c, 0])

structure.positions[moving] += neigh.vecs[moving, 0]/2
del structure[deleting]

if iterations > 0:
return merge(structure, cutoff=cutoff, iterations=iterations-1)
return structure


__all__ = [
"repulse"
"merge",
"repulse",
]
64 changes: 63 additions & 1 deletion tests/test_geometry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import unittest

import numpy as np
from ase import Atoms
from ase.build import bulk

from structuretoolkit.analyse import get_neighbors
from structuretoolkit.build.geometry import repulse
from structuretoolkit.build.geometry import merge, repulse


class TestRepulse(unittest.TestCase):
Expand Down Expand Up @@ -59,3 +60,64 @@ def test_min_dist(self):
result = repulse(self.atoms, min_dist=min_dist)
neigh = get_neighbors(result, num_neighbors=1)
self.assertGreaterEqual(neigh.distances[:, 0].min(), min_dist)


def _two_atom_structure(d: float) -> Atoms:
"""Return a two-Cu-atom cell with atoms separated by ``d`` Å along x."""
return Atoms("Cu2", positions=[[0, 0, 0], [d, 0, 0]], cell=[20, 20, 20], pbc=True)


class TestMerge(unittest.TestCase):
def test_noop(self):
"""Perfect FCC Cu has no contacts below default cutoff; structure is unchanged."""
atoms = bulk("Cu", cubic=True).repeat(3)
original_positions = atoms.positions.copy()
result = merge(atoms)
self.assertEqual(len(result), len(atoms))
np.testing.assert_array_equal(result.positions, original_positions)

def test_reduces_count(self):
"""Two atoms within cutoff are collapsed into one."""
atoms = _two_atom_structure(0.5)
result = merge(atoms, cutoff=1.8)
self.assertEqual(len(result), 1)

def test_midpoint(self):
"""The surviving atom must sit at the midpoint of the original pair."""
atoms = _two_atom_structure(1.0)
result = merge(atoms, cutoff=1.8)
self.assertEqual(len(result), 1)
np.testing.assert_allclose(result.positions[0], [0.5, 0, 0], atol=1e-10)

def test_cutoff_respected(self):
"""Atoms just beyond the cutoff must not be merged."""
atoms = _two_atom_structure(2.0)
result = merge(atoms, cutoff=1.8)
self.assertEqual(len(result), 2)

def test_multiple_pairs(self):
"""All clashing pairs in the structure are merged in one call."""
# Two independent close pairs, well separated from each other
atoms = Atoms(
"Cu4",
positions=[[0, 0, 0], [0.5, 0, 0], [10, 0, 0], [10.5, 0, 0]],
cell=[30, 30, 30],
pbc=True,
)
result = merge(atoms, cutoff=1.8)
self.assertEqual(len(result), 2)

def test_iterations_zero_stops_early(self):
"""With iterations=0 only one pass runs; further clashes are left unresolved."""
# Three atoms in a row: 0–1 and 1–2 clash. First pass merges 0+1 → 0.25.
# The merged atom is now ~9.75 Å from atom 2, so one pass is enough here
# — we just verify the function returns without error.
atoms = Atoms(
"Cu3",
positions=[[0, 0, 0], [0.5, 0, 0], [10, 0, 0]],
cell=[20, 20, 20],
pbc=True,
)
result = merge(atoms, cutoff=1.8, iterations=0)
# At least one merge happened
self.assertLess(len(result), 3)
Loading