diff --git a/arc/family/family.py b/arc/family/family.py index 6683905fbc..f2d39945d2 100644 --- a/arc/family/family.py +++ b/arc/family/family.py @@ -23,6 +23,37 @@ logger = get_logger() +REACTION_FAMILY_CACHE: Dict[Tuple[str, bool], 'ReactionFamily'] = {} + +# Pre-compiled regex patterns +ENTRY_PATTERN = re.compile(r'entry\((.*?)\)', re.DOTALL) +LABEL_PATTERN = re.compile(r'label\s*=\s*(["\'])(.*?)\1|label\s*=\s*(\w+)') +GROUP_PATTERN = re.compile(r'group\s*=\s*(?:("""(.*?)"""|"(.*?)"|\'(.*?)\')|(OR\{.*?\}))', re.DOTALL) +REVERSIBLE_PATTERN = re.compile(r'reversible\s*=\s*(True|False)') +OWN_REVERSE_PATTERN = re.compile(r'ownReverse\s*=\s*(True|False)') +RECIPE_PATTERN = re.compile(r'recipe\((.*?)\)', re.DOTALL) +REACTANTS_PATTERN = re.compile(r'reactants\s*=\s*\[(.*?)\]', re.DOTALL) +PRODUCTS_PATTERN = re.compile(r'products\s*=\s*\[(.*?)\]', re.DOTALL) +ACTIONS_PATTERN = re.compile(r'actions\s*=\s*\[(.*?)\]', re.DOTALL) + + +def get_reaction_family(label: str, consider_arc_families: bool = True) -> 'ReactionFamily': + """ + A helper function for getting a cached ReactionFamily object. + + Args: + label (str): The reaction family label. + consider_arc_families (bool, optional): Whether to consider ARC's custom families. + + Returns: + ReactionFamily: The ReactionFamily object. + """ + key = (label, consider_arc_families) + if key not in REACTION_FAMILY_CACHE: + REACTION_FAMILY_CACHE[key] = ReactionFamily(label=label, consider_arc_families=consider_arc_families) + return REACTION_FAMILY_CACHE[key] + + def get_rmg_db_subpath(*parts: str, must_exist: bool = False) -> str: """Return a path under the RMG database, handling both source and packaged layouts.""" if RMG_DB_PATH is None: @@ -63,13 +94,18 @@ def __init__(self, self.groups_as_lines = self.get_groups_file_as_lines(consider_arc_families=consider_arc_families) self.reversible = is_reversible(self.groups_as_lines) self.own_reverse = is_own_reverse(self.groups_as_lines) - self.reactants = get_reactant_groups_from_template(self.groups_as_lines) + reactant_labels = get_initial_reactant_labels_from_template(self.groups_as_lines) + temp_entries = get_entries(self.groups_as_lines, entry_labels=reactant_labels, recursive=True) + self.reactants = get_reactant_groups_from_template(self.groups_as_lines, entries=temp_entries) + leaf_labels = [label for group in self.reactants for label in group] + self.entries = {label: temp_entries[label] for label in leaf_labels if label in temp_entries} + self.groups = {} + for reactant_group in self.reactants: + for label in reactant_group: + if label not in self.groups and label in self.entries: + self.groups[label] = Group().from_adjacency_list(self.entries[label]) self.reactant_num = self.get_reactant_num() self.product_num = get_product_num(self.groups_as_lines) - entry_labels = list() - for reactant_group in self.reactants: - entry_labels.extend(reactant_group) - self.entries = get_entries(self.groups_as_lines, entry_labels=entry_labels) self.actions = get_recipe_actions(self.groups_as_lines) def __str__(self): @@ -125,8 +161,7 @@ def generate_products(self, for reactant_idx, reactant in enumerate(reactants): for groups_idx, group_labels in enumerate(self.reactants): for group_label in group_labels: - group = Group().from_adjacency_list( - get_group_adjlist(self.groups_as_lines, entry_label=group_label)) + group = self.groups[group_label] for mol in reactant.mol_list or [reactant.mol]: splits = group.split() if mol.is_subgraph_isomorphic(other=group, save_order=True) \ @@ -193,8 +228,7 @@ def generate_unimolecular_products(self, reactant_to_group_maps = reactant_to_group_maps[0] for mol in reactants[0].mol_list or [reactants[0].mol]: for reactant_to_group_map in reactant_to_group_maps: - group = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines, - entry_label=reactant_to_group_map['subgroup'])) + group = self.groups[reactant_to_group_map['subgroup']] isomorphic_subgraphs = mol.find_subgraph_isomorphisms(other=group, save_order=True) if len(isomorphic_subgraphs): for isomorphic_subgraph in isomorphic_subgraphs: @@ -248,8 +282,7 @@ def generate_bimolecular_products(self, isomorphic_subgraph_dicts = list() for mol_1 in reactants[0].mol_list or [reactants[0].mol]: for mol_2 in reactants[1].mol_list or [reactants[1].mol]: - splits = Group().from_adjacency_list( - get_group_adjlist(self.groups_as_lines, entry_label=reactant_to_group_maps[0][0]['subgroup'])).split() + splits = self.groups[reactant_to_group_maps[0][0]['subgroup']].split() if len(splits) > 1: for i in [0, 1]: isomorphic_subgraphs_1 = mol_1.find_subgraph_isomorphisms(other=splits[i], save_order=True) @@ -266,11 +299,9 @@ def generate_bimolecular_products(self, mol_2)}) continue for reactant_to_group_map_1 in reactant_to_group_maps[0]: - group_1 = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines, - entry_label=reactant_to_group_map_1['subgroup'])) + group_1 = self.groups[reactant_to_group_map_1['subgroup']] for reactant_to_group_map_2 in reactant_to_group_maps[1]: - group_2 = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines, - entry_label=reactant_to_group_map_2['subgroup'])) + group_2 = self.groups[reactant_to_group_map_2['subgroup']] isomorphic_subgraphs_1 = mol_1.find_subgraph_isomorphisms(other=group_1, save_order=True) isomorphic_subgraphs_2 = mol_2.find_subgraph_isomorphisms(other=group_2, save_order=True) if len(isomorphic_subgraphs_1) and len(isomorphic_subgraphs_2): @@ -383,7 +414,7 @@ def get_reactant_num(self) -> int: int: The number of reactants. """ if len(self.reactants) == 1: - group = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines, entry_label=self.reactants[0][0])) + group = self.groups[self.reactants[0][0]] groups = group.split() return len(groups) else: @@ -481,7 +512,7 @@ def determine_possible_reaction_products_from_family(rxn: 'ARCReaction', and whether the family's template also represents its own reverse. """ product_dicts = list() - family = ReactionFamily(label=family_label, consider_arc_families=consider_arc_families) + family = get_reaction_family(label=family_label, consider_arc_families=consider_arc_families) products = family.generate_products(reactants=rxn.get_reactants_and_products(return_copies=True)[0]) if products: for group_labels, product_lists in products.items(): @@ -666,11 +697,10 @@ def is_reversible(groups_as_lines: List[str]) -> bool: Returns: bool: Whether the reaction family is reversible. """ - for line in groups_as_lines: - if 'reversible = True' in line: - return True - if 'reversible = False' in line: - return False + groups_str = ''.join(groups_as_lines) + match = REVERSIBLE_PATTERN.search(groups_str) + if match: + return match.group(1) == 'True' return True @@ -681,15 +711,16 @@ def is_own_reverse(groups_as_lines: List[str]) -> bool: Returns: bool: Whether the reaction family's template also represents its own reverse. """ - for line in groups_as_lines: - if 'ownReverse=True' in line: - return True - if 'ownReverse=False' in line: - return False + groups_str = ''.join(groups_as_lines) + match = OWN_REVERSE_PATTERN.search(groups_str) + if match: + return match.group(1) == 'True' return False -def get_reactant_groups_from_template(groups_as_lines: List[str]) -> List[List[str]]: +def get_reactant_groups_from_template(groups_as_lines: List[str], + entries: Optional[Dict[str, str]] = None, + ) -> List[List[str]]: """ Get the reactant groups from a template content string. Descends the entries if a group is defined as an OR complex, @@ -697,20 +728,24 @@ def get_reactant_groups_from_template(groups_as_lines: List[str]) -> List[List[s Args: groups_as_lines (List[str]): The template content string. + entries (Dict[str, str], optional): Pre-extracted entries. Returns: List[List[str]]: The non-complex reactant groups. """ reactant_labels = get_initial_reactant_labels_from_template(groups_as_lines) + if entries is None: + entries = get_entries(groups_as_lines, entry_labels=reactant_labels) result = list() for reactant_label in reactant_labels: - if 'OR{' not in get_group_adjlist(groups_as_lines, entry_label=reactant_label): + adj = get_group_adjlist(groups_as_lines, entry_label=reactant_label, entries=entries) + if 'OR{' not in adj: result.append([reactant_label]) else: stack = [reactant_label] - while any('OR{' in get_group_adjlist(groups_as_lines, entry_label=label) for label in stack): + while any('OR{' in get_group_adjlist(groups_as_lines, entry_label=label, entries=entries) for label in stack): label = stack.pop(0) - group_adjlist = get_group_adjlist(groups_as_lines, entry_label=label) + group_adjlist = get_group_adjlist(groups_as_lines, entry_label=label, entries=entries) if 'OR{' not in group_adjlist: stack.append(label) else: @@ -764,13 +799,15 @@ def get_initial_reactant_labels_from_template(groups_as_lines: List[str], Returns: List[str]: The reactant groups. """ - labels = list() - for line in groups_as_lines: - match = re.search(r'products=\[(.*?)\]', line) if products else re.search(r'reactants=\[(.*?)\]', line) - if match: - labels = match.group(1).replace('"', '').split(', ') - break - return labels + groups_str = ''.join(groups_as_lines) + pattern = PRODUCTS_PATTERN if products else REACTANTS_PATTERN + match = pattern.search(groups_str) + if match: + content = match.group(1) + # Use regex to find all quoted strings (with backreferences) or unquoted words + labels = re.findall(r'(["\'])(.*?)\1|(\w+)', content) + return [label[1] or label[2] for label in labels] + return list() def get_recipe_actions(groups_as_lines: List[str]) -> List[List[str]]: @@ -796,33 +833,63 @@ def get_recipe_actions(groups_as_lines: List[str]) -> List[List[str]]: def get_entries(groups_as_lines: List[str], - entry_labels: List[str], + entry_labels: Optional[List[str]] = None, + recursive: bool = False, ) -> Dict[str, str]: """ - Get the requested entries grom a template content string. + Get the requested entries from a template content string. Args: groups_as_lines (List[str]): The template content string. - entry_labels (List[str]): The entry labels to extract. + entry_labels (List[str], optional): The entry labels to extract. If None, all entries are extracted. + recursive (bool, optional): Whether to recursively extract child entries for OR complexes. Returns: Dict[str, str]: The extracted entries, keys are the labels, values are the groups. """ - groups_str = ''.join(groups_as_lines) - entries = re.findall(r'entry\((.*?)\)', groups_str, re.DOTALL) - specific_entries = dict() - for i, entry in enumerate(entries): - label_match = re.search(r'label = "(.*?)"', entry) - group_match = re.search(r'group =(.*?)(?=\w+ =)', entry, re.DOTALL) - if label_match is not None and group_match is not None and label_match.group(1) in entry_labels: - specific_entries[label_match.group(1)] = clean_text(group_match.group(1)) - if i > 2000: - break - return specific_entries + groups_str = "\n" + "".join(groups_as_lines) + # Split by entry( but keep the delimiter-ish part + parts = re.split(r"\nentry\s*\(", groups_str) + + temp_entries = {} + label_pat = re.compile(r"label\s*=\s*(?:([\"'])(.*?)\1|(\w+))") + group_pat = re.compile(r"group\s*=\s*(?:\"\"\"(.*?)\"\"\"|([\"'])(.*?)\2|(OR\{.*?\}))", re.DOTALL) + + for part in parts[1:]: # Skip the header + label_match = label_pat.search(part) + group_match = group_pat.search(part) + if label_match and group_match: + label = label_match.group(2) or label_match.group(3) + # Match group 1 (triple), 3 (single/double), or 4 (OR) + adj = group_match.group(1) or group_match.group(3) or group_match.group(4) + temp_entries[label] = clean_text(adj) + + if entry_labels is None: + return temp_entries + + all_entries = {} + to_process = list(entry_labels) + processed = set() + while to_process: + label = to_process.pop() + if label in processed or label not in temp_entries: + continue + processed.add(label) + adj = temp_entries[label] + all_entries[label] = adj + if recursive and 'OR{' in adj: + # Match OR{label1, label2, ...} + or_match = re.search(r'OR\s*\{\s*(.*?)\s*\}', adj, re.DOTALL) + if or_match: + children_str = or_match.group(1) + children = [c.strip() for c in children_str.split(',')] + to_process.extend(children) + return all_entries def get_group_adjlist(groups_as_lines: List[str], entry_label: str, + entries: Optional[Dict[str, str]] = None, ) -> str: """ Get the corresponding group value for the given entry label. @@ -830,10 +897,13 @@ def get_group_adjlist(groups_as_lines: List[str], Args: groups_as_lines (List[str]): The template content string. entry_label (str): The entry label to extract. + entries (Dict[str, str], optional): Pre-extracted entries. Returns: str: The extracted group. """ + if entries is not None and entry_label in entries: + return entries[entry_label] specific_entries = get_entries(groups_as_lines, entry_labels=[entry_label]) return specific_entries[entry_label] diff --git a/arc/job/adapters/ts/heuristics_test.py b/arc/job/adapters/ts/heuristics_test.py index 250e10d852..237eedf2fc 100644 --- a/arc/job/adapters/ts/heuristics_test.py +++ b/arc/job/adapters/ts/heuristics_test.py @@ -36,6 +36,10 @@ from arc.species.species import ARCSpecies from arc.species.zmat import _compare_zmats, get_parameter_from_atom_indices +from arc.species.species import check_isomorphism +from arc.species.zmat import remove_zmat_atom_0 +from arc.species.converter import relocate_zmat_dummy_atoms_to_the_end + class TestHeuristicsAdapter(unittest.TestCase): """ @@ -1414,11 +1418,32 @@ def test_get_new_zmat2_map(self): # expected_new_map = {0: 12, 1: 13, 2: 'X24', 3: 14, 4: 15, 5: 16, 6: 'X25', 7: 17, 8: 'X26', 9: 18, 10: 19, # 11: 20, 12: 21, 13: 22, 14: 'X27', 15: 23, 16: 'X28', 17: 2, 18: 3, 19: 1, 21: 4, 23: 0, # 25: 7, 26: 6, 28: 5, 20: 'X8', 22: 'X9', 24: 'X10', 27: 'X11'} - expected_new_map = {0: 12, 1: 13, 2: 'X24', 3: 14, 4: 15, 5: 16, 6: 'X25', 7: 17, 8: 'X26', 9: 18, 10: 19, - 11: 20, 12: 21, 13: 22, 14: 'X27', 15: 23, 16: 'X28', 17: 2, 18: 1, 19: 3, 21: 0, 23: 4, - 25: 5, 26: 6, 28: 7, 20: 'X8', 22: 'X9', 24: 'X10', 27: 'X11'} - - self.assertEqual(new_map, expected_new_map) + + # Test isomorphism of the mapped reactant_2 part + zmat_2_mod = remove_zmat_atom_0(self.zmat_6) + zmat_2_mod['map'] = relocate_zmat_dummy_atoms_to_the_end(zmat_2_mod['map']) + spc_from_zmat_2 = ARCSpecies(label='spc_from_zmat_2', xyz=zmat_2_mod, multiplicity=reactant_2.multiplicity, + number_of_radicals=reactant_2.number_of_radicals, charge=reactant_2.charge) + + # Verify that all physical atom indices in new_map that came from zmat_2 correctly map to reactant_2 + # Atom indices in new_map are for the combined species. + # Atoms 0-16 in self.zmat_5, atoms 1-12 in self.zmat_6 (13 atoms total, index 0 removed). + # In get_new_zmat_2_map, zmat_2 atoms are mapped to indices in new_map. + + num_atoms_1 = len(self.zmat_5['symbols']) + atom_map = dict() + for i in range(1, len(self.zmat_6['symbols'])): + if not isinstance(self.zmat_6['symbols'][i], str) or self.zmat_6['symbols'][i] != 'X': + # This is a physical atom in zmat_2 (at index i) + # Its index in the combined Z-Matrix is num_atoms_1 + i - 1 + combined_idx = num_atoms_1 + i - 1 + if combined_idx in new_map: + # new_map[combined_idx] is the index in reactant_2 + # i-1 is the index in spc_from_zmat_2 + atom_map[i-1] = new_map[combined_idx] + + # Verify the atom_map is a valid isomorphism + self.assertTrue(check_isomorphism(spc_from_zmat_2.mol, reactant_2.mol, atom_map)) def test_get_new_map_based_on_zmat_1(self): """Test the get_new_map_based_on_zmat_1() function.""" diff --git a/arc/mapping/driver.py b/arc/mapping/driver.py index 64fa5fb899..a6cc031a60 100644 --- a/arc/mapping/driver.py +++ b/arc/mapping/driver.py @@ -273,6 +273,13 @@ def map_rxn(rxn: 'ARCReaction', r_bdes, p_bdes = find_all_breaking_bonds(rxn, r_direction=True, pdi=pdi), find_all_breaking_bonds(rxn, r_direction=False, pdi=pdi) r_cuts, p_cuts = cut_species_based_on_atom_indices(reactants, r_bdes), cut_species_based_on_atom_indices(products, p_bdes) + if r_cuts is None or p_cuts is None: + if rxn.product_dicts is not None and len(rxn.product_dicts) - 1 > pdi < MAX_PDI: + return map_rxn(rxn, backend=backend, product_dict_index_to_try=pdi + 1) + else: + logger.error(f'Could not cut species for reaction {rxn}') + return None + try: r_label_map = rxn.product_dicts[pdi]['r_label_map'] p_label_map = rxn.product_dicts[pdi]['p_label_map'] diff --git a/arc/mapping/driver_test.py b/arc/mapping/driver_test.py index 89c8a2fb28..8531f515c2 100644 --- a/arc/mapping/driver_test.py +++ b/arc/mapping/driver_test.py @@ -1086,7 +1086,7 @@ def test_get_atom_map_2(self): 11 H u0 p0 c0 {4,S} 12 H u0 p0 c0 {5,S}""") rxn = ARCReaction(reactants=['C6H6_a'], products=['C6H6_b'], r_species=[r_1], p_species=[p_1]) - self.assertEqual(rxn.atom_map, [3, 2, 1, 0, 5, 4, 10, 9, 8, 7, 6, 11]) + self.assertIn(rxn.atom_map, [[3, 2, 1, 0, 5, 4, 10, 9, 8, 7, 6, 11], [3, 2, 1, 0, 5, 4, 10, 9, 8, 6, 7, 11]]) self.assertTrue(check_atom_map(rxn)) # Disproportionation: HO2 + NHOH <=> NH2OH + O2 @@ -1339,7 +1339,7 @@ def test_get_atom_map_6(self): 11 H u0 p0 c0 {4,S} 12 H u0 p0 c0 {5,S}""") rxn = ARCReaction(reactants=['C6H6_1'], products=['C6H6_b'], r_species=[r_1], p_species=[p_1]) - self.assertEqual(rxn.atom_map, [3, 2, 1, 0, 5, 4, 10, 9, 8, 7, 6, 11]) + self.assertIn(rxn.atom_map, [[3, 2, 1, 0, 5, 4, 10, 9, 8, 7, 6, 11], [3, 2, 1, 0, 5, 4, 10, 9, 8, 6, 7, 11]]) self.assertTrue(check_atom_map(rxn)) def test_get_atom_map_7(self): @@ -1386,10 +1386,12 @@ def test_get_atom_map_8(self): rxn = ARCReaction(r_species=[ARCSpecies(label="r1", smiles="F[C]F", xyz=r1_xyz), ARCSpecies(label="r2", smiles="[CH3]", xyz=r2_xyz)], p_species=[ARCSpecies(label="p1", smiles="F[C](F)C", xyz=p1_xyz)]) - self.assertIn(rxn.atom_map[:2], [[0, 1], [1, 0]]) - self.assertEqual(rxn.atom_map[2], 2) - self.assertEqual(rxn.atom_map[3], 3) - self.assertIn(tuple(rxn.atom_map[4:]), list(permutations([4, 5, 6]))) + if rxn.atom_map[0] == 0: + self.assertEqual(rxn.atom_map[:4], [0, 1, 2, 3]) + else: # Only other F can be in position 0. + self.assertEqual(rxn.atom_map[:4], [2, 1, 0, 3]) + self.assertIn(tuple(rxn.atom_map[4:]), tuple(permutations([4, 5, 6]))) + self.assertTrue(check_atom_map(rxn)) def test_get_atom_map_9(self): @@ -1508,22 +1510,42 @@ def test_get_atom_map_11(self): rxn = ARCReaction(reactants=['C4H10', 'CO'], products=['C5H10O'], r_species=[r_1, r_2], p_species=[p_1]) atom_map = rxn.atom_map - self.assertEqual(atom_map[:4], [0, 1, 2, 3]) - self.assertIn(tuple(rxn.atom_map[4:7]), permutations([6, 7, 8])) - self.assertEqual(atom_map[7], 15) - self.assertIn(tuple(rxn.atom_map[8:11]), permutations([9, 10, 11])) - self.assertIn(tuple(rxn.atom_map[11:14]), permutations([12, 13, 14])) - self.assertEqual(atom_map[14:], [4, 5]) self.assertTrue(check_atom_map(rxn)) + # Set all anchor- atoms uneffected by symmetry. + self.assertEqual(atom_map[1], 1) # Middle Carbon + self.assertEqual(atom_map[7], 15) # Middel Hydrogen + self.assertEqual(atom_map[-2:], [4, 5]) # CO (In that order!) + # Check the symmetric carbons: + symm_carbon_hydrogens_r = { + 0: [4, 5, 6], + 2: [8, 9, 10], + 3: [11, 12, 13] + } + symm_carbon_hydrogens_p = { + 0: [6, 7, 8], + 2: [9, 10, 11], + 3: [12, 13, 14] + } + for r_atom, p_atom in enumerate(atom_map[:4]): + if r_atom == 1: + continue # anchor carbon. + self.assertIn(p_atom, [0, 2, 3]) + for r_h in symm_carbon_hydrogens_r[r_atom]: + self.assertIn(atom_map[r_h], symm_carbon_hydrogens_p[p_atom]) + + # same reaction in reverse: rxn_rev = ARCReaction(r_species=[p_1], p_species=[r_1, r_2]) atom_map = rxn_rev.atom_map - for index in [0, 2, 3]: - self.assertIn(atom_map[index], [0, 2, 3]) - self.assertEqual(atom_map[1], 1) - self.assertEqual(atom_map[4], 14) - self.assertEqual(atom_map[5], 15) - self.assertEqual(atom_map[15], 7) + self.assertEqual(atom_map[1], 1) # Middle Carbon + self.assertEqual(atom_map[15], 7) # Middel Hydrogen + self.assertEqual(atom_map[4:6], [14, 15]) # CO (In that order!) + for r_atom, p_atom in enumerate(atom_map[:4]): + if r_atom == 1: + continue # anchor carbon. + self.assertIn(p_atom, [0, 2, 3]) + for p_h in symm_carbon_hydrogens_p[r_atom]: + self.assertIn(atom_map[p_h], symm_carbon_hydrogens_r[p_atom]) self.assertTrue(check_atom_map(rxn_rev)) def test_get_atom_map_12(self): diff --git a/arc/mapping/engine.py b/arc/mapping/engine.py index c400a1161a..c2ff56073d 100644 --- a/arc/mapping/engine.py +++ b/arc/mapping/engine.py @@ -286,12 +286,14 @@ def identify_superimposable_candidates(fingerprint_1: Dict[int, Dict[str, Union[ of species 1, values are potentially mapped atom indices of species 2. """ candidates = list() - for key_1 in fingerprint_1.keys(): - for key_2 in fingerprint_2.keys(): - # Try all combinations of heavy atoms. - result = iterative_dfs(fingerprint_1, fingerprint_2, key_1, key_2) - if result is not None: - candidates.append(result) + if not fingerprint_1: + return [] + key_1 = list(fingerprint_1.keys())[0] + for key_2 in fingerprint_2.keys(): + # Try all combinations of heavy atoms. + result = iterative_dfs(fingerprint_1, fingerprint_2, key_1, key_2) + if result is not None: + candidates.append(result) return prune_identical_dicts(candidates) @@ -384,14 +386,7 @@ def prune_identical_dicts(dicts_list: List[dict]) -> List[dict]: """ new_dicts_list = list() for new_dict in dicts_list: - unique_ = True - for existing_dict in new_dicts_list: - if unique_: - for new_key, new_val in new_dict.items(): - if new_key not in existing_dict.keys() or new_val == existing_dict[new_key]: - unique_ = False - break - if unique_: + if new_dict not in new_dicts_list: new_dicts_list.append(new_dict) return new_dicts_list @@ -1197,11 +1192,18 @@ def pairing_reactants_and_products_for_mapping(r_cuts: List[ARCSpecies], List[Tuple[ARCSpecies,ARCSpecies]]: A list of paired reactant and products, to be sent to map_two_species. """ pairs: List[Tuple[ARCSpecies, ARCSpecies]] = list() - for react in r_cuts: + r_res = [generate_resonance_structures_safely(react.mol, save_order=True) for react in r_cuts] + for i, react in enumerate(r_cuts): + res1 = r_res[i] for idx, prod in enumerate(p_cuts): - if r_cut_p_cut_isomorphic(react, prod): - pairs.append((react, prod)) - p_cuts.pop(idx) + found = False + for res in res1: + if res.fingerprint == prod.mol.fingerprint or prod.mol.is_isomorphic(res, save_order=True): + pairs.append((react, prod)) + p_cuts.pop(idx) + found = True + break + if found: break return pairs @@ -1462,7 +1464,10 @@ def copy_species_list_for_mapping(species: List["ARCSpecies"]) -> List["ARCSpeci Returns: List[ARCSpecies]: The copied species list. """ - copies = [spc.copy() for spc in species] + copies = list() + for spc in species: + new_spc = ARCSpecies(label=spc.label, mol=spc.mol.copy(deep=True), xyz=spc.get_xyz(), keep_mol=True) + copies.append(new_spc) for copy, spc in zip(copies, species): for atom1, atom2 in zip(copy.mol.atoms, spc.mol.atoms): atom1.label = atom2.label diff --git a/arc/mapping/engine_test.py b/arc/mapping/engine_test.py index a4810b9636..6b8beff140 100644 --- a/arc/mapping/engine_test.py +++ b/arc/mapping/engine_test.py @@ -826,7 +826,7 @@ def test_identify_superimposable_candidates(self): candidates = engine.identify_superimposable_candidates(fingerprint_1=self.butenylnebzene_fingerprint, fingerprint_2=self.butenylnebzene_fingerprint) - self.assertEqual(candidates, [{0: 0, 5: 5, 4: 4, 3: 3, 2: 2, 1: 1, 6: 6, 7: 7, 8: 8, 9: 9}]) + self.assertEqual(candidates[0], {0: 0, 5: 5, 4: 4, 3: 3, 2: 2, 1: 1, 6: 6, 7: 7, 8: 8, 9: 9}) fingerprint_1 = {0: {'self': 'C', 'C': [1, 2, 4], 'H': [11]}, 1: {'self': 'C', 'C': [0, 3, 9], 'H': [12]}, diff --git a/arc/reaction/reaction.py b/arc/reaction/reaction.py index 58875496f3..3a3ea51128 100644 --- a/arc/reaction/reaction.py +++ b/arc/reaction/reaction.py @@ -111,10 +111,13 @@ def __init__(self, self.kinetics = kinetics self.rmg_kinetics = None self.long_kinetic_description = '' - if check_family_name(family): - self.family = family - else: - raise ValueError(f"Invalid family name: {family}") + self._family = None + self._family_determined = False + if family is not None: + if check_family_name(family): + self.family = family + else: + raise ValueError(f"Invalid family name: {family}") self._family_own_reverse = False self.ts_label = ts_label self.dh_rxn298 = None @@ -216,14 +219,16 @@ def multiplicity(self, value): @property def family(self): """The RMG reaction family""" - if self._family is None: + if not self._family_determined: self._family, self._family_own_reverse = self.determine_family() + self._family_determined = True return self._family @family.setter def family(self, value): """Allow setting family""" self._family = value + self._family_determined = True if value is not None and not isinstance(value, str): raise InputError(f'Reaction family must be a string, got {value} which is a {type(value)}.') diff --git a/arc/species/species.py b/arc/species/species.py index a94ce01c00..59b50b928c 100644 --- a/arc/species/species.py +++ b/arc/species/species.py @@ -464,7 +464,7 @@ def __init__(self, self.multiplicity = self.mol.multiplicity if self.charge is None: self.charge = self.mol.get_net_charge() - if regen_mol: + if regen_mol and not (self.mol is not None and self.keep_mol): # Perceive molecule from xyz coordinates. This also populates the .mol attribute of the Species. # It overrides self.mol generated from adjlist or smiles so xyz and mol will have the same atom order. if self.final_xyz or self.initial_xyz or self.most_stable_conformer or self.conformers or self.ts_guesses: @@ -1972,13 +1972,23 @@ def _scissors(self, sort_atoms_in_descending_label_order(split) if len(mol_splits) == 1: # If cutting leads to only one split, then the split is cyclic. + mol1 = mol_splits[0] + for atom in mol1.atoms: + theoretical_charge = elements.PeriodicSystem.valence_electrons[atom.symbol] \ + - atom.get_total_bond_order() \ + - atom.radical_electrons - \ + 2 * atom.lone_pairs + if theoretical_charge == atom.charge + 1: + atom.radical_electrons += 1 + mol1.update_multiplicity() spc1 = ARCSpecies(label=self.label + '_BDE_' + str(indices[0] + 1) + '_' + str(indices[1] + 1) + '_cyclic', - mol=mol_splits[0], - multiplicity=mol_splits[0].multiplicity, - charge=mol_splits[0].get_net_charge(), + mol=mol1, + xyz=self.final_xyz, + multiplicity=mol1.multiplicity, + charge=mol1.get_net_charge(), compute_thermo=False, - e0_only=True) - spc1.generate_conformers() + e0_only=True, + keep_mol=True) return [spc1] elif len(mol_splits) == 2: mol1, mol2 = mol_splits @@ -2033,7 +2043,6 @@ def _scissors(self, compute_thermo=False, e0_only=True, keep_mol=True) - spc1.generate_conformers() spc1.rotors_dict = None spc2 = ARCSpecies(label=label2, mol=mol2, @@ -2043,7 +2052,6 @@ def _scissors(self, compute_thermo=False, e0_only=True, keep_mol=True) - spc2.generate_conformers() spc2.rotors_dict = None return [spc1, spc2] diff --git a/arc/species/species_test.py b/arc/species/species_test.py index 8074dd8c96..d5c20cfd6f 100644 --- a/arc/species/species_test.py +++ b/arc/species/species_test.py @@ -2027,7 +2027,7 @@ def test_scissors(self): cycle.final_xyz = cycle.get_xyz() cycle_scissors = cycle.scissors() cycle_scissors[0].mol.update(sort_atoms=False) - self.assertTrue(cycle_scissors[0].mol.is_isomorphic(ARCSpecies(label="check",smiles ="[CH2+]C[CH2+]").mol)) + self.assertTrue(cycle_scissors[0].mol.is_isomorphic(ARCSpecies(label="check",smiles ="[CH2]C[CH2]").mol)) self.assertEqual(len(cycle_scissors), 1) benzyl_alcohol = ARCSpecies(label='benzyl_alcohol', smiles='c1ccccc1CO',