Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions python/packages/core/agent_framework/_workflows/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
self._state = state
self._running = False
self._resumed_from_checkpoint = False # Track whether we resumed
self._resume_parent_checkpoint_id: CheckpointID | None = None

@property
def context(self) -> RunnerContext:
Expand All @@ -81,7 +82,9 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
raise WorkflowRunnerException("Runner is already running.")

self._running = True
previous_checkpoint_id: CheckpointID | None = None
previous_checkpoint_id: CheckpointID | None = (
self._resume_parent_checkpoint_id if self._resumed_from_checkpoint else None
)
try:
# Emit any events already produced prior to entering loop
if await self._ctx.has_events():
Expand Down Expand Up @@ -154,6 +157,7 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:

logger.info(f"Workflow completed after {self._iteration} supersteps")
self._resumed_from_checkpoint = False # Reset resume flag for next run
self._resume_parent_checkpoint_id = None
finally:
self._running = False

Expand Down Expand Up @@ -285,7 +289,7 @@ async def restore_from_checkpoint(
# Apply the checkpoint to the context
await self._ctx.apply_checkpoint(checkpoint)
# Mark the runner as resumed
self._mark_resumed(checkpoint.iteration_count)
self._mark_resumed(checkpoint.iteration_count, checkpoint.checkpoint_id)

logger.info(f"Successfully restored workflow from checkpoint: {checkpoint_id}")
except WorkflowCheckpointException:
Expand Down Expand Up @@ -351,13 +355,14 @@ def _parse_edge_runners(self, edge_runners: list[EdgeRunner]) -> dict[str, list[

return parsed

def _mark_resumed(self, iteration: int) -> None:
def _mark_resumed(self, iteration: int, checkpoint_id: CheckpointID) -> None:
"""Mark the runner as having resumed from a checkpoint.

Optionally set the current iteration and max iterations.
"""
self._resumed_from_checkpoint = True
self._iteration = iteration
self._resume_parent_checkpoint_id = checkpoint_id

async def _set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None:
"""Store executor state in state under a reserved key.
Expand Down
88 changes: 88 additions & 0 deletions python/packages/core/tests/workflow/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,94 @@ async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None:
)


async def test_resumed_workflow_keeps_previous_checkpoint_id_chain():
"""New checkpoints created after resume should chain to the restored checkpoint."""
from typing_extensions import Never

from agent_framework import WorkflowBuilder, WorkflowContext, handler
from agent_framework._workflows._executor import Executor

class StartExecutor(Executor):
@handler
async def run(self, message: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(message, target_id="middle")

class MiddleExecutor(Executor):
@handler
async def process(self, message: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(message + "-processed", target_id="finish")

class FinishExecutor(Executor):
@handler
async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None:
await ctx.yield_output(message + "-done")

storage = InMemoryCheckpointStorage()

start = StartExecutor(id="start")
middle = MiddleExecutor(id="middle")
finish = FinishExecutor(id="finish")

workflow_name = "resume-chain-workflow"

workflow = (
WorkflowBuilder(
max_iterations=10,
name=workflow_name,
start_executor=start,
checkpoint_storage=storage,
)
.add_edge(start, middle)
.add_edge(middle, finish)
.build()
)

_ = [event async for event in workflow.run("hello", stream=True)]

initial_checkpoints = sorted(
await storage.list_checkpoints(workflow_name=workflow.name),
key=lambda c: c.timestamp,
)
assert len(initial_checkpoints) >= 2, (
f"Expected at least 2 checkpoints before resume, got {len(initial_checkpoints)}"
)

restore_checkpoint = initial_checkpoints[1]
initial_ids = {checkpoint.checkpoint_id for checkpoint in initial_checkpoints}

resumed_start = StartExecutor(id="start")
resumed_middle = MiddleExecutor(id="middle")
resumed_finish = FinishExecutor(id="finish")

resumed_workflow = (
WorkflowBuilder(
max_iterations=10,
name=workflow_name,
start_executor=resumed_start,
checkpoint_storage=storage,
)
.add_edge(resumed_start, resumed_middle)
.add_edge(resumed_middle, resumed_finish)
.build()
)

_ = [event async for event in resumed_workflow.run(checkpoint_id=restore_checkpoint.checkpoint_id, stream=True)]

all_checkpoints = sorted(
await storage.list_checkpoints(workflow_name=workflow.name),
key=lambda c: c.timestamp,
)
resumed_checkpoints = [
checkpoint for checkpoint in all_checkpoints if checkpoint.checkpoint_id not in initial_ids
]

assert resumed_checkpoints, "Expected at least one new checkpoint after resume"
assert resumed_checkpoints[0].previous_checkpoint_id == restore_checkpoint.checkpoint_id

for i in range(1, len(resumed_checkpoints)):
assert resumed_checkpoints[i].previous_checkpoint_id == resumed_checkpoints[i - 1].checkpoint_id


async def test_memory_checkpoint_storage_roundtrip_json_native_types():
"""Test that JSON-native types (str, int, float, bool, None) roundtrip correctly."""
storage = InMemoryCheckpointStorage()
Expand Down