Skip to content

Commit 923cbb3

Browse files
danielmillerpclaude
andcommitted
feat: add HTTP-proxy LangGraph checkpoint API
Agents no longer need a direct Postgres connection for LangGraph checkpointing. Instead, checkpoint operations are proxied through 5 new backend endpoints under /checkpoints (get-tuple, put, put-writes, list, delete-thread). Binary blob data is base64-encoded for JSON transport. Includes ORM models for the 4 checkpoint tables, Alembic migration, repository with composite-PK queries, use case layer, Pydantic schemas, and FastAPI routes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f436743 commit 923cbb3

9 files changed

Lines changed: 1533 additions & 0 deletions

File tree

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""add_langgraph_checkpoint_tables
2+
3+
Revision ID: a1b2c3d4e5f6
4+
Revises: d024851e790c
5+
Create Date: 2026-02-07 00:00:00.000000
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
from alembic import op
11+
import sqlalchemy as sa
12+
from sqlalchemy.dialects import postgresql
13+
14+
15+
# revision identifiers, used by Alembic.
16+
revision: str = 'a1b2c3d4e5f6'
17+
down_revision: Union[str, None] = 'd024851e790c'
18+
branch_labels: Union[str, Sequence[str], None] = None
19+
depends_on: Union[str, Sequence[str], None] = None
20+
21+
22+
def upgrade() -> None:
23+
# checkpoint_migrations
24+
op.create_table(
25+
'checkpoint_migrations',
26+
sa.Column('v', sa.Integer(), nullable=False),
27+
sa.PrimaryKeyConstraint('v'),
28+
)
29+
30+
# checkpoints
31+
op.create_table(
32+
'checkpoints',
33+
sa.Column('thread_id', sa.Text(), nullable=False),
34+
sa.Column('checkpoint_ns', sa.Text(), server_default='', nullable=False),
35+
sa.Column('checkpoint_id', sa.Text(), nullable=False),
36+
sa.Column('parent_checkpoint_id', sa.Text(), nullable=True),
37+
sa.Column('type', sa.Text(), nullable=True),
38+
sa.Column('checkpoint', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
39+
sa.Column('metadata', postgresql.JSONB(astext_type=sa.Text()), server_default='{}', nullable=False),
40+
sa.PrimaryKeyConstraint('thread_id', 'checkpoint_ns', 'checkpoint_id'),
41+
)
42+
op.create_index('checkpoints_thread_id_idx', 'checkpoints', ['thread_id'], unique=False)
43+
44+
# checkpoint_blobs
45+
op.create_table(
46+
'checkpoint_blobs',
47+
sa.Column('thread_id', sa.Text(), nullable=False),
48+
sa.Column('checkpoint_ns', sa.Text(), server_default='', nullable=False),
49+
sa.Column('channel', sa.Text(), nullable=False),
50+
sa.Column('version', sa.Text(), nullable=False),
51+
sa.Column('type', sa.Text(), nullable=False),
52+
sa.Column('blob', sa.LargeBinary(), nullable=True),
53+
sa.PrimaryKeyConstraint('thread_id', 'checkpoint_ns', 'channel', 'version'),
54+
)
55+
op.create_index('checkpoint_blobs_thread_id_idx', 'checkpoint_blobs', ['thread_id'], unique=False)
56+
57+
# checkpoint_writes
58+
op.create_table(
59+
'checkpoint_writes',
60+
sa.Column('thread_id', sa.Text(), nullable=False),
61+
sa.Column('checkpoint_ns', sa.Text(), server_default='', nullable=False),
62+
sa.Column('checkpoint_id', sa.Text(), nullable=False),
63+
sa.Column('task_id', sa.Text(), nullable=False),
64+
sa.Column('idx', sa.Integer(), nullable=False),
65+
sa.Column('channel', sa.Text(), nullable=False),
66+
sa.Column('type', sa.Text(), nullable=True),
67+
sa.Column('blob', sa.LargeBinary(), nullable=False),
68+
sa.Column('task_path', sa.Text(), server_default='', nullable=False),
69+
sa.PrimaryKeyConstraint('thread_id', 'checkpoint_ns', 'checkpoint_id', 'task_id', 'idx'),
70+
)
71+
op.create_index('checkpoint_writes_thread_id_idx', 'checkpoint_writes', ['thread_id'], unique=False)
72+
73+
# Pre-populate checkpoint_migrations so LangGraph sees all its
74+
# internal migrations as already applied and skips setup().
75+
op.execute(
76+
sa.text(
77+
"INSERT INTO checkpoint_migrations (v) VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)"
78+
)
79+
)
80+
81+
82+
def downgrade() -> None:
83+
op.drop_index('checkpoint_writes_thread_id_idx', table_name='checkpoint_writes')
84+
op.drop_table('checkpoint_writes')
85+
op.drop_index('checkpoint_blobs_thread_id_idx', table_name='checkpoint_blobs')
86+
op.drop_table('checkpoint_blobs')
87+
op.drop_index('checkpoints_thread_id_idx', table_name='checkpoints')
88+
op.drop_table('checkpoints')
89+
op.drop_table('checkpoint_migrations')

agentex/src/adapters/orm.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
DateTime,
66
ForeignKey,
77
Index,
8+
Integer,
9+
LargeBinary,
810
String,
911
Text,
1012
func,
@@ -213,3 +215,56 @@ class DeploymentHistoryORM(BaseORM):
213215
"commit_hash",
214216
),
215217
)
218+
219+
220+
# LangGraph checkpoint tables
221+
# These mirror the schema from langgraph.checkpoint.postgres so that
222+
# tables are created via Alembic migrations rather than at agent runtime.
223+
224+
225+
class CheckpointMigrationORM(BaseORM):
226+
__tablename__ = "checkpoint_migrations"
227+
v = Column(Integer, primary_key=True)
228+
229+
230+
class CheckpointORM(BaseORM):
231+
__tablename__ = "checkpoints"
232+
thread_id = Column(Text, nullable=False, primary_key=True)
233+
checkpoint_ns = Column(Text, nullable=False, primary_key=True, server_default="")
234+
checkpoint_id = Column(Text, nullable=False, primary_key=True)
235+
parent_checkpoint_id = Column(Text, nullable=True)
236+
type = Column(Text, nullable=True)
237+
checkpoint = Column(JSONB, nullable=False)
238+
metadata_ = Column("metadata", JSONB, nullable=False, server_default="{}")
239+
__table_args__ = (
240+
Index("checkpoints_thread_id_idx", "thread_id"),
241+
)
242+
243+
244+
class CheckpointBlobORM(BaseORM):
245+
__tablename__ = "checkpoint_blobs"
246+
thread_id = Column(Text, nullable=False, primary_key=True)
247+
checkpoint_ns = Column(Text, nullable=False, primary_key=True, server_default="")
248+
channel = Column(Text, nullable=False, primary_key=True)
249+
version = Column(Text, nullable=False, primary_key=True)
250+
type = Column(Text, nullable=False)
251+
blob = Column(LargeBinary, nullable=True)
252+
__table_args__ = (
253+
Index("checkpoint_blobs_thread_id_idx", "thread_id"),
254+
)
255+
256+
257+
class CheckpointWriteORM(BaseORM):
258+
__tablename__ = "checkpoint_writes"
259+
thread_id = Column(Text, nullable=False, primary_key=True)
260+
checkpoint_ns = Column(Text, nullable=False, primary_key=True, server_default="")
261+
checkpoint_id = Column(Text, nullable=False, primary_key=True)
262+
task_id = Column(Text, nullable=False, primary_key=True)
263+
idx = Column(Integer, nullable=False, primary_key=True)
264+
channel = Column(Text, nullable=False)
265+
type = Column(Text, nullable=True)
266+
blob = Column(LargeBinary, nullable=False)
267+
task_path = Column(Text, nullable=False, server_default="")
268+
__table_args__ = (
269+
Index("checkpoint_writes_thread_id_idx", "thread_id"),
270+
)

