Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cloud_pipelines_backend/api_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,20 @@ def stream_container_log(
**default_config,
)(create_run_func)

create_batch_func = pipeline_run_service.create_batch
create_batch_func = inject_session_dependency(create_batch_func)
create_batch_func = add_parameter_annotation_metadata(
create_batch_func,
parameter_name="created_by",
annotation_metadata=get_user_name_dependency,
)
router.post(
"/api/pipeline_runs/batch",
tags=["pipelineRuns"],
dependencies=pipeline_run_creation_dependencies,
**default_config,
)(create_batch_func)

router.get(
"/api/artifacts/{id}/signed_artifact_url", tags=["artifacts"], **default_config
)(inject_session_dependency(artifact_service.get_signed_artifact_url))
Expand Down
69 changes: 69 additions & 0 deletions cloud_pipelines_backend/api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,22 @@ class ListPipelineJobsResponse:
next_page_token: str | None = None


@dataclasses.dataclass(kw_only=True)
class BatchCreateRequest:
root_task: structures.TaskSpec
components: Optional[list[structures.ComponentReference]] = None
annotations: Optional[dict[str, Any]] = None


@dataclasses.dataclass(kw_only=True)
class BatchCreatePipelineRunsResponse:
created_runs: list[PipelineRunResponse]


class PipelineRunsApiService_Sql:
_PIPELINE_NAME_EXTRA_DATA_KEY = "pipeline_name"
_DEFAULT_PAGE_SIZE: Final[int] = 10
_MAX_BATCH_SIZE: Final[int] = 100
_SYSTEM_KEY_RESERVED_MSG = (
"Annotation keys starting with "
f"{filter_query_sql.SYSTEM_KEY_PREFIX!r} are reserved for system use."
Expand Down Expand Up @@ -148,6 +161,62 @@ def create(
session.refresh(pipeline_run)
return PipelineRunResponse.from_db(pipeline_run)

def create_batch(
self,
session: orm.Session,
runs: list[BatchCreateRequest],
created_by: str | None = None,
) -> BatchCreatePipelineRunsResponse:
if not runs:
raise errors.ApiValidationError("Batch must contain at least one run.")
if len(runs) > self._MAX_BATCH_SIZE:
raise errors.ApiValidationError(
f"Batch size {len(runs)} exceeds the maximum of {self._MAX_BATCH_SIZE}."
)

pipeline_runs: list[bts.PipelineRun] = []

with session.begin():
for run_request in runs:
pipeline_name = run_request.root_task.component_ref.spec.name

root_execution_node = (
_recursively_create_all_executions_and_artifacts_root(
session=session,
root_task_spec=run_request.root_task,
)
)

current_time = _get_current_time()
pipeline_run = bts.PipelineRun(
root_execution=root_execution_node,
created_at=current_time,
updated_at=current_time,
annotations=run_request.annotations,
created_by=created_by,
extra_data={
self._PIPELINE_NAME_EXTRA_DATA_KEY: pipeline_name,
},
)
session.add(pipeline_run)
session.flush()
_mirror_system_annotations(
session=session,
pipeline_run_id=pipeline_run.id,
created_by=created_by,
pipeline_name=pipeline_name,
)
pipeline_runs.append(pipeline_run)

session.commit()

responses: list[PipelineRunResponse] = []
for pipeline_run in pipeline_runs:
session.refresh(pipeline_run)
responses.append(PipelineRunResponse.from_db(pipeline_run))

return BatchCreatePipelineRunsResponse(created_runs=responses)

def get(self, session: orm.Session, id: bts.IdType) -> PipelineRunResponse:
pipeline_run = session.get(bts.PipelineRun, id)
if not pipeline_run:
Expand Down
95 changes: 95 additions & 0 deletions tests/test_api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,3 +1628,98 @@ def test_returns_none_on_malformed_dict(self):
task_spec_dict={"bad": "data"}
)
assert result is None


class TestPipelineRunServiceCreateBatch:
def test_create_batch_returns_all_runs(self, session_factory, service):
runs = [
api_server_sql.BatchCreateRequest(root_task=_make_task_spec(f"pipeline-{i}"))
for i in range(3)
]
with session_factory() as session:
result = service.create_batch(session=session, runs=runs)
assert len(result.created_runs) == 3
ids = [r.id for r in result.created_runs]
assert len(set(ids)) == 3 # All unique IDs

def test_create_batch_with_created_by(self, session_factory, service):
runs = [
api_server_sql.BatchCreateRequest(root_task=_make_task_spec("p1")),
api_server_sql.BatchCreateRequest(root_task=_make_task_spec("p2")),
]
with session_factory() as session:
result = service.create_batch(
session=session, runs=runs, created_by="alice@example.com"
)
for run in result.created_runs:
assert run.created_by == "alice@example.com"

def test_create_batch_with_annotations(self, session_factory, service):
annotations = {"team": "ml-ops"}
runs = [
api_server_sql.BatchCreateRequest(
root_task=_make_task_spec(), annotations=annotations
),
]
with session_factory() as session:
result = service.create_batch(session=session, runs=runs)
assert result.created_runs[0].annotations == annotations

def test_create_batch_mirrors_system_annotations(self, session_factory, service):
runs = [
api_server_sql.BatchCreateRequest(root_task=_make_task_spec("batch-pipe")),
]
with session_factory() as session:
result = service.create_batch(
session=session, runs=runs, created_by="bob"
)
with session_factory() as session:
annotations = service.list_annotations(
session=session, id=result.created_runs[0].id
)
assert (
annotations[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME]
== "batch-pipe"
)
assert (
annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY]
== "bob"
)

def test_create_batch_rejects_empty_list(self, session_factory, service):
with session_factory() as session:
with pytest.raises(errors.ApiValidationError, match="at least one run"):
service.create_batch(session=session, runs=[])

def test_create_batch_rejects_exceeding_max_size(self, session_factory, service):
runs = [
api_server_sql.BatchCreateRequest(root_task=_make_task_spec())
for _ in range(101)
]
with session_factory() as session:
with pytest.raises(errors.ApiValidationError, match="exceeds the maximum"):
service.create_batch(session=session, runs=runs)

def test_create_batch_accepts_max_size(self, session_factory, service):
runs = [
api_server_sql.BatchCreateRequest(root_task=_make_task_spec(f"p-{i}"))
for i in range(100)
]
with session_factory() as session:
result = service.create_batch(session=session, runs=runs)
assert len(result.created_runs) == 100

def test_create_batch_is_atomic(self, session_factory, service):
"""All runs in a batch share a single transaction."""
runs = [
api_server_sql.BatchCreateRequest(root_task=_make_task_spec(f"atomic-{i}"))
for i in range(3)
]
with session_factory() as session:
result = service.create_batch(session=session, runs=runs)

# Verify all runs are retrievable
with session_factory() as session:
for run in result.created_runs:
fetched = service.get(session=session, id=run.id)
assert fetched.id == run.id
Loading