Skip to content

Commit e853104

Browse files
committed
refactor: Search pipeline run using Pydantic query parameter
1 parent ee37633 commit e853104

2 files changed

Lines changed: 624 additions & 94 deletions

File tree

cloud_pipelines_backend/api_server_sql.py

Lines changed: 133 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44
import json
55
import logging
66
import typing
7-
from typing import Any, Optional
7+
from typing import Any, Final, Optional
8+
9+
import sqlalchemy as sql
10+
from sqlalchemy import orm
11+
12+
from . import backend_types_sql as bts
13+
from . import component_structures as structures
14+
from . import errors
815

916
if typing.TYPE_CHECKING:
1017
from cloud_pipelines.orchestration.storage_providers import (
@@ -26,10 +33,9 @@ def _get_current_time() -> datetime.datetime:
2633
return datetime.datetime.now(tz=datetime.timezone.utc)
2734

2835

29-
from . import component_structures as structures
30-
from . import backend_types_sql as bts
31-
from . import errors
32-
from .errors import ItemNotFoundError
36+
_PAGE_TOKEN_OFFSET_KEY: Final[str] = "offset"
37+
_PAGE_TOKEN_FILTER_KEY: Final[str] = "filter"
38+
_DEFAULT_PAGE_SIZE: Final[int] = 10
3339

3440

3541
# ==== PipelineJobService
@@ -65,10 +71,6 @@ class ListPipelineJobsResponse:
6571
next_page_token: str | None = None
6672

6773

68-
import sqlalchemy as sql
69-
from sqlalchemy import orm
70-
71-
7274
class PipelineRunsApiService_Sql:
7375
PIPELINE_NAME_EXTRA_DATA_KEY = "pipeline_name"
7476

@@ -116,7 +118,7 @@ def create(
116118
def get(self, session: orm.Session, id: bts.IdType) -> PipelineRunResponse:
117119
pipeline_run = session.get(bts.PipelineRun, id)
118120
if not pipeline_run:
119-
raise ItemNotFoundError(f"Pipeline run {id} not found.")
121+
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
120122
return PipelineRunResponse.from_db(pipeline_run)
121123

122124
def terminate(
@@ -128,7 +130,7 @@ def terminate(
128130
):
129131
pipeline_run = session.get(bts.PipelineRun, id)
130132
if not pipeline_run:
131-
raise ItemNotFoundError(f"Pipeline run {id} not found.")
133+
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
132134
if not skip_user_check and (terminated_by != pipeline_run.created_by):
133135
raise errors.PermissionError(
134136
f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be terminated by {terminated_by}"
@@ -166,98 +168,86 @@ def list(
166168
*,
167169
session: orm.Session,
168170
page_token: str | None = None,
169-
# page_size: int = 10,
170171
filter: str | None = None,
171172
current_user: str | None = None,
172173
include_pipeline_names: bool = False,
173174
include_execution_stats: bool = False,
174175
) -> ListPipelineJobsResponse:
175-
page_token_dict = _decode_page_token(page_token)
176-
OFFSET_KEY = "offset"
177-
offset = page_token_dict.get(OFFSET_KEY, 0)
178-
page_size = 10
179-
180-
FILTER_KEY = "filter"
181-
if page_token:
182-
filter = page_token_dict.get(FILTER_KEY, None)
183-
where_clauses = []
184-
parsed_filter = _parse_filter(filter) if filter else {}
185-
for key, value in parsed_filter.items():
186-
if key == "_text":
187-
raise NotImplementedError("Text search is not implemented yet.")
188-
elif key == "created_by":
189-
if value == "me":
190-
if current_user is None:
191-
# raise ApiServiceError(
192-
# f"The `created_by:me` filter requires `current_user`."
193-
# )
194-
current_user = ""
195-
value = current_user
196-
# TODO: Maybe make this a bit more robust.
197-
# We need to change the filter since it goes into the next_page_token.
198-
filter = filter.replace(
199-
"created_by:me", f"created_by:{current_user}"
200-
)
201-
if value:
202-
where_clauses.append(bts.PipelineRun.created_by == value)
203-
else:
204-
where_clauses.append(bts.PipelineRun.created_by == None)
205-
else:
206-
raise NotImplementedError(f"Unsupported filter {filter}.")
176+
filter_value, offset = _resolve_filter_value(
177+
filter=filter,
178+
page_token=page_token,
179+
)
180+
where_clauses, next_page_filter_value = _build_filter_where_clauses(
181+
filter_value=filter_value,
182+
current_user=current_user,
183+
)
184+
207185
pipeline_runs = list(
208186
session.scalars(
209187
sql.select(bts.PipelineRun)
210188
.where(*where_clauses)
211189
.order_by(bts.PipelineRun.created_at.desc())
212190
.offset(offset)
213-
.limit(page_size)
191+
.limit(_DEFAULT_PAGE_SIZE)
214192
).all()
215193
)
216-
next_page_offset = offset + page_size
217-
next_page_token_dict = {OFFSET_KEY: next_page_offset, FILTER_KEY: filter}
194+
next_page_offset = offset + _DEFAULT_PAGE_SIZE
195+
next_page_token_dict = {
196+
_PAGE_TOKEN_OFFSET_KEY: next_page_offset,
197+
_PAGE_TOKEN_FILTER_KEY: next_page_filter_value,
198+
}
218199
next_page_token = _encode_page_token(next_page_token_dict)
219-
if len(pipeline_runs) < page_size:
200+
if len(pipeline_runs) < _DEFAULT_PAGE_SIZE:
220201
next_page_token = None
221202

222-
def create_pipeline_run_response(
223-
pipeline_run: bts.PipelineRun,
224-
) -> PipelineRunResponse:
225-
response = PipelineRunResponse.from_db(pipeline_run)
226-
if include_pipeline_names:
227-
pipeline_name = None
228-
extra_data = pipeline_run.extra_data or {}
229-
if self.PIPELINE_NAME_EXTRA_DATA_KEY in extra_data:
230-
pipeline_name = extra_data[self.PIPELINE_NAME_EXTRA_DATA_KEY]
231-
else:
232-
execution_node = session.get(
233-
bts.ExecutionNode, pipeline_run.root_execution_id
234-
)
235-
if execution_node:
236-
task_spec = structures.TaskSpec.from_json_dict(
237-
execution_node.task_spec
238-
)
239-
component_spec = task_spec.component_ref.spec
240-
if component_spec:
241-
pipeline_name = component_spec.name
242-
response.pipeline_name = pipeline_name
243-
if include_execution_stats:
244-
execution_status_stats = self._calculate_execution_status_stats(
245-
session=session, root_execution_id=pipeline_run.root_execution_id
246-
)
247-
response.execution_status_stats = {
248-
status.value: count
249-
for status, count in execution_status_stats.items()
250-
}
251-
return response
252-
253203
return ListPipelineJobsResponse(
254204
pipeline_runs=[
255-
create_pipeline_run_response(pipeline_run)
205+
self._create_pipeline_run_response(
206+
session=session,
207+
pipeline_run=pipeline_run,
208+
include_pipeline_names=include_pipeline_names,
209+
include_execution_stats=include_execution_stats,
210+
)
256211
for pipeline_run in pipeline_runs
257212
],
258213
next_page_token=next_page_token,
259214
)
260215

216+
def _create_pipeline_run_response(
217+
self,
218+
*,
219+
session: orm.Session,
220+
pipeline_run: bts.PipelineRun,
221+
include_pipeline_names: bool,
222+
include_execution_stats: bool,
223+
) -> PipelineRunResponse:
224+
response = PipelineRunResponse.from_db(pipeline_run)
225+
if include_pipeline_names:
226+
pipeline_name = None
227+
extra_data = pipeline_run.extra_data or {}
228+
if self.PIPELINE_NAME_EXTRA_DATA_KEY in extra_data:
229+
pipeline_name = extra_data[self.PIPELINE_NAME_EXTRA_DATA_KEY]
230+
else:
231+
execution_node = session.get(
232+
bts.ExecutionNode, pipeline_run.root_execution_id
233+
)
234+
if execution_node:
235+
task_spec = structures.TaskSpec.from_json_dict(
236+
execution_node.task_spec
237+
)
238+
component_spec = task_spec.component_ref.spec
239+
if component_spec:
240+
pipeline_name = component_spec.name
241+
response.pipeline_name = pipeline_name
242+
if include_execution_stats:
243+
execution_status_stats = self._calculate_execution_status_stats(
244+
session=session, root_execution_id=pipeline_run.root_execution_id
245+
)
246+
response.execution_status_stats = {
247+
status.value: count for status, count in execution_status_stats.items()
248+
}
249+
return response
250+
261251
def _calculate_execution_status_stats(
262252
self, session: orm.Session, root_execution_id: bts.IdType
263253
) -> dict[bts.ContainerExecutionStatus, int]:
@@ -316,7 +306,7 @@ def set_annotation(
316306
):
317307
pipeline_run = session.get(bts.PipelineRun, id)
318308
if not pipeline_run:
319-
raise ItemNotFoundError(f"Pipeline run {id} not found.")
309+
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
320310
if not skip_user_check and (user_name != pipeline_run.created_by):
321311
raise errors.PermissionError(
322312
f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be changed by {user_name}"
@@ -338,7 +328,7 @@ def delete_annotation(
338328
):
339329
pipeline_run = session.get(bts.PipelineRun, id)
340330
if not pipeline_run:
341-
raise ItemNotFoundError(f"Pipeline run {id} not found.")
331+
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
342332
if not skip_user_check and (user_name != pipeline_run.created_by):
343333
raise errors.PermissionError(
344334
f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be changed by {user_name}"
@@ -349,6 +339,58 @@ def delete_annotation(
349339
session.commit()
350340

351341

342+
def _resolve_filter_value(
343+
*,
344+
filter: str | None,
345+
page_token: str | None,
346+
) -> tuple[str | None, int]:
347+
"""Decode page_token and return the effective (filter_value, offset).
348+
349+
If a page_token is present, its stored filter takes precedence over the
350+
raw filter parameter (the token carries the resolved filter forward across pages).
351+
"""
352+
page_token_dict = _decode_page_token(page_token)
353+
offset = page_token_dict.get(_PAGE_TOKEN_OFFSET_KEY, 0)
354+
if page_token:
355+
filter = page_token_dict.get(_PAGE_TOKEN_FILTER_KEY, None)
356+
return filter, offset
357+
358+
359+
def _build_filter_where_clauses(
360+
*,
361+
filter_value: str | None,
362+
current_user: str | None,
363+
) -> tuple[list[sql.ColumnElement], str | None]:
364+
"""Parse a filter string into SQLAlchemy WHERE clauses.
365+
366+
Returns (where_clauses, next_page_filter_value). The second value is the
367+
filter string with shorthand values resolved (e.g. "created_by:me" becomes
368+
"created_by:alice@example.com") so it can be embedded in the next page token.
369+
"""
370+
where_clauses: list[sql.ColumnElement] = []
371+
parsed_filter = _parse_filter(filter_value) if filter_value else {}
372+
for key, value in parsed_filter.items():
373+
if key == "_text":
374+
raise NotImplementedError("Text search is not implemented yet.")
375+
elif key == "created_by":
376+
if value == "me":
377+
if current_user is None:
378+
current_user = ""
379+
value = current_user
380+
# TODO: Maybe make this a bit more robust.
381+
# We need to change the filter since it goes into the next_page_token.
382+
filter_value = filter_value.replace(
383+
"created_by:me", f"created_by:{current_user}"
384+
)
385+
if value:
386+
where_clauses.append(bts.PipelineRun.created_by == value)
387+
else:
388+
where_clauses.append(bts.PipelineRun.created_by == None)
389+
else:
390+
raise NotImplementedError(f"Unsupported filter {filter_value}.")
391+
return where_clauses, filter_value
392+
393+
352394
def _decode_page_token(page_token: str) -> dict[str, Any]:
353395
return json.loads(base64.b64decode(page_token)) if page_token else {}
354396

@@ -524,7 +566,7 @@ class ExecutionNodesApiService_Sql:
524566
def get(self, session: orm.Session, id: bts.IdType) -> GetExecutionInfoResponse:
525567
execution_node = session.get(bts.ExecutionNode, id)
526568
if execution_node is None:
527-
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
569+
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")
528570

529571
parent_pipeline_run_id = session.scalar(
530572
sql.select(bts.PipelineRun.id).where(
@@ -676,7 +718,7 @@ def get_container_execution_state(
676718
) -> GetContainerExecutionStateResponse:
677719
execution = session.get(bts.ExecutionNode, id)
678720
if not execution:
679-
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
721+
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")
680722
container_execution = execution.container_execution
681723
if not container_execution:
682724
raise RuntimeError(
@@ -696,7 +738,7 @@ def get_artifacts(
696738
if not session.scalar(
697739
sql.select(sql.exists().where(bts.ExecutionNode.id == id))
698740
):
699-
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
741+
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")
700742

701743
input_artifact_links = session.scalars(
702744
sql.select(bts.InputArtifactLink)
@@ -742,7 +784,7 @@ def get_container_execution_log(
742784
) -> GetContainerExecutionLogResponse:
743785
execution = session.get(bts.ExecutionNode, id)
744786
if not execution:
745-
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
787+
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")
746788
container_execution = execution.container_execution
747789
execution_extra_data = execution.extra_data or {}
748790
system_error_exception_full = execution_extra_data.get(
@@ -829,7 +871,9 @@ def stream_container_execution_log(
829871
) -> typing.Iterator[str]:
830872
execution = session.get(bts.ExecutionNode, execution_id)
831873
if not execution:
832-
raise ItemNotFoundError(f"Execution with {execution_id=} does not exist.")
874+
raise errors.ItemNotFoundError(
875+
f"Execution with {execution_id=} does not exist."
876+
)
833877
container_execution = execution.container_execution
834878
if not container_execution:
835879
raise ApiServiceError(
@@ -970,7 +1014,7 @@ class ArtifactNodesApiService_Sql:
9701014
def get(self, session: orm.Session, id: bts.IdType) -> GetArtifactInfoResponse:
9711015
artifact_node = session.get(bts.ArtifactNode, id)
9721016
if artifact_node is None:
973-
raise ItemNotFoundError(f"Artifact with {id=} does not exist.")
1017+
raise errors.ItemNotFoundError(f"Artifact with {id=} does not exist.")
9741018
artifact_data = artifact_node.artifact_data
9751019
result = GetArtifactInfoResponse(id=artifact_node.id)
9761020
if artifact_data:
@@ -986,7 +1030,7 @@ def get_signed_artifact_url(
9861030
.where(bts.ArtifactNode.id == id)
9871031
)
9881032
if not artifact_data:
989-
raise ItemNotFoundError(f"Artifact node with {id=} does not exist.")
1033+
raise errors.ItemNotFoundError(f"Artifact node with {id=} does not exist.")
9901034
if not artifact_data.uri:
9911035
raise ValueError(f"Artifact node with {id=} does not have artifact URI.")
9921036
if artifact_data.is_dir:

0 commit comments

Comments
 (0)