Skip to content

Commit 962d30e

Browse files
committed
refactor: Search pipeline run using Pydantic query parameter
1 parent 546f75f commit 962d30e

2 files changed

Lines changed: 618 additions & 71 deletions

File tree

cloud_pipelines_backend/api_server_sql.py

Lines changed: 121 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import json
55
import logging
66
import typing
7-
from typing import Any, Optional
7+
from typing import Annotated, Any, Final, Optional
8+
9+
from fastapi import Query
10+
from pydantic import BaseModel
811

912
if typing.TYPE_CHECKING:
1013
from cloud_pipelines.orchestration.storage_providers import (
@@ -31,6 +34,9 @@ def _get_current_time() -> datetime.datetime:
3134
from . import errors
3235
from .errors import ItemNotFoundError
3336

37+
_PAGE_TOKEN_OFFSET_KEY: Final[str] = "offset"
38+
_PAGE_TOKEN_FILTER_KEY: Final[str] = "filter"
39+
3440

3541
# ==== PipelineJobService
3642
@dataclasses.dataclass(kw_only=True)
@@ -65,6 +71,13 @@ class ListPipelineJobsResponse:
6571
next_page_token: str | None = None
6672

6773

74+
class ListPipelineRunsParams(BaseModel):
75+
filter: str | None = None
76+
page_token: str | None = None
77+
include_pipeline_names: bool = False
78+
include_execution_stats: bool = False
79+
80+
6881
import sqlalchemy as sql
6982
from sqlalchemy import orm
7083

@@ -165,45 +178,19 @@ def list(
165178
self,
166179
*,
167180
session: orm.Session,
168-
page_token: str | None = None,
169-
# page_size: int = 10,
170-
filter: str | None = None,
171181
current_user: str | None = None,
172-
include_pipeline_names: bool = False,
173-
include_execution_stats: bool = False,
182+
params: Annotated[ListPipelineRunsParams, Query()],
174183
) -> ListPipelineJobsResponse:
175-
page_token_dict = _decode_page_token(page_token)
176-
OFFSET_KEY = "offset"
177-
offset = page_token_dict.get(OFFSET_KEY, 0)
178-
page_size = 10
184+
filter_value, offset = _resolve_filter_value(
185+
filter=params.filter,
186+
page_token=params.page_token,
187+
)
188+
where_clauses, next_page_filter_value = _build_filter_where_clauses(
189+
filter_value=filter_value,
190+
current_user=current_user,
191+
)
179192

180-
FILTER_KEY = "filter"
181-
if page_token:
182-
filter = page_token_dict.get(FILTER_KEY, None)
183-
where_clauses = []
184-
parsed_filter = _parse_filter(filter) if filter else {}
185-
for key, value in parsed_filter.items():
186-
if key == "_text":
187-
raise NotImplementedError("Text search is not implemented yet.")
188-
elif key == "created_by":
189-
if value == "me":
190-
if current_user is None:
191-
# raise ApiServiceError(
192-
# f"The `created_by:me` filter requires `current_user`."
193-
# )
194-
current_user = ""
195-
value = current_user
196-
# TODO: Maybe make this a bit more robust.
197-
# We need to change the filter since it goes into the next_page_token.
198-
filter = filter.replace(
199-
"created_by:me", f"created_by:{current_user}"
200-
)
201-
if value:
202-
where_clauses.append(bts.PipelineRun.created_by == value)
203-
else:
204-
where_clauses.append(bts.PipelineRun.created_by == None)
205-
else:
206-
raise NotImplementedError(f"Unsupported filter {filter}.")
193+
page_size = 10
207194
pipeline_runs = list(
208195
session.scalars(
209196
sql.select(bts.PipelineRun)
@@ -214,50 +201,62 @@ def list(
214201
).all()
215202
)
216203
next_page_offset = offset + page_size
217-
next_page_token_dict = {OFFSET_KEY: next_page_offset, FILTER_KEY: filter}
204+
next_page_token_dict = {
205+
_PAGE_TOKEN_OFFSET_KEY: next_page_offset,
206+
_PAGE_TOKEN_FILTER_KEY: next_page_filter_value,
207+
}
218208
next_page_token = _encode_page_token(next_page_token_dict)
219209
if len(pipeline_runs) < page_size:
220210
next_page_token = None
221211

222-
def create_pipeline_run_response(
223-
pipeline_run: bts.PipelineRun,
224-
) -> PipelineRunResponse:
225-
response = PipelineRunResponse.from_db(pipeline_run)
226-
if include_pipeline_names:
227-
pipeline_name = None
228-
extra_data = pipeline_run.extra_data or {}
229-
if self.PIPELINE_NAME_EXTRA_DATA_KEY in extra_data:
230-
pipeline_name = extra_data[self.PIPELINE_NAME_EXTRA_DATA_KEY]
231-
else:
232-
execution_node = session.get(
233-
bts.ExecutionNode, pipeline_run.root_execution_id
234-
)
235-
if execution_node:
236-
task_spec = structures.TaskSpec.from_json_dict(
237-
execution_node.task_spec
238-
)
239-
component_spec = task_spec.component_ref.spec
240-
if component_spec:
241-
pipeline_name = component_spec.name
242-
response.pipeline_name = pipeline_name
243-
if include_execution_stats:
244-
execution_status_stats = self._calculate_execution_status_stats(
245-
session=session, root_execution_id=pipeline_run.root_execution_id
246-
)
247-
response.execution_status_stats = {
248-
status.value: count
249-
for status, count in execution_status_stats.items()
250-
}
251-
return response
252-
253212
return ListPipelineJobsResponse(
254213
pipeline_runs=[
255-
create_pipeline_run_response(pipeline_run)
214+
self._create_pipeline_run_response(
215+
session=session,
216+
pipeline_run=pipeline_run,
217+
include_pipeline_names=params.include_pipeline_names,
218+
include_execution_stats=params.include_execution_stats,
219+
)
256220
for pipeline_run in pipeline_runs
257221
],
258222
next_page_token=next_page_token,
259223
)
260224

225+
def _create_pipeline_run_response(
226+
self,
227+
*,
228+
session: orm.Session,
229+
pipeline_run: bts.PipelineRun,
230+
include_pipeline_names: bool,
231+
include_execution_stats: bool,
232+
) -> PipelineRunResponse:
233+
response = PipelineRunResponse.from_db(pipeline_run)
234+
if include_pipeline_names:
235+
pipeline_name = None
236+
extra_data = pipeline_run.extra_data or {}
237+
if self.PIPELINE_NAME_EXTRA_DATA_KEY in extra_data:
238+
pipeline_name = extra_data[self.PIPELINE_NAME_EXTRA_DATA_KEY]
239+
else:
240+
execution_node = session.get(
241+
bts.ExecutionNode, pipeline_run.root_execution_id
242+
)
243+
if execution_node:
244+
task_spec = structures.TaskSpec.from_json_dict(
245+
execution_node.task_spec
246+
)
247+
component_spec = task_spec.component_ref.spec
248+
if component_spec:
249+
pipeline_name = component_spec.name
250+
response.pipeline_name = pipeline_name
251+
if include_execution_stats:
252+
execution_status_stats = self._calculate_execution_status_stats(
253+
session=session, root_execution_id=pipeline_run.root_execution_id
254+
)
255+
response.execution_status_stats = {
256+
status.value: count for status, count in execution_status_stats.items()
257+
}
258+
return response
259+
261260
def _calculate_execution_status_stats(
262261
self, session: orm.Session, root_execution_id: bts.IdType
263262
) -> dict[bts.ContainerExecutionStatus, int]:
@@ -349,6 +348,58 @@ def delete_annotation(
349348
session.commit()
350349

351350

351+
def _resolve_filter_value(
352+
*,
353+
filter: str | None,
354+
page_token: str | None,
355+
) -> tuple[str | None, int]:
356+
"""Decode page_token and return the effective (filter_value, offset).
357+
358+
If a page_token is present, its stored filter takes precedence over the
359+
raw filter parameter (the token carries the resolved filter forward across pages).
360+
"""
361+
page_token_dict = _decode_page_token(page_token)
362+
offset = page_token_dict.get(_PAGE_TOKEN_OFFSET_KEY, 0)
363+
if page_token:
364+
filter = page_token_dict.get(_PAGE_TOKEN_FILTER_KEY, None)
365+
return filter, offset
366+
367+
368+
def _build_filter_where_clauses(
369+
*,
370+
filter_value: str | None,
371+
current_user: str | None,
372+
) -> tuple[list[sql.ColumnElement], str | None]:
373+
"""Parse a filter string into SQLAlchemy WHERE clauses.
374+
375+
Returns (where_clauses, next_page_filter_value). The second value is the
376+
filter string with shorthand values resolved (e.g. "created_by:me" becomes
377+
"created_by:alice@example.com") so it can be embedded in the next page token.
378+
"""
379+
where_clauses: list[sql.ColumnElement] = []
380+
parsed_filter = _parse_filter(filter_value) if filter_value else {}
381+
for key, value in parsed_filter.items():
382+
if key == "_text":
383+
raise NotImplementedError("Text search is not implemented yet.")
384+
elif key == "created_by":
385+
if value == "me":
386+
if current_user is None:
387+
current_user = ""
388+
value = current_user
389+
# TODO: Maybe make this a bit more robust.
390+
# We need to change the filter since it goes into the next_page_token.
391+
filter_value = filter_value.replace(
392+
"created_by:me", f"created_by:{current_user}"
393+
)
394+
if value:
395+
where_clauses.append(bts.PipelineRun.created_by == value)
396+
else:
397+
where_clauses.append(bts.PipelineRun.created_by == None)
398+
else:
399+
raise NotImplementedError(f"Unsupported filter {filter_value}.")
400+
return where_clauses, filter_value
401+
402+
352403
def _decode_page_token(page_token: str) -> dict[str, Any]:
353404
return json.loads(base64.b64decode(page_token)) if page_token else {}
354405

0 commit comments

Comments
 (0)