Skip to content

Commit e0d141d

Browse files
committed
feat: Add include_sql to Search Pipeline Run API
1 parent a9adba7 commit e0d141d

2 files changed

Lines changed: 128 additions & 10 deletions

File tree

cloud_pipelines_backend/api_server_sql.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class GetPipelineRunResponse(PipelineRunResponse):
6464
class ListPipelineJobsResponse:
6565
pipeline_runs: list[PipelineRunResponse]
6666
next_page_token: str | None = None
67+
sql: str | None = None
6768

6869

6970
class PipelineRunsApiService_Sql:
@@ -175,6 +176,36 @@ def terminate(
175176
execution_node.extra_data["desired_state"] = "TERMINATED"
176177
session.commit()
177178

179+
@staticmethod
180+
def _compile_sql_string(
181+
stmt: sql.Select,
182+
dialect: sql.engine.Dialect,
183+
) -> str:
184+
"""Compile a SQLAlchemy statement to a SQL string for debugging.
185+
186+
Uses ``literal_binds=True`` to inline bound parameters as literal
187+
values, producing a self-contained query string::
188+
189+
SELECT ... WHERE key = 'environment' AND created_at < '2024-01-15' LIMIT 10
190+
191+
If a column type lacks a ``literal_processor`` (raises CompileError or
192+
NotImplementedError), falls back to placeholder syntax with a params
193+
comment::
194+
195+
SELECT ... WHERE key = :key_1 AND created_at < :created_at_1 LIMIT :param_1
196+
-- params: {'key_1': 'environment', 'created_at_1': '2024-01-15', 'param_1': 10}
197+
"""
198+
try:
199+
compiled = stmt.compile(
200+
dialect=dialect,
201+
compile_kwargs={"literal_binds": True},
202+
)
203+
return str(compiled)
204+
except (sql.exc.CompileError, NotImplementedError):
205+
compiled = stmt.compile(dialect=dialect)
206+
params_suffix = f"\n-- params: {compiled.params}" if compiled.params else ""
207+
return str(compiled) + params_suffix
208+
178209
# Note: This method must be last to not shadow the "list" type
179210
def list(
180211
self,
@@ -186,6 +217,7 @@ def list(
186217
current_user: str | None = None,
187218
include_pipeline_names: bool = False,
188219
include_execution_stats: bool = False,
220+
include_sql: bool = False,
189221
) -> ListPipelineJobsResponse:
190222
where_clauses = filter_query_sql.build_list_filters(
191223
filter_value=filter,
@@ -194,18 +226,22 @@ def list(
194226
current_user=current_user,
195227
)
196228

197-
pipeline_runs = list(
198-
session.scalars(
199-
sql.select(bts.PipelineRun)
200-
.where(*where_clauses)
201-
.order_by(
202-
bts.PipelineRun.created_at.desc(),
203-
bts.PipelineRun.id.desc(),
204-
)
205-
.limit(self._DEFAULT_PAGE_SIZE)
206-
).all()
229+
stmt = (
230+
sql.select(bts.PipelineRun)
231+
.where(*where_clauses)
232+
.order_by(
233+
bts.PipelineRun.created_at.desc(),
234+
bts.PipelineRun.id.desc(),
235+
)
236+
.limit(self._DEFAULT_PAGE_SIZE)
207237
)
208238

239+
sql_string = None
240+
if include_sql:
241+
sql_string = self._compile_sql_string(stmt, session.bind.dialect)
242+
243+
pipeline_runs = list(session.scalars(stmt).all())
244+
209245
next_page_token = filter_query_sql.maybe_next_page_token(
210246
rows=pipeline_runs, page_size=self._DEFAULT_PAGE_SIZE
211247
)
@@ -221,6 +257,7 @@ def list(
221257
for pipeline_run in pipeline_runs
222258
],
223259
next_page_token=next_page_token,
260+
sql=sql_string,
224261
)
225262

226263
def _create_pipeline_run_response(

tests/test_api_server_sql.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,87 @@ def test_list_filter_created_by_me(self, session_factory, service):
295295
assert len(result.pipeline_runs) == 1
296296
assert result.pipeline_runs[0].created_by == "alice@example.com"
297297

298+
def test_list_include_sql_default_none(self, session_factory, service):
299+
_create_run(session_factory, service, root_task=_make_task_spec())
300+
301+
with session_factory() as session:
302+
result = service.list(session=session)
303+
assert result.sql is None
304+
305+
def test_list_include_sql_true(self, session_factory, service):
306+
_create_run(session_factory, service, root_task=_make_task_spec())
307+
308+
with session_factory() as session:
309+
result = service.list(session=session, include_sql=True)
310+
expected = (
311+
"SELECT pipeline_run.id, pipeline_run.root_execution_id,"
312+
" pipeline_run.annotations, pipeline_run.created_by,"
313+
" pipeline_run.created_at, pipeline_run.updated_at,"
314+
" pipeline_run.parent_pipeline_id, pipeline_run.extra_data \n"
315+
"FROM pipeline_run"
316+
" ORDER BY pipeline_run.created_at DESC, pipeline_run.id DESC\n"
317+
" LIMIT 10 OFFSET 0"
318+
)
319+
assert result.sql == expected
320+
321+
def test_list_include_sql_with_filter_query(self, session_factory, service):
322+
run = _create_run(session_factory, service, root_task=_make_task_spec())
323+
with session_factory() as session:
324+
service.set_annotation(session=session, id=run.id, key="team", value="ml")
325+
326+
fq = json.dumps({"and": [{"key_exists": {"key": "team"}}]})
327+
with session_factory() as session:
328+
result = service.list(session=session, filter_query=fq, include_sql=True)
329+
expected = (
330+
"SELECT pipeline_run.id, pipeline_run.root_execution_id,"
331+
" pipeline_run.annotations, pipeline_run.created_by,"
332+
" pipeline_run.created_at, pipeline_run.updated_at,"
333+
" pipeline_run.parent_pipeline_id, pipeline_run.extra_data \n"
334+
"FROM pipeline_run \n"
335+
"WHERE EXISTS (SELECT pipeline_run_annotation.pipeline_run_id \n"
336+
"FROM pipeline_run_annotation \n"
337+
"WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id"
338+
" AND pipeline_run_annotation.\"key\" = 'team')"
339+
" ORDER BY pipeline_run.created_at DESC, pipeline_run.id DESC\n"
340+
" LIMIT 10 OFFSET 0"
341+
)
342+
assert result.sql == expected
343+
344+
def test_list_include_sql_with_cursor(self, session_factory, service):
345+
for i in range(12):
346+
_create_run(
347+
session_factory,
348+
service,
349+
root_task=_make_task_spec(f"pipeline-{i}"),
350+
)
351+
352+
with session_factory() as session:
353+
page1 = service.list(session=session)
354+
assert page1.next_page_token is not None
355+
356+
with session_factory() as session:
357+
page2 = service.list(
358+
session=session,
359+
page_token=page1.next_page_token,
360+
include_sql=True,
361+
)
362+
363+
cursor_dt_iso, cursor_id = page1.next_page_token.split("~")
364+
cursor_dt = datetime.datetime.fromisoformat(cursor_dt_iso)
365+
sql_dt = cursor_dt.strftime("%Y-%m-%d %H:%M:%S.%f")
366+
expected = (
367+
"SELECT pipeline_run.id, pipeline_run.root_execution_id,"
368+
" pipeline_run.annotations, pipeline_run.created_by,"
369+
" pipeline_run.created_at, pipeline_run.updated_at,"
370+
" pipeline_run.parent_pipeline_id, pipeline_run.extra_data \n"
371+
"FROM pipeline_run \n"
372+
f"WHERE (pipeline_run.created_at, pipeline_run.id)"
373+
f" < ('{sql_dt}', '{cursor_id}')"
374+
" ORDER BY pipeline_run.created_at DESC, pipeline_run.id DESC\n"
375+
" LIMIT 10 OFFSET 0"
376+
)
377+
assert page2.sql == expected
378+
298379

299380
class TestCreatePipelineRunResponse:
300381
def test_base_response(self, session_factory, service):

0 commit comments

Comments
 (0)