Skip to content

Commit 578c2ef

Browse files
committed
reformat w/ ruff
1 parent 6a85dd2 commit 578c2ef

6 files changed

Lines changed: 91 additions & 71 deletions

File tree

chebai_graph/models/dynamic_gni.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def __init__(self, config: dict[str, Any], **kwargs: Any):
8585
print("Using complete randomness: ", self.complete_randomness)
8686

8787
if not self.complete_randomness:
88-
assert (
89-
"pad_node_features" in config or "pad_edge_features" in config
90-
), "Missing 'pad_node_features' or 'pad_edge_features' in config when complete_randomness is False"
88+
assert "pad_node_features" in config or "pad_edge_features" in config, (
89+
"Missing 'pad_node_features' or 'pad_edge_features' in config when complete_randomness is False"
90+
)
9191
self.pad_node_features = (
9292
int(config["pad_node_features"])
9393
if config.get("pad_node_features") is not None
@@ -112,9 +112,9 @@ def __init__(self, config: dict[str, Any], **kwargs: Any):
112112
f"in each forward pass."
113113
)
114114

115-
assert (
116-
self.pad_node_features > 0 or self.pad_edge_features > 0
117-
), "'pad_node_features' or 'pad_edge_features' must be positive integers"
115+
assert self.pad_node_features > 0 or self.pad_edge_features > 0, (
116+
"'pad_node_features' or 'pad_edge_features' must be positive integers"
117+
)
118118

119119
self.resgated: BasicGNN = ResGatedModel(
120120
in_channels=self.in_channels,
@@ -182,9 +182,9 @@ def forward(self, batch: dict[str, Any]) -> Tensor:
182182
)
183183
new_edge_attr = torch.cat((graph_data.edge_attr, pad_edge), dim=1)
184184

185-
assert (
186-
new_x is not None and new_edge_attr is not None
187-
), "Feature initialization failed"
185+
assert new_x is not None and new_edge_attr is not None, (
186+
"Feature initialization failed"
187+
)
188188
out = self.resgated(
189189
x=new_x.float(),
190190
edge_index=graph_data.edge_index.long(),

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,14 @@ def _merge_props_into_base(
610610
enc_len = property_values.shape[1]
611611
# -------------- Node properties ---------------
612612
if isinstance(property, AllNodeTypeProperty):
613-
x[:, atom_offset : atom_offset + enc_len] = property_values
613+
try:
614+
x[:, atom_offset : atom_offset + enc_len] = property_values
615+
except Exception as e:
616+
raise ValueError(
617+
f"Error assigning property '{property.name}' values to node features: {e}\n"
618+
f"Property values shape: {property_values.shape}, expected (num_nodes, {enc_len})\n"
619+
f"Node feature matrix shape: {x.shape}"
620+
)
614621
atom_offset += enc_len
615622
fg_offset += enc_len
616623
graph_offset += enc_len

chebai_graph/preprocessing/properties/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,9 @@ def get_property_value(self, augmented_mol: dict) -> list:
262262
)
263263
prop_list.append(self.get_atom_value(graph_node))
264264

265-
assert (
266-
len(prop_list) == augmented_mol[self.MAIN_KEY]["num_nodes"]
267-
), "Number of property values should be equal to number of nodes"
265+
assert len(prop_list) == augmented_mol[self.MAIN_KEY]["num_nodes"], (
266+
"Number of property values should be equal to number of nodes"
267+
)
268268
return prop_list
269269

270270
def _check_modify_atom_prop_value(
@@ -390,9 +390,9 @@ def get_property_value(self, augmented_mol: dict) -> list:
390390
)
391391

392392
num_directed_edges = augmented_mol[self.MAIN_KEY][k.NUM_EDGES] // 2
393-
assert (
394-
len(prop_list) == num_directed_edges
395-
), f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} "
393+
assert len(prop_list) == num_directed_edges, (
394+
f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} "
395+
)
396396

397397
return prop_list
398398

chebai_graph/preprocessing/property_encoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,9 @@ def encode(self, token: float | int | None) -> torch.Tensor:
261261
"""
262262
if token is None:
263263
return torch.zeros(1, self.get_encoding_length())
264-
assert (
265-
len(token) == self.get_encoding_length()
266-
), "Length of token should be equal to encoding length"
264+
assert len(token) == self.get_encoding_length(), (
265+
"Length of token should be equal to encoding length"
266+
)
267267
# return torch.tensor([token]) # token is an ndarray, no need to create list of ndarray due to below warning
268268
# UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow.
269269
# Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor.

chebai_graph/preprocessing/reader/augmented_reader.py

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# https://mail.python.org/pipermail/python-dev/2017-December/151283.html
2424
# Order preservation is necessary to to create `is_atom_node` mask
2525

26+
2627
class _AugmentorReader(DataReader, ABC):
2728
"""
2829
Abstract base class for augmentor readers that extend ChemDataReader.
@@ -85,13 +86,17 @@ def _read_data(self, raw_data: str | Chem.Mol) -> tuple[GeomData, dict] | None:
8586
try:
8687
returned_result = self._create_augmented_graph(mol)
8788
except Exception as e:
88-
print(f"Failed to construct augmented graph for smiles {smiles}, Error: {e}")
89+
print(
90+
f"Failed to construct augmented graph for smiles {smiles}, Error: {e}"
91+
)
8992
self.f_cnt_for_aug_graph += 1
9093
return None
9194

