Skip to content

Commit ba954b7

Browse files
Added xDTD, xCRG, pathfinder, pathfinder-constrained routes
1 parent 20e7580 commit ba954b7

18 files changed

Lines changed: 1437 additions & 276 deletions

scripts/demo_streamlit.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,46 @@
1919
from trapi_agent.agent_graph import graph
2020

2121
# ─────────────────────────── UI OPTIONS ───────────────────────────
22-
ROUTES = ["Auto (router picks)", "onehop", "pathfinder", "treats", "chem_gene"]
22+
ROUTES = ["Auto (router picks)", "onehop", "pathfinder", "treats", "chem_gene", "xcrg","pathfinder_constrained"]
2323

2424
EXAMPLES: List[Tuple[str, str]] = [
2525
("onehop", "What proteins does acetaminophen interact with?"),
2626
("onehop", "What biological processes are related to GFAP?"),
2727
("pathfinder", "Find a path between asthma and diabetes mellitus"),
2828
("pathfinder", "By what paths are ibuprofen and headaches connected?"),
2929
("treats", "What drugs treat asthma?"),
30+
("treats", "What drugs treat diabetes mellitus?"),
31+
("treats", "What drugs treat castleman disease?"),
32+
("treats", "What drugs treat What drugs treat malignant ciliary body melanoma?"),
33+
("treats", "What chemicals are predicted to be useful to treat malignant ciliary body melanoma?"),
34+
("pathfinder", "How are neutropenia and filgrastim related in multi-hop paths?"),
35+
("pathfinder","Find me paths between ibuprofen and COX1?"),
36+
("xcrg","what genes are upregulated by filgrastim?"),
37+
("xcrg","what genes are downregulated by filgrastim?"),
38+
("xcrg","which drugs inhibit the activity of ABCB1"),
39+
("pathfinder_constrained", "Find a path between asthma and diabetes mellitus that includes a protein"),
40+
("pathfinder_constrained", "Find a path between asthma and diabetes mellitus that includes a biological_process"),
41+
('pathfinder_constrained', "How are ibuprofen and headaches related via paths going through genes"),
42+
('pathfinder_constrained', "By what paths are BRCA1 and breast cancer connected via genes"),
43+
('pathfinder_constrained', "Find me paths between ibuprofen and COX1 via proteins"),
44+
('pathfinder_constrained', "Find me paths between ibuprofen and COX1 via biological_processes"),
45+
('pathfinder_constrained', "Find me paths between ibuprofen and COX1 via diseases"),
46+
('pathfinder_constrained', "How are EGFR and lung cancer related through proteins?"),
47+
('pathfinder_constrained', "How are neutropenia and filgrastim related in multi-hop paths going through diseases?"),
48+
('pathfinder_constrained', "How are obesity and insulin resistance connected via diseases?"),
49+
('pathfinder_constrained', "Show connections between BRCA1 and asthma via a drug?"),
50+
('pathfinder_constrained', "Paths between TNF and rheumatoid arthritis via a chemical?"),
51+
('pathfinder_constrained', "Find paths from LRRK2 to Parkinson disease via small molecules?"),
52+
('pathfinder_constrained', "Show multi-hop paths between BRCA1 and DNA repair through pathways?"),
53+
('pathfinder_constrained', "Show paths from kinase inhibitors to EGFR via molecular activities?"),
54+
('pathfinder_constrained', "How are COX1 and prostaglandin synthesis related through activities?"),
55+
('pathfinder_constrained', "Find paths between GFAP and seizures via tissues?"),
56+
('pathfinder_constrained', "Show paths between HIF1A and hypoxia via cells?"),
57+
('pathfinder_constrained', "How are BRCA1 and DNA repair related through organelles?"),
58+
('pathfinder_constrained', "Show paths between APOE and Alzheimer disease through phenotypes)"),
59+
('pathfinder_constrained', "Find paths from TP53 to cancer via phenotypes)"),
60+
61+
3062
]
3163

3264
logging.getLogger("chromadb.telemetry.product.posthog").setLevel(logging.ERROR)
@@ -175,4 +207,7 @@ def render_qg_pyvis(qg: Dict[str, Any]) -> None:
175207
# Save to history
176208
st.session_state.history.append({"route": final_state.get("route"), "query": query, "dt": dt})
177209