agentex/src/api/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
agent_api_keys,
2020
agent_task_tracker,
2121
agents,
22+
checkpoints,
2223
deployment_history,
2324
events,
2425
messages,
@@ -183,6 +184,7 @@ async def handle_unexpected(request, exc):
183184
fastapi_app.include_router(agent_api_keys.router)
184185
fastapi_app.include_router(deployment_history.router)
185186
fastapi_app.include_router(schedules.router)
187+
fastapi_app.include_router(checkpoints.router)
186188

187189
# Wrap FastAPI app with health check interceptor for sub-millisecond K8s probe responses.
188190
# This must be the outermost layer to bypass all middleware.
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import base64
2+
3+
from fastapi import APIRouter, Response
4+
5+
from src.api.schemas.checkpoints import (
6+
BlobResponse,
7+
CheckpointListItem,
8+
CheckpointTupleResponse,
9+
DeleteThreadRequest,
10+
GetCheckpointTupleRequest,
11+
ListCheckpointsRequest,
12+
PutCheckpointRequest,
13+
PutCheckpointResponse,
14+
PutWritesRequest,
15+
WriteResponse,
16+
)
17+
from src.api.schemas.authorization_types import (
18+
AgentexResourceType,
19+
AuthorizedOperationType,
20+
)
21+
from src.domain.use_cases.checkpoints_use_case import DCheckpointsUseCase
22+
from src.utils.authorization_shortcuts import DAuthorizedBodyId
23+
from src.utils.logging import make_logger
24+
25+
logger = make_logger(__name__)
26+
27+
router = APIRouter(prefix="/checkpoints", tags=["Checkpoints"])
28+
29+
30+
def _bytes_to_b64(data: bytes | None) -> str | None:
31+
if data is None:
32+
return None
33+
return base64.b64encode(data).decode("ascii")
34+
35+
36+
def _b64_to_bytes(data: str | None) -> bytes | None:
37+
if data is None:
38+
return None
39+
return base64.b64decode(data)
40+
41+
42+
@router.post(
43+
"/get-tuple",
44+
response_model=CheckpointTupleResponse | None,
45+
)
46+
async def get_checkpoint_tuple(
47+
request: GetCheckpointTupleRequest,
48+
checkpoints_use_case: DCheckpointsUseCase,
49+
_authorized_task_id: DAuthorizedBodyId(
50+
AgentexResourceType.task, AuthorizedOperationType.read, field_name="thread_id"
51+
),
52+
) -> CheckpointTupleResponse | None:
53+
result = await checkpoints_use_case.get_tuple(
54+
thread_id=request.thread_id,
55+
checkpoint_ns=request.checkpoint_ns,
56+
checkpoint_id=request.checkpoint_id,
57+
)
58+
if result is None:
59+
return None
60+
61+
return CheckpointTupleResponse(
62+
thread_id=result["thread_id"],
63+
checkpoint_ns=result["checkpoint_ns"],
64+
checkpoint_id=result["checkpoint_id"],
65+
parent_checkpoint_id=result["parent_checkpoint_id"],
66+
checkpoint=result["checkpoint"],
67+
metadata=result["metadata"],
68+
blobs=[
69+
BlobResponse(
70+
channel=b["channel"],
71+
version=b["version"],
72+
type=b["type"],
73+
blob=_bytes_to_b64(b["blob"]),
74+
)
75+
for b in result.get("blobs", [])
76+
],
77+
pending_writes=[
78+
WriteResponse(
79+
task_id=w["task_id"],
80+
idx=w["idx"],
81+
channel=w["channel"],
82+
type=w["type"],
83+
blob=_bytes_to_b64(w["blob"]),
84+
)
85+
for w in result.get("pending_writes", [])
86+
],
87+
)
88+
89+
90+
@router.post(
91+
"/put",
92+
response_model=PutCheckpointResponse,
93+
)
94+
async def put_checkpoint(
95+
request: PutCheckpointRequest,
96+
checkpoints_use_case: DCheckpointsUseCase,
97+
_authorized_task_id: DAuthorizedBodyId(
98+
AgentexResourceType.task, AuthorizedOperationType.execute, field_name="thread_id"
99+
),
100+
) -> PutCheckpointResponse:
101+
blobs = [
102+
{
103+
"channel": b.channel,
104+
"version": b.version,
105+
"type": b.type,
106+
"blob": _b64_to_bytes(b.blob),
107+
}
108+
for b in request.blobs
109+
]
110+
111+
await checkpoints_use_case.put(
112+
thread_id=request.thread_id,
113+
checkpoint_ns=request.checkpoint_ns,
114+
checkpoint_id=request.checkpoint_id,
115+
parent_checkpoint_id=request.parent_checkpoint_id,
116+
checkpoint=request.checkpoint,
117+
metadata=request.metadata,
118+
blobs=blobs,
119+
)
120+
121+
return PutCheckpointResponse(
122+
thread_id=request.thread_id,
123+
checkpoint_ns=request.checkpoint_ns,
124+
checkpoint_id=request.checkpoint_id,
125+
)
126+
127+
128+
@router.post(
129+
"/put-writes",
130+
status_code=204,
131+
)
132+
async def put_writes(
133+
request: PutWritesRequest,
134+
checkpoints_use_case: DCheckpointsUseCase,
135+
_authorized_task_id: DAuthorizedBodyId(
136+
AgentexResourceType.task, AuthorizedOperationType.execute, field_name="thread_id"
137+
),
138+
) -> Response:
139+
writes = [
140+
{
141+
"task_id": w.task_id,
142+
"idx": w.idx,
143+
"channel": w.channel,
144+
"type": w.type,
145+
"blob": _b64_to_bytes(w.blob),
146+
"task_path": w.task_path,
147+
}
148+
for w in request.writes
149+
]
150+
151+
await checkpoints_use_case.put_writes(
152+
thread_id=request.thread_id,
153+
checkpoint_ns=request.checkpoint_ns,
154+
checkpoint_id=request.checkpoint_id,
155+
writes=writes,
156+
upsert=request.upsert,
157+
)
158+
159+
return Response(status_code=204)
160+
161+
162+
@router.post(
163+
"/list",
164+
response_model=list[CheckpointListItem],
165+
)
166+
async def list_checkpoints(
167+
request: ListCheckpointsRequest,
168+
checkpoints_use_case: DCheckpointsUseCase,
169+
_authorized_task_id: DAuthorizedBodyId(
170+
AgentexResourceType.task, AuthorizedOperationType.read, field_name="thread_id"
171+
),
172+
) -> list[CheckpointListItem]:
173+
results = await checkpoints_use_case.list_checkpoints(
174+
thread_id=request.thread_id,
175+
checkpoint_ns=request.checkpoint_ns,
176+
before_checkpoint_id=request.before_checkpoint_id,
177+
filter_metadata=request.filter_metadata,
178+
limit=request.limit,
179+
)
180+
181+
return [
182+
CheckpointListItem(
183+
thread_id=r["thread_id"],
184+
checkpoint_ns=r["checkpoint_ns"],
185+
checkpoint_id=r["checkpoint_id"],
186+
parent_checkpoint_id=r["parent_checkpoint_id"],
187+
checkpoint=r["checkpoint"],
188+
metadata=r["metadata"],
189+
)
190+
for r in results
191+
]
192+
193+
194+
@router.post(
195+
"/delete-thread",
196+
status_code=204,
197+
)
198+
async def delete_thread(
199+
request: DeleteThreadRequest,
200+
checkpoints_use_case: DCheckpointsUseCase,
201+
_authorized_task_id: DAuthorizedBodyId(
202+
AgentexResourceType.task, AuthorizedOperationType.delete, field_name="thread_id"
203+
),
204+
) -> Response:
205+
await checkpoints_use_case.delete_thread(thread_id=request.thread_id)
206+
return Response(status_code=204)

0 commit comments

Comments
 (0)