Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion docs/relations/foreign-key.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

`ForeignKey(to: Model, *, name: str = None, unique: bool = False, nullable: bool = True,
related_name: str = None, virtual: bool = False, onupdate: Union[ReferentialAction, str] = None,
ondelete: Union[ReferentialAction, str] = None, **kwargs: Any)`
ondelete: Union[ReferentialAction, str] = None, foreign_key_name: str = None, **kwargs: Any)`
has required parameters `to` that takes target `Model` class.

Sqlalchemy column and type are automatically taken from target `Model`.
Expand Down Expand Up @@ -220,6 +220,28 @@ Set the ForeignKey to its default value; a `server_default` for the ForeignKey m

Take `NO ACTION`; NO ACTION and RESTRICT are very much alike. The main difference between NO ACTION and RESTRICT is that with NO ACTION the referential integrity check is done after trying to alter the table. RESTRICT does the check before trying to execute the UPDATE or DELETE statement. Both referential actions act the same if the referential integrity check fails: the UPDATE or DELETE statement will result in an error.

## Overriding the foreign key constraint name

By default ormar generates the foreign key constraint name as
`fk_{source_table}_{target_table}_{target_pk}_{field_name}`. On databases with a
short identifier length limit (for example MySQL's 64 character limit) the
auto-generated name can be truncated or rejected. Pass `foreign_key_name` to use
a custom name instead:

```python
class Book(ormar.Model):
ormar_config = base_ormar_config.copy(tablename="books")

id: int = ormar.Integer(primary_key=True)
author = ormar.ForeignKey(
Author,
foreign_key_name="fk_books_author",
)
```

When used on an abstract base class, each subclass suffixes the name with its
own tablename to avoid constraint name collisions across sibling tables.

## Relation Setup

You have several ways to set-up a relationship connection.
Expand Down
27 changes: 27 additions & 0 deletions docs/relations/many-to-many.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,33 @@ class StudentCourse(ormar.Model):
provide your own custom Through model you cannot change the names there and you need to use
same `through_relation_name` and `through_reverse_relation_name` parameters.

## Overriding foreign key constraint names on the through model

Auto-generated foreign key constraint names on the through model can be
overridden with:

* `through_foreign_key_name` - name of the FK constraint on the column that
references the model where `ManyToMany` is declared (the owner side).
* `through_reverse_foreign_key_name` - name of the FK constraint on the column
that references the target model.

This is primarily useful for databases with short identifier limits (for
example MySQL's 64 character limit) where the auto-generated name would be
truncated.

```python
class Student(ormar.Model):
ormar_config = base_ormar_config.copy()

