diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index 9ce373c769..0159a7c0e2 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -10,6 +10,9 @@ LOGGER = logging.getLogger(__name__) +MODEL_REGISTRY: dict[str, type[BaseModel]] = {} + + class Task(BlueapiBaseModel): """ Task that will run a plan @@ -37,6 +40,7 @@ def do_task(self, ctx: BlueskyContext) -> None: func = ctx.plan_functions[self.name] prepared_params = self.prepare_params(ctx) + print(prepared_params) ctx.run_engine.md.update(self.metadata) result = ctx.run_engine(func(**prepared_params)) if isinstance(result, tuple): # pragma: no cover @@ -45,6 +49,37 @@ def do_task(self, ctx: BlueskyContext) -> None: return result.plan_result +def register_model(model: type[BaseModel]) -> type[BaseModel]: + MODEL_REGISTRY[model.__name__] = model + return model + + +def restore_models(obj: Any) -> Any: + if isinstance(obj, list): + return [restore_models(v) for v in obj] + + if not isinstance(obj, dict): + return obj + + # First recursively restore children + restored = {k: restore_models(v) for k, v in obj.items()} + + type_name = restored.get("__type__") + if isinstance(type_name, str) and type_name in MODEL_REGISTRY: + restored.pop("__type__") + + model_cls = MODEL_REGISTRY[type_name] + + arg_names = restored.pop("__args__", None) + if arg_names: + args = tuple(MODEL_REGISTRY[a] for a in arg_names) + model_cls = model_cls[*args] + + return model_cls.model_validate(restored) + + return restored + + def _lookup_params(ctx: BlueskyContext, task: Task) -> BaseModel: """ Checks plan parameters against context @@ -60,5 +95,10 @@ def _lookup_params(ctx: BlueskyContext, task: Task) -> BaseModel: plan = ctx.plans[task.name] model = plan.model + # Attempt to restore the plan arguments back into a pydantic model by + # checking against registered models. + restored_params = restore_models(dict(task.params)) + adapter = TypeAdapter(model) - return adapter.validate_python(task.params) + return adapter.validate_python(restored_params) + # return adapter.validate_python(task.params) diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index 4b4d83408c..df0783e5ad 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -5,7 +5,7 @@ from concurrent.futures import Future from pathlib import Path from queue import Full -from typing import Any, TypeVar +from typing import Any, Generic, TypeVar from unittest.mock import ANY, MagicMock, Mock, patch import pydantic @@ -19,6 +19,7 @@ asserting_span_exporter, ) from ophyd_async.core import AsyncStatus +from pydantic import BaseModel from blueapi.config import DeviceSource, EnvironmentConfig from blueapi.core import BlueskyContext, EventStream @@ -36,6 +37,7 @@ WorkerState, ) from blueapi.worker.event import TaskResult, TaskStatusEnum +from blueapi.worker.task import register_model, restore_models _SIMPLE_TASK = Task(name="sleep", params={"time": 0.0}) _LONG_TASK = Task(name="sleep", params={"time": 1.0}) @@ -893,3 +895,126 @@ def test_plan_module_with_composite_devices_can_be_loaded_before_device_module( params = Task(name="injected_device_plan").prepare_params(context_without_devices) assert params["composite"].fake_device == fake_device assert params["composite"].second_fake_device == second_fake_device + + +T = TypeVar("T") + + +class GenericPlanArgs(BaseModel): + value1: int + + +class SpecialisedPlanArgs(GenericPlanArgs, Generic[T]): + value2: T + + +class SpecialisedPlanArgs2(GenericPlanArgs): + value3: float + + +def plan_with_model(val: SpecialisedPlanArgs) -> MsgGenerator[T]: + yield from () + assert isinstance(val, SpecialisedPlanArgs) + + +def test_plan_args_are_converted_back_to_model( + context: BlueskyContext, +) -> None: + context.register_plan(plan_with_model) + + task = Task(name="plan_with_model", params={"val": {"value1": 1, "value2": "test"}}) + task.do_task(context) + + +def plan_using_base_model(val: GenericPlanArgs) -> MsgGenerator[T]: + yield from () + assert isinstance(val, SpecialisedPlanArgs) + + +def test_base_model_plan_args_are_converted_back_to_specialised_model( + context: BlueskyContext, +) -> None: + context.register_plan(plan_using_base_model) + register_model(SpecialisedPlanArgs) + + task = Task( + name="plan_using_base_model", + params={ + "val": { + "__type__": SpecialisedPlanArgs.__name__, + "value1": 1, + "value2": "test", + } + }, + ) + task.do_task(context) + + +def plan_with_generic_parameters(val: T) -> MsgGenerator[T]: + yield from () + assert isinstance(val, SpecialisedPlanArgs) + + +def test_generic_plan_args_are_converted_back_to_specialised_model( + context: BlueskyContext, +) -> None: + context.register_plan(plan_with_generic_parameters) + register_model(SpecialisedPlanArgs) + + task = Task( + name="plan_with_generic_parameters", + params={ + "val": { + "__type__": SpecialisedPlanArgs.__name__, + "value1": 1, + "value2": "test", + } + }, + ) + task.do_task(context) + + +def test_nested_models_restore_from_types(): + register_model(SpecialisedPlanArgs) + register_model(SpecialisedPlanArgs2) + + data = { + "__type__": SpecialisedPlanArgs.__name__, + "value1": 1, + "value2": { + "__type__": SpecialisedPlanArgs2.__name__, + "value1": 1, + "value3": 1.5, + }, + } + + result = restore_models(data) + + assert isinstance(result, SpecialisedPlanArgs) + assert result.value1 == 1 + assert isinstance(result.value2, SpecialisedPlanArgs2) + assert result.value2.value1 == 1 + assert result.value2.value3 == 1.5 + + +def test_nested_models_restore_from_types_and_generic_args(): + register_model(SpecialisedPlanArgs) + register_model(SpecialisedPlanArgs2) + + data = { + "__type__": SpecialisedPlanArgs.__name__, + "__args__": [SpecialisedPlanArgs2.__name__], + "value1": 1, + "value2": { + "value1": 1, + "value3": 1.5, + }, + } + + result = restore_models(data) + + assert isinstance(result, SpecialisedPlanArgs) + assert result.value1 == 1 + assert isinstance(result.value2, SpecialisedPlanArgs2) + assert result.value2.value1 == 1 + assert result.value2.value3 == 1.5