From 5576839291878e1920d36bc939d260d2d1b8a382 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 25 Feb 2026 11:58:48 +0100 Subject: [PATCH 01/20] add option for passing mol object, gentle error handling --- .../preprocessing/reader/augmented_reader.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index b6b2d90..55962a2 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -4,6 +4,7 @@ import torch from chebai.preprocessing.reader import DataReader +from chebai.preprocessing.datasets.chebi import sanitize_molecule from rdkit import Chem from torch_geometric.data import Data as GeomData @@ -22,7 +23,6 @@ # https://mail.python.org/pipermail/python-dev/2017-December/151283.html # Order preservation is necessary to to create `is_atom_node` mask - class _AugmentorReader(DataReader, ABC): """ Abstract base class for augmentor readers that extend ChemDataReader. @@ -59,12 +59,12 @@ def name(cls) -> str: """ return f"{cls.__name__}".lower() - def _read_data(self, smiles: str) -> GeomData | None: + def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: """ Reads and augments molecular data from a SMILES string. Args: - smiles (str): SMILES representation of the molecule. + raw_data (str | Chem.Mol): SMILES string or RDKit molecule object representing the molecule. Returns: GeomData | None: A PyTorch Geometric Data object with augmented nodes and edges, @@ -73,20 +73,25 @@ def _read_data(self, smiles: str) -> GeomData | None: Raises: RuntimeError: If an unexpected error occurs during graph augmentation. """ - mol = self._smiles_to_mol(smiles) + if isinstance(raw_data, str): + mol = self._smiles_to_mol(raw_data) + smiles = raw_data + else: + mol = raw_data + smiles = Chem.MolToSmiles(mol) if mol is None: return None try: returned_result = self._create_augmented_graph(mol) except Exception as e: - raise RuntimeError( - f"Error has occurred for following SMILES: {smiles}\n\t {e}" - ) from e + print(f"Failed to construct augmented graph for smiles {smiles}, Error: {e}") + self.f_cnt_for_aug_graph += 1 + return None # If the returned result is None, it indicates that the graph augmentation failed if returned_result is None: - print(f"Failed to construct augmented graph for smiles {smiles}") + print(f"Failed to construct augmented graph for smiles {smiles} (returned None)") self.f_cnt_for_aug_graph += 1 return None @@ -136,13 +141,13 @@ def _smiles_to_mol(self, smiles: str) -> Chem.Mol | None: Returns: Chem.Mol | None: RDKit molecule object if successful, else None. """ - mol = Chem.MolFromSmiles(smiles) + mol = Chem.MolFromSmiles(smiles, sanitize=False) if mol is None: print(f"RDKit failed to parse {smiles} (returned None)") self.f_cnt_for_smiles += 1 else: try: - Chem.SanitizeMol(mol) + mol = sanitize_molecule(mol) except Exception as e: print(f"RDKit failed at sanitizing {smiles}, Error {e}") self.f_cnt_for_smiles += 1 @@ -662,17 +667,17 @@ def add_fg_internal_edge(source_fg: int, target_fg: int) -> None: class _AddGraphNode(_AugmentorReader): """Adds a graph-level node and connects it to selected/given nodes.""" - def _read_data(self, smiles: str) -> GeomData | None: + def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: """ Reads data and adds a graph-level node annotation. Args: - smiles (str): SMILES string. + raw_data (str | Chem.Mol): SMILES string or RDKit molecule object representing the molecule. Returns: Data | None: Geometric data object with is_graph_node annotation. """ - geom_data = super()._read_data(smiles) + geom_data = super()._read_data(raw_data) if geom_data is None: return None NUM_NODES = geom_data.x.shape[0] @@ -941,3 +946,4 @@ def _augment_graph_structure( return self._add_graph_node_and_edges_to_nodes( augmented_struct, atom_ids | fg_to_atoms_map.keys() ) + \ No newline at end of file From 95adba2cbfee48985ef7cedeb16ad86a4898301f Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 25 Feb 2026 12:05:18 +0100 Subject: [PATCH 02/20] update reader for mol objects --- chebai_graph/preprocessing/reader/reader.py | 41 +++++++++++++++++---- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/chebai_graph/preprocessing/reader/reader.py b/chebai_graph/preprocessing/reader/reader.py index a63b8a1..abb29d1 100644 --- a/chebai_graph/preprocessing/reader/reader.py +++ b/chebai_graph/preprocessing/reader/reader.py @@ -1,6 +1,7 @@ import os import chebai.preprocessing.reader as dr +from chebai.preprocessing.datasets.chebi import sanitize_molecule import networkx as nx import pysmiles as ps import rdkit.Chem as Chem @@ -54,30 +55,33 @@ def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None: if smiles in self.mol_object_buffer: return self.mol_object_buffer[smiles] - mol = Chem.MolFromSmiles(smiles) + mol = Chem.MolFromSmiles(smiles, sanitize=False) if mol is None: print(f"RDKit failed to at parsing {smiles} (returned None)") self.failed_counter += 1 else: try: - Chem.SanitizeMol(mol) + sanitize_molecule(mol) except Exception as e: print(f"Rdkit failed at sanitizing {smiles}, \n Error: {e}") self.failed_counter += 1 self.mol_object_buffer[smiles] = mol return mol - def _read_data(self, raw_data: str) -> GeomData | None: + def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: """ Convert raw SMILES string data into a PyTorch Geometric Data object. Args: - raw_data (str): SMILES string. + raw_data (str | Chem.Mol): SMILES string or RDKit molecule object. Returns: GeomData | None: Graph data object or None if molecule parsing failed. """ - mol = self._smiles_to_mol(raw_data) + if isinstance(raw_data, Chem.Mol): + mol = raw_data + else: + mol = self._smiles_to_mol(raw_data) if mol is None: return None @@ -144,19 +148,19 @@ def name(cls) -> str: """ return "graph" - def _read_data(self, raw_data: str) -> GeomData | None: + def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: """ Convert a SMILES string into a PyTorch Geometric Data object with atom tokens and bond order attributes. Args: - raw_data (str): SMILES string. + raw_data (str | Chem.Mol): SMILES string or RDKit molecule object. Returns: GeomData | None: Graph data object or None if parsing failed. """ # raw_data is a SMILES string try: - mol = ps.read_smiles(raw_data) + mol = self._smiles_to_mol(raw_data) if isinstance(raw_data, str) else raw_data except ValueError: return None assert isinstance(mol, nx.Graph) @@ -189,6 +193,27 @@ def _read_data(self, raw_data: str) -> GeomData | None: nx.set_edge_attributes(mol, de, "edge_attr") data = from_networkx(mol) return data + + def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None: + """ + Load SMILES string into an RDKit molecule object. + + Args: + smiles (str): The SMILES string to parse. + + Returns: + Chem.rdchem.Mol | None: Parsed molecule object or None if parsing failed. + """ + + mol = Chem.MolFromSmiles(smiles, sanitize=False) + if mol is None: + print(f"RDKit failed to at parsing {smiles} (returned None)") + else: + try: + sanitize_molecule(mol) + except Exception as e: + print(f"Rdkit failed at sanitizing {smiles}, \n Error: {e}") + return mol def collate(self, list_of_tuples: list) -> any: """ From ad4432f9b366fb6aef3c4a7b5351a7be039e3371 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 25 Feb 2026 13:27:51 +0100 Subject: [PATCH 03/20] fix read property --- .../preprocessing/reader/augmented_reader.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 55962a2..8f5afe0 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -285,21 +285,24 @@ def on_finish(self) -> None: ) self.mol_object_buffer = {} - def read_property(self, smiles: str, property: MolecularProperty) -> list | None: + def read_property(self, data: str | Chem.Mol, property: MolecularProperty) -> list | None: """ Reads a specific property from a molecule represented by a SMILES string. Args: - smiles (str): SMILES string representing the molecule. + data (str | Chem.Mol): SMILES string or RDKit molecule object representing the molecule. property (MolecularProperty): Molecular property object for which the value needs to be extracted. Returns: list | None: Property values if molecule parsing is successful, else None. """ - if smiles in self.mol_object_buffer: - return property.get_property_value(self.mol_object_buffer[smiles]) - - mol = self._smiles_to_mol(smiles) + if isinstance(data, Chem.Mol): + mol = data + else: + smiles = data + if smiles in self.mol_object_buffer: + return property.get_property_value(self.mol_object_buffer[smiles]) + mol = self._smiles_to_mol(smiles) if mol is None: return None From 66c9659e913c4ed4beab9a8e7a30482780f1154e Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 25 Feb 2026 14:39:09 +0100 Subject: [PATCH 04/20] make read_properties more flexible, don't recalculate properties --- chebai_graph/preprocessing/datasets/chebi.py | 10 +++-- .../preprocessing/reader/augmented_reader.py | 45 ++++++++++--------- chebai_graph/preprocessing/reader/reader.py | 10 ++--- 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index fdf2e6d..13d53db 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -15,6 +15,7 @@ ) from lightning_utilities.core.rank_zero import rank_zero_info from torch_geometric.data.data import Data as GeomData +from rdkit import Chem from chebai_graph.preprocessing.properties import ( AllNodeTypeProperty, @@ -185,20 +186,23 @@ def _after_setup(self, **kwargs) -> None: super()._after_setup(**kwargs) def _preprocess_smiles_for_pred( - self, idx, smiles: str, model_hparams: Optional[dict] = None + self, idx, raw_data: str | Chem.Mol, model_hparams: Optional[dict] = None ) -> dict: """Preprocess prediction data.""" # Add dummy labels because the collate function requires them. # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. result = self.reader.to_data( - {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} + {"id": f"smiles_{idx}", "features": raw_data, "labels": [1, 2]} ) + # _read_data can return an updated version of the input data (e.g. augmented molecule dict) along with the GeomData object + if isinstance(result["features"], tuple): + result["features"], raw_data = result["features"][0] if result is None or result["features"] is None: return None for property in self.properties: property.encoder.eval = True - property_value = self.reader.read_property(smiles, property) + property_value = self.reader.read_property(raw_data, property) if property_value is None or len(property_value) == 0: encoded_value = None else: diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 8f5afe0..06cdb0a 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -59,7 +59,7 @@ def name(cls) -> str: """ return f"{cls.__name__}".lower() - def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: + def _read_data(self, raw_data: str | Chem.Mol) -> tuple[GeomData, dict] | None: """ Reads and augments molecular data from a SMILES string. @@ -67,8 +67,8 @@ def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: raw_data (str | Chem.Mol): SMILES string or RDKit molecule object representing the molecule. Returns: - GeomData | None: A PyTorch Geometric Data object with augmented nodes and edges, - or None if parsing or augmentation fails. + tuple[GeomData, dict] | None: A tuple containing a PyTorch Geometric Data object with augmented nodes and edges, + and a dictionary of augmented molecule data, or None if parsing or augmentation fails. Raises: RuntimeError: If an unexpected error occurs during graph augmentation. @@ -124,12 +124,12 @@ def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: NUM_ATOM_NODES = augmented_molecule["nodes"]["atom_nodes"].GetNumAtoms() is_atom_mask[:NUM_ATOM_NODES] = True - return GeomData( + return (GeomData( x=x, edge_index=edge_index, edge_attr=edge_attr, is_atom_node=is_atom_mask, - ) + ), augmented_molecule) def _smiles_to_mol(self, smiles: str) -> Chem.Mol | None: """ @@ -285,32 +285,35 @@ def on_finish(self) -> None: ) self.mol_object_buffer = {} - def read_property(self, data: str | Chem.Mol, property: MolecularProperty) -> list | None: + def read_property(self, raw_data: str | Chem.Mol | dict, property: MolecularProperty) -> list | None: """ Reads a specific property from a molecule represented by a SMILES string. Args: - data (str | Chem.Mol): SMILES string or RDKit molecule object representing the molecule. + raw_data (str | Chem.Mol | dict): SMILES string, RDKit molecule object, or dictionary representation of a molecule. property (MolecularProperty): Molecular property object for which the value needs to be extracted. Returns: list | None: Property values if molecule parsing is successful, else None. """ - if isinstance(data, Chem.Mol): - mol = data + if isinstance(raw_data, dict): + augmented_mol = raw_data else: - smiles = data - if smiles in self.mol_object_buffer: - return property.get_property_value(self.mol_object_buffer[smiles]) - mol = self._smiles_to_mol(smiles) - if mol is None: - return None + if isinstance(raw_data, Chem.Mol): + mol = raw_data + else: + smiles = raw_data + if smiles in self.mol_object_buffer: + return property.get_property_value(self.mol_object_buffer[smiles]) + mol = self._smiles_to_mol(smiles) + if mol is None: + return None - returned_result = self._create_augmented_graph(mol) - if returned_result is None: - return None + returned_result = self._create_augmented_graph(mol) + if returned_result is None: + return None - _, augmented_mol = returned_result + _, augmented_mol = returned_result return property.get_property_value(augmented_mol) @@ -680,14 +683,14 @@ def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: Returns: Data | None: Geometric data object with is_graph_node annotation. """ - geom_data = super()._read_data(raw_data) + geom_data, augmented_mol = super()._read_data(raw_data) if geom_data is None: return None NUM_NODES = geom_data.x.shape[0] is_graph_node = torch.zeros(NUM_NODES, dtype=torch.bool) is_graph_node[-1] = True geom_data.is_graph_node = is_graph_node - return geom_data + return (geom_data, augmented_mol) def _add_graph_node_and_edges_to_nodes( self, diff --git a/chebai_graph/preprocessing/reader/reader.py b/chebai_graph/preprocessing/reader/reader.py index abb29d1..b2e9351 100644 --- a/chebai_graph/preprocessing/reader/reader.py +++ b/chebai_graph/preprocessing/reader/reader.py @@ -68,7 +68,7 @@ def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None: self.mol_object_buffer[smiles] = mol return mol - def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: + def _read_data(self, raw_data: str | Chem.Mol) -> tuple[GeomData, Chem.Mol] | None: """ Convert raw SMILES string data into a PyTorch Geometric Data object. @@ -95,7 +95,7 @@ def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: # edge_index.shape == [2, num_edges]; edge_attr.shape == [num_edges, num_edge_features] edge_attr = torch.zeros((edge_index.size(1), 0)) - return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) + return (GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr), mol) def on_finish(self) -> None: """ @@ -104,18 +104,18 @@ def on_finish(self) -> None: print(f"Failed to read {self.failed_counter} SMILES in total") self.mol_object_buffer = {} - def read_property(self, smiles: str, property: MolecularProperty) -> list | None: + def read_property(self, raw_data: str | Chem.Mol, property: MolecularProperty) -> list | None: """ Read a molecular property for a given SMILES string. Args: - smiles (str): SMILES string of the molecule. + raw_data (str | Chem.Mol): SMILES string or RDKit molecule object of the molecule. property (MolecularProperty): Property extractor to apply. Returns: list | None: Property values or None if molecule parsing failed. """ - mol = self._smiles_to_mol(smiles) + mol = self._smiles_to_mol(raw_data) if isinstance(raw_data, str) else raw_data if mol is None: return None return property.get_property_value(mol) From be7278dad7247be5547c89c1c3ea6b4d3aa67d24 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 25 Feb 2026 14:50:07 +0100 Subject: [PATCH 05/20] avoid redundant property calculation --- chebai_graph/preprocessing/datasets/chebi.py | 12 ++++++++++-- .../preprocessing/reader/augmented_reader.py | 8 ++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 13d53db..b5b55d0 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -5,6 +5,7 @@ from typing import Optional import pandas as pd +from chebai_graph.preprocessing.reader.augmented_reader import _AugmentorReader import torch import tqdm from chebai.preprocessing.datasets.chebi import ( @@ -126,6 +127,13 @@ def enc_if_not_none(encode, value): if value is not None and len(value) > 0 else None ) + + # augment molecule graph if possible (this would also happen for the properties if needed, but this avoids redundancy) + if isinstance(self.reader, _AugmentorReader): + returned_results = [self._create_augmented_graph(mol) for mol in features] + mols = [augmented_mol[1] for augmented_mol in returned_results if augmented_mol is not None] + else: + mols = features for property in self.properties: if not os.path.isfile(self.get_property_path(property)): @@ -133,8 +141,8 @@ def enc_if_not_none(encode, value): # read all property values first, then encode rank_zero_info(f"\tReading property values of {property.name}...") property_values = [ - self.reader.read_property(feat, property) - for feat in tqdm.tqdm(features) + self.reader.read_property(mol, property) + for mol in tqdm.tqdm(mols) ] rank_zero_info(f"\tEncoding property values of {property.name}...") property.encoder.on_start(property_values=property_values) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 06cdb0a..17fc2af 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -308,8 +308,12 @@ def read_property(self, raw_data: str | Chem.Mol | dict, property: MolecularProp mol = self._smiles_to_mol(smiles) if mol is None: return None - - returned_result = self._create_augmented_graph(mol) + try: + returned_result = self._create_augmented_graph(mol) + except Exception as e: + print(f"Failed to construct augmented graph, Error: {e}") + self.f_cnt_for_aug_graph += 1 + return None if returned_result is None: return None From a50ce8fb6dbd3922d98bc5860cca10ad67ccd33c Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 25 Feb 2026 14:53:23 +0100 Subject: [PATCH 06/20] fix function call --- chebai_graph/preprocessing/datasets/chebi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index b5b55d0..bf96b9d 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -130,7 +130,7 @@ def enc_if_not_none(encode, value): # augment molecule graph if possible (this would also happen for the properties if needed, but this avoids redundancy) if isinstance(self.reader, _AugmentorReader): - returned_results = [self._create_augmented_graph(mol) for mol in features] + returned_results = [self.reader._create_augmented_graph(mol) for mol in features] mols = [augmented_mol[1] for augmented_mol in returned_results if augmented_mol is not None] else: mols = features From b629f9c30808ed38347a429fae0405f7c8190524 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 25 Feb 2026 15:07:02 +0100 Subject: [PATCH 07/20] catch mol processing errors --- chebai_graph/preprocessing/datasets/chebi.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index bf96b9d..0509c77 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -130,7 +130,13 @@ def enc_if_not_none(encode, value): # augment molecule graph if possible (this would also happen for the properties if needed, but this avoids redundancy) if isinstance(self.reader, _AugmentorReader): - returned_results = [self.reader._create_augmented_graph(mol) for mol in features] + returned_results = [] + for mol in features: + try: + r = self.reader._create_augmented_graph(mol) + except Exception as e: + r = None + returned_results.append(r) mols = [augmented_mol[1] for augmented_mol in returned_results if augmented_mol is not None] else: mols = features From 430c72ca195911e71e9622ef736f345a259e56e2 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 25 Feb 2026 15:45:38 +0100 Subject: [PATCH 08/20] make sure that property cache is mutable --- chebai_graph/preprocessing/property_encoder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index 991ab16..fc536b7 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -156,6 +156,9 @@ def encode(self, token: str | None) -> torch.Tensor: return torch.tensor([self._unk_token_idx]) if str(token) not in self.cache: + # Ensure cache is a mutable dict (jsonargparse may convert it to mappingproxy) + if not isinstance(self.cache, dict): + self.cache = dict(self.cache) self.cache[str(token)] = len(self.cache) return torch.tensor([self.cache[str(token)] + self.offset]) From 2b9c1e584d6d55ee034622c582e51f73d6e45374 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Sat, 28 Feb 2026 22:44:00 +0100 Subject: [PATCH 09/20] fix data loading --- chebai_graph/preprocessing/datasets/chebi.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 0509c77..2f4e0e4 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -296,7 +296,10 @@ def _merge_props_into_base(self, row: pd.Series | dict) -> GeomData: Returns: A GeomData object with merged features. """ - geom_data = row["features"] + if isinstance(row["features"], tuple): + geom_data, _ = row["features"] # ignore additional returned data from _read_data (e.g. augmented molecule dict) + else: + geom_data = row["features"] assert isinstance(geom_data, GeomData) edge_attr = geom_data.edge_attr x = geom_data.x From 3a53cd2567a765e7aed6a2eb00cf3912cc17c9ba Mon Sep 17 00:00:00 2001 From: sfluegel Date: Sat, 28 Feb 2026 22:53:26 +0100 Subject: [PATCH 10/20] add chebi25 dataset --- chebai_graph/preprocessing/datasets/chebi.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 2f4e0e4..a93175c 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -826,3 +826,9 @@ class ChEBI50_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver50): class ChEBI100_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver100): READER = AtomFGReader_WithFGEdges_WithGraphNode + + +class ChEBI25_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOverX): + READER = AtomFGReader_WithFGEdges_WithGraphNode + + THRESHOLD = 25 \ No newline at end of file From db3d4296957e6ba1c30cbdb3dc41aea7c61ef801 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Sat, 28 Feb 2026 23:18:27 +0100 Subject: [PATCH 11/20] use chebi-utils --- chebai_graph/preprocessing/reader/augmented_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 17fc2af..5b19dc5 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -4,7 +4,7 @@ import torch from chebai.preprocessing.reader import DataReader -from chebai.preprocessing.datasets.chebi import sanitize_molecule +from chebi_utils.sdf_extractor import sanitize_molecule from rdkit import Chem from torch_geometric.data import Data as GeomData From e5d870b306a7a0fdcab2b7b733bb41cef7fbc0d9 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Sun, 1 Mar 2026 21:55:33 +0100 Subject: [PATCH 12/20] fix function name --- chebai_graph/preprocessing/reader/augmented_reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 5b19dc5..3c23fd1 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -4,7 +4,7 @@ import torch from chebai.preprocessing.reader import DataReader -from chebi_utils.sdf_extractor import sanitize_molecule +from chebi_utils.sdf_extractor import _sanitize_molecule from rdkit import Chem from torch_geometric.data import Data as GeomData @@ -147,7 +147,7 @@ def _smiles_to_mol(self, smiles: str) -> Chem.Mol | None: self.f_cnt_for_smiles += 1 else: try: - mol = sanitize_molecule(mol) + mol = _sanitize_molecule(mol) except Exception as e: print(f"RDKit failed at sanitizing {smiles}, Error {e}") self.f_cnt_for_smiles += 1 From 315f49c91797ed62c871c1c4cfb7e8b7e8e0dc95 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 2 Mar 2026 09:15:49 +0100 Subject: [PATCH 13/20] fix function name --- chebai_graph/preprocessing/reader/reader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chebai_graph/preprocessing/reader/reader.py b/chebai_graph/preprocessing/reader/reader.py index b2e9351..802cf3f 100644 --- a/chebai_graph/preprocessing/reader/reader.py +++ b/chebai_graph/preprocessing/reader/reader.py @@ -1,7 +1,7 @@ import os import chebai.preprocessing.reader as dr -from chebai.preprocessing.datasets.chebi import sanitize_molecule +from chebai.preprocessing.datasets.chebi import _sanitize_molecule import networkx as nx import pysmiles as ps import rdkit.Chem as Chem @@ -61,7 +61,7 @@ def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None: self.failed_counter += 1 else: try: - sanitize_molecule(mol) + _sanitize_molecule(mol) except Exception as e: print(f"Rdkit failed at sanitizing {smiles}, \n Error: {e}") self.failed_counter += 1 @@ -210,7 +210,7 @@ def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None: print(f"RDKit failed to at parsing {smiles} (returned None)") else: try: - sanitize_molecule(mol) + _sanitize_molecule(mol) except Exception as e: print(f"Rdkit failed at sanitizing {smiles}, \n Error: {e}") return mol From de3193e12a0e87028bb122c7e3100535b7328a83 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 2 Mar 2026 09:17:56 +0100 Subject: [PATCH 14/20] fix import --- chebai_graph/preprocessing/reader/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/reader/reader.py b/chebai_graph/preprocessing/reader/reader.py index 802cf3f..922b401 100644 --- a/chebai_graph/preprocessing/reader/reader.py +++ b/chebai_graph/preprocessing/reader/reader.py @@ -1,7 +1,7 @@ import os import chebai.preprocessing.reader as dr -from chebai.preprocessing.datasets.chebi import _sanitize_molecule +from chebi_utils.sdf_extractor import _sanitize_molecule import networkx as nx import pysmiles as ps import rdkit.Chem as Chem From ea77f36071b0e0b5b550d15a994428e0dbb86626 Mon Sep 17 00:00:00 2001 From: sifluegel Date: Mon, 2 Mar 2026 09:33:13 +0100 Subject: [PATCH 15/20] add new tokens --- chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt | 1 + chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt | 1 + chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt | 2 ++ 3 files changed, 4 insertions(+) diff --git a/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt b/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt index f024a9d..5ad8538 100644 --- a/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt +++ b/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt @@ -5,3 +5,4 @@ 1 5 6 +8 diff --git a/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt b/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt index 97ae8be..f36bdf7 100644 --- a/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt +++ b/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt @@ -3,3 +3,4 @@ SINGLE AROMATIC TRIPLE DOUBLE +UNSPECIFIED diff --git a/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt b/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt index 7036755..22da8da 100644 --- a/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt +++ b/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt @@ -9,3 +9,5 @@ 7 10 12 +11 +9 From 354225b371025d9ad4ad7f5953db79325d4c9569 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 2 Mar 2026 09:43:39 +0100 Subject: [PATCH 16/20] catch none --- chebai_graph/preprocessing/reader/augmented_reader.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 3c23fd1..1df73a5 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -677,7 +677,7 @@ def add_fg_internal_edge(source_fg: int, target_fg: int) -> None: class _AddGraphNode(_AugmentorReader): """Adds a graph-level node and connects it to selected/given nodes.""" - def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: + def _read_data(self, raw_data: str | Chem.Mol) -> tuple[GeomData, dict] | None: """ Reads data and adds a graph-level node annotation. @@ -687,9 +687,10 @@ def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: Returns: Data | None: Geometric data object with is_graph_node annotation. """ - geom_data, augmented_mol = super()._read_data(raw_data) - if geom_data is None: + res = super()._read_data(raw_data) + if res is None: return None + geom_data, augmented_mol = res NUM_NODES = geom_data.x.shape[0] is_graph_node = torch.zeros(NUM_NODES, dtype=torch.bool) is_graph_node[-1] = True From d2bbad5a358212db1f1002acd579c2e94c7ae659 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 2 Mar 2026 12:56:39 +0100 Subject: [PATCH 17/20] fix assertion error --- chebai_graph/preprocessing/datasets/chebi.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index a93175c..b1b8e45 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -201,7 +201,7 @@ def _after_setup(self, **kwargs) -> None: def _preprocess_smiles_for_pred( self, idx, raw_data: str | Chem.Mol, model_hparams: Optional[dict] = None - ) -> dict: + ) -> Optional[dict]: """Preprocess prediction data.""" # Add dummy labels because the collate function requires them. # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, @@ -211,7 +211,7 @@ def _preprocess_smiles_for_pred( ) # _read_data can return an updated version of the input data (e.g. augmented molecule dict) along with the GeomData object if isinstance(result["features"], tuple): - result["features"], raw_data = result["features"][0] + result["features"], raw_data = result["features"] if result is None or result["features"] is None: return None for property in self.properties: @@ -559,6 +559,8 @@ def _merge_props_into_base( geom_data = row["features"] if geom_data is None: return None + if isinstance(geom_data, tuple): + geom_data = geom_data[0] # ignore additional returned data from _read_data (e.g. augmented molecule dict) assert isinstance(geom_data, GeomData) is_atom_node = geom_data.is_atom_node From 0ba4f1115e533495feb0a72fd211201b3d91a8fb Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 2 Mar 2026 14:13:41 +0100 Subject: [PATCH 18/20] only calculate entended molecule graph if needed, sanitize molecule with custom method in fg rules --- chebai_graph/preprocessing/datasets/chebi.py | 108 ++++++++++-------- .../fg_detection/fg_aware_rule_based.py | 8 +- 2 files changed, 68 insertions(+), 48 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index b1b8e45..fb6d356 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -127,45 +127,53 @@ def enc_if_not_none(encode, value): if value is not None and len(value) > 0 else None ) - - # augment molecule graph if possible (this would also happen for the properties if needed, but this avoids redundancy) - if isinstance(self.reader, _AugmentorReader): - returned_results = [] - for mol in features: - try: - r = self.reader._create_augmented_graph(mol) - except Exception as e: - r = None - returned_results.append(r) - mols = [augmented_mol[1] for augmented_mol in returned_results if augmented_mol is not None] - else: - mols = features - for property in self.properties: - if not os.path.isfile(self.get_property_path(property)): - rank_zero_info(f"Processing property {property.name}") - # read all property values first, then encode - rank_zero_info(f"\tReading property values of {property.name}...") - property_values = [ - self.reader.read_property(mol, property) - for mol in tqdm.tqdm(mols) - ] - rank_zero_info(f"\tEncoding property values of {property.name}...") - property.encoder.on_start(property_values=property_values) - encoded_values = [ - enc_if_not_none(property.encoder.encode, value) - for value in tqdm.tqdm(property_values) + if any( + not os.path.isfile(self.get_property_path(property)) + for property in self.properties + ): + # augment molecule graph if possible (this would also happen for the properties if needed, but this avoids redundancy) + if isinstance(self.reader, _AugmentorReader): + returned_results = [] + for mol in features: + try: + r = self.reader._create_augmented_graph(mol) + except Exception as e: + r = None + returned_results.append(r) + mols = [ + augmented_mol[1] + for augmented_mol in returned_results + if augmented_mol is not None ] - - torch.save( - [ - {property.name: torch.cat(feat), "ident": id} - for feat, id in zip(encoded_values, idents) - if feat is not None - ], - self.get_property_path(property), - ) - property.on_finish() + else: + mols = features + + for property in self.properties: + if not os.path.isfile(self.get_property_path(property)): + rank_zero_info(f"Processing property {property.name}") + # read all property values first, then encode + rank_zero_info(f"\tReading property values of {property.name}...") + property_values = [ + self.reader.read_property(mol, property) + for mol in tqdm.tqdm(mols) + ] + rank_zero_info(f"\tEncoding property values of {property.name}...") + property.encoder.on_start(property_values=property_values) + encoded_values = [ + enc_if_not_none(property.encoder.encode, value) + for value in tqdm.tqdm(property_values) + ] + + torch.save( + [ + {property.name: torch.cat(feat), "ident": id} + for feat, id in zip(encoded_values, idents) + if feat is not None + ], + self.get_property_path(property), + ) + property.on_finish() @property def processed_properties_dir(self) -> str: @@ -268,7 +276,9 @@ def __init__( assert ( distribution is not None and distribution in RandomFeatureInitializationReader.DISTRIBUTIONS - ), "When using padding for features, a valid distribution must be specified." + ), ( + "When using padding for features, a valid distribution must be specified." + ) self.distribution = distribution if self.pad_node_features: print( @@ -297,7 +307,9 @@ def _merge_props_into_base(self, row: pd.Series | dict) -> GeomData: A GeomData object with merged features. """ if isinstance(row["features"], tuple): - geom_data, _ = row["features"] # ignore additional returned data from _read_data (e.g. augmented molecule dict) + geom_data, _ = row[ + "features" + ] # ignore additional returned data from _read_data (e.g. augmented molecule dict) else: geom_data = row["features"] assert isinstance(geom_data, GeomData) @@ -560,7 +572,9 @@ def _merge_props_into_base( if geom_data is None: return None if isinstance(geom_data, tuple): - geom_data = geom_data[0] # ignore additional returned data from _read_data (e.g. augmented molecule dict) + geom_data = geom_data[ + 0 + ] # ignore additional returned data from _read_data (e.g. augmented molecule dict) assert isinstance(geom_data, GeomData) is_atom_node = geom_data.is_atom_node @@ -573,9 +587,9 @@ def _merge_props_into_base( edge_attr = geom_data.edge_attr # Initialize node feature matrix - assert ( - max_len_node_properties is not None - ), "Maximum len of node properties should not be None" + assert max_len_node_properties is not None, ( + "Maximum len of node properties should not be None" + ) x = torch.zeros((num_nodes, max_len_node_properties)) # Track column offsets for each node type @@ -630,9 +644,9 @@ def _merge_props_into_base( raise TypeError(f"Unsupported property type: {type(property).__name__}") total_used_columns = max(atom_offset, fg_offset, graph_offset) - assert ( - total_used_columns <= max_len_node_properties - ), f"Used {total_used_columns} columns, but max allowed is {max_len_node_properties}" + assert total_used_columns <= max_len_node_properties, ( + f"Used {total_used_columns} columns, but max allowed is {max_len_node_properties}" + ) return GeomData( x=x, @@ -833,4 +847,4 @@ class ChEBI100_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver100): class ChEBI25_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOverX): READER = AtomFGReader_WithFGEdges_WithGraphNode - THRESHOLD = 25 \ No newline at end of file + THRESHOLD = 25 diff --git a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py index f60f580..b573f0c 100644 --- a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py +++ b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py @@ -8,6 +8,8 @@ from rdkit.Chem import AllChem from rdkit.Chem import MolToSmiles as m2s +from chebi_utils.sdf_extractor import _sanitize_molecule + from .fg_constants import ELEMENTS, FLAG_NO_FG @@ -1911,7 +1913,11 @@ def get_structure(mol): structure[frag] = {"atom": atom_idx, "is_ring_fg": False} # Convert fragment SMILES back to mol to match with fused ring atom indices - frag_mol = Chem.MolFromSmiles(frag) + frag_mol = Chem.MolFromSmiles(frag, sanitize=False) + try: + frag_mol = _sanitize_molecule(frag_mol) + except: + pass frag_rings = frag_mol.GetRingInfo().AtomRings() if len(frag_rings) >= 1: structure[frag]["is_ring_fg"] = True From 6a85dd2e2b8f9736c398eee5fa7678c3510c117a Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 2 Mar 2026 14:31:22 +0100 Subject: [PATCH 19/20] update default values in configs --- configs/model/gat_aug_aapool.yml | 2 +- configs/model/gat_aug_amgpool.yml | 2 +- configs/model/res_aug_aapool.yml | 2 +- configs/model/res_aug_amgpool.yml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/configs/model/gat_aug_aapool.yml b/configs/model/gat_aug_aapool.yml index 155d402..fae47c3 100644 --- a/configs/model/gat_aug_aapool.yml +++ b/configs/model/gat_aug_aapool.yml @@ -7,7 +7,7 @@ init_args: hidden_channels: 256 out_channels: 512 num_layers: 4 - edge_dim: 11 # number of bond properties + edge_dim: 12 # number of bond properties heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given) v2: True # This uses `torch_geometric.nn.conv.GATv2Conv` convolution layers, set False to use `GATConv` n_molecule_properties: 0 diff --git a/configs/model/gat_aug_amgpool.yml b/configs/model/gat_aug_amgpool.yml index 048bbe1..e596487 100644 --- a/configs/model/gat_aug_amgpool.yml +++ b/configs/model/gat_aug_amgpool.yml @@ -7,7 +7,7 @@ init_args: hidden_channels: 256 out_channels: 512 num_layers: 4 - edge_dim: 11 # number of bond properties + edge_dim: 12 # number of bond properties heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given) v2: True # This uses `torch_geometric.nn.conv.GATv2Conv` convolution layers, set False to use `GATConv` dropout: 0 diff --git a/configs/model/res_aug_aapool.yml b/configs/model/res_aug_aapool.yml index 9e364f9..de28d1c 100644 --- a/configs/model/res_aug_aapool.yml +++ b/configs/model/res_aug_aapool.yml @@ -7,7 +7,7 @@ init_args: hidden_channels: 256 out_channels: 512 num_layers: 4 - edge_dim: 11 # number of bond properties + edge_dim: 12 # number of bond properties dropout: 0 n_molecule_properties: 0 n_linear_layers: 1 diff --git a/configs/model/res_aug_amgpool.yml b/configs/model/res_aug_amgpool.yml index 2aba5ea..9194cd7 100644 --- a/configs/model/res_aug_amgpool.yml +++ b/configs/model/res_aug_amgpool.yml @@ -7,7 +7,7 @@ init_args: hidden_channels: 256 out_channels: 512 num_layers: 4 - edge_dim: 11 # number of bond properties + edge_dim: 12 # number of bond properties dropout: 0 n_molecule_properties: 0 n_linear_layers: 1 From 578c2ef6d90ecbde07ef2aa6bfc887538e36656f Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 2 Mar 2026 14:33:17 +0100 Subject: [PATCH 20/20] reformat w/ ruff --- chebai_graph/models/dynamic_gni.py | 18 +-- chebai_graph/preprocessing/datasets/chebi.py | 9 +- chebai_graph/preprocessing/properties/base.py | 12 +- .../preprocessing/property_encoder.py | 6 +- .../preprocessing/reader/augmented_reader.py | 107 ++++++++++-------- chebai_graph/preprocessing/reader/reader.py | 10 +- 6 files changed, 91 insertions(+), 71 deletions(-) diff --git a/chebai_graph/models/dynamic_gni.py b/chebai_graph/models/dynamic_gni.py index bb48571..8cb6c7b 100644 --- a/chebai_graph/models/dynamic_gni.py +++ b/chebai_graph/models/dynamic_gni.py @@ -85,9 +85,9 @@ def __init__(self, config: dict[str, Any], **kwargs: Any): print("Using complete randomness: ", self.complete_randomness) if not self.complete_randomness: - assert ( - "pad_node_features" in config or "pad_edge_features" in config - ), "Missing 'pad_node_features' or 'pad_edge_features' in config when complete_randomness is False" + assert "pad_node_features" in config or "pad_edge_features" in config, ( + "Missing 'pad_node_features' or 'pad_edge_features' in config when complete_randomness is False" + ) self.pad_node_features = ( int(config["pad_node_features"]) if config.get("pad_node_features") is not None @@ -112,9 +112,9 @@ def __init__(self, config: dict[str, Any], **kwargs: Any): f"in each forward pass." ) - assert ( - self.pad_node_features > 0 or self.pad_edge_features > 0 - ), "'pad_node_features' or 'pad_edge_features' must be positive integers" + assert self.pad_node_features > 0 or self.pad_edge_features > 0, ( + "'pad_node_features' or 'pad_edge_features' must be positive integers" + ) self.resgated: BasicGNN = ResGatedModel( in_channels=self.in_channels, @@ -182,9 +182,9 @@ def forward(self, batch: dict[str, Any]) -> Tensor: ) new_edge_attr = torch.cat((graph_data.edge_attr, pad_edge), dim=1) - assert ( - new_x is not None and new_edge_attr is not None - ), "Feature initialization failed" + assert new_x is not None and new_edge_attr is not None, ( + "Feature initialization failed" + ) out = self.resgated( x=new_x.float(), edge_index=graph_data.edge_index.long(), diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index fb6d356..2ed8362 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -610,7 +610,14 @@ def _merge_props_into_base( enc_len = property_values.shape[1] # -------------- Node properties --------------- if isinstance(property, AllNodeTypeProperty): - x[:, atom_offset : atom_offset + enc_len] = property_values + try: + x[:, atom_offset : atom_offset + enc_len] = property_values + except Exception as e: + raise ValueError( + f"Error assigning property '{property.name}' values to node features: {e}\n" + f"Property values shape: {property_values.shape}, expected (num_nodes, {enc_len})\n" + f"Node feature matrix shape: {x.shape}" + ) atom_offset += enc_len fg_offset += enc_len graph_offset += enc_len diff --git a/chebai_graph/preprocessing/properties/base.py b/chebai_graph/preprocessing/properties/base.py index da5d9c2..ea2725a 100644 --- a/chebai_graph/preprocessing/properties/base.py +++ b/chebai_graph/preprocessing/properties/base.py @@ -262,9 +262,9 @@ def get_property_value(self, augmented_mol: dict) -> list: ) prop_list.append(self.get_atom_value(graph_node)) - assert ( - len(prop_list) == augmented_mol[self.MAIN_KEY]["num_nodes"] - ), "Number of property values should be equal to number of nodes" + assert len(prop_list) == augmented_mol[self.MAIN_KEY]["num_nodes"], ( + "Number of property values should be equal to number of nodes" + ) return prop_list def _check_modify_atom_prop_value( @@ -390,9 +390,9 @@ def get_property_value(self, augmented_mol: dict) -> list: ) num_directed_edges = augmented_mol[self.MAIN_KEY][k.NUM_EDGES] // 2 - assert ( - len(prop_list) == num_directed_edges - ), f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} " + assert len(prop_list) == num_directed_edges, ( + f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} " + ) return prop_list diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index fc536b7..3edd2d2 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -261,9 +261,9 @@ def encode(self, token: float | int | None) -> torch.Tensor: """ if token is None: return torch.zeros(1, self.get_encoding_length()) - assert ( - len(token) == self.get_encoding_length() - ), "Length of token should be equal to encoding length" + assert len(token) == self.get_encoding_length(), ( + "Length of token should be equal to encoding length" + ) # return torch.tensor([token]) # token is an ndarray, no need to create list of ndarray due to below warning # UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. # Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 1df73a5..986ad7d 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -23,6 +23,7 @@ # https://mail.python.org/pipermail/python-dev/2017-December/151283.html # Order preservation is necessary to to create `is_atom_node` mask + class _AugmentorReader(DataReader, ABC): """ Abstract base class for augmentor readers that extend ChemDataReader. @@ -85,13 +86,17 @@ def _read_data(self, raw_data: str | Chem.Mol) -> tuple[GeomData, dict] | None: try: returned_result = self._create_augmented_graph(mol) except Exception as e: - print(f"Failed to construct augmented graph for smiles {smiles}, Error: {e}") + print( + f"Failed to construct augmented graph for smiles {smiles}, Error: {e}" + ) self.f_cnt_for_aug_graph += 1 return None # If the returned result is None, it indicates that the graph augmentation failed if returned_result is None: - print(f"Failed to construct augmented graph for smiles {smiles} (returned None)") + print( + f"Failed to construct augmented graph for smiles {smiles} (returned None)" + ) self.f_cnt_for_aug_graph += 1 return None @@ -100,36 +105,39 @@ def _read_data(self, raw_data: str | Chem.Mol) -> tuple[GeomData, dict] | None: # Empty features initialized; node and edge features can be added later NUM_NODES = augmented_molecule["nodes"]["num_nodes"] - assert ( - NUM_NODES is not None and NUM_NODES > 1 - ), "Num of nodes in augmented graph should be more than 1" + assert NUM_NODES is not None and NUM_NODES > 1, ( + "Num of nodes in augmented graph should be more than 1" + ) x = torch.zeros((NUM_NODES, 0)) edge_attr = torch.zeros((augmented_molecule["edges"][k.NUM_EDGES], 0)) - assert ( - edge_index.shape[0] == 2 - ), f"Expected edge_index to have shape [2, num_edges], but got shape {edge_index.shape}" + assert edge_index.shape[0] == 2, ( + f"Expected edge_index to have shape [2, num_edges], but got shape {edge_index.shape}" + ) - assert ( - edge_index.shape[1] == edge_attr.shape[0] - ), f"Mismatch between number of edges in edge_index ({edge_index.shape[1]}) and edge_attr ({edge_attr.shape[0]})" + assert edge_index.shape[1] == edge_attr.shape[0], ( + f"Mismatch between number of edges in edge_index ({edge_index.shape[1]}) and edge_attr ({edge_attr.shape[0]})" + ) - assert ( - len(set(edge_index[0].tolist())) == x.shape[0] - ), f"Number of unique source nodes in edge_index ({len(set(edge_index[0].tolist()))}) does not match number of nodes in x ({x.shape[0]})" + assert len(set(edge_index[0].tolist())) == x.shape[0], ( + f"Number of unique source nodes in edge_index ({len(set(edge_index[0].tolist()))}) does not match number of nodes in x ({x.shape[0]})" + ) # Create a boolean mask: True for atom, False for augmented is_atom_mask = torch.zeros(NUM_NODES, dtype=torch.bool) NUM_ATOM_NODES = augmented_molecule["nodes"]["atom_nodes"].GetNumAtoms() is_atom_mask[:NUM_ATOM_NODES] = True - return (GeomData( - x=x, - edge_index=edge_index, - edge_attr=edge_attr, - is_atom_node=is_atom_mask, - ), augmented_molecule) + return ( + GeomData( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + is_atom_node=is_atom_mask, + ), + augmented_molecule, + ) def _smiles_to_mol(self, smiles: str) -> Chem.Mol | None: """ @@ -219,9 +227,9 @@ def _augment_graph_structure(self, mol: Chem.Mol) -> dict: atom_edge_index = self._generate_atom_level_edge_index(mol) total_atoms = mol.GetNumAtoms() - assert ( - self._idx_of_node == total_atoms - ), f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}" + assert self._idx_of_node == total_atoms, ( + f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}" + ) node_info = { "atom_nodes": mol, @@ -229,9 +237,9 @@ def _augment_graph_structure(self, mol: Chem.Mol) -> dict: } total_edges = mol.GetNumBonds() - assert ( - self._idx_of_edge == total_edges - ), f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}" + assert self._idx_of_edge == total_edges, ( + f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}" + ) edge_info = { k.WITHIN_ATOMS_EDGE: mol, k.NUM_EDGES: self._idx_of_edge, @@ -285,7 +293,9 @@ def on_finish(self) -> None: ) self.mol_object_buffer = {} - def read_property(self, raw_data: str | Chem.Mol | dict, property: MolecularProperty) -> list | None: + def read_property( + self, raw_data: str | Chem.Mol | dict, property: MolecularProperty + ) -> list | None: """ Reads a specific property from a molecule represented by a SMILES string. @@ -362,16 +372,16 @@ def _augment_graph_structure( augmented_mol["directed_edge_index"] = directed_edge_index total_atoms = sum([mol.GetNumAtoms(), len(fg_nodes)]) - assert ( - self._idx_of_node == total_atoms - ), f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}" + assert self._idx_of_node == total_atoms, ( + f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}" + ) augmented_mol["node_info"]["fg_nodes"] = fg_nodes augmented_mol["node_info"]["num_nodes"] = self._idx_of_node total_edges = sum([mol.GetNumBonds(), len(atom_fg_edges)]) - assert ( - self._idx_of_edge == total_edges - ), f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}" + assert self._idx_of_edge == total_edges, ( + f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}" + ) augmented_mol["edge_info"][k.ATOM_FG_EDGE] = atom_fg_edges augmented_mol["edge_info"][k.NUM_EDGES] = self._idx_of_edge @@ -594,12 +604,12 @@ def _augment_graph_structure( augmented_struct["edge_info"][k.WITHIN_FG_EDGE] = internal_fg_edges augmented_struct["edge_info"][k.NUM_EDGES] += len(internal_fg_edges) - assert ( - self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES] - ), f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}" - assert ( - self._idx_of_node == augmented_struct["node_info"]["num_nodes"] - ), f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}" + assert self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES], ( + f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}" + ) + assert self._idx_of_node == augmented_struct["node_info"]["num_nodes"], ( + f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}" + ) augmented_struct["directed_edge_index"] = torch.cat( [ @@ -629,9 +639,9 @@ def _construct_fg_level_structure( internal_edge_index = [[], []] def add_fg_internal_edge(source_fg: int, target_fg: int) -> None: - assert ( - source_fg is not None and target_fg is not None - ), "Each bond should have a fg node on both end" + assert source_fg is not None and target_fg is not None, ( + "Each bond should have a fg node on both end" + ) assert source_fg != target_fg, "Source and Target FG should be different" edge_key = tuple(sorted((source_fg, target_fg))) @@ -718,15 +728,15 @@ def _add_graph_node_and_edges_to_nodes( augmented_struct["edge_info"][k.TO_GRAPHNODE_EDGE] = nodes_to_graph_edges augmented_struct["edge_info"][k.NUM_EDGES] += len(nodes_to_graph_edges) - assert ( - self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES] - ), f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}" + assert self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES], ( + f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}" + ) augmented_struct["node_info"]["graph_node"] = graph_node augmented_struct["node_info"]["num_nodes"] += 1 - assert ( - self._idx_of_node == augmented_struct["node_info"]["num_nodes"] - ), f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}" + assert self._idx_of_node == augmented_struct["node_info"]["num_nodes"], ( + f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}" + ) augmented_struct["directed_edge_index"] = torch.cat( [ @@ -957,4 +967,3 @@ def _augment_graph_structure( return self._add_graph_node_and_edges_to_nodes( augmented_struct, atom_ids | fg_to_atoms_map.keys() ) - \ No newline at end of file diff --git a/chebai_graph/preprocessing/reader/reader.py b/chebai_graph/preprocessing/reader/reader.py index 922b401..d5ff4bc 100644 --- a/chebai_graph/preprocessing/reader/reader.py +++ b/chebai_graph/preprocessing/reader/reader.py @@ -104,7 +104,9 @@ def on_finish(self) -> None: print(f"Failed to read {self.failed_counter} SMILES in total") self.mol_object_buffer = {} - def read_property(self, raw_data: str | Chem.Mol, property: MolecularProperty) -> list | None: + def read_property( + self, raw_data: str | Chem.Mol, property: MolecularProperty + ) -> list | None: """ Read a molecular property for a given SMILES string. @@ -160,7 +162,9 @@ def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: """ # raw_data is a SMILES string try: - mol = self._smiles_to_mol(raw_data) if isinstance(raw_data, str) else raw_data + mol = ( + self._smiles_to_mol(raw_data) if isinstance(raw_data, str) else raw_data + ) except ValueError: return None assert isinstance(mol, nx.Graph) @@ -193,7 +197,7 @@ def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: nx.set_edge_attributes(mol, de, "edge_attr") data = from_networkx(mol) return data - + def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None: """ Load SMILES string into an RDKit molecule object.