id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
courses = ormar.ManyToMany(
Course,
through_foreign_key_name="fk_sc_student",
through_reverse_foreign_key_name="fk_sc_course",
)
```

## Through Fields

The through field is auto added to the reverse side of the relation.
Expand Down
28 changes: 19 additions & 9 deletions ormar/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def __init__(self, **kwargs: Any) -> None:
self.through_reverse_relation_name = kwargs.pop(
"through_reverse_relation_name", None
)
self.through_foreign_key_name: Optional[str] = kwargs.pop(
"through_foreign_key_name", None
)
self.through_reverse_foreign_key_name: Optional[str] = kwargs.pop(
"through_reverse_foreign_key_name", None
)

self.skip_reverse: bool = kwargs.pop("skip_reverse", False)
self.skip_field: bool = kwargs.pop("skip_field", False)
Expand Down Expand Up @@ -257,16 +263,20 @@ def construct_constraints(self) -> list:
:return: list of sqlalchemy foreign keys - by default one.
:rtype: list[sqlalchemy.schema.ForeignKey]
"""
constraints = [
sqlalchemy.ForeignKey(
con.reference,
ondelete=con.ondelete,
onupdate=con.onupdate,
name=f"fk_{self.owner.ormar_config.tablename}_{self.to.ormar_config.tablename}"
f"_{self.to.get_column_alias(self.to.ormar_config.pkname)}_{self.name}",
constraints = []
for constraint in self.constraints:
owner_table = self.owner.ormar_config.tablename
target_table = self.to.ormar_config.tablename
target_pk = self.to.get_column_alias(self.to.ormar_config.pkname)
default_name = f"fk_{owner_table}_{target_table}_{target_pk}_{self.name}"
constraints.append(
sqlalchemy.ForeignKey(
constraint.reference,
ondelete=constraint.ondelete,
onupdate=constraint.onupdate,
name=constraint.name or default_name,
)
)
for con in self.constraints
]
return constraints

def get_column(self, name: str) -> sqlalchemy.Column:
Expand Down
18 changes: 17 additions & 1 deletion ormar/fields/foreign_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def populate_fk_params_based_on_to_model(
nullable: bool,
onupdate: Optional[str] = None,
ondelete: Optional[str] = None,
foreign_key_name: Optional[str] = None,
) -> tuple[Any, list, Any, Any]:
"""
Based on target to model to which relation leads to populates the type of the
Expand All @@ -96,6 +97,9 @@ def populate_fk_params_based_on_to_model(
:param ondelete: parameter passed to sqlalchemy.ForeignKey.
How to treat child rows on delete of parent (the one where FK is defined) model.
:type ondelete: str
:param foreign_key_name: optional override for the foreign key constraint name
emitted in migrations. Defaults to ``None``, which lets ormar generate a name.
:type foreign_key_name: Optional[str]
:return: tuple with target pydantic type, list of fk constraints and target col type
:rtype: tuple[Any, list, Any]
"""
Expand All @@ -111,7 +115,10 @@ def populate_fk_params_based_on_to_model(
)
constraints = [
ForeignKeyConstraint(
reference=fk_string, ondelete=ondelete, onupdate=onupdate, name=None
reference=fk_string,
ondelete=ondelete,
onupdate=onupdate,
name=foreign_key_name,
)
]
column_type = to_field.column_type
Expand Down Expand Up @@ -203,6 +210,7 @@ def ForeignKey( # type: ignore # noqa CFQ002
virtual: bool = False,
onupdate: Union[ReferentialAction, str, None] = None,
ondelete: Union[ReferentialAction, str, None] = None,
foreign_key_name: Optional[str] = None,
**kwargs: Any,
) -> "T":
"""
Expand Down Expand Up @@ -230,6 +238,10 @@ def ForeignKey( # type: ignore # noqa CFQ002
:param ondelete: parameter passed to sqlalchemy.ForeignKey.
How to treat child rows on delete of parent (the one where FK is defined) model.
:type ondelete: Union[ReferentialAction, str]
:param foreign_key_name: optional override for the foreign key constraint name
generated in migrations. Useful when the auto-generated name exceeds a database
specific identifier length limit (for example MySQL 64 chars).
:type foreign_key_name: Optional[str]
:param kwargs: all other args to be populated by BaseField
:type kwargs: Any
:return: ormar ForeignKeyField with relation to selected model
Expand Down Expand Up @@ -270,6 +282,7 @@ def ForeignKey( # type: ignore # noqa CFQ002
nullable=nullable,
ondelete=ondelete,
onupdate=onupdate,
foreign_key_name=foreign_key_name,
)

namespace = dict(
Expand All @@ -292,6 +305,7 @@ def ForeignKey( # type: ignore # noqa CFQ002
server_default=None,
onupdate=onupdate,
ondelete=ondelete,
foreign_key_name=foreign_key_name,
owner=owner,
self_reference=self_reference,
is_relation=True,
Expand All @@ -316,6 +330,7 @@ def __init__(self, **kwargs: Any) -> None:
self.to: type["Model"]
self.ondelete: str = kwargs.pop("ondelete", None)
self.onupdate: str = kwargs.pop("onupdate", None)
self.foreign_key_name: Optional[str] = kwargs.pop("foreign_key_name", None)
super().__init__(**kwargs)

def get_source_related_name(self) -> str:
Expand Down Expand Up @@ -447,6 +462,7 @@ def evaluate_forward_ref(self, globalns: Any, localns: Any) -> None:
nullable=self.nullable,
ondelete=self.ondelete,
onupdate=self.onupdate,
foreign_key_name=self.foreign_key_name,
)

def _extract_model_from_sequence(
Expand Down
7 changes: 7 additions & 0 deletions ormar/fields/many_to_many.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def ManyToMany( # type: ignore
through_relation_name = kwargs.pop("through_relation_name", None)
through_reverse_relation_name = kwargs.pop("through_reverse_relation_name", None)

through_foreign_key_name = kwargs.pop("through_foreign_key_name", None)
through_reverse_foreign_key_name = kwargs.pop(
"through_reverse_foreign_key_name", None
)

if through is not None and through.__class__ != ForwardRef:
forbid_through_relations(cast(type["Model"], through))

Expand Down Expand Up @@ -171,6 +176,8 @@ def ManyToMany( # type: ignore
skip_field=skip_field,
through_relation_name=through_relation_name,
through_reverse_relation_name=through_reverse_relation_name,
through_foreign_key_name=through_foreign_key_name,
through_reverse_foreign_key_name=through_reverse_foreign_key_name,
)

Field = type("ManyToMany", (ManyToManyField, BaseField), {})
Expand Down
23 changes: 18 additions & 5 deletions ormar/models/helpers/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,16 @@ def adjust_through_many_to_many_model(model_field: "ManyToManyField") -> None:
)

create_and_append_m2m_fk(
model=model_field.to, model_field=model_field, field_name=parent_name
model=model_field.to,
model_field=model_field,
field_name=parent_name,
foreign_key_name=model_field.through_reverse_foreign_key_name,
)
create_and_append_m2m_fk(
model=model_field.owner, model_field=model_field, field_name=child_name
model=model_field.owner,
model_field=model_field,
field_name=child_name,
foreign_key_name=model_field.through_foreign_key_name,
)

create_pydantic_field(parent_name, model_field.to, model_field)
Expand All @@ -58,7 +64,10 @@ def adjust_through_many_to_many_model(model_field: "ManyToManyField") -> None:


def create_and_append_m2m_fk(
model: type["Model"], model_field: "ManyToManyField", field_name: str
model: type["Model"],
model_field: "ManyToManyField",
field_name: str,
foreign_key_name: Optional[str] = None,
) -> None:
"""
Registers sqlalchemy Column with sqlalchemy.ForeignKey leading to the model.
Expand All @@ -72,6 +81,8 @@ def create_and_append_m2m_fk(
:type model: Model class
:param model_field: field with ManyToMany relation
:type model_field: ManyToManyField field
:param foreign_key_name: optional override for the generated FK constraint name.
:type foreign_key_name: Optional[str]
"""
pk_alias = model.get_column_alias(model.ormar_config.pkname)
pk_column = next(
Expand All @@ -81,15 +92,17 @@ def create_and_append_m2m_fk(
raise ormar.ModelDefinitionError(
"ManyToMany relation cannot lead to field without pk"
)
through_table = model_field.through.ormar_config.tablename
target_table = model.ormar_config.tablename
default_name = f"fk_{through_table}_{target_table}_{field_name}_{pk_alias}"
column = sqlalchemy.Column(
field_name,
pk_column.type,
sqlalchemy.schema.ForeignKey(
model.ormar_config.tablename + "." + pk_alias,
ondelete="CASCADE",
onupdate="CASCADE",
name=f"fk_{model_field.through.ormar_config.tablename}_{model.ormar_config.tablename}"
f"_{field_name}_{pk_alias}",
name=foreign_key_name or default_name,
),
)
model_field.through.ormar_config.columns.append(column)
Expand Down
18 changes: 15 additions & 3 deletions ormar/models/metaclass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import dataclasses
import sys
import warnings
from pathlib import Path
Expand Down Expand Up @@ -385,13 +386,24 @@ def copy_data_from_parent_model( # noqa: CCR001
base_class=base_class, # type: ignore
)

elif field.is_relation and field.related_name:
elif field.is_relation and (
field.related_name or cast(ForeignKeyField, field).foreign_key_name
):
fk_field = cast(ForeignKeyField, field)
Field = type( # type: ignore
field.__class__.__name__, (ForeignKeyField, BaseField), {}
)
copy_field = Field(**dict(field.__dict__))
related_name = field.related_name + "_" + table_name
copy_field.related_name = related_name # type: ignore
if fk_field.related_name:
related_name = fk_field.related_name + "_" + table_name
copy_field.related_name = related_name # type: ignore
if fk_field.foreign_key_name:
new_fk_name = f"{fk_field.foreign_key_name}_{table_name}"
copy_field.foreign_key_name = new_fk_name # type: ignore
copy_field.constraints = [
dataclasses.replace(constraint, name=new_fk_name)
for constraint in fk_field.constraints
]
parent_fields[field_name] = copy_field
else:
parent_fields[field_name] = field
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import ormar
from tests.lifespan import init_tests
from tests.settings import create_config

base_ormar_config = create_config()


class FkInhParent(ormar.Model):
ormar_config = base_ormar_config.copy(tablename="fk_inh_parents")

id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)


class FkInhBase(ormar.Model):
ormar_config = base_ormar_config.copy(abstract=True)

id: int = ormar.Integer(primary_key=True)
parent = ormar.ForeignKey(
FkInhParent,
related_name="kids",
foreign_key_name="fk_custom_parent",
)


class FkInhChildA(FkInhBase):
ormar_config = base_ormar_config.copy(tablename="fk_inh_child_a")


class FkInhChildB(FkInhBase):
ormar_config = base_ormar_config.copy(tablename="fk_inh_child_b")


create_test_database = init_tests(base_ormar_config)


def _fk_names(table):
return [fk.name for col in table.c for fk in col.foreign_keys]


def test_foreign_key_name_is_suffixed_per_subclass_to_avoid_conflicts():
a_names = _fk_names(FkInhChildA.ormar_config.table)
b_names = _fk_names(FkInhChildB.ormar_config.table)
assert "fk_custom_parent_fk_inh_child_a" in a_names
assert "fk_custom_parent_fk_inh_child_b" in b_names
assert set(a_names).isdisjoint(set(b_names))
Loading