Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions cloud_pipelines_backend/api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class BatchCreateRequest:

@dataclasses.dataclass(kw_only=True)
class BatchCreatePipelineRunsResponse:
batch_id: str
created_runs: list[PipelineRunResponse]


Expand Down Expand Up @@ -152,6 +153,7 @@ def _build_pipeline_run(
pipeline_run_id=pipeline_run.id,
created_by=created_by,
pipeline_name=pipeline_name,
annotations=annotations,
)
return pipeline_run

Expand Down Expand Up @@ -194,14 +196,19 @@ def create_batch(
f"Batch size {len(runs)} exceeds the maximum of {self._MAX_BATCH_SIZE}."
)

batch_id = bts.generate_unique_id()
pipeline_runs: list[bts.PipelineRun] = []

with session.begin():
for run_request in runs:
run_annotations = {
**(run_request.annotations or {}),
"batch_id": batch_id,
}
pipeline_run = self._build_pipeline_run(
session=session,
root_task=run_request.root_task,
annotations=run_request.annotations,
annotations=run_annotations,
created_by=created_by,
)
pipeline_runs.append(pipeline_run)
Expand All @@ -213,7 +220,10 @@ def create_batch(
session.refresh(pipeline_run)
responses.append(PipelineRunResponse.from_db(pipeline_run))

return BatchCreatePipelineRunsResponse(created_runs=responses)
return BatchCreatePipelineRunsResponse(
batch_id=batch_id,
created_runs=responses,
)

def get(self, session: orm.Session, id: bts.IdType) -> PipelineRunResponse:
pipeline_run = session.get(bts.PipelineRun, id)
Expand Down Expand Up @@ -1321,6 +1331,7 @@ def _mirror_system_annotations(
pipeline_run_id: bts.IdType,
created_by: str | None,
pipeline_name: str | None,
annotations: dict[str, Any] | None = None,
) -> None:
"""Mirror pipeline run fields as system annotations for filter_query search"""

Expand Down Expand Up @@ -1356,6 +1367,15 @@ def _mirror_system_annotations(
value=pipeline_name,
)
)
batch_id = (annotations or {}).get("batch_id")
if batch_id:
session.add(
bts.PipelineRunAnnotation(
pipeline_run_id=pipeline_run_id,
key=filter_query_sql.PipelineRunAnnotationSystemKey.BATCH_ID,
value=str(batch_id)[:bts._STR_MAX_LENGTH],
)
)


def _recursively_create_all_executions_and_artifacts_root(
Expand Down
5 changes: 5 additions & 0 deletions cloud_pipelines_backend/filter_query_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class PipelineRunAnnotationSystemKey(str, enum.Enum):
CREATED_BY = f"{_PIPELINE_RUN_KEY_PREFIX}created_by"
PIPELINE_NAME = f"{_PIPELINE_RUN_KEY_PREFIX}name"
CREATED_AT = f"{_PIPELINE_RUN_KEY_PREFIX}date.created_at"
BATCH_ID = f"{_PIPELINE_RUN_KEY_PREFIX}batch_id"


SYSTEM_KEY_SUPPORTED_PREDICATES: dict[PipelineRunAnnotationSystemKey, set[type]] = {
Expand All @@ -43,6 +44,10 @@ class PipelineRunAnnotationSystemKey(str, enum.Enum):
PipelineRunAnnotationSystemKey.CREATED_AT: {
filter_query_models.TimeRangePredicate,
},
PipelineRunAnnotationSystemKey.BATCH_ID: {
filter_query_models.KeyExistsPredicate,
filter_query_models.ValueEqualsPredicate,
},
}

# ---------------------------------------------------------------------------
Expand Down
48 changes: 47 additions & 1 deletion tests/test_api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,7 +1663,9 @@ def test_create_batch_with_annotations(self, session_factory, service):
]
with session_factory() as session:
result = service.create_batch(session=session, runs=runs)
assert result.created_runs[0].annotations == annotations
run_annotations = result.created_runs[0].annotations
assert run_annotations["team"] == "ml-ops"
assert "batch_id" in run_annotations # batch_id is always injected

def test_create_batch_mirrors_system_annotations(self, session_factory, service):
runs = [
Expand Down Expand Up @@ -1723,3 +1725,47 @@ def test_create_batch_is_atomic(self, session_factory, service):
for run in result.created_runs:
fetched = service.get(session=session, id=run.id)
assert fetched.id == run.id

def test_create_batch_returns_batch_id(self, session_factory, service):
runs = [
api_server_sql.BatchCreateRequest(root_task=_make_task_spec())
for _ in range(2)
]
with session_factory() as session:
result = service.create_batch(session=session, runs=runs)
# batch_id is a valid UUID string
import uuid
uuid.UUID(result.batch_id) # raises ValueError if invalid

def test_create_batch_all_runs_share_batch_id(self, session_factory, service):
runs = [
api_server_sql.BatchCreateRequest(root_task=_make_task_spec(f"p-{i}"))
for i in range(3)
]
with session_factory() as session:
result = service.create_batch(session=session, runs=runs)
for run in result.created_runs:
assert run.annotations["batch_id"] == result.batch_id

def test_create_batch_different_batches_get_different_ids(self, session_factory, service):
runs = [api_server_sql.BatchCreateRequest(root_task=_make_task_spec())]
with session_factory() as session:
result1 = service.create_batch(session=session, runs=runs)
with session_factory() as session:
result2 = service.create_batch(session=session, runs=runs)
assert result1.batch_id != result2.batch_id

def test_create_batch_mirrors_batch_id_system_annotation(self, session_factory, service):
runs = [
api_server_sql.BatchCreateRequest(root_task=_make_task_spec()),
]
with session_factory() as session:
result = service.create_batch(session=session, runs=runs)
with session_factory() as session:
annotations = service.list_annotations(
session=session, id=result.created_runs[0].id
)
assert (
annotations[filter_query_sql.PipelineRunAnnotationSystemKey.BATCH_ID]
== result.batch_id
)