Skip to content
Open
174 changes: 122 additions & 52 deletions arc/family/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,37 @@
logger = get_logger()


REACTION_FAMILY_CACHE: Dict[Tuple[str, bool], 'ReactionFamily'] = {}

# Pre-compiled regex patterns
ENTRY_PATTERN = re.compile(r'entry\((.*?)\)', re.DOTALL)
LABEL_PATTERN = re.compile(r'label\s*=\s*(["\'])(.*?)\1|label\s*=\s*(\w+)')
GROUP_PATTERN = re.compile(r'group\s*=\s*(?:("""(.*?)"""|"(.*?)"|\'(.*?)\')|(OR\{.*?\}))', re.DOTALL)
REVERSIBLE_PATTERN = re.compile(r'reversible\s*=\s*(True|False)')
OWN_REVERSE_PATTERN = re.compile(r'ownReverse\s*=\s*(True|False)')
RECIPE_PATTERN = re.compile(r'recipe\((.*?)\)', re.DOTALL)
REACTANTS_PATTERN = re.compile(r'reactants\s*=\s*\[(.*?)\]', re.DOTALL)
PRODUCTS_PATTERN = re.compile(r'products\s*=\s*\[(.*?)\]', re.DOTALL)
ACTIONS_PATTERN = re.compile(r'actions\s*=\s*\[(.*?)\]', re.DOTALL)


def get_reaction_family(label: str, consider_arc_families: bool = True) -> 'ReactionFamily':
"""
A helper function for getting a cached ReactionFamily object.

Args:
label (str): The reaction family label.
consider_arc_families (bool, optional): Whether to consider ARC's custom families.

Returns:
ReactionFamily: The ReactionFamily object.
"""
key = (label, consider_arc_families)
if key not in REACTION_FAMILY_CACHE:
REACTION_FAMILY_CACHE[key] = ReactionFamily(label=label, consider_arc_families=consider_arc_families)
return REACTION_FAMILY_CACHE[key]


