diff --git a/sqlmodel/main.py b/sqlmodel/main.py index fbc44de0e5..e0fbdea432 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -38,6 +38,7 @@ from sqlalchemy import Enum as sa_Enum from sqlalchemy.orm import ( Mapped, + MappedColumn, RelationshipProperty, declared_attr, registry, @@ -688,10 +689,10 @@ def get_sqlalchemy_type(field: Any) -> Any: raise ValueError(f"{type_} has no matching SQLAlchemy type") -def get_column_from_field(field: Any) -> Column: # type: ignore +def get_column_from_field(field: Any) -> Union[Column, MappedColumn]: # type: ignore field_info = field sa_column = getattr(field_info, "sa_column", Undefined) - if isinstance(sa_column, Column): + if isinstance(sa_column, (Column, MappedColumn)): return sa_column sa_type = get_sqlalchemy_type(field) primary_key = getattr(field_info, "primary_key", Undefined) diff --git a/tests/test_field_sa_column.py b/tests/test_field_sa_column.py index e2ccc6d7ef..3ee5f50b9d 100644 --- a/tests/test_field_sa_column.py +++ b/tests/test_field_sa_column.py @@ -2,14 +2,16 @@ import pytest from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import mapped_column from sqlmodel import Field, SQLModel -def test_sa_column_takes_precedence() -> None: +@pytest.mark.parametrize("column_class", [Column, mapped_column]) +def test_sa_column_takes_precedence(clear_sqlmodel, column_class) -> None: class Item(SQLModel, table=True): id: Optional[int] = Field( default=None, - sa_column=Column(String, primary_key=True, nullable=False), + sa_column=column_class(String, primary_key=True, nullable=False), ) # It would have been nullable with no sa_column @@ -17,62 +19,68 @@ class Item(SQLModel, table=True): assert isinstance(Item.id.type, String) # type: ignore -def test_sa_column_no_sa_args() -> None: +@pytest.mark.parametrize("column_class", [Column, mapped_column]) +def test_sa_column_no_sa_args(column_class) -> None: with pytest.raises(RuntimeError): class Item(SQLModel, table=True): id: Optional[int] = Field( default=None, sa_column_args=[Integer], - sa_column=Column(Integer, primary_key=True), + sa_column=column_class(Integer, primary_key=True), ) -def test_sa_column_no_sa_kargs() -> None: +@pytest.mark.parametrize("column_class", [Column, mapped_column]) +def test_sa_column_no_sa_kargs(column_class) -> None: with pytest.raises(RuntimeError): class Item(SQLModel, table=True): id: Optional[int] = Field( default=None, sa_column_kwargs={"primary_key": True}, - sa_column=Column(Integer, primary_key=True), + sa_column=column_class(Integer, primary_key=True), ) -def test_sa_column_no_type() -> None: +@pytest.mark.parametrize("column_class", [Column, mapped_column]) +def test_sa_column_no_type(column_class) -> None: with pytest.raises(RuntimeError): class Item(SQLModel, table=True): id: Optional[int] = Field( default=None, sa_type=Integer, - sa_column=Column(Integer, primary_key=True), + sa_column=column_class(Integer, primary_key=True), ) -def test_sa_column_no_primary_key() -> None: +@pytest.mark.parametrize("column_class", [Column, mapped_column]) +def test_sa_column_no_primary_key(column_class) -> None: with pytest.raises(RuntimeError): class Item(SQLModel, table=True): id: Optional[int] = Field( default=None, primary_key=True, - sa_column=Column(Integer, primary_key=True), + sa_column=column_class(Integer, primary_key=True), ) -def test_sa_column_no_nullable() -> None: +@pytest.mark.parametrize("column_class", [Column, mapped_column]) +def test_sa_column_no_nullable(column_class) -> None: with pytest.raises(RuntimeError): class Item(SQLModel, table=True): id: Optional[int] = Field( default=None, nullable=True, - sa_column=Column(Integer, primary_key=True), + sa_column=column_class(Integer, primary_key=True), ) -def test_sa_column_no_foreign_key() -> None: +@pytest.mark.parametrize("column_class", [Column, mapped_column]) +def test_sa_column_no_foreign_key(clear_sqlmodel, column_class) -> None: with pytest.raises(RuntimeError): class Team(SQLModel, table=True): @@ -84,38 +92,41 @@ class Hero(SQLModel, table=True): team_id: Optional[int] = Field( default=None, foreign_key="team.id", - sa_column=Column(Integer, primary_key=True), + sa_column=column_class(Integer, primary_key=True), ) -def test_sa_column_no_unique() -> None: +@pytest.mark.parametrize("column_class", [Column, mapped_column]) +def test_sa_column_no_unique(column_class) -> None: with pytest.raises(RuntimeError): class Item(SQLModel, table=True): id: Optional[int] = Field( default=None, unique=True, - sa_column=Column(Integer, primary_key=True), + sa_column=column_class(Integer, primary_key=True), ) -def test_sa_column_no_index() -> None: +@pytest.mark.parametrize("column_class", [Column, mapped_column]) +def test_sa_column_no_index(column_class) -> None: with pytest.raises(RuntimeError): class Item(SQLModel, table=True): id: Optional[int] = Field( default=None, index=True, - sa_column=Column(Integer, primary_key=True), + sa_column=column_class(Integer, primary_key=True), ) -def test_sa_column_no_ondelete() -> None: +@pytest.mark.parametrize("column_class", [Column, mapped_column]) +def test_sa_column_no_ondelete(column_class) -> None: with pytest.raises(RuntimeError): class Item(SQLModel, table=True): id: Optional[int] = Field( default=None, - sa_column=Column(Integer, primary_key=True), + sa_column=column_class(Integer, primary_key=True), ondelete="CASCADE", )