Skip to content

Commit 3b348ae

Browse files
authored
Merge pull request #28 from pamelafox/encodefix
Replace private encode/decode imports with pickle serialization
2 parents 2aa9e39 + 0ec0907 commit 3b348ae

2 files changed

Lines changed: 22 additions & 26 deletions

File tree

examples/spanish/workflow_hitl_checkpoint_pg.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
import asyncio
1212
import os
13+
import pickle # noqa: S403
1314
from dataclasses import dataclass
1415
from typing import Any
1516

1617
import psycopg
17-
from psycopg.types.json import Jsonb
1818
from agent_framework import (
1919
Agent,
2020
AgentExecutor,
@@ -29,10 +29,6 @@
2929
response_handler,
3030
)
3131
from agent_framework import WorkflowCheckpoint
32-
33-
# Importación privada — aún no hay API pública para la codificación de checkpoints.
34-
# Ver: https://github.com/microsoft/agent-framework/issues/4428
35-
from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
3632
from agent_framework.exceptions import WorkflowCheckpointException
3733
from agent_framework.openai import OpenAIChatClient
3834
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
@@ -51,7 +47,10 @@ class PostgresCheckpointStorage:
5147
"""Almacenamiento de checkpoints respaldado por PostgreSQL.
5248
5349
Guarda checkpoints en una sola tabla con columnas para ID, nombre del workflow,
54-
timestamp y los datos JSON codificados. SQL maneja la indexación y el filtrado.
50+
timestamp y los datos serializados con pickle. SQL maneja la indexación y el filtrado.
51+
52+
ADVERTENCIA DE SEGURIDAD: Los checkpoints usan pickle para la serialización.
53+
Solo carga checkpoints de fuentes confiables.
5554
"""
5655

5756
def __init__(self, conninfo: str) -> None:
@@ -65,7 +64,7 @@ def _ensure_table(self) -> None:
6564
id TEXT PRIMARY KEY,
6665
workflow_name TEXT NOT NULL,
6766
timestamp TEXT NOT NULL,
68-
data JSONB NOT NULL
67+
data BYTEA NOT NULL
6968
)
7069
""")
7170
conn.execute("""
@@ -75,14 +74,14 @@ def _ensure_table(self) -> None:
7574

7675
async def save(self, checkpoint: WorkflowCheckpoint) -> str:
7776
"""Guarda un checkpoint en PostgreSQL."""
78-
encoded = encode_checkpoint_value(checkpoint.to_dict())
77+
data = pickle.dumps(checkpoint, protocol=pickle.HIGHEST_PROTOCOL) # noqa: S301
7978
async with await psycopg.AsyncConnection.connect(self._conninfo) as conn:
8079
await conn.execute(
8180
"""INSERT INTO workflow_checkpoints (id, workflow_name, timestamp, data)
8281
VALUES (%s, %s, %s, %s)
8382
ON CONFLICT (id) DO UPDATE SET data = EXCLUDED.data""",
8483
(checkpoint.checkpoint_id, checkpoint.workflow_name,
85-
checkpoint.timestamp, Jsonb(encoded)),
84+
checkpoint.timestamp, data),
8685
)
8786
return checkpoint.checkpoint_id
8887

@@ -94,8 +93,7 @@ async def load(self, checkpoint_id: str) -> WorkflowCheckpoint:
9493
)).fetchone()
9594
if row is None:
9695
raise WorkflowCheckpointException(f"No se encontró checkpoint con ID {checkpoint_id}")
97-
decoded = decode_checkpoint_value(row["data"])
98-
return WorkflowCheckpoint.from_dict(decoded)
96+
return pickle.loads(row["data"]) # noqa: S301
9997

10098
async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoint]:
10199
"""Lista todos los checkpoints de un workflow, ordenados por timestamp."""
@@ -104,7 +102,7 @@ async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoi
104102
"SELECT data FROM workflow_checkpoints WHERE workflow_name = %s ORDER BY timestamp",
105103
(workflow_name,),
106104
)).fetchall()
107-
return [WorkflowCheckpoint.from_dict(decode_checkpoint_value(r["data"])) for r in rows]
105+
return [pickle.loads(r["data"]) for r in rows] # noqa: S301
108106

109107
async def delete(self, checkpoint_id: str) -> bool:
110108
"""Elimina un checkpoint por ID."""
@@ -124,7 +122,7 @@ async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None:
124122
)).fetchone()
125123
if row is None:
126124
return None
127-
return WorkflowCheckpoint.from_dict(decode_checkpoint_value(row["data"]))
125+
return pickle.loads(row["data"]) # noqa: S301
128126

