Skip to content

Commit 606b129

Browse files
committed
Adds bulk pipeline submission
1 parent 42d18f6 commit 606b129

3 files changed

Lines changed: 178 additions & 0 deletions

File tree

cloud_pipelines_backend/api_router.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,20 @@ def stream_container_log(
331331
**default_config,
332332
)(create_run_func)
333333

334+
create_batch_func = pipeline_run_service.create_batch
335+
create_batch_func = inject_session_dependency(create_batch_func)
336+
create_batch_func = add_parameter_annotation_metadata(
337+
create_batch_func,
338+
parameter_name="created_by",
339+
annotation_metadata=get_user_name_dependency,
340+
)
341+
router.post(
342+
"/api/pipeline_runs/batch",
343+
tags=["pipelineRuns"],
344+
dependencies=pipeline_run_creation_dependencies,
345+
**default_config,
346+
)(create_batch_func)
347+
334348
router.get(
335349
"/api/artifacts/{id}/signed_artifact_url", tags=["artifacts"], **default_config
336350
)(inject_session_dependency(artifact_service.get_signed_artifact_url))

cloud_pipelines_backend/api_server_sql.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,22 @@ class ListPipelineJobsResponse:
8686
next_page_token: str | None = None
8787

8888

89+
@dataclasses.dataclass(kw_only=True)
90+
class BatchCreateRequest:
91+
root_task: structures.TaskSpec
92+
components: Optional[list[structures.ComponentReference]] = None
93+
annotations: Optional[dict[str, Any]] = None
94+
95+
96+
@dataclasses.dataclass(kw_only=True)
97+
class BatchCreatePipelineRunsResponse:
98+
created_runs: list[PipelineRunResponse]
99+
100+
89101
class PipelineRunsApiService_Sql:
90102
_PIPELINE_NAME_EXTRA_DATA_KEY = "pipeline_name"
91103
_DEFAULT_PAGE_SIZE: Final[int] = 10
104+
_MAX_BATCH_SIZE: Final[int] = 100
92105
_SYSTEM_KEY_RESERVED_MSG = (
93106
"Annotation keys starting with "
94107
f"{filter_query_sql.SYSTEM_KEY_PREFIX!r} are reserved for system use."
@@ -148,6 +161,62 @@ def create(
148161
session.refresh(pipeline_run)
149162
return PipelineRunResponse.from_db(pipeline_run)
150163

164+
def create_batch(
165+
self,
166+
session: orm.Session,
167+
runs: list[BatchCreateRequest],
168+
created_by: str | None = None,
169+
) -> BatchCreatePipelineRunsResponse:
170+
if not runs:
171+
raise errors.ApiValidationError("Batch must contain at least one run.")
172+
if len(runs) > self._MAX_BATCH_SIZE:
173+
raise errors.ApiValidationError(
174+
f"Batch size {len(runs)} exceeds the maximum of {self._MAX_BATCH_SIZE}."
175+
)
176+
177+
pipeline_runs: list[bts.PipelineRun] = []
178+
179+
with session.begin():
180+
for run_request in runs:
181+
pipeline_name = run_request.root_task.component_ref.spec.name
182+
183+
root_execution_node = (
184+
_recursively_create_all_executions_and_artifacts_root(
185+
session=session,
186+
root_task_spec=run_request.root_task,
187+
)
188+
)
189+
190+
current_time = _get_current_time()
191+
pipeline_run = bts.PipelineRun(
192+
root_execution=root_execution_node,
193+
created_at=current_time,
194+
updated_at=current_time,
195+
annotations=run_request.annotations,
196+
created_by=created_by,
197+
extra_data={
198+
self._PIPELINE_NAME_EXTRA_DATA_KEY: pipeline_name,
199+
},
200+
)
201+
session.add(pipeline_run)
202+
session.flush()
203+
_mirror_system_annotations(
204+
session=session,
205+
pipeline_run_id=pipeline_run.id,
206+
created_by=created_by,
207+
pipeline_name=pipeline_name,
208+
)
209+
pipeline_runs.append(pipeline_run)
210+
211+
session.commit()
212+
213+
responses: list[PipelineRunResponse] = []
214+
for pipeline_run in pipeline_runs:
215+
session.refresh(pipeline_run)
216+
responses.append(PipelineRunResponse.from_db(pipeline_run))
217+
218+
return BatchCreatePipelineRunsResponse(created_runs=responses)
219+
151220
def get(self, session: orm.Session, id: bts.IdType) -> PipelineRunResponse:
152221
pipeline_run = session.get(bts.PipelineRun, id)
153222
if not pipeline_run:

tests/test_api_server_sql.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,3 +1628,98 @@ def test_returns_none_on_malformed_dict(self):
16281628
task_spec_dict={"bad": "data"}
16291629
)
16301630
assert result is None
1631+
1632+
1633+
class TestPipelineRunServiceCreateBatch:
1634+
def test_create_batch_returns_all_runs(self, session_factory, service):
1635+
runs = [
1636+
api_server_sql.BatchCreateRequest(root_task=_make_task_spec(f"pipeline-{i}"))
1637+
for i in range(3)
1638+
]
1639+
with session_factory() as session:
1640+
result = service.create_batch(session=session, runs=runs)
1641+
assert len(result.created_runs) == 3
1642+
ids = [r.id for r in result.created_runs]
1643+
assert len(set(ids)) == 3 # All unique IDs
1644+
1645+
def test_create_batch_with_created_by(self, session_factory, service):
1646+
runs = [
1647+
api_server_sql.BatchCreateRequest(root_task=_make_task_spec("p1")),
1648+
api_server_sql.BatchCreateRequest(root_task=_make_task_spec("p2")),
1649+
]
1650+
with session_factory() as session:
1651+
result = service.create_batch(
1652+
session=session, runs=runs, created_by="alice@example.com"
1653+
)
1654+
for run in result.created_runs:
1655+
assert run.created_by == "alice@example.com"
1656+
1657+
def test_create_batch_with_annotations(self, session_factory, service):
1658+
annotations = {"team": "ml-ops"}
1659+
runs = [
1660+
api_server_sql.BatchCreateRequest(
1661+
root_task=_make_task_spec(), annotations=annotations
1662+
),
1663+
]
1664+
with session_factory() as session:
1665+
result = service.create_batch(session=session, runs=runs)
1666+
assert result.created_runs[0].annotations == annotations
1667+
1668+
def test_create_batch_mirrors_system_annotations(self, session_factory, service):
1669+
runs = [
1670+
api_server_sql.BatchCreateRequest(root_task=_make_task_spec("batch-pipe")),
1671+
]
1672+
with session_factory() as session:
1673+
result = service.create_batch(
1674+
session=session, runs=runs, created_by="bob"
1675+
)
1676+
with session_factory() as session:
1677+
annotations = service.list_annotations(
1678+
session=session, id=result.created_runs[0].id
1679+
)
1680+
assert (
1681+
annotations[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME]
1682+
== "batch-pipe"
1683+
)
1684+
assert (
1685+
annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY]
1686+
== "bob"
1687+
)
1688+
1689+
def test_create_batch_rejects_empty_list(self, session_factory, service):
1690+
with session_factory() as session:
1691+
with pytest.raises(errors.ApiValidationError, match="at least one run"):
1692+
service.create_batch(session=session, runs=[])
1693+
1694+
def test_create_batch_rejects_exceeding_max_size(self, session_factory, service):
1695+
runs = [
1696+
api_server_sql.BatchCreateRequest(root_task=_make_task_spec())
1697+
for _ in range(101)
1698+
]
1699+
with session_factory() as session:
1700+
with pytest.raises(errors.ApiValidationError, match="exceeds the maximum"):
1701+
service.create_batch(session=session, runs=runs)
1702+
1703+
def test_create_batch_accepts_max_size(self, session_factory, service):
1704+
runs = [
1705+
api_server_sql.BatchCreateRequest(root_task=_make_task_spec(f"p-{i}"))
1706+
for i in range(100)
1707+
]
1708+
with session_factory() as session:
1709+
result = service.create_batch(session=session, runs=runs)
1710+
assert len(result.created_runs) == 100
1711+
1712+
def test_create_batch_is_atomic(self, session_factory, service):
1713+
"""All runs in a batch share a single transaction."""
1714+
runs = [
1715+
api_server_sql.BatchCreateRequest(root_task=_make_task_spec(f"atomic-{i}"))
1716+
for i in range(3)
1717+
]
1718+
with session_factory() as session:
1719+
result = service.create_batch(session=session, runs=runs)
1720+
1721+
# Verify all runs are retrievable
1722+
with session_factory() as session:
1723+
for run in result.created_runs:
1724+
fetched = service.get(session=session, id=run.id)
1725+
assert fetched.id == run.id

0 commit comments

Comments
 (0)