diff --git a/Dockerfile b/Dockerfile index 833f7a3ca..2388e17e2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ # # Copyright (c) 2026 Tom Kralidis # Copyright (c) 2019 Just van den Broecke -# Copyright (c) 2025 Francesco Bartoli +# Copyright (c) 2026 Francesco Bartoli # Copyright (c) 2025 Angelos Tzotsos # Copyright (c) 2023 Bernhard Mallinger # @@ -70,7 +70,6 @@ ARG ADD_DEB_PACKAGES="\ python3-netcdf4 \ python3-pandas \ python3-psycopg2 \ - python3-pydantic \ python3-pymongo \ python3-pyproj \ python3-rasterio \ diff --git a/pygeoapi/models/config.py b/pygeoapi/models/config.py index b92a64bd7..5472cd044 100644 --- a/pygeoapi/models/config.py +++ b/pygeoapi/models/config.py @@ -4,7 +4,7 @@ # Francesco Bartoli # # Copyright (c) 2023 Sander Schaminee -# Copyright (c) 2025 Francesco Bartoli +# Copyright (c) 2026 Francesco Bartoli # # Permission is hereby granted, free of charge, to any person # obtaining a copy of this software and associated documentation @@ -29,59 +29,82 @@ # # ================================================================= -from pydantic import BaseModel, Field -import pydantic +import re +from dataclasses import dataclass, fields, asdict +from typing import Any, Dict -# Handle Pydantic v1/v2 compatibility -if pydantic.VERSION.startswith('1'): - model_validator = 'parse_obj' - model_fields = '__fields__' - regex_param = {'regex': r'^\d+\.\d+\..+$'} -else: - model_validator = 'model_validate' - model_fields = 'model_fields' - regex_param = {'pattern': r'^\d+\.\d+\..+$'} +from pygeoapi.models.validation import validate_type -class APIRules(BaseModel): +SEMVER_PATTERN = re.compile(r'^\d+\.\d+\..+$') + + +class APIRulesValidationError(ValueError): + """Raised when APIRules validation fails.""" + pass + + +@dataclass +class APIRules: """ - Pydantic model for API design rules that must be adhered to. + API design rules that must be adhered to. + + Concrete dataclass implementation that can be mimicked + downstream. + + :param api_version: Semantic API version number (e.g. '1.0.0') + :param url_prefix: URL path prefix for routes (e.g. '/v1') + If set, pygeoapi routes will be prepended + with the given URL path prefix (e.g. '/v1'). + Defaults to an empty string (no prefix). + :param version_header: Response header name for API version + If set, pygeoapi will set a response + header with this name and its value will + hold the API version. + Defaults to an empty string (i.e. no header). + Often 'API-Version' or 'X-API-Version' are + used here. + :param strict_slashes: Whether trailing slashes return 404 + If False (default), URL trailing slashes + are allowed. + If True, pygeoapi will return a 404. """ - api_version: str = Field(**regex_param, - description='Semantic API version number.') - url_prefix: str = Field( - '', - description="If set, pygeoapi routes will be prepended with the " - "given URL path prefix (e.g. '/v1'). " - "Defaults to an empty string (no prefix)." - ) - version_header: str = Field( - '', - description="If set, pygeoapi will set a response header with this " - "name and its value will hold the API version. " - "Defaults to an empty string (i.e. no header). " - "Often 'API-Version' or 'X-API-Version' are used here." - ) - strict_slashes: bool = Field( - False, - description="If False (default), URL trailing slashes are allowed. " - "If True, pygeoapi will return a 404." - ) - - @staticmethod - def create(**rules_config) -> 'APIRules': + + api_version: str = '' + url_prefix: str = '' + version_header: str = '' + strict_slashes: bool = False + + def __post_init__(self): + try: + validate_type(self) + except ValueError as e: + raise APIRulesValidationError(str(e)) from e + if not SEMVER_PATTERN.match(self.api_version): + raise APIRulesValidationError( + f"Invalid semantic version: '{self.api_version}'. " + f"Expected format: MAJOR.MINOR.PATCH" + ) + + @classmethod + def create(cls, **rules_config) -> 'APIRules': """ Returns a new APIRules instance for the current API version and configured rules. + + Filters only valid fields from the config dict and + creates a validated instance. + + :param rules_config: Configuration dict + + :returns: Validated APIRules instance """ - obj = { + valid = {f.name for f in fields(cls)} + filtered = { k: v for k, v in rules_config.items() - if k in getattr(APIRules, model_fields) + if k in valid } - # Validation will fail if required `api_version` is missing - # or if `api_version` is not a semantic version number - model_validator_ = getattr(APIRules, model_validator) - return model_validator_(obj) + return cls(**filtered) @property def response_headers(self) -> dict: @@ -122,3 +145,15 @@ def get_url_prefix(self, style: str = '') -> str: else: # If no format is specified, return only the bare prefix return prefix + + def model_dump( + self, exclude_none: bool = False + ) -> Dict[str, Any]: + """Serialize to dict.""" + result = asdict(self) + if exclude_none: + result = { + k: v for k, v in result.items() + if v is not None + } + return result diff --git a/pygeoapi/models/openapi.py b/pygeoapi/models/openapi.py index 179070cc7..c6e40e197 100644 --- a/pygeoapi/models/openapi.py +++ b/pygeoapi/models/openapi.py @@ -1,10 +1,8 @@ -# ****************************** -*- -# flake8: noqa # ================================================================= # # Authors: Francesco Bartoli # -# Copyright (c) 2025 Francesco Bartoli +# Copyright (c) 2026 Francesco Bartoli # # Permission is hereby granted, free of charge, to any person # obtaining a copy of this software and associated documentation @@ -29,20 +27,55 @@ # # ================================================================= +from dataclasses import dataclass from enum import Enum +from typing import Any, Dict -from pydantic import BaseModel -import pydantic +from pygeoapi.models.validation import validate_type class SupportedFormats(Enum): JSON = 'json' YAML = 'yaml' -# Handle Pydantic v1/v2 compatibility -if pydantic.VERSION.startswith('1'): - class OAPIFormat(BaseModel): - __root__: SupportedFormats = SupportedFormats.YAML -else: - class OAPIFormat(BaseModel): - root: SupportedFormats = SupportedFormats.YAML + +@dataclass +class OAPIFormat: + """ + OpenAPI output format. + + Concrete dataclass implementation that can be mimicked + downstream. + + :param root: output format, defaults to ``yaml`` + """ + + root: SupportedFormats = SupportedFormats.YAML + + def __post_init__(self): + # Coerce str to enum before type validation + if isinstance(self.root, str): + try: + self.root = SupportedFormats(self.root) + except ValueError: + raise ValueError( + f"Unsupported format: '{self.root}'. " + f"Must be one of: " + f"{[f.value for f in SupportedFormats]}" + ) + validate_type(self) + + def __eq__(self, other): + if isinstance(other, str): + return self.root.value == other + if isinstance(other, SupportedFormats): + return self.root == other + if isinstance(other, OAPIFormat): + return self.root == other.root + return NotImplemented + + def model_dump( + self, exclude_none: bool = False + ) -> Dict[str, Any]: + """Serialize to dict.""" + return {'root': self.root.value} diff --git a/pygeoapi/models/provider/base.py b/pygeoapi/models/provider/base.py index cc67443cf..0df2b3edb 100644 --- a/pygeoapi/models/provider/base.py +++ b/pygeoapi/models/provider/base.py @@ -5,7 +5,7 @@ # Tom Kralidis # # Copyright (c) 2022 Antonio Cerciello -# Copyright (c) 2025 Francesco Bartoli +# Copyright (c) 2026 Francesco Bartoli # Copyright (c) 2025 Tom Kralidis # # Permission is hereby granted, free of charge, to any person @@ -31,15 +31,14 @@ # # ================================================================= +from dataclasses import dataclass, field from datetime import datetime from enum import Enum import json from pathlib import Path -from typing import List, Optional - -import pydantic -from pydantic import BaseModel +from typing import Any, Dict, List, Optional +from pygeoapi.models.validation import validate_type from pygeoapi.util import DEFINITIONSDIR TMS_DIR = DEFINITIONSDIR / 'tiles' @@ -77,14 +76,40 @@ class GeometryDimensionEnum(int, Enum): SOLIDS = 3 -class TileMatrixSetEnumType(BaseModel): - tileMatrixSet: str - tileMatrixSetURI: str - crs: str - title: str - orderedAxes: List[str] - wellKnownScaleSet: str - tileMatrices: List[dict] +@dataclass +class TileMatrixSetEnumType: + """Tile matrix set definition loaded from JSON.""" + + tileMatrixSet: str = '' + tileMatrixSetURI: str = '' + crs: str = '' + title: str = '' + orderedAxes: List[str] = field(default_factory=list) + wellKnownScaleSet: str = '' + tileMatrices: List[dict] = field(default_factory=list) + + def __post_init__(self): + validate_type(self) + + def model_dump( + self, exclude_none: bool = False + ) -> Dict[str, Any]: + """Serialize to dict.""" + result = { + 'tileMatrixSet': self.tileMatrixSet, + 'tileMatrixSetURI': self.tileMatrixSetURI, + 'crs': self.crs, + 'title': self.title, + 'orderedAxes': self.orderedAxes, + 'wellKnownScaleSet': self.wellKnownScaleSet, + 'tileMatrices': self.tileMatrices, + } + if exclude_none: + result = { + k: v for k, v in result.items() + if v is not None + } + return result class TileMatrixSetLoader: @@ -136,30 +161,108 @@ def create_enum(self) -> Enum: # Tile Set Metadata Sub Types -class TileMatrixLimitsType(BaseModel): - tileMatrix: str - minTileRow: int - maxTileRow: int - minTileCol: int - maxTileCol: int - -class TwoDBoundingBoxType(BaseModel): - lowerLeft: List[float] - upperRight: List[float] +@dataclass +class TileMatrixLimitsType: + """Tile matrix limits type.""" + + tileMatrix: str = '' + minTileRow: int = 0 + maxTileRow: int = 0 + minTileCol: int = 0 + maxTileCol: int = 0 + + def __post_init__(self): + validate_type(self) + + def model_dump( + self, exclude_none: bool = False + ) -> Dict[str, Any]: + """Serialize to dict.""" + result = { + 'tileMatrix': self.tileMatrix, + 'minTileRow': self.minTileRow, + 'maxTileRow': self.maxTileRow, + 'minTileCol': self.minTileCol, + 'maxTileCol': self.maxTileCol, + } + if exclude_none: + result = { + k: v for k, v in result.items() + if v is not None + } + return result + + +@dataclass +class TwoDBoundingBoxType: + """2D bounding box type.""" + + lowerLeft: List[float] = field(default_factory=list) + upperRight: List[float] = field(default_factory=list) crs: Optional[str] = None - -class LinkType(BaseModel): - href: str + def __post_init__(self): + validate_type(self) + + def model_dump( + self, exclude_none: bool = False + ) -> Dict[str, Any]: + """Serialize to dict.""" + result = { + 'lowerLeft': self.lowerLeft, + 'upperRight': self.upperRight, + 'crs': self.crs, + } + if exclude_none: + result = { + k: v for k, v in result.items() + if v is not None + } + return result + + +@dataclass +class LinkType: + """Link object.""" + + href: str = '' rel: Optional[str] = None type_: Optional[str] = None hreflang: Optional[str] = None title: Optional[str] = None length: Optional[int] = None + def __post_init__(self): + validate_type(self) + + def model_dump( + self, exclude_none: bool = False + ) -> Dict[str, Any]: + """Serialize to dict. + + Note: Renames type_ to type for JSON output. + """ + result = { + 'href': self.href, + 'rel': self.rel, + 'type': self.type_, + 'hreflang': self.hreflang, + 'title': self.title, + 'length': self.length, + } + if exclude_none: + result = { + k: v for k, v in result.items() + if v is not None + } + return result + + +@dataclass +class GeospatialDataType: + """Geospatial data reference type.""" -class GeospatialDataType(BaseModel): id: Optional[str] = None title: Optional[str] = None description: Optional[str] = None @@ -176,81 +279,157 @@ class GeospatialDataType(BaseModel): links: Optional[LinkType] = None propertiesSchema: Optional[dict] = None + def __post_init__(self): + validate_type(self) + + def model_dump( + self, exclude_none: bool = False + ) -> Dict[str, Any]: + """Serialize to dict.""" + result = {} + for key, value in self.__dict__.items(): + if value is None and exclude_none: + continue + if hasattr(value, 'model_dump'): + result[key] = value.model_dump( + exclude_none=exclude_none + ) + elif isinstance(value, Enum): + result[key] = value.value + else: + result[key] = value + return result + + +@dataclass +class StyleType: + """Style type definition.""" -class StyleType(BaseModel): id: Optional[str] = None title: Optional[str] = None description: Optional[str] = None keywords: Optional[List[str]] = None links: Optional[LinkType] = None - -class TilePointType(BaseModel): - crs: str + def __post_init__(self): + validate_type(self) + + def model_dump( + self, exclude_none: bool = False + ) -> Dict[str, Any]: + """Serialize to dict.""" + result = {} + for key, value in self.__dict__.items(): + if value is None and exclude_none: + continue + if hasattr(value, 'model_dump'): + result[key] = value.model_dump( + exclude_none=exclude_none + ) + else: + result[key] = value + return result + + +@dataclass +class TilePointType: + """Tile point type.""" + + crs: str = '' coordinates: Optional[List[float]] = None scaleDenominator: Optional[float] = None - cellSize: Optional[float] = None - # CodeType as adaptation of MD_Identifier class ISO 19115 - tileMatrix: str cellSize: Optional[str] = None + tileMatrix: str = '' + + def __post_init__(self): + validate_type(self) + + def model_dump( + self, exclude_none: bool = False + ) -> Dict[str, Any]: + """Serialize to dict.""" + result = self.__dict__.copy() + if exclude_none: + result = { + k: v for k, v in result.items() + if v is not None + } + return result -class TileSetMetadata(BaseModel): - # A title for this tileset +@dataclass +class TileSetMetadata: + """ + OGC Tile Set Metadata. + + Full metadata for a tileset compliant with + OGC API - Tiles specification. + """ + title: Optional[str] = None - # Brief narrative description of this tile set description: Optional[str] = None - # keywords about this tileset keywords: Optional[List[str]] = None - # Version of the Tile Set. Changes if the data behind the tiles - # has been changed version: Optional[str] = None - # Useful information to contact the authors or custodians for the Tile Set pointOfContact: Optional[str] = None - # Short reference to recognize the author or provider attribution: Optional[str] = None - # License applicable to the tiles license_: Optional[str] = None - # Restrictions on the availability of the Tile Set that the user needs to - # be aware of before using or redistributing the Tile Set - accessConstraints: Optional[AccessConstraintsEnum] = AccessConstraintsEnum.UNCLASSIFIED # noqa - # Media types available for the tiles - mediaTypes: Optional[List[str]] = None - # Type of data represented in the tileset + accessConstraints: Optional[AccessConstraintsEnum] = ( + AccessConstraintsEnum.UNCLASSIFIED + ) + mediaTypes: Optional[List[str]] = None dataType: DataTypeEnum = DataTypeEnum.VECTOR - # Limits for the TileRow and TileCol values for each TileMatrix in the - # tileMatrixSet. If missing, there are no limits other that the ones - # imposed by the TileMatrixSet. If present the TileMatrices listed are - # limited and the rest not available at all tileMatrixSetLimits: Optional[TileMatrixLimitsType] = None - # Coordinate Reference System (CRS) crs: Optional[str] = None - # Epoch of the Coordinate Reference System (CRS) epoch: Optional[int] = None - # Minimum bounding rectangle surrounding the tile matrix set, in the - # supported CRS boundingBox: Optional[TwoDBoundingBoxType] = None - # When the Tile Set was first produced created: Optional[datetime] = None - # Last Tile Set change/revision updated: Optional[datetime] = None layers: Optional[GeospatialDataType] = None - # Style involving all layers used to generate the tileset style: Optional[StyleType] = None - # Location of a tile that nicely represents the tileset. - # Implementations may use this center value to set the default location - # or to present a representative tile in a user interface centerPoint: Optional[TilePointType] = None - # Tile matrix set definition tileMatrixSet: Optional[str] = None - # Reference to a Tile Matrix Set on an official source tileMatrixSetURI: Optional[str] = None - # Links to related resources. links: Optional[List[LinkType]] = None + def __post_init__(self): + validate_type(self) -if pydantic.VERSION.startswith('1'): - def _dump(self, *, exclude_none: bool = False, **kwargs): - return self.dict(exclude_none=exclude_none, **kwargs) + def model_dump( + self, exclude_none: bool = False + ) -> Dict[str, Any]: + """Serialize to dict. - TileSetMetadata.model_dump = _dump + Handles nested models, enum values, and + renames license_ to license for JSON output. + """ + result = {} + for key, value in self.__dict__.items(): + out_key = 'license' if key == 'license_' else key + + if value is None and exclude_none: + continue + + if isinstance(value, list) and value: + items = [] + for item in value: + if hasattr(item, 'model_dump'): + items.append( + item.model_dump( + exclude_none=exclude_none + ) + ) + else: + items.append(item) + result[out_key] = items + elif hasattr(value, 'model_dump'): + result[out_key] = value.model_dump( + exclude_none=exclude_none + ) + elif isinstance(value, Enum): + result[out_key] = value.value + elif isinstance(value, datetime): + result[out_key] = value.isoformat() + else: + result[out_key] = value + + return result diff --git a/pygeoapi/models/provider/mvt.py b/pygeoapi/models/provider/mvt.py index 3fd6de248..bef085831 100644 --- a/pygeoapi/models/provider/mvt.py +++ b/pygeoapi/models/provider/mvt.py @@ -1,8 +1,10 @@ # ================================================================= # # Authors: Antonio Cerciello +# Francesco Bartoli # # Copyright (c) 2022 Antonio Cerciello +# Copyright (c) 2026 Francesco Bartoli # # Permission is hereby granted, free of charge, to any person # obtaining a copy of this software and associated documentation @@ -27,20 +29,47 @@ # # ================================================================= -import pydantic -from pydantic import BaseModel -from typing import List, Optional +from dataclasses import dataclass, fields as dc_fields +from typing import Any, Dict, List, Optional +from pygeoapi.models.validation import validate_type -class VectorLayers(BaseModel): - id: str - description: Optional[str] - minzoom: Optional[int] - maxzoom: Optional[int] - fields: Optional[dict] +@dataclass +class VectorLayers: + """TileJSON vector layer definition.""" + + id: str = '' + description: Optional[str] = None + minzoom: Optional[int] = None + maxzoom: Optional[int] = None + fields: Optional[dict] = None + + def __post_init__(self): + validate_type(self) + + def model_dump( + self, exclude_none: bool = False + ) -> Dict[str, Any]: + """Serialize to dict.""" + result = self.__dict__.copy() + if exclude_none: + result = { + k: v for k, v in result.items() + if v is not None + } + return result + + +@dataclass +class MVTTilesJson: + """TileJSON 3.0 specification. + + Accepts and silently ignores unknown kwargs to match + the validation behaviour when instantiated from arbitrary + JSON metadata dicts (e.g. ``MVTTilesJson(**json_data)``). + """ -class MVTTilesJson(BaseModel): tilejson: str = "3.0.0" name: Optional[str] = None tiles: Optional[str] = None @@ -52,9 +81,37 @@ class MVTTilesJson(BaseModel): description: Optional[str] = None vector_layers: Optional[List[VectorLayers]] = None + def __init__(self, **kwargs): + for f in dc_fields(self): + value = kwargs.get(f.name, getattr(self, f.name)) + # Coerce str to int for Optional[int] fields + if (value is not None + and isinstance(value, str) + and 'int' in str(f.type)): + try: + value = int(value) + except (ValueError, TypeError): + pass + setattr(self, f.name, value) + validate_type(self) -if pydantic.VERSION.startswith('1'): - def _dump(self, *, exclude_none: bool = False, **kwargs): - return self.dict(exclude_none=exclude_none, **kwargs) - - MVTTilesJson.model_dump = _dump + def model_dump( + self, exclude_none: bool = False + ) -> Dict[str, Any]: + """Serialize to dict.""" + result = {} + for key, value in self.__dict__.items(): + if value is None and exclude_none: + continue + if (key == 'vector_layers' + and isinstance(value, list) and value): + result[key] = [ + item.model_dump( + exclude_none=exclude_none + ) if hasattr(item, 'model_dump') + else item + for item in value + ] + else: + result[key] = value + return result diff --git a/pygeoapi/models/validation.py b/pygeoapi/models/validation.py new file mode 100644 index 000000000..31934b5ca --- /dev/null +++ b/pygeoapi/models/validation.py @@ -0,0 +1,112 @@ +# ================================================================= +# +# Authors: Francesco Bartoli +# +# Copyright (c) 2026 Francesco Bartoli +# +# Permission is hereby granted, free of charge, to any person +# obtaining a copy of this software and associated documentation +# files (the "Software"), to deal in the Software without +# restriction, including without limitation the rights to use, +# copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following +# conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +# OTHER DEALINGS IN THE SOFTWARE. +# +# ================================================================= + +""" +Validation utilities for dataclass models. + +Provides runtime type checking that matches validation +behaviour, using standard library only. +""" + +from dataclasses import fields as dc_fields +from typing import Any, get_type_hints + + +def validate_type(dc_instance: Any) -> None: + """ + Validate field types on a dataclass instance. + + Checks each field value against its declared type, + matching runtime type validation behaviour. + Supports Optional[T], List[T], and plain types. + + :param dc_instance: dataclass instance to validate + + :raises ValueError: if a field value has the wrong type + """ + hints = get_type_hints(dc_instance.__class__) + for f in dc_fields(dc_instance): + value = getattr(dc_instance, f.name) + expected = hints[f.name] + + # Extract inner type from Optional[T] + origin = getattr(expected, '__origin__', None) + args = getattr(expected, '__args__', ()) + + is_optional = ( + origin is type(None) # noqa: E721 + or (origin is not None + and type(None) in args) + ) + + if is_optional and value is None: + continue + + # Unwrap Optional to get the inner type + if is_optional and args: + inner_types = [ + a for a in args if a is not type(None) + ] + if len(inner_types) == 1: + expected = inner_types[0] + origin = getattr( + expected, '__origin__', None + ) + args = getattr(expected, '__args__', ()) + + # Check List[T] + if origin is list: + if not isinstance(value, list): + raise ValueError( + f"{f.name} must be a list, " + f"got {type(value).__name__}" + ) + # Check plain types (str, int, float, bool, Enum) + elif origin is None: + if isinstance(expected, type): + # bool is subclass of int, check bool first + if expected is bool: + if not isinstance(value, bool): + raise ValueError( + f"{f.name} must be a bool, " + f"got {type(value).__name__}" + ) + elif expected is int: + if isinstance(value, bool) \ + or not isinstance(value, int): + raise ValueError( + f"{f.name} must be an int, " + f"got {type(value).__name__}" + ) + elif not isinstance(value, expected): + raise ValueError( + f"{f.name} must be a " + f"{expected.__name__}, " + f"got {type(value).__name__}" + ) diff --git a/pygeoapi/plugin.py b/pygeoapi/plugin.py index 32292c895..2ecf5bfb6 100644 --- a/pygeoapi/plugin.py +++ b/pygeoapi/plugin.py @@ -1,8 +1,10 @@ # ================================================================= # # Authors: Tom Kralidis +# Francesco Bartoli # # Copyright (c) 2026 Tom Kralidis +# Copyright (c) 2026 Francesco Bartoli # # Permission is hereby granted, free of charge, to any person # obtaining a copy of this software and associated documentation @@ -30,10 +32,58 @@ import importlib import logging -from typing import Any +from dataclasses import dataclass +from typing import Any, Dict, List, Optional LOGGER = logging.getLogger(__name__) + +@dataclass +class PluginContext: + """ + Inject dependencies with a context object into plugins. + + This allows passing runtime dependencies to plugins without + relying on global state or complex config dictionaries. + + Attributes: + config: Original plugin configuration dictionary + logger: Optional injected logger instance + locales: Optional list of supported locale codes + base_url: Optional API base URL for link generation + + Example: + >>> from pygeoapi.plugin import PluginContext, load_plugin + >>> context = PluginContext( + ... config={'name': 'GeoJSON', 'type': 'feature', + ... 'data': 'obs.geojson'}, + ... logger=custom_logger, + ... base_url='https://api.example.com' + ... ) + >>> provider = load_plugin('provider', context.config, context=context) + """ + + config: Dict[str, Any] + logger: Optional[Any] = None + locales: Optional[List[str]] = None + base_url: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """ + Convert to plain dict format for backwards compatibility. + + :returns: Dictionary with config and injected dependencies + """ + result = dict(self.config) + if self.logger: + result["_logger"] = self.logger + if self.base_url: + result["_base_url"] = self.base_url + if self.locales: + result["_locales"] = self.locales + return result + + #: Loads provider plugins to be used by pygeoapi,\ #: formatters and processes available PLUGINS = { @@ -94,14 +144,33 @@ } -def load_plugin(plugin_type: str, plugin_def: dict) -> Any: +def load_plugin( + plugin_type: str, plugin_def: dict, context: Optional[PluginContext] = None +) -> Any: """ - loads plugin by name + Loads plugin by name with optional dependency injection. - :param plugin_type: type of plugin (provider, formatter) - :param plugin_def: plugin definition + :param plugin_type: type of plugin (provider, formatter, process, etc.) + :param plugin_def: plugin definition dictionary + :param context: optional context with injected dependencies :returns: plugin object + + Example: + # Plain mode (backwards compatible) + >>> provider = load_plugin('provider', { + ... 'name': 'GeoJSON', + ... 'type': 'feature', + ... 'data': 'obs.geojson' + ... }) + + # Modern mode with dependencies + >>> context = PluginContext( + ... config={'name': 'GeoJSON', 'type': 'feature', + ... 'data': 'obs.geojson'}, + ... logger=custom_logger + ... ) + >>> provider = load_plugin('provider', context.config, context=context) """ name = plugin_def['name'] @@ -130,7 +199,23 @@ def load_plugin(plugin_type: str, plugin_def: dict) -> Any: module = importlib.import_module(packagename) class_ = getattr(module, classname) - plugin = class_(plugin_def) + + # Support injected dependencies via PluginContext + if context is not None: + # Try context-aware constructor first + try: + plugin = class_(plugin_def, context=context) + LOGGER.debug(f"{name} initialized with PluginContext") + except TypeError as err: + # Fallback: legacy constructor without context parameter + LOGGER.debug( + f"{name} does not support PluginContext, " + f"using legacy init: {err}" + ) + plugin = class_(plugin_def) + else: + # Plain mode: no more context provided + plugin = class_(plugin_def) return plugin diff --git a/pygeoapi/process/base.py b/pygeoapi/process/base.py index 9e2136476..04e10bc02 100644 --- a/pygeoapi/process/base.py +++ b/pygeoapi/process/base.py @@ -2,9 +2,11 @@ # # Authors: Tom Kralidis # Francesco Martinelli +# Francesco Bartoli # # Copyright (c) 2022 Tom Kralidis # Copyright (c) 2024 Francesco Martinelli +# Copyright (c) 2026 Francesco Bartoli # # Permission is hereby granted, free of charge, to any person # obtaining a copy of this software and associated documentation @@ -29,23 +31,32 @@ # # ================================================================= +from __future__ import annotations + import logging -from typing import Any, Tuple, Optional +from typing import Any, Optional, Tuple, TYPE_CHECKING from pygeoapi.error import GenericError +if TYPE_CHECKING: + from pygeoapi.plugin import PluginContext + LOGGER = logging.getLogger(__name__) class BaseProcessor: """generic Processor ABC. Processes are inherited from this class""" - def __init__(self, processor_def: dict, process_metadata: dict): + def __init__( + self, processor_def: dict, process_metadata: dict, + context: Optional[PluginContext] = None, + ): """ Initialize object :param processor_def: processor definition :param process_metadata: process metadata `dict` + :param context: optional PluginContext with injected dependencies :returns: pygeoapi.processor.base.BaseProvider """ @@ -54,6 +65,18 @@ def __init__(self, processor_def: dict, process_metadata: dict): self.metadata = process_metadata self.supports_outputs = False + # Dependencies support + self._context = context + if context and context.logger: + self._logger = context.logger + else: + self._logger = LOGGER # Global fallback + + @property + def logger(self): + """Get logger (injected or global)""" + return self._logger + def set_job_id(self, job_id: str) -> None: """ Set the job_id within the processor diff --git a/pygeoapi/process/manager/base.py b/pygeoapi/process/manager/base.py index fadb90fbd..c3ebb46b0 100644 --- a/pygeoapi/process/manager/base.py +++ b/pygeoapi/process/manager/base.py @@ -31,17 +31,22 @@ # # ================================================================= +from __future__ import annotations + import collections import json import logging +import uuid from multiprocessing import dummy from pathlib import Path -from typing import Any, Dict, Tuple, Optional, OrderedDict -import uuid +from typing import Any, Dict, Optional, OrderedDict, Tuple, TYPE_CHECKING import requests from pygeoapi.plugin import load_plugin + +if TYPE_CHECKING: + from pygeoapi.plugin import PluginContext from pygeoapi.process.base import ( BaseProcessor, JobNotFoundError, @@ -64,11 +69,15 @@ class BaseManager: """generic Manager ABC""" processes: OrderedDict[str, Dict] - def __init__(self, manager_def: dict): + def __init__( + self, manager_def: dict, + context: Optional[PluginContext] = None, + ): """ Initialize object :param manager_def: manager definition + :param context: optional PluginContext with injected dependencies :returns: `pygeoapi.process.manager.base.BaseManager` """ @@ -91,6 +100,18 @@ def __init__(self, manager_def: dict): for id_, process_conf in manager_def.get('processes', {}).items(): self.processes[id_] = dict(process_conf) + # Dependencies support + self._context = context + if context and context.logger: + self._logger = context.logger + else: + self._logger = LOGGER # Global fallback + + @property + def logger(self): + """Get logger (injected or global)""" + return self._logger + def get_processor(self, process_id: str) -> BaseProcessor: """Instantiate a processor. @@ -426,7 +447,7 @@ def execute_process( # do we support sync? process_supports_sync = ( ProcessExecutionMode.sync_execute.value in job_control_options - ) + ) if not process_supports_sync: LOGGER.debug('Asynchronous execution') handler = self._execute_handler_async @@ -489,7 +510,7 @@ def _send_in_progress_notification(self, subscriber: Optional[Subscriber]): ) def _send_success_notification( - self, subscriber: Optional[Subscriber], outputs: Any + self, subscriber: Optional[Subscriber], outputs: Any ): if subscriber: response = requests.post(subscriber.success_uri, json=outputs) diff --git a/pygeoapi/provider/base.py b/pygeoapi/provider/base.py index 3dd4c5a18..62c2834d6 100644 --- a/pygeoapi/provider/base.py +++ b/pygeoapi/provider/base.py @@ -27,14 +27,20 @@ # # ================================================================= +from __future__ import annotations + import json import logging from enum import Enum from http import HTTPStatus +from typing import TYPE_CHECKING, Optional from pygeoapi.crs import DEFAULT_STORAGE_CRS, get_crs from pygeoapi.error import GenericError +if TYPE_CHECKING: + from pygeoapi.plugin import PluginContext + LOGGER = logging.getLogger(__name__) @@ -48,11 +54,12 @@ class SchemaType(Enum): class BaseProvider: """generic Provider ABC""" - def __init__(self, provider_def): + def __init__(self, provider_def, context: Optional[PluginContext] = None): """ Initialize object :param provider_def: provider definition + :param context: optional PluginContext with injected dependencies :returns: pygeoapi.provider.base.BaseProvider """ @@ -91,6 +98,28 @@ def __init__(self, provider_def): self.crs = None self.num_bands = None + # Dependencies support + self._context = context + if context and context.logger: + self._logger = context.logger + else: + self._logger = LOGGER # Global fallback + + if context and context.locales: + self._locales = context.locales + else: + self._locales = [] + + @property + def logger(self): + """Get logger (injected or global)""" + return self._logger + + @property + def locales(self): + """Get supported locales (injected or global)""" + return self._locales + def get_fields(self): """ Get provider field information (names, types) diff --git a/requirements-dev.txt b/requirements-dev.txt index d02e34262..45c423b23 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -13,7 +13,7 @@ starlette uvicorn[standard] httpx -# Pydantic/Dataclasses models +# Dataclasses models polyfactory # PEP8 diff --git a/requirements-django.txt b/requirements-django.txt index 92d78d376..0195664b2 100644 --- a/requirements-django.txt +++ b/requirements-django.txt @@ -1,3 +1,2 @@ Django django-cors-headers -pydantic diff --git a/requirements.txt b/requirements.txt index 82244eb4b..79e59c78c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,6 @@ filelock Flask jinja2 jsonschema -pydantic pygeofilter pygeoif pyproj diff --git a/tests/test_models.py b/tests/test_models.py index f2419e455..83d5d4f9d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -3,7 +3,7 @@ # Authors: Francesco Bartoli # Tom Kralidis: # -# Copyright (c) 2025 Francesco Bartoli +# Copyright (c) 2026 Francesco Bartoli # Copyright (c) 2024 Tom Kralidis # # Permission is hereby granted, free of charge, to any person @@ -29,19 +29,665 @@ # # ================================================================= -from polyfactory.factories.pydantic_factory import ModelFactory +import pytest + +from polyfactory.factories.dataclass_factory import DataclassFactory from polyfactory.pytest_plugin import register_fixture -from pygeoapi.models.provider.base import GeospatialDataType +from pygeoapi.models.config import APIRules, APIRulesValidationError +from pygeoapi.models.openapi import OAPIFormat, SupportedFormats +from pygeoapi.models.provider.mvt import MVTTilesJson, VectorLayers +from pygeoapi.models.provider.base import ( + GeospatialDataType, + LinkType, + TileMatrixLimitsType, + TileMatrixSetEnumType, + TilePointType, + TileSetMetadata, + TwoDBoundingBoxType, + StyleType, + DataTypeEnum, +) @register_fixture -class GeospatialDataTypeFactory(ModelFactory[GeospatialDataType]): +class GeospatialDataTypeFactory(DataclassFactory[GeospatialDataType]): ... def test_provider_base_geospatial_data_type( - geospatial_data_type_factory: GeospatialDataTypeFactory) -> None: + geospatial_data_type_factory: GeospatialDataTypeFactory +) -> None: gdt_instance = geospatial_data_type_factory.build() assert gdt_instance.model_dump() assert isinstance(gdt_instance, GeospatialDataType) + + +# OAPIFormat dataclass tests + +class TestOAPIFormatCreation: + """Test OAPIFormat instantiation and validation.""" + + def test_default_is_yaml(self): + fmt = OAPIFormat() + assert fmt.root == SupportedFormats.YAML + + def test_create_with_enum(self): + fmt = OAPIFormat(root=SupportedFormats.JSON) + assert fmt.root == SupportedFormats.JSON + + def test_create_with_string_yaml(self): + fmt = OAPIFormat(root='yaml') + assert fmt.root == SupportedFormats.YAML + + def test_create_with_string_json(self): + fmt = OAPIFormat(root='json') + assert fmt.root == SupportedFormats.JSON + + def test_invalid_string_raises(self): + with pytest.raises(ValueError, match='Unsupported format'): + OAPIFormat(root='xml') + + def test_non_string_int_raises(self): + with pytest.raises(ValueError, match='must be a'): + OAPIFormat(root=42) + + def test_non_string_bool_raises(self): + with pytest.raises(ValueError, match='must be a'): + OAPIFormat(root=True) + + def test_non_string_none_raises(self): + with pytest.raises(ValueError, match='must be a'): + OAPIFormat(root=None) + + +class TestOAPIFormatEquality: + """Test OAPIFormat comparison with strings, enums, instances.""" + + def test_eq_string_yaml(self): + fmt = OAPIFormat(root='yaml') + assert fmt == 'yaml' + assert not (fmt == 'json') + + def test_eq_string_json(self): + fmt = OAPIFormat(root='json') + assert fmt == 'json' + assert not (fmt == 'yaml') + + def test_eq_enum(self): + fmt = OAPIFormat() + assert fmt == SupportedFormats.YAML + + def test_eq_instance(self): + assert OAPIFormat(root='yaml') == OAPIFormat() + + def test_eq_unsupported_type(self): + fmt = OAPIFormat() + assert not (fmt == 42) + + +class TestOAPIFormatModelDump: + """Test model_dump method.""" + + def test_model_dump_yaml(self): + fmt = OAPIFormat() + result = fmt.model_dump() + assert result == {'root': 'yaml'} + + def test_model_dump_json(self): + fmt = OAPIFormat(root='json') + result = fmt.model_dump() + assert result == {'root': 'json'} + + +class TestOAPIFormatOpenAPICompatibility: + """Test compatibility with openapi.py / asyncapi.py usage. + + Both modules receive a format string from click and compare + it directly: ``if output_format == 'yaml'``. + """ + + def test_click_string_passthrough(self): + """Simulates click passing a plain string.""" + format_ = 'yaml' + fmt = OAPIFormat(root=format_) + assert fmt == 'yaml' + + def test_click_json_passthrough(self): + format_ = 'json' + fmt = OAPIFormat(root=format_) + assert fmt == 'json' + + def test_plain_string_still_works(self): + """Current code passes raw strings, not OAPIFormat + instances. Verify raw strings match as before.""" + format_ = 'yaml' + assert format_ == 'yaml' + + +# APIRules dataclass tests + +class TestAPIRulesCreation: + """Test APIRules instantiation and validation.""" + + def test_valid_semver(self): + rules = APIRules(api_version='1.0.0') + assert rules.api_version == '1.0.0' + assert rules.url_prefix == '' + assert rules.version_header == '' + assert rules.strict_slashes is False + + def test_valid_semver_with_build(self): + rules = APIRules(api_version='2.1.3-beta') + assert rules.api_version == '2.1.3-beta' + + def test_invalid_semver_raises(self): + with pytest.raises( + APIRulesValidationError, + match='Invalid semantic version' + ): + APIRules(api_version='not-a-version') + + def test_empty_version_raises(self): + with pytest.raises(APIRulesValidationError): + APIRules(api_version='') + + def test_non_string_version_raises(self): + with pytest.raises( + APIRulesValidationError, match='must be a str' + ): + APIRules(api_version=123) + + def test_non_string_url_prefix_raises(self): + with pytest.raises( + APIRulesValidationError, match='url_prefix must be a str' + ): + APIRules(api_version='1.0.0', url_prefix=42) + + def test_non_string_version_header_raises(self): + with pytest.raises( + APIRulesValidationError, + match='version_header must be a str', + ): + APIRules(api_version='1.0.0', version_header=True) + + def test_non_bool_strict_slashes_raises(self): + with pytest.raises( + APIRulesValidationError, + match='strict_slashes must be a bool', + ): + APIRules(api_version='1.0.0', strict_slashes='yes') + + def test_full_creation(self): + rules = APIRules( + api_version='1.0.0', + url_prefix='/v{api_major}', + version_header='X-API-Version', + strict_slashes=True, + ) + assert rules.url_prefix == '/v{api_major}' + assert rules.version_header == 'X-API-Version' + assert rules.strict_slashes is True + + +class TestAPIRulesFactory: + """Test APIRules.create() factory method.""" + + def test_create_filters_unknown_keys(self): + rules = APIRules.create( + api_version='1.0.0', + unknown_key='ignored', + ) + assert rules.api_version == '1.0.0' + assert not hasattr(rules, 'unknown_key') + + def test_create_with_all_valid_keys(self): + rules = APIRules.create( + api_version='1.0.0', + url_prefix='/api', + version_header='API-Version', + strict_slashes=True, + ) + assert rules.url_prefix == '/api' + assert rules.strict_slashes is True + + def test_create_missing_version_raises(self): + with pytest.raises(APIRulesValidationError): + APIRules.create(url_prefix='/api') + + +class TestAPIRulesResponseHeaders: + """Test response_headers property.""" + + def test_no_header_configured(self): + rules = APIRules(api_version='1.0.0') + assert rules.response_headers == {} + + def test_header_configured(self): + rules = APIRules( + api_version='1.0.0', + version_header='X-API-Version', + ) + assert rules.response_headers == { + 'X-API-Version': '1.0.0' + } + + +class TestAPIRulesURLPrefix: + """Test get_url_prefix() with framework styles.""" + + def test_no_prefix(self): + rules = APIRules(api_version='1.0.0') + assert rules.get_url_prefix() == '' + + def test_bare_prefix(self): + rules = APIRules( + api_version='1.0.0', url_prefix='/v{api_major}' + ) + assert rules.get_url_prefix() == 'v1' + + def test_full_version_prefix(self): + rules = APIRules( + api_version='1.2.3', + url_prefix='/api/v{api_version}', + ) + assert rules.get_url_prefix() == 'api/v1.2.3' + + def test_flask_style(self): + rules = APIRules( + api_version='1.0.0', url_prefix='/v{api_major}' + ) + assert rules.get_url_prefix('flask') == '/v1' + + def test_starlette_style(self): + rules = APIRules( + api_version='1.0.0', url_prefix='/v{api_major}' + ) + assert rules.get_url_prefix('starlette') == '/v1' + + def test_django_style(self): + rules = APIRules( + api_version='1.0.0', url_prefix='/v{api_major}' + ) + assert rules.get_url_prefix('django') == r'^v1/' + + +class TestAPIRulesModelDump: + """Test model_dump interface for duck-typing compatibility.""" + + def test_model_dump(self): + rules = APIRules(api_version='1.0.0') + result = rules.model_dump() + assert result == { + 'api_version': '1.0.0', + 'url_prefix': '', + 'version_header': '', + 'strict_slashes': False, + } + + def test_model_dump_exclude_none(self): + rules = APIRules(api_version='1.0.0') + result = rules.model_dump(exclude_none=True) + assert 'api_version' in result + + +# Tile models dataclass tests + +class TestTileMatrixSetEnumTypeValidation: + """Test TileMatrixSetEnumType type validation.""" + + def test_valid_creation(self): + tms_enum = TileMatrixSetEnumType( + tileMatrixSet='WebMercatorQuad', + tileMatrixSetURI='http://example.com', + crs='EPSG:3857', + title='Web Mercator', + orderedAxes=['X', 'Y'], + wellKnownScaleSet='GoogleMaps', + tileMatrices=[{'id': '0'}], + ) + assert tms_enum.tileMatrixSet == 'WebMercatorQuad' + + def test_non_string_tileMatrixSet_raises(self): + with pytest.raises(ValueError): + TileMatrixSetEnumType(tileMatrixSet=123) + + def test_non_list_orderedAxes_raises(self): + with pytest.raises(ValueError): + TileMatrixSetEnumType(orderedAxes='X,Y') + + def test_model_dump(self): + tms_enum = TileMatrixSetEnumType( + tileMatrixSet='Test', crs='EPSG:4326', + title='Test', tileMatrixSetURI='http://x', + orderedAxes=['X', 'Y'], + tileMatrices=[], + ) + result = tms_enum.model_dump() + assert result['tileMatrixSet'] == 'Test' + assert result['crs'] == 'EPSG:4326' + + +class TestTileMatrixLimitsTypeValidation: + """Test TileMatrixLimitsType type validation.""" + + def test_valid_creation(self): + tm_limits = TileMatrixLimitsType( + tileMatrix='0', minTileRow=0, + maxTileRow=1, minTileCol=0, maxTileCol=1, + ) + assert tm_limits.tileMatrix == '0' + + def test_non_string_tileMatrix_raises(self): + with pytest.raises(ValueError): + TileMatrixLimitsType(tileMatrix=0) + + def test_non_int_minTileRow_raises(self): + with pytest.raises(ValueError): + TileMatrixLimitsType( + tileMatrix='0', minTileRow='zero' + ) + + def test_model_dump(self): + tm_limits = TileMatrixLimitsType( + tileMatrix='0', minTileRow=0, + maxTileRow=10, minTileCol=0, maxTileCol=10, + ) + result = tm_limits.model_dump() + assert result['maxTileRow'] == 10 + + +class TestTwoDBoundingBoxTypeValidation: + """Test TwoDBoundingBoxType type validation.""" + + def test_valid_creation(self): + bbox = TwoDBoundingBoxType( + lowerLeft=[-180.0, -90.0], + upperRight=[180.0, 90.0], + ) + assert bbox.lowerLeft == [-180.0, -90.0] + + def test_non_list_lowerLeft_raises(self): + with pytest.raises(ValueError): + TwoDBoundingBoxType(lowerLeft='not a list') + + def test_non_optional_str_crs_raises(self): + with pytest.raises(ValueError): + TwoDBoundingBoxType(crs=42) + + def test_model_dump_exclude_none(self): + bbox = TwoDBoundingBoxType( + lowerLeft=[-180.0, -90.0], + upperRight=[180.0, 90.0], + ) + result = bbox.model_dump(exclude_none=True) + assert 'crs' not in result + + +class TestLinkTypeValidation: + """Test LinkType type validation.""" + + def test_valid_creation(self): + link = LinkType( + href='http://example.com', + rel='item', + type_='application/json', + ) + assert link.href == 'http://example.com' + + def test_non_string_href_raises(self): + with pytest.raises(ValueError): + LinkType(href=42) + + def test_non_string_rel_raises(self): + with pytest.raises(ValueError): + LinkType(href='http://x', rel=123) + + def test_non_int_length_raises(self): + with pytest.raises(ValueError): + LinkType(href='http://x', length='big') + + def test_model_dump_renames_type(self): + link = LinkType( + href='http://x', rel='item', + type_='application/json', + ) + result = link.model_dump(exclude_none=True) + assert 'type' in result + assert 'type_' not in result + + def test_model_dump_exclude_none(self): + link = LinkType(href='http://x') + result = link.model_dump(exclude_none=True) + assert 'href' in result + assert 'rel' not in result + assert 'title' not in result + + +class TestGeospatialDataTypeValidation: + """Test GeospatialDataType type validation.""" + + def test_valid_creation(self): + geo_dt = GeospatialDataType( + id='layer1', title='Layer 1', + dataType=DataTypeEnum.VECTOR, + ) + assert geo_dt.id == 'layer1' + + def test_non_string_id_raises(self): + with pytest.raises(ValueError): + GeospatialDataType(id=123) + + def test_model_dump_enum_value(self): + geo_dt = GeospatialDataType( + id='layer1', dataType=DataTypeEnum.VECTOR, + ) + result = geo_dt.model_dump(exclude_none=True) + assert result['dataType'] == 'vector' + + +class TestStyleTypeValidation: + """Test StyleType type validation.""" + + def test_valid_creation(self): + style = StyleType(id='default', title='Default') + assert style.id == 'default' + + def test_non_string_id_raises(self): + with pytest.raises(ValueError): + StyleType(id=42) + + +class TestTilePointTypeValidation: + """Test TilePointType type validation.""" + + def test_valid_creation(self): + tp = TilePointType( + crs='EPSG:4326', tileMatrix='0', + coordinates=[0.0, 0.0], + ) + assert tp.crs == 'EPSG:4326' + + def test_non_string_crs_raises(self): + with pytest.raises(ValueError): + TilePointType(crs=42) + + def test_non_string_tileMatrix_raises(self): + with pytest.raises(ValueError): + TilePointType(crs='EPSG:4326', tileMatrix=0) + + +class TestTileSetMetadataValidation: + """Test TileSetMetadata type validation.""" + + def test_valid_creation(self): + ts_meta = TileSetMetadata( + title='Test', crs='EPSG:4326', + ) + assert ts_meta.title == 'Test' + + def test_non_string_title_raises(self): + with pytest.raises(ValueError): + TileSetMetadata(title=42) + + def test_non_string_crs_raises(self): + with pytest.raises(ValueError): + TileSetMetadata(crs=123) + + def test_model_dump_renames_license(self): + ts_meta = TileSetMetadata(license_='MIT') + result = ts_meta.model_dump(exclude_none=True) + assert 'license' in result + assert 'license_' not in result + + def test_model_dump_nested_links(self): + link = LinkType( + href='http://x', rel='item', + type_='application/json', + ) + ts_meta = TileSetMetadata(links=[link]) + result = ts_meta.model_dump(exclude_none=True) + assert isinstance(result['links'][0], dict) + assert result['links'][0]['href'] == 'http://x' + assert 'type' in result['links'][0] + + def test_model_dump_exclude_none(self): + ts_meta = TileSetMetadata(title='Test') + result = ts_meta.model_dump(exclude_none=True) + assert 'title' in result + assert 'description' not in result + assert 'version' not in result + + +# MVT models dataclass tests + +class TestVectorLayersValidation: + """Test VectorLayers type validation.""" + + def test_valid_creation(self): + vector_lyr = VectorLayers( + id='layer1', description='A layer', + minzoom=0, maxzoom=14, + fields={'name': 'String'}, + ) + assert vector_lyr.id == 'layer1' + + def test_non_string_id_raises(self): + with pytest.raises(ValueError): + VectorLayers(id=42) + + def test_non_int_minzoom_raises(self): + with pytest.raises(ValueError): + VectorLayers(id='layer1', minzoom='zero') + + def test_model_dump(self): + vector_lyr = VectorLayers( + id='layer1', minzoom=0, maxzoom=14, + ) + result = vector_lyr.model_dump(exclude_none=True) + assert result['id'] == 'layer1' + assert result['minzoom'] == 0 + assert 'description' not in result + + def test_model_dump_all_fields(self): + vector_lyr = VectorLayers( + id='layer1', description='desc', + minzoom=0, maxzoom=14, + fields={'name': 'String'}, + ) + result = vector_lyr.model_dump() + assert result['fields'] == {'name': 'String'} + + +class TestMVTTilesJsonValidation: + """Test MVTTilesJson type validation.""" + + def test_valid_creation_defaults(self): + mvt_tj = MVTTilesJson() + assert mvt_tj.tilejson == '3.0.0' + assert mvt_tj.name is None + + def test_valid_creation_full(self): + mvt_tj = MVTTilesJson( + tilejson='3.0.0', + name='test', + tiles='http://example.com/{z}/{x}/{y}.pbf', + minzoom=0, maxzoom=14, + description='Test tileset', + ) + assert mvt_tj.name == 'test' + + def test_non_string_tilejson_raises(self): + with pytest.raises(ValueError): + MVTTilesJson(tilejson=3) + + def test_non_string_name_raises(self): + with pytest.raises(ValueError): + MVTTilesJson(name=42) + + def test_non_int_minzoom_raises(self): + with pytest.raises(ValueError): + MVTTilesJson(minzoom='zero') + + def test_kwargs_from_dict(self): + """Test instantiation from dict, as used in providers.""" + data = { + 'tilejson': '3.0.0', + 'name': 'test', + 'tiles': 'http://example.com/{z}/{x}/{y}.pbf', + } + mvt_tj = MVTTilesJson(**data) + assert mvt_tj.name == 'test' + + def test_kwargs_ignores_unknown_keys(self): + """Extra keys in JSON metadata are silently ignored.""" + data = { + 'tilejson': '3.0.0', + 'name': 'test', + 'version': 2, + 'format': 'pbf', + 'json': '{}', + } + mvt_tj = MVTTilesJson(**data) + assert mvt_tj.name == 'test' + assert not hasattr(mvt_tj, 'format') + + def test_str_to_int_coercion(self): + """String zoom values are coerced to int.""" + mvt_tj = MVTTilesJson(minzoom='0', maxzoom='14') + assert mvt_tj.minzoom == 0 + assert mvt_tj.maxzoom == 14 + + def test_model_dump(self): + mvt_tj = MVTTilesJson( + name='test', + tiles='http://example.com', + ) + result = mvt_tj.model_dump() + assert result['tilejson'] == '3.0.0' + assert result['name'] == 'test' + + def test_model_dump_exclude_none(self): + mvt_tj = MVTTilesJson() + result = mvt_tj.model_dump(exclude_none=True) + assert 'tilejson' in result + assert 'name' not in result + + def test_model_dump_with_vector_layers(self): + vector_lyr = VectorLayers(id='layer1', minzoom=0, maxzoom=14) + mvt_tj = MVTTilesJson(vector_layers=[vector_lyr]) + result = mvt_tj.model_dump(exclude_none=True) + assert isinstance(result['vector_layers'][0], dict) + assert result['vector_layers'][0]['id'] == 'layer1' + + def test_field_assignment_after_creation(self): + """Test that fields can be set after init, + as done in mvt_tippecanoe.py.""" + mvt_tj = MVTTilesJson() + mvt_tj.tiles = 'http://example.com' + mvt_tj.vector_layers = [ + VectorLayers(id='l1', minzoom=0, maxzoom=5) + ] + assert mvt_tj.tiles == 'http://example.com' + result = mvt_tj.model_dump(exclude_none=True) + assert result['vector_layers'][0]['id'] == 'l1' diff --git a/tests/test_plugin_context.py b/tests/test_plugin_context.py new file mode 100644 index 000000000..646dd11df --- /dev/null +++ b/tests/test_plugin_context.py @@ -0,0 +1,193 @@ +# ================================================================= +# +# Authors: Francesco Bartoli +# +# Copyright (c) 2026 Francesco Bartoli +# +# Permission is hereby granted, free of charge, to any person +# obtaining a copy of this software and associated documentation +# files (the "Software"), to deal in the Software without +# restriction, including without limitation the rights to use, +# copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following +# conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +# OTHER DEALINGS IN THE SOFTWARE. +# +# ================================================================= + +"""Tests for PluginContext dataclass""" + +import logging +from unittest.mock import Mock + +from pygeoapi.plugin import PluginContext + + +def test_plugin_context_creation_minimal(): + """Test creating PluginContext with minimal config""" + config = { + "name": "GeoJSON", + "type": "feature", + "data": "tests/data/obs.geojson", + } + + context = PluginContext(config=config) + + assert context.config == config + assert context.logger is None + assert context.locales is None + assert context.base_url is None + + +def test_plugin_context_creation_with_logger(): + """Test creating PluginContext with custom logger""" + config = {"name": "Test", "type": "feature", "data": "test.json"} + mock_logger = Mock(spec=logging.Logger) + + context = PluginContext(config=config, logger=mock_logger) + + assert context.config == config + assert context.logger == mock_logger + assert context.locales is None + assert context.base_url is None + + +def test_plugin_context_creation_full(): + """Test creating PluginContext with all parameters""" + config = {"name": "Test", "type": "feature", "data": "test.json"} + mock_logger = Mock(spec=logging.Logger) + locales = ["en", "it", "fr"] + base_url = "https://api.example.com" + + context = PluginContext( + config=config, logger=mock_logger, locales=locales, base_url=base_url + ) + + assert context.config == config + assert context.logger == mock_logger + assert context.locales == locales + assert context.base_url == base_url + + +def test_plugin_context_to_dict_minimal(): + """Test converting PluginContext to dict with minimal config""" + config = {"name": "Test", "type": "feature", "data": "test.json"} + context = PluginContext(config=config) + + result = context.to_dict() + + assert result == config + assert "_logger" not in result + assert "_locales" not in result + assert "_base_url" not in result + + +def test_plugin_context_to_dict_with_logger(): + """Test converting PluginContext to dict with logger""" + config = {"name": "Test", "type": "feature", "data": "test.json"} + mock_logger = Mock(spec=logging.Logger) + + context = PluginContext(config=config, logger=mock_logger) + result = context.to_dict() + + assert result["name"] == "Test" + assert result["type"] == "feature" + assert result["data"] == "test.json" + assert result["_logger"] == mock_logger + + +def test_plugin_context_to_dict_full(): + """Test converting PluginContext to dict with all fields""" + config = {"name": "Test", "type": "feature", "data": "test.json"} + mock_logger = Mock(spec=logging.Logger) + locales = ["en", "it"] + base_url = "https://api.example.com" + + context = PluginContext( + config=config, logger=mock_logger, locales=locales, base_url=base_url + ) + result = context.to_dict() + + # Original config preserved + assert result["name"] == "Test" + assert result["type"] == "feature" + assert result["data"] == "test.json" + + # Injected dependencies added with underscore prefix + assert result["_logger"] == mock_logger + assert result["_locales"] == locales + assert result["_base_url"] == base_url + + +def test_plugin_context_extensible_subclassing(): + """Test that PluginContext can be extended via subclassing""" + from dataclasses import dataclass + from typing import Optional + + @dataclass + class ExtendedContext(PluginContext): + """Extended context with custom fields""" + + metrics_collector: Optional[Mock] = None + cache_backend: Optional[Mock] = None + + config = {"name": "Test", "type": "feature", "data": "test.json"} + mock_metrics = Mock() + mock_cache = Mock() + + context = ExtendedContext( + config=config, + logger=Mock(), + metrics_collector=mock_metrics, + cache_backend=mock_cache, + ) + + # Standard fields work + assert context.config == config + assert context.logger is not None + + # Extended fields work + assert context.metrics_collector == mock_metrics + assert context.cache_backend == mock_cache + + # to_dict() still works (includes base fields only) + result = context.to_dict() + assert "_logger" in result + # Extended fields not in to_dict (as expected) + + +def test_plugin_context_config_immutability(): + """Test that modifying config dict doesn't affect context""" + original_config = {"name": "Test", "type": "feature", "data": "test.json"} + context = PluginContext(config=original_config) + + # Modify original config + original_config["name"] = "Modified" + + # Context config should be affected (dict is mutable) + # This is expected behavior - user should copy if needed + assert context.config["name"] == "Modified" + + +def test_plugin_context_config_copy(): + """Test creating context with config copy for immutability""" + original_config = {"name": "Test", "type": "feature", "data": "test.json"} + context = PluginContext(config=dict(original_config)) + + # Modify original config + original_config["name"] = "Modified" + + # Context config should be unchanged + assert context.config["name"] == "Test" diff --git a/tests/test_plugin_dependencies.py b/tests/test_plugin_dependencies.py new file mode 100644 index 000000000..8e307635e --- /dev/null +++ b/tests/test_plugin_dependencies.py @@ -0,0 +1,271 @@ +# ================================================================= +# +# Authors: Francesco Bartoli +# +# Copyright (c) 2026 Francesco Bartoli +# +# Permission is hereby granted, free of charge, to any person +# obtaining a copy of this software and associated documentation +# files (the "Software"), to deal in the Software without +# restriction, including without limitation the rights to use, +# copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following +# conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +# OTHER DEALINGS IN THE SOFTWARE. +# +# ================================================================= + +"""Tests for dependency injection in plugin loading""" + +import logging +from pathlib import Path +from unittest.mock import Mock + +import pytest + +from pygeoapi.plugin import PluginContext, load_plugin + + +@pytest.fixture +def geojson_config(): + """Fixture for GeoJSON provider config""" + return { + "name": "GeoJSON", + "type": "feature", + "data": str(Path(__file__).parent / "data" / "obs.geojson"), + "id_field": "id", + } + + +@pytest.fixture +def mock_logger(): + """Fixture for mock logger""" + return Mock(spec=logging.Logger) + + +def test_load_plugin_legacy_without_context(geojson_config): + """Test loading plugin without context (backwards compatibility)""" + # This is the current behavior - should still work + provider = load_plugin("provider", geojson_config) + + assert provider is not None + assert provider.name == "GeoJSON" + assert provider.type == "feature" + + +def test_load_plugin_with_context_none(geojson_config): + """Test loading plugin with context=None (explicit)""" + # Explicitly passing None should work like legacy mode + provider = load_plugin("provider", geojson_config, context=None) + + assert provider is not None + assert provider.name == "GeoJSON" + + +def test_load_plugin_with_context_minimal(geojson_config): + """Test loading plugin with minimal context""" + context = PluginContext(config=geojson_config) + + # Provider should load even if it doesn't support context + # (will fall back to legacy constructor) + provider = load_plugin("provider", geojson_config, context=context) + + assert provider is not None + assert provider.name == "GeoJSON" + + +def test_load_plugin_with_context_logger(geojson_config, mock_logger): + """Test loading plugin with context containing logger""" + context = PluginContext(config=geojson_config, logger=mock_logger) + + provider = load_plugin("provider", geojson_config, context=context) + + assert provider is not None + # Note: Current providers don't support context yet + # This test verifies load_plugin handles context gracefully + + +def test_load_plugin_with_context_full(geojson_config, mock_logger): + """Test loading plugin with full context""" + context = PluginContext( + config=geojson_config, + logger=mock_logger, + locales=["en", "it", "fr"], + base_url="https://api.example.com", + ) + + provider = load_plugin("provider", geojson_config, context=context) + + assert provider is not None + assert provider.name == "GeoJSON" + + +def test_load_plugin_invalid_plugin_type(): + """Test loading plugin with invalid type raises error""" + config = {"name": "Test", "data": "test.json"} + + with pytest.raises(Exception): # Should raise InvalidPluginError + load_plugin("invalid_type", config) + + +def test_load_plugin_invalid_plugin_name(): + """Test loading plugin with invalid name raises error""" + config = { + "name": "NonExistentProvider", + "type": "feature", + "data": "test.json", + } + + with pytest.raises(Exception): # Should raise InvalidPluginError + load_plugin("provider", config) + + +def test_load_plugin_multiple_providers_with_context(mock_logger): + """Test loading multiple providers with different contexts""" + # Create two different contexts + context1 = PluginContext( + config={ + "name": "GeoJSON", + "type": "feature", + "data": str(Path(__file__).parent / "data" / "obs.geojson"), + }, + logger=mock_logger, + base_url="https://api1.example.com", + ) + + context2 = PluginContext( + config={ + "name": "CSV", + "type": "feature", + "data": str(Path(__file__).parent / "data" / "obs.csv"), + "id_field": "id", + "geometry": {"x_field": "long", "y_field": "lat"}, + }, + logger=mock_logger, + base_url="https://api2.example.com", + ) + + # Load two providers with different contexts + provider1 = load_plugin("provider", context1.config, context=context1) + provider2 = load_plugin("provider", context2.config, context=context2) + + # Both should load successfully + assert provider1.name == "GeoJSON" + assert provider2.name == "CSV" + + # They should be different instances + assert provider1 is not provider2 + + +def test_load_plugin_process_without_context(): + """Test loading process without context (backwards compatibility)""" + config = {"name": "HelloWorld"} + + processor = load_plugin("process", config) + + assert processor is not None + + +def test_load_plugin_process_with_context(mock_logger): + """Test loading process with context""" + context = PluginContext(config={"name": "HelloWorld"}, logger=mock_logger) + + processor = load_plugin("process", context.config, context=context) + + assert processor is not None + + +def test_load_plugin_process_manager_without_context(): + """Test loading process manager without context""" + config = {"name": "Dummy"} + + manager = load_plugin("process_manager", config) + + assert manager is not None + + +def test_load_plugin_process_manager_with_context(mock_logger): + """Test loading process manager with context""" + context = PluginContext(config={"name": "Dummy"}, logger=mock_logger) + + manager = load_plugin("process_manager", context.config, context=context) + + assert manager is not None + + +def test_load_plugin_custom_plugin_with_dotted_path(): + """Test loading custom plugin using dotted path notation""" + # This tests that context works with custom plugins too + config = { + "name": "pygeoapi.provider.geojson.GeoJSONProvider", + "type": "feature", + "data": str(Path(__file__).parent / "data" / "obs.geojson"), + } + + context = PluginContext(config=config) + provider = load_plugin("provider", config, context=context) + + assert provider is not None + + +def test_context_extensibility_in_plugin_loading(mock_logger): + """Test that extended context works with load_plugin""" + from dataclasses import dataclass + from typing import Optional + + @dataclass + class ExtendedContext(PluginContext): + """Extended context for testing""" + + custom_field: Optional[str] = None + + config = { + "name": "GeoJSON", + "type": "feature", + "data": str(Path(__file__).parent / "data" / "obs.geojson"), + } + + context = ExtendedContext( + config=config, logger=mock_logger, custom_field="test_value" + ) + + # Should work even with extended context + provider = load_plugin("provider", config, context=context) + + assert provider is not None + assert provider.name == "GeoJSON" + + +@pytest.mark.parametrize( + "plugin_type,config", + [ + ( + "provider", + {"name": "GeoJSON", "type": "feature", "data": "test.geojson"}, + ), + ("process", {"name": "HelloWorld"}), + ("process_manager", {"name": "Dummy"}), + ], +) +def test_load_plugin_context_backwards_compatible( + plugin_type, config, mock_logger +): + """Test that context doesn't break any plugin type""" + context = PluginContext(config=config, logger=mock_logger) + + # Should load without errors (may fall back to legacy constructor) + plugin = load_plugin(plugin_type, config, context=context) + + assert plugin is not None