Skip to content

Commit dd5abe4

Browse files
danielmillerpclaude
andcommitted
feat: add HTTP-proxy LangGraph checkpoint API
Agents no longer need a direct Postgres connection for LangGraph checkpointing. Instead, checkpoint operations are proxied through 5 new backend endpoints under /checkpoints (get-tuple, put, put-writes, list, delete-thread). Binary blob data is base64-encoded for JSON transport. Includes ORM models for the 4 checkpoint tables, Alembic migration, repository with composite-PK queries, use case layer, Pydantic schemas, and FastAPI routes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f436743 commit dd5abe4

9 files changed

Lines changed: 1514 additions & 0 deletions

File tree

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""add_langgraph_checkpoint_tables
2+
3+
Revision ID: a1b2c3d4e5f6
4+
Revises: d024851e790c
5+
Create Date: 2026-02-07 00:00:00.000000
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
from alembic import op
11+
import sqlalchemy as sa
12+
from sqlalchemy.dialects import postgresql
13+
14+
15+
# revision identifiers, used by Alembic.
16+
revision: str = 'a1b2c3d4e5f6'
17+
down_revision: Union[str, None] = 'd024851e790c'
18+
branch_labels: Union[str, Sequence[str], None] = None
19+
depends_on: Union[str, Sequence[str], None] = None
20+
21+
22+
def upgrade() -> None:
23+
# checkpoint_migrations
24+
op.create_table(
25+
'checkpoint_migrations',
26+
sa.Column('v', sa.Integer(), nullable=False),
27+
sa.PrimaryKeyConstraint('v'),
28+
)
29+
30+
# checkpoints
31+
op.create_table(
32+
'checkpoints',
33+
sa.Column('thread_id', sa.Text(), nullable=False),
34+
sa.Column('checkpoint_ns', sa.Text(), server_default='', nullable=False),
35+
sa.Column('checkpoint_id', sa.Text(), nullable=False),
36+
sa.Column('parent_checkpoint_id', sa.Text(), nullable=True),
37+
sa.Column('type', sa.Text(), nullable=True),
38+
sa.Column('checkpoint', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
39+
sa.Column('metadata', postgresql.JSONB(astext_type=sa.Text()), server_default='{}', nullable=False),
40+
sa.PrimaryKeyConstraint('thread_id', 'checkpoint_ns', 'checkpoint_id'),
41+
)
42+
op.create_index('checkpoints_thread_id_idx', 'checkpoints', ['thread_id'], unique=False)
43+
44+
# checkpoint_blobs
45+
op.create_table(
46+
'checkpoint_blobs',
47+
sa.Column('thread_id', sa.Text(), nullable=False),
48+
sa.Column('checkpoint_ns', sa.Text(), server_default='', nullable=False),
49+
sa.Column('channel', sa.Text(), nullable=False),
50+
sa.Column('version', sa.Text(), nullable=False),
51+
sa.Column('type', sa.Text(), nullable=False),
52+
sa.Column('blob', sa.LargeBinary(), nullable=True),
53+
sa.PrimaryKeyConstraint('thread_id', 'checkpoint_ns', 'channel', 'version'),
54+
)
55+
op.create_index('checkpoint_blobs_thread_id_idx', 'checkpoint_blobs', ['thread_id'], unique=False)
56+
57+
# checkpoint_writes
58+
op.create_table(
59+
'checkpoint_writes',
60+
sa.Column('thread_id', sa.Text(), nullable=False),
61+
sa.Column('checkpoint_ns', sa.Text(), server_default='', nullable=False),
62+
sa.Column('checkpoint_id', sa.Text(), nullable=False),
63+
sa.Column('task_id', sa.Text(), nullable=False),
64+
sa.Column('idx', sa.Integer(), nullable=False),
65+
sa.Column('channel', sa.Text(), nullable=False),
66+
sa.Column('type', sa.Text(), nullable=True),
67+
sa.Column('blob', sa.LargeBinary(), nullable=False),
68+
sa.Column('task_path', sa.Text(), server_default='', nullable=False),
69+
sa.PrimaryKeyConstraint('thread_id', 'checkpoint_ns', 'checkpoint_id', 'task_id', 'idx'),
70+
)
71+
op.create_index('checkpoint_writes_thread_id_idx', 'checkpoint_writes', ['thread_id'], unique=False)
72+
73+
# Pre-populate checkpoint_migrations so LangGraph sees all its
74+
# internal migrations as already applied and skips setup().
75+
op.execute(
76+
sa.text(
77+
"INSERT INTO checkpoint_migrations (v) VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)"
78+
)
79+
)
80+
81+
82+
def downgrade() -> None:
83+
op.drop_index('checkpoint_writes_thread_id_idx', table_name='checkpoint_writes')
84+
op.drop_table('checkpoint_writes')
85+
op.drop_index('checkpoint_blobs_thread_id_idx', table_name='checkpoint_blobs')
86+
op.drop_table('checkpoint_blobs')
87+
op.drop_index('checkpoints_thread_id_idx', table_name='checkpoints')
88+
op.drop_table('checkpoints')
89+
op.drop_table('checkpoint_migrations')

agentex/src/adapters/orm.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
DateTime,
66
ForeignKey,
77
Index,
8+
Integer,
9+
LargeBinary,
810
String,
911
Text,
1012
func,
@@ -213,3 +215,56 @@ class DeploymentHistoryORM(BaseORM):
213215
"commit_hash",
214216
),
215217
)
218+
219+
220+
# LangGraph checkpoint tables
221+
# These mirror the schema from langgraph.checkpoint.postgres so that
222+
# tables are created via Alembic migrations rather than at agent runtime.
223+
224+
225+
class CheckpointMigrationORM(BaseORM):
226+
__tablename__ = "checkpoint_migrations"
227+
v = Column(Integer, primary_key=True)
228+
229+
230+
class CheckpointORM(BaseORM):
231+
__tablename__ = "checkpoints"
232+
thread_id = Column(Text, nullable=False, primary_key=True)
233+
checkpoint_ns = Column(Text, nullable=False, primary_key=True, server_default="")
234+
checkpoint_id = Column(Text, nullable=False, primary_key=True)
235+
parent_checkpoint_id = Column(Text, nullable=True)
236+
type = Column(Text, nullable=True)
237+
checkpoint = Column(JSONB, nullable=False)
238+
metadata_ = Column("metadata", JSONB, nullable=False, server_default="{}")
239+
__table_args__ = (
240+
Index("checkpoints_thread_id_idx", "thread_id"),
241+
)
242+
243+
244+
class CheckpointBlobORM(BaseORM):
245+
__tablename__ = "checkpoint_blobs"
246+
thread_id = Column(Text, nullable=False, primary_key=True)
247+
checkpoint_ns = Column(Text, nullable=False, primary_key=True, server_default="")
248+
channel = Column(Text, nullable=False, primary_key=True)
249+
version = Column(Text, nullable=False, primary_key=True)
250+
type = Column(Text, nullable=False)
251+
blob = Column(LargeBinary, nullable=True)
252+
__table_args__ = (
253+
Index("checkpoint_blobs_thread_id_idx", "thread_id"),
254+
)
255+
256+
257+
class CheckpointWriteORM(BaseORM):
258+
__tablename__ = "checkpoint_writes"
259+
thread_id = Column(Text, nullable=False, primary_key=True)
260+
checkpoint_ns = Column(Text, nullable=False, primary_key=True, server_default="")
261+
checkpoint_id = Column(Text, nullable=False, primary_key=True)
262+
task_id = Column(Text, nullable=False, primary_key=True)
263+
idx = Column(Integer, nullable=False, primary_key=True)
264+
channel = Column(Text, nullable=False)
265+
type = Column(Text, nullable=True)
266+
blob = Column(LargeBinary, nullable=False)
267+
task_path = Column(Text, nullable=False, server_default="")
268+
__table_args__ = (
269+
Index("checkpoint_writes_thread_id_idx", "thread_id"),
270+
)

