diff --git a/linopy/model.py b/linopy/model.py index 54334411..06e814c6 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1401,7 +1401,9 @@ def solve( if remote is not None: if isinstance(remote, OetcHandler): - solved = remote.solve_on_oetc(self) + solved = remote.solve_on_oetc( + self, solver_name=solver_name, **solver_options + ) else: solved = remote.solve_on_remote( self, @@ -1417,7 +1419,8 @@ def solve( **solver_options, ) - self.objective.set_value(solved.objective.value) + if solved.objective.value is not None: + self.objective.set_value(float(solved.objective.value)) self.status = solved.status self.termination_condition = solved.termination_condition for k, v in self.variables.items(): diff --git a/linopy/remote/oetc.py b/linopy/remote/oetc.py index ee94fd43..f451a43d 100644 --- a/linopy/remote/oetc.py +++ b/linopy/remote/oetc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import gzip import json @@ -8,6 +10,10 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta from enum import Enum +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from linopy.model import Model try: import requests @@ -42,11 +48,97 @@ class OetcSettings: orchestrator_server_url: str compute_provider: ComputeProvider = ComputeProvider.GCP solver: str = "highs" - solver_options: dict = field(default_factory=dict) + solver_options: dict[str, Any] = field(default_factory=dict) cpu_cores: int = 2 disk_space_gb: int = 10 delete_worker_on_error: bool = False + @classmethod + def from_env( + cls, + *, + email: str | None = None, + password: str | None = None, + name: str | None = None, + authentication_server_url: str | None = None, + orchestrator_server_url: str | None = None, + cpu_cores: int | None = None, + disk_space_gb: int | None = None, + delete_worker_on_error: bool | None = None, + ) -> OetcSettings: + required_fields = { + "email": ("OETC_EMAIL", email), + "password": ("OETC_PASSWORD", password), + "name": ("OETC_NAME", name), + "authentication_server_url": ("OETC_AUTH_URL", authentication_server_url), + "orchestrator_server_url": ( + "OETC_ORCHESTRATOR_URL", + orchestrator_server_url, + ), + } + + resolved: dict[str, Any] = {} + missing: list[str] = [] + + for field_name, (env_var, kwarg) in required_fields.items(): + if kwarg is not None: + resolved[field_name] = kwarg + else: + env_val = os.environ.get(env_var, "").strip() + if env_val: + resolved[field_name] = env_val + else: + missing.append(env_var) + + if missing: + raise ValueError( + f"Missing required OETC configuration: {', '.join(missing)}" + ) + + kwargs: dict[str, Any] = { + "credentials": OetcCredentials( + email=resolved["email"], password=resolved["password"] + ), + "name": resolved["name"], + "authentication_server_url": resolved["authentication_server_url"], + "orchestrator_server_url": resolved["orchestrator_server_url"], + } + + if cpu_cores is not None: + kwargs["cpu_cores"] = cpu_cores + elif (cpu_env := os.environ.get("OETC_CPU_CORES")) is not None: + try: + kwargs["cpu_cores"] = int(cpu_env) + except ValueError as e: + raise ValueError( + f"OETC_CPU_CORES is not a valid integer: {cpu_env}" + ) from e + + if disk_space_gb is not None: + kwargs["disk_space_gb"] = disk_space_gb + elif (disk_env := os.environ.get("OETC_DISK_SPACE_GB")) is not None: + try: + kwargs["disk_space_gb"] = int(disk_env) + except ValueError as e: + raise ValueError( + f"OETC_DISK_SPACE_GB is not a valid integer: {disk_env}" + ) from e + + if delete_worker_on_error is not None: + kwargs["delete_worker_on_error"] = delete_worker_on_error + elif (del_env := os.environ.get("OETC_DELETE_WORKER_ON_ERROR")) is not None: + low = del_env.lower() + if low in ("true", "1", "yes"): + kwargs["delete_worker_on_error"] = True + elif low in ("false", "0", "no"): + kwargs["delete_worker_on_error"] = False + else: + raise ValueError( + f"OETC_DELETE_WORKER_ON_ERROR has invalid value: {del_env}" + ) + + return cls(**kwargs) + @dataclass class GcpCredentials: @@ -226,12 +318,16 @@ def __get_gcp_credentials(self) -> GcpCredentials: except Exception as e: raise Exception(f"Error fetching GCP credentials: {e}") - def _submit_job_to_compute_service(self, input_file_name: str) -> str: + def _submit_job_to_compute_service( + self, input_file_name: str, solver: str, solver_options: dict[str, Any] + ) -> str: """ Submit a job to the compute service. Args: input_file_name: Name of the input file uploaded to GCP + solver: Solver name to use + solver_options: Solver options dict Returns: CreateComputeJobResult: The job creation result with UUID @@ -243,8 +339,8 @@ def _submit_job_to_compute_service(self, input_file_name: str) -> str: logger.info("OETC - Submitting compute job...") payload = { "name": self.settings.name, - "solver": self.settings.solver, - "solver_options": self.settings.solver_options, + "solver": solver, + "solver_options": solver_options, "provider": self.settings.compute_provider.value, "cpu_cores": self.settings.cpu_cores, "disk_space_gb": self.settings.disk_space_gb, @@ -534,13 +630,19 @@ def _download_file_from_gcp(self, file_name: str) -> str: except Exception as e: raise Exception(f"Failed to download file from GCP: {e}") - def solve_on_oetc(self, model): # type: ignore + def solve_on_oetc( + self, model: Model, solver_name: str | None = None, **solver_options: Any + ) -> Model: """ Solve a linopy model on the OET Cloud compute app. Parameters ---------- model : linopy.model.Model + solver_name : str, optional + Override the solver from settings. + **solver_options + Override/extend solver_options from settings. Returns ------- @@ -552,17 +654,19 @@ def solve_on_oetc(self, model): # type: ignore Exception: If solving fails at any stage """ try: - # Save model to temporary file and upload + effective_solver = solver_name or self.settings.solver + merged_solver_options = {**self.settings.solver_options, **solver_options} + with tempfile.NamedTemporaryFile(prefix="linopy-", suffix=".nc") as fn: fn.file.close() model.to_netcdf(fn.name) input_file_name = self._upload_file_to_gcp(fn.name) - # Submit job and wait for completion - job_uuid = self._submit_job_to_compute_service(input_file_name) + job_uuid = self._submit_job_to_compute_service( + input_file_name, effective_solver, merged_solver_options + ) job_result = self.wait_and_get_job_data(job_uuid) - # Download and load the solution if not job_result.output_files: raise Exception("No output files found in completed job") @@ -572,18 +676,14 @@ def solve_on_oetc(self, model): # type: ignore solution_file_path = self._download_file_from_gcp(output_file_name) - # Load the solved model solved_model = linopy.read_netcdf(solution_file_path) - # Clean up downloaded file os.remove(solution_file_path) logger.info( f"OETC - Model solved successfully. Status: {solved_model.status}" ) - if hasattr(solved_model, "objective") and hasattr( - solved_model.objective, "value" - ): + if solved_model.objective.value is not None: logger.info( f"OETC - Objective value: {solved_model.objective.value:.2e}" ) @@ -591,7 +691,7 @@ def solve_on_oetc(self, model): # type: ignore return solved_model except Exception as e: - raise Exception(f"Error solving model on OETC: {e}") + raise Exception(f"Error solving model on OETC: {e}") from e def _gzip_compress(self, source_path: str) -> str: """ diff --git a/test/remote/test_oetc.py b/test/remote/test_oetc.py index 0704d24d..7b2d75f2 100644 --- a/test/remote/test_oetc.py +++ b/test/remote/test_oetc.py @@ -1392,7 +1392,9 @@ def test_submit_job_success( mock_post.return_value = mock_response # Execute - result = handler_with_auth_setup._submit_job_to_compute_service(input_file_name) + result = handler_with_auth_setup._submit_job_to_compute_service( + input_file_name, "gurobi", {} + ) # Verify request expected_payload = { @@ -1434,7 +1436,9 @@ def test_submit_job_http_error( # Execute and verify exception with pytest.raises(Exception) as exc_info: - handler_with_auth_setup._submit_job_to_compute_service(input_file_name) + handler_with_auth_setup._submit_job_to_compute_service( + input_file_name, "highs", {} + ) assert "Failed to submit job to compute service" in str(exc_info.value) @@ -1452,7 +1456,9 @@ def test_submit_job_missing_uuid_in_response( # Execute and verify exception with pytest.raises(Exception) as exc_info: - handler_with_auth_setup._submit_job_to_compute_service(input_file_name) + handler_with_auth_setup._submit_job_to_compute_service( + input_file_name, "highs", {} + ) assert "Invalid job submission response format: missing field 'uuid'" in str( exc_info.value @@ -1469,7 +1475,9 @@ def test_submit_job_network_error( # Execute and verify exception with pytest.raises(Exception) as exc_info: - handler_with_auth_setup._submit_job_to_compute_service(input_file_name) + handler_with_auth_setup._submit_job_to_compute_service( + input_file_name, "highs", {} + ) assert "Failed to submit job to compute service" in str(exc_info.value) @@ -1568,7 +1576,9 @@ def test_solve_on_oetc_file_upload( "/tmp/linopy-abc123.nc" ) mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") - mock_submit.assert_called_once_with("uploaded_file.nc.gz") + mock_submit.assert_called_once_with( + "uploaded_file.nc.gz", "highs", {} + ) mock_wait.assert_called_once_with("test-job-uuid") mock_download.assert_called_once_with("result.nc.gz") mock_read_netcdf.assert_called_once_with( @@ -1694,7 +1704,9 @@ def test_solve_on_oetc_with_job_submission( "/tmp/linopy-abc123.nc" ) mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") - mock_submit.assert_called_once_with(uploaded_file_name) + mock_submit.assert_called_once_with( + uploaded_file_name, "highs", {} + ) mock_wait.assert_called_once_with(job_uuid) mock_download.assert_called_once_with("result.nc.gz") mock_read_netcdf.assert_called_once_with( diff --git a/test/test_oetc_settings.py b/test/test_oetc_settings.py new file mode 100644 index 00000000..a113176c --- /dev/null +++ b/test/test_oetc_settings.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from linopy.remote.oetc import ( + ComputeProvider, + OetcCredentials, + OetcHandler, + OetcSettings, +) + +REQUIRED_ENV = { + "OETC_EMAIL": "test@example.com", + "OETC_PASSWORD": "secret", + "OETC_NAME": "test-job", + "OETC_AUTH_URL": "https://auth.example.com", + "OETC_ORCHESTRATOR_URL": "https://orch.example.com", +} + + +def _set_required_env(monkeypatch: pytest.MonkeyPatch) -> None: + for k, v in REQUIRED_ENV.items(): + monkeypatch.setenv(k, v) + + +def _clear_oetc_env(monkeypatch: pytest.MonkeyPatch) -> None: + for key in [ + "OETC_EMAIL", + "OETC_PASSWORD", + "OETC_NAME", + "OETC_AUTH_URL", + "OETC_ORCHESTRATOR_URL", + "OETC_CPU_CORES", + "OETC_DISK_SPACE_GB", + "OETC_DELETE_WORKER_ON_ERROR", + ]: + monkeypatch.delenv(key, raising=False) + + +def test_from_env_all_set(monkeypatch: pytest.MonkeyPatch) -> None: + _clear_oetc_env(monkeypatch) + _set_required_env(monkeypatch) + monkeypatch.setenv("OETC_CPU_CORES", "8") + monkeypatch.setenv("OETC_DISK_SPACE_GB", "20") + monkeypatch.setenv("OETC_DELETE_WORKER_ON_ERROR", "true") + + s = OetcSettings.from_env() + assert s.credentials.email == "test@example.com" + assert s.credentials.password == "secret" + assert s.name == "test-job" + assert s.cpu_cores == 8 + assert s.disk_space_gb == 20 + assert s.compute_provider == ComputeProvider.GCP + assert s.delete_worker_on_error is True + + +def test_from_env_kwargs_override(monkeypatch: pytest.MonkeyPatch) -> None: + _clear_oetc_env(monkeypatch) + _set_required_env(monkeypatch) + + s = OetcSettings.from_env(email="override@example.com") + assert s.credentials.email == "override@example.com" + + +def test_from_env_missing_required(monkeypatch: pytest.MonkeyPatch) -> None: + _clear_oetc_env(monkeypatch) + with pytest.raises( + ValueError, + match="OETC_EMAIL.*OETC_PASSWORD.*OETC_NAME.*OETC_AUTH_URL.*OETC_ORCHESTRATOR_URL", + ): + OetcSettings.from_env() + + +def test_from_env_empty_string_required(monkeypatch: pytest.MonkeyPatch) -> None: + _clear_oetc_env(monkeypatch) + monkeypatch.setenv("OETC_EMAIL", "") + monkeypatch.setenv("OETC_PASSWORD", " ") + monkeypatch.setenv("OETC_NAME", "valid") + monkeypatch.setenv("OETC_AUTH_URL", "https://auth.example.com") + monkeypatch.setenv("OETC_ORCHESTRATOR_URL", "https://orch.example.com") + + with pytest.raises(ValueError, match="OETC_EMAIL.*OETC_PASSWORD"): + OetcSettings.from_env() + + +def test_from_env_partial_kwargs(monkeypatch: pytest.MonkeyPatch) -> None: + _clear_oetc_env(monkeypatch) + monkeypatch.setenv("OETC_NAME", "env-name") + monkeypatch.setenv("OETC_AUTH_URL", "https://auth.example.com") + monkeypatch.setenv("OETC_ORCHESTRATOR_URL", "https://orch.example.com") + + s = OetcSettings.from_env(email="a@b.com", password="pw") + assert s.credentials.email == "a@b.com" + assert s.name == "env-name" + + +def test_from_env_defaults_applied(monkeypatch: pytest.MonkeyPatch) -> None: + _clear_oetc_env(monkeypatch) + _set_required_env(monkeypatch) + + s = OetcSettings.from_env() + assert s.solver == "highs" + assert s.solver_options == {} + assert s.cpu_cores == 2 + assert s.disk_space_gb == 10 + assert s.compute_provider == ComputeProvider.GCP + assert s.delete_worker_on_error is False + + +def test_from_env_cpu_cores_valid(monkeypatch: pytest.MonkeyPatch) -> None: + _clear_oetc_env(monkeypatch) + _set_required_env(monkeypatch) + monkeypatch.setenv("OETC_CPU_CORES", "4") + + assert OetcSettings.from_env().cpu_cores == 4 + + +def test_from_env_cpu_cores_invalid(monkeypatch: pytest.MonkeyPatch) -> None: + _clear_oetc_env(monkeypatch) + _set_required_env(monkeypatch) + monkeypatch.setenv("OETC_CPU_CORES", "abc") + + with pytest.raises(ValueError, match="OETC_CPU_CORES"): + OetcSettings.from_env() + + +@pytest.mark.parametrize("val", ["true", "1", "yes"]) +def test_from_env_bool_true_values(monkeypatch: pytest.MonkeyPatch, val: str) -> None: + _clear_oetc_env(monkeypatch) + _set_required_env(monkeypatch) + monkeypatch.setenv("OETC_DELETE_WORKER_ON_ERROR", val) + + assert OetcSettings.from_env().delete_worker_on_error is True + + +@pytest.mark.parametrize("val", ["false", "0", "no"]) +def test_from_env_bool_false_values(monkeypatch: pytest.MonkeyPatch, val: str) -> None: + _clear_oetc_env(monkeypatch) + _set_required_env(monkeypatch) + monkeypatch.setenv("OETC_DELETE_WORKER_ON_ERROR", val) + + assert OetcSettings.from_env().delete_worker_on_error is False + + +def test_from_env_bool_invalid(monkeypatch: pytest.MonkeyPatch) -> None: + _clear_oetc_env(monkeypatch) + _set_required_env(monkeypatch) + monkeypatch.setenv("OETC_DELETE_WORKER_ON_ERROR", "maybe") + + with pytest.raises(ValueError, match="OETC_DELETE_WORKER_ON_ERROR"): + OetcSettings.from_env() + + +def _make_handler(settings: OetcSettings) -> OetcHandler: + with ( + patch("linopy.remote.oetc._oetc_deps_available", True), + patch.object(OetcHandler, "_OetcHandler__sign_in", return_value=MagicMock()), + patch.object( + OetcHandler, + "_OetcHandler__get_cloud_provider_credentials", + return_value=MagicMock(), + ), + ): + return OetcHandler(settings) + + +def _default_settings(**overrides: Any) -> OetcSettings: + defaults: dict[str, Any] = dict( + credentials=OetcCredentials(email="a@b.com", password="pw"), + name="test", + authentication_server_url="https://auth", + orchestrator_server_url="https://orch", + solver="highs", + solver_options={"TimeLimit": 100}, + ) + defaults.update(overrides) + return OetcSettings(**defaults) + + +def test_solve_on_oetc_mutation_safety() -> None: + settings = _default_settings() + handler = _make_handler(settings) + original_opts = dict(settings.solver_options) + + mock_model = MagicMock() + mock_solved = MagicMock() + mock_solved.objective.value = 42.0 + mock_solved.status = "ok" + + with ( + patch.object(handler, "_upload_file_to_gcp", return_value="file.nc.gz"), + patch.object(handler, "_submit_job_to_compute_service", return_value="uuid"), + patch.object(handler, "wait_and_get_job_data") as mock_wait, + patch.object(handler, "_download_file_from_gcp", return_value="/tmp/sol.nc"), + patch("linopy.read_netcdf", return_value=mock_solved), + patch("os.remove"), + ): + mock_wait.return_value = MagicMock(output_files=["out.nc.gz"]) + + handler.solve_on_oetc(mock_model, Extra=999) + handler.solve_on_oetc(mock_model, Other=1) + + assert settings.solver_options == original_opts + + +def test_solve_on_oetc_solver_name_override() -> None: + settings = _default_settings() + handler = _make_handler(settings) + + mock_model = MagicMock() + mock_solved = MagicMock() + mock_solved.objective.value = 1.0 + mock_solved.status = "ok" + + with ( + patch.object(handler, "_upload_file_to_gcp", return_value="file.nc.gz"), + patch.object( + handler, "_submit_job_to_compute_service", return_value="uuid" + ) as mock_submit, + patch.object(handler, "wait_and_get_job_data") as mock_wait, + patch.object(handler, "_download_file_from_gcp", return_value="/tmp/sol.nc"), + patch("linopy.read_netcdf", return_value=mock_solved), + patch("os.remove"), + ): + mock_wait.return_value = MagicMock(output_files=["out.nc.gz"]) + + handler.solve_on_oetc(mock_model, solver_name="gurobi") + + mock_submit.assert_called_once() + assert mock_submit.call_args[0][1] == "gurobi" + + +def test_solve_on_oetc_solver_options_merge_precedence() -> None: + settings = _default_settings(solver_options={"TimeLimit": 100}) + handler = _make_handler(settings) + + mock_model = MagicMock() + mock_solved = MagicMock() + mock_solved.objective.value = 1.0 + mock_solved.status = "ok" + + with ( + patch.object(handler, "_upload_file_to_gcp", return_value="file.nc.gz"), + patch.object( + handler, "_submit_job_to_compute_service", return_value="uuid" + ) as mock_submit, + patch.object(handler, "wait_and_get_job_data") as mock_wait, + patch.object(handler, "_download_file_from_gcp", return_value="/tmp/sol.nc"), + patch("linopy.read_netcdf", return_value=mock_solved), + patch("os.remove"), + ): + mock_wait.return_value = MagicMock(output_files=["out.nc.gz"]) + + handler.solve_on_oetc(mock_model, TimeLimit=200) + + mock_submit.assert_called_once() + assert mock_submit.call_args[0][2] == {"TimeLimit": 200} + + +def test_solve_on_oetc_solver_name_default_fallback() -> None: + settings = _default_settings(solver="cplex") + handler = _make_handler(settings) + + mock_model = MagicMock() + mock_solved = MagicMock() + mock_solved.objective.value = 1.0 + mock_solved.status = "ok" + + with ( + patch.object(handler, "_upload_file_to_gcp", return_value="file.nc.gz"), + patch.object( + handler, "_submit_job_to_compute_service", return_value="uuid" + ) as mock_submit, + patch.object(handler, "wait_and_get_job_data") as mock_wait, + patch.object(handler, "_download_file_from_gcp", return_value="/tmp/sol.nc"), + patch("linopy.read_netcdf", return_value=mock_solved), + patch("os.remove"), + ): + mock_wait.return_value = MagicMock(output_files=["out.nc.gz"]) + + handler.solve_on_oetc(mock_model) + + mock_submit.assert_called_once() + assert mock_submit.call_args[0][1] == "cplex" + + +def test_from_env_disk_space_gb_invalid(monkeypatch: pytest.MonkeyPatch) -> None: + _clear_oetc_env(monkeypatch) + _set_required_env(monkeypatch) + monkeypatch.setenv("OETC_DISK_SPACE_GB", "abc") + + with pytest.raises(ValueError, match="OETC_DISK_SPACE_GB"): + OetcSettings.from_env() + + +def test_model_solve_forwards_to_oetc() -> None: + from linopy import Model + + m = Model() + m.add_variables(lower=0, name="x") + + handler = MagicMock(spec=OetcHandler) + mock_solved = MagicMock() + mock_solved.status = "ok" + mock_solved.termination_condition = "optimal" + mock_solved.objective.value = 10.0 + mock_solved.variables.items.return_value = [(k, v) for k, v in m.variables.items()] + mock_solved.constraints.items.return_value = [] + for k in m.variables: + mock_solved.variables[k].solution = 0.0 + handler.solve_on_oetc.return_value = mock_solved + + m.solve(solver_name="gurobi", remote=handler, TimeLimit=100) + + handler.solve_on_oetc.assert_called_once_with( + m, solver_name="gurobi", TimeLimit=100 + )