diff --git a/cloud_pipelines_backend/launchers/kubernetes_launchers.py b/cloud_pipelines_backend/launchers/kubernetes_launchers.py index ac442c0..86d1734 100644 --- a/cloud_pipelines_backend/launchers/kubernetes_launchers.py +++ b/cloud_pipelines_backend/launchers/kubernetes_launchers.py @@ -2,6 +2,7 @@ import copy import datetime +import enum import json import logging import os @@ -67,6 +68,34 @@ _MULTI_NODE_NODE_INDEX_ENV_VAR_NAME = "_TANGLE_MULTI_NODE_NODE_INDEX" +class _JobConditionType(str, enum.Enum): + """Kubernetes Job condition types. + + A Job is considered finished when it is in a terminal condition, + either "Complete" or "Failed". + + Reference: https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/job-v1/ + See: `A job is considered finished when it is in a terminal condition, either "Complete" or "Failed".` + """ + + COMPLETE = "Complete" + FAILED = "Failed" + SUSPENDED = "Suspended" + FAILURE_TARGET = "FailureTarget" + + +class _ConditionStatus(str, enum.Enum): + """Kubernetes condition status values. + + Reference: https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/job-v1/ + See: `Status of the condition, one of True, False, Unknown.` + """ + + TRUE = "True" + FALSE = "False" + UNKNOWN = "Unknown" + + _T = typing.TypeVar("_T") _CONTAINER_FILE_NAME = "data" @@ -1265,11 +1294,13 @@ def status(self) -> interfaces.ContainerStatus: if not job_status: return interfaces.ContainerStatus.PENDING has_succeeded_condition = any( - condition.type == "Complete" and condition.status == "True" + condition.type == _JobConditionType.COMPLETE + and condition.status == _ConditionStatus.TRUE for condition in job_status.conditions or [] ) has_failed_condition = any( - condition.type == "Failed" and condition.status == "True" + condition.type == _JobConditionType.FAILED + and condition.status == _ConditionStatus.TRUE for condition in job_status.conditions or [] ) if has_failed_condition: @@ -1338,13 +1369,19 @@ def started_at(self) -> datetime.datetime | None: @property def ended_at(self) -> datetime.datetime | None: + """Return the time when the Job entered a terminal condition. + + A Job is considered finished when it has a "Complete" or "Failed" + condition with status "True". + """ job_status = self._debug_job.status if not job_status: return None ended_condition_times = [ condition.last_transition_time for condition in job_status.conditions or [] - if condition.type in ("Succeeded", "Failed") and condition.status == "True" + if condition.type in (_JobConditionType.COMPLETE, _JobConditionType.FAILED) + and condition.status == _ConditionStatus.TRUE ] if not ended_condition_times: return None diff --git a/cloud_pipelines_backend/orchestrator_sql.py b/cloud_pipelines_backend/orchestrator_sql.py index 4ac30a9..21f100b 100644 --- a/cloud_pipelines_backend/orchestrator_sql.py +++ b/cloud_pipelines_backend/orchestrator_sql.py @@ -170,8 +170,10 @@ def internal_process_running_executions_queue(self, session: orm.Session): except Exception as ex: _logger.exception("Error processing running container execution") session.rollback() - running_container_execution.status = ( - bts.ContainerExecutionStatus.SYSTEM_ERROR + _record_terminal_state( + container_execution=running_container_execution, + status=bts.ContainerExecutionStatus.SYSTEM_ERROR, + ended_at=_get_current_time(), ) # Doing an intermediate commit here because it's most important to mark the problematic execution as SYSTEM_ERROR. session.commit() @@ -684,9 +686,12 @@ def internal_process_one_running_execution( # Requesting container termination. # Termination might not happen immediately (e.g. Kubernetes has grace period). launched_container.terminate() - container_execution.ended_at = _get_current_time() # We need to mark the execution as CANCELLED otherwise orchestrator will continue polling it. - container_execution.status = bts.ContainerExecutionStatus.CANCELLED + _record_terminal_state( + container_execution=container_execution, + status=bts.ContainerExecutionStatus.CANCELLED, + ended_at=_get_current_time(), + ) terminated = True # Mark the execution nodes as cancelled only after the launched container is successfully terminated (if needed) @@ -746,10 +751,13 @@ def internal_process_one_running_execution( bts.ContainerExecutionStatus.RUNNING ) elif new_status == launcher_interfaces.ContainerStatus.SUCCEEDED: - container_execution.status = bts.ContainerExecutionStatus.SUCCEEDED - container_execution.exit_code = reloaded_launched_container.exit_code - container_execution.started_at = reloaded_launched_container.started_at - container_execution.ended_at = reloaded_launched_container.ended_at + _record_terminal_state( + container_execution=container_execution, + status=bts.ContainerExecutionStatus.SUCCEEDED, + exit_code=reloaded_launched_container.exit_code, + started_at=reloaded_launched_container.started_at, + ended_at=reloaded_launched_container.ended_at, + ) # Don't fail the execution if log upload fails. # Logs are important, but not so important that we should fail a successfully completed container execution. @@ -881,10 +889,13 @@ def _maybe_preload_value( bts.ContainerExecutionStatus.QUEUED ) elif new_status == launcher_interfaces.ContainerStatus.FAILED: - container_execution.status = bts.ContainerExecutionStatus.FAILED - container_execution.exit_code = reloaded_launched_container.exit_code - container_execution.started_at = reloaded_launched_container.started_at - container_execution.ended_at = reloaded_launched_container.ended_at + _record_terminal_state( + container_execution=container_execution, + status=bts.ContainerExecutionStatus.FAILED, + exit_code=reloaded_launched_container.exit_code, + started_at=reloaded_launched_container.started_at, + ended_at=reloaded_launched_container.ended_at, + ) launcher_error = reloaded_launched_container.launcher_error_message if launcher_error: orchestration_error_message = f"Launcher error: {launcher_error}" @@ -1010,6 +1021,28 @@ def _get_current_time() -> datetime.datetime: return datetime.datetime.now(tz=datetime.timezone.utc) +def _record_terminal_state( + *, + container_execution: bts.ContainerExecution, + status: bts.ContainerExecutionStatus, + ended_at: datetime.datetime, + exit_code: int | None = None, + started_at: datetime.datetime | None = None, +) -> None: + """Record terminal state fields on a container execution. + + A terminal state must minimally include a status change and an end time. + exit_code and started_at are optional — they depend on whether the + launcher was able to report them before the execution ended. + """ + container_execution.status = status + container_execution.ended_at = ended_at + if exit_code is not None: + container_execution.exit_code = exit_code + if started_at is not None: + container_execution.started_at = started_at + + def _generate_random_id() -> str: import os import time diff --git a/tests/test_kubernetes_launchers.py b/tests/test_kubernetes_launchers.py new file mode 100644 index 0000000..192aa10 --- /dev/null +++ b/tests/test_kubernetes_launchers.py @@ -0,0 +1,159 @@ +"""Tests for LaunchedKubernetesJob.ended_at — the property fixed to use +"Complete" instead of the incorrect "Succeeded" K8s Job condition type. +""" + +from __future__ import annotations + +import datetime +from typing import Any +from unittest import mock + +from cloud_pipelines_backend.launchers import kubernetes_launchers as kl + + +def _utc( + *, + year: int = 2026, + month: int = 3, + day: int = 20, + hour: int = 12, + minute: int = 0, +) -> datetime.datetime: + return datetime.datetime( + year, month, day, hour, minute, tzinfo=datetime.timezone.utc + ) + + +def _make_condition( + *, + type: str, + status: str = "True", + last_transition_time: datetime.datetime | None = None, +) -> mock.Mock: + c = mock.Mock() + c.type = type + c.status = status + c.last_transition_time = last_transition_time or _utc() + return c + + +def _make_job( + *, + conditions: list[Any] | None = None, + active: int | None = None, + succeeded: int | None = None, + failed: int | None = None, + start_time: datetime.datetime | None = None, + completions: int | None = 1, +) -> mock.Mock: + job = mock.Mock() + job.status = mock.Mock() + job.status.conditions = conditions + job.status.active = active + job.status.succeeded = succeeded + job.status.failed = failed + job.status.start_time = start_time + job.spec = mock.Mock() + job.spec.completions = completions + return job + + +def _make_launched_job( + *, + job: mock.Mock | None = None, +) -> kl.LaunchedKubernetesJob: + if job is None: + job = _make_job() + return kl.LaunchedKubernetesJob( + job_name="test-job", + namespace="default", + output_uris={}, + log_uri="gs://bucket/log", + debug_job=job, + ) + + +class TestEndedAt: + """Tests for LaunchedKubernetesJob.ended_at. + + This property reads job.status.conditions and returns the + last_transition_time of the first terminal condition (Complete or Failed) + with status=True. + + Code under test: kubernetes_launchers.py LaunchedKubernetesJob.ended_at + """ + + def test_returns_none_when_no_status(self) -> None: + job = mock.Mock() + job.status = None + launched = _make_launched_job(job=job) + assert launched.ended_at is None + + def test_returns_none_when_no_conditions(self) -> None: + launched = _make_launched_job(job=_make_job(conditions=None)) + assert launched.ended_at is None + + def test_returns_none_when_empty_conditions(self) -> None: + launched = _make_launched_job(job=_make_job(conditions=[])) + assert launched.ended_at is None + + def test_returns_none_when_only_suspended_condition(self) -> None: + """A Suspended=True condition is not terminal — ended_at stays None.""" + condition = _make_condition(type="Suspended", status="True") + launched = _make_launched_job(job=_make_job(conditions=[condition])) + assert launched.ended_at is None + + def test_returns_time_for_complete_condition(self) -> None: + """Job finished successfully: condition type=Complete, status=True.""" + t = _utc(hour=14) + condition = _make_condition( + type="Complete", status="True", last_transition_time=t + ) + launched = _make_launched_job(job=_make_job(conditions=[condition])) + assert launched.ended_at == t + + def test_returns_time_for_failed_condition(self) -> None: + """Job failed: condition type=Failed, status=True.""" + t = _utc(hour=15) + condition = _make_condition( + type="Failed", status="True", last_transition_time=t + ) + launched = _make_launched_job(job=_make_job(conditions=[condition])) + assert launched.ended_at == t + + def test_ignores_complete_condition_with_status_false(self) -> None: + condition = _make_condition(type="Complete", status="False") + launched = _make_launched_job(job=_make_job(conditions=[condition])) + assert launched.ended_at is None + + def test_ignores_failed_condition_with_status_unknown(self) -> None: + condition = _make_condition(type="Failed", status="Unknown") + launched = _make_launched_job(job=_make_job(conditions=[condition])) + assert launched.ended_at is None + + def test_does_not_match_succeeded_string(self) -> None: + """Regression: 'Succeeded' is not a valid K8s Job condition type. + The old code had condition.type in ("Succeeded", "Failed") which + caused ended_at to always be None for successful jobs. + """ + condition = _make_condition(type="Succeeded", status="True") + launched = _make_launched_job(job=_make_job(conditions=[condition])) + assert launched.ended_at is None + + def test_picks_terminal_condition_ignoring_suspended(self) -> None: + """Real scenario: a job was suspended then resumed and completed. + Conditions list has Suspended=True followed by Complete=True. + ended_at should come from the Complete condition. + """ + t_suspended = _utc(hour=10) + t_complete = _utc(hour=14) + conditions = [ + _make_condition( + type="Suspended", status="True", last_transition_time=t_suspended + ), + _make_condition( + type="Complete", status="True", last_transition_time=t_complete + ), + ] + launched = _make_launched_job(job=_make_job(conditions=conditions)) + assert launched.ended_at == t_complete diff --git a/tests/test_orchestrator_terminal_state.py b/tests/test_orchestrator_terminal_state.py new file mode 100644 index 0000000..dd062e8 --- /dev/null +++ b/tests/test_orchestrator_terminal_state.py @@ -0,0 +1,446 @@ +"""Tests for terminal-state branches in the orchestrator. + +Each test targets a specific branch in orchestrator_sql.py that calls +_record_terminal_state and verifies that status, ended_at, exit_code, and +started_at are persisted correctly on the ContainerExecution. + +Branches tested: +- SYSTEM_ERROR: internal_process_running_executions_queue exception handler +- SUCCEEDED: internal_process_one_running_execution -> SUCCEEDED branch +- FAILED: internal_process_one_running_execution -> FAILED branch +- CANCELLED: internal_process_one_running_execution -> cancellation branch +""" + +from __future__ import annotations + +import datetime +import pathlib +from unittest import mock + +import pytest +from sqlalchemy import orm + +from cloud_pipelines.orchestration.storage_providers import local_storage + +from cloud_pipelines_backend import backend_types_sql as bts +from cloud_pipelines_backend import orchestrator_sql +from cloud_pipelines_backend.launchers import interfaces as launcher_interfaces +from tests.test_api_server_sql import session_factory # noqa: F401 + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def storage_provider() -> local_storage.LocalStorageProvider: + return local_storage.LocalStorageProvider() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _utc(*, hour: int = 12) -> datetime.datetime: + return datetime.datetime(2026, 3, 20, hour, 0, tzinfo=datetime.timezone.utc) + + +def _create_container_execution( + *, + session: orm.Session, + status: bts.ContainerExecutionStatus = bts.ContainerExecutionStatus.RUNNING, + launcher_data: dict | None = None, + output_artifact_data_map: dict | None = None, + desired_state: str | None = None, +) -> bts.ContainerExecution: + """Insert a real ContainerExecution + ExecutionNode into the DB.""" + ce = bts.ContainerExecution( + status=status, + launcher_data=launcher_data or {"kubernetes_job": {}}, + output_artifact_data_map=output_artifact_data_map or {}, + ) + node = bts.ExecutionNode(task_spec={"component_ref": {"name": "test-task"}}) + if desired_state: + node.extra_data = {"desired_state": desired_state} + node.container_execution = ce + node.container_execution_status = status + session.add(ce) + session.add(node) + session.flush() + return ce + + +def _make_launched_container( + *, + status: launcher_interfaces.ContainerStatus = launcher_interfaces.ContainerStatus.RUNNING, + exit_code: int | None = None, + started_at: datetime.datetime | None = None, + ended_at: datetime.datetime | None = None, + launcher_error_message: str | None = None, +) -> mock.MagicMock: + lc = mock.MagicMock(spec=launcher_interfaces.LaunchedContainer) + lc.status = status + lc.exit_code = exit_code + lc.started_at = started_at + lc.ended_at = ended_at + lc.launcher_error_message = launcher_error_message + lc.to_dict.return_value = {"kubernetes_job": {"refreshed": True}} + return lc + + +def _make_orchestrator( + *, + launcher: mock.MagicMock | None = None, + storage_provider: local_storage.LocalStorageProvider, + tmp_path: pathlib.Path, +) -> orchestrator_sql.OrchestratorService_Sql: + if launcher is None: + launcher = mock.MagicMock(spec=launcher_interfaces.ContainerTaskLauncher) + return orchestrator_sql.OrchestratorService_Sql( + session_factory=mock.MagicMock(), + launcher=launcher, + storage_provider=storage_provider, + data_root_uri=str(tmp_path / "data"), + logs_root_uri=str(tmp_path / "logs"), + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestSystemErrorBranch: + """internal_process_running_executions_queue exception handler. + + When internal_process_one_running_execution raises, the outer handler + records SYSTEM_ERROR + ended_at on the ContainerExecution. + """ + + def test_records_system_error_status_and_ended_at( + self, + session_factory: orm.sessionmaker, + storage_provider: local_storage.LocalStorageProvider, + tmp_path: pathlib.Path, + ) -> None: + with session_factory() as session: + ce = _create_container_execution(session=session) + ce_id = ce.id + session.commit() + + launcher = mock.MagicMock(spec=launcher_interfaces.ContainerTaskLauncher) + orch = _make_orchestrator( + launcher=launcher, + storage_provider=storage_provider, + tmp_path=tmp_path, + ) + + frozen_time = _utc(hour=15) + + with session_factory() as session: + ce = session.get(bts.ContainerExecution, ce_id) + with ( + mock.patch.object( + orch, + "internal_process_one_running_execution", + side_effect=RuntimeError("boom"), + ), + mock.patch.object( + orchestrator_sql, + "_get_current_time", + return_value=frozen_time, + ), + mock.patch.object(orchestrator_sql, "record_system_error_exception"), + mock.patch.object( + orchestrator_sql, + "_mark_all_downstream_executions_as_skipped", + ), + ): + session.scalar = mock.MagicMock(return_value=ce) + session.scalars = mock.MagicMock() + session.scalars.return_value.all.return_value = [ + node.id for node in ce.execution_nodes + ] + orch.internal_process_running_executions_queue(session=session) + + session.expire_on_commit = False + session.commit() + + with session_factory() as session: + ce = session.get(bts.ContainerExecution, ce_id) + assert ce.status == bts.ContainerExecutionStatus.SYSTEM_ERROR + assert ce.ended_at == frozen_time + + +class TestSucceededBranch: + """internal_process_one_running_execution -> SUCCEEDED branch. + + When the refreshed container reports SUCCEEDED, the orchestrator + persists status, exit_code, started_at, ended_at. + """ + + def test_records_succeeded_fields( + self, + session_factory: orm.sessionmaker, + storage_provider: local_storage.LocalStorageProvider, + tmp_path: pathlib.Path, + ) -> None: + start = _utc(hour=10) + end = _utc(hour=14) + + output_file = tmp_path / "output" + output_file.write_text("hello") + output_uri = str(output_file) + + with session_factory() as session: + ce = _create_container_execution( + session=session, + output_artifact_data_map={ + "result": {"uri": output_uri}, + }, + ) + ce_id = ce.id + session.commit() + + previous_lc = _make_launched_container( + status=launcher_interfaces.ContainerStatus.RUNNING, + ) + refreshed_lc = _make_launched_container( + status=launcher_interfaces.ContainerStatus.SUCCEEDED, + exit_code=0, + started_at=start, + ended_at=end, + ) + + launcher = mock.MagicMock(spec=launcher_interfaces.ContainerTaskLauncher) + launcher.deserialize_launched_container_from_dict.return_value = previous_lc + launcher.get_refreshed_launched_container_from_dict.return_value = refreshed_lc + + orch = _make_orchestrator( + launcher=launcher, + storage_provider=storage_provider, + tmp_path=tmp_path, + ) + + with session_factory() as session: + ce = session.get(bts.ContainerExecution, ce_id) + with ( + mock.patch.object( + orchestrator_sql, + "_get_current_time", + return_value=_utc(hour=14), + ), + mock.patch.object( + orchestrator_sql, + "_retry", + side_effect=lambda fn, **kwargs: fn(), + ), + ): + orch.internal_process_one_running_execution( + session=session, + container_execution=ce, + ) + session.commit() + + with session_factory() as session: + ce = session.get(bts.ContainerExecution, ce_id) + assert ce.status == bts.ContainerExecutionStatus.SUCCEEDED + assert ce.exit_code == 0 + assert ce.started_at == start + assert ce.ended_at == end + + +class TestFailedBranch: + """internal_process_one_running_execution -> FAILED branch. + + When the refreshed container reports FAILED, the orchestrator + persists status, exit_code, started_at, ended_at. + """ + + def test_records_failed_fields( + self, + session_factory: orm.sessionmaker, + storage_provider: local_storage.LocalStorageProvider, + tmp_path: pathlib.Path, + ) -> None: + start = _utc(hour=10) + end = _utc(hour=13) + + with session_factory() as session: + ce = _create_container_execution(session=session) + ce_id = ce.id + session.commit() + + previous_lc = _make_launched_container( + status=launcher_interfaces.ContainerStatus.RUNNING, + ) + refreshed_lc = _make_launched_container( + status=launcher_interfaces.ContainerStatus.FAILED, + exit_code=1, + started_at=start, + ended_at=end, + launcher_error_message="OOM killed", + ) + + launcher = mock.MagicMock(spec=launcher_interfaces.ContainerTaskLauncher) + launcher.deserialize_launched_container_from_dict.return_value = previous_lc + launcher.get_refreshed_launched_container_from_dict.return_value = refreshed_lc + + orch = _make_orchestrator( + launcher=launcher, + storage_provider=storage_provider, + tmp_path=tmp_path, + ) + + with session_factory() as session: + ce = session.get(bts.ContainerExecution, ce_id) + with ( + mock.patch.object( + orchestrator_sql, + "_get_current_time", + return_value=_utc(hour=13), + ), + mock.patch.object( + orchestrator_sql, + "_retry", + side_effect=lambda fn, **kwargs: fn(), + ), + ): + orch.internal_process_one_running_execution( + session=session, + container_execution=ce, + ) + session.commit() + + with session_factory() as session: + ce = session.get(bts.ContainerExecution, ce_id) + assert ce.status == bts.ContainerExecutionStatus.FAILED + assert ce.exit_code == 1 + assert ce.started_at == start + assert ce.ended_at == end + + def test_records_launcher_error_message( + self, + session_factory: orm.sessionmaker, + storage_provider: local_storage.LocalStorageProvider, + tmp_path: pathlib.Path, + ) -> None: + with session_factory() as session: + ce = _create_container_execution(session=session) + ce_id = ce.id + session.commit() + + previous_lc = _make_launched_container( + status=launcher_interfaces.ContainerStatus.RUNNING, + ) + refreshed_lc = _make_launched_container( + status=launcher_interfaces.ContainerStatus.FAILED, + exit_code=137, + started_at=_utc(hour=10), + ended_at=_utc(hour=11), + launcher_error_message="OOM killed", + ) + + launcher = mock.MagicMock(spec=launcher_interfaces.ContainerTaskLauncher) + launcher.deserialize_launched_container_from_dict.return_value = previous_lc + launcher.get_refreshed_launched_container_from_dict.return_value = refreshed_lc + + orch = _make_orchestrator( + launcher=launcher, + storage_provider=storage_provider, + tmp_path=tmp_path, + ) + + with session_factory() as session: + ce = session.get(bts.ContainerExecution, ce_id) + with ( + mock.patch.object( + orchestrator_sql, + "_get_current_time", + return_value=_utc(hour=11), + ), + mock.patch.object( + orchestrator_sql, + "_retry", + side_effect=lambda fn, **kwargs: fn(), + ), + ): + orch.internal_process_one_running_execution( + session=session, + container_execution=ce, + ) + session.commit() + + with session_factory() as session: + ce = session.get(bts.ContainerExecution, ce_id) + assert ce.extra_data is not None + assert "OOM killed" in ce.extra_data.get( + bts.CONTAINER_EXECUTION_EXTRA_DATA_ORCHESTRATION_ERROR_MESSAGE_KEY, "" + ) + + +class TestCancelledBranch: + """internal_process_one_running_execution -> cancellation branch. + + When all execution nodes have desired_state=TERMINATED, the orchestrator + terminates the container and sets CANCELLED + ended_at. + """ + + def test_records_cancelled_fields( + self, + session_factory: orm.sessionmaker, + storage_provider: local_storage.LocalStorageProvider, + tmp_path: pathlib.Path, + ) -> None: + with session_factory() as session: + ce = _create_container_execution( + session=session, + desired_state="TERMINATED", + ) + ce_id = ce.id + session.commit() + + previous_lc = _make_launched_container( + status=launcher_interfaces.ContainerStatus.RUNNING, + ) + + launcher = mock.MagicMock(spec=launcher_interfaces.ContainerTaskLauncher) + launcher.deserialize_launched_container_from_dict.return_value = previous_lc + + orch = _make_orchestrator( + launcher=launcher, + storage_provider=storage_provider, + tmp_path=tmp_path, + ) + + frozen_time = _utc(hour=16) + with session_factory() as session: + ce = session.get(bts.ContainerExecution, ce_id) + with ( + mock.patch.object( + orchestrator_sql, + "_get_current_time", + return_value=frozen_time, + ), + mock.patch.object( + orchestrator_sql, + "_retry", + side_effect=lambda fn, **kwargs: fn(), + ), + mock.patch.object( + orchestrator_sql, + "_mark_all_downstream_executions_as_skipped", + ), + ): + orch.internal_process_one_running_execution( + session=session, + container_execution=ce, + ) + session.commit() + + with session_factory() as session: + ce = session.get(bts.ContainerExecution, ce_id) + assert ce.status == bts.ContainerExecutionStatus.CANCELLED + assert ce.ended_at == frozen_time + previous_lc.terminate.assert_called_once()