diff --git a/sqlmodel/main.py b/sqlmodel/main.py index fbc44de0e5..69488e4176 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -1,8 +1,10 @@ from __future__ import annotations import builtins +import inspect as inspect_module import ipaddress import uuid +import warnings import weakref from collections.abc import Mapping, Sequence, Set from datetime import date, datetime, time, timedelta @@ -11,6 +13,7 @@ from pathlib import Path from typing import ( TYPE_CHECKING, + Annotated, Any, Callable, ClassVar, @@ -22,6 +25,7 @@ ) from pydantic import BaseModel, EmailStr +from pydantic import Field as PydanticField from pydantic.fields import FieldInfo as PydanticFieldInfo from sqlalchemy import ( Boolean, @@ -88,6 +92,10 @@ ] OnDeleteType = Literal["CASCADE", "SET NULL", "RESTRICT"] +FIELD_ACCEPTED_KWARGS = set(inspect_module.signature(PydanticField).parameters.keys()) +if "schema_extra" in FIELD_ACCEPTED_KWARGS: + FIELD_ACCEPTED_KWARGS.remove("schema_extra") + def __dataclass_transform__( *, @@ -237,7 +245,16 @@ def Field( sa_type: Union[type[Any], UndefinedType] = Undefined, sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, - schema_extra: Optional[dict[str, Any]] = None, + schema_extra: Annotated[ + Optional[dict[str, Any]], + deprecated( + """ + This parameter is deprecated. + Use `json_schema_extra` to add extra information to JSON schema. + """ + ), + ] = None, + json_schema_extra: Optional[dict[str, Any]] = None, ) -> Any: ... @@ -281,7 +298,16 @@ def Field( sa_type: Union[type[Any], UndefinedType] = Undefined, sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, - schema_extra: Optional[dict[str, Any]] = None, + schema_extra: Annotated[ + Optional[dict[str, Any]], + deprecated( + """ + This parameter is deprecated. + Use `json_schema_extra` to add extra information to JSON schema. + """ + ), + ] = None, + json_schema_extra: Optional[dict[str, Any]] = None, ) -> Any: ... @@ -325,7 +351,16 @@ def Field( discriminator: Optional[str] = None, repr: bool = True, sa_column: Union[Column[Any], UndefinedType] = Undefined, - schema_extra: Optional[dict[str, Any]] = None, + schema_extra: Annotated[ + Optional[dict[str, Any]], + deprecated( + """ + This parameter is deprecated. + Use `json_schema_extra` to add extra information to JSON schema. + """ + ), + ] = None, + json_schema_extra: Optional[dict[str, Any]] = None, ) -> Any: ... @@ -367,9 +402,28 @@ def Field( sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, - schema_extra: Optional[dict[str, Any]] = None, + schema_extra: Annotated[ + Optional[dict[str, Any]], + deprecated( + """ + This parameter is deprecated. + Use `json_schema_extra` to add extra information to JSON schema. + """ + ), + ] = None, + json_schema_extra: Optional[dict[str, Any]] = None, ) -> Any: + if schema_extra: + warnings.warn( + "schema_extra parameter is deprecated. " + "Use json_schema_extra to add extra information to JSON schema.", + DeprecationWarning, + stacklevel=1, + ) + + current_json_schema_extra = json_schema_extra or {} current_schema_extra = schema_extra or {} + # Extract possible alias settings from schema_extra so we can control precedence schema_validation_alias = current_schema_extra.pop("validation_alias", None) schema_serialization_alias = current_schema_extra.pop("serialization_alias", None) @@ -417,6 +471,21 @@ def Field( serialization_alias or schema_serialization_alias or alias ) + # Handle a workaround when json_schema_extra was passed via schema_extra + if "json_schema_extra" in current_schema_extra: + json_schema_extra_from_schema_extra = current_schema_extra.pop( + "json_schema_extra" + ) + if not current_json_schema_extra: + current_json_schema_extra = json_schema_extra_from_schema_extra + # Split parameters from schema_extra to field_info_kwargs and json_schema_extra + for key, value in current_schema_extra.items(): + if key in FIELD_ACCEPTED_KWARGS: + field_info_kwargs[key] = value + else: + current_json_schema_extra[key] = value + field_info_kwargs["json_schema_extra"] = current_json_schema_extra + field_info = FieldInfo( default, default_factory=default_factory, diff --git a/tests/test_field_json_schema_extra.py b/tests/test_field_json_schema_extra.py new file mode 100644 index 0000000000..0cbe3dafaa --- /dev/null +++ b/tests/test_field_json_schema_extra.py @@ -0,0 +1,85 @@ +import pytest +from sqlmodel import Field, SQLModel + + +def test_json_schema_extra_applied(): + """test json_schema_extra is applied to the field""" + + class Item(SQLModel): + name: str = Field( + json_schema_extra={ + "example": "Sword of Power", + "x-custom-key": "Important Data", + } + ) + + schema = Item.model_json_schema() + name_schema = schema["properties"]["name"] + + assert name_schema["example"] == "Sword of Power" + assert name_schema["x-custom-key"] == "Important Data" + + +def test_schema_extra_and_json_schema_extra_conflict(): + """ + Test that passing schema_extra and json_schema_extra at the same time produces + a warning. + """ + + with pytest.warns(DeprecationWarning, match="schema_extra parameter is deprecated"): + Field(schema_extra={"legacy": 1}, json_schema_extra={"new": 2}) + + +def test_schema_extra_backward_compatibility(): + """ + test that schema_extra is backward compatible with json_schema_extra + """ + + with pytest.warns(DeprecationWarning, match="schema_extra parameter is deprecated"): + + class LegacyItem(SQLModel): + name: str = Field( + schema_extra={ + "example": "Sword of Old", + "x-custom-key": "Important Data", + "serialization_alias": "id_test", + } + ) + + schema = LegacyItem.model_json_schema() + name_schema = schema["properties"]["name"] + + assert name_schema["example"] == "Sword of Old" + assert name_schema["x-custom-key"] == "Important Data" + + # With Pydantic V1 serialization_alias from schema_extra is applied + field_info = LegacyItem.model_fields["name"] + assert field_info.serialization_alias == "id_test" + + +def test_json_schema_extra_mix_in_schema_extra(): + """ + Test workaround when json_schema_extra was passed via schema_extra. + """ + + with pytest.warns(DeprecationWarning, match="schema_extra parameter is deprecated"): + + class Item(SQLModel): + name: str = Field( + schema_extra={ + "json_schema_extra": { + "example": "Sword of Power", + "x-custom-key": "Important Data", + }, + "serialization_alias": "id_test", + } + ) + + schema = Item.model_json_schema() + + name_schema = schema["properties"]["name"] + assert name_schema["example"] == "Sword of Power" + assert name_schema["x-custom-key"] == "Important Data" + + field_info = Item.model_fields["name"] + assert field_info.serialization_alias == "id_test"