diff --git a/pylasu/model/model.py b/pylasu/model/model.py index 3ea9010..5bc3429 100644 --- a/pylasu/model/model.py +++ b/pylasu/model/model.py @@ -3,6 +3,7 @@ from dataclasses import Field, MISSING, dataclass, field from typing import Optional, Callable, List, Union +from .naming import ReferenceByName from .position import Position, Source from .reflection import Multiplicity, PropertyDescription from ..reflection import getannotations, get_type_arguments, is_sequence_type @@ -95,13 +96,57 @@ def provides_nodes(decl_type): return isinstance(decl_type, type) and issubclass(decl_type, Node) +def get_only_type_arg(decl_type): + """If decl_type has a single type argument, return it, otherwise return None""" + type_args = get_type_arguments(decl_type) + if len(type_args) == 1: + return type_args[0] + else: + return None + + +def process_annotated_property(name, decl_type, known_property_names): + multiplicity = Multiplicity.SINGULAR + is_reference = False + if get_type_origin(decl_type) is ReferenceByName: + decl_type = get_only_type_arg(decl_type) or decl_type + is_reference = True + if is_sequence_type(decl_type): + decl_type = get_only_type_arg(decl_type) or decl_type + multiplicity = Multiplicity.MANY + if get_type_origin(decl_type) is Union: + type_args = get_type_arguments(decl_type) + if len(type_args) == 1: + decl_type = type_args[0] + elif len(type_args) == 2: + if type_args[0] is type(None): + decl_type = type_args[1] + elif type_args[1] is type(None): + decl_type = type_args[0] + else: + raise Exception(f"Unsupported feature {name} of type {decl_type}") + if multiplicity == Multiplicity.SINGULAR: + multiplicity = Multiplicity.OPTIONAL + else: + raise Exception(f"Unsupported feature {name} of type {decl_type}") + if not isinstance(decl_type, type): + raise Exception(f"Unsupported feature {name} of type {decl_type}") + is_containment = provides_nodes(decl_type) and not is_reference + known_property_names.add(name) + return PropertyDescription(name, decl_type, is_containment, is_reference, multiplicity) + + class Concept(ABCMeta): def __init__(cls, what, bases=None, dict=None): super().__init__(what, bases, dict) - cls.__internal_properties__ = \ - (["origin", "destination", "parent", "position", "position_override"] - + [n for n, v in inspect.getmembers(cls, is_internal_property_or_method)]) + cls.__internal_properties__ = [] + for base in bases: + if hasattr(base, "__internal_properties__"): + cls.__internal_properties__.extend(base.__internal_properties__) + if not cls.__internal_properties__: + cls.__internal_properties__ = ["origin", "destination", "parent", "position", "position_override"] + cls.__internal_properties__.extend([n for n, v in inspect.getmembers(cls, is_internal_property_or_method)]) @property def node_properties(cls): @@ -115,23 +160,11 @@ def _direct_node_properties(cls, cl, known_property_names): return for name in anns: if name not in known_property_names and cls.is_node_property(name): - is_child_property = False - multiplicity = Multiplicity.SINGULAR - if name in anns: - decl_type = anns[name] - if is_sequence_type(decl_type): - multiplicity = Multiplicity.MANY - type_args = get_type_arguments(decl_type) - if len(type_args) == 1: - is_child_property = provides_nodes(type_args[0]) - else: - is_child_property = provides_nodes(decl_type) - known_property_names.add(name) - yield PropertyDescription(name, is_child_property, multiplicity) + yield process_annotated_property(name, anns[name], known_property_names) for name in dir(cl): if name not in known_property_names and cls.is_node_property(name): known_property_names.add(name) - yield PropertyDescription(name, False) + yield PropertyDescription(name, None, False, False) def is_node_property(cls, name): return not name.startswith('_') and name not in cls.__internal_properties__ @@ -180,7 +213,9 @@ def source(self) -> Optional[Source]: @internal_property def properties(self): - return (PropertyDescription(p.name, p.provides_nodes, p.multiplicity, getattr(self, p.name)) + return (PropertyDescription(p.name, p.type, + is_containment=p.is_containment, is_reference=p.is_reference, + multiplicity=p.multiplicity, value=getattr(self, p.name)) for p in self.__class__.node_properties) @internal_property diff --git a/pylasu/model/reflection.py b/pylasu/model/reflection.py index f9d7def..6c5811e 100644 --- a/pylasu/model/reflection.py +++ b/pylasu/model/reflection.py @@ -1,5 +1,6 @@ import enum from dataclasses import dataclass +from typing import Optional class Multiplicity(enum.Enum): @@ -11,7 +12,9 @@ class Multiplicity(enum.Enum): @dataclass class PropertyDescription: name: str - provides_nodes: bool + type: Optional[type] + is_containment: bool + is_reference: bool multiplicity: Multiplicity = Multiplicity.SINGULAR value: object = None diff --git a/pylasu/reflection/reflection.py b/pylasu/reflection/reflection.py index 2b0f3ce..f415add 100644 --- a/pylasu/reflection/reflection.py +++ b/pylasu/reflection/reflection.py @@ -4,25 +4,30 @@ def getannotations(cls): - import inspect - try: # On Python 3.10+ - return inspect.getannotations(cls) + try: + # https://peps.python.org/pep-0563/ + return typing.get_type_hints(cls, globalns=None, localns=None) except AttributeError: - if isinstance(cls, type): - return cls.__dict__.get('__annotations__', None) - else: - return getattr(cls, '__annotations__', None) + try: + # On Python 3.10+ + import inspect + return inspect.getannotations(cls) + except AttributeError: + if isinstance(cls, type): + return cls.__dict__.get('__annotations__', None) + else: + return getattr(cls, '__annotations__', None) def get_type_origin(tp): + origin = None if hasattr(typing, "get_origin"): - return typing.get_origin(tp) + origin = typing.get_origin(tp) elif hasattr(tp, "__origin__"): - return tp.__origin__ + origin = tp.__origin__ elif tp is typing.Generic: - return typing.Generic - else: - return None + origin = typing.Generic + return origin or (tp if isinstance(tp, type) else None) def is_enum_type(attr_type): diff --git a/pylasu/testing/testing.py b/pylasu/testing/testing.py index a20edc3..a7590ca 100644 --- a/pylasu/testing/testing.py +++ b/pylasu/testing/testing.py @@ -20,7 +20,7 @@ def assert_asts_are_equal( case.fail(f"No property {expected_property.name} found at {context}") actual_prop_value = actual_property.value expected_prop_value = expected_property.value - if expected_property.provides_nodes: + if expected_property.is_containment: if expected_property.multiple: assert_multi_properties_are_equal( case, expected_property, expected_prop_value, actual_prop_value, context, consider_position) diff --git a/tests/model/test_model.py b/tests/model/test_model.py index 7600c7b..1e855b1 100644 --- a/tests/model/test_model.py +++ b/tests/model/test_model.py @@ -2,8 +2,8 @@ import unittest from typing import List, Optional, Union -from pylasu.model import Node, Position, Point -from pylasu.model.reflection import Multiplicity +from pylasu.model import Node, Position, Point, internal_field +from pylasu.model.reflection import Multiplicity, PropertyDescription from pylasu.model.naming import ReferenceByName, Named, Scope, Symbol from pylasu.support import extension_method @@ -13,14 +13,26 @@ class SomeNode(Node, Named): foo = 3 bar: int = dataclasses.field(init=False) __private__ = 4 - ref: Node = None + containment: Node = None + reference: ReferenceByName[Node] = None multiple: List[Node] = dataclasses.field(default_factory=list) + optional: Optional[Node] = None multiple_opt: List[Optional[Node]] = dataclasses.field(default_factory=list) + internal: Node = internal_field(default=None) def __post_init__(self): self.bar = 5 +@dataclasses.dataclass +class ExtendedNode(SomeNode): + prop = 2 + cont_fwd: "ExtendedNode" = None + cont_ref: ReferenceByName["ExtendedNode"] = None + multiple2: List[SomeNode] = dataclasses.field(default_factory=list) + internal2: Node = internal_field(default=None) + + @dataclasses.dataclass class SomeSymbol(Symbol): index: int = dataclasses.field(default=None) @@ -39,6 +51,14 @@ class InvalidNode(Node): another_child: Node = None +def require_feature(node, name) -> PropertyDescription: + return next(n for n in node.properties if n.name == name) + + +def find_feature(node, name) -> Optional[PropertyDescription]: + return next((n for n in node.properties if n.name == name), None) + + class ModelTest(unittest.TestCase): def test_reference_by_name_unsolved_str(self): @@ -77,9 +97,29 @@ def test_node_with_position(self): def test_node_properties(self): node = SomeNode("n").with_position(Position(Point(1, 0), Point(2, 1))) - self.assertIsNotNone(next(n for n in node.properties if n.name == 'foo')) - self.assertIsNotNone(next(n for n in node.properties if n.name == 'bar')) - self.assertIsNotNone(next(n for n in node.properties if n.name == "name")) + self.assertIsNotNone(find_feature(node, 'foo')) + self.assertFalse(find_feature(node, 'foo').is_containment) + self.assertIsNotNone(find_feature(node, 'bar')) + self.assertFalse(find_feature(node, 'bar').is_containment) + self.assertIsNotNone(find_feature(node, 'name')) + self.assertTrue(find_feature(node, 'containment').is_containment) + self.assertFalse(find_feature(node, 'containment').is_reference) + self.assertFalse(find_feature(node, 'reference').is_containment) + self.assertTrue(find_feature(node, 'reference').is_reference) + with self.assertRaises(StopIteration): + next(n for n in node.properties if n.name == '__private__') + with self.assertRaises(StopIteration): + next(n for n in node.properties if n.name == 'non_existent') + with self.assertRaises(StopIteration): + next(n for n in node.properties if n.name == 'properties') + with self.assertRaises(StopIteration): + next(n for n in node.properties if n.name == "origin") + + def test_node_properties_inheritance(self): + node = ExtendedNode("n").with_position(Position(Point(1, 0), Point(2, 1))) + self.assertIsNotNone(find_feature(node, 'foo')) + self.assertIsNotNone(find_feature(node, 'bar')) + self.assertIsNotNone(find_feature(node, 'name')) with self.assertRaises(StopIteration): next(n for n in node.properties if n.name == '__private__') with self.assertRaises(StopIteration): @@ -159,20 +199,52 @@ def frob_node(_: Node): pass pds = [pd for pd in sorted(SomeNode.node_properties, key=lambda x: x.name)] - self.assertEqual(6, len(pds), f"{pds} should be 6") + self.assertEqual(8, len(pds), f"{pds} should be 7") self.assertEqual("bar", pds[0].name) - self.assertFalse(pds[0].provides_nodes) - self.assertEqual("foo", pds[1].name) - self.assertFalse(pds[1].provides_nodes) - self.assertEqual("multiple", pds[2].name) - self.assertTrue(pds[2].provides_nodes) - self.assertEqual(Multiplicity.MANY, pds[2].multiplicity) - self.assertEqual("multiple_opt", pds[3].name) - self.assertTrue(pds[3].provides_nodes) + self.assertFalse(pds[0].is_containment) + self.assertEqual("containment", pds[1].name) + self.assertTrue(pds[1].is_containment) + self.assertEqual("foo", pds[2].name) + self.assertFalse(pds[2].is_containment) + self.assertEqual("multiple", pds[3].name) + self.assertTrue(pds[3].is_containment) self.assertEqual(Multiplicity.MANY, pds[3].multiplicity) - self.assertEqual("name", pds[4].name) - self.assertFalse(pds[4].provides_nodes) - self.assertEqual("ref", pds[5].name) - self.assertTrue(pds[5].provides_nodes) + self.assertEqual("multiple_opt", pds[4].name) + self.assertTrue(pds[4].is_containment) + self.assertEqual(Multiplicity.MANY, pds[4].multiplicity) + self.assertEqual("name", pds[5].name) + self.assertFalse(pds[5].is_containment) + self.assertEqual("optional", pds[6].name) + self.assertTrue(pds[6].is_containment) + self.assertEqual(Multiplicity.OPTIONAL, pds[6].multiplicity) + self.assertEqual("reference", pds[7].name) + self.assertTrue(pds[7].is_reference) + + self.assertRaises(Exception, lambda: [x for x in InvalidNode.node_properties]) + + def test_node_properties_meta_inheritance(self): + @extension_method(Node) + def frob_node_2(_: Node): + pass + + pds = [pd for pd in sorted(ExtendedNode.node_properties, key=lambda x: x.name)] + self.assertEqual(12, len(pds), f"{pds} should be 7") + self.assertEqual("bar", pds[0].name) + self.assertFalse(pds[0].is_containment) + self.assertEqual("cont_fwd", pds[1].name) + self.assertTrue(pds[1].is_containment) + self.assertEqual(ExtendedNode, pds[1].type) + self.assertEqual("cont_ref", pds[2].name) + self.assertTrue(pds[2].is_reference) + self.assertEqual(ExtendedNode, pds[2].type) + self.assertEqual("containment", pds[3].name) + self.assertTrue(pds[3].is_containment) + self.assertEqual("foo", pds[4].name) + self.assertEqual("multiple", pds[5].name) + self.assertTrue(pds[5].is_containment) + self.assertEqual(Multiplicity.MANY, pds[5].multiplicity) + self.assertEqual("multiple2", pds[6].name) + self.assertTrue(pds[6].is_containment) + self.assertEqual(Multiplicity.MANY, pds[6].multiplicity) self.assertRaises(Exception, lambda: [x for x in InvalidNode.node_properties]) diff --git a/tests/test_metamodel_builder.py b/tests/test_metamodel_builder.py index 6512251..d338d16 100644 --- a/tests/test_metamodel_builder.py +++ b/tests/test_metamodel_builder.py @@ -118,7 +118,7 @@ def test_build_metamodel_single_package_inheritance(self): next((a for a in box.eClass.eAllAttributes() if a.name == "name"), None)) self.assertIsNotNone( next((a for a in box.eClass.eAllAttributes() if a.name == "strength"), None)) - self.assertEqual(2, len(box.eClass.eAllAttributes())) + self.assertEqual(3, len(box.eClass.eAllAttributes())) STARLASU_MODEL_JSON = '''{ diff --git a/tests/test_processing.py b/tests/test_processing.py index 50cf6c8..64a6abe 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -1,6 +1,6 @@ import unittest from dataclasses import dataclass -from typing import List, Set +from typing import List from pylasu.model import Node from tests.fixtures import box, Item @@ -17,12 +17,6 @@ class BW(Node): many_as: List[AW] -@dataclass -class CW(Node): - a: AW - many_as: Set[AW] - - class ProcessingTest(unittest.TestCase): def test_search_by_type(self): self.assertEqual(["1", "2", "3", "4", "5", "6"], [i.name for i in box.search_by_type(Item)]) @@ -42,15 +36,6 @@ def test_replace_in_list(self): self.assertEqual("4", b.many_as[0].s) self.assertEqual(BW(a1, [a4, a3]), b) - def test_replace_in_set(self): - a1 = AW("1") - a2 = AW("2") - a3 = AW("3") - a4 = AW("4") - c = CW(a1, {a2, a3}) - c.assign_parents() - self.assertRaises(Exception, lambda: a2.replace_with(a4)) - def test_replace_single(self): a1 = AW("1") a2 = AW("2")