@@ -1628,3 +1628,98 @@ def test_returns_none_on_malformed_dict(self):
16281628 task_spec_dict = {"bad" : "data" }
16291629 )
16301630 assert result is None
1631+
1632+
1633+ class TestPipelineRunServiceCreateBatch :
1634+ def test_create_batch_returns_all_runs (self , session_factory , service ):
1635+ runs = [
1636+ api_server_sql .BatchCreateRequest (root_task = _make_task_spec (f"pipeline-{ i } " ))
1637+ for i in range (3 )
1638+ ]
1639+ with session_factory () as session :
1640+ result = service .create_batch (session = session , runs = runs )
1641+ assert len (result .created_runs ) == 3
1642+ ids = [r .id for r in result .created_runs ]
1643+ assert len (set (ids )) == 3 # All unique IDs
1644+
1645+ def test_create_batch_with_created_by (self , session_factory , service ):
1646+ runs = [
1647+ api_server_sql .BatchCreateRequest (root_task = _make_task_spec ("p1" )),
1648+ api_server_sql .BatchCreateRequest (root_task = _make_task_spec ("p2" )),
1649+ ]
1650+ with session_factory () as session :
1651+ result = service .create_batch (
1652+ session = session , runs = runs , created_by = "alice@example.com"
1653+ )
1654+ for run in result .created_runs :
1655+ assert run .created_by == "alice@example.com"
1656+
1657+ def test_create_batch_with_annotations (self , session_factory , service ):
1658+ annotations = {"team" : "ml-ops" }
1659+ runs = [
1660+ api_server_sql .BatchCreateRequest (
1661+ root_task = _make_task_spec (), annotations = annotations
1662+ ),
1663+ ]
1664+ with session_factory () as session :
1665+ result = service .create_batch (session = session , runs = runs )
1666+ assert result .created_runs [0 ].annotations == annotations
1667+
1668+ def test_create_batch_mirrors_system_annotations (self , session_factory , service ):
1669+ runs = [
1670+ api_server_sql .BatchCreateRequest (root_task = _make_task_spec ("batch-pipe" )),
1671+ ]
1672+ with session_factory () as session :
1673+ result = service .create_batch (
1674+ session = session , runs = runs , created_by = "bob"
1675+ )
1676+ with session_factory () as session :
1677+ annotations = service .list_annotations (
1678+ session = session , id = result .created_runs [0 ].id
1679+ )
1680+ assert (
1681+ annotations [filter_query_sql .PipelineRunAnnotationSystemKey .PIPELINE_NAME ]
1682+ == "batch-pipe"
1683+ )
1684+ assert (
1685+ annotations [filter_query_sql .PipelineRunAnnotationSystemKey .CREATED_BY ]
1686+ == "bob"
1687+ )
1688+
1689+ def test_create_batch_rejects_empty_list (self , session_factory , service ):
1690+ with session_factory () as session :
1691+ with pytest .raises (errors .ApiValidationError , match = "at least one run" ):
1692+ service .create_batch (session = session , runs = [])
1693+
1694+ def test_create_batch_rejects_exceeding_max_size (self , session_factory , service ):
1695+ runs = [
1696+ api_server_sql .BatchCreateRequest (root_task = _make_task_spec ())
1697+ for _ in range (101 )
1698+ ]
1699+ with session_factory () as session :
1700+ with pytest .raises (errors .ApiValidationError , match = "exceeds the maximum" ):
1701+ service .create_batch (session = session , runs = runs )
1702+
1703+ def test_create_batch_accepts_max_size (self , session_factory , service ):
1704+ runs = [
1705+ api_server_sql .BatchCreateRequest (root_task = _make_task_spec (f"p-{ i } " ))
1706+ for i in range (100 )
1707+ ]
1708+ with session_factory () as session :
1709+ result = service .create_batch (session = session , runs = runs )
1710+ assert len (result .created_runs ) == 100
1711+
1712+ def test_create_batch_is_atomic (self , session_factory , service ):
1713+ """All runs in a batch share a single transaction."""
1714+ runs = [
1715+ api_server_sql .BatchCreateRequest (root_task = _make_task_spec (f"atomic-{ i } " ))
1716+ for i in range (3 )
1717+ ]
1718+ with session_factory () as session :
1719+ result = service .create_batch (session = session , runs = runs )
1720+
1721+ # Verify all runs are retrievable
1722+ with session_factory () as session :
1723+ for run in result .created_runs :
1724+ fetched = service .get (session = session , id = run .id )
1725+ assert fetched .id == run .id
0 commit comments