Skip to content

Commit 5fc9f23

Browse files
ENH: Add ability for snapshot restore to report progress to the frontend
1 parent ddc79b3 commit 5fc9f23

5 files changed

Lines changed: 206 additions & 8 deletions

File tree

app/api/v1/snapshots.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from app.models.job import JobType
1212
from app.schemas.job import JobCreatedDTO
1313
from app.dependencies import require_read_access, require_write_access
14-
from app.api.responses import APIException, success_response
14+
from app.api.responses import APIException, error_response, success_response
1515
from app.schemas.snapshot import NewSnapshotDTO, RestoreRequestDTO, UpdateSnapshotDTO
1616
from app.services.job_service import JobService
1717
from app.services.epics_service import EpicsService, get_epics_service
@@ -190,6 +190,7 @@ async def restore_snapshot(
190190
snapshot_id: str,
191191
request: RestoreRequestDTO | None = None,
192192
db: AsyncSession = Depends(get_db),
193+
async_mode: bool = Query(True, alias="async"),
193194
epics: EpicsService = Depends(get_epics_service),
194195
) -> dict:
195196
"""
@@ -208,6 +209,34 @@ async def restore_snapshot(
208209
if not snapshot:
209210
raise APIException(404, f"Snapshot {snapshot_id} not found", 404)
210211

212+
if async_mode:
213+
job_service = JobService(db)
214+
job = await job_service.create_job(
215+
JobType.SNAPSHOT_RESTORE,
216+
job_data={"snapshotId": snapshot_id},
217+
)
218+
219+
await db.commit()
220+
pool = await get_arq_pool()
221+
if pool:
222+
try:
223+
await pool.enqueue_job(
224+
"restore_snapshot_task",
225+
job_id=str(job.id),
226+
snapshot_id=snapshot_id,
227+
pv_ids=request.pvIds if request else None,
228+
)
229+
logger.info(f"Enqueued restore job to Arq: {job.id}")
230+
231+
return success_response(
232+
JobCreatedDTO(
233+
jobId=job.id,
234+
message=f"Snapshot restore queued ({snapshot_id})",
235+
)
236+
)
237+
except Exception as e:
238+
return error_response(500, f"Failed to enqueue to Arq: {e}")
239+
211240
result = await service.restore_snapshot(snapshot_id, request)
212241
return success_response(result)
213242

app/services/background_tasks.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from datetime import datetime
55

66
from app.db.session import async_session_maker
7-
from app.schemas.snapshot import NewSnapshotDTO
7+
from app.schemas.snapshot import NewSnapshotDTO, RestoreRequestDTO
88
from app.services.epics_service import get_epics_service
99
from app.services.redis_service import get_redis_service
1010
from app.services.snapshot_service import SnapshotService
@@ -111,6 +111,92 @@ async def progress_update(current: int, total: int, message: str):
111111
logger.exception(f"Failed to update job status: {inner_e}")
112112

113113

114+
async def run_snapshot_restore(
115+
job_id: str,
116+
snapshot_id: str,
117+
pv_ids: list[str] | None = None,
118+
) -> None:
119+
"""
120+
Background task to restore a snapshot.
121+
122+
This runs in a separate asyncio task and uses its own database session.
123+
"""
124+
logger.info(f"Background task started for job {job_id}: Restoring snapshot '{snapshot_id}'")
125+
126+
async with async_session_maker() as session:
127+
try:
128+
job_repo = JobRepository(session)
129+
130+
# Mark job as running
131+
await job_repo.mark_running(job_id)
132+
await session.commit()
133+
await asyncio.sleep(0)
134+
135+
# Initial progress update
136+
await job_repo.update_progress(job_id, 5, "Loading snapshot values...")
137+
await session.commit()
138+
await asyncio.sleep(0)
139+
140+
epics = get_epics_service()
141+
snapshot_service = SnapshotService(session, epics)
142+
143+
# Optional restore request
144+
request = RestoreRequestDTO(pvIds=pv_ids) if pv_ids else None
145+
146+
last_update = {"progress": 5, "last_time": datetime.now()}
147+
148+
async def progress_update(current: int, total: int, message: str):
149+
try:
150+
write_progress = int((current / total) * 85) if total > 0 else 0
151+
job_progress = 10 + write_progress
152+
153+
now = datetime.now()
154+
time_elapsed = (now - last_update["last_time"]).total_seconds()
155+
progress_changed = job_progress - last_update["progress"] >= 2
156+
157+
if progress_changed or time_elapsed >= 2.0 or current >= total:
158+
last_update["progress"] = job_progress
159+
last_update["last_time"] = now
160+
await job_repo.update_progress(job_id, job_progress, message)
161+
await session.commit()
162+
await asyncio.sleep(0)
163+
except Exception as e:
164+
logger.error(f"Error in restore progress_update: {e}")
165+
166+
result = await snapshot_service.restore_snapshot(
167+
snapshot_id,
168+
request,
169+
progress_callback=progress_update,
170+
)
171+
172+
# Final completion update
173+
completion_message = f"Restore completed: {result.successCount}/{result.totalPVs} PVs restored" + (
174+
f", {result.failureCount} failed" if result.failureCount > 0 else ""
175+
)
176+
177+
await job_repo.mark_completed(
178+
job_id,
179+
message=completion_message,
180+
)
181+
await session.commit()
182+
183+
logger.info(
184+
f"Background restore completed for job {job_id}: "
185+
f"{result.successCount}/{result.totalPVs} succeeded, "
186+
f"{result.failureCount} failed"
187+
)
188+
189+
except Exception as e:
190+
logger.exception(f"Background restore failed for job {job_id}: {e}")
191+
error_msg = f"{type(e).__name__}: {str(e)}"
192+
try:
193+
await session.rollback()
194+
await job_repo.mark_failed(job_id, error_msg)
195+
await session.commit()
196+
except Exception as inner_e:
197+
logger.exception(f"Failed to update restore job status: {inner_e}")
198+
199+
114200
def schedule_snapshot_creation(job_id: str, title: str, description: str | None = None, use_cache: bool = True) -> None:
115201
"""
116202
Schedule a snapshot creation task to run in the background.