178-
# PYTHONPATH="$(pwd)" streamlit run scripts/demo_streamlit.py --server.port 7860 --server.address 0.0.0.011
210+
# PYTHONPATH="$(pwd)" streamlit run scripts/demo_streamlit.py --server.port 7860 --server.address 0.0.0.0
211+
# lsof -i:7860
212+
# kill -9 3607435
213+
# pkill -f streamlit

scripts/run_agent_cli.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
2+
13
#!/usr/bin/env python3
24
"""
35
Command-line interface for invoking the TRAPI agent graph on natural-language queries.
46
57
Usage:
68
python -m scripts.run_agent_cli "What drugs treat asthma?"
79
python -m scripts.run_agent_cli "Find paths between ibuprofen and COX1" --route pathfinder
10+
python -m scripts.run_agent_cli "what genes are upregulated by filgrastim?" --route xcrg
811
echo "What proteins interact with aspirin?" | python -m scripts.run_agent_cli
912
"""
1013
from __future__ import annotations
@@ -17,7 +20,9 @@
1720

1821
from trapi_agent.agent_graph import graph
1922

20-
ROUTE_CHOICES = ["onehop", "pathfinder", "treats", "chem_gene", "multihop"]
23+
# Add xcrg to the allowed routes
24+
ROUTE_CHOICES = ["onehop", "pathfinder", "pathfinder_constrained", "treats", "chem_gene", "multihop", "xcrg"]
25+
2126

2227

2328
def parse_args() -> argparse.Namespace:

trapi_agent/agent_graph.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
from __future__ import annotations
23

34
import logging
@@ -14,13 +15,16 @@
1415

1516
logger = logging.getLogger(__name__)
1617

18+
1719
def _import_route_modules() -> None:
18-
for mod in ("onehop", "pathfinder", "treats", "chem_gene"):
20+
# Register all available routes (explicit routing uses these)
21+
for mod in ("onehop", "pathfinder", "treats", "chem_gene", "xcrg","pathfinder_constrained"):
1922
try:
2023
importlib.import_module(f"{__package__}.routes.{mod}")
2124
logger.debug("Imported route module: %s", mod)
2225
except Exception as e:
23-
logger.warning("⚠️ Failed importing route '%s': %s", mod, e)
26+
logger.warning(" Failed importing route '%s': %s", mod, e)
27+
2428

2529
# Import at module import time (before graph is built)
2630
_import_route_modules()
@@ -42,13 +46,12 @@ def build_agent_graph() -> Any:
4246
g.add_node("ResolveEntities", resolve_entities.node)
4347

4448
def _set_route_flags(state: TRAPIState) -> TRAPIState:
45-
# IMPORTANT: do NOT force 'onehop' here if unknown; keep the string.
4649
inbound = state.get("route")
4750
route = inbound or "onehop"
4851
handler = R.ROUTES.get(route)
4952

5053
if handler is None:
51-
# Keep requested route string; only derive skip_schema for pathfinder.
54+
# Keep the requested route string; only guess skip_schema for pathfinder
5255
state["route"] = route
5356
state["skip_schema"] = (route == "pathfinder")
5457
logger.info(
@@ -57,8 +60,8 @@ def _set_route_flags(state: TRAPIState) -> TRAPIState:
5760
)
5861
return state
5962

