Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion app/api/v1/snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from app.services.job_service import JobService
from app.services.epics_service import EpicsService, get_epics_service
from app.services.redis_service import get_redis_service
from app.services.background_tasks import run_snapshot_creation
from app.services.background_tasks import run_snapshot_restore, run_snapshot_creation
from app.services.snapshot_service import SnapshotService

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -188,8 +188,11 @@ async def update_snapshot(
@router.post("/{snapshot_id}/restore", dependencies=[Security(require_write_access)])
async def restore_snapshot(
snapshot_id: str,
background_tasks: BackgroundTasks,
request: RestoreRequestDTO | None = None,
db: AsyncSession = Depends(get_db),
async_mode: bool = Query(True, alias="async"),
use_arq: bool = Query(True, description="Use Arq persistent queue (recommended) vs FastAPI BackgroundTasks"),
epics: EpicsService = Depends(get_epics_service),
) -> dict:
"""
Expand All @@ -208,6 +211,47 @@ async def restore_snapshot(
if not snapshot:
raise APIException(404, f"Snapshot {snapshot_id} not found", 404)

if async_mode:
job_service = JobService(db)
job = await job_service.create_job(
JobType.SNAPSHOT_RESTORE,
job_data={"snapshotId": snapshot_id},
)

await db.commit()
pv_ids = request.pvIds if request else None

if use_arq:
pool = await get_arq_pool()
if pool:
try:
await pool.enqueue_job(
"restore_snapshot_task",
job_id=str(job.id),
snapshot_id=snapshot_id,
pv_ids=pv_ids,
)
logger.info(f"Enqueued restore job to Arq: {job.id}")

return success_response(
JobCreatedDTO(
jobId=job.id,
message=f"Snapshot restore queued ({snapshot_id})",
)
)
except Exception as e:
logger.warning(f"Failed to enqueue to Arq, falling back to BackgroundTasks: {e}")

# Fallback to FastAPI BackgroundTasks
background_tasks.add_task(run_snapshot_restore, str(job.id), snapshot_id, pv_ids)
logger.info(f"Scheduled restore job via BackgroundTasks: {job.id}")
return success_response(
JobCreatedDTO(
jobId=job.id,
message=f"Snapshot restore started ({snapshot_id})",
)
)

result = await service.restore_snapshot(snapshot_id, request)
return success_response(result)

Expand Down
35 changes: 27 additions & 8 deletions app/repositories/job_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,34 @@ async def mark_running(self, job_id: str) -> Job | None:
"""Mark a job as running."""
return await self.update_status(job_id, JobStatus.RUNNING, progress=0, message="Job started")

async def mark_completed(self, job_id: str, result_id: str | None = None, message: str | None = None) -> Job | None:
async def mark_completed(
self,
job_id: str,
result_id: str | None = None,
message: str | None = None,
result_data: dict | None = None,
) -> Job | None:
"""Mark a job as completed."""
return await self.update_status(
job_id,
JobStatus.COMPLETED,
progress=100,
result_id=result_id,
message=message or "Job completed successfully",
)
job = await self.get_by_id(job_id)
if not job:
return None

job.status = JobStatus.COMPLETED.value
job.progress = 100
job.message = message or "Job completed successfully"
job.completed_at = datetime.now()

if result_id is not None:
job.result_id = result_id

if result_data:
existing = dict(job.job_data or {})
existing["result"] = result_data
job.job_data = existing

await self.session.flush()
await self.session.refresh(job)
return job

async def mark_failed(self, job_id: str, error: str) -> Job | None:
"""Mark a job as failed."""
Expand Down
1 change: 1 addition & 0 deletions app/schemas/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class JobDTO(BaseModel):
message: str | None = None
resultId: str | None = None
error: str | None = None
jobData: dict | None = None
createdAt: datetime
startedAt: datetime | None = None
completedAt: datetime | None = None
Expand Down
102 changes: 101 additions & 1 deletion app/services/background_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime

from app.db.session import async_session_maker
from app.schemas.snapshot import NewSnapshotDTO
from app.schemas.snapshot import NewSnapshotDTO, RestoreRequestDTO
from app.services.epics_service import get_epics_service
from app.services.redis_service import get_redis_service
from app.services.snapshot_service import SnapshotService
Expand Down Expand Up @@ -111,6 +111,106 @@ async def progress_update(current: int, total: int, message: str):
logger.exception(f"Failed to update job status: {inner_e}")


async def run_snapshot_restore(
job_id: str,
snapshot_id: str,
pv_ids: list[str] | None = None,
) -> None:
"""
Background task to restore a snapshot.

This runs in a separate asyncio task and uses its own database session.
"""
logger.info(f"Background task started for job {job_id}: Restoring snapshot '{snapshot_id}'")

async with async_session_maker() as session:
try:
job_repo = JobRepository(session)

# Mark job as running
await job_repo.mark_running(job_id)
await session.commit()
await asyncio.sleep(0)

# Initial progress update
await job_repo.update_progress(job_id, 5, "Loading snapshot values...")
await session.commit()
await asyncio.sleep(0)

epics = get_epics_service()
snapshot_service = SnapshotService(session, epics)

# Optional restore request
request = RestoreRequestDTO(pvIds=pv_ids) if pv_ids else None

last_update = {"progress": 5, "last_time": datetime.now()}

async def progress_update(current: int, total: int, message: str):
try:
write_progress = int((current / total) * 85) if total > 0 else 0
job_progress = 10 + write_progress

now = datetime.now()
time_elapsed = (now - last_update["last_time"]).total_seconds()
progress_changed = job_progress - last_update["progress"] >= 2

if progress_changed or time_elapsed >= 2.0 or current >= total:
last_update["progress"] = job_progress
last_update["last_time"] = now
await job_repo.update_progress(job_id, job_progress, message)
await session.commit()
await asyncio.sleep(0)
except Exception as e:
logger.error(f"Error in restore progress_update: {e}")

result = await snapshot_service.restore_snapshot(
snapshot_id,
request,
progress_callback=progress_update,
)

# Build result data with capped failures
result_data = {
"total_pvs": result.totalPVs,
"success_count": result.successCount,
"failure_count": result.failureCount,
"failures": [
{"pvId": f["pvId"], "pvName": f["pvName"], "error": f["error"]} for f in result.failures[:50]
],
}

# Final completion update
if result.failureCount > 0:
completion_message = (
f"Restored {result.successCount:,}/{result.totalPVs:,} PVs " f"({result.failureCount} failed)"
)
else:
completion_message = f"All {result.totalPVs:,} PVs have been restored to their snapshot values."

await job_repo.mark_completed(
job_id,
message=completion_message,
result_data=result_data,
)
await session.commit()

logger.info(
f"Background restore completed for job {job_id}: "
f"{result.successCount}/{result.totalPVs} succeeded, "
f"{result.failureCount} failed"
)

except Exception as e:
logger.exception(f"Background restore failed for job {job_id}: {e}")
error_msg = f"{type(e).__name__}: {str(e)}"
try:
await session.rollback()
await job_repo.mark_failed(job_id, error_msg)
await session.commit()
except Exception as inner_e:
logger.exception(f"Failed to update restore job status: {inner_e}")


def schedule_snapshot_creation(job_id: str, title: str, description: str | None = None, use_cache: bool = True) -> None:
"""
Schedule a snapshot creation task to run in the background.
Expand Down
59 changes: 57 additions & 2 deletions app/services/epics_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime
from collections.abc import Callable

from aioca import FORMAT_TIME, caget, caput, connect, purge_channel_caches
from aioca import FORMAT_TIME, CANothing, caget, caput, connect, purge_channel_caches

from app.config import get_settings
from app.services.epics_types import EpicsValue
Expand Down Expand Up @@ -109,6 +109,13 @@ def _sanitize_value(self, value: Any) -> Any:
return self._sanitize_value(value.tolist())
return value

def _ca_error_message(self, error_msg: CANothing) -> str:
"""Convert CA error result to a more user-friendly message."""
msg = str(error_msg).strip()
if "user specified timeout" in msg.lower():
return "Connection timeout"
return msg if msg else "Unknown error"

def _augmented_to_epics_value(self, pv_name: str, result) -> EpicsValue:
"""Convert aioca AugmentedValue to our EpicsValue dataclass."""
if not result.ok:
Expand Down Expand Up @@ -421,7 +428,7 @@ async def put_many(self, values: dict[str, Any]) -> dict[str, tuple[bool, str |
if result.ok:
results[original] = (True, None)
else:
results[original] = (False, f"Failed: {getattr(result, 'errorcode', 'unknown')}")
results[original] = (False, self._ca_error_message(result))

except Exception as e:
logger.error(f"Batch put error: {e}")
Expand Down Expand Up @@ -471,6 +478,54 @@ async def put_many(self, values: dict[str, Any]) -> dict[str, tuple[bool, str |

return results

async def put_many_with_progress(
self,
values: dict[str, Any],
progress_callback: Callable | None = None,
) -> dict[str, tuple[bool, str | None]]:
"""
Put values to multiple PVs with progress tracking.
Can be used to update the user on progress when a snapshot restore is initiated.
"""
total_pvs = len(values)
results: dict[str, tuple[bool, str | None]] = {}

logger.info(f"Starting put_many_with_progress for {total_pvs} PVs")

if progress_callback:
await progress_callback(0, total_pvs, f"Starting restore of {total_pvs:,} PVs")

items = list(values.items())
batch_size = self._chunk_size

for i in range(0, total_pvs, batch_size):
batch_items = items[i : i + batch_size]
batch_values = dict(batch_items)

try:
batch_results = await self.put_many(batch_values)
results.update(batch_results)
except Exception as e:
logger.error(f"Chunk put error ({i}-{i + len(batch_items)}): {e}")
for pv_name, _ in batch_items:
if pv_name not in results:
results[pv_name] = (False, str(e))

current = min(i + batch_size, total_pvs)
success_count = sum(1 for ok, _ in results.values() if ok)

if progress_callback:
await progress_callback(
current,
total_pvs,
f"{current:,}/{total_pvs:,} PVs",
)

logger.info(f"Restored {current:,}/{total_pvs:,} PVs " f"({success_count:,} successful)")

logger.info(f"Completed put_many_with_progress: {len(results)}/{total_pvs} PVs processed")
return results

async def shutdown(self):
"""Cleanup resources."""
# aioca manages its own connections via libca
Expand Down
1 change: 1 addition & 0 deletions app/services/job_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def _to_dto(self, job: Job) -> JobDTO:
message=job.message,
resultId=job.result_id,
error=job.error,
jobData=job.job_data,
createdAt=job.created_at,
startedAt=job.started_at,
completedAt=job.completed_at,
Expand Down
27 changes: 23 additions & 4 deletions app/services/snapshot_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,12 @@ def _format_ts(ts: float | None) -> str | None:
logger.exception(f"Error creating snapshot from cache '{data.title}': {e}")
raise

async def restore_snapshot(self, snapshot_id: str, request: RestoreRequestDTO | None = None) -> RestoreResultDTO:
async def restore_snapshot(
self,
snapshot_id: str,
request: RestoreRequestDTO | None = None,
progress_callback: Callable | None = None,
) -> RestoreResultDTO:
"""
Restore PV values from a snapshot to EPICS.

Expand Down Expand Up @@ -560,10 +565,18 @@ async def restore_snapshot(self, snapshot_id: str, request: RestoreRequestDTO |
values_to_write[pv.setpoint_address] = write_value
pv_id_by_address[pv.setpoint_address] = pv.id

logger.info(f"Writing {len(values_to_write)} PV values")
total_pvs = len(values_to_write)
logger.info(f"Writing {total_pvs} PV values")

# Initial progress update
if progress_callback:
await progress_callback(0, total_pvs, f"Starting restore of {total_pvs:,} PVs")

# Write to EPICS in parallel
results = await self.epics.put_many(values_to_write)
if progress_callback:
results = await self.epics.put_many_with_progress(values_to_write, progress_callback)
else:
results = await self.epics.put_many(values_to_write)

# Process results
failures = []
Expand All @@ -576,14 +589,20 @@ async def restore_snapshot(self, snapshot_id: str, request: RestoreRequestDTO |
pv_id = pv_id_by_address.get(address, "")
failures.append({"pvId": pv_id, "pvName": address, "error": error or "Unknown error"})

if progress_callback:
await progress_callback(
total_pvs,
total_pvs,
f"{success_count:,}/{total_pvs:,} PVs restored",
)
total_time = datetime.now()
logger.info(
f"Restore completed in {(total_time - start_time).total_seconds():.2f}s "
f"({success_count} success, {len(failures)} failures)"
)

return RestoreResultDTO(
totalPVs=len(values_to_write), successCount=success_count, failureCount=len(failures), failures=failures
totalPVs=total_pvs, successCount=success_count, failureCount=len(failures), failures=failures
)

async def compare_snapshots(self, snapshot1_id: str, snapshot2_id: str) -> ComparisonResultDTO:
Expand Down
Loading
Loading