diff --git a/src/easyscience/base_classes/__init__.py b/src/easyscience/base_classes/__init__.py index 9f3ba08..2d13a3b 100644 --- a/src/easyscience/base_classes/__init__.py +++ b/src/easyscience/base_classes/__init__.py @@ -1,5 +1,6 @@ from .based_base import BasedBase from .collection_base import CollectionBase +from .easy_list import EasyList from .model_base import ModelBase from .new_base import NewBase from .obj_base import ObjBase @@ -10,4 +11,5 @@ ObjBase, ModelBase, NewBase, + EasyList ] diff --git a/src/easyscience/base_classes/easy_list.py b/src/easyscience/base_classes/easy_list.py new file mode 100644 index 0000000..aa6e8ed --- /dev/null +++ b/src/easyscience/base_classes/easy_list.py @@ -0,0 +1,267 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project ProtectedType_: ... + @overload + def __getitem__(self, idx: slice) -> 'EasyList[ProtectedType_]': ... + @overload + def __getitem__(self, idx: str) -> ProtectedType_: ... + def __getitem__(self, idx: int | slice | str) -> ProtectedType_ | 'EasyList[ProtectedType_]': + """ + Get an item by index, slice, or unique_name. + + :param idx: Index, slice, or unique_name of the item + :return: The item or a new EasyList for slices + """ + if isinstance(idx, int): + return self._data[idx] + elif isinstance(idx, slice): + return self.__class__(self._data[idx], protected_types=self._protected_types) + elif isinstance(idx, str): + element = next((r for r in self._data if r.unique_name == idx), None) + if element is not None: + return element + raise KeyError(f'No item with unique name "{idx}" found') + else: + raise TypeError('Index must be an int, slice, or str') + + @overload + def __setitem__(self, idx: int, value: ProtectedType_) -> None: ... + @overload + def __setitem__(self, idx: slice, value: Iterable[ProtectedType_]) -> None: ... + + def __setitem__(self, idx: int | slice, value: ProtectedType_ | Iterable[ProtectedType_]) -> None: + """ + Set an item at an index. + + :param idx: Index to set + :param value: New value + """ + if isinstance(idx, int): + if not isinstance(value, tuple(self._protected_types)): + raise TypeError(f'Items must be one of {self._protected_types}, got {type(value)}') + self._data[idx] = value + elif isinstance(idx, slice): + if not isinstance(value, Iterable): + raise TypeError('Value must be an iterable for slice assignment') + for v in value: + if not isinstance(v, tuple(self._protected_types)): + raise TypeError(f'Items must be one of {self._protected_types}, got {type(v)}') + self._data[idx] = list(value) # type: ignore[arg-type] + else: + raise TypeError('Index must be an int or slice') + + def __delitem__(self, idx: int | slice | str) -> None: + """ + Delete an item by index, slice, or name. + + :param idx: Index, slice, or name of item to delete + """ + if isinstance(idx, (int, slice)): + del self._data[idx] + elif isinstance(idx, str): + for i, item in enumerate(self._data): + if item.unique_name == idx: + del self._data[i] + return + raise KeyError(f'No item with unique name "{idx}" found') + else: + raise TypeError('Index must be an int, slice, or str') + + def __len__(self) -> int: + """Return the number of items in the collection.""" + return len(self._data) + + def insert(self, index: int, value: ProtectedType_) -> None: + """ + Insert an item at an index. + + :param index: Index to insert at + :param value: Item to insert + """ + if not isinstance(index, int): + raise TypeError('Index must be an integer') + elif not isinstance(value, tuple(self._protected_types)): + raise TypeError(f'Items must be one of {self._protected_types}, got {type(value)}') + + self._data.insert(index, value) + + # Overwriting methods + + def sort(self, mapping: Callable[[ProtectedType_], Any], reverse: bool = False) -> None: + """ + Sort the collection according to the given mapping. + + :param mapping: Mapping function to sort by + :param reverse: Whether to reverse the sort + """ + self._data.sort(key=mapping, reverse=reverse) # type: ignore[arg-type] + + def __repr__(self) -> str: + return f'{self.__class__.__name__} of length {len(self)} of type(s) {self._protected_types}' + + def __iter__(self) -> Any: + return iter(self._data) + + def __contains__(self, item: ProtectedType_ | str) -> bool: + if isinstance(item, str): + return any(r.unique_name == item for r in self._data) + return item in self._data + + def index(self, value: ProtectedType_ | str, start: int = 0, stop: int = ...) -> int: + if isinstance(value, str): + for i in range(start, min(stop, len(self._data))): + if self._data[i].unique_name == value: + return i + raise ValueError(f'{value} is not in EasyList') + return self._data.index(value, start, stop) + + def append(self, value: ProtectedType_) -> None: + """ + Append an item to the end of the collection. + + :param value: Item to append + """ + if not isinstance(value, tuple(self._protected_types)): + raise TypeError(f'Items must be one of {self._protected_types}, got {type(value)}') + self._data.append(value) + + def pop(self, index: int | str = -1) -> ProtectedType_: + """ + Remove and return an item at the given index or unique_name. + + :param index: Index or unique_name of the item to remove + :return: The removed item + """ + if isinstance(index, int): + return self._data.pop(index) + elif isinstance(index, str): + for i, item in enumerate(self._data): + if item.unique_name == index: + return self._data.pop(i) + raise KeyError(f'No item with unique name "{index}" found') + else: + raise TypeError('Index must be an int or str') + + # Serialization support + + def to_dict(self) -> dict: + """ + Convert the EasyList to a dictionary for serialization. + + :return: Dictionary representation of the EasyList + """ + dict_repr = super().to_dict() + if self._protected_types != [NewBase]: + dict_repr['protected_types'] = [ + {'@module': cls_.__module__, '@class': cls_.__name__} for cls_ in self._protected_types + ] # noqa: E501 + dict_repr['data'] = [item.to_dict() for item in self._data] + return dict_repr + + @classmethod + def from_dict(cls, obj_dict: Dict[str, Any]) -> NewBase: + """ + Re-create an EasyScience object from a full encoded dictionary. + + :param obj_dict: dictionary containing the serialized contents (from `SerializerDict`) of an EasyScience object + :return: Reformed EasyScience object + """ + if not SerializerBase._is_serialized_easyscience_object(obj_dict): + raise ValueError('Input must be a dictionary representing an EasyScience EasyList object.') + if obj_dict['@class'] == cls.__name__: + if 'protected_types' in obj_dict: + protected_types = obj_dict.pop('protected_types') + for i, type_dict in enumerate(protected_types): + if '@module' in type_dict and '@class' in type_dict: + modname = type_dict['@module'] + classname = type_dict['@class'] + mod = __import__(modname, globals(), locals(), [classname], 0) + if hasattr(mod, classname): + cls_ = getattr(mod, classname) + protected_types[i] = cls_ + else: + raise ImportError(f'Could not import class {classname} from module {modname}') + else: + raise ValueError( + 'Each protected type must be a serialized EasyScience class with @module and @class keys' + ) # noqa: E501 + else: + protected_types = None + kwargs = SerializerBase.deserialize_dict(obj_dict) + data = kwargs.pop('data', []) + return cls(data, protected_types=protected_types, **kwargs) + else: + raise ValueError(f'Class name in dictionary does not match the expected class: {cls.__name__}.')