Skip to content

Commit 8af49a5

Browse files
Add 3D visualization capabilities and update dependencies
- Introduced new functions in `three_d.py` for interactive 3D visualization of molecules and trajectories using `py3Dmol` and `nglview`. - Updated `__init__.py` to include new 3D visualization functions in the module's public API. - Added tests for the new 3D visualization features in `test_three_d.py` to ensure functionality and correctness. - Updated `pyproject.toml` to include `nglview` and `py3dmol` as dependencies, enhancing visualization capabilities.
1 parent 6239f48 commit 8af49a5

4 files changed

Lines changed: 341 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies = [
2121
"mdtraj",
2222
"mplplots @ git+https://github.com/FridrichMethod/mplplots.git",
2323
"numba",
24+
"nglview",
2425
"numpy",
2526
"pandas",
2627
"panedr",
@@ -30,6 +31,7 @@ dependencies = [
3031
"polars",
3132
"prody",
3233
"propka",
34+
"py3dmol",
3335
"rdkit",
3436
"scikit-learn",
3537
"scipy",
@@ -42,7 +44,7 @@ dependencies = [
4244
version = { attr = "mdpp.__version__" }
4345

4446
[project.optional-dependencies]
45-
viz = ["gmxvg", "jupyter", "py3dmol", "nglview"]
47+
viz = ["gmxvg", "jupyter"]
4648
dev = [
4749
"pre-commit",
4850
"pytest",

src/mdpp/plots/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from mdpp.plots.matrix import plot_dccm
66
from mdpp.plots.molecules import draw_mol, draw_mols, get_highlight_bonds
77
from mdpp.plots.scatter import plot_projection, plot_ramachandran
8+
from mdpp.plots.three_d import make_atom_labels_3d, view_mol_3d, view_traj_3d
89
from mdpp.plots.timeseries import (
910
plot_distances,
1011
plot_energy,
@@ -22,6 +23,7 @@
2223
"draw_mol",
2324
"draw_mols",
2425
"get_highlight_bonds",
26+
"make_atom_labels_3d",
2527
"plot_contact_map",
2628
"plot_dccm",
2729
"plot_distances",
@@ -36,4 +38,6 @@
3638
"plot_rmsd",
3739
"plot_rmsf",
3840
"plot_sasa",
41+
"view_mol_3d",
42+
"view_traj_3d",
3943
]

src/mdpp/plots/three_d.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
"""Interactive 3D visualization helpers for molecules and trajectories."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Callable, Mapping, Sequence
6+
7+
import mdtraj as md
8+
import nglview
9+
import py3Dmol
10+
from rdkit import Chem
11+
12+
_PY3DMOL_POSITION_KEYS = ("x", "y", "z")
13+
_DEFAULT_TRAJ_REPRESENTATIONS: tuple[dict[str, object], ...] = (
14+
{"type": "cartoon", "selection": "protein"},
15+
{"type": "ball+stick", "selection": "not protein"},
16+
)
17+
18+
19+
def _get_selected_conformer(
20+
mol: Chem.rdchem.Mol,
21+
conformer_id: int,
22+
) -> tuple[Chem.rdchem.Conformer, int]:
23+
"""Return the selected conformer object and its RDKit conformer ID."""
24+
if mol.GetNumConformers() == 0:
25+
raise ValueError("view_mol_3d requires a molecule with at least one conformer.")
26+
27+
conformer = mol.GetConformer() if conformer_id == -1 else mol.GetConformer(conformer_id)
28+
return conformer, conformer.GetId()
29+
30+
31+
def _label_position_from_atom(
32+
conformer: Chem.rdchem.Conformer,
33+
atom_index: int,
34+
) -> dict[str, float]:
35+
"""Return a py3Dmol-compatible position dictionary for an atom index."""
36+
position = conformer.GetAtomPosition(atom_index)
37+
return {"x": float(position.x), "y": float(position.y), "z": float(position.z)}
38+
39+
40+
def _normalize_label_spec(
41+
label: Mapping[str, object],
42+
conformer: Chem.rdchem.Conformer,
43+
) -> tuple[str, dict[str, object]]:
44+
"""Validate and normalize one py3Dmol label specification."""
45+
if "text" not in label:
46+
raise ValueError("Each label must include a 'text' field.")
47+
48+
has_atom_index = "atom_index" in label
49+
has_position = "position" in label
50+
if has_atom_index == has_position:
51+
raise ValueError("Each label must include exactly one of 'atom_index' or 'position'.")
52+
53+
text = str(label["text"])
54+
style = {key: value for key, value in label.items() if key not in {"text", "atom_index"}}
55+
56+
if has_atom_index:
57+
atom_index_value = label["atom_index"]
58+
if isinstance(atom_index_value, bool) or not isinstance(atom_index_value, int | str):
59+
raise ValueError("Label atom_index must be an integer.")
60+
atom_index = int(atom_index_value)
61+
style["position"] = _label_position_from_atom(conformer, atom_index)
62+
else:
63+
position = label["position"]
64+
if not isinstance(position, Mapping) or any(
65+
key not in position for key in _PY3DMOL_POSITION_KEYS
66+
):
67+
raise ValueError("Label positions must provide mapping keys 'x', 'y', and 'z'.")
68+
style["position"] = {
69+
"x": float(position["x"]),
70+
"y": float(position["y"]),
71+
"z": float(position["z"]),
72+
}
73+
74+
return text, style
75+
76+
77+
def make_atom_labels_3d(
78+
mol: Chem.rdchem.Mol,
79+
*,
80+
conformer_id: int = -1,
81+
atom_indices: Sequence[int] | None = None,
82+
text_fn: Callable[[Chem.rdchem.Atom], str] | None = None,
83+
color_fn: Callable[[Chem.rdchem.Atom], str | None] | None = None,
84+
base_style: Mapping[str, object] | None = None,
85+
) -> list[dict[str, object]]:
86+
"""Build generic per-atom label specifications for ``view_mol_3d``.
87+
88+
Args:
89+
mol: Molecule to label.
90+
conformer_id: RDKit conformer ID to use. ``-1`` selects the default conformer.
91+
atom_indices: Optional subset of atom indices to label.
92+
text_fn: Function that returns the label text for an atom.
93+
color_fn: Optional function that returns a font color for an atom.
94+
base_style: Optional py3Dmol label-style keys to apply to every label.
95+
96+
Returns:
97+
Label dictionaries consumable by ``view_mol_3d``.
98+
99+
Raises:
100+
ValueError: If the molecule has no conformer or ``text_fn`` is not provided.
101+
"""
102+
if text_fn is None:
103+
raise ValueError("make_atom_labels_3d requires a text_fn.")
104+
105+
conformer, _ = _get_selected_conformer(mol, conformer_id)
106+
if atom_indices is None:
107+
atom_indices = tuple(range(mol.GetNumAtoms()))
108+
109+
labels: list[dict[str, object]] = []
110+
for atom_index in atom_indices:
111+
atom = mol.GetAtomWithIdx(int(atom_index))
112+
label: dict[str, object] = dict(base_style or {})
113+
label["text"] = text_fn(atom)
114+
label["position"] = _label_position_from_atom(conformer, atom.GetIdx())
115+
color = color_fn(atom) if color_fn is not None else None
116+
if color is not None:
117+
label["fontColor"] = color
118+
labels.append(label)
119+
return labels
120+
121+
122+
def view_mol_3d(
123+
mol: Chem.rdchem.Mol,
124+
*,
125+
conformer_id: int = -1,
126+
width: int = 800,
127+
height: int = 400,
128+
style: Mapping[str, object] | None = None,
129+
labels: Sequence[Mapping[str, object]] | None = None,
130+
background_color: str = "white",
131+
zoom_to: bool = True,
132+
show: bool = False,
133+
) -> py3Dmol.view:
134+
"""Create an interactive py3Dmol viewer for an RDKit molecule.
135+
136+
Args:
137+
mol: Molecule to visualize.
138+
conformer_id: RDKit conformer ID to use. ``-1`` selects the default conformer.
139+
width: Viewer width in pixels.
140+
height: Viewer height in pixels.
141+
style: py3Dmol style dictionary passed to ``setStyle``.
142+
labels: Sequence of label dictionaries. Each must contain ``text`` and
143+
exactly one of ``atom_index`` or ``position``.
144+
background_color: Viewer background color.
145+
zoom_to: Whether to call ``zoomTo()`` after loading content.
146+
show: Whether to call ``show()`` before returning the viewer.
147+
148+
Returns:
149+
A configured py3Dmol viewer.
150+
151+
Raises:
152+
ValueError: If the molecule has no conformer or any label spec is invalid.
153+
"""
154+
conformer, selected_conformer_id = _get_selected_conformer(mol, conformer_id)
155+
del conformer # Only needed for label coordinate lookup below.
156+
157+
viewer = py3Dmol.view(width=width, height=height)
158+
viewer.setBackgroundColor(background_color)
159+
viewer.addModel(Chem.MolToMolBlock(mol, confId=selected_conformer_id), "sdf")
160+
viewer.setStyle(dict(style) if style is not None else {"stick": {}})
161+
162+
if labels:
163+
selected_conformer = mol.GetConformer(selected_conformer_id)
164+
for label in labels:
165+
text, label_style = _normalize_label_spec(label, selected_conformer)
166+
viewer.addLabel(text, label_style)
167+
168+
if zoom_to:
169+
viewer.zoomTo()
170+
if show:
171+
viewer.show()
172+
return viewer
173+
174+
175+
def view_traj_3d(
176+
traj: md.Trajectory,
177+
*,
178+
representations: Sequence[Mapping[str, object]] | None = None,
179+
) -> nglview.NGLWidget:
180+
"""Create an interactive nglview widget for a trajectory or structure.
181+
182+
Args:
183+
traj: MDTraj trajectory to visualize.
184+
representations: Sequence of representation dictionaries. Each must
185+
include ``type`` and may include ``selection`` plus any additional
186+
nglview representation keyword arguments.
187+
188+
Returns:
189+
A configured nglview widget.
190+
191+
Raises:
192+
ValueError: If any representation dictionary is missing ``type``.
193+
"""
194+
widget = nglview.show_mdtraj(traj)
195+
widget.clear_representations()
196+
197+
for representation in representations or _DEFAULT_TRAJ_REPRESENTATIONS:
198+
if "type" not in representation:
199+
raise ValueError("Each trajectory representation must include a 'type' field.")
200+
201+
rep_kwargs = dict(representation)
202+
rep_type = str(rep_kwargs.pop("type"))
203+
selection = rep_kwargs.pop("selection", "all")
204+
widget.add_representation(rep_type, selection=selection, **rep_kwargs)
205+
206+
return widget

tests/plots/test_three_d.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Tests for mdpp.plots.three_d."""
2+
3+
from __future__ import annotations
4+
5+
import mdtraj as md
6+
import nglview
7+
import py3Dmol
8+
import pytest
9+
from rdkit import Chem
10+
from rdkit.Chem import AllChem
11+
12+
from mdpp.plots import make_atom_labels_3d, view_mol_3d, view_traj_3d
13+
14+
15+
@pytest.fixture()
16+
def ethanol_3d() -> Chem.rdchem.Mol:
17+
mol = Chem.AddHs(Chem.MolFromSmiles("CCO"))
18+
AllChem.EmbedMolecule(mol, randomSeed=0)
19+
return mol
20+
21+
22+
def _conformer_positions(mol: Chem.rdchem.Mol) -> list[tuple[float, float, float]]:
23+
conformer = mol.GetConformer()
24+
return [
25+
(
26+
conformer.GetAtomPosition(atom_idx).x,
27+
conformer.GetAtomPosition(atom_idx).y,
28+
conformer.GetAtomPosition(atom_idx).z,
29+
)
30+
for atom_idx in range(mol.GetNumAtoms())
31+
]
32+
33+
34+
class TestViewMol3D:
35+
def test_returns_py3dmol_viewer(self, ethanol_3d):
36+
viewer = view_mol_3d(ethanol_3d)
37+
assert isinstance(viewer, py3Dmol.view)
38+
39+
def test_raises_without_conformer(self):
40+
mol = Chem.MolFromSmiles("CCO")
41+
with pytest.raises(ValueError, match="at least one conformer"):
42+
view_mol_3d(mol)
43+
44+
def test_does_not_modify_input_conformer(self, ethanol_3d):
45+
before_positions = _conformer_positions(ethanol_3d)
46+
ethanol_3d.GetAtomWithIdx(0).SetDoubleProp("PartialCharge", -0.12)
47+
48+
viewer = view_mol_3d(ethanol_3d)
49+
50+
assert isinstance(viewer, py3Dmol.view)
51+
assert _conformer_positions(ethanol_3d) == before_positions
52+
assert ethanol_3d.GetAtomWithIdx(0).GetDoubleProp("PartialCharge") == pytest.approx(-0.12)
53+
54+
def test_accepts_position_labels(self, ethanol_3d):
55+
labels = [{"text": "C0", "position": {"x": 0.0, "y": 0.0, "z": 0.0}}]
56+
viewer = view_mol_3d(ethanol_3d, labels=labels)
57+
assert isinstance(viewer, py3Dmol.view)
58+
59+
def test_resolves_atom_index_labels(self, ethanol_3d):
60+
labels = [{"text": "O", "atom_index": 2, "fontColor": "red"}]
61+
viewer = view_mol_3d(ethanol_3d, labels=labels)
62+
assert isinstance(viewer, py3Dmol.view)
63+
64+
def test_rejects_invalid_labels(self, ethanol_3d):
65+
labels = [{"text": "bad"}]
66+
with pytest.raises(ValueError, match="exactly one of 'atom_index' or 'position'"):
67+
view_mol_3d(ethanol_3d, labels=labels)
68+
69+
70+
class TestMakeAtomLabels3D:
71+
def test_requires_text_fn(self, ethanol_3d):
72+
with pytest.raises(ValueError, match="requires a text_fn"):
73+
make_atom_labels_3d(ethanol_3d)
74+
75+
def test_returns_one_label_per_atom(self, ethanol_3d):
76+
labels = make_atom_labels_3d(
77+
ethanol_3d,
78+
text_fn=lambda atom: atom.GetSymbol(),
79+
base_style={"fontSize": 10},
80+
)
81+
82+
assert len(labels) == ethanol_3d.GetNumAtoms()
83+
assert all("text" in label for label in labels)
84+
assert all("position" in label for label in labels)
85+
assert all(label["fontSize"] == 10 for label in labels)
86+
87+
def test_supports_atom_subset(self, ethanol_3d):
88+
labels = make_atom_labels_3d(
89+
ethanol_3d,
90+
atom_indices=[0, 2],
91+
text_fn=lambda atom: str(atom.GetIdx()),
92+
)
93+
94+
assert [label["text"] for label in labels] == ["0", "2"]
95+
96+
def test_supports_charge_style_callbacks(self, ethanol_3d):
97+
for atom in ethanol_3d.GetAtoms():
98+
atom.SetDoubleProp("PartialCharge", (-1.0) ** atom.GetIdx() * 0.1)
99+
100+
labels = make_atom_labels_3d(
101+
ethanol_3d,
102+
text_fn=lambda atom: f"{atom.GetDoubleProp('PartialCharge'):+.2f}",
103+
color_fn=lambda atom: "red" if atom.GetDoubleProp("PartialCharge") < 0 else "blue",
104+
base_style={"showBackground": True},
105+
)
106+
107+
assert labels[0]["text"].startswith("+")
108+
assert labels[1]["fontColor"] == "red"
109+
assert labels[0]["showBackground"] is True
110+
111+
112+
class TestViewTraj3D:
113+
def test_returns_ngl_widget(self, two_atom_trajectory: md.Trajectory):
114+
widget = view_traj_3d(two_atom_trajectory)
115+
assert isinstance(widget, nglview.NGLWidget)
116+
117+
def test_accepts_custom_representations(self, two_atom_trajectory: md.Trajectory):
118+
widget = view_traj_3d(
119+
two_atom_trajectory,
120+
representations=[
121+
{"type": "ball+stick", "selection": "all", "color": "element"},
122+
],
123+
)
124+
assert isinstance(widget, nglview.NGLWidget)
125+
126+
def test_requires_representation_type(self, two_atom_trajectory: md.Trajectory):
127+
with pytest.raises(ValueError, match="must include a 'type' field"):
128+
view_traj_3d(two_atom_trajectory, representations=[{"selection": "all"}])

0 commit comments

Comments
 (0)