60-
state["route"] = route
61-
# handler.name
63+
# Canonicalize to the registered handler
64+
state["route"] = handler.name
6265
state["skip_schema"] = handler.skip_schema
6366
logger.info(
6467
"SetRouteFlags: inbound='%s' → handler='%s' (skip_schema=%s)",
@@ -123,13 +126,17 @@ def _validate(state: TRAPIState) -> TRAPIState:
123126

124127
g.add_conditional_edges(
125128
"Validate",
126-
{ END: lambda s: s.get("valid", False),
127-
"Fix": lambda s: not s.get("valid", False) }
129+
{
130+
END: lambda s: s.get("valid", False),
131+
"Fix": lambda s: not s.get("valid", False),
132+
}
128133
)
129134
g.add_conditional_edges(
130135
"Fix",
131-
{ "Validate": lambda s: s.get("fix_attempts", 0) < settings.MAX_FIX_ATTEMPTS,
132-
END: lambda s: s.get("fix_attempts", 0) >= settings.MAX_FIX_ATTEMPTS }
136+
{
137+
"Validate": lambda s: s.get("fix_attempts", 0) < settings.MAX_FIX_ATTEMPTS,
138+
END: lambda s: s.get("fix_attempts", 0) >= settings.MAX_FIX_ATTEMPTS,
139+
}
133140
)
134141

135142
compiled = g.compile()
@@ -140,5 +147,6 @@ def _validate(state: TRAPIState) -> TRAPIState:
140147
pass
141148
return compiled
142149

150+
143151
# Global
144152
graph = build_agent_graph()
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
#!/usr/bin/env python3
2+
"""
3+
construct_pathfinder_constrained.py
4+
5+
Build a Pathfinder-style TRAPI query_graph with a single intermediate category constraint:
6+
7+
nodes: exactly two (both pinned) → {"ids": ["CURIE"]}
8+
paths: single p0 with predicates=["biolink:related_to"] and
9+
constraints:[{"intermediate_categories":[<ONE biolink:Class>]}]
10+
11+
Priority for ONE intermediate class:
12+
1) state['intermediate_category'] (string)
13+
2) first usable from state['intermediate_hints'] (list[str])
14+
3) regex hint from the NL query (e.g., "via genes", "through diseases", "contain a drug")
15+
4) fallback to state['generic_types'] (but skip ChemicalEntity unless query mentions drug/chemical/compound)
16+
17+
All candidates are canonicalized via biolink_utils.canonicalize_class().
18+
"""
19+
from __future__ import annotations
20+
21+
import logging
22+
import re
23+
from typing import Dict, Any, List, Optional, Tuple
24+
25+
from ..state_types import TRAPIState
26+
from ..utils.biolink_utils import canonicalize_class
27+
28+
logger = logging.getLogger(__name__)
29+
REQ_PRED = "biolink:related_to"
30+
31+
32+
# ── helpers ───────────────────────────────────────────────────────────────────
33+
34+
def _unique(seq: List[str]) -> List[str]:
35+
seen, out = set(), []
36+
for x in seq or []:
37+
if x and x not in seen:
38+
out.append(x)
39+
seen.add(x)
40+
return out
41+
42+
43+
def _pick_two_pinned(nodes: Dict[str, Dict[str, Any]]) -> List[Tuple[str, str]]:
44+
"""Return up to 2 (node_id, CURIE) pairs for pinned nodes."""
45+
seen, out = set(), []
46+
for nid, meta in (nodes or {}).items():
47+
curie = meta.get("id")
48+
if curie and curie not in seen:
49+
out.append((nid, curie))
50+
seen.add(curie)
51+
if len(out) == 2:
52+
break
53+
return out
54+
55+
56+
# Single-class, deterministic regex hints. The FIRST match wins.
57+
_HINT_ORDER: List[Tuple[str, str]] = [
58+
# genes / proteins
59+
(r"\b(?:via|through)\s+(?:the\s+)?genes?\b", "Gene"),
60+
(r"\b(?:via|through)\s+(?:the\s+)?proteins?\b", "Protein"),
61+
# diseases
62+
(r"\b(?:via|through)\s+(?:the\s+)?diseases?\b", "Disease"),
63+
# drug / chemical
64+
(r"\bcontain(?:s|ing)?\s+(?:a\s+)?drug\b", "Drug"),
65+
(r"\b(?:via|through)\s+(?:a\s+)?drug\b", "Drug"),
66+
(r"\b(?:via|through)\s+(?:a\s+)?chemicals?\b", "ChemicalEntity"),
67+
(r"\b(?:via|through)\s+(?:a\s+)?compounds?\b", "ChemicalEntity"),
68+
# pathway / phenotype / anatomy
69+
(r"\b(?:via|through)\s+(?:the\s+)?pathways?\b", "Pathway"),
70+
(r"\b(?:via|through)\s+(?:the\s+)?phenotypes?\b", "PhenotypicFeature"),
71+
(r"\b(?:via|through)\s+(?:the\s+)?tissues?\b", "AnatomicalEntity"),
72+
(r"\b(?:via|through)\s+anatom(?:y|ical(?:\s+entity)?)\b", "AnatomicalEntity"),
73+
]
74+
75+
76+
def _query_hint_category(query: str) -> Optional[str]:
77+
q = (query or "").lower()
78+
for pat, raw in _HINT_ORDER:
79+
if re.search(pat, q, flags=re.I):
80+
cat = canonicalize_class(raw)
81+
if cat:
82+
return cat
83+
return None
84+
85+
86+
def _pick_single_intermediate(state: TRAPIState) -> Optional[str]:
87+
"""
88+
Choose exactly ONE intermediate Biolink category.
89+
"""
90+
# 1) explicit single choice (e.g., from UI)
91+
ui_raw = (state.get("intermediate_category") or "").strip()
92+
if ui_raw:
93+
cat = canonicalize_class(ui_raw)
94+
if cat:
95+
return cat
96+
97+
# 2) optional list of hints (first usable)
98+
for raw in _unique(state.get("intermediate_hints", []) or []):
99+
cat = canonicalize_class(raw)
100+
if cat:
101+
return cat
102+
103+
# 3) regex hint from query
104+
cat = _query_hint_category(state.get("query", ""))
105+
if cat:
106+
return cat
107+
108+
# 4) fallback from generic_types, but avoid ChemicalEntity unless query mentions drug/chemical/compound
109+
generics = _unique(state.get("generic_types", []) or [])
110+
ql = (state.get("query") or "").lower()
111+
allow_chem = bool(re.search(r"\b(drug|chemical|compound)s?\b", ql))
112+
113+
# Preferred order for informative constraints
114+
preferred = [
115+
"biolink:Gene",
116+
"biolink:Protein",
117+
"biolink:Disease",
118+
"biolink:Pathway",
119+
"biolink:PhenotypicFeature",
120+
"biolink:AnatomicalEntity",
121+
"biolink:Drug",
122+
"biolink:ChemicalEntity",
123+
]
124+
125+
for want in preferred:
126+
if want in generics and (want != "biolink:ChemicalEntity" or allow_chem):
127+
return want
128+
129+
# Last-ditch: first canonicalizable thing
130+
for raw in generics:
131+
cat = canonicalize_class(raw)
132+
if cat:
133+
return cat
134+
135+
return None
136+
137+
138+
# ── node ──────────────────────────────────────────────────────────────────────
139+
140+
def node(state: TRAPIState) -> TRAPIState:
141+
"""
142+
nodes:
143+
n0: pinned CURIE
144+
n1: pinned CURIE
145+
paths:
146+
p0: subject=n0, object=n1, predicates=[biolink:related_to],
147+
constraints=[{'intermediate_categories':[<ONE class>]}] (only if chosen)
148+
"""
149+
src_nodes: Dict[str, Dict[str, Any]] = state.get("nodes", {}) or {}
150+
pinned = _pick_two_pinned(src_nodes)
151+
if len(pinned) < 2:
152+
logger.warning("Pathfinder-constrained needs 2 pinned nodes; found %d", len(pinned))
153+
154+
# Re-key as n0/n1
155+
qg_nodes: Dict[str, Dict[str, Any]] = {}
156+
for i, (_, curie) in enumerate(pinned[:2]):
157+
qg_nodes[f"n{i}"] = {"ids": [curie]}
158+
159+
p0: Dict[str, Any] = {
160+
"subject": "n0",
161+
"object": "n1",
162+
"predicates": [REQ_PRED],
163+
}
164+
165+
picked = _pick_single_intermediate(state)
166+
if picked:
167+
p0["constraints"] = [{"intermediate_categories": [picked]}]
168+
logger.info("Added intermediate_categories constraint: %s", [picked])
169+
else:
170+
logger.info("No intermediate category found; emitting unconstrained path.")
171+
172+
state["output_json"] = {
173+
"message": {
174+
"query_graph": {
175+
"nodes": qg_nodes,
176+
"paths": {"p0": p0},
177+
}
178+
}
179+
}
180+
return state

0 commit comments

Comments
 (0)