diff --git a/cloud_pipelines_backend/api_router.py b/cloud_pipelines_backend/api_router.py index e60a733..44f3068 100644 --- a/cloud_pipelines_backend/api_router.py +++ b/cloud_pipelines_backend/api_router.py @@ -331,6 +331,20 @@ def stream_container_log( **default_config, )(create_run_func) + create_batch_func = pipeline_run_service.create_batch + create_batch_func = inject_session_dependency(create_batch_func) + create_batch_func = add_parameter_annotation_metadata( + create_batch_func, + parameter_name="created_by", + annotation_metadata=get_user_name_dependency, + ) + router.post( + "/api/pipeline_runs/batch", + tags=["pipelineRuns"], + dependencies=pipeline_run_creation_dependencies, + **default_config, + )(create_batch_func) + router.get( "/api/artifacts/{id}/signed_artifact_url", tags=["artifacts"], **default_config )(inject_session_dependency(artifact_service.get_signed_artifact_url)) diff --git a/cloud_pipelines_backend/api_server_sql.py b/cloud_pipelines_backend/api_server_sql.py index 2e13daa..861f1c9 100644 --- a/cloud_pipelines_backend/api_server_sql.py +++ b/cloud_pipelines_backend/api_server_sql.py @@ -86,9 +86,22 @@ class ListPipelineJobsResponse: next_page_token: str | None = None +@dataclasses.dataclass(kw_only=True) +class BatchCreateRequest: + root_task: structures.TaskSpec + components: Optional[list[structures.ComponentReference]] = None + annotations: Optional[dict[str, Any]] = None + + +@dataclasses.dataclass(kw_only=True) +class BatchCreatePipelineRunsResponse: + created_runs: list[PipelineRunResponse] + + class PipelineRunsApiService_Sql: _PIPELINE_NAME_EXTRA_DATA_KEY = "pipeline_name" _DEFAULT_PAGE_SIZE: Final[int] = 10 + _MAX_BATCH_SIZE: Final[int] = 100 _SYSTEM_KEY_RESERVED_MSG = ( "Annotation keys starting with " f"{filter_query_sql.SYSTEM_KEY_PREFIX!r} are reserved for system use." @@ -148,6 +161,62 @@ def create( session.refresh(pipeline_run) return PipelineRunResponse.from_db(pipeline_run) + def create_batch( + self, + session: orm.Session, + runs: list[BatchCreateRequest], + created_by: str | None = None, + ) -> BatchCreatePipelineRunsResponse: + if not runs: + raise errors.ApiValidationError("Batch must contain at least one run.") + if len(runs) > self._MAX_BATCH_SIZE: + raise errors.ApiValidationError( + f"Batch size {len(runs)} exceeds the maximum of {self._MAX_BATCH_SIZE}." + ) + + pipeline_runs: list[bts.PipelineRun] = [] + + with session.begin(): + for run_request in runs: + pipeline_name = run_request.root_task.component_ref.spec.name + + root_execution_node = ( + _recursively_create_all_executions_and_artifacts_root( + session=session, + root_task_spec=run_request.root_task, + ) + ) + + current_time = _get_current_time() + pipeline_run = bts.PipelineRun( + root_execution=root_execution_node, + created_at=current_time, + updated_at=current_time, + annotations=run_request.annotations, + created_by=created_by, + extra_data={ + self._PIPELINE_NAME_EXTRA_DATA_KEY: pipeline_name, + }, + ) + session.add(pipeline_run) + session.flush() + _mirror_system_annotations( + session=session, + pipeline_run_id=pipeline_run.id, + created_by=created_by, + pipeline_name=pipeline_name, + ) + pipeline_runs.append(pipeline_run) + + session.commit() + + responses: list[PipelineRunResponse] = [] + for pipeline_run in pipeline_runs: + session.refresh(pipeline_run) + responses.append(PipelineRunResponse.from_db(pipeline_run)) + + return BatchCreatePipelineRunsResponse(created_runs=responses) + def get(self, session: orm.Session, id: bts.IdType) -> PipelineRunResponse: pipeline_run = session.get(bts.PipelineRun, id) if not pipeline_run: diff --git a/tests/test_api_server_sql.py b/tests/test_api_server_sql.py index b1a2592..c2bd512 100644 --- a/tests/test_api_server_sql.py +++ b/tests/test_api_server_sql.py @@ -1628,3 +1628,98 @@ def test_returns_none_on_malformed_dict(self): task_spec_dict={"bad": "data"} ) assert result is None + + +class TestPipelineRunServiceCreateBatch: + def test_create_batch_returns_all_runs(self, session_factory, service): + runs = [ + api_server_sql.BatchCreateRequest(root_task=_make_task_spec(f"pipeline-{i}")) + for i in range(3) + ] + with session_factory() as session: + result = service.create_batch(session=session, runs=runs) + assert len(result.created_runs) == 3 + ids = [r.id for r in result.created_runs] + assert len(set(ids)) == 3 # All unique IDs + + def test_create_batch_with_created_by(self, session_factory, service): + runs = [ + api_server_sql.BatchCreateRequest(root_task=_make_task_spec("p1")), + api_server_sql.BatchCreateRequest(root_task=_make_task_spec("p2")), + ] + with session_factory() as session: + result = service.create_batch( + session=session, runs=runs, created_by="alice@example.com" + ) + for run in result.created_runs: + assert run.created_by == "alice@example.com" + + def test_create_batch_with_annotations(self, session_factory, service): + annotations = {"team": "ml-ops"} + runs = [ + api_server_sql.BatchCreateRequest( + root_task=_make_task_spec(), annotations=annotations + ), + ] + with session_factory() as session: + result = service.create_batch(session=session, runs=runs) + assert result.created_runs[0].annotations == annotations + + def test_create_batch_mirrors_system_annotations(self, session_factory, service): + runs = [ + api_server_sql.BatchCreateRequest(root_task=_make_task_spec("batch-pipe")), + ] + with session_factory() as session: + result = service.create_batch( + session=session, runs=runs, created_by="bob" + ) + with session_factory() as session: + annotations = service.list_annotations( + session=session, id=result.created_runs[0].id + ) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME] + == "batch-pipe" + ) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY] + == "bob" + ) + + def test_create_batch_rejects_empty_list(self, session_factory, service): + with session_factory() as session: + with pytest.raises(errors.ApiValidationError, match="at least one run"): + service.create_batch(session=session, runs=[]) + + def test_create_batch_rejects_exceeding_max_size(self, session_factory, service): + runs = [ + api_server_sql.BatchCreateRequest(root_task=_make_task_spec()) + for _ in range(101) + ] + with session_factory() as session: + with pytest.raises(errors.ApiValidationError, match="exceeds the maximum"): + service.create_batch(session=session, runs=runs) + + def test_create_batch_accepts_max_size(self, session_factory, service): + runs = [ + api_server_sql.BatchCreateRequest(root_task=_make_task_spec(f"p-{i}")) + for i in range(100) + ] + with session_factory() as session: + result = service.create_batch(session=session, runs=runs) + assert len(result.created_runs) == 100 + + def test_create_batch_is_atomic(self, session_factory, service): + """All runs in a batch share a single transaction.""" + runs = [ + api_server_sql.BatchCreateRequest(root_task=_make_task_spec(f"atomic-{i}")) + for i in range(3) + ] + with session_factory() as session: + result = service.create_batch(session=session, runs=runs) + + # Verify all runs are retrievable + with session_factory() as session: + for run in result.created_runs: + fetched = service.get(session=session, id=run.id) + assert fetched.id == run.id