9295
# If the returned result is None, it indicates that the graph augmentation failed
9396
if returned_result is None:
94-
print(f"Failed to construct augmented graph for smiles {smiles} (returned None)")
97+
print(
98+
f"Failed to construct augmented graph for smiles {smiles} (returned None)"
99+
)
95100
self.f_cnt_for_aug_graph += 1
96101
return None
97102

@@ -100,36 +105,39 @@ def _read_data(self, raw_data: str | Chem.Mol) -> tuple[GeomData, dict] | None:
100105

101106
# Empty features initialized; node and edge features can be added later
102107
NUM_NODES = augmented_molecule["nodes"]["num_nodes"]
103-
assert (
104-
NUM_NODES is not None and NUM_NODES > 1
105-
), "Num of nodes in augmented graph should be more than 1"
108+
assert NUM_NODES is not None and NUM_NODES > 1, (
109+
"Num of nodes in augmented graph should be more than 1"
110+
)
106111

107112
x = torch.zeros((NUM_NODES, 0))
108113
edge_attr = torch.zeros((augmented_molecule["edges"][k.NUM_EDGES], 0))
109114

110-
assert (
111-
edge_index.shape[0] == 2
112-
), f"Expected edge_index to have shape [2, num_edges], but got shape {edge_index.shape}"
115+
assert edge_index.shape[0] == 2, (
116+
f"Expected edge_index to have shape [2, num_edges], but got shape {edge_index.shape}"
117+
)
113118

114-
assert (
115-
edge_index.shape[1] == edge_attr.shape[0]
116-
), f"Mismatch between number of edges in edge_index ({edge_index.shape[1]}) and edge_attr ({edge_attr.shape[0]})"
119+
assert edge_index.shape[1] == edge_attr.shape[0], (
120+
f"Mismatch between number of edges in edge_index ({edge_index.shape[1]}) and edge_attr ({edge_attr.shape[0]})"
121+
)
117122

118-
assert (
119-
len(set(edge_index[0].tolist())) == x.shape[0]
120-
), f"Number of unique source nodes in edge_index ({len(set(edge_index[0].tolist()))}) does not match number of nodes in x ({x.shape[0]})"
123+
assert len(set(edge_index[0].tolist())) == x.shape[0], (
124+
f"Number of unique source nodes in edge_index ({len(set(edge_index[0].tolist()))}) does not match number of nodes in x ({x.shape[0]})"
125+
)
121126

122127
# Create a boolean mask: True for atom, False for augmented
123128
is_atom_mask = torch.zeros(NUM_NODES, dtype=torch.bool)
124129
NUM_ATOM_NODES = augmented_molecule["nodes"]["atom_nodes"].GetNumAtoms()
125130
is_atom_mask[:NUM_ATOM_NODES] = True
126131

127-
return (GeomData(
128-
x=x,
129-
edge_index=edge_index,
130-
edge_attr=edge_attr,
131-
is_atom_node=is_atom_mask,
132-
), augmented_molecule)
132+
return (
133+
GeomData(
134+
x=x,
135+
edge_index=edge_index,
136+
edge_attr=edge_attr,
137+
is_atom_node=is_atom_mask,
138+
),
139+
augmented_molecule,
140+
)
133141

134142
def _smiles_to_mol(self, smiles: str) -> Chem.Mol | None:
135143
"""
@@ -219,19 +227,19 @@ def _augment_graph_structure(self, mol: Chem.Mol) -> dict:
219227
atom_edge_index = self._generate_atom_level_edge_index(mol)
220228

221229
total_atoms = mol.GetNumAtoms()
222-
assert (
223-
self._idx_of_node == total_atoms
224-
), f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}"
230+
assert self._idx_of_node == total_atoms, (
231+
f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}"
232+
)
225233

226234
node_info = {
227235
"atom_nodes": mol,
228236
"num_nodes": self._idx_of_node,
229237
}
230238

231239
total_edges = mol.GetNumBonds()
232-
assert (
233-
self._idx_of_edge == total_edges
234-
), f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}"
240+
assert self._idx_of_edge == total_edges, (
241+
f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}"
242+
)
235243
edge_info = {
236244
k.WITHIN_ATOMS_EDGE: mol,
237245
k.NUM_EDGES: self._idx_of_edge,
@@ -285,7 +293,9 @@ def on_finish(self) -> None:
285293
)
286294
self.mol_object_buffer = {}
287295

288-
def read_property(self, raw_data: str | Chem.Mol | dict, property: MolecularProperty) -> list | None:
296+
def read_property(
297+
self, raw_data: str | Chem.Mol | dict, property: MolecularProperty
298+
) -> list | None:
289299
"""
290300
Reads a specific property from a molecule represented by a SMILES string.
291301
@@ -362,16 +372,16 @@ def _augment_graph_structure(
362372
augmented_mol["directed_edge_index"] = directed_edge_index
363373

364374
total_atoms = sum([mol.GetNumAtoms(), len(fg_nodes)])
365-
assert (
366-
self._idx_of_node == total_atoms
367-
), f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}"
375+
assert self._idx_of_node == total_atoms, (
376+
f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}"
377+
)
368378
augmented_mol["node_info"]["fg_nodes"] = fg_nodes
369379
augmented_mol["node_info"]["num_nodes"] = self._idx_of_node
370380

