|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import asyncio |
3 | 4 | import logging |
4 | 5 | import re |
5 | 6 | from typing import TYPE_CHECKING, Dict, List, Optional, Union |
|
10 | 11 | from .operations import ( |
11 | 12 | _DEPLOY_OP, |
12 | 13 | _DEPLOY_WITH_PROFILE_OP, |
| 14 | + _GET_PIPELINE_STATUS_OP, |
13 | 15 | _SCALE_OP, |
| 16 | + _START_PIPELINE_STAGE_OP, |
| 17 | + _START_PIPELINE_STAGE_WITH_PROFILE_OP, |
| 18 | + _STOP_PIPELINE_STAGE_OP, |
14 | 19 | _TEARDOWN_OP, |
15 | 20 | ) |
16 | | -from .schemas import Profile, ProfileConfig |
| 21 | +from .schemas import ( |
| 22 | + PipelineStage, |
| 23 | + PipelineState, |
| 24 | + PipelineStatusList, |
| 25 | + Profile, |
| 26 | + ProfileConfig, |
| 27 | +) |
17 | 28 |
|
18 | 29 | if TYPE_CHECKING: |
19 | 30 | from ..schemas import CommandOutput |
@@ -95,3 +106,155 @@ async def teardown(self) -> None: |
95 | 106 |
|
96 | 107 | async def scale(self, services: Dict[str, int]) -> None: |
97 | 108 | await self._execute_operation(_SCALE_OP, data=services) |
| 109 | + |
| 110 | + # Pipeline operations |
| 111 | + |
| 112 | + async def start_stage( |
| 113 | + self, |
| 114 | + stage: Union[PipelineStage, str], |
| 115 | + profile: Optional[str] = None, |
| 116 | + ) -> None: |
| 117 | + """Start a pipeline stage. |
| 118 | +
|
| 119 | + Args: |
| 120 | + stage: The pipeline stage to start ('prepare', 'test', or 'run'). |
| 121 | + profile: Optional profile name. If provided, starts the stage with |
| 122 | + that profile. Required for first run after deploy. |
| 123 | +
|
| 124 | + Raises: |
| 125 | + ValidationError: If the workspace is not running or parameters are invalid. |
| 126 | + NotFoundError: If the workspace is not found. |
| 127 | + """ |
| 128 | + if isinstance(stage, PipelineStage): |
| 129 | + stage = stage.value |
| 130 | + |
| 131 | + if profile is not None: |
| 132 | + _validate_profile_name(profile) |
| 133 | + await self._execute_operation( |
| 134 | + _START_PIPELINE_STAGE_WITH_PROFILE_OP, stage=stage, profile=profile |
| 135 | + ) |
| 136 | + else: |
| 137 | + await self._execute_operation(_START_PIPELINE_STAGE_OP, stage=stage) |
| 138 | + |
| 139 | + async def stop_stage(self, stage: Union[PipelineStage, str]) -> None: |
| 140 | + """Stop a pipeline stage. |
| 141 | +
|
| 142 | + Args: |
| 143 | + stage: The pipeline stage to stop ('prepare', 'test', or 'run'). |
| 144 | +
|
| 145 | + Raises: |
| 146 | + ValidationError: If the workspace is not running or parameters are invalid. |
| 147 | + NotFoundError: If the workspace is not found. |
| 148 | + """ |
| 149 | + if isinstance(stage, PipelineStage): |
| 150 | + stage = stage.value |
| 151 | + |
| 152 | + await self._execute_operation(_STOP_PIPELINE_STAGE_OP, stage=stage) |
| 153 | + |
| 154 | + async def get_stage_status( |
| 155 | + self, stage: Union[PipelineStage, str] |
| 156 | + ) -> PipelineStatusList: |
| 157 | + """Get the status of a pipeline stage. |
| 158 | +
|
| 159 | + Args: |
| 160 | + stage: The pipeline stage to get status for ('prepare', 'test', or 'run'). |
| 161 | +
|
| 162 | + Returns: |
| 163 | + List of PipelineStatus objects, one per replica/server. |
| 164 | +
|
| 165 | + Raises: |
| 166 | + ValidationError: If the workspace is not running or parameters are invalid. |
| 167 | + NotFoundError: If the workspace is not found. |
| 168 | + """ |
| 169 | + if isinstance(stage, PipelineStage): |
| 170 | + stage = stage.value |
| 171 | + |
| 172 | + return await self._execute_operation(_GET_PIPELINE_STATUS_OP, stage=stage) |
| 173 | + |
| 174 | + async def wait_for_stage( |
| 175 | + self, |
| 176 | + stage: Union[PipelineStage, str], |
| 177 | + *, |
| 178 | + timeout: float = 300.0, |
| 179 | + poll_interval: float = 5.0, |
| 180 | + server: Optional[str] = None, |
| 181 | + ) -> PipelineStatusList: |
| 182 | + """Wait for a pipeline stage to complete (success or failure). |
| 183 | +
|
| 184 | + Args: |
| 185 | + stage: The pipeline stage to wait for. |
| 186 | + timeout: Maximum time to wait in seconds (default: 300). |
| 187 | + poll_interval: Time between status checks in seconds (default: 5). |
| 188 | + server: Optional server name to filter by. If None, waits for all |
| 189 | + servers that have steps defined for this stage. |
| 190 | +
|
| 191 | + Returns: |
| 192 | + Final PipelineStatusList when stage completes. |
| 193 | +
|
| 194 | + Raises: |
| 195 | + TimeoutError: If the stage doesn't complete within the timeout. |
| 196 | + ValidationError: If the workspace is not running. |
| 197 | + """ |
| 198 | + if poll_interval <= 0: |
| 199 | + raise ValueError("poll_interval must be greater than 0") |
| 200 | + |
| 201 | + stage_name = stage.value if isinstance(stage, PipelineStage) else stage |
| 202 | + elapsed = 0.0 |
| 203 | + |
| 204 | + while elapsed < timeout: |
| 205 | + status_list = await self.get_stage_status(stage) |
| 206 | + |
| 207 | + # Filter to relevant servers for THIS stage |
| 208 | + # A server is relevant for this stage if: |
| 209 | + # - It has steps defined (meaning it participates in this stage) |
| 210 | + # - OR it's not in 'waiting' state (meaning it has started) |
| 211 | + relevant_statuses = [] |
| 212 | + for s in status_list: |
| 213 | + if server is not None: |
| 214 | + # Filter by specific server |
| 215 | + if s.server == server: |
| 216 | + relevant_statuses.append(s) |
| 217 | + else: |
| 218 | + # Include servers that have steps for this stage |
| 219 | + # Servers with no steps and waiting state don't participate in this stage |
| 220 | + if s.steps: |
| 221 | + relevant_statuses.append(s) |
| 222 | + elif s.state != PipelineState.WAITING: |
| 223 | + # Started but no steps visible yet |
| 224 | + relevant_statuses.append(s) |
| 225 | + |
| 226 | + # If no relevant statuses yet, keep waiting |
| 227 | + if not relevant_statuses: |
| 228 | + log.debug( |
| 229 | + "Pipeline stage '%s': no servers with steps yet, waiting...", |
| 230 | + stage_name, |
| 231 | + ) |
| 232 | + await asyncio.sleep(poll_interval) |
| 233 | + elapsed += poll_interval |
| 234 | + continue |
| 235 | + |
| 236 | + # Check if all relevant servers have completed |
| 237 | + all_completed = all( |
| 238 | + s.state |
| 239 | + in (PipelineState.SUCCESS, PipelineState.FAILURE, PipelineState.ABORTED) |
| 240 | + for s in relevant_statuses |
| 241 | + ) |
| 242 | + |
| 243 | + if all_completed: |
| 244 | + log.debug("Pipeline stage '%s' completed.", stage_name) |
| 245 | + return PipelineStatusList(root=relevant_statuses) |
| 246 | + |
| 247 | + # Log current state |
| 248 | + states = [f"{s.server}={s.state.value}" for s in relevant_statuses] |
| 249 | + log.debug( |
| 250 | + "Pipeline stage '%s' status: %s (elapsed: %.1fs)", |
| 251 | + stage_name, |
| 252 | + ", ".join(states), |
| 253 | + elapsed, |
| 254 | + ) |
| 255 | + await asyncio.sleep(poll_interval) |
| 256 | + elapsed += poll_interval |
| 257 | + |
| 258 | + raise TimeoutError( |
| 259 | + f"Pipeline stage '{stage_name}' did not complete within {timeout} seconds." |
| 260 | + ) |
0 commit comments