diff --git a/docs/changes/newsfragments/8041.new b/docs/changes/newsfragments/8041.new new file mode 100644 index 00000000000..495291e823e --- /dev/null +++ b/docs/changes/newsfragments/8041.new @@ -0,0 +1,4 @@ +Added ``StructParameter``, a new parameter type that returns structured data +as a Python ``dataclass`` or Pydantic v2 ``BaseModel``. Each field of the struct +is automatically unpacked into a separate dataset column when used with +``Measurement``. Pydantic support is optional. diff --git a/docs/examples/Parameters/StructParameter.ipynb b/docs/examples/Parameters/StructParameter.ipynb new file mode 100644 index 00000000000..1df17a3c692 --- /dev/null +++ b/docs/examples/Parameters/StructParameter.ipynb @@ -0,0 +1,563 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# StructParameter\n", + "\n", + "A `StructParameter` returns structured data — a Python `dataclass` or Pydantic v2 `BaseModel` —\n", + "and automatically unpacks each field into a separate dataset column when used in a\n", + "`Measurement`.\n", + "\n", + "This is useful when a single instrument query returns multiple related values\n", + "(e.g. voltage *and* current, or magnitude *and* phase) that belong together logically\n", + "but should be stored as individual columns in the dataset.\n", + "\n", + "Key properties:\n", + "\n", + "- **Get-only** — `StructParameter` is a subclass of `ParameterBase` with no set support.\n", + "- **Automatic field introspection** — field names, types, labels, and units are derived from the struct definition.\n", + "- **Dataset integration** — `Measurement.register_parameter` registers each field as a separate dependent parameter.\n", + "- **Pydantic support is optional** — works with plain dataclasses out of the box; Pydantic v2 is supported when installed." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import tempfile\n", + "from dataclasses import dataclass\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "\n", + "from qcodes.dataset import (\n", + " Measurement,\n", + " initialise_or_create_database_at,\n", + " load_or_create_experiment,\n", + ")\n", + "from qcodes.parameters import ManualParameter, StructParameter\n", + "from qcodes.station import Station" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a temporary database for this example\n", + "db_path = Path(tempfile.gettempdir()) / \"struct_parameter_example.db\"\n", + "initialise_or_create_database_at(str(db_path))\n", + "exp = load_or_create_experiment(\"struct_parameter_tutorial\", sample_name=\"no sample\")\n", + "station = Station()\n", + "rng = np.random.default_rng(42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Defining a struct type\n", + "\n", + "The struct type is a plain Python `dataclass` (or a Pydantic `BaseModel`) whose fields\n", + "describe the shape of the data returned by the parameter." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "@dataclass\n", + "class IVResult:\n", + " \"\"\"Result of a simultaneous I-V measurement.\"\"\"\n", + "\n", + " voltage: float\n", + " current: float" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a StructParameter\n", + "\n", + "To create a `StructParameter` you subclass it and implement `get_raw`.\n", + "The `struct_type` argument tells the parameter which dataclass to expect." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "class SimulatedIV(StructParameter[IVResult, None]):\n", + " \"\"\"Simulated I-V measurement that returns voltage and current.\"\"\"\n", + "\n", + " def get_raw(self) -> IVResult:\n", + " # In a real driver this would query the instrument\n", + " v = rng.uniform(0.0, 1.0)\n", + " i = v / 1000.0 + rng.normal(0, 1e-5)\n", + " return IVResult(voltage=v, current=i)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "iv = SimulatedIV(\n", + " \"iv\",\n", + " struct_type=IVResult,\n", + " field_labels={\"voltage\": \"Voltage\", \"current\": \"Current\"},\n", + " field_units={\"voltage\": \"V\", \"current\": \"A\"},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Calling `get()` returns the full dataclass instance:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "IVResult(voltage=0.7739560485559633, current=0.0007635562074935583)\n", + " voltage = 0.7740 V\n", + " current = 0.000764 A\n" + ] + } + ], + "source": [ + "result = iv.get()\n", + "print(result)\n", + "print(f\" voltage = {result.voltage:.4f} V\")\n", + "print(f\" current = {result.current:.6f} A\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Inspecting the parameter\n", + "\n", + "The parameter exposes metadata about its fields:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "names: ('voltage', 'current')\n", + "labels: ('Voltage', 'Current')\n", + "units: ('V', 'A')\n", + "full_names: ('iv_voltage', 'iv_current')\n" + ] + } + ], + "source": [ + "print(\"names: \", iv.names)\n", + "print(\"labels:\", iv.labels)\n", + "print(\"units: \", iv.units)\n", + "print(\"full_names:\", iv.full_names)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Each field has a synthetic child parameter that carries its own label, unit, and paramtype.\n", + "These are accessible via `field_parameters`:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " voltage: register_name='iv_voltage', label='Voltage', unit='V'\n", + " current: register_name='iv_current', label='Current', unit='A'\n" + ] + } + ], + "source": [ + "for name, fp in iv.field_parameters.items():\n", + " print(\n", + " f\" {name}: register_name={fp.register_name!r}, \"\n", + " f\"label={fp.label!r}, unit={fp.unit!r}\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using StructParameter in a Measurement\n", + "\n", + "When you register a `StructParameter` with a `Measurement`, each field is registered\n", + "as a separate dependent parameter. The struct parameter itself is *not* stored in the\n", + "dataset — only the individual fields are.\n", + "\n", + "When adding results, you pass the struct parameter and the dataclass instance.\n", + "The `Measurement` automatically unpacks the struct into its field values." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'gate'" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a setpoint parameter\n", + "gate = ManualParameter(\"gate\", label=\"Gate Voltage\", unit=\"V\")\n", + "station.add_component(gate)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting experimental run with id: 2. \n", + "Run ID: 2\n", + "Parameters in dataset: ['gate', 'iv_voltage', 'iv_current']\n" + ] + } + ], + "source": [ + "meas = Measurement(exp=exp, name=\"iv_sweep\")\n", + "meas.register_parameter(gate)\n", + "meas.register_parameter(iv, setpoints=(gate,))\n", + "\n", + "with meas.run() as datasaver:\n", + " for v_gate in np.linspace(0, 1, 20):\n", + " gate.set(v_gate)\n", + " iv_result = iv.get()\n", + " datasaver.add_result(\n", + " (gate, v_gate),\n", + " (iv, iv_result), # type: ignore[arg-type]\n", + " )\n", + "\n", + "dataset = datasaver.dataset\n", + "print(f\"Run ID: {dataset.run_id}\")\n", + "print(f\"Parameters in dataset: {[p.name for p in dataset.get_parameters()]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that the dataset contains `gate`, `iv_voltage`, and `iv_current` — the struct\n", + "has been unpacked into individual columns." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iv_currentiv_voltage
gate
0.0000000.0008680.858598
0.0526320.0000810.094177
0.1052630.0007580.761140
0.1578950.0001200.128114
0.2105260.0003790.370798
\n", + "
" + ], + "text/plain": [ + " iv_current iv_voltage\n", + "gate \n", + "0.000000 0.000868 0.858598\n", + "0.052632 0.000081 0.094177\n", + "0.105263 0.000758 0.761140\n", + "0.157895 0.000120 0.128114\n", + "0.210526 0.000379 0.370798" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = dataset.to_pandas_dataframe()\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Customising field labels, units, and paramtypes\n", + "\n", + "You can customise each field's label, unit, and paramtype via\n", + "`field_labels`, `field_units`, and `field_paramtypes` arguments.\n", + "\n", + "Type inference is automatic:\n", + "- `float`, `int`, `bool` → `\"numeric\"`\n", + "- `str` → `\"text\"`\n", + "- `complex` → `\"complex\"`\n", + "- `numpy.ndarray` → `\"array\"`\n", + "\n", + "You can override the inferred type using `field_paramtypes`." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Channel: CH1\n", + "Sample rate: 1000000 Sa/s\n", + "Waveform shape: (1000,)\n" + ] + } + ], + "source": [ + "@dataclass\n", + "class ScopeTrace:\n", + " \"\"\"Oscilloscope trace with metadata.\"\"\"\n", + "\n", + " waveform: np.ndarray\n", + " sample_rate: float\n", + " channel_name: str\n", + "\n", + "\n", + "class SimulatedScope(StructParameter[ScopeTrace, None]):\n", + " def get_raw(self) -> ScopeTrace:\n", + " t = np.linspace(0, 1e-3, 1000)\n", + " signal = np.sin(2 * np.pi * 1000 * t) + rng.normal(0, 0.1, len(t))\n", + " return ScopeTrace(\n", + " waveform=signal,\n", + " sample_rate=1e6,\n", + " channel_name=\"CH1\",\n", + " )\n", + "\n", + "\n", + "scope = SimulatedScope(\n", + " \"scope\",\n", + " struct_type=ScopeTrace,\n", + " field_labels={\n", + " \"waveform\": \"Signal\",\n", + " \"sample_rate\": \"Sample Rate\",\n", + " \"channel_name\": \"Channel\",\n", + " },\n", + " field_units={\"waveform\": \"V\", \"sample_rate\": \"Sa/s\"},\n", + ")\n", + "\n", + "trace = scope.get()\n", + "print(f\"Channel: {trace.channel_name}\")\n", + "print(f\"Sample rate: {trace.sample_rate:.0f} Sa/s\")\n", + "print(f\"Waveform shape: {trace.waveform.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using with a lambda get function\n", + "\n", + "For simple cases where you don't need a subclass, you can pass a `get_cmd` callable\n", + "that returns the struct instance. This works because `StructParameter` inherits\n", + "the `get_cmd` support from `ParameterBase`." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TemperatureReading(temperature=22.56985663453715, humidity=45.10192588737266)\n" + ] + } + ], + "source": [ + "@dataclass\n", + "class TemperatureReading:\n", + " \"\"\"Temperature sensor reading.\"\"\"\n", + "\n", + " temperature: float\n", + " humidity: float\n", + "\n", + "\n", + "def fake_read_sensor() -> TemperatureReading:\n", + " return TemperatureReading(\n", + " temperature=22.5 + rng.normal(0, 0.1),\n", + " humidity=45.0 + rng.normal(0, 1.0),\n", + " )\n", + "\n", + "\n", + "temp_sensor = StructParameter(\n", + " \"environment\",\n", + " struct_type=TemperatureReading,\n", + " get_cmd=fake_read_sensor,\n", + " field_labels={\"temperature\": \"Temperature\", \"humidity\": \"Relative Humidity\"},\n", + " field_units={\"temperature\": \"°C\", \"humidity\": \"%\"},\n", + ")\n", + "\n", + "print(temp_sensor.get())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Snapshot behaviour\n", + "\n", + "By default, `StructParameter` sets `snapshot_value=False` so that the snapshot\n", + "does not call `get()` on the instrument during snapshotting. The snapshot still\n", + "includes metadata like the struct type name, field names, labels, and units." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " name: iv\n", + " struct_type_name: IVResult\n", + " names: ('voltage', 'current')\n", + " labels: ('Voltage', 'Current')\n", + " units: ('V', 'A')\n" + ] + } + ], + "source": [ + "snap = iv.snapshot()\n", + "for key in (\"name\", \"struct_type_name\", \"names\", \"labels\", \"units\"):\n", + " print(f\" {key}: {snap[key]}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/qcodes/__init__.py b/src/qcodes/__init__.py index e6f0b8d4674..13371e2908c 100644 --- a/src/qcodes/__init__.py +++ b/src/qcodes/__init__.py @@ -65,6 +65,7 @@ Parameter, ParameterWithSetpoints, ScaledParameter, + StructParameter, SweepFixedValues, SweepValues, combine, diff --git a/src/qcodes/dataset/measurements.py b/src/qcodes/dataset/measurements.py index cb11a9246af..14b57b0220e 100644 --- a/src/qcodes/dataset/measurements.py +++ b/src/qcodes/dataset/measurements.py @@ -52,6 +52,7 @@ ParameterBase, ParameterWithSetpoints, ParamSpecBase, + StructParameter, ) from qcodes.station import Station from qcodes.utils import DelayedKeyboardInterrupt @@ -1083,6 +1084,8 @@ def register_parameter( basis, paramtype, ) + case StructParameter(): + self._register_structparameter(parameter, setpoints, basis) case ParameterBase() | ParameterWithSetpoints(): if paramtype is not None: parameter.paramtype = paramtype @@ -1360,6 +1363,30 @@ def _register_multiparameter( paramtype, ) + def _register_structparameter( + self, + parameter: StructParameter, + setpoints: SetpointsType | None, + basis: SetpointsType | None, + ) -> None: + """Register each field of a StructParameter as a separate dataset column. + + The struct parameter itself is not stored in the dataset. Only its + individual field parameters are registered as dependents on the + supplied setpoints. + """ + for field_param in parameter.field_parameters.values(): + field_paramtype = field_param.paramtype + self._register_parameter( + name=field_param.register_name, + label=field_param.label, + unit=field_param.unit, + setpoints=setpoints, + basis=basis, + paramtype=field_paramtype, + ) + self._registered_parameters.add(field_param) + def register_custom_parameter( self: Self, name: str, diff --git a/src/qcodes/parameters/__init__.py b/src/qcodes/parameters/__init__.py index 4ff68a69fac..a95a31780d8 100644 --- a/src/qcodes/parameters/__init__.py +++ b/src/qcodes/parameters/__init__.py @@ -89,6 +89,7 @@ from .parameter_with_setpoints import ParameterWithSetpoints, expand_setpoints_helper from .scaled_paramter import ScaledParameter from .specialized_parameters import ElapsedTimeParameter, InstrumentRefParameter +from .struct_parameter import StructParameter from .sweep_values import SweepFixedValues, SweepValues from .val_mapping import create_on_off_val_mapping @@ -118,6 +119,7 @@ "ParameterSet", "ParameterWithSetpoints", "ScaledParameter", + "StructParameter", "SweepFixedValues", "SweepValues", "combine", diff --git a/src/qcodes/parameters/struct_parameter.py b/src/qcodes/parameters/struct_parameter.py new file mode 100644 index 00000000000..0a4767ac02a --- /dev/null +++ b/src/qcodes/parameters/struct_parameter.py @@ -0,0 +1,395 @@ +"""Parameter that returns structured data as a dataclass or Pydantic BaseModel. + +A :class:`StructParameter` wraps a ``get_raw`` that returns a dataclass or +Pydantic v2 :class:`~pydantic.BaseModel` instance. Each field of the struct +is automatically unpacked into a separate dataset column when used in a +:class:`~qcodes.dataset.Measurement`. + +Pydantic support is optional and is enabled when ``pydantic`` is installed. +Only Pydantic v2 (``pydantic.BaseModel`` with ``model_fields``) is supported. +""" + +from __future__ import annotations + +import dataclasses +import os +import typing +from types import MethodType +from typing import TYPE_CHECKING, Any, Generic + +import numpy as np + +from qcodes.validators import Arrays, ComplexNumbers, Numbers, Strings + +from .parameter_base import ( + InstrumentTypeVar_co, + ParameterBase, + ParameterBaseKWArgs, + ParameterDataTypeVar, +) + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + + from typing_extensions import Unpack + + from qcodes.dataset.data_set_protocol import ValuesType + from qcodes.validators import Validator + + +# Supported paramtypes for dataset columns +_ALLOWED_PARAMTYPES = frozenset({"numeric", "text", "complex", "array"}) + + +def _is_pydantic_model_class(cls: type) -> bool: + """Check if a class is a Pydantic v2 BaseModel subclass.""" + try: + from pydantic import ( # noqa: PLC0415 # pyright: ignore[reportMissingImports] + BaseModel, + ) + + return isinstance(cls, type) and issubclass(cls, BaseModel) + except ImportError: + return False + + +def _get_struct_fields(struct_type: type) -> list[tuple[str, type]]: + """Extract (name, annotation) pairs from a dataclass or Pydantic model. + + Args: + struct_type: A dataclass or Pydantic v2 BaseModel class. + + Returns: + List of ``(field_name, field_type)`` tuples. + + Raises: + TypeError: If ``struct_type`` is neither a dataclass nor a Pydantic + BaseModel. + + """ + if dataclasses.is_dataclass(struct_type): + # Use get_type_hints to resolve string annotations from + # `from __future__ import annotations` + try: + hints = typing.get_type_hints(struct_type) + except Exception: + hints = {} + return [ + (f.name, hints.get(f.name, object)) + for f in dataclasses.fields(struct_type) # type: ignore[arg-type] + ] + + if _is_pydantic_model_class(struct_type): + # Pydantic v2 resolves annotations itself, so model_fields + # already has the resolved type in field.annotation + fields = struct_type.model_fields # type: ignore[union-attr] + result: list[tuple[str, type]] = [] + for name, field in fields.items(): + ann: type = ( + field.annotation if isinstance(field.annotation, type) else object + ) + result.append((name, ann)) + return result + + raise TypeError( + f"struct_type must be a dataclass or Pydantic v2 BaseModel, got {struct_type!r}" + ) + + +def _infer_paramtype_from_annotation(annotation: type) -> str: + """Map a Python type annotation to a QCoDeS paramtype string. + + Args: + annotation: The type annotation of a struct field. + + Returns: + One of ``"numeric"``, ``"text"``, ``"complex"``, or ``"array"``. + + Raises: + TypeError: If the annotation maps to an unsupported or nested type. + + """ + # Handle basic types + if annotation in (float, int, bool): + return "numeric" + if annotation is str: + return "text" + if annotation is complex: + return "complex" + if annotation is np.ndarray: + return "array" + + # Reject nested dataclasses and Pydantic models + if dataclasses.is_dataclass(annotation) or _is_pydantic_model_class(annotation): + raise TypeError( + f"Nested structured types are not supported as struct fields: " + f"{annotation!r}" + ) + + # Default to numeric for unknown types (int subclasses, enums, etc.) + return "numeric" + + +def _validator_for_paramtype(paramtype: str) -> Validator[Any]: + """Create a QCoDeS validator matching the given paramtype.""" + match paramtype: + case "numeric": + return Numbers() + case "text": + return Strings() + case "complex": + return ComplexNumbers() + case "array": + return Arrays() + case _: + raise ValueError(f"Unknown paramtype: {paramtype!r}") + + +def _extract_field_value(struct_instance: Any, field_name: str) -> Any: + """Extract a field value from a dataclass or Pydantic model instance.""" + return getattr(struct_instance, field_name) + + +class _FieldParameter(ParameterBase[Any, None]): + """Synthetic parameter representing a single field of a StructParameter. + + These parameters are not independently gettable or settable. They exist + solely for dataset registration and field-value storage during unpacking. + """ + + def __init__( + self, + name: str, + *, + label: str | None = None, + unit: str | None = None, + paramtype: str = "numeric", + ) -> None: + super().__init__( + name, + bind_to_instrument=False, + snapshot_value=False, + ) + self.label = label or name + self.unit = unit or "" + self._set_paramtype(paramtype) + + +class StructParameter( + ParameterBase[ParameterDataTypeVar, InstrumentTypeVar_co], + Generic[ParameterDataTypeVar, InstrumentTypeVar_co], +): + """A gettable parameter that returns a dataclass or Pydantic BaseModel. + + When used in a :class:`~qcodes.dataset.Measurement`, each field of the + struct is automatically unpacked into a separate dataset column. + + Subclasses should define a :meth:`get_raw` method that returns an instance + of ``struct_type``. + + Args: + name: The local name of the parameter. Must be a valid identifier. + struct_type: A dataclass or Pydantic v2 BaseModel class whose fields + define the structure of the returned data. + field_labels: Optional mapping of ``{field_name: label}`` for + dataset/graph axis labels. Defaults to field names. + field_units: Optional mapping of ``{field_name: unit}`` for + dataset/graph axis units. Defaults to empty strings. + field_paramtypes: Optional mapping of ``{field_name: paramtype}`` + to override the auto-inferred paramtype. Valid values are + ``"numeric"``, ``"text"``, ``"complex"``, and ``"array"``. + get_cmd: A callable with zero arguments that returns an instance + of ``struct_type``. If ``None`` (the default), the subclass must + implement :meth:`get_raw`. + docstring: Documentation string for the ``__doc__`` field. + **kwargs: Forwarded to :class:`ParameterBase`. + See :class:`ParameterBaseKWArgs` for details. + + Example: + + .. code-block:: python + + from dataclasses import dataclass + from qcodes.parameters import StructParameter + + @dataclass + class IVResult: + voltage: float + current: float + + class MyIVParameter(StructParameter): + def get_raw(self): + v = self.instrument.ask("MEAS:VOLT?") + i = self.instrument.ask("MEAS:CURR?") + return IVResult(voltage=float(v), current=float(i)) + + param = MyIVParameter( + "iv_measurement", + struct_type=IVResult, + field_units={"voltage": "V", "current": "A"}, + ) + + """ + + def __init__( + self, + name: str, + struct_type: type, + *, + get_cmd: Callable[[], Any] | None = None, + field_labels: Mapping[str, str] | None = None, + field_units: Mapping[str, str] | None = None, + field_paramtypes: Mapping[str, str] | None = None, + docstring: str | None = None, + **kwargs: Unpack[ + ParameterBaseKWArgs[ParameterDataTypeVar, InstrumentTypeVar_co] + ], + ) -> None: + kwargs.setdefault("snapshot_value", False) + super().__init__(name, **kwargs) + + # Wire up get_cmd as get_raw if provided + if get_cmd is not None: + if self._implements_get_raw: + raise TypeError( + "Supplying get_cmd to a StructParameter that already " + "implements get_raw is an error." + ) + + def _get_from_cmd(self: StructParameter) -> Any: # type: ignore[type-arg] + return get_cmd() + + self.get_raw = MethodType(_get_from_cmd, self) # type: ignore[method-assign] + self._gettable = True + self.get = self._wrap_get(self.get_raw) + + self._struct_type = struct_type + field_labels = field_labels or {} + field_units = field_units or {} + field_paramtypes = field_paramtypes or {} + + # Introspect the struct type + fields = _get_struct_fields(struct_type) + if not fields: + raise TypeError(f"struct_type {struct_type.__name__} has no fields") + + # Validate user-supplied overrides reference real fields + field_name_set = {f[0] for f in fields} + for mapping_name, mapping in [ + ("field_labels", field_labels), + ("field_units", field_units), + ("field_paramtypes", field_paramtypes), + ]: + unknown = set(mapping.keys()) - field_name_set + if unknown: + raise ValueError( + f"{mapping_name} contains unknown field names: {unknown}" + ) + + for pt_name, pt_val in field_paramtypes.items(): + if pt_val not in _ALLOWED_PARAMTYPES: + raise ValueError( + f"Invalid paramtype {pt_val!r} for field {pt_name!r}. " + f"Allowed values: {sorted(_ALLOWED_PARAMTYPES)}" + ) + + # Build child parameters for each field + self._field_parameters: dict[str, _FieldParameter] = {} + names_list: list[str] = [] + labels_list: list[str] = [] + units_list: list[str] = [] + + for field_name, field_annotation in fields: + paramtype = field_paramtypes.get( + field_name, + _infer_paramtype_from_annotation(field_annotation), + ) + label = field_labels.get(field_name, field_name) + unit = field_units.get(field_name, "") + child_name = f"{name}_{field_name}" + + child_param: _FieldParameter = _FieldParameter( + name=child_name, + label=label, + unit=unit, + paramtype=paramtype, + ) + child_param.vals = _validator_for_paramtype(paramtype) + self._field_parameters[field_name] = child_param + names_list.append(field_name) + labels_list.append(label) + units_list.append(unit) + + self.names: tuple[str, ...] = tuple(names_list) + self.labels: tuple[str, ...] = tuple(labels_list) + self.units: tuple[str, ...] = tuple(units_list) + + self._meta_attrs.extend(["names", "labels", "units", "struct_type_name"]) + + # Generate docstring + self.__doc__ = os.linesep.join( + ( + "StructParameter class:", + "", + f"* `name` {self.name}", + f"* `struct_type` {struct_type.__name__}", + "* `names` {}".format(", ".join(self.names)), + "* `labels` {}".format(", ".join(self.labels)), + "* `units` {}".format(", ".join(self.units)), + ) + ) + if docstring is not None: + self.__doc__ = os.linesep.join((docstring, "", self.__doc__)) + + if not self.gettable: + raise AttributeError("StructParameter must have a get method") + + @property + def struct_type(self) -> type: + """The dataclass or Pydantic BaseModel class for this parameter.""" + return self._struct_type + + @property + def struct_type_name(self) -> str: + """Name of the struct type, included in snapshots.""" + return self._struct_type.__name__ + + @property + def field_parameters(self) -> dict[str, _FieldParameter]: + """Mapping of field name to the synthetic child parameter.""" + return dict(self._field_parameters) + + @property + def short_names(self) -> tuple[str, ...]: + """Short names of the struct fields (without instrument prefix).""" + return self.names + + @property + def full_names(self) -> tuple[str, ...]: + """Full names of fields including instrument name prefix.""" + inst_name = "_".join(self.name_parts[:-1]) + if inst_name: + return tuple(f"{inst_name}_{self.name}_{n}" for n in self.names) + return tuple(f"{self.name}_{n}" for n in self.names) + + def unpack_self( + self, value: ValuesType + ) -> list[tuple[ParameterBase[Any, Any], ValuesType]]: + """Unpack a struct value into individual field parameter results. + + This method does NOT include the parent struct parameter itself in the + results. Only the individual field values are returned, each paired + with its corresponding synthetic child parameter. + + Args: + value: An instance of the struct type returned by ``get_raw``. + + Returns: + A list of ``(field_parameter, field_value)`` tuples. + + """ + results: list[tuple[ParameterBase[Any, Any], ValuesType]] = [] + for field_name, field_param in self._field_parameters.items(): + field_value = _extract_field_value(value, field_name) + results.append((field_param, field_value)) + return results diff --git a/tests/parameter/test_struct_parameter.py b/tests/parameter/test_struct_parameter.py new file mode 100644 index 00000000000..c7dec75483f --- /dev/null +++ b/tests/parameter/test_struct_parameter.py @@ -0,0 +1,576 @@ +"""Tests for StructParameter.""" + +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest + +from qcodes.dataset import Measurement +from qcodes.parameters import ManualParameter +from qcodes.parameters.struct_parameter import ( + StructParameter, + _extract_field_value, + _FieldParameter, + _get_struct_fields, + _infer_paramtype_from_annotation, + _is_pydantic_model_class, +) + +if TYPE_CHECKING: + from qcodes.parameters.parameter_base import ParamRawDataType + + +# --- Test struct types --- + + +@dataclasses.dataclass +class SimpleResult: + voltage: float + current: float + + +@dataclasses.dataclass +class MixedResult: + name: str + value: float + count: int + flag: bool + + +@dataclasses.dataclass +class ComplexFieldResult: + impedance: complex + signal: float + + +@dataclasses.dataclass +class ArrayFieldResult: + trace: np.ndarray + amplitude: float + + +@dataclasses.dataclass +class EmptyStruct: + pass + + +# --- Concrete StructParameter subclasses for testing --- + + +class SimpleStructParam(StructParameter[SimpleResult, None]): + def __init__(self, name: str, result: SimpleResult, **kwargs: Any) -> None: + self._result = result + super().__init__(name, struct_type=SimpleResult, **kwargs) + + def get_raw(self) -> ParamRawDataType: + return self._result + + +class MixedStructParam(StructParameter[MixedResult, None]): + def __init__(self, name: str, result: MixedResult, **kwargs: Any) -> None: + self._result = result + super().__init__(name, struct_type=MixedResult, **kwargs) + + def get_raw(self) -> ParamRawDataType: + return self._result + + +class ComplexStructParam(StructParameter[ComplexFieldResult, None]): + def __init__(self, name: str, result: ComplexFieldResult, **kwargs: Any) -> None: + self._result = result + super().__init__(name, struct_type=ComplexFieldResult, **kwargs) + + def get_raw(self) -> ParamRawDataType: + return self._result + + +# --- Helper function tests --- + + +class TestIsPydanticModelClass: + def test_dataclass_is_not_pydantic(self) -> None: + assert not _is_pydantic_model_class(SimpleResult) + + def test_regular_class_is_not_pydantic(self) -> None: + assert not _is_pydantic_model_class(int) + + def test_non_type_is_not_pydantic(self) -> None: + assert not _is_pydantic_model_class(42) # type: ignore[arg-type] + + +class TestGetStructFields: + def test_dataclass_fields(self) -> None: + fields = _get_struct_fields(SimpleResult) + assert len(fields) == 2 + assert fields[0] == ("voltage", float) + assert fields[1] == ("current", float) + + def test_mixed_dataclass_fields(self) -> None: + fields = _get_struct_fields(MixedResult) + assert len(fields) == 4 + names = [f[0] for f in fields] + assert names == ["name", "value", "count", "flag"] + + def test_non_struct_raises(self) -> None: + with pytest.raises(TypeError, match="must be a dataclass"): + _get_struct_fields(int) + + def test_regular_class_raises(self) -> None: + class NotAStruct: + x: int = 5 + + with pytest.raises(TypeError, match="must be a dataclass"): + _get_struct_fields(NotAStruct) + + +class TestInferParamtype: + def test_float(self) -> None: + assert _infer_paramtype_from_annotation(float) == "numeric" + + def test_int(self) -> None: + assert _infer_paramtype_from_annotation(int) == "numeric" + + def test_bool(self) -> None: + assert _infer_paramtype_from_annotation(bool) == "numeric" + + def test_str(self) -> None: + assert _infer_paramtype_from_annotation(str) == "text" + + def test_complex(self) -> None: + assert _infer_paramtype_from_annotation(complex) == "complex" + + def test_ndarray(self) -> None: + assert _infer_paramtype_from_annotation(np.ndarray) == "array" + + def test_nested_dataclass_raises(self) -> None: + with pytest.raises(TypeError, match="Nested structured types"): + _infer_paramtype_from_annotation(SimpleResult) + + def test_unknown_defaults_to_numeric(self) -> None: + assert _infer_paramtype_from_annotation(bytes) == "numeric" + + +class TestExtractFieldValue: + def test_dataclass(self) -> None: + result = SimpleResult(voltage=1.5, current=0.3) + assert _extract_field_value(result, "voltage") == 1.5 + assert _extract_field_value(result, "current") == 0.3 + + +# --- FieldParameter tests --- + + +class TestFieldParameter: + def test_basic_creation(self) -> None: + fp = _FieldParameter("test_field", label="Test", unit="V") + assert fp.name == "test_field" + assert fp.label == "Test" + assert fp.unit == "V" + assert not fp.gettable + assert not fp.settable + + def test_default_label_and_unit(self) -> None: + fp = _FieldParameter("my_field") + assert fp.label == "my_field" + assert fp.unit == "" + + def test_paramtype(self) -> None: + fp = _FieldParameter("f", paramtype="text") + assert fp.paramtype == "text" + + def test_not_bound_to_instrument(self) -> None: + fp = _FieldParameter("f") + assert fp.instrument is None + + def test_snapshot_excluded(self) -> None: + fp = _FieldParameter("f") + assert fp.snapshot_exclude is False + assert fp.snapshot_value is False + + +# --- StructParameter tests --- + + +class TestStructParameterInit: + def test_basic_creation(self) -> None: + result = SimpleResult(voltage=1.0, current=0.5) + param = SimpleStructParam("iv", result=result) + assert param.name == "iv" + assert param.struct_type is SimpleResult + assert param.struct_type_name == "SimpleResult" + assert param.names == ("voltage", "current") + assert param.labels == ("voltage", "current") + assert param.units == ("", "") + assert param.gettable + + def test_custom_labels_and_units(self) -> None: + result = SimpleResult(voltage=1.0, current=0.5) + param = SimpleStructParam( + "iv", + result=result, + field_labels={"voltage": "Voltage", "current": "Current"}, + field_units={"voltage": "V", "current": "A"}, + ) + assert param.labels == ("Voltage", "Current") + assert param.units == ("V", "A") + + def test_custom_paramtypes(self) -> None: + result = SimpleResult(voltage=1.0, current=0.5) + param = SimpleStructParam( + "iv", + result=result, + field_paramtypes={"voltage": "text"}, + ) + field_params = param.field_parameters + assert field_params["voltage"].paramtype == "text" + assert field_params["current"].paramtype == "numeric" + + def test_snapshot_value_defaults_to_false(self) -> None: + result = SimpleResult(voltage=1.0, current=0.5) + param = SimpleStructParam("iv", result=result) + assert param.snapshot_value is False + + def test_snapshot_value_can_be_overridden(self) -> None: + result = SimpleResult(voltage=1.0, current=0.5) + param = SimpleStructParam("iv", result=result, snapshot_value=True) + assert param.snapshot_value is True + + def test_empty_struct_raises(self) -> None: + with pytest.raises(TypeError, match="has no fields"): + + class EmptyStructParam(StructParameter[EmptyStruct, None]): + def get_raw(self) -> ParamRawDataType: + return EmptyStruct() + + EmptyStructParam("empty", struct_type=EmptyStruct) + + def test_non_struct_type_raises(self) -> None: + with pytest.raises(TypeError, match="must be a dataclass"): + + class BadStructParam(StructParameter[int, None]): + def get_raw(self) -> ParamRawDataType: + return 42 + + BadStructParam("bad", struct_type=int) + + def test_unknown_field_label_raises(self) -> None: + with pytest.raises(ValueError, match="unknown field names"): + SimpleStructParam( + "iv", + result=SimpleResult(1.0, 0.5), + field_labels={"nonexistent": "Nope"}, + ) + + def test_unknown_field_unit_raises(self) -> None: + with pytest.raises(ValueError, match="unknown field names"): + SimpleStructParam( + "iv", + result=SimpleResult(1.0, 0.5), + field_units={"nonexistent": "X"}, + ) + + def test_invalid_paramtype_raises(self) -> None: + with pytest.raises(ValueError, match="Invalid paramtype"): + SimpleStructParam( + "iv", + result=SimpleResult(1.0, 0.5), + field_paramtypes={"voltage": "invalid"}, + ) + + def test_nested_dataclass_field_raises(self) -> None: + @dataclasses.dataclass + class Outer: + inner: SimpleResult + + with pytest.raises(TypeError, match="Nested structured types"): + + class NestedStructParam(StructParameter[Outer, None]): + def get_raw(self) -> ParamRawDataType: + return Outer(inner=SimpleResult(1.0, 0.5)) + + NestedStructParam("nested", struct_type=Outer) + + +class TestStructParameterFieldParameters: + def test_field_parameters_dict(self) -> None: + param = SimpleStructParam("iv", result=SimpleResult(1.0, 0.5)) + fps = param.field_parameters + assert set(fps.keys()) == {"voltage", "current"} + assert isinstance(fps["voltage"], _FieldParameter) + assert isinstance(fps["current"], _FieldParameter) + + def test_field_param_names(self) -> None: + param = SimpleStructParam("iv", result=SimpleResult(1.0, 0.5)) + fps = param.field_parameters + assert fps["voltage"].name == "iv_voltage" + assert fps["current"].name == "iv_current" + + def test_field_param_returns_copy(self) -> None: + param = SimpleStructParam("iv", result=SimpleResult(1.0, 0.5)) + fps1 = param.field_parameters + fps2 = param.field_parameters + assert fps1 is not fps2 + assert fps1.keys() == fps2.keys() + + +class TestStructParameterNames: + def test_short_names(self) -> None: + param = SimpleStructParam("iv", result=SimpleResult(1.0, 0.5)) + assert param.short_names == ("voltage", "current") + + def test_full_names_no_instrument(self) -> None: + param = SimpleStructParam("iv", result=SimpleResult(1.0, 0.5)) + assert param.full_names == ("iv_voltage", "iv_current") + + +class TestStructParameterGet: + def test_get_returns_struct(self) -> None: + result = SimpleResult(voltage=1.5, current=0.3) + param = SimpleStructParam("iv", result=result) + got = param.get() + assert got == result + + def test_call_returns_struct(self) -> None: + result = SimpleResult(voltage=1.5, current=0.3) + param = SimpleStructParam("iv", result=result) + got = param() + assert got == result + + def test_mixed_types(self) -> None: + result = MixedResult(name="test", value=1.5, count=42, flag=True) + param = MixedStructParam("mixed", result=result) + got = param.get() + assert got.name == "test" + assert got.value == 1.5 + assert got.count == 42 + assert got.flag is True + + def test_get_cmd_callable(self) -> None: + expected = SimpleResult(voltage=2.0, current=0.5) + param = StructParameter( + "iv", + struct_type=SimpleResult, + get_cmd=lambda: expected, + ) + got = param.get() + assert got == expected + + def test_get_cmd_with_labels_and_units(self) -> None: + expected = SimpleResult(voltage=3.0, current=1.0) + param = StructParameter( + "iv", + struct_type=SimpleResult, + get_cmd=lambda: expected, + field_labels={"voltage": "V_out"}, + field_units={"current": "mA"}, + ) + assert param.get() == expected + assert param.field_parameters["voltage"].label == "V_out" + assert param.field_parameters["current"].unit == "mA" + + def test_get_cmd_with_subclass_get_raw_raises(self) -> None: + with pytest.raises( + TypeError, + match="Supplying get_cmd to a StructParameter that already implements get_raw", + ): + SimpleStructParam( + "iv", + result=SimpleResult(voltage=1.0, current=0.1), + get_cmd=lambda: SimpleResult(voltage=2.0, current=0.2), # type: ignore[call-arg] + ) + + +class TestStructParameterUnpackSelf: + def test_unpack_simple(self) -> None: + result = SimpleResult(voltage=1.5, current=0.3) + param = SimpleStructParam("iv", result=result) + unpacked = param.unpack_self(result) # type: ignore[arg-type] + assert len(unpacked) == 2 + # Check that we get the field parameters with the right values + param_names = [p.name for p, _ in unpacked] + values = [v for _, v in unpacked] + assert "iv_voltage" in param_names + assert "iv_current" in param_names + assert 1.5 in values + assert 0.3 in values + + def test_unpack_does_not_include_self(self) -> None: + result = SimpleResult(voltage=1.5, current=0.3) + param = SimpleStructParam("iv", result=result) + unpacked = param.unpack_self(result) # type: ignore[arg-type] + # None of the unpacked parameters should be the struct parameter itself + for p, _ in unpacked: + assert p is not param + + def test_unpack_mixed_types(self) -> None: + result = MixedResult(name="hello", value=3.14, count=7, flag=False) + param = MixedStructParam("data", result=result) + unpacked = param.unpack_self(result) # type: ignore[arg-type] + assert len(unpacked) == 4 + values_dict = {p.name: v for p, v in unpacked} + assert values_dict["data_name"] == "hello" + assert values_dict["data_value"] == 3.14 + assert values_dict["data_count"] == 7 + assert values_dict["data_flag"] is False + + +class TestStructParameterSnapshot: + def test_snapshot_no_value_by_default(self) -> None: + result = SimpleResult(voltage=1.0, current=0.5) + param = SimpleStructParam("iv", result=result) + snap = param.snapshot() + assert "value" not in snap + assert "raw_value" not in snap + + def test_snapshot_includes_struct_metadata(self) -> None: + result = SimpleResult(voltage=1.0, current=0.5) + param = SimpleStructParam("iv", result=result) + snap = param.snapshot() + assert snap["struct_type_name"] == "SimpleResult" + assert snap["names"] == ("voltage", "current") + + def test_snapshot_with_value(self) -> None: + result = SimpleResult(voltage=1.0, current=0.5) + param = SimpleStructParam("iv", result=result, snapshot_value=True) + snap = param.snapshot(update=True) + assert "value" in snap + + +class TestStructParameterDocstring: + def test_docstring_generated(self) -> None: + result = SimpleResult(voltage=1.0, current=0.5) + param = SimpleStructParam("iv", result=result) + assert param.__doc__ is not None + assert "iv" in param.__doc__ + assert "SimpleResult" in param.__doc__ + + def test_custom_docstring(self) -> None: + result = SimpleResult(voltage=1.0, current=0.5) + param = SimpleStructParam("iv", result=result, docstring="Custom docs") + assert param.__doc__ is not None + assert param.__doc__.startswith("Custom docs") + + +# --- Measurement integration tests --- + + +class TestStructParameterMeasurement: + def test_register_struct_parameter(self, experiment: Any) -> None: + + setpoint = ManualParameter("x") + result = SimpleResult(voltage=1.0, current=0.5) + param = SimpleStructParam("iv", result=result) + + meas = Measurement(experiment) + meas.register_parameter(setpoint) + meas.register_parameter(param, setpoints=[setpoint]) + + # Field parameters should be registered + interdeps = meas._interdeps + param_names = {ps.name for ps in interdeps.dependencies.keys()} + assert "iv_voltage" in param_names + assert "iv_current" in param_names + + def test_add_result_with_struct(self, experiment: Any) -> None: + + setpoint = ManualParameter("x") + result = SimpleResult(voltage=1.5, current=0.3) + param = SimpleStructParam("iv", result=result) + + meas = Measurement(experiment) + meas.register_parameter(setpoint) + meas.register_parameter(param, setpoints=[setpoint]) + + with meas.run() as datasaver: + for x_val in [0.0, 1.0, 2.0]: + setpoint(x_val) + struct_val = SimpleResult(voltage=x_val * 2, current=x_val * 0.1) + datasaver.add_result( + (setpoint, x_val), + (param, struct_val), # type: ignore[arg-type] + ) + ds = datasaver.dataset + + data = ds.get_parameter_data() + # Check that both field columns exist + assert "iv_voltage" in data + assert "iv_current" in data + + # Check the data values + voltage_data = data["iv_voltage"] + assert "iv_voltage" in voltage_data + assert "x" in voltage_data + np.testing.assert_array_almost_equal( + voltage_data["iv_voltage"], [0.0, 2.0, 4.0] + ) + np.testing.assert_array_almost_equal(voltage_data["x"], [0.0, 1.0, 2.0]) + + current_data = data["iv_current"] + assert "iv_current" in current_data + np.testing.assert_array_almost_equal( + current_data["iv_current"], [0.0, 0.1, 0.2] + ) + + def test_add_result_with_get(self, experiment: Any) -> None: + """Test using param.get() and passing the struct value.""" + + setpoint = ManualParameter("x") + result = SimpleResult(voltage=3.0, current=1.5) + param = SimpleStructParam("iv", result=result) + + meas = Measurement(experiment) + meas.register_parameter(setpoint) + meas.register_parameter(param, setpoints=[setpoint]) + + with meas.run() as datasaver: + setpoint(0.0) + val = param.get() + datasaver.add_result( + (setpoint, 0.0), + (param, val), # type: ignore[arg-type] + ) + ds = datasaver.dataset + + data = ds.get_parameter_data() + assert "iv_voltage" in data + voltage_data = data["iv_voltage"] + np.testing.assert_array_almost_equal(voltage_data["iv_voltage"], [3.0]) + + def test_mixed_types_measurement(self, experiment: Any) -> None: + """Test struct with mixed field types in a measurement.""" + + setpoint = ManualParameter("x") + + @dataclasses.dataclass + class TextNumResult: + label: str + value: float + + class TextNumParam(StructParameter[TextNumResult, None]): + def get_raw(self) -> ParamRawDataType: + return TextNumResult(label="test", value=42.0) + + param = TextNumParam( + "tn", + struct_type=TextNumResult, + field_paramtypes={"label": "text"}, + ) + + meas = Measurement(experiment) + meas.register_parameter(setpoint) + meas.register_parameter(param, setpoints=[setpoint]) + + with meas.run() as datasaver: + setpoint(1.0) + datasaver.add_result( + (setpoint, 1.0), + (param, TextNumResult(label="hello", value=99.0)), # type: ignore[arg-type] + ) + ds = datasaver.dataset + + data = ds.get_parameter_data() + assert "tn_label" in data + assert "tn_value" in data