From 609d91b665744b6616994666a48cc422636f05fc Mon Sep 17 00:00:00 2001 From: mbeaulne Date: Wed, 11 Mar 2026 18:05:59 -0400 Subject: [PATCH] Add batch id to annotations --- cloud_pipelines_backend/api_server_sql.py | 24 ++++++++++- cloud_pipelines_backend/filter_query_sql.py | 5 +++ tests/test_api_server_sql.py | 48 ++++++++++++++++++++- 3 files changed, 74 insertions(+), 3 deletions(-) diff --git a/cloud_pipelines_backend/api_server_sql.py b/cloud_pipelines_backend/api_server_sql.py index 2851777..c02b5f5 100644 --- a/cloud_pipelines_backend/api_server_sql.py +++ b/cloud_pipelines_backend/api_server_sql.py @@ -95,6 +95,7 @@ class BatchCreateRequest: @dataclasses.dataclass(kw_only=True) class BatchCreatePipelineRunsResponse: + batch_id: str created_runs: list[PipelineRunResponse] @@ -152,6 +153,7 @@ def _build_pipeline_run( pipeline_run_id=pipeline_run.id, created_by=created_by, pipeline_name=pipeline_name, + annotations=annotations, ) return pipeline_run @@ -194,14 +196,19 @@ def create_batch( f"Batch size {len(runs)} exceeds the maximum of {self._MAX_BATCH_SIZE}." ) + batch_id = bts.generate_unique_id() pipeline_runs: list[bts.PipelineRun] = [] with session.begin(): for run_request in runs: + run_annotations = { + **(run_request.annotations or {}), + "batch_id": batch_id, + } pipeline_run = self._build_pipeline_run( session=session, root_task=run_request.root_task, - annotations=run_request.annotations, + annotations=run_annotations, created_by=created_by, ) pipeline_runs.append(pipeline_run) @@ -213,7 +220,10 @@ def create_batch( session.refresh(pipeline_run) responses.append(PipelineRunResponse.from_db(pipeline_run)) - return BatchCreatePipelineRunsResponse(created_runs=responses) + return BatchCreatePipelineRunsResponse( + batch_id=batch_id, + created_runs=responses, + ) def get(self, session: orm.Session, id: bts.IdType) -> PipelineRunResponse: pipeline_run = session.get(bts.PipelineRun, id) @@ -1321,6 +1331,7 @@ def _mirror_system_annotations( pipeline_run_id: bts.IdType, created_by: str | None, pipeline_name: str | None, + annotations: dict[str, Any] | None = None, ) -> None: """Mirror pipeline run fields as system annotations for filter_query search""" @@ -1356,6 +1367,15 @@ def _mirror_system_annotations( value=pipeline_name, ) ) + batch_id = (annotations or {}).get("batch_id") + if batch_id: + session.add( + bts.PipelineRunAnnotation( + pipeline_run_id=pipeline_run_id, + key=filter_query_sql.PipelineRunAnnotationSystemKey.BATCH_ID, + value=str(batch_id)[:bts._STR_MAX_LENGTH], + ) + ) def _recursively_create_all_executions_and_artifacts_root( diff --git a/cloud_pipelines_backend/filter_query_sql.py b/cloud_pipelines_backend/filter_query_sql.py index b306a57..fd573d3 100644 --- a/cloud_pipelines_backend/filter_query_sql.py +++ b/cloud_pipelines_backend/filter_query_sql.py @@ -26,6 +26,7 @@ class PipelineRunAnnotationSystemKey(str, enum.Enum): CREATED_BY = f"{_PIPELINE_RUN_KEY_PREFIX}created_by" PIPELINE_NAME = f"{_PIPELINE_RUN_KEY_PREFIX}name" CREATED_AT = f"{_PIPELINE_RUN_KEY_PREFIX}date.created_at" + BATCH_ID = f"{_PIPELINE_RUN_KEY_PREFIX}batch_id" SYSTEM_KEY_SUPPORTED_PREDICATES: dict[PipelineRunAnnotationSystemKey, set[type]] = { @@ -43,6 +44,10 @@ class PipelineRunAnnotationSystemKey(str, enum.Enum): PipelineRunAnnotationSystemKey.CREATED_AT: { filter_query_models.TimeRangePredicate, }, + PipelineRunAnnotationSystemKey.BATCH_ID: { + filter_query_models.KeyExistsPredicate, + filter_query_models.ValueEqualsPredicate, + }, } # --------------------------------------------------------------------------- diff --git a/tests/test_api_server_sql.py b/tests/test_api_server_sql.py index c2bd512..d2c332e 100644 --- a/tests/test_api_server_sql.py +++ b/tests/test_api_server_sql.py @@ -1663,7 +1663,9 @@ def test_create_batch_with_annotations(self, session_factory, service): ] with session_factory() as session: result = service.create_batch(session=session, runs=runs) - assert result.created_runs[0].annotations == annotations + run_annotations = result.created_runs[0].annotations + assert run_annotations["team"] == "ml-ops" + assert "batch_id" in run_annotations # batch_id is always injected def test_create_batch_mirrors_system_annotations(self, session_factory, service): runs = [ @@ -1723,3 +1725,47 @@ def test_create_batch_is_atomic(self, session_factory, service): for run in result.created_runs: fetched = service.get(session=session, id=run.id) assert fetched.id == run.id + + def test_create_batch_returns_batch_id(self, session_factory, service): + runs = [ + api_server_sql.BatchCreateRequest(root_task=_make_task_spec()) + for _ in range(2) + ] + with session_factory() as session: + result = service.create_batch(session=session, runs=runs) + # batch_id is a valid UUID string + import uuid + uuid.UUID(result.batch_id) # raises ValueError if invalid + + def test_create_batch_all_runs_share_batch_id(self, session_factory, service): + runs = [ + api_server_sql.BatchCreateRequest(root_task=_make_task_spec(f"p-{i}")) + for i in range(3) + ] + with session_factory() as session: + result = service.create_batch(session=session, runs=runs) + for run in result.created_runs: + assert run.annotations["batch_id"] == result.batch_id + + def test_create_batch_different_batches_get_different_ids(self, session_factory, service): + runs = [api_server_sql.BatchCreateRequest(root_task=_make_task_spec())] + with session_factory() as session: + result1 = service.create_batch(session=session, runs=runs) + with session_factory() as session: + result2 = service.create_batch(session=session, runs=runs) + assert result1.batch_id != result2.batch_id + + def test_create_batch_mirrors_batch_id_system_annotation(self, session_factory, service): + runs = [ + api_server_sql.BatchCreateRequest(root_task=_make_task_spec()), + ] + with session_factory() as session: + result = service.create_batch(session=session, runs=runs) + with session_factory() as session: + annotations = service.list_annotations( + session=session, id=result.created_runs[0].id + ) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.BATCH_ID] + == result.batch_id + )