def get_rmg_db_subpath(*parts: str, must_exist: bool = False) -> str:
"""Return a path under the RMG database, handling both source and packaged layouts."""
if RMG_DB_PATH is None:
Expand Down Expand Up @@ -63,13 +94,18 @@ def __init__(self,
self.groups_as_lines = self.get_groups_file_as_lines(consider_arc_families=consider_arc_families)
self.reversible = is_reversible(self.groups_as_lines)
self.own_reverse = is_own_reverse(self.groups_as_lines)
self.reactants = get_reactant_groups_from_template(self.groups_as_lines)
reactant_labels = get_initial_reactant_labels_from_template(self.groups_as_lines)
temp_entries = get_entries(self.groups_as_lines, entry_labels=reactant_labels, recursive=True)
self.reactants = get_reactant_groups_from_template(self.groups_as_lines, entries=temp_entries)
leaf_labels = [label for group in self.reactants for label in group]
self.entries = {label: temp_entries[label] for label in leaf_labels if label in temp_entries}
self.groups = {}
for reactant_group in self.reactants:
for label in reactant_group:
if label not in self.groups and label in self.entries:
self.groups[label] = Group().from_adjacency_list(self.entries[label])
self.reactant_num = self.get_reactant_num()
self.product_num = get_product_num(self.groups_as_lines)
entry_labels = list()
for reactant_group in self.reactants:
entry_labels.extend(reactant_group)
self.entries = get_entries(self.groups_as_lines, entry_labels=entry_labels)
self.actions = get_recipe_actions(self.groups_as_lines)

def __str__(self):
Expand Down Expand Up @@ -125,8 +161,7 @@ def generate_products(self,
for reactant_idx, reactant in enumerate(reactants):
for groups_idx, group_labels in enumerate(self.reactants):
for group_label in group_labels:
group = Group().from_adjacency_list(
get_group_adjlist(self.groups_as_lines, entry_label=group_label))
group = self.groups[group_label]
for mol in reactant.mol_list or [reactant.mol]:
splits = group.split()
if mol.is_subgraph_isomorphic(other=group, save_order=True) \
Expand Down Expand Up @@ -193,8 +228,7 @@ def generate_unimolecular_products(self,
reactant_to_group_maps = reactant_to_group_maps[0]
for mol in reactants[0].mol_list or [reactants[0].mol]:
for reactant_to_group_map in reactant_to_group_maps:
group = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines,
entry_label=reactant_to_group_map['subgroup']))
group = self.groups[reactant_to_group_map['subgroup']]
isomorphic_subgraphs = mol.find_subgraph_isomorphisms(other=group, save_order=True)
if len(isomorphic_subgraphs):
for isomorphic_subgraph in isomorphic_subgraphs:
Expand Down Expand Up @@ -248,8 +282,7 @@ def generate_bimolecular_products(self,
isomorphic_subgraph_dicts = list()
for mol_1 in reactants[0].mol_list or [reactants[0].mol]:
for mol_2 in reactants[1].mol_list or [reactants[1].mol]:
splits = Group().from_adjacency_list(
get_group_adjlist(self.groups_as_lines, entry_label=reactant_to_group_maps[0][0]['subgroup'])).split()
splits = self.groups[reactant_to_group_maps[0][0]['subgroup']].split()
if len(splits) > 1:
for i in [0, 1]:
isomorphic_subgraphs_1 = mol_1.find_subgraph_isomorphisms(other=splits[i], save_order=True)
Expand All @@ -266,11 +299,9 @@ def generate_bimolecular_products(self,
mol_2)})
continue
for reactant_to_group_map_1 in reactant_to_group_maps[0]:
group_1 = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines,
entry_label=reactant_to_group_map_1['subgroup']))
group_1 = self.groups[reactant_to_group_map_1['subgroup']]
for reactant_to_group_map_2 in reactant_to_group_maps[1]:
group_2 = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines,
entry_label=reactant_to_group_map_2['subgroup']))
group_2 = self.groups[reactant_to_group_map_2['subgroup']]
isomorphic_subgraphs_1 = mol_1.find_subgraph_isomorphisms(other=group_1, save_order=True)
isomorphic_subgraphs_2 = mol_2.find_subgraph_isomorphisms(other=group_2, save_order=True)
if len(isomorphic_subgraphs_1) and len(isomorphic_subgraphs_2):
Expand Down Expand Up @@ -383,7 +414,7 @@ def get_reactant_num(self) -> int:
int: The number of reactants.
"""
if len(self.reactants) == 1:
group = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines, entry_label=self.reactants[0][0]))
group = self.groups[self.reactants[0][0]]
groups = group.split()
return len(groups)
else:
Expand Down Expand Up @@ -481,7 +512,7 @@ def determine_possible_reaction_products_from_family(rxn: 'ARCReaction',
and whether the family's template also represents its own reverse.
"""
product_dicts = list()
family = ReactionFamily(label=family_label, consider_arc_families=consider_arc_families)
family = get_reaction_family(label=family_label, consider_arc_families=consider_arc_families)
products = family.generate_products(reactants=rxn.get_reactants_and_products(return_copies=True)[0])
if products:
for group_labels, product_lists in products.items():
Expand Down Expand Up @@ -666,11 +697,10 @@ def is_reversible(groups_as_lines: List[str]) -> bool:
Returns:
bool: Whether the reaction family is reversible.
"""
for line in groups_as_lines:
if 'reversible = True' in line:
return True
if 'reversible = False' in line:
return False
groups_str = ''.join(groups_as_lines)
match = REVERSIBLE_PATTERN.search(groups_str)
if match:
return match.group(1) == 'True'
return True


Expand All @@ -681,36 +711,41 @@ def is_own_reverse(groups_as_lines: List[str]) -> bool:
Returns:
bool: Whether the reaction family's template also represents its own reverse.
"""
for line in groups_as_lines:
if 'ownReverse=True' in line:
return True
if 'ownReverse=False' in line:
return False
groups_str = ''.join(groups_as_lines)
match = OWN_REVERSE_PATTERN.search(groups_str)
if match:
return match.group(1) == 'True'
return False