129127
async def list_checkpoint_ids(self, *, workflow_name: str) -> list[str]:
130128
"""Lista los IDs de checkpoints de un workflow."""

examples/workflow_hitl_checkpoint_pg.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
import asyncio
1212
import os
13+
import pickle # noqa: S403
1314
from dataclasses import dataclass
1415
from typing import Any
1516

1617
import psycopg
17-
from psycopg.types.json import Jsonb
1818
from agent_framework import (
1919
Agent,
2020
AgentExecutor,
@@ -29,10 +29,6 @@
2929
response_handler,
3030
)
3131
from agent_framework import WorkflowCheckpoint
32-
33-
# Private import — no public API for checkpoint encoding yet.
34-
# See: https://github.com/microsoft/agent-framework/issues/4428
35-
from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
3632
from agent_framework.exceptions import WorkflowCheckpointException
3733
from agent_framework.openai import OpenAIChatClient
3834
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
@@ -51,7 +47,10 @@ class PostgresCheckpointStorage:
5147
"""PostgreSQL-backed checkpoint storage.
5248
5349
Stores checkpoints in a single table with columns for ID, workflow name,
54-
timestamp, and the encoded JSON data. SQL handles indexing and filtering.
50+
timestamp, and the pickled checkpoint data. SQL handles indexing and filtering.
51+
52+
SECURITY WARNING: Checkpoints use pickle for serialization. Only load
53+
checkpoints from trusted sources.
5554
"""
5655

5756
def __init__(self, conninfo: str) -> None:
@@ -65,7 +64,7 @@ def _ensure_table(self) -> None:
6564
id TEXT PRIMARY KEY,
6665
workflow_name TEXT NOT NULL,
6766
timestamp TEXT NOT NULL,
68-
data JSONB NOT NULL
67+
data BYTEA NOT NULL
6968
)
7069
""")
7170
conn.execute("""
@@ -75,14 +74,14 @@ def _ensure_table(self) -> None:
7574

7675
async def save(self, checkpoint: WorkflowCheckpoint) -> str:
7776
"""Save a checkpoint to PostgreSQL."""
78-
encoded = encode_checkpoint_value(checkpoint.to_dict())
77+
data = pickle.dumps(checkpoint, protocol=pickle.HIGHEST_PROTOCOL) # noqa: S301
7978
async with await psycopg.AsyncConnection.connect(self._conninfo) as conn:
8079
await conn.execute(
8180
"""INSERT INTO workflow_checkpoints (id, workflow_name, timestamp, data)
8281
VALUES (%s, %s, %s, %s)
8382
ON CONFLICT (id) DO UPDATE SET data = EXCLUDED.data""",
8483
(checkpoint.checkpoint_id, checkpoint.workflow_name,
85-
checkpoint.timestamp, Jsonb(encoded)),
84+
checkpoint.timestamp, data),
8685
)
8786
return checkpoint.checkpoint_id
8887

@@ -94,8 +93,7 @@ async def load(self, checkpoint_id: str) -> WorkflowCheckpoint:
9493
)).fetchone()
9594
if row is None:
9695
raise WorkflowCheckpointException(f"No checkpoint found with ID {checkpoint_id}")
97-
decoded = decode_checkpoint_value(row["data"])
98-
return WorkflowCheckpoint.from_dict(decoded)
96+
return pickle.loads(row["data"]) # noqa: S301
9997

10098
async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoint]:
10199
"""List all checkpoints for a workflow, ordered by timestamp."""
@@ -104,7 +102,7 @@ async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoi
104102
"SELECT data FROM workflow_checkpoints WHERE workflow_name = %s ORDER BY timestamp",
105103
(workflow_name,),
106104
)).fetchall()
107-
return [WorkflowCheckpoint.from_dict(decode_checkpoint_value(r["data"])) for r in rows]
105+
return [pickle.loads(r["data"]) for r in rows] # noqa: S301
108106

109107
async def delete(self, checkpoint_id: str) -> bool:
110108
"""Delete a checkpoint by ID."""
@@ -124,7 +122,7 @@ async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None:
124122
)).fetchone()
125123
if row is None:
126124
return None
127-
return WorkflowCheckpoint.from_dict(decode_checkpoint_value(row["data"]))
125+
return pickle.loads(row["data"]) # noqa: S301
128126

129127
async def list_checkpoint_ids(self, *, workflow_name: str) -> list[str]:
130128
"""List checkpoint IDs for a workflow."""

0 commit comments

Comments
 (0)