Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions linopy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,9 @@ def solve(

if remote is not None:
if isinstance(remote, OetcHandler):
solved = remote.solve_on_oetc(self)
solved = remote.solve_on_oetc(
self, solver_name=solver_name, **solver_options
)
else:
solved = remote.solve_on_remote(
self,
Expand All @@ -1417,7 +1419,8 @@ def solve(
**solver_options,
)

self.objective.set_value(solved.objective.value)
if solved.objective.value is not None:
self.objective.set_value(float(solved.objective.value))
self.status = solved.status
self.termination_condition = solved.termination_condition
for k, v in self.variables.items():
Expand Down
130 changes: 115 additions & 15 deletions linopy/remote/oetc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import base64
import gzip
import json
Expand All @@ -8,6 +10,10 @@
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from linopy.model import Model

try:
import requests
Expand Down Expand Up @@ -42,11 +48,97 @@ class OetcSettings:
orchestrator_server_url: str
compute_provider: ComputeProvider = ComputeProvider.GCP
solver: str = "highs"
solver_options: dict = field(default_factory=dict)
solver_options: dict[str, Any] = field(default_factory=dict)
cpu_cores: int = 2
disk_space_gb: int = 10
delete_worker_on_error: bool = False

@classmethod
def from_env(
cls,
*,
email: str | None = None,
password: str | None = None,
name: str | None = None,
authentication_server_url: str | None = None,
orchestrator_server_url: str | None = None,
cpu_cores: int | None = None,
disk_space_gb: int | None = None,
delete_worker_on_error: bool | None = None,
) -> OetcSettings:
required_fields = {
"email": ("OETC_EMAIL", email),
"password": ("OETC_PASSWORD", password),
"name": ("OETC_NAME", name),
"authentication_server_url": ("OETC_AUTH_URL", authentication_server_url),
"orchestrator_server_url": (
"OETC_ORCHESTRATOR_URL",
orchestrator_server_url,
),
}

resolved: dict[str, Any] = {}
missing: list[str] = []

for field_name, (env_var, kwarg) in required_fields.items():
if kwarg is not None:
resolved[field_name] = kwarg
else:
env_val = os.environ.get(env_var, "").strip()
if env_val:
resolved[field_name] = env_val
else:
missing.append(env_var)

if missing:
raise ValueError(
f"Missing required OETC configuration: {', '.join(missing)}"
)

kwargs: dict[str, Any] = {
"credentials": OetcCredentials(
email=resolved["email"], password=resolved["password"]
),
"name": resolved["name"],
"authentication_server_url": resolved["authentication_server_url"],
"orchestrator_server_url": resolved["orchestrator_server_url"],
}

if cpu_cores is not None:
kwargs["cpu_cores"] = cpu_cores
elif (cpu_env := os.environ.get("OETC_CPU_CORES")) is not None:
try:
kwargs["cpu_cores"] = int(cpu_env)
except ValueError as e:
raise ValueError(
f"OETC_CPU_CORES is not a valid integer: {cpu_env}"
) from e

if disk_space_gb is not None:
kwargs["disk_space_gb"] = disk_space_gb
elif (disk_env := os.environ.get("OETC_DISK_SPACE_GB")) is not None:
try:
kwargs["disk_space_gb"] = int(disk_env)
except ValueError as e:
raise ValueError(
f"OETC_DISK_SPACE_GB is not a valid integer: {disk_env}"
) from e

if delete_worker_on_error is not None:
kwargs["delete_worker_on_error"] = delete_worker_on_error
elif (del_env := os.environ.get("OETC_DELETE_WORKER_ON_ERROR")) is not None:
low = del_env.lower()
if low in ("true", "1", "yes"):
kwargs["delete_worker_on_error"] = True
elif low in ("false", "0", "no"):
kwargs["delete_worker_on_error"] = False
else:
raise ValueError(
f"OETC_DELETE_WORKER_ON_ERROR has invalid value: {del_env}"
)

return cls(**kwargs)


@dataclass
class GcpCredentials:
Expand Down Expand Up @@ -226,12 +318,16 @@ def __get_gcp_credentials(self) -> GcpCredentials:
except Exception as e:
raise Exception(f"Error fetching GCP credentials: {e}")

def _submit_job_to_compute_service(self, input_file_name: str) -> str:
def _submit_job_to_compute_service(
self, input_file_name: str, solver: str, solver_options: dict[str, Any]
) -> str:
"""
Submit a job to the compute service.

Args:
input_file_name: Name of the input file uploaded to GCP
solver: Solver name to use
solver_options: Solver options dict

Returns:
CreateComputeJobResult: The job creation result with UUID
Expand All @@ -243,8 +339,8 @@ def _submit_job_to_compute_service(self, input_file_name: str) -> str:
logger.info("OETC - Submitting compute job...")
payload = {
"name": self.settings.name,
"solver": self.settings.solver,
"solver_options": self.settings.solver_options,
"solver": solver,
"solver_options": solver_options,
"provider": self.settings.compute_provider.value,
"cpu_cores": self.settings.cpu_cores,
"disk_space_gb": self.settings.disk_space_gb,
Expand Down Expand Up @@ -534,13 +630,19 @@ def _download_file_from_gcp(self, file_name: str) -> str:
except Exception as e:
raise Exception(f"Failed to download file from GCP: {e}")

def solve_on_oetc(self, model): # type: ignore
def solve_on_oetc(
self, model: Model, solver_name: str | None = None, **solver_options: Any
) -> Model:
"""
Solve a linopy model on the OET Cloud compute app.

Parameters
----------
model : linopy.model.Model
solver_name : str, optional
Override the solver from settings.
**solver_options
Override/extend solver_options from settings.

Returns
-------
Expand All @@ -552,17 +654,19 @@ def solve_on_oetc(self, model): # type: ignore
Exception: If solving fails at any stage
"""
try:
# Save model to temporary file and upload
effective_solver = solver_name or self.settings.solver
merged_solver_options = {**self.settings.solver_options, **solver_options}

with tempfile.NamedTemporaryFile(prefix="linopy-", suffix=".nc") as fn:
fn.file.close()
model.to_netcdf(fn.name)
input_file_name = self._upload_file_to_gcp(fn.name)

# Submit job and wait for completion
job_uuid = self._submit_job_to_compute_service(input_file_name)
job_uuid = self._submit_job_to_compute_service(
input_file_name, effective_solver, merged_solver_options
)
job_result = self.wait_and_get_job_data(job_uuid)

# Download and load the solution
if not job_result.output_files:
raise Exception("No output files found in completed job")

Expand All @@ -572,26 +676,22 @@ def solve_on_oetc(self, model): # type: ignore

solution_file_path = self._download_file_from_gcp(output_file_name)

# Load the solved model
solved_model = linopy.read_netcdf(solution_file_path)

# Clean up downloaded file
os.remove(solution_file_path)

logger.info(
f"OETC - Model solved successfully. Status: {solved_model.status}"
)
if hasattr(solved_model, "objective") and hasattr(
solved_model.objective, "value"
):
if solved_model.objective.value is not None:
logger.info(
f"OETC - Objective value: {solved_model.objective.value:.2e}"
)

return solved_model

except Exception as e:
raise Exception(f"Error solving model on OETC: {e}")
raise Exception(f"Error solving model on OETC: {e}") from e

def _gzip_compress(self, source_path: str) -> str:
"""
Expand Down
24 changes: 18 additions & 6 deletions test/remote/test_oetc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,9 @@ def test_submit_job_success(
mock_post.return_value = mock_response

# Execute
result = handler_with_auth_setup._submit_job_to_compute_service(input_file_name)
result = handler_with_auth_setup._submit_job_to_compute_service(
input_file_name, "gurobi", {}
)

# Verify request
expected_payload = {
Expand Down Expand Up @@ -1434,7 +1436,9 @@ def test_submit_job_http_error(

# Execute and verify exception
with pytest.raises(Exception) as exc_info:
handler_with_auth_setup._submit_job_to_compute_service(input_file_name)
handler_with_auth_setup._submit_job_to_compute_service(
input_file_name, "highs", {}
)

assert "Failed to submit job to compute service" in str(exc_info.value)

Expand All @@ -1452,7 +1456,9 @@ def test_submit_job_missing_uuid_in_response(

# Execute and verify exception
with pytest.raises(Exception) as exc_info:
handler_with_auth_setup._submit_job_to_compute_service(input_file_name)
handler_with_auth_setup._submit_job_to_compute_service(
input_file_name, "highs", {}
)

assert "Invalid job submission response format: missing field 'uuid'" in str(
exc_info.value
Expand All @@ -1469,7 +1475,9 @@ def test_submit_job_network_error(

# Execute and verify exception
with pytest.raises(Exception) as exc_info:
handler_with_auth_setup._submit_job_to_compute_service(input_file_name)
handler_with_auth_setup._submit_job_to_compute_service(
input_file_name, "highs", {}
)

assert "Failed to submit job to compute service" in str(exc_info.value)

Expand Down Expand Up @@ -1568,7 +1576,9 @@ def test_solve_on_oetc_file_upload(
"/tmp/linopy-abc123.nc"
)
mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc")
mock_submit.assert_called_once_with("uploaded_file.nc.gz")
mock_submit.assert_called_once_with(
"uploaded_file.nc.gz", "highs", {}
)
mock_wait.assert_called_once_with("test-job-uuid")
mock_download.assert_called_once_with("result.nc.gz")
mock_read_netcdf.assert_called_once_with(
Expand Down Expand Up @@ -1694,7 +1704,9 @@ def test_solve_on_oetc_with_job_submission(
"/tmp/linopy-abc123.nc"
)
mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc")
mock_submit.assert_called_once_with(uploaded_file_name)
mock_submit.assert_called_once_with(
uploaded_file_name, "highs", {}
)
mock_wait.assert_called_once_with(job_uuid)
mock_download.assert_called_once_with("result.nc.gz")
mock_read_netcdf.assert_called_once_with(
Expand Down
Loading
Loading