2323# https://mail.python.org/pipermail/python-dev/2017-December/151283.html
2424# Order preservation is necessary to to create `is_atom_node` mask
2525
26+
2627class _AugmentorReader (DataReader , ABC ):
2728 """
2829 Abstract base class for augmentor readers that extend ChemDataReader.
@@ -85,13 +86,17 @@ def _read_data(self, raw_data: str | Chem.Mol) -> tuple[GeomData, dict] | None:
8586 try :
8687 returned_result = self ._create_augmented_graph (mol )
8788 except Exception as e :
88- print (f"Failed to construct augmented graph for smiles { smiles } , Error: { e } " )
89+ print (
90+ f"Failed to construct augmented graph for smiles { smiles } , Error: { e } "
91+ )
8992 self .f_cnt_for_aug_graph += 1
9093 return None
9194
9295 # If the returned result is None, it indicates that the graph augmentation failed
9396 if returned_result is None :
94- print (f"Failed to construct augmented graph for smiles { smiles } (returned None)" )
97+ print (
98+ f"Failed to construct augmented graph for smiles { smiles } (returned None)"
99+ )
95100 self .f_cnt_for_aug_graph += 1
96101 return None
97102
@@ -100,36 +105,39 @@ def _read_data(self, raw_data: str | Chem.Mol) -> tuple[GeomData, dict] | None:
100105
101106 # Empty features initialized; node and edge features can be added later
102107 NUM_NODES = augmented_molecule ["nodes" ]["num_nodes" ]
103- assert (
104- NUM_NODES is not None and NUM_NODES > 1
105- ), "Num of nodes in augmented graph should be more than 1"
108+ assert NUM_NODES is not None and NUM_NODES > 1 , (
109+ "Num of nodes in augmented graph should be more than 1"
110+ )
106111
107112 x = torch .zeros ((NUM_NODES , 0 ))
108113 edge_attr = torch .zeros ((augmented_molecule ["edges" ][k .NUM_EDGES ], 0 ))
109114
110- assert (
111- edge_index . shape [ 0 ] == 2
112- ), f"Expected edge_index to have shape [2, num_edges], but got shape { edge_index . shape } "
115+ assert edge_index . shape [ 0 ] == 2 , (
116+ f"Expected edge_index to have shape [2, num_edges], but got shape { edge_index . shape } "
117+ )
113118
114- assert (
115- edge_index .shape [1 ] == edge_attr .shape [0 ]
116- ), f"Mismatch between number of edges in edge_index ( { edge_index . shape [ 1 ] } ) and edge_attr ( { edge_attr . shape [ 0 ] } )"
119+ assert edge_index . shape [ 1 ] == edge_attr . shape [ 0 ], (
120+ f"Mismatch between number of edges in edge_index ( { edge_index .shape [1 ]} ) and edge_attr ( { edge_attr .shape [0 ]} )"
121+ )
117122
118- assert (
119- len (set (edge_index [0 ].tolist ())) == x .shape [0 ]
120- ), f"Number of unique source nodes in edge_index ( { len ( set ( edge_index [ 0 ]. tolist ())) } ) does not match number of nodes in x ( { x . shape [ 0 ] } )"
123+ assert len ( set ( edge_index [ 0 ]. tolist ())) == x . shape [ 0 ], (
124+ f"Number of unique source nodes in edge_index ( { len (set (edge_index [0 ].tolist ()))} ) does not match number of nodes in x ( { x .shape [0 ]} )"
125+ )
121126
122127 # Create a boolean mask: True for atom, False for augmented
123128 is_atom_mask = torch .zeros (NUM_NODES , dtype = torch .bool )
124129 NUM_ATOM_NODES = augmented_molecule ["nodes" ]["atom_nodes" ].GetNumAtoms ()
125130 is_atom_mask [:NUM_ATOM_NODES ] = True
126131
127- return (GeomData (
128- x = x ,
129- edge_index = edge_index ,
130- edge_attr = edge_attr ,
131- is_atom_node = is_atom_mask ,
132- ), augmented_molecule )
132+ return (
133+ GeomData (
134+ x = x ,
135+ edge_index = edge_index ,
136+ edge_attr = edge_attr ,
137+ is_atom_node = is_atom_mask ,
138+ ),
139+ augmented_molecule ,
140+ )
133141
134142 def _smiles_to_mol (self , smiles : str ) -> Chem .Mol | None :
135143 """
@@ -219,19 +227,19 @@ def _augment_graph_structure(self, mol: Chem.Mol) -> dict:
219227 atom_edge_index = self ._generate_atom_level_edge_index (mol )
220228
221229 total_atoms = mol .GetNumAtoms ()
222- assert (
223- self . _idx_of_node == total_atoms
224- ), f"Mismatch in number of nodes: expected { total_atoms } , got { self . _idx_of_node } "
230+ assert self . _idx_of_node == total_atoms , (
231+ f"Mismatch in number of nodes: expected { total_atoms } , got { self . _idx_of_node } "
232+ )
225233
226234 node_info = {
227235 "atom_nodes" : mol ,
228236 "num_nodes" : self ._idx_of_node ,
229237 }
230238
231239 total_edges = mol .GetNumBonds ()
232- assert (
233- self . _idx_of_edge == total_edges
234- ), f"Mismatch in number of edges: expected { total_edges } , got { self . _idx_of_edge } "
240+ assert self . _idx_of_edge == total_edges , (
241+ f"Mismatch in number of edges: expected { total_edges } , got { self . _idx_of_edge } "
242+ )
235243 edge_info = {
236244 k .WITHIN_ATOMS_EDGE : mol ,
237245 k .NUM_EDGES : self ._idx_of_edge ,
@@ -285,7 +293,9 @@ def on_finish(self) -> None:
285293 )
286294 self .mol_object_buffer = {}
287295
288- def read_property (self , raw_data : str | Chem .Mol | dict , property : MolecularProperty ) -> list | None :
296+ def read_property (
297+ self , raw_data : str | Chem .Mol | dict , property : MolecularProperty
298+ ) -> list | None :
289299 """
290300 Reads a specific property from a molecule represented by a SMILES string.
291301
@@ -362,16 +372,16 @@ def _augment_graph_structure(
362372 augmented_mol ["directed_edge_index" ] = directed_edge_index
363373
364374 total_atoms = sum ([mol .GetNumAtoms (), len (fg_nodes )])
365- assert (
366- self . _idx_of_node == total_atoms
367- ), f"Mismatch in number of nodes: expected { total_atoms } , got { self . _idx_of_node } "
375+ assert self . _idx_of_node == total_atoms , (
376+ f"Mismatch in number of nodes: expected { total_atoms } , got { self . _idx_of_node } "
377+ )
368378 augmented_mol ["node_info" ]["fg_nodes" ] = fg_nodes
369379 augmented_mol ["node_info" ]["num_nodes" ] = self ._idx_of_node
370380
371381 total_edges = sum ([mol .GetNumBonds (), len (atom_fg_edges )])
372- assert (
373- self . _idx_of_edge == total_edges
374- ), f"Mismatch in number of edges: expected { total_edges } , got { self . _idx_of_edge } "
382+ assert self . _idx_of_edge == total_edges , (
383+ f"Mismatch in number of edges: expected { total_edges } , got { self . _idx_of_edge } "
384+ )
375385 augmented_mol ["edge_info" ][k .ATOM_FG_EDGE ] = atom_fg_edges
376386 augmented_mol ["edge_info" ][k .NUM_EDGES ] = self ._idx_of_edge
377387
@@ -594,12 +604,12 @@ def _augment_graph_structure(
594604 augmented_struct ["edge_info" ][k .WITHIN_FG_EDGE ] = internal_fg_edges
595605 augmented_struct ["edge_info" ][k .NUM_EDGES ] += len (internal_fg_edges )
596606
597- assert (
598- self ._idx_of_edge == augmented_struct [" edge_info" ][k .NUM_EDGES ]
599- ), f"Mismatch in number of edges: expected { self . _idx_of_edge } , got { augmented_struct [ 'edge_info' ][ k . NUM_EDGES ] } "
600- assert (
601- self ._idx_of_node == augmented_struct [" node_info" ][ " num_nodes" ]
602- ), f"Mismatch in number of nodes: expected { self . _idx_of_node } , got { augmented_struct [ 'node_info' ][ 'num_nodes' ] } "
607+ assert self . _idx_of_edge == augmented_struct [ "edge_info" ][ k . NUM_EDGES ], (
608+ f"Mismatch in number of edges: expected { self ._idx_of_edge } , got { augmented_struct [' edge_info' ][k .NUM_EDGES ]} "
609+ )
610+ assert self . _idx_of_node == augmented_struct [ "node_info" ][ "num_nodes" ], (
611+ f"Mismatch in number of nodes: expected { self ._idx_of_node } , got { augmented_struct [' node_info' ][ ' num_nodes' ] } "
612+ )
603613
604614 augmented_struct ["directed_edge_index" ] = torch .cat (
605615 [
@@ -629,9 +639,9 @@ def _construct_fg_level_structure(
629639 internal_edge_index = [[], []]
630640
631641 def add_fg_internal_edge (source_fg : int , target_fg : int ) -> None :
632- assert (
633- source_fg is not None and target_fg is not None
634- ), "Each bond should have a fg node on both end"
642+ assert source_fg is not None and target_fg is not None , (
643+ "Each bond should have a fg node on both end"
644+ )
635645 assert source_fg != target_fg , "Source and Target FG should be different"
636646
637647 edge_key = tuple (sorted ((source_fg , target_fg )))
@@ -718,15 +728,15 @@ def _add_graph_node_and_edges_to_nodes(
718728
719729 augmented_struct ["edge_info" ][k .TO_GRAPHNODE_EDGE ] = nodes_to_graph_edges
720730 augmented_struct ["edge_info" ][k .NUM_EDGES ] += len (nodes_to_graph_edges )
721- assert (
722- self ._idx_of_edge == augmented_struct [" edge_info" ][k .NUM_EDGES ]
723- ), f"Mismatch in number of edges: expected { self . _idx_of_edge } , got { augmented_struct [ 'edge_info' ][ k . NUM_EDGES ] } "
731+ assert self . _idx_of_edge == augmented_struct [ "edge_info" ][ k . NUM_EDGES ], (
732+ f"Mismatch in number of edges: expected { self ._idx_of_edge } , got { augmented_struct [' edge_info' ][k .NUM_EDGES ]} "
733+ )
724734
725735 augmented_struct ["node_info" ]["graph_node" ] = graph_node
726736 augmented_struct ["node_info" ]["num_nodes" ] += 1
727- assert (
728- self ._idx_of_node == augmented_struct [" node_info" ][ " num_nodes" ]
729- ), f"Mismatch in number of nodes: expected { self . _idx_of_node } , got { augmented_struct [ 'node_info' ][ 'num_nodes' ] } "
737+ assert self . _idx_of_node == augmented_struct [ "node_info" ][ "num_nodes" ], (
738+ f"Mismatch in number of nodes: expected { self ._idx_of_node } , got { augmented_struct [' node_info' ][ ' num_nodes' ] } "
739+ )
730740
731741 augmented_struct ["directed_edge_index" ] = torch .cat (
732742 [
@@ -957,4 +967,3 @@ def _augment_graph_structure(
957967 return self ._add_graph_node_and_edges_to_nodes (
958968 augmented_struct , atom_ids | fg_to_atoms_map .keys ()
959969 )
960-
0 commit comments