Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion src/blueapi/worker/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
LOGGER = logging.getLogger(__name__)


MODEL_REGISTRY: dict[str, type[BaseModel]] = {}


class Task(BlueapiBaseModel):
"""
Task that will run a plan
Expand Down Expand Up @@ -37,6 +40,7 @@ def do_task(self, ctx: BlueskyContext) -> None:

func = ctx.plan_functions[self.name]
prepared_params = self.prepare_params(ctx)
print(prepared_params)
ctx.run_engine.md.update(self.metadata)
result = ctx.run_engine(func(**prepared_params))
if isinstance(result, tuple): # pragma: no cover
Expand All @@ -45,6 +49,37 @@ def do_task(self, ctx: BlueskyContext) -> None:
return result.plan_result


def register_model(model: type[BaseModel]) -> type[BaseModel]:
MODEL_REGISTRY[model.__name__] = model
return model


def restore_models(obj: Any) -> Any:
if isinstance(obj, list):
return [restore_models(v) for v in obj]

if not isinstance(obj, dict):
return obj

# First recursively restore children
restored = {k: restore_models(v) for k, v in obj.items()}

type_name = restored.get("__type__")
if isinstance(type_name, str) and type_name in MODEL_REGISTRY:
restored.pop("__type__")

model_cls = MODEL_REGISTRY[type_name]

arg_names = restored.pop("__args__", None)
if arg_names:
args = tuple(MODEL_REGISTRY[a] for a in arg_names)
model_cls = model_cls[*args]

return model_cls.model_validate(restored)

return restored


def _lookup_params(ctx: BlueskyContext, task: Task) -> BaseModel:
"""
Checks plan parameters against context
Expand All @@ -60,5 +95,10 @@ def _lookup_params(ctx: BlueskyContext, task: Task) -> BaseModel:

plan = ctx.plans[task.name]
model = plan.model
# Attempt to restore the plan arguments back into a pydantic model by
# checking against registered models.
restored_params = restore_models(dict(task.params))

adapter = TypeAdapter(model)
return adapter.validate_python(task.params)
return adapter.validate_python(restored_params)
# return adapter.validate_python(task.params)
127 changes: 126 additions & 1 deletion tests/unit_tests/worker/test_task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from concurrent.futures import Future
from pathlib import Path
from queue import Full
from typing import Any, TypeVar
from typing import Any, Generic, TypeVar
from unittest.mock import ANY, MagicMock, Mock, patch

import pydantic
Expand All @@ -19,6 +19,7 @@
asserting_span_exporter,
)
from ophyd_async.core import AsyncStatus
from pydantic import BaseModel

from blueapi.config import DeviceSource, EnvironmentConfig
from blueapi.core import BlueskyContext, EventStream
Expand All @@ -36,6 +37,7 @@
WorkerState,
)
from blueapi.worker.event import TaskResult, TaskStatusEnum
from blueapi.worker.task import register_model, restore_models

_SIMPLE_TASK = Task(name="sleep", params={"time": 0.0})
_LONG_TASK = Task(name="sleep", params={"time": 1.0})
Expand Down Expand Up @@ -893,3 +895,126 @@ def test_plan_module_with_composite_devices_can_be_loaded_before_device_module(
params = Task(name="injected_device_plan").prepare_params(context_without_devices)
assert params["composite"].fake_device == fake_device
assert params["composite"].second_fake_device == second_fake_device


T = TypeVar("T")


class GenericPlanArgs(BaseModel):
value1: int


class SpecialisedPlanArgs(GenericPlanArgs, Generic[T]):
value2: T


class SpecialisedPlanArgs2(GenericPlanArgs):
value3: float


def plan_with_model(val: SpecialisedPlanArgs) -> MsgGenerator[T]:
yield from ()
assert isinstance(val, SpecialisedPlanArgs)


def test_plan_args_are_converted_back_to_model(
context: BlueskyContext,
) -> None:
context.register_plan(plan_with_model)

task = Task(name="plan_with_model", params={"val": {"value1": 1, "value2": "test"}})
task.do_task(context)


def plan_using_base_model(val: GenericPlanArgs) -> MsgGenerator[T]:
yield from ()
assert isinstance(val, SpecialisedPlanArgs)


def test_base_model_plan_args_are_converted_back_to_specialised_model(
context: BlueskyContext,
) -> None:
context.register_plan(plan_using_base_model)
register_model(SpecialisedPlanArgs)

task = Task(
name="plan_using_base_model",
params={
"val": {
"__type__": SpecialisedPlanArgs.__name__,
"value1": 1,
"value2": "test",
}
},
)
task.do_task(context)


def plan_with_generic_parameters(val: T) -> MsgGenerator[T]:
yield from ()
assert isinstance(val, SpecialisedPlanArgs)


def test_generic_plan_args_are_converted_back_to_specialised_model(
context: BlueskyContext,
) -> None:
context.register_plan(plan_with_generic_parameters)
register_model(SpecialisedPlanArgs)

task = Task(
name="plan_with_generic_parameters",
params={
"val": {
"__type__": SpecialisedPlanArgs.__name__,
"value1": 1,
"value2": "test",
}
},
)
task.do_task(context)


def test_nested_models_restore_from_types():
register_model(SpecialisedPlanArgs)
register_model(SpecialisedPlanArgs2)

data = {
"__type__": SpecialisedPlanArgs.__name__,
"value1": 1,
"value2": {
"__type__": SpecialisedPlanArgs2.__name__,
"value1": 1,
"value3": 1.5,
},
}

result = restore_models(data)

assert isinstance(result, SpecialisedPlanArgs)
assert result.value1 == 1
assert isinstance(result.value2, SpecialisedPlanArgs2)
assert result.value2.value1 == 1
assert result.value2.value3 == 1.5


def test_nested_models_restore_from_types_and_generic_args():
register_model(SpecialisedPlanArgs)
register_model(SpecialisedPlanArgs2)

data = {
"__type__": SpecialisedPlanArgs.__name__,
"__args__": [SpecialisedPlanArgs2.__name__],
"value1": 1,
"value2": {
"value1": 1,
"value3": 1.5,
},
}

result = restore_models(data)

assert isinstance(result, SpecialisedPlanArgs)
assert result.value1 == 1
assert isinstance(result.value2, SpecialisedPlanArgs2)
assert result.value2.value1 == 1
assert result.value2.value3 == 1.5
Loading