1010
1111import asyncio
1212import os
13+ import pickle # noqa: S403
1314from dataclasses import dataclass
1415from typing import Any
1516
1617import psycopg
17- from psycopg .types .json import Jsonb
1818from agent_framework import (
1919 Agent ,
2020 AgentExecutor ,
2929 response_handler ,
3030)
3131from agent_framework import WorkflowCheckpoint
32-
33- # Private import — no public API for checkpoint encoding yet.
34- # See: https://github.com/microsoft/agent-framework/issues/4428
35- from agent_framework ._workflows ._checkpoint_encoding import decode_checkpoint_value , encode_checkpoint_value
3632from agent_framework .exceptions import WorkflowCheckpointException
3733from agent_framework .openai import OpenAIChatClient
3834from azure .identity .aio import DefaultAzureCredential , get_bearer_token_provider
@@ -51,7 +47,10 @@ class PostgresCheckpointStorage:
5147 """PostgreSQL-backed checkpoint storage.
5248
5349 Stores checkpoints in a single table with columns for ID, workflow name,
54- timestamp, and the encoded JSON data. SQL handles indexing and filtering.
50+ timestamp, and the pickled checkpoint data. SQL handles indexing and filtering.
51+
52+ SECURITY WARNING: Checkpoints use pickle for serialization. Only load
53+ checkpoints from trusted sources.
5554 """
5655
5756 def __init__ (self , conninfo : str ) -> None :
@@ -65,7 +64,7 @@ def _ensure_table(self) -> None:
6564 id TEXT PRIMARY KEY,
6665 workflow_name TEXT NOT NULL,
6766 timestamp TEXT NOT NULL,
68- data JSONB NOT NULL
67+ data BYTEA NOT NULL
6968 )
7069 """ )
7170 conn .execute ("""
@@ -75,14 +74,14 @@ def _ensure_table(self) -> None:
7574
7675 async def save (self , checkpoint : WorkflowCheckpoint ) -> str :
7776 """Save a checkpoint to PostgreSQL."""
78- encoded = encode_checkpoint_value (checkpoint . to_dict ())
77+ data = pickle . dumps (checkpoint , protocol = pickle . HIGHEST_PROTOCOL ) # noqa: S301
7978 async with await psycopg .AsyncConnection .connect (self ._conninfo ) as conn :
8079 await conn .execute (
8180 """INSERT INTO workflow_checkpoints (id, workflow_name, timestamp, data)
8281 VALUES (%s, %s, %s, %s)
8382 ON CONFLICT (id) DO UPDATE SET data = EXCLUDED.data""" ,
8483 (checkpoint .checkpoint_id , checkpoint .workflow_name ,
85- checkpoint .timestamp , Jsonb ( encoded ) ),
84+ checkpoint .timestamp , data ),
8685 )
8786 return checkpoint .checkpoint_id
8887
@@ -94,8 +93,7 @@ async def load(self, checkpoint_id: str) -> WorkflowCheckpoint:
9493 )).fetchone ()
9594 if row is None :
9695 raise WorkflowCheckpointException (f"No checkpoint found with ID { checkpoint_id } " )
97- decoded = decode_checkpoint_value (row ["data" ])
98- return WorkflowCheckpoint .from_dict (decoded )
96+ return pickle .loads (row ["data" ]) # noqa: S301
9997
10098 async def list_checkpoints (self , * , workflow_name : str ) -> list [WorkflowCheckpoint ]:
10199 """List all checkpoints for a workflow, ordered by timestamp."""
@@ -104,7 +102,7 @@ async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoi
104102 "SELECT data FROM workflow_checkpoints WHERE workflow_name = %s ORDER BY timestamp" ,
105103 (workflow_name ,),
106104 )).fetchall ()
107- return [WorkflowCheckpoint . from_dict ( decode_checkpoint_value ( r ["data" ])) for r in rows ]
105+ return [pickle . loads ( r ["data" ]) for r in rows ] # noqa: S301
108106
109107 async def delete (self , checkpoint_id : str ) -> bool :
110108 """Delete a checkpoint by ID."""
@@ -124,7 +122,7 @@ async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None:
124122 )).fetchone ()
125123 if row is None :
126124 return None
127- return WorkflowCheckpoint . from_dict ( decode_checkpoint_value ( row ["data" ]))
125+ return pickle . loads ( row ["data" ]) # noqa: S301
128126
129127 async def list_checkpoint_ids (self , * , workflow_name : str ) -> list [str ]:
130128 """List checkpoint IDs for a workflow."""
0 commit comments