371381
total_edges = sum([mol.GetNumBonds(), len(atom_fg_edges)])
372-
assert (
373-
self._idx_of_edge == total_edges
374-
), f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}"
382+
assert self._idx_of_edge == total_edges, (
383+
f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}"
384+
)
375385
augmented_mol["edge_info"][k.ATOM_FG_EDGE] = atom_fg_edges
376386
augmented_mol["edge_info"][k.NUM_EDGES] = self._idx_of_edge
377387

@@ -594,12 +604,12 @@ def _augment_graph_structure(
594604
augmented_struct["edge_info"][k.WITHIN_FG_EDGE] = internal_fg_edges
595605
augmented_struct["edge_info"][k.NUM_EDGES] += len(internal_fg_edges)
596606

597-
assert (
598-
self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES]
599-
), f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}"
600-
assert (
601-
self._idx_of_node == augmented_struct["node_info"]["num_nodes"]
602-
), f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}"
607+
assert self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES], (
608+
f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}"
609+
)
610+
assert self._idx_of_node == augmented_struct["node_info"]["num_nodes"], (
611+
f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}"
612+
)
603613

604614
augmented_struct["directed_edge_index"] = torch.cat(
605615
[
@@ -629,9 +639,9 @@ def _construct_fg_level_structure(
629639
internal_edge_index = [[], []]
630640

631641
def add_fg_internal_edge(source_fg: int, target_fg: int) -> None:
632-
assert (
633-
source_fg is not None and target_fg is not None
634-
), "Each bond should have a fg node on both end"
642+
assert source_fg is not None and target_fg is not None, (
643+
"Each bond should have a fg node on both end"
644+
)
635645
assert source_fg != target_fg, "Source and Target FG should be different"
636646

637647
edge_key = tuple(sorted((source_fg, target_fg)))
@@ -718,15 +728,15 @@ def _add_graph_node_and_edges_to_nodes(
718728

719729
augmented_struct["edge_info"][k.TO_GRAPHNODE_EDGE] = nodes_to_graph_edges
720730
augmented_struct["edge_info"][k.NUM_EDGES] += len(nodes_to_graph_edges)
721-
assert (
722-
self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES]
723-
), f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}"
731+
assert self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES], (
732+
f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}"
733+
)
724734

725735
augmented_struct["node_info"]["graph_node"] = graph_node
726736
augmented_struct["node_info"]["num_nodes"] += 1
727-
assert (
728-
self._idx_of_node == augmented_struct["node_info"]["num_nodes"]
729-
), f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}"
737+
assert self._idx_of_node == augmented_struct["node_info"]["num_nodes"], (
738+
f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}"
739+
)
730740

731741
augmented_struct["directed_edge_index"] = torch.cat(
732742
[
@@ -957,4 +967,3 @@ def _augment_graph_structure(
957967
return self._add_graph_node_and_edges_to_nodes(
958968
augmented_struct, atom_ids | fg_to_atoms_map.keys()
959969
)
960-

chebai_graph/preprocessing/reader/reader.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def on_finish(self) -> None:
104104
print(f"Failed to read {self.failed_counter} SMILES in total")
105105
self.mol_object_buffer = {}
106106

107-
def read_property(self, raw_data: str | Chem.Mol, property: MolecularProperty) -> list | None:
107+
def read_property(
108+
self, raw_data: str | Chem.Mol, property: MolecularProperty
109+
) -> list | None:
108110
"""
109111
Read a molecular property for a given SMILES string.
110112
@@ -160,7 +162,9 @@ def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None:
160162
"""
161163
# raw_data is a SMILES string
162164
try:
163-
mol = self._smiles_to_mol(raw_data) if isinstance(raw_data, str) else raw_data
165+
mol = (
166+
self._smiles_to_mol(raw_data) if isinstance(raw_data, str) else raw_data
167+
)
164168
except ValueError:
165169
return None
166170
assert isinstance(mol, nx.Graph)
@@ -193,7 +197,7 @@ def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None:
193197
nx.set_edge_attributes(mol, de, "edge_attr")
194198
data = from_networkx(mol)
195199
return data
196-
200+
197201
def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None:
198202
"""
199203
Load SMILES string into an RDKit molecule object.

0 commit comments

Comments
 (0)