diff --git a/CHANGELOG.md b/CHANGELOG.md index 89e8171a..82d58611 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +* Changed `compas_model.models.Model.__from_data__` return type to `Self` so subclasses retain their type when deserialized. + ### Removed diff --git a/pyproject.toml b/pyproject.toml index 9661eca1..aa1f1205 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ] -dependencies = ["compas", "shapely"] +dependencies = ["compas", "shapely", "typing_extensions"] [project.optional-dependencies] dev = [ @@ -35,6 +35,7 @@ dev = [ "bump-my-version", "compas_invocations2", "invoke >=0.14", + "pyright", "pytest", "pytest-dependency", "ruff", diff --git a/src/compas_model/models/model.py b/src/compas_model/models/model.py index 9a1dab50..3cdea1f2 100644 --- a/src/compas_model/models/model.py +++ b/src/compas_model/models/model.py @@ -4,6 +4,8 @@ from typing import TypeVar from typing import Union +from typing_extensions import Self + from compas.datastructures import Datastructure from compas.geometry import Point from compas.geometry import Transformation @@ -59,7 +61,7 @@ def __data__(self) -> dict: return data @classmethod - def __from_data__(cls, data: dict) -> "Model": + def __from_data__(cls, data: dict) -> "Self": model = cls() model._transformation = data["transformation"] diff --git a/tests/test_model.py b/tests/test_model.py index cf5bbf56..c8e57764 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -45,8 +45,79 @@ # assert c_model.tree is not None # assert len(c_model.tree.elements) == 3 +import subprocess +from pathlib import Path + from compas_model.models import Model # noqa: F401 def test_import(): assert True + + +def test_from_data_roundtrip_preserves_subclass_behavior(): + class MyModel(Model): + pass + + model = MyModel() + data = model.__data__ + restored = MyModel.__from_data__(data) + + assert isinstance(restored, MyModel) + assert restored.__data__ == data + + +def test_self_return_type_with_pyright(tmp_path: Path): + test_file = tmp_path / "typing_case.py" + test_file.write_text( + """ +from typing import assert_type + +from compas_model.models import Model + + +class MyModel(Model): + pass + + +obj = MyModel.__from_data__(MyModel().__data__) +assert_type(obj, MyModel) +""" + ) + + result = subprocess.run( + ["pyright", str(test_file)], + text=True, + capture_output=True, + cwd=Path(__file__).parents[1], + ) + + assert result.returncode == 0, result.stdout + result.stderr + + +def test_from_data_with_overridden_subclass(): + class MyModel(Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._extra = None + + @property + def __data__(self): + data = super().__data__ + data["extra"] = self._extra + return data + + @classmethod + def __from_data__(cls, data): + model = super().__from_data__(data) + model._extra = data.get("extra") + return model + + model = MyModel() + model._extra = "hello" + data = model.__data__ + + restored = MyModel.__from_data__(data) + + assert isinstance(restored, MyModel) + assert restored._extra == "hello"