From 26cd45873b70f285daa0981b9738926331a54698 Mon Sep 17 00:00:00 2001 From: kfir4444 Date: Tue, 24 Mar 2026 08:44:43 +0200 Subject: [PATCH 01/11] species: Fix cyclic scission in _scissors by updating radical electrons and multiplicity When a bond in a cyclic molecule is broken (ring-opening), the resulting fragment must have new radical electrons assigned to the atoms that were part of the severed bond. This commit fixes _scissors() to: 1. Detect ring-opening cases (len(mol_splits) == 1). 2. Assign radical electrons to the cut atoms based on their missing valency. 3. Update the fragment's multiplicity. 4. Set keep_mol=True and provide final_xyz to the new ARCSpecies to ensure proper initialization. This prevents ValueError and SpeciesError during molecule perception when mapping reactions involving ring-opening/closing. --- arc/species/species.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/arc/species/species.py b/arc/species/species.py index a94ce01c00..743427ea59 100644 --- a/arc/species/species.py +++ b/arc/species/species.py @@ -1972,13 +1972,23 @@ def _scissors(self, sort_atoms_in_descending_label_order(split) if len(mol_splits) == 1: # If cutting leads to only one split, then the split is cyclic. + mol1 = mol_splits[0] + for atom in mol1.atoms: + theoretical_charge = elements.PeriodicSystem.valence_electrons[atom.symbol] \ + - atom.get_total_bond_order() \ + - atom.radical_electrons - \ + 2 * atom.lone_pairs + if theoretical_charge == atom.charge + 1: + atom.radical_electrons += 1 + mol1.update_multiplicity() spc1 = ARCSpecies(label=self.label + '_BDE_' + str(indices[0] + 1) + '_' + str(indices[1] + 1) + '_cyclic', - mol=mol_splits[0], - multiplicity=mol_splits[0].multiplicity, - charge=mol_splits[0].get_net_charge(), + mol=mol1, + xyz=self.final_xyz, + multiplicity=mol1.multiplicity, + charge=mol1.get_net_charge(), compute_thermo=False, - e0_only=True) - spc1.generate_conformers() + e0_only=True, + keep_mol=True) return [spc1] elif len(mol_splits) == 2: mol1, mol2 = mol_splits @@ -2033,7 +2043,6 @@ def _scissors(self, compute_thermo=False, e0_only=True, keep_mol=True) - spc1.generate_conformers() spc1.rotors_dict = None spc2 = ARCSpecies(label=label2, mol=mol2, @@ -2043,7 +2052,6 @@ def _scissors(self, compute_thermo=False, e0_only=True, keep_mol=True) - spc2.generate_conformers() spc2.rotors_dict = None return [spc1, spc2] From 30d02e2c1452e4bcb2d56484a8521d33cb0704f5 Mon Sep 17 00:00:00 2001 From: kfir4444 Date: Tue, 24 Mar 2026 08:44:43 +0200 Subject: [PATCH 02/11] mapping: Handle failed scission in map_rxn to avoid TypeError Added a check for None results from cut_species_based_on_atom_indices in map_rxn. If scission fails for reactants or products, the function now attempts to fall back to the next dictionary template (if available) or returns None with a logged error, preventing a TypeError when calling update_xyz on a None object. --- arc/mapping/driver.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/arc/mapping/driver.py b/arc/mapping/driver.py index 64fa5fb899..a6cc031a60 100644 --- a/arc/mapping/driver.py +++ b/arc/mapping/driver.py @@ -273,6 +273,13 @@ def map_rxn(rxn: 'ARCReaction', r_bdes, p_bdes = find_all_breaking_bonds(rxn, r_direction=True, pdi=pdi), find_all_breaking_bonds(rxn, r_direction=False, pdi=pdi) r_cuts, p_cuts = cut_species_based_on_atom_indices(reactants, r_bdes), cut_species_based_on_atom_indices(products, p_bdes) + if r_cuts is None or p_cuts is None: + if rxn.product_dicts is not None and len(rxn.product_dicts) - 1 > pdi < MAX_PDI: + return map_rxn(rxn, backend=backend, product_dict_index_to_try=pdi + 1) + else: + logger.error(f'Could not cut species for reaction {rxn}') + return None + try: r_label_map = rxn.product_dicts[pdi]['r_label_map'] p_label_map = rxn.product_dicts[pdi]['p_label_map'] From c054addf950737295adfcf3ceb5ca171486d75fe Mon Sep 17 00:00:00 2001 From: kfir4444 Date: Tue, 24 Mar 2026 08:44:43 +0200 Subject: [PATCH 03/11] tests: Use assertIn for benzene scission atom mapping Updated benzene scission tests to allow for multiple valid atom mappings due to the symmetry of the molecule. Replacing assertEqual with assertIn ensures the tests pass if any of the chemically equivalent valid maps are returned. This was also verified with the NEB, achieving the same TS with both AMs. --- arc/mapping/driver_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arc/mapping/driver_test.py b/arc/mapping/driver_test.py index 89c8a2fb28..023965500a 100644 --- a/arc/mapping/driver_test.py +++ b/arc/mapping/driver_test.py @@ -1086,7 +1086,7 @@ def test_get_atom_map_2(self): 11 H u0 p0 c0 {4,S} 12 H u0 p0 c0 {5,S}""") rxn = ARCReaction(reactants=['C6H6_a'], products=['C6H6_b'], r_species=[r_1], p_species=[p_1]) - self.assertEqual(rxn.atom_map, [3, 2, 1, 0, 5, 4, 10, 9, 8, 7, 6, 11]) + self.assertIn(rxn.atom_map, [[3, 2, 1, 0, 5, 4, 10, 9, 8, 7, 6, 11], [3, 2, 1, 0, 5, 4, 10, 9, 8, 6, 7, 11]]) self.assertTrue(check_atom_map(rxn)) # Disproportionation: HO2 + NHOH <=> NH2OH + O2 @@ -1339,7 +1339,7 @@ def test_get_atom_map_6(self): 11 H u0 p0 c0 {4,S} 12 H u0 p0 c0 {5,S}""") rxn = ARCReaction(reactants=['C6H6_1'], products=['C6H6_b'], r_species=[r_1], p_species=[p_1]) - self.assertEqual(rxn.atom_map, [3, 2, 1, 0, 5, 4, 10, 9, 8, 7, 6, 11]) + self.assertIn(rxn.atom_map, [[3, 2, 1, 0, 5, 4, 10, 9, 8, 7, 6, 11], [3, 2, 1, 0, 5, 4, 10, 9, 8, 6, 7, 11]]) self.assertTrue(check_atom_map(rxn)) def test_get_atom_map_7(self): From 86eb282e8ae6be344ecea5ff02df6624b513e2d0 Mon Sep 17 00:00:00 2001 From: kfir4444 Date: Wed, 25 Mar 2026 07:41:13 +0200 Subject: [PATCH 04/11] mapping: Optimize atom mapping performance and fix pruning logic This commit introduces several optimizations to the atom mapping algorithm: 1. identify_superimposable_candidates: Reduced algorithmic complexity from O(N^2) to O(N) by starting graph traversal from only the first heavy atom. For connected molecular graphs, this is sufficient to find all valid mappings (including symmetries) and significantly reduces redundant DFS calls. 2. prune_identical_dicts: Fixed a logic bug where dictionaries were incorrectly pruned if they shared a single key-value pair. Now correctly uses exact dictionary equality. 3. pairing_reactants_and_products_for_mapping: Pre-calculates resonance structures for all reactant fragments once before the pairing loop, avoiding redundant O(R*P) computations. 4. copy_species_list_for_mapping: Replaced the expensive spc.copy() (which uses dictionary serialization) with a lighter direct ARCSpecies instantiation to reduce overhead during the mapping process. --- arc/mapping/engine.py | 43 ++++++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/arc/mapping/engine.py b/arc/mapping/engine.py index c400a1161a..c2ff56073d 100644 --- a/arc/mapping/engine.py +++ b/arc/mapping/engine.py @@ -286,12 +286,14 @@ def identify_superimposable_candidates(fingerprint_1: Dict[int, Dict[str, Union[ of species 1, values are potentially mapped atom indices of species 2. """ candidates = list() - for key_1 in fingerprint_1.keys(): - for key_2 in fingerprint_2.keys(): - # Try all combinations of heavy atoms. - result = iterative_dfs(fingerprint_1, fingerprint_2, key_1, key_2) - if result is not None: - candidates.append(result) + if not fingerprint_1: + return [] + key_1 = list(fingerprint_1.keys())[0] + for key_2 in fingerprint_2.keys(): + # Try all combinations of heavy atoms. + result = iterative_dfs(fingerprint_1, fingerprint_2, key_1, key_2) + if result is not None: + candidates.append(result) return prune_identical_dicts(candidates) @@ -384,14 +386,7 @@ def prune_identical_dicts(dicts_list: List[dict]) -> List[dict]: """ new_dicts_list = list() for new_dict in dicts_list: - unique_ = True - for existing_dict in new_dicts_list: - if unique_: - for new_key, new_val in new_dict.items(): - if new_key not in existing_dict.keys() or new_val == existing_dict[new_key]: - unique_ = False - break - if unique_: + if new_dict not in new_dicts_list: new_dicts_list.append(new_dict) return new_dicts_list @@ -1197,11 +1192,18 @@ def pairing_reactants_and_products_for_mapping(r_cuts: List[ARCSpecies], List[Tuple[ARCSpecies,ARCSpecies]]: A list of paired reactant and products, to be sent to map_two_species. """ pairs: List[Tuple[ARCSpecies, ARCSpecies]] = list() - for react in r_cuts: + r_res = [generate_resonance_structures_safely(react.mol, save_order=True) for react in r_cuts] + for i, react in enumerate(r_cuts): + res1 = r_res[i] for idx, prod in enumerate(p_cuts): - if r_cut_p_cut_isomorphic(react, prod): - pairs.append((react, prod)) - p_cuts.pop(idx) + found = False + for res in res1: + if res.fingerprint == prod.mol.fingerprint or prod.mol.is_isomorphic(res, save_order=True): + pairs.append((react, prod)) + p_cuts.pop(idx) + found = True + break + if found: break return pairs @@ -1462,7 +1464,10 @@ def copy_species_list_for_mapping(species: List["ARCSpecies"]) -> List["ARCSpeci Returns: List[ARCSpecies]: The copied species list. """ - copies = [spc.copy() for spc in species] + copies = list() + for spc in species: + new_spc = ARCSpecies(label=spc.label, mol=spc.mol.copy(deep=True), xyz=spc.get_xyz(), keep_mol=True) + copies.append(new_spc) for copy, spc in zip(copies, species): for atom1, atom2 in zip(copy.mol.atoms, spc.mol.atoms): atom1.label = atom2.label From 4aada2c365976ce8f1de98a14391e20c5e18e700 Mon Sep 17 00:00:00 2001 From: kfir4444 Date: Wed, 25 Mar 2026 07:41:13 +0200 Subject: [PATCH 05/11] species: Optimize ARCSpecies constructor to skip redundant molecule perception Modified ARCSpecies.__init__ to skip the expensive mol_from_xyz() call (which performs molecule perception from 3D coordinates) when a valid Molecule object is already provided and the keep_mol flag is set. This significantly reduces initialization overhead during scission and atom mapping operations where the molecular graph is already known. --- arc/species/species.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arc/species/species.py b/arc/species/species.py index 743427ea59..59b50b928c 100644 --- a/arc/species/species.py +++ b/arc/species/species.py @@ -464,7 +464,7 @@ def __init__(self, self.multiplicity = self.mol.multiplicity if self.charge is None: self.charge = self.mol.get_net_charge() - if regen_mol: + if regen_mol and not (self.mol is not None and self.keep_mol): # Perceive molecule from xyz coordinates. This also populates the .mol attribute of the Species. # It overrides self.mol generated from adjlist or smiles so xyz and mol will have the same atom order. if self.final_xyz or self.initial_xyz or self.most_stable_conformer or self.conformers or self.ts_guesses: From e8220fc98cfd8f61f0750f8c4dac6131f228cc53 Mon Sep 17 00:00:00 2001 From: kfir4444 Date: Thu, 26 Mar 2026 12:56:41 +0200 Subject: [PATCH 06/11] Generalized tests --- arc/mapping/driver_test.py | 54 +++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/arc/mapping/driver_test.py b/arc/mapping/driver_test.py index 023965500a..8531f515c2 100644 --- a/arc/mapping/driver_test.py +++ b/arc/mapping/driver_test.py @@ -1386,10 +1386,12 @@ def test_get_atom_map_8(self): rxn = ARCReaction(r_species=[ARCSpecies(label="r1", smiles="F[C]F", xyz=r1_xyz), ARCSpecies(label="r2", smiles="[CH3]", xyz=r2_xyz)], p_species=[ARCSpecies(label="p1", smiles="F[C](F)C", xyz=p1_xyz)]) - self.assertIn(rxn.atom_map[:2], [[0, 1], [1, 0]]) - self.assertEqual(rxn.atom_map[2], 2) - self.assertEqual(rxn.atom_map[3], 3) - self.assertIn(tuple(rxn.atom_map[4:]), list(permutations([4, 5, 6]))) + if rxn.atom_map[0] == 0: + self.assertEqual(rxn.atom_map[:4], [0, 1, 2, 3]) + else: # Only other F can be in position 0. + self.assertEqual(rxn.atom_map[:4], [2, 1, 0, 3]) + self.assertIn(tuple(rxn.atom_map[4:]), tuple(permutations([4, 5, 6]))) + self.assertTrue(check_atom_map(rxn)) def test_get_atom_map_9(self): @@ -1508,22 +1510,42 @@ def test_get_atom_map_11(self): rxn = ARCReaction(reactants=['C4H10', 'CO'], products=['C5H10O'], r_species=[r_1, r_2], p_species=[p_1]) atom_map = rxn.atom_map - self.assertEqual(atom_map[:4], [0, 1, 2, 3]) - self.assertIn(tuple(rxn.atom_map[4:7]), permutations([6, 7, 8])) - self.assertEqual(atom_map[7], 15) - self.assertIn(tuple(rxn.atom_map[8:11]), permutations([9, 10, 11])) - self.assertIn(tuple(rxn.atom_map[11:14]), permutations([12, 13, 14])) - self.assertEqual(atom_map[14:], [4, 5]) self.assertTrue(check_atom_map(rxn)) + # Set all anchor- atoms uneffected by symmetry. + self.assertEqual(atom_map[1], 1) # Middle Carbon + self.assertEqual(atom_map[7], 15) # Middel Hydrogen + self.assertEqual(atom_map[-2:], [4, 5]) # CO (In that order!) + # Check the symmetric carbons: + symm_carbon_hydrogens_r = { + 0: [4, 5, 6], + 2: [8, 9, 10], + 3: [11, 12, 13] + } + symm_carbon_hydrogens_p = { + 0: [6, 7, 8], + 2: [9, 10, 11], + 3: [12, 13, 14] + } + for r_atom, p_atom in enumerate(atom_map[:4]): + if r_atom == 1: + continue # anchor carbon. + self.assertIn(p_atom, [0, 2, 3]) + for r_h in symm_carbon_hydrogens_r[r_atom]: + self.assertIn(atom_map[r_h], symm_carbon_hydrogens_p[p_atom]) + + # same reaction in reverse: rxn_rev = ARCReaction(r_species=[p_1], p_species=[r_1, r_2]) atom_map = rxn_rev.atom_map - for index in [0, 2, 3]: - self.assertIn(atom_map[index], [0, 2, 3]) - self.assertEqual(atom_map[1], 1) - self.assertEqual(atom_map[4], 14) - self.assertEqual(atom_map[5], 15) - self.assertEqual(atom_map[15], 7) + self.assertEqual(atom_map[1], 1) # Middle Carbon + self.assertEqual(atom_map[15], 7) # Middel Hydrogen + self.assertEqual(atom_map[4:6], [14, 15]) # CO (In that order!) + for r_atom, p_atom in enumerate(atom_map[:4]): + if r_atom == 1: + continue # anchor carbon. + self.assertIn(p_atom, [0, 2, 3]) + for p_h in symm_carbon_hydrogens_p[r_atom]: + self.assertIn(atom_map[p_h], symm_carbon_hydrogens_r[p_atom]) self.assertTrue(check_atom_map(rxn_rev)) def test_get_atom_map_12(self): From 78df8b1736602b380a52f40f884ab74c241633f9 Mon Sep 17 00:00:00 2001 From: kfir4444 Date: Thu, 26 Mar 2026 15:51:42 +0200 Subject: [PATCH 07/11] test: engine --- arc/mapping/engine_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arc/mapping/engine_test.py b/arc/mapping/engine_test.py index a4810b9636..6b8beff140 100644 --- a/arc/mapping/engine_test.py +++ b/arc/mapping/engine_test.py @@ -826,7 +826,7 @@ def test_identify_superimposable_candidates(self): candidates = engine.identify_superimposable_candidates(fingerprint_1=self.butenylnebzene_fingerprint, fingerprint_2=self.butenylnebzene_fingerprint) - self.assertEqual(candidates, [{0: 0, 5: 5, 4: 4, 3: 3, 2: 2, 1: 1, 6: 6, 7: 7, 8: 8, 9: 9}]) + self.assertEqual(candidates[0], {0: 0, 5: 5, 4: 4, 3: 3, 2: 2, 1: 1, 6: 6, 7: 7, 8: 8, 9: 9}) fingerprint_1 = {0: {'self': 'C', 'C': [1, 2, 4], 'H': [11]}, 1: {'self': 'C', 'C': [0, 3, 9], 'H': [12]}, From 0890f5a84ffe9375d932d9024cbe3f16c75f7a12 Mon Sep 17 00:00:00 2001 From: kfir4444 Date: Sat, 28 Mar 2026 21:39:23 +0300 Subject: [PATCH 08/11] Fix species test To facilitate new scissors logic --- arc/species/species_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arc/species/species_test.py b/arc/species/species_test.py index 8074dd8c96..d5c20cfd6f 100644 --- a/arc/species/species_test.py +++ b/arc/species/species_test.py @@ -2027,7 +2027,7 @@ def test_scissors(self): cycle.final_xyz = cycle.get_xyz() cycle_scissors = cycle.scissors() cycle_scissors[0].mol.update(sort_atoms=False) - self.assertTrue(cycle_scissors[0].mol.is_isomorphic(ARCSpecies(label="check",smiles ="[CH2+]C[CH2+]").mol)) + self.assertTrue(cycle_scissors[0].mol.is_isomorphic(ARCSpecies(label="check",smiles ="[CH2]C[CH2]").mol)) self.assertEqual(len(cycle_scissors), 1) benzyl_alcohol = ARCSpecies(label='benzyl_alcohol', smiles='c1ccccc1CO', From 96dc3e3961facac81b6f39329b685580eaf89247 Mon Sep 17 00:00:00 2001 From: kfir4444 Date: Sat, 28 Mar 2026 21:39:51 +0300 Subject: [PATCH 09/11] Change reaction family detection logic --- arc/reaction/reaction.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/arc/reaction/reaction.py b/arc/reaction/reaction.py index 58875496f3..3a3ea51128 100644 --- a/arc/reaction/reaction.py +++ b/arc/reaction/reaction.py @@ -111,10 +111,13 @@ def __init__(self, self.kinetics = kinetics self.rmg_kinetics = None self.long_kinetic_description = '' - if check_family_name(family): - self.family = family - else: - raise ValueError(f"Invalid family name: {family}") + self._family = None + self._family_determined = False + if family is not None: + if check_family_name(family): + self.family = family + else: + raise ValueError(f"Invalid family name: {family}") self._family_own_reverse = False self.ts_label = ts_label self.dh_rxn298 = None @@ -216,14 +219,16 @@ def multiplicity(self, value): @property def family(self): """The RMG reaction family""" - if self._family is None: + if not self._family_determined: self._family, self._family_own_reverse = self.determine_family() + self._family_determined = True return self._family @family.setter def family(self, value): """Allow setting family""" self._family = value + self._family_determined = True if value is not None and not isinstance(value, str): raise InputError(f'Reaction family must be a string, got {value} which is a {type(value)}.') From f1be392781b1abad69725ffc22e86c34006606e7 Mon Sep 17 00:00:00 2001 From: kfir4444 Date: Sat, 28 Mar 2026 21:45:28 +0300 Subject: [PATCH 10/11] Modify Heuristic test logic for get_new_zmat_2_map --- arc/job/adapters/ts/heuristics_test.py | 35 ++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/arc/job/adapters/ts/heuristics_test.py b/arc/job/adapters/ts/heuristics_test.py index 250e10d852..237eedf2fc 100644 --- a/arc/job/adapters/ts/heuristics_test.py +++ b/arc/job/adapters/ts/heuristics_test.py @@ -36,6 +36,10 @@ from arc.species.species import ARCSpecies from arc.species.zmat import _compare_zmats, get_parameter_from_atom_indices +from arc.species.species import check_isomorphism +from arc.species.zmat import remove_zmat_atom_0 +from arc.species.converter import relocate_zmat_dummy_atoms_to_the_end + class TestHeuristicsAdapter(unittest.TestCase): """ @@ -1414,11 +1418,32 @@ def test_get_new_zmat2_map(self): # expected_new_map = {0: 12, 1: 13, 2: 'X24', 3: 14, 4: 15, 5: 16, 6: 'X25', 7: 17, 8: 'X26', 9: 18, 10: 19, # 11: 20, 12: 21, 13: 22, 14: 'X27', 15: 23, 16: 'X28', 17: 2, 18: 3, 19: 1, 21: 4, 23: 0, # 25: 7, 26: 6, 28: 5, 20: 'X8', 22: 'X9', 24: 'X10', 27: 'X11'} - expected_new_map = {0: 12, 1: 13, 2: 'X24', 3: 14, 4: 15, 5: 16, 6: 'X25', 7: 17, 8: 'X26', 9: 18, 10: 19, - 11: 20, 12: 21, 13: 22, 14: 'X27', 15: 23, 16: 'X28', 17: 2, 18: 1, 19: 3, 21: 0, 23: 4, - 25: 5, 26: 6, 28: 7, 20: 'X8', 22: 'X9', 24: 'X10', 27: 'X11'} - - self.assertEqual(new_map, expected_new_map) + + # Test isomorphism of the mapped reactant_2 part + zmat_2_mod = remove_zmat_atom_0(self.zmat_6) + zmat_2_mod['map'] = relocate_zmat_dummy_atoms_to_the_end(zmat_2_mod['map']) + spc_from_zmat_2 = ARCSpecies(label='spc_from_zmat_2', xyz=zmat_2_mod, multiplicity=reactant_2.multiplicity, + number_of_radicals=reactant_2.number_of_radicals, charge=reactant_2.charge) + + # Verify that all physical atom indices in new_map that came from zmat_2 correctly map to reactant_2 + # Atom indices in new_map are for the combined species. + # Atoms 0-16 in self.zmat_5, atoms 1-12 in self.zmat_6 (13 atoms total, index 0 removed). + # In get_new_zmat_2_map, zmat_2 atoms are mapped to indices in new_map. + + num_atoms_1 = len(self.zmat_5['symbols']) + atom_map = dict() + for i in range(1, len(self.zmat_6['symbols'])): + if not isinstance(self.zmat_6['symbols'][i], str) or self.zmat_6['symbols'][i] != 'X': + # This is a physical atom in zmat_2 (at index i) + # Its index in the combined Z-Matrix is num_atoms_1 + i - 1 + combined_idx = num_atoms_1 + i - 1 + if combined_idx in new_map: + # new_map[combined_idx] is the index in reactant_2 + # i-1 is the index in spc_from_zmat_2 + atom_map[i-1] = new_map[combined_idx] + + # Verify the atom_map is a valid isomorphism + self.assertTrue(check_isomorphism(spc_from_zmat_2.mol, reactant_2.mol, atom_map)) def test_get_new_map_based_on_zmat_1(self): """Test the get_new_map_based_on_zmat_1() function.""" From 0c62c032a22ff2742f2a0889bda4d6a5dc609983 Mon Sep 17 00:00:00 2001 From: kfir4444 Date: Sat, 28 Mar 2026 21:50:28 +0300 Subject: [PATCH 11/11] perf: optimize reaction family mapping performance - Implement global ReactionFamily caching to avoid redundant instantiation. - Refactor ReactionFamily initialization to selectively load only leaf-level entries identified from the template, significantly reducing object creation overhead. - Improve RMG group file parsing robustness and performance using pre-compiled regex and recursive entry extraction for OR complexes. - Refine product generation logic to use pre-loaded groups, avoiding repeated parsing and object creation during mapping. --- arc/family/family.py | 174 ++++++++++++++++++++++++++++++------------- 1 file changed, 122 insertions(+), 52 deletions(-) diff --git a/arc/family/family.py b/arc/family/family.py index 6683905fbc..f2d39945d2 100644 --- a/arc/family/family.py +++ b/arc/family/family.py @@ -23,6 +23,37 @@ logger = get_logger() +REACTION_FAMILY_CACHE: Dict[Tuple[str, bool], 'ReactionFamily'] = {} + +# Pre-compiled regex patterns +ENTRY_PATTERN = re.compile(r'entry\((.*?)\)', re.DOTALL) +LABEL_PATTERN = re.compile(r'label\s*=\s*(["\'])(.*?)\1|label\s*=\s*(\w+)') +GROUP_PATTERN = re.compile(r'group\s*=\s*(?:("""(.*?)"""|"(.*?)"|\'(.*?)\')|(OR\{.*?\}))', re.DOTALL) +REVERSIBLE_PATTERN = re.compile(r'reversible\s*=\s*(True|False)') +OWN_REVERSE_PATTERN = re.compile(r'ownReverse\s*=\s*(True|False)') +RECIPE_PATTERN = re.compile(r'recipe\((.*?)\)', re.DOTALL) +REACTANTS_PATTERN = re.compile(r'reactants\s*=\s*\[(.*?)\]', re.DOTALL) +PRODUCTS_PATTERN = re.compile(r'products\s*=\s*\[(.*?)\]', re.DOTALL) +ACTIONS_PATTERN = re.compile(r'actions\s*=\s*\[(.*?)\]', re.DOTALL) + + +def get_reaction_family(label: str, consider_arc_families: bool = True) -> 'ReactionFamily': + """ + A helper function for getting a cached ReactionFamily object. + + Args: + label (str): The reaction family label. + consider_arc_families (bool, optional): Whether to consider ARC's custom families. + + Returns: + ReactionFamily: The ReactionFamily object. + """ + key = (label, consider_arc_families) + if key not in REACTION_FAMILY_CACHE: + REACTION_FAMILY_CACHE[key] = ReactionFamily(label=label, consider_arc_families=consider_arc_families) + return REACTION_FAMILY_CACHE[key] + + def get_rmg_db_subpath(*parts: str, must_exist: bool = False) -> str: """Return a path under the RMG database, handling both source and packaged layouts.""" if RMG_DB_PATH is None: @@ -63,13 +94,18 @@ def __init__(self, self.groups_as_lines = self.get_groups_file_as_lines(consider_arc_families=consider_arc_families) self.reversible = is_reversible(self.groups_as_lines) self.own_reverse = is_own_reverse(self.groups_as_lines) - self.reactants = get_reactant_groups_from_template(self.groups_as_lines) + reactant_labels = get_initial_reactant_labels_from_template(self.groups_as_lines) + temp_entries = get_entries(self.groups_as_lines, entry_labels=reactant_labels, recursive=True) + self.reactants = get_reactant_groups_from_template(self.groups_as_lines, entries=temp_entries) + leaf_labels = [label for group in self.reactants for label in group] + self.entries = {label: temp_entries[label] for label in leaf_labels if label in temp_entries} + self.groups = {} + for reactant_group in self.reactants: + for label in reactant_group: + if label not in self.groups and label in self.entries: + self.groups[label] = Group().from_adjacency_list(self.entries[label]) self.reactant_num = self.get_reactant_num() self.product_num = get_product_num(self.groups_as_lines) - entry_labels = list() - for reactant_group in self.reactants: - entry_labels.extend(reactant_group) - self.entries = get_entries(self.groups_as_lines, entry_labels=entry_labels) self.actions = get_recipe_actions(self.groups_as_lines) def __str__(self): @@ -125,8 +161,7 @@ def generate_products(self, for reactant_idx, reactant in enumerate(reactants): for groups_idx, group_labels in enumerate(self.reactants): for group_label in group_labels: - group = Group().from_adjacency_list( - get_group_adjlist(self.groups_as_lines, entry_label=group_label)) + group = self.groups[group_label] for mol in reactant.mol_list or [reactant.mol]: splits = group.split() if mol.is_subgraph_isomorphic(other=group, save_order=True) \ @@ -193,8 +228,7 @@ def generate_unimolecular_products(self, reactant_to_group_maps = reactant_to_group_maps[0] for mol in reactants[0].mol_list or [reactants[0].mol]: for reactant_to_group_map in reactant_to_group_maps: - group = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines, - entry_label=reactant_to_group_map['subgroup'])) + group = self.groups[reactant_to_group_map['subgroup']] isomorphic_subgraphs = mol.find_subgraph_isomorphisms(other=group, save_order=True) if len(isomorphic_subgraphs): for isomorphic_subgraph in isomorphic_subgraphs: @@ -248,8 +282,7 @@ def generate_bimolecular_products(self, isomorphic_subgraph_dicts = list() for mol_1 in reactants[0].mol_list or [reactants[0].mol]: for mol_2 in reactants[1].mol_list or [reactants[1].mol]: - splits = Group().from_adjacency_list( - get_group_adjlist(self.groups_as_lines, entry_label=reactant_to_group_maps[0][0]['subgroup'])).split() + splits = self.groups[reactant_to_group_maps[0][0]['subgroup']].split() if len(splits) > 1: for i in [0, 1]: isomorphic_subgraphs_1 = mol_1.find_subgraph_isomorphisms(other=splits[i], save_order=True) @@ -266,11 +299,9 @@ def generate_bimolecular_products(self, mol_2)}) continue for reactant_to_group_map_1 in reactant_to_group_maps[0]: - group_1 = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines, - entry_label=reactant_to_group_map_1['subgroup'])) + group_1 = self.groups[reactant_to_group_map_1['subgroup']] for reactant_to_group_map_2 in reactant_to_group_maps[1]: - group_2 = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines, - entry_label=reactant_to_group_map_2['subgroup'])) + group_2 = self.groups[reactant_to_group_map_2['subgroup']] isomorphic_subgraphs_1 = mol_1.find_subgraph_isomorphisms(other=group_1, save_order=True) isomorphic_subgraphs_2 = mol_2.find_subgraph_isomorphisms(other=group_2, save_order=True) if len(isomorphic_subgraphs_1) and len(isomorphic_subgraphs_2): @@ -383,7 +414,7 @@ def get_reactant_num(self) -> int: int: The number of reactants. """ if len(self.reactants) == 1: - group = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines, entry_label=self.reactants[0][0])) + group = self.groups[self.reactants[0][0]] groups = group.split() return len(groups) else: @@ -481,7 +512,7 @@ def determine_possible_reaction_products_from_family(rxn: 'ARCReaction', and whether the family's template also represents its own reverse. """ product_dicts = list() - family = ReactionFamily(label=family_label, consider_arc_families=consider_arc_families) + family = get_reaction_family(label=family_label, consider_arc_families=consider_arc_families) products = family.generate_products(reactants=rxn.get_reactants_and_products(return_copies=True)[0]) if products: for group_labels, product_lists in products.items(): @@ -666,11 +697,10 @@ def is_reversible(groups_as_lines: List[str]) -> bool: Returns: bool: Whether the reaction family is reversible. """ - for line in groups_as_lines: - if 'reversible = True' in line: - return True - if 'reversible = False' in line: - return False + groups_str = ''.join(groups_as_lines) + match = REVERSIBLE_PATTERN.search(groups_str) + if match: + return match.group(1) == 'True' return True @@ -681,15 +711,16 @@ def is_own_reverse(groups_as_lines: List[str]) -> bool: Returns: bool: Whether the reaction family's template also represents its own reverse. """ - for line in groups_as_lines: - if 'ownReverse=True' in line: - return True - if 'ownReverse=False' in line: - return False + groups_str = ''.join(groups_as_lines) + match = OWN_REVERSE_PATTERN.search(groups_str) + if match: + return match.group(1) == 'True' return False -def get_reactant_groups_from_template(groups_as_lines: List[str]) -> List[List[str]]: +def get_reactant_groups_from_template(groups_as_lines: List[str], + entries: Optional[Dict[str, str]] = None, + ) -> List[List[str]]: """ Get the reactant groups from a template content string. Descends the entries if a group is defined as an OR complex, @@ -697,20 +728,24 @@ def get_reactant_groups_from_template(groups_as_lines: List[str]) -> List[List[s Args: groups_as_lines (List[str]): The template content string. + entries (Dict[str, str], optional): Pre-extracted entries. Returns: List[List[str]]: The non-complex reactant groups. """ reactant_labels = get_initial_reactant_labels_from_template(groups_as_lines) + if entries is None: + entries = get_entries(groups_as_lines, entry_labels=reactant_labels) result = list() for reactant_label in reactant_labels: - if 'OR{' not in get_group_adjlist(groups_as_lines, entry_label=reactant_label): + adj = get_group_adjlist(groups_as_lines, entry_label=reactant_label, entries=entries) + if 'OR{' not in adj: result.append([reactant_label]) else: stack = [reactant_label] - while any('OR{' in get_group_adjlist(groups_as_lines, entry_label=label) for label in stack): + while any('OR{' in get_group_adjlist(groups_as_lines, entry_label=label, entries=entries) for label in stack): label = stack.pop(0) - group_adjlist = get_group_adjlist(groups_as_lines, entry_label=label) + group_adjlist = get_group_adjlist(groups_as_lines, entry_label=label, entries=entries) if 'OR{' not in group_adjlist: stack.append(label) else: @@ -764,13 +799,15 @@ def get_initial_reactant_labels_from_template(groups_as_lines: List[str], Returns: List[str]: The reactant groups. """ - labels = list() - for line in groups_as_lines: - match = re.search(r'products=\[(.*?)\]', line) if products else re.search(r'reactants=\[(.*?)\]', line) - if match: - labels = match.group(1).replace('"', '').split(', ') - break - return labels + groups_str = ''.join(groups_as_lines) + pattern = PRODUCTS_PATTERN if products else REACTANTS_PATTERN + match = pattern.search(groups_str) + if match: + content = match.group(1) + # Use regex to find all quoted strings (with backreferences) or unquoted words + labels = re.findall(r'(["\'])(.*?)\1|(\w+)', content) + return [label[1] or label[2] for label in labels] + return list() def get_recipe_actions(groups_as_lines: List[str]) -> List[List[str]]: @@ -796,33 +833,63 @@ def get_recipe_actions(groups_as_lines: List[str]) -> List[List[str]]: def get_entries(groups_as_lines: List[str], - entry_labels: List[str], + entry_labels: Optional[List[str]] = None, + recursive: bool = False, ) -> Dict[str, str]: """ - Get the requested entries grom a template content string. + Get the requested entries from a template content string. Args: groups_as_lines (List[str]): The template content string. - entry_labels (List[str]): The entry labels to extract. + entry_labels (List[str], optional): The entry labels to extract. If None, all entries are extracted. + recursive (bool, optional): Whether to recursively extract child entries for OR complexes. Returns: Dict[str, str]: The extracted entries, keys are the labels, values are the groups. """ - groups_str = ''.join(groups_as_lines) - entries = re.findall(r'entry\((.*?)\)', groups_str, re.DOTALL) - specific_entries = dict() - for i, entry in enumerate(entries): - label_match = re.search(r'label = "(.*?)"', entry) - group_match = re.search(r'group =(.*?)(?=\w+ =)', entry, re.DOTALL) - if label_match is not None and group_match is not None and label_match.group(1) in entry_labels: - specific_entries[label_match.group(1)] = clean_text(group_match.group(1)) - if i > 2000: - break - return specific_entries + groups_str = "\n" + "".join(groups_as_lines) + # Split by entry( but keep the delimiter-ish part + parts = re.split(r"\nentry\s*\(", groups_str) + + temp_entries = {} + label_pat = re.compile(r"label\s*=\s*(?:([\"'])(.*?)\1|(\w+))") + group_pat = re.compile(r"group\s*=\s*(?:\"\"\"(.*?)\"\"\"|([\"'])(.*?)\2|(OR\{.*?\}))", re.DOTALL) + + for part in parts[1:]: # Skip the header + label_match = label_pat.search(part) + group_match = group_pat.search(part) + if label_match and group_match: + label = label_match.group(2) or label_match.group(3) + # Match group 1 (triple), 3 (single/double), or 4 (OR) + adj = group_match.group(1) or group_match.group(3) or group_match.group(4) + temp_entries[label] = clean_text(adj) + + if entry_labels is None: + return temp_entries + + all_entries = {} + to_process = list(entry_labels) + processed = set() + while to_process: + label = to_process.pop() + if label in processed or label not in temp_entries: + continue + processed.add(label) + adj = temp_entries[label] + all_entries[label] = adj + if recursive and 'OR{' in adj: + # Match OR{label1, label2, ...} + or_match = re.search(r'OR\s*\{\s*(.*?)\s*\}', adj, re.DOTALL) + if or_match: + children_str = or_match.group(1) + children = [c.strip() for c in children_str.split(',')] + to_process.extend(children) + return all_entries def get_group_adjlist(groups_as_lines: List[str], entry_label: str, + entries: Optional[Dict[str, str]] = None, ) -> str: """ Get the corresponding group value for the given entry label. @@ -830,10 +897,13 @@ def get_group_adjlist(groups_as_lines: List[str], Args: groups_as_lines (List[str]): The template content string. entry_label (str): The entry label to extract. + entries (Dict[str, str], optional): Pre-extracted entries. Returns: str: The extracted group. """ + if entries is not None and entry_label in entries: + return entries[entry_label] specific_entries = get_entries(groups_as_lines, entry_labels=[entry_label]) return specific_entries[entry_label]