app/services/epics_service.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,54 @@ async def put_many(self, values: dict[str, Any]) -> dict[str, tuple[bool, str |
471471

472472
return results
473473

474+
async def put_many_with_progress(
475+
self,
476+
values: dict[str, Any],
477+
progress_callback: Callable | None = None,
478+
) -> dict[str, tuple[bool, str | None]]:
479+
"""
480+
Put values to multiple PVs with progress tracking.
481+
Can be used to update the user on progress when a snapshot restore is initiated.
482+
"""
483+
total_pvs = len(values)
484+
results: dict[str, tuple[bool, str | None]] = {}
485+
486+
logger.info(f"Starting put_many_with_progress for {total_pvs} PVs")
487+
488+
if progress_callback:
489+
await progress_callback(0, total_pvs, f"Starting restore of {total_pvs:,} PVs")
490+
491+
items = list(values.items())
492+
batch_size = self._chunk_size
493+
494+
for i in range(0, total_pvs, batch_size):
495+
batch_items = items[i : i + batch_size]
496+
batch_values = dict(batch_items)
497+
498+
try:
499+
batch_results = await self.put_many(batch_values)
500+
results.update(batch_results)
501+
except Exception as e:
502+
logger.error(f"Chunk put error ({i}-{i + len(batch_items)}): {e}")
503+
for pv_name, _ in batch_items:
504+
if pv_name not in results:
505+
results[pv_name] = (False, str(e))
506+
507+
current = min(i + batch_size, total_pvs)
508+
success_count = sum(1 for ok, _ in results.values() if ok)
509+
510+
if progress_callback:
511+
await progress_callback(
512+
current,
513+
total_pvs,
514+
f"Restored {current:,}/{total_pvs:,} PVs ({success_count:,} successful)",
515+
)
516+
517+
logger.info(f"Restored {current:,}/{total_pvs:,} PVs " f"({success_count:,} successful)")
518+
519+
logger.info(f"Completed put_many_with_progress: {len(results)}/{total_pvs} PVs processed")
520+
return results
521+
474522
async def shutdown(self):
475523
"""Cleanup resources."""
476524
# aioca manages its own connections via libca

app/services/snapshot_service.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,12 @@ def _format_ts(ts: float | None) -> str | None:
522522
logger.exception(f"Error creating snapshot from cache '{data.title}': {e}")
523523
raise
524524

525-
async def restore_snapshot(self, snapshot_id: str, request: RestoreRequestDTO | None = None) -> RestoreResultDTO:
525+
async def restore_snapshot(
526+
self,
527+
snapshot_id: str,
528+
request: RestoreRequestDTO | None = None,
529+
progress_callback: Callable | None = None,
530+
) -> RestoreResultDTO:
526531
"""
527532
Restore PV values from a snapshot to EPICS.
528533
@@ -560,10 +565,18 @@ async def restore_snapshot(self, snapshot_id: str, request: RestoreRequestDTO |
560565
values_to_write[pv.setpoint_address] = write_value
561566
pv_id_by_address[pv.setpoint_address] = pv.id
562567

563-
logger.info(f"Writing {len(values_to_write)} PV values")
568+
total_pvs = len(values_to_write)
569+
logger.info(f"Writing {total_pvs} PV values")
570+
571+
# Initial progress update
572+
if progress_callback:
573+
await progress_callback(0, total_pvs, f"Starting restore of {total_pvs:,} PVs")
564574

565575
# Write to EPICS in parallel
566-
results = await self.epics.put_many(values_to_write)
576+
if progress_callback:
577+
results = await self.epics.put_many_with_progress(values_to_write, progress_callback)
578+
else:
579+
results = await self.epics.put_many(values_to_write)
567580

568581
# Process results
569582
failures = []
@@ -576,14 +589,20 @@ async def restore_snapshot(self, snapshot_id: str, request: RestoreRequestDTO |
576589
pv_id = pv_id_by_address.get(address, "")
577590
failures.append({"pvId": pv_id, "pvName": address, "error": error or "Unknown error"})
578591

592+
if progress_callback:
593+
await progress_callback(
594+
total_pvs,
595+
total_pvs,
596+
f"Completed restore: {success_count}/{total_pvs} PVs successful",
597+
)
579598
total_time = datetime.now()
580599
logger.info(
581600
f"Restore completed in {(total_time - start_time).total_seconds():.2f}s "
582601
f"({success_count} success, {len(failures)} failures)"
583602
)
584603

585604
return RestoreResultDTO(
586-
totalPVs=len(values_to_write), successCount=success_count, failureCount=len(failures), failures=failures
605+
totalPVs=total_pvs, successCount=success_count, failureCount=len(failures), failures=failures
587606
)
588607

589608
async def compare_snapshots(self, snapshot1_id: str, snapshot2_id: str) -> ComparisonResultDTO:

app/tasks/snapshot_tasks.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ async def restore_snapshot_task(ctx: dict, job_id: str, snapshot_id: str, pv_ids
128128
await job_repo.mark_running(job_id)
129129
await session.commit()
130130

131+
# Create progress callback for job updates
132+
async def on_progress(current: int, total: int, message: str) -> None:
133+
progress = int((current / total) * 100) if total > 0 else 0
134+
await job_repo.update_progress(job_id, progress, message)
135+
await session.commit()
136+
logger.debug(f"Restore job {job_id} progress: {progress}% - {message}")
137+
131138
# Initialize services
132139
epics = ctx.get("epics") or get_epics_service()
133140

@@ -140,16 +147,25 @@ async def restore_snapshot_task(ctx: dict, job_id: str, snapshot_id: str, pv_ids
140147
request = RestoreRequestDTO(pvIds=pv_ids) if pv_ids else None
141148

142149
# Restore the snapshot
143-
result = await snapshot_service.restore_snapshot(snapshot_id, request)
150+
result = await snapshot_service.restore_snapshot(
151+
snapshot_id,
152+
request,
153+
progress_callback=on_progress,
154+
)
144155

145156
# Mark job as completed
146157
result_data = {
147158
"total_pvs": result.totalPVs,
148159
"success_count": result.successCount,
149160
"failure_count": result.failureCount,
150161
}
162+
completion_message = f"Restored {result.successCount}/{result.totalPVs} PVs" + (
163+
f" ({result.failureCount} failed)" if result.failureCount > 0 else ""
164+
)
151165
await job_repo.mark_completed(
152-
job_id, result_id=snapshot_id, message=f"Restored {result.successCount}/{result.totalPVs} PVs"
166+
job_id,
167+
result_id=snapshot_id,
168+
message=completion_message,
153169
)
154170
await session.commit()
155171

0 commit comments

Comments
 (0)