Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/modelplane/evaluator/dag.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only case where this is a bit tricky is where you have two parallel nodes.

Suppose you have Node A -> (Node B, Node C) -> Node D

Because we run Node B and Node C sequentially, Node C will see the updated context from Node B.

But I think this isn't an issue, because later we enforce that all parents of Node D must produce the same updated context, so I think it's safe?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is actually smart enough right now to handle the parallel case as well! Because when we update the context, we only pass the outputs from that node's predecessors. And node B is not a predecessor for node C.

I added some more tests to confirm that things work as expected.

Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _run_traced(self, ctx: EvalContext) -> tuple[DAGOutput, set[tuple[str, str]]
for node_name in self._ordered:
if node_name not in reachable:
continue
run_ctx = ctx.with_parent_outputs(
ctx = ctx.with_parent_outputs(
{
pred: node_outputs[pred]
for pred in self._predecessors[node_name]
Expand All @@ -190,14 +190,14 @@ def _run_traced(self, ctx: EvalContext) -> tuple[DAGOutput, set[tuple[str, str]]
)
node = self._nodes[node_name]
if isinstance(node, CacheableNodeMixin):
key = node.cache_key(run_ctx)
key = node.cache_key(ctx)
if key in self._node_caches[node.name]:
output = self._node_caches[node.name][key]
else:
output = node.run(run_ctx)
output = node.run(ctx)
self._node_caches[node.name][key] = output
else:
output = node.run(run_ctx)
output = node.run(ctx)
node_outputs[node_name] = output
total_cost += output.realized_cost
if isinstance(output.value, Verdict):
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/evaluator/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ def output_tokens(self, ctx: EvalContext) -> int:
return context_token_count(ctx)


class NoOpEnricher(Enricher):
"""Passes context through without changing it."""

def run(self, ctx: EvalContext) -> NodeOutput:
return self.build_output(None, ctx)


class FixedScorer(Enricher):
"""Returns a fixed float score regardless of context."""

Expand Down
91 changes: 89 additions & 2 deletions tests/unit/evaluator/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,16 @@
from modelplane.evaluator.safety import Safety

from .conftest import skip_in_ci
from .mocks import AlwaysTrueCacheable

from .mocks import (
AlwaysSafe,
AlwaysTrue,
AlwaysTrueCacheable,
LowerCaser,
LowerCaseScorer,
NoOpEnricher,
ThresholdArbiter,
UpperCaser,
)

def test_dag_outputs(simple_dag):
assert simple_dag.verdict_type == Safety
Expand Down Expand Up @@ -182,6 +190,85 @@ def test_dag_run(simple_dag, sample_ctx):
assert dag_output.verdict.name == "UNSAFE"


def test_dag_passes_updated_context_to_downstream_nodes():
ctx = EvalContext(prompt="x", response="HELLO")
dag = (
Composer("ctx_update", verdict_type=Safety)
.add_node(
AlwaysTrue(
name="always_true",
routes_true=["lower_caser"],
routes_false=["always_safe"],
)
)
.add_node(AlwaysSafe(name="always_safe"))
.add_node(LowerCaser(name="lower_caser", routes=["noop"]))
.add_node(NoOpEnricher(name="noop", routes=["lower_scorer"]))
.add_node(LowerCaseScorer(name="lower_scorer", routes=["threshold_arbiter"]))
.add_node(ThresholdArbiter(name="threshold_arbiter", threshold=0.5))
)
dag_output = dag.run(ctx)
assert dag_output.node_outputs["lower_caser"].updated_ctx.response == "hello"
# Scorer reads ctx.response; 1.0 only if it saw the lowercased update from lower_caser.
assert dag_output.node_outputs["lower_scorer"].value == pytest.approx(1.0)


def test_dag_updated_context_not_passed_to_parallel_nodes():
# noop and lower caser are parallel nodes. noop should not see the updated context from lower_caser.
ctx = EvalContext(prompt="x", response="HELLO")
dag = (
Composer("ctx_update", verdict_type=Safety)
.add_node(
AlwaysTrue(
name="always_true",
routes_true=["lower_caser", "noop"],
routes_false=["always_safe"],
)
)
.add_node(AlwaysSafe(name="always_safe"))
.add_node(LowerCaser(name="lower_caser", routes=["lower_scorer"]))
.add_node(NoOpEnricher(name="noop", routes=["lower_scorer"]))
.add_node(LowerCaseScorer(name="lower_scorer", routes=["threshold_arbiter"]))
.add_node(ThresholdArbiter(name="threshold_arbiter", threshold=0.5))
)
dag_output = dag.run(ctx)

assert dag_output.node_outputs["lower_caser"].original_ctx.response == "HELLO"
assert dag_output.node_outputs["lower_caser"].updated_ctx.response == "hello"

assert dag_output.node_outputs["noop"].original_ctx.response == "HELLO"
assert dag_output.node_outputs["noop"].updated_ctx is None

assert dag_output.node_outputs["lower_scorer"].original_ctx.response == "hello"
# Scorer reads ctx.response; 1.0 only if it saw the lowercased update from lower_caser.
assert dag_output.node_outputs["lower_scorer"].value == pytest.approx(1.0)


def test_dag_parallel_nodes_different_updated_contexts_raises_error():
# upper caser and lower caser are parallel nodes, they update the dontext differently which should raise an error.
ctx = EvalContext(prompt="x", response="HELLO")
dag = (
Composer("ctx_update", verdict_type=Safety)
.add_node(
AlwaysTrue(
name="always_true",
routes_true=["lower_caser", "upper_caser"],
routes_false=["always_safe"],
)
)
.add_node(AlwaysSafe(name="always_safe"))
.add_node(LowerCaser(name="lower_caser", routes=["lower_scorer"]))
.add_node(UpperCaser(name="upper_caser", routes=["lower_scorer"]))
.add_node(LowerCaseScorer(name="lower_scorer", routes=["threshold_arbiter"]))
.add_node(ThresholdArbiter(name="threshold_arbiter", threshold=0.5))
)
with pytest.raises(
ValueError,
match="all parent outputs must have the same updated context",
):
dag.run(ctx)


def test_dag_run_with_dataframe(simple_dag, tmp_path):
# "hello world" (space lowers avg below threshold) → safe
# "helloworld" (no space, avg = 0.5 = threshold) → unsafe
Expand Down
Loading