agentex/src/api/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
agent_api_keys,
2020
agent_task_tracker,
2121
agents,
22+
checkpoints,
2223
deployment_history,
2324
events,
2425
messages,
@@ -183,6 +184,7 @@ async def handle_unexpected(request, exc):
183184
fastapi_app.include_router(agent_api_keys.router)
184185
fastapi_app.include_router(deployment_history.router)
185186
fastapi_app.include_router(schedules.router)
187+
fastapi_app.include_router(checkpoints.router)
186188

187189
# Wrap FastAPI app with health check interceptor for sub-millisecond K8s probe responses.
188190
# This must be the outermost layer to bypass all middleware.
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import base64
2+
3+
from fastapi import APIRouter, Response
4+
5+
from src.api.schemas.checkpoints import (
6+
BlobResponse,
7+
CheckpointListItem,
8+
CheckpointTupleResponse,
9+
DeleteThreadRequest,
10+
GetCheckpointTupleRequest,
11+
ListCheckpointsRequest,
12+
PutCheckpointRequest,
13+
PutCheckpointResponse,
14+
PutWritesRequest,
15+
WriteResponse,
16+
)
17+
from src.domain.use_cases.checkpoints_use_case import DCheckpointsUseCase
18+
from src.utils.logging import make_logger
19+
20+
logger = make_logger(__name__)
21+
22+
router = APIRouter(prefix="/checkpoints", tags=["Checkpoints"])
23+
24+
25+
def _bytes_to_b64(data: bytes | None) -> str | None:
26+
if data is None:
27+
return None
28+
return base64.b64encode(data).decode("ascii")
29+
30+
31+
def _b64_to_bytes(data: str | None) -> bytes | None:
32+
if data is None:
33+
return None
34+
return base64.b64decode(data)
35+
36+
37+
@router.post(
38+
"/get-tuple",
39+
response_model=CheckpointTupleResponse | None,
40+
)
41+
async def get_checkpoint_tuple(
42+
request: GetCheckpointTupleRequest,
43+
checkpoints_use_case: DCheckpointsUseCase,
44+
) -> CheckpointTupleResponse | None:
45+
result = await checkpoints_use_case.get_tuple(
46+
thread_id=request.thread_id,
47+
checkpoint_ns=request.checkpoint_ns,
48+
checkpoint_id=request.checkpoint_id,
49+
)
50+
if result is None:
51+
return None
52+
53+
return CheckpointTupleResponse(
54+
thread_id=result["thread_id"],
55+
checkpoint_ns=result["checkpoint_ns"],
56+
checkpoint_id=result["checkpoint_id"],
57+
parent_checkpoint_id=result["parent_checkpoint_id"],
58+
checkpoint=result["checkpoint"],
59+
metadata=result["metadata"],
60+
blobs=[
61+
BlobResponse(
62+
channel=b["channel"],
63+
version=b["version"],
64+
type=b["type"],
65+
blob=_bytes_to_b64(b["blob"]),
66+
)
67+
for b in result.get("blobs", [])
68+
],
69+
pending_writes=[
70+
WriteResponse(
71+
task_id=w["task_id"],
72+
idx=w["idx"],
73+
channel=w["channel"],
74+
type=w["type"],
75+
blob=_bytes_to_b64(w["blob"]),
76+
)
77+
for w in result.get("pending_writes", [])
78+
],
79+
)
80+
81+
82+
@router.post(
83+
"/put",
84+
response_model=PutCheckpointResponse,
85+
)
86+
async def put_checkpoint(
87+
request: PutCheckpointRequest,
88+
checkpoints_use_case: DCheckpointsUseCase,
89+
) -> PutCheckpointResponse:
90+
blobs = [
91+
{
92+
"channel": b.channel,
93+
"version": b.version,
94+
"type": b.type,
95+
"blob": _b64_to_bytes(b.blob),
96+
}
97+
for b in request.blobs
98+
]
99+
100+
await checkpoints_use_case.put(
101+
thread_id=request.thread_id,
102+
checkpoint_ns=request.checkpoint_ns,
103+
checkpoint_id=request.checkpoint_id,
104+
parent_checkpoint_id=request.parent_checkpoint_id,
105+
checkpoint=request.checkpoint,
106+
metadata=request.metadata,
107+
blobs=blobs,
108+
)
109+
110+
return PutCheckpointResponse(
111+
thread_id=request.thread_id,
112+
checkpoint_ns=request.checkpoint_ns,
113+
checkpoint_id=request.checkpoint_id,
114+
)
115+
116+
117+
@router.post(
118+
"/put-writes",
119+
status_code=204,
120+
)
121+
async def put_writes(
122+
request: PutWritesRequest,
123+
checkpoints_use_case: DCheckpointsUseCase,
124+
) -> Response:
125+
writes = [
126+
{
127+
"task_id": w.task_id,
128+
"idx": w.idx,
129+
"channel": w.channel,
130+
"type": w.type,
131+
"blob": _b64_to_bytes(w.blob),
132+
"task_path": w.task_path,
133+
}
134+
for w in request.writes
135+
]
136+
137+
await checkpoints_use_case.put_writes(
138+
thread_id=request.thread_id,
139+
checkpoint_ns=request.checkpoint_ns,
140+
checkpoint_id=request.checkpoint_id,
141+
writes=writes,
142+
upsert=request.upsert,
143+
)
144+
145+
return Response(status_code=204)
146+
147+
148+
@router.post(
149+
"/list",
150+
response_model=list[CheckpointListItem],
151+
)
152+
async def list_checkpoints(
153+
request: ListCheckpointsRequest,
154+
checkpoints_use_case: DCheckpointsUseCase,
155+
) -> list[CheckpointListItem]:
156+
results = await checkpoints_use_case.list_checkpoints(
157+
thread_id=request.thread_id,
158+
checkpoint_ns=request.checkpoint_ns,
159+
before_checkpoint_id=request.before_checkpoint_id,
160+
filter_metadata=request.filter_metadata,
161+
limit=request.limit,
162+
)
163+
164+
return [
165+
CheckpointListItem(
166+
thread_id=r["thread_id"],
167+
checkpoint_ns=r["checkpoint_ns"],
168+
checkpoint_id=r["checkpoint_id"],
169+
parent_checkpoint_id=r["parent_checkpoint_id"],
170+
checkpoint=r["checkpoint"],
171+
metadata=r["metadata"],
172+
)
173+
for r in results
174+
]
175+
176+
177+
@router.post(
178+
"/delete-thread",
179+
status_code=204,
180+
)
181+
async def delete_thread(
182+
request: DeleteThreadRequest,
183+
checkpoints_use_case: DCheckpointsUseCase,
184+
) -> Response:
185+
await checkpoints_use_case.delete_thread(thread_id=request.thread_id)
186+
return Response(status_code=204)

0 commit comments

Comments
 (0)