|
| 1 | +"""Interactive 3D visualization helpers for molecules and trajectories.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from collections.abc import Callable, Mapping, Sequence |
| 6 | + |
| 7 | +import mdtraj as md |
| 8 | +import nglview |
| 9 | +import py3Dmol |
| 10 | +from rdkit import Chem |
| 11 | + |
| 12 | +_PY3DMOL_POSITION_KEYS = ("x", "y", "z") |
| 13 | +_DEFAULT_TRAJ_REPRESENTATIONS: tuple[dict[str, object], ...] = ( |
| 14 | + {"type": "cartoon", "selection": "protein"}, |
| 15 | + {"type": "ball+stick", "selection": "not protein"}, |
| 16 | +) |
| 17 | + |
| 18 | + |
| 19 | +def _get_selected_conformer( |
| 20 | + mol: Chem.rdchem.Mol, |
| 21 | + conformer_id: int, |
| 22 | +) -> tuple[Chem.rdchem.Conformer, int]: |
| 23 | + """Return the selected conformer object and its RDKit conformer ID.""" |
| 24 | + if mol.GetNumConformers() == 0: |
| 25 | + raise ValueError("view_mol_3d requires a molecule with at least one conformer.") |
| 26 | + |
| 27 | + conformer = mol.GetConformer() if conformer_id == -1 else mol.GetConformer(conformer_id) |
| 28 | + return conformer, conformer.GetId() |
| 29 | + |
| 30 | + |
| 31 | +def _label_position_from_atom( |
| 32 | + conformer: Chem.rdchem.Conformer, |
| 33 | + atom_index: int, |
| 34 | +) -> dict[str, float]: |
| 35 | + """Return a py3Dmol-compatible position dictionary for an atom index.""" |
| 36 | + position = conformer.GetAtomPosition(atom_index) |
| 37 | + return {"x": float(position.x), "y": float(position.y), "z": float(position.z)} |
| 38 | + |
| 39 | + |
| 40 | +def _normalize_label_spec( |
| 41 | + label: Mapping[str, object], |
| 42 | + conformer: Chem.rdchem.Conformer, |
| 43 | +) -> tuple[str, dict[str, object]]: |
| 44 | + """Validate and normalize one py3Dmol label specification.""" |
| 45 | + if "text" not in label: |
| 46 | + raise ValueError("Each label must include a 'text' field.") |
| 47 | + |
| 48 | + has_atom_index = "atom_index" in label |
| 49 | + has_position = "position" in label |
| 50 | + if has_atom_index == has_position: |
| 51 | + raise ValueError("Each label must include exactly one of 'atom_index' or 'position'.") |
| 52 | + |
| 53 | + text = str(label["text"]) |
| 54 | + style = {key: value for key, value in label.items() if key not in {"text", "atom_index"}} |
| 55 | + |
| 56 | + if has_atom_index: |
| 57 | + atom_index_value = label["atom_index"] |
| 58 | + if isinstance(atom_index_value, bool) or not isinstance(atom_index_value, int | str): |
| 59 | + raise ValueError("Label atom_index must be an integer.") |
| 60 | + atom_index = int(atom_index_value) |
| 61 | + style["position"] = _label_position_from_atom(conformer, atom_index) |
| 62 | + else: |
| 63 | + position = label["position"] |
| 64 | + if not isinstance(position, Mapping) or any( |
| 65 | + key not in position for key in _PY3DMOL_POSITION_KEYS |
| 66 | + ): |
| 67 | + raise ValueError("Label positions must provide mapping keys 'x', 'y', and 'z'.") |
| 68 | + style["position"] = { |
| 69 | + "x": float(position["x"]), |
| 70 | + "y": float(position["y"]), |
| 71 | + "z": float(position["z"]), |
| 72 | + } |
| 73 | + |
| 74 | + return text, style |
| 75 | + |
| 76 | + |
| 77 | +def make_atom_labels_3d( |
| 78 | + mol: Chem.rdchem.Mol, |
| 79 | + *, |
| 80 | + conformer_id: int = -1, |
| 81 | + atom_indices: Sequence[int] | None = None, |
| 82 | + text_fn: Callable[[Chem.rdchem.Atom], str] | None = None, |
| 83 | + color_fn: Callable[[Chem.rdchem.Atom], str | None] | None = None, |
| 84 | + base_style: Mapping[str, object] | None = None, |
| 85 | +) -> list[dict[str, object]]: |
| 86 | + """Build generic per-atom label specifications for ``view_mol_3d``. |
| 87 | +
|
| 88 | + Args: |
| 89 | + mol: Molecule to label. |
| 90 | + conformer_id: RDKit conformer ID to use. ``-1`` selects the default conformer. |
| 91 | + atom_indices: Optional subset of atom indices to label. |
| 92 | + text_fn: Function that returns the label text for an atom. |
| 93 | + color_fn: Optional function that returns a font color for an atom. |
| 94 | + base_style: Optional py3Dmol label-style keys to apply to every label. |
| 95 | +
|
| 96 | + Returns: |
| 97 | + Label dictionaries consumable by ``view_mol_3d``. |
| 98 | +
|
| 99 | + Raises: |
| 100 | + ValueError: If the molecule has no conformer or ``text_fn`` is not provided. |
| 101 | + """ |
| 102 | + if text_fn is None: |
| 103 | + raise ValueError("make_atom_labels_3d requires a text_fn.") |
| 104 | + |
| 105 | + conformer, _ = _get_selected_conformer(mol, conformer_id) |
| 106 | + if atom_indices is None: |
| 107 | + atom_indices = tuple(range(mol.GetNumAtoms())) |
| 108 | + |
| 109 | + labels: list[dict[str, object]] = [] |
| 110 | + for atom_index in atom_indices: |
| 111 | + atom = mol.GetAtomWithIdx(int(atom_index)) |
| 112 | + label: dict[str, object] = dict(base_style or {}) |
| 113 | + label["text"] = text_fn(atom) |
| 114 | + label["position"] = _label_position_from_atom(conformer, atom.GetIdx()) |
| 115 | + color = color_fn(atom) if color_fn is not None else None |
| 116 | + if color is not None: |
| 117 | + label["fontColor"] = color |
| 118 | + labels.append(label) |
| 119 | + return labels |
| 120 | + |
| 121 | + |
| 122 | +def view_mol_3d( |
| 123 | + mol: Chem.rdchem.Mol, |
| 124 | + *, |
| 125 | + conformer_id: int = -1, |
| 126 | + width: int = 800, |
| 127 | + height: int = 400, |
| 128 | + style: Mapping[str, object] | None = None, |
| 129 | + labels: Sequence[Mapping[str, object]] | None = None, |
| 130 | + background_color: str = "white", |
| 131 | + zoom_to: bool = True, |
| 132 | + show: bool = False, |
| 133 | +) -> py3Dmol.view: |
| 134 | + """Create an interactive py3Dmol viewer for an RDKit molecule. |
| 135 | +
|
| 136 | + Args: |
| 137 | + mol: Molecule to visualize. |
| 138 | + conformer_id: RDKit conformer ID to use. ``-1`` selects the default conformer. |
| 139 | + width: Viewer width in pixels. |
| 140 | + height: Viewer height in pixels. |
| 141 | + style: py3Dmol style dictionary passed to ``setStyle``. |
| 142 | + labels: Sequence of label dictionaries. Each must contain ``text`` and |
| 143 | + exactly one of ``atom_index`` or ``position``. |
| 144 | + background_color: Viewer background color. |
| 145 | + zoom_to: Whether to call ``zoomTo()`` after loading content. |
| 146 | + show: Whether to call ``show()`` before returning the viewer. |
| 147 | +
|
| 148 | + Returns: |
| 149 | + A configured py3Dmol viewer. |
| 150 | +
|
| 151 | + Raises: |
| 152 | + ValueError: If the molecule has no conformer or any label spec is invalid. |
| 153 | + """ |
| 154 | + conformer, selected_conformer_id = _get_selected_conformer(mol, conformer_id) |
| 155 | + del conformer # Only needed for label coordinate lookup below. |
| 156 | + |
| 157 | + viewer = py3Dmol.view(width=width, height=height) |
| 158 | + viewer.setBackgroundColor(background_color) |
| 159 | + viewer.addModel(Chem.MolToMolBlock(mol, confId=selected_conformer_id), "sdf") |
| 160 | + viewer.setStyle(dict(style) if style is not None else {"stick": {}}) |
| 161 | + |
| 162 | + if labels: |
| 163 | + selected_conformer = mol.GetConformer(selected_conformer_id) |
| 164 | + for label in labels: |
| 165 | + text, label_style = _normalize_label_spec(label, selected_conformer) |
| 166 | + viewer.addLabel(text, label_style) |
| 167 | + |
| 168 | + if zoom_to: |
| 169 | + viewer.zoomTo() |
| 170 | + if show: |
| 171 | + viewer.show() |
| 172 | + return viewer |
| 173 | + |
| 174 | + |
| 175 | +def view_traj_3d( |
| 176 | + traj: md.Trajectory, |
| 177 | + *, |
| 178 | + representations: Sequence[Mapping[str, object]] | None = None, |
| 179 | +) -> nglview.NGLWidget: |
| 180 | + """Create an interactive nglview widget for a trajectory or structure. |
| 181 | +
|
| 182 | + Args: |
| 183 | + traj: MDTraj trajectory to visualize. |
| 184 | + representations: Sequence of representation dictionaries. Each must |
| 185 | + include ``type`` and may include ``selection`` plus any additional |
| 186 | + nglview representation keyword arguments. |
| 187 | +
|
| 188 | + Returns: |
| 189 | + A configured nglview widget. |
| 190 | +
|
| 191 | + Raises: |
| 192 | + ValueError: If any representation dictionary is missing ``type``. |
| 193 | + """ |
| 194 | + widget = nglview.show_mdtraj(traj) |
| 195 | + widget.clear_representations() |
| 196 | + |
| 197 | + for representation in representations or _DEFAULT_TRAJ_REPRESENTATIONS: |
| 198 | + if "type" not in representation: |
| 199 | + raise ValueError("Each trajectory representation must include a 'type' field.") |
| 200 | + |
| 201 | + rep_kwargs = dict(representation) |
| 202 | + rep_type = str(rep_kwargs.pop("type")) |
| 203 | + selection = rep_kwargs.pop("selection", "all") |
| 204 | + widget.add_representation(rep_type, selection=selection, **rep_kwargs) |
| 205 | + |
| 206 | + return widget |
0 commit comments