From 29012811d649abc853b62556922972e378078968 Mon Sep 17 00:00:00 2001 From: Oli Wenman Date: Wed, 1 Apr 2026 09:23:52 +0000 Subject: [PATCH 1/4] Restore plan parameters back to model using registered models and __type__ --- src/blueapi/worker/task.py | 42 +++++++- tests/unit_tests/worker/test_task_worker.py | 101 +++++++++++++++++++- 2 files changed, 141 insertions(+), 2 deletions(-) diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index 9ce373c769..0159a7c0e2 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -10,6 +10,9 @@ LOGGER = logging.getLogger(__name__) +MODEL_REGISTRY: dict[str, type[BaseModel]] = {} + + class Task(BlueapiBaseModel): """ Task that will run a plan @@ -37,6 +40,7 @@ def do_task(self, ctx: BlueskyContext) -> None: func = ctx.plan_functions[self.name] prepared_params = self.prepare_params(ctx) + print(prepared_params) ctx.run_engine.md.update(self.metadata) result = ctx.run_engine(func(**prepared_params)) if isinstance(result, tuple): # pragma: no cover @@ -45,6 +49,37 @@ def do_task(self, ctx: BlueskyContext) -> None: return result.plan_result +def register_model(model: type[BaseModel]) -> type[BaseModel]: + MODEL_REGISTRY[model.__name__] = model + return model + + +def restore_models(obj: Any) -> Any: + if isinstance(obj, list): + return [restore_models(v) for v in obj] + + if not isinstance(obj, dict): + return obj + + # First recursively restore children + restored = {k: restore_models(v) for k, v in obj.items()} + + type_name = restored.get("__type__") + if isinstance(type_name, str) and type_name in MODEL_REGISTRY: + restored.pop("__type__") + + model_cls = MODEL_REGISTRY[type_name] + + arg_names = restored.pop("__args__", None) + if arg_names: + args = tuple(MODEL_REGISTRY[a] for a in arg_names) + model_cls = model_cls[*args] + + return model_cls.model_validate(restored) + + return restored + + def _lookup_params(ctx: BlueskyContext, task: Task) -> BaseModel: """ Checks plan parameters against context @@ -60,5 +95,10 @@ def _lookup_params(ctx: BlueskyContext, task: Task) -> BaseModel: plan = ctx.plans[task.name] model = plan.model + # Attempt to restore the plan arguments back into a pydantic model by + # checking against registered models. + restored_params = restore_models(dict(task.params)) + adapter = TypeAdapter(model) - return adapter.validate_python(task.params) + return adapter.validate_python(restored_params) + # return adapter.validate_python(task.params) diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index 4b4d83408c..64a3a1d1e3 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -5,7 +5,7 @@ from concurrent.futures import Future from pathlib import Path from queue import Full -from typing import Any, TypeVar +from typing import Annotated, Any, Generic, Literal, TypeVar from unittest.mock import ANY, MagicMock, Mock, patch import pydantic @@ -19,6 +19,7 @@ asserting_span_exporter, ) from ophyd_async.core import AsyncStatus +from pydantic import BaseModel, Field from blueapi.config import DeviceSource, EnvironmentConfig from blueapi.core import BlueskyContext, EventStream @@ -36,6 +37,7 @@ WorkerState, ) from blueapi.worker.event import TaskResult, TaskStatusEnum +from blueapi.worker.task import register_model, restore_models _SIMPLE_TASK = Task(name="sleep", params={"time": 0.0}) _LONG_TASK = Task(name="sleep", params={"time": 1.0}) @@ -893,3 +895,100 @@ def test_plan_module_with_composite_devices_can_be_loaded_before_device_module( params = Task(name="injected_device_plan").prepare_params(context_without_devices) assert params["composite"].fake_device == fake_device assert params["composite"].second_fake_device == second_fake_device + + +T = TypeVar("T") + + +class GenericPlanArgs(BaseModel): + value1: int + + +class SpecialisedPlanArgs(GenericPlanArgs, Generic[T]): + _region_type: Literal["vgscienta"] = "vgscienta" + value2: T + + +class SpecialisedPlanArgs2(GenericPlanArgs): + _region_type: Literal["specs"] = "specs" + value3: float + + +Region = Annotated[ + SpecialisedPlanArgs | SpecialisedPlanArgs2, + Field(discriminator="_region_type"), +] + + +def plan_with_model(val: SpecialisedPlanArgs) -> MsgGenerator[T]: + yield from () + assert isinstance(val, SpecialisedPlanArgs) + + +def plan_using_base_model(val: Region) -> MsgGenerator[T]: + yield from () + assert isinstance(val, SpecialisedPlanArgs) + + +def plan_with_generic_model(val: T) -> MsgGenerator[T]: + yield from () + assert isinstance(val, SpecialisedPlanArgs) + + +def test_plan_args_are_converted_back_to_model( + context: BlueskyContext, +) -> None: + context.register_plan(plan_with_model) + + task = Task(name="plan_with_model", params={"val": {"value1": 1, "value2": "test"}}) + task.do_task(context) + + +def test_base_model_plan_args_are_converted_back_to_specialised_model( + context: BlueskyContext, +) -> None: + context.register_plan(plan_using_base_model) + + task = Task( + name="plan_using_base_model", params={"val": {"value1": 1, "value2": "test"}} + ) + task.do_task(context) + + +def test_generic_plan_args_are_converted_back_to_specialised_model( + context: BlueskyContext, +) -> None: + context.register_plan(plan_with_generic_model) + register_model(SpecialisedPlanArgs) + + task = Task( + name="plan_with_generic_model", + params={ + "val": { + "__type__": SpecialisedPlanArgs.__name__, + "value1": 1, + "value2": "test", + } + }, + ) + task.do_task(context) + + +def test_nested_models_restore(): + register_model(SpecialisedPlanArgs) + register_model(SpecialisedPlanArgs2) + + data = { + "__type__": SpecialisedPlanArgs.__name__, + "value1": 1, + "value2": { + "__type__": SpecialisedPlanArgs2.__name__, + "value1": 1, + "value3": 1.5, + }, + } + + result = restore_models(data) + + assert isinstance(result, SpecialisedPlanArgs) + assert isinstance(result.value2, SpecialisedPlanArgs2) From 6cd6f73f650dcd56c08be57ae89f5df14fae361f Mon Sep 17 00:00:00 2001 From: Oli Wenman Date: Wed, 1 Apr 2026 09:29:41 +0000 Subject: [PATCH 2/4] Correct test --- tests/unit_tests/worker/test_task_worker.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index 64a3a1d1e3..947d58c56b 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -914,18 +914,12 @@ class SpecialisedPlanArgs2(GenericPlanArgs): value3: float -Region = Annotated[ - SpecialisedPlanArgs | SpecialisedPlanArgs2, - Field(discriminator="_region_type"), -] - - def plan_with_model(val: SpecialisedPlanArgs) -> MsgGenerator[T]: yield from () assert isinstance(val, SpecialisedPlanArgs) -def plan_using_base_model(val: Region) -> MsgGenerator[T]: +def plan_using_base_model(val: GenericPlanArgs) -> MsgGenerator[T]: yield from () assert isinstance(val, SpecialisedPlanArgs) @@ -948,9 +942,17 @@ def test_base_model_plan_args_are_converted_back_to_specialised_model( context: BlueskyContext, ) -> None: context.register_plan(plan_using_base_model) + register_model(SpecialisedPlanArgs) task = Task( - name="plan_using_base_model", params={"val": {"value1": 1, "value2": "test"}} + name="plan_using_base_model", + params={ + "val": { + "__type__": SpecialisedPlanArgs.__name__, + "value1": 1, + "value2": "test", + } + }, ) task.do_task(context) From 7d0e9ee04c4f11a9009bac0e1b5d3bcc367723e3 Mon Sep 17 00:00:00 2001 From: Oli Wenman Date: Wed, 1 Apr 2026 10:01:01 +0000 Subject: [PATCH 3/4] Improve tests --- tests/unit_tests/worker/test_task_worker.py | 54 +++++++++++++++------ 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index 947d58c56b..3729940bc9 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -905,12 +905,10 @@ class GenericPlanArgs(BaseModel): class SpecialisedPlanArgs(GenericPlanArgs, Generic[T]): - _region_type: Literal["vgscienta"] = "vgscienta" value2: T class SpecialisedPlanArgs2(GenericPlanArgs): - _region_type: Literal["specs"] = "specs" value3: float @@ -919,16 +917,6 @@ def plan_with_model(val: SpecialisedPlanArgs) -> MsgGenerator[T]: assert isinstance(val, SpecialisedPlanArgs) -def plan_using_base_model(val: GenericPlanArgs) -> MsgGenerator[T]: - yield from () - assert isinstance(val, SpecialisedPlanArgs) - - -def plan_with_generic_model(val: T) -> MsgGenerator[T]: - yield from () - assert isinstance(val, SpecialisedPlanArgs) - - def test_plan_args_are_converted_back_to_model( context: BlueskyContext, ) -> None: @@ -938,6 +926,11 @@ def test_plan_args_are_converted_back_to_model( task.do_task(context) +def plan_using_base_model(val: GenericPlanArgs) -> MsgGenerator[T]: + yield from () + assert isinstance(val, SpecialisedPlanArgs) + + def test_base_model_plan_args_are_converted_back_to_specialised_model( context: BlueskyContext, ) -> None: @@ -957,14 +950,19 @@ def test_base_model_plan_args_are_converted_back_to_specialised_model( task.do_task(context) +def plan_with_generic_parameters(val: T) -> MsgGenerator[T]: + yield from () + assert isinstance(val, SpecialisedPlanArgs) + + def test_generic_plan_args_are_converted_back_to_specialised_model( context: BlueskyContext, ) -> None: - context.register_plan(plan_with_generic_model) + context.register_plan(plan_with_generic_parameters) register_model(SpecialisedPlanArgs) task = Task( - name="plan_with_generic_model", + name="plan_with_generic_parameters", params={ "val": { "__type__": SpecialisedPlanArgs.__name__, @@ -976,7 +974,7 @@ def test_generic_plan_args_are_converted_back_to_specialised_model( task.do_task(context) -def test_nested_models_restore(): +def test_nested_models_restore_from_types(): register_model(SpecialisedPlanArgs) register_model(SpecialisedPlanArgs2) @@ -993,4 +991,30 @@ def test_nested_models_restore(): result = restore_models(data) assert isinstance(result, SpecialisedPlanArgs) + assert result.value1 == 1 + assert isinstance(result.value2, SpecialisedPlanArgs2) + assert result.value2.value1 == 1 + assert result.value2.value3 == 1.5 + + +def test_nested_models_restore_from_types_and_generic_args(): + register_model(SpecialisedPlanArgs) + register_model(SpecialisedPlanArgs2) + + data = { + "__type__": SpecialisedPlanArgs.__name__, + "__args__": [SpecialisedPlanArgs2.__name__], + "value1": 1, + "value2": { + "value1": 1, + "value3": 1.5, + }, + } + + result = restore_models(data) + + assert isinstance(result, SpecialisedPlanArgs) + assert result.value1 == 1 assert isinstance(result.value2, SpecialisedPlanArgs2) + assert result.value2.value1 == 1 + assert result.value2.value3 == 1.5 From 800392ab37f9bcc0f52baf73299608a97bea2fe2 Mon Sep 17 00:00:00 2001 From: Oli Wenman Date: Wed, 1 Apr 2026 15:27:29 +0100 Subject: [PATCH 4/4] Remove unused imports --- tests/unit_tests/worker/test_task_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index 3729940bc9..df0783e5ad 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -5,7 +5,7 @@ from concurrent.futures import Future from pathlib import Path from queue import Full -from typing import Annotated, Any, Generic, Literal, TypeVar +from typing import Any, Generic, TypeVar from unittest.mock import ANY, MagicMock, Mock, patch import pydantic @@ -19,7 +19,7 @@ asserting_span_exporter, ) from ophyd_async.core import AsyncStatus -from pydantic import BaseModel, Field +from pydantic import BaseModel from blueapi.config import DeviceSource, EnvironmentConfig from blueapi.core import BlueskyContext, EventStream