Skip to content

Commit 676ff67

Browse files
Routing algorithm / logic added
1 parent 83e9957 commit 676ff67

1 file changed

Lines changed: 77 additions & 2 deletions

File tree

trapi_agent/nodes/router.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,30 @@
1212
- Do NOT set state['predicate'] here; parse_query/construct_* handle it.
1313
"""
1414
from __future__ import annotations
15+
1516
import logging
17+
import os
18+
import pickle
19+
from functools import lru_cache
20+
from pathlib import Path
21+
from typing import Optional, Tuple
22+
23+
import numpy as np
24+
1625
from ..state_types import TRAPIState
1726

1827
logger = logging.getLogger(__name__)
1928

29+
ROUTE_CHOICES = {
30+
"onehop",
31+
"pathfinder",
32+
"pathfinder_constrained",
33+
"treats",
34+
"chem_gene",
35+
"xcrg",
36+
"multihop",
37+
}
38+
2039
KEYWORDS = {
2140
# Two-CURIE path query
2241
"pathfinder": [
@@ -32,21 +51,77 @@
3251
"chem_gene": ["gene", "protein"],
3352
}
3453

54+
55+
@lru_cache(maxsize=1)
56+
def _load_router_model() -> Optional[object]:
57+
"""
58+
Load a pre-trained router model if available.
59+
Path order:
60+
1) env ROUTER_MODEL
61+
2) repo data/router_model.pkl
62+
"""
63+
candidates = []
64+
env_path = os.getenv("ROUTER_MODEL")
65+
if env_path:
66+
candidates.append(Path(env_path))
67+
candidates.append(Path(__file__).resolve().parents[2] / "data" / "router_model.pkl")
68+
69+
for p in candidates:
70+
try:
71+
if p.exists():
72+
with p.open("rb") as fh:
73+
return pickle.load(fh)
74+
except Exception as e:
75+
logger.warning("Router model load failed (%s): %s", p, e)
76+
return None
77+
78+
79+
def _predict_route(model: object, query: str) -> Tuple[str, float]:
80+
"""
81+
Predict route and confidence using a scikit-learn style pipeline.
82+
"""
83+
if hasattr(model, "predict_proba"):
84+
probs = model.predict_proba([query])[0]
85+
conf = float(np.max(probs))
86+
else:
87+
scores = model.decision_function([query])
88+
exps = np.exp(scores - np.max(scores))
89+
probs = exps / exps.sum()
90+
conf = float(np.max(probs))
91+
label = model.predict([query])[0]
92+
return str(label), conf
93+
94+
3595
def node(state: TRAPIState) -> TRAPIState:
3696
# 1) Respect explicit route (e.g., CLI/UI dropdown)
3797
if state.get("route"):
3898
logger.info("Router picked route=%s", state["route"])
3999
return state
40100

41-
# 2) Infer from query keywords
101+
# 2) Model-based routing (if available and confident)
102+
q = (state.get("query") or "").strip()
103+
model = _load_router_model()
104+
if model and q:
105+
try:
106+
label, conf = _predict_route(model, q)
107+
threshold = float(os.getenv("ROUTER_CONF_MIN", "0.55"))
108+
if label in ROUTE_CHOICES and conf >= threshold:
109+
state["route"] = label
110+
logger.info("Router picked route=%s (model, conf=%.2f)", label, conf)
111+
return state
112+
logger.info("Router model low confidence (%.2f), falling back to heuristics", conf)
113+
except Exception as e:
114+
logger.warning("Router model prediction failed: %s", e)
115+
116+
# 3) Infer from query keywords
42117
q = (state.get("query") or "").lower()
43118
for route, toks in KEYWORDS.items():
44119
if any(tok in q for tok in toks):
45120
state["route"] = route
46121
logger.info("Router picked route=%s", route)
47122
return state
48123

49-
# 3) Default
124+
# 4) Default
50125
state["route"] = "onehop"
51126
logger.info("Router picked route=onehop")
52127
return state

0 commit comments

Comments
 (0)