Skip to content

Commit ca4340f

Browse files
committed
fix: avoid blocking and use threads for now
1 parent 44c5972 commit ca4340f

7 files changed

Lines changed: 176 additions & 15 deletions

File tree

src/pypsa_app/backend/api/routes/runs.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import logging
44
import re
5+
import threading
56
import urllib.parse
67
import uuid
78
from pathlib import PurePosixPath
@@ -34,7 +35,11 @@
3435
)
3536
from pypsa_app.backend.services.backend_registry import backend_registry
3637
from pypsa_app.backend.services.run import SnakedispatchClient, SnakedispatchError
37-
from pypsa_app.backend.services.sync import SYNCED_STATUSES, sync_run_from_job
38+
from pypsa_app.backend.services.callback import fire_callback_sync
39+
from pypsa_app.backend.services.sync import (
40+
SYNCED_STATUSES,
41+
sync_run_from_job,
42+
)
3843
from pypsa_app.backend.settings import settings
3944

4045
router = APIRouter()
@@ -125,7 +130,7 @@ def create_run(
125130

126131
payload = body.model_dump(
127132
exclude_none=True,
128-
exclude={"backend_id", "import_networks", "cache"},
133+
exclude={"backend_id", "import_networks", "cache", "callback_url"},
129134
)
130135
if body.cache:
131136
payload["cache_key"] = body.cache.key
@@ -142,6 +147,7 @@ def create_run(
142147
extra_files=body.extra_files,
143148
cache=body.cache.model_dump() if body.cache else None,
144149
import_networks=body.import_networks,
150+
callback_url=str(body.callback_url) if body.callback_url else None,
145151
status=RunStatus(result.get("status", "PENDING")),
146152
)
147153
db.add(run)
@@ -281,8 +287,16 @@ def get_run(
281287
if client:
282288
try:
283289
job = client.get_job(str(run_id))
284-
sync_run_from_job(run, job, db)
290+
needs_callback = sync_run_from_job(run, job, db)
285291
db.commit()
292+
if needs_callback:
293+
# TODO: replace with proper async callback or
294+
# FastAPI BackgroundTasks.
295+
threading.Thread(
296+
target=fire_callback_sync,
297+
args=(run,),
298+
daemon=True,
299+
).start()
286300
except SnakedispatchError:
287301
pass
288302

@@ -378,13 +392,24 @@ def cancel_run(
378392
sd_client = _get_client_for_run(run)
379393
try:
380394
result = sd_client.cancel_job(str(run_id))
381-
sync_run_from_job(run, result, db)
395+
needs_callback = sync_run_from_job(run, result, db)
382396
db.commit()
397+
if needs_callback:
398+
# TODO: replace with proper async callback or
399+
# FastAPI BackgroundTasks.
400+
threading.Thread(
401+
target=fire_callback_sync, args=(run,), daemon=True
402+
).start()
383403
except SnakedispatchError as e:
384404
if e.status_code in (404, 409):
385405
if run.status not in SYNCED_STATUSES:
386406
run.status = RunStatus.CANCELLED
387407
db.commit()
408+
# TODO: replace with proper async callback or
409+
# FastAPI BackgroundTasks.
410+
threading.Thread(
411+
target=fire_callback_sync, args=(run,), daemon=True
412+
).start()
388413
else:
389414
raise
390415

src/pypsa_app/backend/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ class Run(Base):
294294
extra_files: Mapped[Any | None] = mapped_column(JSON)
295295
cache: Mapped[Any | None] = mapped_column(JSON)
296296

297+
callback_url: Mapped[str | None] = mapped_column(String(512))
298+
297299
# Job metadata (synced from Snakedispatch)
298300
git_ref: Mapped[str | None] = mapped_column(String(255))
299301
git_sha: Mapped[str | None] = mapped_column(String(40))

src/pypsa_app/backend/schemas/run.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
import uuid
44
from datetime import datetime
55

6-
from pydantic import BaseModel, ConfigDict, Field
6+
from urllib.parse import urlparse
7+
8+
from pydantic import BaseModel, ConfigDict, Field, HttpUrl, field_validator
79

810
from pypsa_app.backend.models import RunStatus
911
from pypsa_app.backend.schemas.auth import UserPublicResponse
1012
from pypsa_app.backend.schemas.backend import BackendPublicResponse
1113
from pypsa_app.backend.schemas.common import PaginationMeta
14+
from pypsa_app.backend.settings import settings
1215

1316

1417
class RunCache(BaseModel):
@@ -46,6 +49,22 @@ class RunCreate(BaseModel):
4649
cache: RunCache | None = None
4750
import_networks: list[str] | None = None
4851
backend_id: uuid.UUID | None = None
52+
callback_url: HttpUrl | None = None
53+
54+
@field_validator("callback_url")
55+
@classmethod
56+
def _validate_callback_domain(cls, v: HttpUrl | None) -> HttpUrl | None:
57+
if v is None:
58+
return v
59+
allowed = settings.resolved_callback_domains
60+
if not allowed:
61+
msg = "Callbacks are not enabled on this server"
62+
raise ValueError(msg)
63+
host = v.host or ""
64+
if not any(host == d or host.endswith(f".{d}") for d in allowed):
65+
msg = f"callback_url host '{host}' is not in the allowed domains"
66+
raise ValueError(msg)
67+
return v
4968

5069

5170
class RunSummary(BaseModel):
@@ -75,9 +94,18 @@ class RunResponse(RunSummary):
7594
extra_files: dict[str, str] | None = None
7695
cache: RunCache | None = None
7796
import_networks: list[str] | None = None
97+
callback_url: str | None = Field(None, validation_alias="callback_url")
7898
exit_code: int | None = None
7999
networks: list[RunNetworkSummary] = []
80100

101+
@field_validator("callback_url", mode="before")
102+
@classmethod
103+
def _redact_callback_url(cls, v: str | None) -> str | None:
104+
if not v:
105+
return None
106+
parsed = urlparse(v)
107+
return f"{parsed.scheme}://{parsed.hostname}/***"
108+
81109

82110
class RunListMeta(PaginationMeta):
83111
"""Extended pagination meta with run-specific filter options."""
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Callback helpers for notifying external systems of run status changes."""
2+
3+
import logging
4+
5+
import httpx
6+
7+
from pypsa_app.backend.models import Run
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def _build_payload(run: Run) -> dict:
13+
return {"run_id": str(run.job_id), "status": run.status.value}
14+
15+
16+
def fire_callback_sync(run: Run) -> None:
17+
"""POST to the run's callback URL (blocking)."""
18+
if not run.callback_url:
19+
return
20+
url = str(run.callback_url)
21+
payload = _build_payload(run)
22+
try:
23+
httpx.post(url, json=payload, timeout=5.0, follow_redirects=False)
24+
except Exception:
25+
logger.warning(
26+
"Callback failed for run %s to %s",
27+
payload["run_id"],
28+
url,
29+
exc_info=True,
30+
)
31+
32+
33+
async def fire_callback_async(url: str, payload: dict) -> None:
34+
"""POST to a callback URL (async). Used by the background sync loop."""
35+
try:
36+
async with httpx.AsyncClient() as client:
37+
await client.post(url, json=payload, timeout=5.0, follow_redirects=False)
38+
except Exception:
39+
logger.warning(
40+
"Callback failed for run %s to %s",
41+
payload.get("run_id"),
42+
url,
43+
exc_info=True,
44+
)

src/pypsa_app/backend/services/sync.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111

1212
from pypsa_app.backend.database import SessionLocal
1313
from pypsa_app.backend.models import Run, RunStatus
14+
from pypsa_app.backend.services.callback import fire_callback_async
1415
from pypsa_app.backend.services.backend_registry import backend_registry
1516
from pypsa_app.backend.tasks import import_run_outputs_task
1617

1718
logger = logging.getLogger(__name__)
1819

20+
# Hold references to fire-and-forget callback tasks to prevent garbage collection.
21+
_background_tasks: set[asyncio.Task] = set()
22+
1923
# Statuses where the remote executor is done, no need to sync from Snakedispatch
2024
SYNCED_STATUSES = {
2125
RunStatus.UPLOADING,
@@ -41,8 +45,16 @@
4145
]
4246

4347

44-
def sync_run_from_job(run: Run, job: dict, db: Session) -> None:
45-
"""Update a Run record from a Snakedispatch response dict."""
48+
_CALLBACK_STATUSES = SYNCED_STATUSES - {RunStatus.UPLOADING}
49+
50+
51+
def sync_run_from_job(run: Run, job: dict, db: Session) -> bool:
52+
"""Update a Run record from a Snakedispatch response dict.
53+
54+
Returns:
55+
True if a callback should be fired after the transaction commits.
56+
"""
57+
old_status = run.status
4658
changed = False
4759
for field in _SYNC_FIELDS:
4860
new_val = job.get(field)
@@ -65,7 +77,7 @@ def sync_run_from_job(run: Run, job: dict, db: Session) -> None:
6577
run.status = RunStatus.UPLOADING
6678
db.flush()
6779
import_run_outputs_task.apply_async(args=(str(run.job_id),))
68-
return
80+
return False
6981
if completed_with_import_pending:
7082
run.status = RunStatus.COMPLETED
7183
changed = True
@@ -76,38 +88,63 @@ def sync_run_from_job(run: Run, job: dict, db: Session) -> None:
7688
if changed:
7789
db.flush()
7890

91+
return run.status in _CALLBACK_STATUSES and old_status not in _CALLBACK_STATUSES
92+
93+
94+
def sync_non_terminal_runs() -> list[dict]:
95+
"""Poll all backends and update runs that haven't reached a terminal state.
7996
80-
def sync_non_terminal_runs() -> None:
81-
"""Poll all backends and update runs that haven't reached a terminal state."""
97+
Returns:
98+
List of callback dicts ``{"url": ..., "payload": ...}`` to be fired
99+
by the async caller after the DB session is closed.
100+
"""
101+
callbacks: list[dict] = []
82102
db = SessionLocal()
83103
try:
84104
non_terminal = db.query(Run).filter(Run.status.notin_(SYNCED_STATUSES)).all()
85105
if not non_terminal:
86-
return
106+
return callbacks
87107

88108
for backend_id, client in backend_registry.all_clients().items():
89109
backend_runs = [r for r in non_terminal if r.backend_id == backend_id]
90110
if not backend_runs:
91111
continue
92112
try:
93113
jobs_by_id = {j["job_id"]: j for j in client.list_jobs()}
114+
callback_runs: list[Run] = []
94115
for run in backend_runs:
95116
job = jobs_by_id.get(str(run.job_id))
96-
if job:
97-
sync_run_from_job(run, job, db)
117+
if job and sync_run_from_job(run, job, db):
118+
callback_runs.append(run)
98119
db.commit()
120+
callbacks.extend(
121+
{
122+
"url": str(run.callback_url),
123+
"payload": {
124+
"run_id": str(run.job_id),
125+
"status": run.status.value,
126+
},
127+
}
128+
for run in callback_runs
129+
if run.callback_url
130+
)
99131
except Exception:
100132
db.rollback()
101133
logger.warning("Sync failed for backend %s", backend_id, exc_info=True)
102134
finally:
103135
db.close()
136+
return callbacks
104137

105138

106139
async def run_sync_loop(interval: float = 10.0) -> None:
107140
"""Periodically sync non-terminal runs in a background thread."""
108141
while True:
109142
await asyncio.sleep(interval)
110143
try:
111-
await asyncio.to_thread(sync_non_terminal_runs)
144+
callbacks = await asyncio.to_thread(sync_non_terminal_runs)
145+
for cb in callbacks:
146+
task = asyncio.create_task(fire_callback_async(cb["url"], cb["payload"]))
147+
_background_tasks.add(task)
148+
task.add_done_callback(_background_tasks.discard)
112149
except Exception:
113150
logger.warning("Background run sync failed", exc_info=True)

src/pypsa_app/backend/settings.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,26 @@ def networks_path(self) -> Path:
9797
description="Interval in seconds between background Snakedispatch sync cycles",
9898
json_schema_extra={"category": "Runs"},
9999
)
100+
callback_url_allowed_domains: str = Field(
101+
default="",
102+
description=(
103+
"Comma-separated list of allowed domains for run callback URLs "
104+
"(e.g. hooks.myorg.dev,example.com). "
105+
"Callbacks are rejected unless the host matches. "
106+
"Empty disables callbacks entirely."
107+
),
108+
json_schema_extra={"category": "Runs"},
109+
)
110+
111+
@property
112+
def resolved_callback_domains(self) -> list[str]:
113+
"""Parse CALLBACK_URL_ALLOWED_DOMAINS into a list of domain strings."""
114+
if not self.callback_url_allowed_domains:
115+
return []
116+
return [
117+
d.strip() for d in self.callback_url_allowed_domains.split(",") if d.strip()
118+
]
119+
100120
snakedispatch_backends: str | None = Field(
101121
default=None,
102122
description=(

src/pypsa_app/backend/tasks.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pypsa_app.backend.models import Run, RunStatus, SnakedispatchBackend
1414
from pypsa_app.backend.schemas.task import TaskResultResponse
1515
from pypsa_app.backend.services.network import import_network_file
16+
from pypsa_app.backend.services.callback import fire_callback_sync
1617
from pypsa_app.backend.services.run import SnakedispatchClient
1718
from pypsa_app.backend.services.statistics import get_plot as get_plot_service
1819
from pypsa_app.backend.services.statistics import (
@@ -69,7 +70,7 @@ def get_plot_task(self: Any, **kwargs: Any) -> dict[str, Any]:
6970

7071

7172
@task_app.task(bind=True, name="tasks.import_run_outputs")
72-
def import_run_outputs_task(self: Any, job_id: str) -> None:
73+
def import_run_outputs_task(self: Any, job_id: str) -> None: # noqa: PLR0915
7374
"""Download .nc outputs from a completed run and import as networks."""
7475
db = SessionLocal()
7576
try:
@@ -90,6 +91,7 @@ def import_run_outputs_task(self: Any, job_id: str) -> None:
9091
)
9192
run.status = RunStatus.ERROR
9293
db.commit()
94+
fire_callback_sync(run)
9395
return
9496
sd_client = SnakedispatchClient(backend.url)
9597
wanted_set = set(run.import_networks or [])
@@ -132,19 +134,22 @@ def import_run_outputs_task(self: Any, job_id: str) -> None:
132134
if run:
133135
run.status = RunStatus.ERROR
134136
db.commit()
137+
fire_callback_sync(run)
135138
return
136139
finally:
137140
tmp.unlink(missing_ok=True)
138141

139142
run.status = RunStatus.COMPLETED
140143
db.commit()
144+
fire_callback_sync(run)
141145
except Exception:
142146
logger.exception("Import task failed", extra={"run_id": job_id})
143147
try:
144148
run = db.query(Run).filter(Run.job_id == job_id).first()
145149
if run:
146150
run.status = RunStatus.ERROR
147151
db.commit()
152+
fire_callback_sync(run)
148153
except Exception:
149154
logger.exception("Failed to mark run as ERROR", extra={"run_id": job_id})
150155
finally:

0 commit comments

Comments
 (0)