From e5417ff6e754cbd5eba8f57b6a0c969c03129d79 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Thu, 21 May 2026 11:15:48 -0700 Subject: [PATCH 1/2] fix bug where updated context doesn't get passed to n+2 node --- src/modelplane/evaluator/dag.py | 8 ++++---- tests/unit/evaluator/mocks.py | 7 +++++++ tests/unit/evaluator/test_dag.py | 34 ++++++++++++++++++++++++++++++-- 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py index f57daf7..31e2f50 100644 --- a/src/modelplane/evaluator/dag.py +++ b/src/modelplane/evaluator/dag.py @@ -181,7 +181,7 @@ def _run_traced(self, ctx: EvalContext) -> tuple[DAGOutput, set[tuple[str, str]] for node_name in self._ordered: if node_name not in reachable: continue - run_ctx = ctx.with_parent_outputs( + ctx = ctx.with_parent_outputs( { pred: node_outputs[pred] for pred in self._predecessors[node_name] @@ -190,14 +190,14 @@ def _run_traced(self, ctx: EvalContext) -> tuple[DAGOutput, set[tuple[str, str]] ) node = self._nodes[node_name] if isinstance(node, CacheableNodeMixin): - key = node.cache_key(run_ctx) + key = node.cache_key(ctx) if key in self._node_caches[node.name]: output = self._node_caches[node.name][key] else: - output = node.run(run_ctx) + output = node.run(ctx) self._node_caches[node.name][key] = output else: - output = node.run(run_ctx) + output = node.run(ctx) node_outputs[node_name] = output total_cost += output.realized_cost if isinstance(output.value, Verdict): diff --git a/tests/unit/evaluator/mocks.py b/tests/unit/evaluator/mocks.py index 68d3a42..f23e464 100644 --- a/tests/unit/evaluator/mocks.py +++ b/tests/unit/evaluator/mocks.py @@ -132,6 +132,13 @@ def output_tokens(self, ctx: EvalContext) -> int: return context_token_count(ctx) +class NoOpEnricher(Enricher): + """Passes context through without changing it.""" + + def run(self, ctx: EvalContext) -> NodeOutput: + return self.build_output(None, ctx) + + class FixedScorer(Enricher): """Returns a fixed float score regardless of context.""" diff --git a/tests/unit/evaluator/test_dag.py b/tests/unit/evaluator/test_dag.py index c9b3071..461e317 100644 --- a/tests/unit/evaluator/test_dag.py +++ b/tests/unit/evaluator/test_dag.py @@ -11,8 +11,15 @@ from modelplane.evaluator.safety import Safety from .conftest import skip_in_ci -from .mocks import AlwaysTrueCacheable - +from .mocks import ( + AlwaysSafe, + AlwaysTrue, + AlwaysTrueCacheable, + LowerCaser, + LowerCaseScorer, + NoOpEnricher, + ThresholdArbiter, +) def test_dag_outputs(simple_dag): assert simple_dag.verdict_type == Safety @@ -182,6 +189,29 @@ def test_dag_run(simple_dag, sample_ctx): assert dag_output.verdict.name == "UNSAFE" +def test_dag_passes_updated_context_to_downstream_nodes(): + ctx = EvalContext(prompt="x", response="HELLO") + dag = ( + Composer("ctx_update", verdict_type=Safety) + .add_node( + AlwaysTrue( + name="always_true", + routes_true=["lower_caser"], + routes_false=["always_safe"], + ) + ) + .add_node(AlwaysSafe(name="always_safe")) + .add_node(LowerCaser(name="lower_caser", routes=["noop"])) + .add_node(NoOpEnricher(name="noop", routes=["lower_scorer"])) + .add_node(LowerCaseScorer(name="lower_scorer", routes=["threshold_arbiter"])) + .add_node(ThresholdArbiter(name="threshold_arbiter", threshold=0.5)) + ) + dag_output = dag.run(ctx) + assert dag_output.node_outputs["lower_caser"].updated_ctx.response == "hello" + # Scorer reads ctx.response; 1.0 only if it saw the lowercased update from lower_caser. + assert dag_output.node_outputs["lower_scorer"].value == pytest.approx(1.0) + + def test_dag_run_with_dataframe(simple_dag, tmp_path): # "hello world" (space lowers avg below threshold) → safe # "helloworld" (no space, avg = 0.5 = threshold) → unsafe From e16eae46f0c1c28444cf200f5ab1a50974d1b4e2 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Fri, 22 May 2026 14:33:23 -0700 Subject: [PATCH 2/2] more tests --- tests/unit/evaluator/test_dag.py | 57 ++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/unit/evaluator/test_dag.py b/tests/unit/evaluator/test_dag.py index 461e317..022b9b8 100644 --- a/tests/unit/evaluator/test_dag.py +++ b/tests/unit/evaluator/test_dag.py @@ -19,6 +19,7 @@ LowerCaseScorer, NoOpEnricher, ThresholdArbiter, + UpperCaser, ) def test_dag_outputs(simple_dag): @@ -212,6 +213,62 @@ def test_dag_passes_updated_context_to_downstream_nodes(): assert dag_output.node_outputs["lower_scorer"].value == pytest.approx(1.0) +def test_dag_updated_context_not_passed_to_parallel_nodes(): + # noop and lower caser are parallel nodes. noop should not see the updated context from lower_caser. + ctx = EvalContext(prompt="x", response="HELLO") + dag = ( + Composer("ctx_update", verdict_type=Safety) + .add_node( + AlwaysTrue( + name="always_true", + routes_true=["lower_caser", "noop"], + routes_false=["always_safe"], + ) + ) + .add_node(AlwaysSafe(name="always_safe")) + .add_node(LowerCaser(name="lower_caser", routes=["lower_scorer"])) + .add_node(NoOpEnricher(name="noop", routes=["lower_scorer"])) + .add_node(LowerCaseScorer(name="lower_scorer", routes=["threshold_arbiter"])) + .add_node(ThresholdArbiter(name="threshold_arbiter", threshold=0.5)) + ) + dag_output = dag.run(ctx) + + assert dag_output.node_outputs["lower_caser"].original_ctx.response == "HELLO" + assert dag_output.node_outputs["lower_caser"].updated_ctx.response == "hello" + + assert dag_output.node_outputs["noop"].original_ctx.response == "HELLO" + assert dag_output.node_outputs["noop"].updated_ctx is None + + assert dag_output.node_outputs["lower_scorer"].original_ctx.response == "hello" + # Scorer reads ctx.response; 1.0 only if it saw the lowercased update from lower_caser. + assert dag_output.node_outputs["lower_scorer"].value == pytest.approx(1.0) + + +def test_dag_parallel_nodes_different_updated_contexts_raises_error(): + # upper caser and lower caser are parallel nodes, they update the dontext differently which should raise an error. + ctx = EvalContext(prompt="x", response="HELLO") + dag = ( + Composer("ctx_update", verdict_type=Safety) + .add_node( + AlwaysTrue( + name="always_true", + routes_true=["lower_caser", "upper_caser"], + routes_false=["always_safe"], + ) + ) + .add_node(AlwaysSafe(name="always_safe")) + .add_node(LowerCaser(name="lower_caser", routes=["lower_scorer"])) + .add_node(UpperCaser(name="upper_caser", routes=["lower_scorer"])) + .add_node(LowerCaseScorer(name="lower_scorer", routes=["threshold_arbiter"])) + .add_node(ThresholdArbiter(name="threshold_arbiter", threshold=0.5)) + ) + with pytest.raises( + ValueError, + match="all parent outputs must have the same updated context", + ): + dag.run(ctx) + + def test_dag_run_with_dataframe(simple_dag, tmp_path): # "hello world" (space lowers avg below threshold) → safe # "helloworld" (no space, avg = 0.5 = threshold) → unsafe