diff --git a/chebai_graph/models/dynamic_gni.py b/chebai_graph/models/dynamic_gni.py index bb48571..8cb6c7b 100644 --- a/chebai_graph/models/dynamic_gni.py +++ b/chebai_graph/models/dynamic_gni.py @@ -85,9 +85,9 @@ def __init__(self, config: dict[str, Any], **kwargs: Any): print("Using complete randomness: ", self.complete_randomness) if not self.complete_randomness: - assert ( - "pad_node_features" in config or "pad_edge_features" in config - ), "Missing 'pad_node_features' or 'pad_edge_features' in config when complete_randomness is False" + assert "pad_node_features" in config or "pad_edge_features" in config, ( + "Missing 'pad_node_features' or 'pad_edge_features' in config when complete_randomness is False" + ) self.pad_node_features = ( int(config["pad_node_features"]) if config.get("pad_node_features") is not None @@ -112,9 +112,9 @@ def __init__(self, config: dict[str, Any], **kwargs: Any): f"in each forward pass." ) - assert ( - self.pad_node_features > 0 or self.pad_edge_features > 0 - ), "'pad_node_features' or 'pad_edge_features' must be positive integers" + assert self.pad_node_features > 0 or self.pad_edge_features > 0, ( + "'pad_node_features' or 'pad_edge_features' must be positive integers" + ) self.resgated: BasicGNN = ResGatedModel( in_channels=self.in_channels, @@ -182,9 +182,9 @@ def forward(self, batch: dict[str, Any]) -> Tensor: ) new_edge_attr = torch.cat((graph_data.edge_attr, pad_edge), dim=1) - assert ( - new_x is not None and new_edge_attr is not None - ), "Feature initialization failed" + assert new_x is not None and new_edge_attr is not None, ( + "Feature initialization failed" + ) out = self.resgated( x=new_x.float(), edge_index=graph_data.edge_index.long(), diff --git a/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt b/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt index f024a9d..5ad8538 100644 --- a/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt +++ b/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt @@ -5,3 +5,4 @@ 1 5 6 +8 diff --git a/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt b/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt index 97ae8be..f36bdf7 100644 --- a/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt +++ b/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt @@ -3,3 +3,4 @@ SINGLE AROMATIC TRIPLE DOUBLE +UNSPECIFIED diff --git a/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt b/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt index 7036755..22da8da 100644 --- a/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt +++ b/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt @@ -9,3 +9,5 @@ 7 10 12 +11 +9 diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index fdf2e6d..2ed8362 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -5,6 +5,7 @@ from typing import Optional import pandas as pd +from chebai_graph.preprocessing.reader.augmented_reader import _AugmentorReader import torch import tqdm from chebai.preprocessing.datasets.chebi import ( @@ -15,6 +16,7 @@ ) from lightning_utilities.core.rank_zero import rank_zero_info from torch_geometric.data.data import Data as GeomData +from rdkit import Chem from chebai_graph.preprocessing.properties import ( AllNodeTypeProperty, @@ -126,31 +128,52 @@ def enc_if_not_none(encode, value): else None ) - for property in self.properties: - if not os.path.isfile(self.get_property_path(property)): - rank_zero_info(f"Processing property {property.name}") - # read all property values first, then encode - rank_zero_info(f"\tReading property values of {property.name}...") - property_values = [ - self.reader.read_property(feat, property) - for feat in tqdm.tqdm(features) - ] - rank_zero_info(f"\tEncoding property values of {property.name}...") - property.encoder.on_start(property_values=property_values) - encoded_values = [ - enc_if_not_none(property.encoder.encode, value) - for value in tqdm.tqdm(property_values) + if any( + not os.path.isfile(self.get_property_path(property)) + for property in self.properties + ): + # augment molecule graph if possible (this would also happen for the properties if needed, but this avoids redundancy) + if isinstance(self.reader, _AugmentorReader): + returned_results = [] + for mol in features: + try: + r = self.reader._create_augmented_graph(mol) + except Exception as e: + r = None + returned_results.append(r) + mols = [ + augmented_mol[1] + for augmented_mol in returned_results + if augmented_mol is not None ] - - torch.save( - [ - {property.name: torch.cat(feat), "ident": id} - for feat, id in zip(encoded_values, idents) - if feat is not None - ], - self.get_property_path(property), - ) - property.on_finish() + else: + mols = features + + for property in self.properties: + if not os.path.isfile(self.get_property_path(property)): + rank_zero_info(f"Processing property {property.name}") + # read all property values first, then encode + rank_zero_info(f"\tReading property values of {property.name}...") + property_values = [ + self.reader.read_property(mol, property) + for mol in tqdm.tqdm(mols) + ] + rank_zero_info(f"\tEncoding property values of {property.name}...") + property.encoder.on_start(property_values=property_values) + encoded_values = [ + enc_if_not_none(property.encoder.encode, value) + for value in tqdm.tqdm(property_values) + ] + + torch.save( + [ + {property.name: torch.cat(feat), "ident": id} + for feat, id in zip(encoded_values, idents) + if feat is not None + ], + self.get_property_path(property), + ) + property.on_finish() @property def processed_properties_dir(self) -> str: @@ -185,20 +208,23 @@ def _after_setup(self, **kwargs) -> None: super()._after_setup(**kwargs) def _preprocess_smiles_for_pred( - self, idx, smiles: str, model_hparams: Optional[dict] = None - ) -> dict: + self, idx, raw_data: str | Chem.Mol, model_hparams: Optional[dict] = None + ) -> Optional[dict]: """Preprocess prediction data.""" # Add dummy labels because the collate function requires them. # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. result = self.reader.to_data( - {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} + {"id": f"smiles_{idx}", "features": raw_data, "labels": [1, 2]} ) + # _read_data can return an updated version of the input data (e.g. augmented molecule dict) along with the GeomData object + if isinstance(result["features"], tuple): + result["features"], raw_data = result["features"] if result is None or result["features"] is None: return None for property in self.properties: property.encoder.eval = True - property_value = self.reader.read_property(smiles, property) + property_value = self.reader.read_property(raw_data, property) if property_value is None or len(property_value) == 0: encoded_value = None else: @@ -250,7 +276,9 @@ def __init__( assert ( distribution is not None and distribution in RandomFeatureInitializationReader.DISTRIBUTIONS - ), "When using padding for features, a valid distribution must be specified." + ), ( + "When using padding for features, a valid distribution must be specified." + ) self.distribution = distribution if self.pad_node_features: print( @@ -278,7 +306,12 @@ def _merge_props_into_base(self, row: pd.Series | dict) -> GeomData: Returns: A GeomData object with merged features. """ - geom_data = row["features"] + if isinstance(row["features"], tuple): + geom_data, _ = row[ + "features" + ] # ignore additional returned data from _read_data (e.g. augmented molecule dict) + else: + geom_data = row["features"] assert isinstance(geom_data, GeomData) edge_attr = geom_data.edge_attr x = geom_data.x @@ -538,6 +571,10 @@ def _merge_props_into_base( geom_data = row["features"] if geom_data is None: return None + if isinstance(geom_data, tuple): + geom_data = geom_data[ + 0 + ] # ignore additional returned data from _read_data (e.g. augmented molecule dict) assert isinstance(geom_data, GeomData) is_atom_node = geom_data.is_atom_node @@ -550,9 +587,9 @@ def _merge_props_into_base( edge_attr = geom_data.edge_attr # Initialize node feature matrix - assert ( - max_len_node_properties is not None - ), "Maximum len of node properties should not be None" + assert max_len_node_properties is not None, ( + "Maximum len of node properties should not be None" + ) x = torch.zeros((num_nodes, max_len_node_properties)) # Track column offsets for each node type @@ -573,7 +610,14 @@ def _merge_props_into_base( enc_len = property_values.shape[1] # -------------- Node properties --------------- if isinstance(property, AllNodeTypeProperty): - x[:, atom_offset : atom_offset + enc_len] = property_values + try: + x[:, atom_offset : atom_offset + enc_len] = property_values + except Exception as e: + raise ValueError( + f"Error assigning property '{property.name}' values to node features: {e}\n" + f"Property values shape: {property_values.shape}, expected (num_nodes, {enc_len})\n" + f"Node feature matrix shape: {x.shape}" + ) atom_offset += enc_len fg_offset += enc_len graph_offset += enc_len @@ -607,9 +651,9 @@ def _merge_props_into_base( raise TypeError(f"Unsupported property type: {type(property).__name__}") total_used_columns = max(atom_offset, fg_offset, graph_offset) - assert ( - total_used_columns <= max_len_node_properties - ), f"Used {total_used_columns} columns, but max allowed is {max_len_node_properties}" + assert total_used_columns <= max_len_node_properties, ( + f"Used {total_used_columns} columns, but max allowed is {max_len_node_properties}" + ) return GeomData( x=x, @@ -805,3 +849,9 @@ class ChEBI50_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver50): class ChEBI100_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver100): READER = AtomFGReader_WithFGEdges_WithGraphNode + + +class ChEBI25_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOverX): + READER = AtomFGReader_WithFGEdges_WithGraphNode + + THRESHOLD = 25 diff --git a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py index f60f580..b573f0c 100644 --- a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py +++ b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py @@ -8,6 +8,8 @@ from rdkit.Chem import AllChem from rdkit.Chem import MolToSmiles as m2s +from chebi_utils.sdf_extractor import _sanitize_molecule + from .fg_constants import ELEMENTS, FLAG_NO_FG @@ -1911,7 +1913,11 @@ def get_structure(mol): structure[frag] = {"atom": atom_idx, "is_ring_fg": False} # Convert fragment SMILES back to mol to match with fused ring atom indices - frag_mol = Chem.MolFromSmiles(frag) + frag_mol = Chem.MolFromSmiles(frag, sanitize=False) + try: + frag_mol = _sanitize_molecule(frag_mol) + except: + pass frag_rings = frag_mol.GetRingInfo().AtomRings() if len(frag_rings) >= 1: structure[frag]["is_ring_fg"] = True diff --git a/chebai_graph/preprocessing/properties/base.py b/chebai_graph/preprocessing/properties/base.py index da5d9c2..ea2725a 100644 --- a/chebai_graph/preprocessing/properties/base.py +++ b/chebai_graph/preprocessing/properties/base.py @@ -262,9 +262,9 @@ def get_property_value(self, augmented_mol: dict) -> list: ) prop_list.append(self.get_atom_value(graph_node)) - assert ( - len(prop_list) == augmented_mol[self.MAIN_KEY]["num_nodes"] - ), "Number of property values should be equal to number of nodes" + assert len(prop_list) == augmented_mol[self.MAIN_KEY]["num_nodes"], ( + "Number of property values should be equal to number of nodes" + ) return prop_list def _check_modify_atom_prop_value( @@ -390,9 +390,9 @@ def get_property_value(self, augmented_mol: dict) -> list: ) num_directed_edges = augmented_mol[self.MAIN_KEY][k.NUM_EDGES] // 2 - assert ( - len(prop_list) == num_directed_edges - ), f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} " + assert len(prop_list) == num_directed_edges, ( + f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} " + ) return prop_list diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index 991ab16..3edd2d2 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -156,6 +156,9 @@ def encode(self, token: str | None) -> torch.Tensor: return torch.tensor([self._unk_token_idx]) if str(token) not in self.cache: + # Ensure cache is a mutable dict (jsonargparse may convert it to mappingproxy) + if not isinstance(self.cache, dict): + self.cache = dict(self.cache) self.cache[str(token)] = len(self.cache) return torch.tensor([self.cache[str(token)] + self.offset]) @@ -258,9 +261,9 @@ def encode(self, token: float | int | None) -> torch.Tensor: """ if token is None: return torch.zeros(1, self.get_encoding_length()) - assert ( - len(token) == self.get_encoding_length() - ), "Length of token should be equal to encoding length" + assert len(token) == self.get_encoding_length(), ( + "Length of token should be equal to encoding length" + ) # return torch.tensor([token]) # token is an ndarray, no need to create list of ndarray due to below warning # UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. # Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index b6b2d90..986ad7d 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -4,6 +4,7 @@ import torch from chebai.preprocessing.reader import DataReader +from chebi_utils.sdf_extractor import _sanitize_molecule from rdkit import Chem from torch_geometric.data import Data as GeomData @@ -59,34 +60,43 @@ def name(cls) -> str: """ return f"{cls.__name__}".lower() - def _read_data(self, smiles: str) -> GeomData | None: + def _read_data(self, raw_data: str | Chem.Mol) -> tuple[GeomData, dict] | None: """ Reads and augments molecular data from a SMILES string. Args: - smiles (str): SMILES representation of the molecule. + raw_data (str | Chem.Mol): SMILES string or RDKit molecule object representing the molecule. Returns: - GeomData | None: A PyTorch Geometric Data object with augmented nodes and edges, - or None if parsing or augmentation fails. + tuple[GeomData, dict] | None: A tuple containing a PyTorch Geometric Data object with augmented nodes and edges, + and a dictionary of augmented molecule data, or None if parsing or augmentation fails. Raises: RuntimeError: If an unexpected error occurs during graph augmentation. """ - mol = self._smiles_to_mol(smiles) + if isinstance(raw_data, str): + mol = self._smiles_to_mol(raw_data) + smiles = raw_data + else: + mol = raw_data + smiles = Chem.MolToSmiles(mol) if mol is None: return None try: returned_result = self._create_augmented_graph(mol) except Exception as e: - raise RuntimeError( - f"Error has occurred for following SMILES: {smiles}\n\t {e}" - ) from e + print( + f"Failed to construct augmented graph for smiles {smiles}, Error: {e}" + ) + self.f_cnt_for_aug_graph += 1 + return None # If the returned result is None, it indicates that the graph augmentation failed if returned_result is None: - print(f"Failed to construct augmented graph for smiles {smiles}") + print( + f"Failed to construct augmented graph for smiles {smiles} (returned None)" + ) self.f_cnt_for_aug_graph += 1 return None @@ -95,35 +105,38 @@ def _read_data(self, smiles: str) -> GeomData | None: # Empty features initialized; node and edge features can be added later NUM_NODES = augmented_molecule["nodes"]["num_nodes"] - assert ( - NUM_NODES is not None and NUM_NODES > 1 - ), "Num of nodes in augmented graph should be more than 1" + assert NUM_NODES is not None and NUM_NODES > 1, ( + "Num of nodes in augmented graph should be more than 1" + ) x = torch.zeros((NUM_NODES, 0)) edge_attr = torch.zeros((augmented_molecule["edges"][k.NUM_EDGES], 0)) - assert ( - edge_index.shape[0] == 2 - ), f"Expected edge_index to have shape [2, num_edges], but got shape {edge_index.shape}" + assert edge_index.shape[0] == 2, ( + f"Expected edge_index to have shape [2, num_edges], but got shape {edge_index.shape}" + ) - assert ( - edge_index.shape[1] == edge_attr.shape[0] - ), f"Mismatch between number of edges in edge_index ({edge_index.shape[1]}) and edge_attr ({edge_attr.shape[0]})" + assert edge_index.shape[1] == edge_attr.shape[0], ( + f"Mismatch between number of edges in edge_index ({edge_index.shape[1]}) and edge_attr ({edge_attr.shape[0]})" + ) - assert ( - len(set(edge_index[0].tolist())) == x.shape[0] - ), f"Number of unique source nodes in edge_index ({len(set(edge_index[0].tolist()))}) does not match number of nodes in x ({x.shape[0]})" + assert len(set(edge_index[0].tolist())) == x.shape[0], ( + f"Number of unique source nodes in edge_index ({len(set(edge_index[0].tolist()))}) does not match number of nodes in x ({x.shape[0]})" + ) # Create a boolean mask: True for atom, False for augmented is_atom_mask = torch.zeros(NUM_NODES, dtype=torch.bool) NUM_ATOM_NODES = augmented_molecule["nodes"]["atom_nodes"].GetNumAtoms() is_atom_mask[:NUM_ATOM_NODES] = True - return GeomData( - x=x, - edge_index=edge_index, - edge_attr=edge_attr, - is_atom_node=is_atom_mask, + return ( + GeomData( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + is_atom_node=is_atom_mask, + ), + augmented_molecule, ) def _smiles_to_mol(self, smiles: str) -> Chem.Mol | None: @@ -136,13 +149,13 @@ def _smiles_to_mol(self, smiles: str) -> Chem.Mol | None: Returns: Chem.Mol | None: RDKit molecule object if successful, else None. """ - mol = Chem.MolFromSmiles(smiles) + mol = Chem.MolFromSmiles(smiles, sanitize=False) if mol is None: print(f"RDKit failed to parse {smiles} (returned None)") self.f_cnt_for_smiles += 1 else: try: - Chem.SanitizeMol(mol) + mol = _sanitize_molecule(mol) except Exception as e: print(f"RDKit failed at sanitizing {smiles}, Error {e}") self.f_cnt_for_smiles += 1 @@ -214,9 +227,9 @@ def _augment_graph_structure(self, mol: Chem.Mol) -> dict: atom_edge_index = self._generate_atom_level_edge_index(mol) total_atoms = mol.GetNumAtoms() - assert ( - self._idx_of_node == total_atoms - ), f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}" + assert self._idx_of_node == total_atoms, ( + f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}" + ) node_info = { "atom_nodes": mol, @@ -224,9 +237,9 @@ def _augment_graph_structure(self, mol: Chem.Mol) -> dict: } total_edges = mol.GetNumBonds() - assert ( - self._idx_of_edge == total_edges - ), f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}" + assert self._idx_of_edge == total_edges, ( + f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}" + ) edge_info = { k.WITHIN_ATOMS_EDGE: mol, k.NUM_EDGES: self._idx_of_edge, @@ -280,29 +293,41 @@ def on_finish(self) -> None: ) self.mol_object_buffer = {} - def read_property(self, smiles: str, property: MolecularProperty) -> list | None: + def read_property( + self, raw_data: str | Chem.Mol | dict, property: MolecularProperty + ) -> list | None: """ Reads a specific property from a molecule represented by a SMILES string. Args: - smiles (str): SMILES string representing the molecule. + raw_data (str | Chem.Mol | dict): SMILES string, RDKit molecule object, or dictionary representation of a molecule. property (MolecularProperty): Molecular property object for which the value needs to be extracted. Returns: list | None: Property values if molecule parsing is successful, else None. """ - if smiles in self.mol_object_buffer: - return property.get_property_value(self.mol_object_buffer[smiles]) - - mol = self._smiles_to_mol(smiles) - if mol is None: - return None - - returned_result = self._create_augmented_graph(mol) - if returned_result is None: - return None + if isinstance(raw_data, dict): + augmented_mol = raw_data + else: + if isinstance(raw_data, Chem.Mol): + mol = raw_data + else: + smiles = raw_data + if smiles in self.mol_object_buffer: + return property.get_property_value(self.mol_object_buffer[smiles]) + mol = self._smiles_to_mol(smiles) + if mol is None: + return None + try: + returned_result = self._create_augmented_graph(mol) + except Exception as e: + print(f"Failed to construct augmented graph, Error: {e}") + self.f_cnt_for_aug_graph += 1 + return None + if returned_result is None: + return None - _, augmented_mol = returned_result + _, augmented_mol = returned_result return property.get_property_value(augmented_mol) @@ -347,16 +372,16 @@ def _augment_graph_structure( augmented_mol["directed_edge_index"] = directed_edge_index total_atoms = sum([mol.GetNumAtoms(), len(fg_nodes)]) - assert ( - self._idx_of_node == total_atoms - ), f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}" + assert self._idx_of_node == total_atoms, ( + f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}" + ) augmented_mol["node_info"]["fg_nodes"] = fg_nodes augmented_mol["node_info"]["num_nodes"] = self._idx_of_node total_edges = sum([mol.GetNumBonds(), len(atom_fg_edges)]) - assert ( - self._idx_of_edge == total_edges - ), f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}" + assert self._idx_of_edge == total_edges, ( + f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}" + ) augmented_mol["edge_info"][k.ATOM_FG_EDGE] = atom_fg_edges augmented_mol["edge_info"][k.NUM_EDGES] = self._idx_of_edge @@ -579,12 +604,12 @@ def _augment_graph_structure( augmented_struct["edge_info"][k.WITHIN_FG_EDGE] = internal_fg_edges augmented_struct["edge_info"][k.NUM_EDGES] += len(internal_fg_edges) - assert ( - self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES] - ), f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}" - assert ( - self._idx_of_node == augmented_struct["node_info"]["num_nodes"] - ), f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}" + assert self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES], ( + f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}" + ) + assert self._idx_of_node == augmented_struct["node_info"]["num_nodes"], ( + f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}" + ) augmented_struct["directed_edge_index"] = torch.cat( [ @@ -614,9 +639,9 @@ def _construct_fg_level_structure( internal_edge_index = [[], []] def add_fg_internal_edge(source_fg: int, target_fg: int) -> None: - assert ( - source_fg is not None and target_fg is not None - ), "Each bond should have a fg node on both end" + assert source_fg is not None and target_fg is not None, ( + "Each bond should have a fg node on both end" + ) assert source_fg != target_fg, "Source and Target FG should be different" edge_key = tuple(sorted((source_fg, target_fg))) @@ -662,24 +687,25 @@ def add_fg_internal_edge(source_fg: int, target_fg: int) -> None: class _AddGraphNode(_AugmentorReader): """Adds a graph-level node and connects it to selected/given nodes.""" - def _read_data(self, smiles: str) -> GeomData | None: + def _read_data(self, raw_data: str | Chem.Mol) -> tuple[GeomData, dict] | None: """ Reads data and adds a graph-level node annotation. Args: - smiles (str): SMILES string. + raw_data (str | Chem.Mol): SMILES string or RDKit molecule object representing the molecule. Returns: Data | None: Geometric data object with is_graph_node annotation. """ - geom_data = super()._read_data(smiles) - if geom_data is None: + res = super()._read_data(raw_data) + if res is None: return None + geom_data, augmented_mol = res NUM_NODES = geom_data.x.shape[0] is_graph_node = torch.zeros(NUM_NODES, dtype=torch.bool) is_graph_node[-1] = True geom_data.is_graph_node = is_graph_node - return geom_data + return (geom_data, augmented_mol) def _add_graph_node_and_edges_to_nodes( self, @@ -702,15 +728,15 @@ def _add_graph_node_and_edges_to_nodes( augmented_struct["edge_info"][k.TO_GRAPHNODE_EDGE] = nodes_to_graph_edges augmented_struct["edge_info"][k.NUM_EDGES] += len(nodes_to_graph_edges) - assert ( - self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES] - ), f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}" + assert self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES], ( + f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}" + ) augmented_struct["node_info"]["graph_node"] = graph_node augmented_struct["node_info"]["num_nodes"] += 1 - assert ( - self._idx_of_node == augmented_struct["node_info"]["num_nodes"] - ), f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}" + assert self._idx_of_node == augmented_struct["node_info"]["num_nodes"], ( + f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}" + ) augmented_struct["directed_edge_index"] = torch.cat( [ diff --git a/chebai_graph/preprocessing/reader/reader.py b/chebai_graph/preprocessing/reader/reader.py index a63b8a1..d5ff4bc 100644 --- a/chebai_graph/preprocessing/reader/reader.py +++ b/chebai_graph/preprocessing/reader/reader.py @@ -1,6 +1,7 @@ import os import chebai.preprocessing.reader as dr +from chebi_utils.sdf_extractor import _sanitize_molecule import networkx as nx import pysmiles as ps import rdkit.Chem as Chem @@ -54,30 +55,33 @@ def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None: if smiles in self.mol_object_buffer: return self.mol_object_buffer[smiles] - mol = Chem.MolFromSmiles(smiles) + mol = Chem.MolFromSmiles(smiles, sanitize=False) if mol is None: print(f"RDKit failed to at parsing {smiles} (returned None)") self.failed_counter += 1 else: try: - Chem.SanitizeMol(mol) + _sanitize_molecule(mol) except Exception as e: print(f"Rdkit failed at sanitizing {smiles}, \n Error: {e}") self.failed_counter += 1 self.mol_object_buffer[smiles] = mol return mol - def _read_data(self, raw_data: str) -> GeomData | None: + def _read_data(self, raw_data: str | Chem.Mol) -> tuple[GeomData, Chem.Mol] | None: """ Convert raw SMILES string data into a PyTorch Geometric Data object. Args: - raw_data (str): SMILES string. + raw_data (str | Chem.Mol): SMILES string or RDKit molecule object. Returns: GeomData | None: Graph data object or None if molecule parsing failed. """ - mol = self._smiles_to_mol(raw_data) + if isinstance(raw_data, Chem.Mol): + mol = raw_data + else: + mol = self._smiles_to_mol(raw_data) if mol is None: return None @@ -91,7 +95,7 @@ def _read_data(self, raw_data: str) -> GeomData | None: # edge_index.shape == [2, num_edges]; edge_attr.shape == [num_edges, num_edge_features] edge_attr = torch.zeros((edge_index.size(1), 0)) - return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) + return (GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr), mol) def on_finish(self) -> None: """ @@ -100,18 +104,20 @@ def on_finish(self) -> None: print(f"Failed to read {self.failed_counter} SMILES in total") self.mol_object_buffer = {} - def read_property(self, smiles: str, property: MolecularProperty) -> list | None: + def read_property( + self, raw_data: str | Chem.Mol, property: MolecularProperty + ) -> list | None: """ Read a molecular property for a given SMILES string. Args: - smiles (str): SMILES string of the molecule. + raw_data (str | Chem.Mol): SMILES string or RDKit molecule object of the molecule. property (MolecularProperty): Property extractor to apply. Returns: list | None: Property values or None if molecule parsing failed. """ - mol = self._smiles_to_mol(smiles) + mol = self._smiles_to_mol(raw_data) if isinstance(raw_data, str) else raw_data if mol is None: return None return property.get_property_value(mol) @@ -144,19 +150,21 @@ def name(cls) -> str: """ return "graph" - def _read_data(self, raw_data: str) -> GeomData | None: + def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: """ Convert a SMILES string into a PyTorch Geometric Data object with atom tokens and bond order attributes. Args: - raw_data (str): SMILES string. + raw_data (str | Chem.Mol): SMILES string or RDKit molecule object. Returns: GeomData | None: Graph data object or None if parsing failed. """ # raw_data is a SMILES string try: - mol = ps.read_smiles(raw_data) + mol = ( + self._smiles_to_mol(raw_data) if isinstance(raw_data, str) else raw_data + ) except ValueError: return None assert isinstance(mol, nx.Graph) @@ -190,6 +198,27 @@ def _read_data(self, raw_data: str) -> GeomData | None: data = from_networkx(mol) return data + def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None: + """ + Load SMILES string into an RDKit molecule object. + + Args: + smiles (str): The SMILES string to parse. + + Returns: + Chem.rdchem.Mol | None: Parsed molecule object or None if parsing failed. + """ + + mol = Chem.MolFromSmiles(smiles, sanitize=False) + if mol is None: + print(f"RDKit failed to at parsing {smiles} (returned None)") + else: + try: + _sanitize_molecule(mol) + except Exception as e: + print(f"Rdkit failed at sanitizing {smiles}, \n Error: {e}") + return mol + def collate(self, list_of_tuples: list) -> any: """ Collate a list of samples into a batch. diff --git a/configs/model/gat_aug_aapool.yml b/configs/model/gat_aug_aapool.yml index 155d402..fae47c3 100644 --- a/configs/model/gat_aug_aapool.yml +++ b/configs/model/gat_aug_aapool.yml @@ -7,7 +7,7 @@ init_args: hidden_channels: 256 out_channels: 512 num_layers: 4 - edge_dim: 11 # number of bond properties + edge_dim: 12 # number of bond properties heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given) v2: True # This uses `torch_geometric.nn.conv.GATv2Conv` convolution layers, set False to use `GATConv` n_molecule_properties: 0 diff --git a/configs/model/gat_aug_amgpool.yml b/configs/model/gat_aug_amgpool.yml index 048bbe1..e596487 100644 --- a/configs/model/gat_aug_amgpool.yml +++ b/configs/model/gat_aug_amgpool.yml @@ -7,7 +7,7 @@ init_args: hidden_channels: 256 out_channels: 512 num_layers: 4 - edge_dim: 11 # number of bond properties + edge_dim: 12 # number of bond properties heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given) v2: True # This uses `torch_geometric.nn.conv.GATv2Conv` convolution layers, set False to use `GATConv` dropout: 0 diff --git a/configs/model/res_aug_aapool.yml b/configs/model/res_aug_aapool.yml index 9e364f9..de28d1c 100644 --- a/configs/model/res_aug_aapool.yml +++ b/configs/model/res_aug_aapool.yml @@ -7,7 +7,7 @@ init_args: hidden_channels: 256 out_channels: 512 num_layers: 4 - edge_dim: 11 # number of bond properties + edge_dim: 12 # number of bond properties dropout: 0 n_molecule_properties: 0 n_linear_layers: 1 diff --git a/configs/model/res_aug_amgpool.yml b/configs/model/res_aug_amgpool.yml index 2aba5ea..9194cd7 100644 --- a/configs/model/res_aug_amgpool.yml +++ b/configs/model/res_aug_amgpool.yml @@ -7,7 +7,7 @@ init_args: hidden_channels: 256 out_channels: 512 num_layers: 4 - edge_dim: 11 # number of bond properties + edge_dim: 12 # number of bond properties dropout: 0 n_molecule_properties: 0 n_linear_layers: 1