Skip to content

Commit 098f799

Browse files
authored
refactor: Search pipeline run refactoring and baseline unit tests (#105)
### TL;DR Refactored the `list` method in `PipelineRunsApiService_Sql` to improve code organization and added comprehensive test coverage for pipeline run operations. ### What changed? #### Functional None. #### Other - Extracted filter parsing logic into separate helper functions `_resolve_filter_value` and `_build_filter_where_clauses` - Moved pipeline run response creation logic into a dedicated `_create_pipeline_run_response` method - Added constants `_PAGE_TOKEN_OFFSET_KEY` and `_PAGE_TOKEN_FILTER_KEY` for page token field names - Added comprehensive test suite covering listing, filtering, pagination, annotation CRUD operations, and helper functions ### How to test? ``` uv run pytest tests/test_api_server_sql.py ``` Run the new test suite in `tests/test_api_server_sql.py` which includes: - Tests for listing pipeline runs with various filters and pagination - Tests for pipeline name resolution and execution stats inclusion - Tests for annotation management (set, delete, list) - Tests for filter parsing and page token handling - Edge cases like empty results and unsupported filters ### Why make this change? 1. **Refactoring upfront** — structural changes land here so upstream PR diffs only show new functionality, not mixed refactoring noise. 1. **Baseline tests** — establishing coverage now means each upstream PR's test additions map directly to its new features, making review clearer.
1 parent 2a2cc96 commit 098f799

2 files changed

Lines changed: 628 additions & 98 deletions

File tree

cloud_pipelines_backend/api_server_sql.py

Lines changed: 141 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44
import json
55
import logging
66
import typing
7-
from typing import Any, Optional
7+
from typing import Any, Final, Optional
8+
9+
import sqlalchemy as sql
10+
from sqlalchemy import orm
11+
12+
from . import backend_types_sql as bts
13+
from . import component_structures as structures
14+
from . import errors
815

916
if typing.TYPE_CHECKING:
1017
from cloud_pipelines.orchestration.storage_providers import (
@@ -26,12 +33,6 @@ def _get_current_time() -> datetime.datetime:
2633
return datetime.datetime.now(tz=datetime.timezone.utc)
2734

2835

29-
from . import component_structures as structures
30-
from . import backend_types_sql as bts
31-
from . import errors
32-
from .errors import ItemNotFoundError
33-
34-
3536
# ==== PipelineJobService
3637
@dataclasses.dataclass(kw_only=True)
3738
class PipelineRunResponse:
@@ -65,12 +66,11 @@ class ListPipelineJobsResponse:
6566
next_page_token: str | None = None
6667

6768

68-
import sqlalchemy as sql
69-
from sqlalchemy import orm
70-
71-
7269
class PipelineRunsApiService_Sql:
73-
PIPELINE_NAME_EXTRA_DATA_KEY = "pipeline_name"
70+
_PIPELINE_NAME_EXTRA_DATA_KEY = "pipeline_name"
71+
_PAGE_TOKEN_OFFSET_KEY: Final[str] = "offset"
72+
_PAGE_TOKEN_FILTER_KEY: Final[str] = "filter"
73+
_DEFAULT_PAGE_SIZE: Final[int] = 10
7474

7575
def create(
7676
self,
@@ -104,7 +104,7 @@ def create(
104104
annotations=annotations,
105105
created_by=created_by,
106106
extra_data={
107-
self.PIPELINE_NAME_EXTRA_DATA_KEY: pipeline_name,
107+
self._PIPELINE_NAME_EXTRA_DATA_KEY: pipeline_name,
108108
},
109109
)
110110
session.add(pipeline_run)
@@ -116,7 +116,7 @@ def create(
116116
def get(self, session: orm.Session, id: bts.IdType) -> PipelineRunResponse:
117117
pipeline_run = session.get(bts.PipelineRun, id)
118118
if not pipeline_run:
119-
raise ItemNotFoundError(f"Pipeline run {id} not found.")
119+
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
120120
return PipelineRunResponse.from_db(pipeline_run)
121121

122122
def terminate(
@@ -128,7 +128,7 @@ def terminate(
128128
):
129129
pipeline_run = session.get(bts.PipelineRun, id)
130130
if not pipeline_run:
131-
raise ItemNotFoundError(f"Pipeline run {id} not found.")
131+
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
132132
if not skip_user_check and (terminated_by != pipeline_run.created_by):
133133
raise errors.PermissionError(
134134
f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be terminated by {terminated_by}"
@@ -166,98 +166,86 @@ def list(
166166
*,
167167
session: orm.Session,
168168
page_token: str | None = None,
169-
# page_size: int = 10,
170169
filter: str | None = None,
171170
current_user: str | None = None,
172171
include_pipeline_names: bool = False,
173172
include_execution_stats: bool = False,
174173
) -> ListPipelineJobsResponse:
175-
page_token_dict = _decode_page_token(page_token)
176-
OFFSET_KEY = "offset"
177-
offset = page_token_dict.get(OFFSET_KEY, 0)
178-
page_size = 10
179-
180-
FILTER_KEY = "filter"
181-
if page_token:
182-
filter = page_token_dict.get(FILTER_KEY, None)
183-
where_clauses = []
184-
parsed_filter = _parse_filter(filter) if filter else {}
185-
for key, value in parsed_filter.items():
186-
if key == "_text":
187-
raise NotImplementedError("Text search is not implemented yet.")
188-
elif key == "created_by":
189-
if value == "me":
190-
if current_user is None:
191-
# raise ApiServiceError(
192-
# f"The `created_by:me` filter requires `current_user`."
193-
# )
194-
current_user = ""
195-
value = current_user
196-
# TODO: Maybe make this a bit more robust.
197-
# We need to change the filter since it goes into the next_page_token.
198-
filter = filter.replace(
199-
"created_by:me", f"created_by:{current_user}"
200-
)
201-
if value:
202-
where_clauses.append(bts.PipelineRun.created_by == value)
203-
else:
204-
where_clauses.append(bts.PipelineRun.created_by == None)
205-
else:
206-
raise NotImplementedError(f"Unsupported filter {filter}.")
174+
filter_value, offset = _resolve_filter_value(
175+
filter=filter,
176+
page_token=page_token,
177+
)
178+
where_clauses, next_page_filter_value = _build_filter_where_clauses(
179+
filter_value=filter_value,
180+
current_user=current_user,
181+
)
182+
207183
pipeline_runs = list(
208184
session.scalars(
209185
sql.select(bts.PipelineRun)
210186
.where(*where_clauses)
211187
.order_by(bts.PipelineRun.created_at.desc())
212188
.offset(offset)
213-
.limit(page_size)
189+
.limit(self._DEFAULT_PAGE_SIZE)
214190
).all()
215191
)
216-
next_page_offset = offset + page_size
217-
next_page_token_dict = {OFFSET_KEY: next_page_offset, FILTER_KEY: filter}
192+
next_page_offset = offset + self._DEFAULT_PAGE_SIZE
193+
next_page_token_dict = {
194+
self._PAGE_TOKEN_OFFSET_KEY: next_page_offset,
195+
self._PAGE_TOKEN_FILTER_KEY: next_page_filter_value,
196+
}
218197
next_page_token = _encode_page_token(next_page_token_dict)
219-
if len(pipeline_runs) < page_size:
198+
if len(pipeline_runs) < self._DEFAULT_PAGE_SIZE:
220199
next_page_token = None
221200

222-
def create_pipeline_run_response(
223-
pipeline_run: bts.PipelineRun,
224-
) -> PipelineRunResponse:
225-
response = PipelineRunResponse.from_db(pipeline_run)
226-
if include_pipeline_names:
227-
pipeline_name = None
228-
extra_data = pipeline_run.extra_data or {}
229-
if self.PIPELINE_NAME_EXTRA_DATA_KEY in extra_data:
230-
pipeline_name = extra_data[self.PIPELINE_NAME_EXTRA_DATA_KEY]
231-
else:
232-
execution_node = session.get(
233-
bts.ExecutionNode, pipeline_run.root_execution_id
234-
)
235-
if execution_node:
236-
task_spec = structures.TaskSpec.from_json_dict(
237-
execution_node.task_spec
238-
)
239-
component_spec = task_spec.component_ref.spec
240-
if component_spec:
241-
pipeline_name = component_spec.name
242-
response.pipeline_name = pipeline_name
243-
if include_execution_stats:
244-
execution_status_stats = self._calculate_execution_status_stats(
245-
session=session, root_execution_id=pipeline_run.root_execution_id
246-
)
247-
response.execution_status_stats = {
248-
status.value: count
249-
for status, count in execution_status_stats.items()
250-
}
251-
return response
252-
253201
return ListPipelineJobsResponse(
254202
pipeline_runs=[
255-
create_pipeline_run_response(pipeline_run)
203+
self._create_pipeline_run_response(
204+
session=session,
205+
pipeline_run=pipeline_run,
206+
include_pipeline_names=include_pipeline_names,
207+
include_execution_stats=include_execution_stats,
208+
)
256209
for pipeline_run in pipeline_runs
257210
],
258211
next_page_token=next_page_token,
259212
)
260213

214+
def _create_pipeline_run_response(
215+
self,
216+
*,
217+
session: orm.Session,
218+
pipeline_run: bts.PipelineRun,
219+
include_pipeline_names: bool,
220+
include_execution_stats: bool,
221+
) -> PipelineRunResponse:
222+
response = PipelineRunResponse.from_db(pipeline_run)
223+
if include_pipeline_names:
224+
pipeline_name = None
225+
extra_data = pipeline_run.extra_data or {}
226+
if self._PIPELINE_NAME_EXTRA_DATA_KEY in extra_data:
227+
pipeline_name = extra_data[self._PIPELINE_NAME_EXTRA_DATA_KEY]
228+
else:
229+
execution_node = session.get(
230+
bts.ExecutionNode, pipeline_run.root_execution_id
231+
)
232+
if execution_node:
233+
task_spec = structures.TaskSpec.from_json_dict(
234+
execution_node.task_spec
235+
)
236+
component_spec = task_spec.component_ref.spec
237+
if component_spec:
238+
pipeline_name = component_spec.name
239+
response.pipeline_name = pipeline_name
240+
if include_execution_stats:
241+
execution_status_stats = self._calculate_execution_status_stats(
242+
session=session, root_execution_id=pipeline_run.root_execution_id
243+
)
244+
response.execution_status_stats = {
245+
status.value: count for status, count in execution_status_stats.items()
246+
}
247+
return response
248+
261249
def _calculate_execution_status_stats(
262250
self, session: orm.Session, root_execution_id: bts.IdType
263251
) -> dict[bts.ContainerExecutionStatus, int]:
@@ -316,7 +304,7 @@ def set_annotation(
316304
):
317305
pipeline_run = session.get(bts.PipelineRun, id)
318306
if not pipeline_run:
319-
raise ItemNotFoundError(f"Pipeline run {id} not found.")
307+
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
320308
if not skip_user_check and (user_name != pipeline_run.created_by):
321309
raise errors.PermissionError(
322310
f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be changed by {user_name}"
@@ -338,7 +326,7 @@ def delete_annotation(
338326
):
339327
pipeline_run = session.get(bts.PipelineRun, id)
340328
if not pipeline_run:
341-
raise ItemNotFoundError(f"Pipeline run {id} not found.")
329+
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
342330
if not skip_user_check and (user_name != pipeline_run.created_by):
343331
raise errors.PermissionError(
344332
f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be changed by {user_name}"
@@ -349,6 +337,64 @@ def delete_annotation(
349337
session.commit()
350338

351339

340+
def _resolve_filter_value(
341+
*,
342+
filter: str | None,
343+
page_token: str | None,
344+
) -> tuple[str | None, int]:
345+
"""Decode page_token and return the effective (filter_value, offset).
346+
347+
If a page_token is present, its stored filter takes precedence over the
348+
raw filter parameter (the token carries the resolved filter forward across pages).
349+
"""
350+
page_token_dict = _decode_page_token(page_token)
351+
offset = page_token_dict.get(
352+
PipelineRunsApiService_Sql._PAGE_TOKEN_OFFSET_KEY,
353+
0,
354+
)
355+
if page_token:
356+
filter = page_token_dict.get(
357+
PipelineRunsApiService_Sql._PAGE_TOKEN_FILTER_KEY,
358+
None,
359+
)
360+
return filter, offset
361+
362+
363+
def _build_filter_where_clauses(
364+
*,
365+
filter_value: str | None,
366+
current_user: str | None,
367+
) -> tuple[list[sql.ColumnElement], str | None]:
368+
"""Parse a filter string into SQLAlchemy WHERE clauses.
369+
370+
Returns (where_clauses, next_page_filter_value). The second value is the
371+
filter string with shorthand values resolved (e.g. "created_by:me" becomes
372+
"created_by:alice@example.com") so it can be embedded in the next page token.
373+
"""
374+
where_clauses: list[sql.ColumnElement] = []
375+
parsed_filter = _parse_filter(filter_value) if filter_value else {}
376+
for key, value in parsed_filter.items():
377+
if key == "_text":
378+
raise NotImplementedError("Text search is not implemented yet.")
379+
elif key == "created_by":
380+
if value == "me":
381+
if current_user is None:
382+
current_user = ""
383+
value = current_user
384+
# TODO: Maybe make this a bit more robust.
385+
# We need to change the filter since it goes into the next_page_token.
386+
filter_value = filter_value.replace(
387+
"created_by:me", f"created_by:{current_user}"
388+
)
389+
if value:
390+
where_clauses.append(bts.PipelineRun.created_by == value)
391+
else:
392+
where_clauses.append(bts.PipelineRun.created_by == None)
393+
else:
394+
raise NotImplementedError(f"Unsupported filter {filter_value}.")
395+
return where_clauses, filter_value
396+
397+
352398
def _decode_page_token(page_token: str) -> dict[str, Any]:
353399
return json.loads(base64.b64decode(page_token)) if page_token else {}
354400

@@ -524,7 +570,7 @@ class ExecutionNodesApiService_Sql:
524570
def get(self, session: orm.Session, id: bts.IdType) -> GetExecutionInfoResponse:
525571
execution_node = session.get(bts.ExecutionNode, id)
526572
if execution_node is None:
527-
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
573+
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")
528574

529575
parent_pipeline_run_id = session.scalar(
530576
sql.select(bts.PipelineRun.id).where(
@@ -676,7 +722,7 @@ def get_container_execution_state(
676722
) -> GetContainerExecutionStateResponse:
677723
execution = session.get(bts.ExecutionNode, id)
678724
if not execution:
679-
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
725+
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")
680726
container_execution = execution.container_execution
681727
if not container_execution:
682728
raise RuntimeError(
@@ -696,7 +742,7 @@ def get_artifacts(
696742
if not session.scalar(
697743
sql.select(sql.exists().where(bts.ExecutionNode.id == id))
698744
):
699-
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
745+
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")
700746

701747
input_artifact_links = session.scalars(
702748
sql.select(bts.InputArtifactLink)
@@ -742,7 +788,7 @@ def get_container_execution_log(
742788
) -> GetContainerExecutionLogResponse:
743789
execution = session.get(bts.ExecutionNode, id)
744790
if not execution:
745-
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
791+
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")
746792
container_execution = execution.container_execution
747793
execution_extra_data = execution.extra_data or {}
748794
system_error_exception_full = execution_extra_data.get(
@@ -829,7 +875,9 @@ def stream_container_execution_log(
829875
) -> typing.Iterator[str]:
830876
execution = session.get(bts.ExecutionNode, execution_id)
831877
if not execution:
832-
raise ItemNotFoundError(f"Execution with {execution_id=} does not exist.")
878+
raise errors.ItemNotFoundError(
879+
f"Execution with {execution_id=} does not exist."
880+
)
833881
container_execution = execution.container_execution
834882
if not container_execution:
835883
raise ApiServiceError(
@@ -970,7 +1018,7 @@ class ArtifactNodesApiService_Sql:
9701018
def get(self, session: orm.Session, id: bts.IdType) -> GetArtifactInfoResponse:
9711019
artifact_node = session.get(bts.ArtifactNode, id)
9721020
if artifact_node is None:
973-
raise ItemNotFoundError(f"Artifact with {id=} does not exist.")
1021+
raise errors.ItemNotFoundError(f"Artifact with {id=} does not exist.")
9741022
artifact_data = artifact_node.artifact_data
9751023
result = GetArtifactInfoResponse(id=artifact_node.id)
9761024
if artifact_data:
@@ -986,7 +1034,7 @@ def get_signed_artifact_url(
9861034
.where(bts.ArtifactNode.id == id)
9871035
)
9881036
if not artifact_data:
989-
raise ItemNotFoundError(f"Artifact node with {id=} does not exist.")
1037+
raise errors.ItemNotFoundError(f"Artifact node with {id=} does not exist.")
9901038
if not artifact_data.uri:
9911039
raise ValueError(f"Artifact node with {id=} does not have artifact URI.")
9921040
if artifact_data.is_dir:

0 commit comments

Comments
 (0)