def get_reactant_groups_from_template(groups_as_lines: List[str]) -> List[List[str]]:
def get_reactant_groups_from_template(groups_as_lines: List[str],
entries: Optional[Dict[str, str]] = None,
) -> List[List[str]]:
"""
Get the reactant groups from a template content string.
Descends the entries if a group is defined as an OR complex,
e.g.: group = "OR{Xtrirad_H, Xbirad_H, Xrad_H, X_H}"

Args:
groups_as_lines (List[str]): The template content string.
entries (Dict[str, str], optional): Pre-extracted entries.

Returns:
List[List[str]]: The non-complex reactant groups.
"""
reactant_labels = get_initial_reactant_labels_from_template(groups_as_lines)
if entries is None:
entries = get_entries(groups_as_lines, entry_labels=reactant_labels)
result = list()
for reactant_label in reactant_labels:
if 'OR{' not in get_group_adjlist(groups_as_lines, entry_label=reactant_label):
adj = get_group_adjlist(groups_as_lines, entry_label=reactant_label, entries=entries)
if 'OR{' not in adj:
result.append([reactant_label])
else:
stack = [reactant_label]
while any('OR{' in get_group_adjlist(groups_as_lines, entry_label=label) for label in stack):
while any('OR{' in get_group_adjlist(groups_as_lines, entry_label=label, entries=entries) for label in stack):
label = stack.pop(0)
group_adjlist = get_group_adjlist(groups_as_lines, entry_label=label)
group_adjlist = get_group_adjlist(groups_as_lines, entry_label=label, entries=entries)
if 'OR{' not in group_adjlist:
stack.append(label)
else:
Expand Down Expand Up @@ -764,13 +799,15 @@ def get_initial_reactant_labels_from_template(groups_as_lines: List[str],
Returns:
List[str]: The reactant groups.
"""
labels = list()
for line in groups_as_lines:
match = re.search(r'products=\[(.*?)\]', line) if products else re.search(r'reactants=\[(.*?)\]', line)
if match:
labels = match.group(1).replace('"', '').split(', ')
break
return labels
groups_str = ''.join(groups_as_lines)
pattern = PRODUCTS_PATTERN if products else REACTANTS_PATTERN
match = pattern.search(groups_str)
if match:
content = match.group(1)
# Use regex to find all quoted strings (with backreferences) or unquoted words
labels = re.findall(r'(["\'])(.*?)\1|(\w+)', content)
return [label[1] or label[2] for label in labels]
return list()


def get_recipe_actions(groups_as_lines: List[str]) -> List[List[str]]:
Expand All @@ -796,44 +833,77 @@ def get_recipe_actions(groups_as_lines: List[str]) -> List[List[str]]:


def get_entries(groups_as_lines: List[str],
entry_labels: List[str],
entry_labels: Optional[List[str]] = None,
recursive: bool = False,
) -> Dict[str, str]:
"""
Get the requested entries grom a template content string.
Get the requested entries from a template content string.

Args:
groups_as_lines (List[str]): The template content string.
entry_labels (List[str]): The entry labels to extract.
entry_labels (List[str], optional): The entry labels to extract. If None, all entries are extracted.
recursive (bool, optional): Whether to recursively extract child entries for OR complexes.

Returns:
Dict[str, str]: The extracted entries, keys are the labels, values are the groups.
"""
groups_str = ''.join(groups_as_lines)
entries = re.findall(r'entry\((.*?)\)', groups_str, re.DOTALL)
specific_entries = dict()
for i, entry in enumerate(entries):
label_match = re.search(r'label = "(.*?)"', entry)
group_match = re.search(r'group =(.*?)(?=\w+ =)', entry, re.DOTALL)
if label_match is not None and group_match is not None and label_match.group(1) in entry_labels:
specific_entries[label_match.group(1)] = clean_text(group_match.group(1))
if i > 2000:
break
return specific_entries
groups_str = "\n" + "".join(groups_as_lines)
# Split by entry( but keep the delimiter-ish part
parts = re.split(r"\nentry\s*\(", groups_str)

temp_entries = {}
label_pat = re.compile(r"label\s*=\s*(?:([\"'])(.*?)\1|(\w+))")
group_pat = re.compile(r"group\s*=\s*(?:\"\"\"(.*?)\"\"\"|([\"'])(.*?)\2|(OR\{.*?\}))", re.DOTALL)

for part in parts[1:]: # Skip the header
label_match = label_pat.search(part)
group_match = group_pat.search(part)
if label_match and group_match:
label = label_match.group(2) or label_match.group(3)
# Match group 1 (triple), 3 (single/double), or 4 (OR)
adj = group_match.group(1) or group_match.group(3) or group_match.group(4)
temp_entries[label] = clean_text(adj)

if entry_labels is None:
return temp_entries

all_entries = {}
to_process = list(entry_labels)
processed = set()
while to_process:
label = to_process.pop()
if label in processed or label not in temp_entries:
continue
processed.add(label)
adj = temp_entries[label]
all_entries[label] = adj
if recursive and 'OR{' in adj:
# Match OR{label1, label2, ...}
or_match = re.search(r'OR\s*\{\s*(.*?)\s*\}', adj, re.DOTALL)
if or_match:
children_str = or_match.group(1)
children = [c.strip() for c in children_str.split(',')]
to_process.extend(children)
return all_entries


def get_group_adjlist(groups_as_lines: List[str],
entry_label: str,
entries: Optional[Dict[str, str]] = None,
) -> str:
"""
Get the corresponding group value for the given entry label.

Args:
groups_as_lines (List[str]): The template content string.
entry_label (str): The entry label to extract.
entries (Dict[str, str], optional): Pre-extracted entries.

Returns:
str: The extracted group.
"""
if entries is not None and entry_label in entries:
return entries[entry_label]
specific_entries = get_entries(groups_as_lines, entry_labels=[entry_label])
return specific_entries[entry_label]

Expand Down
35 changes: 30 additions & 5 deletions arc/job/adapters/ts/heuristics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
from arc.species.species import ARCSpecies
from arc.species.zmat import _compare_zmats, get_parameter_from_atom_indices

from arc.species.species import check_isomorphism
from arc.species.zmat import remove_zmat_atom_0
from arc.species.converter import relocate_zmat_dummy_atoms_to_the_end


class TestHeuristicsAdapter(unittest.TestCase):
"""
Expand Down Expand Up @@ -1414,11 +1418,32 @@ def test_get_new_zmat2_map(self):
# expected_new_map = {0: 12, 1: 13, 2: 'X24', 3: 14, 4: 15, 5: 16, 6: 'X25', 7: 17, 8: 'X26', 9: 18, 10: 19,
# 11: 20, 12: 21, 13: 22, 14: 'X27', 15: 23, 16: 'X28', 17: 2, 18: 3, 19: 1, 21: 4, 23: 0,
# 25: 7, 26: 6, 28: 5, 20: 'X8', 22: 'X9', 24: 'X10', 27: 'X11'}
expected_new_map = {0: 12, 1: 13, 2: 'X24', 3: 14, 4: 15, 5: 16, 6: 'X25', 7: 17, 8: 'X26', 9: 18, 10: 19,
11: 20, 12: 21, 13: 22, 14: 'X27', 15: 23, 16: 'X28', 17: 2, 18: 1, 19: 3, 21: 0, 23: 4,
25: 5, 26: 6, 28: 7, 20: 'X8', 22: 'X9', 24: 'X10', 27: 'X11'}

self.assertEqual(new_map, expected_new_map)

# Test isomorphism of the mapped reactant_2 part
zmat_2_mod = remove_zmat_atom_0(self.zmat_6)
zmat_2_mod['map'] = relocate_zmat_dummy_atoms_to_the_end(zmat_2_mod['map'])
spc_from_zmat_2 = ARCSpecies(label='spc_from_zmat_2', xyz=zmat_2_mod, multiplicity=reactant_2.multiplicity,
number_of_radicals=reactant_2.number_of_radicals, charge=reactant_2.charge)

# Verify that all physical atom indices in new_map that came from zmat_2 correctly map to reactant_2
# Atom indices in new_map are for the combined species.
# Atoms 0-16 in self.zmat_5, atoms 1-12 in self.zmat_6 (13 atoms total, index 0 removed).
# In get_new_zmat_2_map, zmat_2 atoms are mapped to indices in new_map.

num_atoms_1 = len(self.zmat_5['symbols'])
atom_map = dict()
for i in range(1, len(self.zmat_6['symbols'])):
if not isinstance(self.zmat_6['symbols'][i], str) or self.zmat_6['symbols'][i] != 'X':
# This is a physical atom in zmat_2 (at index i)
# Its index in the combined Z-Matrix is num_atoms_1 + i - 1
combined_idx = num_atoms_1 + i - 1
if combined_idx in new_map:
# new_map[combined_idx] is the index in reactant_2
# i-1 is the index in spc_from_zmat_2
atom_map[i-1] = new_map[combined_idx]

# Verify the atom_map is a valid isomorphism
self.assertTrue(check_isomorphism(spc_from_zmat_2.mol, reactant_2.mol, atom_map))

def test_get_new_map_based_on_zmat_1(self):
"""Test the get_new_map_based_on_zmat_1() function."""
Expand Down
7 changes: 7 additions & 0 deletions arc/mapping/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,13 @@ def map_rxn(rxn: 'ARCReaction',
r_bdes, p_bdes = find_all_breaking_bonds(rxn, r_direction=True, pdi=pdi), find_all_breaking_bonds(rxn, r_direction=False, pdi=pdi)
r_cuts, p_cuts = cut_species_based_on_atom_indices(reactants, r_bdes), cut_species_based_on_atom_indices(products, p_bdes)

if r_cuts is None or p_cuts is None:
if rxn.product_dicts is not None and len(rxn.product_dicts) - 1 > pdi < MAX_PDI:
return map_rxn(rxn, backend=backend, product_dict_index_to_try=pdi + 1)
else:
logger.error(f'Could not cut species for reaction {rxn}')
return None

try:
r_label_map = rxn.product_dicts[pdi]['r_label_map']
p_label_map = rxn.product_dicts[pdi]['p_label_map']
Expand Down
Loading
Loading