Skip to content

Commit 5e65d83

Browse files
authored
refactor: improve pydantic json dump (#2130)
* refactor: improve json encode for JSONField with pydantic model as value * chore: upgrade deps * tests: fix postgres test error
1 parent ebc28bf commit 5e65d83

6 files changed

Lines changed: 538 additions & 370 deletions

File tree

tests/contrib/test_pydantic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,8 @@ async def test_json_field(db):
13791379
"data_default": {"a": 1},
13801380
"data_validate": None,
13811381
"data_pydantic": json_pydantic_default.model_dump(),
1382+
"data_decimal": None,
1383+
"data_index": None,
13821384
}
13831385
ret1 = creator.model_validate(json_field_1_get).model_dump()
13841386
assert ret1 == {
@@ -1388,6 +1390,8 @@ async def test_json_field(db):
13881390
"data_default": {"a": 1},
13891391
"data_validate": None,
13901392
"data_pydantic": json_pydantic_default.model_dump(),
1393+
"data_decimal": None,
1394+
"data_index": None,
13911395
}
13921396

13931397

tests/fields/test_json.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from decimal import Decimal
2+
13
import pytest
4+
from pydantic import BaseModel
25

36
from tests import testmodels
47
from tortoise.contrib.test import requireCapability
@@ -10,6 +13,7 @@
1013
IntegrityError,
1114
)
1215
from tortoise.fields import JSONField
16+
from tortoise.indexes import Index
1317

1418

1519
@pytest.mark.asyncio
@@ -19,6 +23,12 @@ async def test_empty(db):
1923
await testmodels.JSONFields.create()
2024

2125

26+
class MyPydanticModel(BaseModel):
27+
name: str
28+
idx: Index
29+
model_config = dict(arbitrary_types_allowed=True)
30+
31+
2232
@pytest.mark.asyncio
2333
async def test_create(db):
2434
"""Test JSON field creation and retrieval."""
@@ -29,6 +39,16 @@ async def test_create(db):
2939
await obj.save()
3040
obj2 = await testmodels.JSONFields.get(id=obj.id)
3141
assert obj == obj2
42+
obj3 = await testmodels.JSONFields.create(data="{}", data_decimal=Decimal(0))
43+
obj3 = await testmodels.JSONFields.get(id=obj3.id)
44+
assert str(obj3.data_decimal) == "0"
45+
pyd_model = MyPydanticModel(name="", idx=Index(fields=["data"]))
46+
obj4 = await testmodels.JSONFields.create(data="{}", data_index=pyd_model)
47+
obj4 = await testmodels.JSONFields.get(id=obj4.id)
48+
assert obj4.data_index == {
49+
"idx": {"expressions": [], "extra": "", "fields": ["data"], "name": None, "type": ""},
50+
"name": "",
51+
}
3252

3353

3454
@pytest.mark.asyncio
@@ -45,6 +65,12 @@ async def test_error(db):
4565
obj.data = "error json"
4666
await obj.save()
4767

68+
with pytest.raises(TypeError):
69+
await testmodels.JSONFields.create(data=Decimal(0))
70+
71+
with pytest.raises(TypeError):
72+
await testmodels.JSONFields.create(data_decimal=Index(fields=["data"]))
73+
4874

4975
@pytest.mark.asyncio
5076
async def test_update(db):

tests/testmodels.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44

55
import binascii
66
import datetime
7+
import json
78
import os
89
import re
910
import uuid
1011
from decimal import Decimal
1112
from enum import Enum, IntEnum
13+
from typing import Any
1214

1315
from pydantic import BaseModel, ConfigDict
1416

@@ -359,6 +361,24 @@ def raise_if_not_dict_or_list(value: dict | list):
359361
raise ValidationError("Value must be a dict or list.")
360362

361363

364+
class DecimalEncoder(json.JSONEncoder):
365+
def default(self, obj: Any) -> Any:
366+
if isinstance(obj, Decimal):
367+
return str(obj)
368+
return super().default(obj)
369+
370+
@classmethod
371+
def dumps(cls, obj: Any) -> str:
372+
return json.dumps(obj, cls=cls)
373+
374+
375+
class IndexEncoder(DecimalEncoder):
376+
def default(self, obj: Any) -> Any:
377+
if isinstance(obj, Index):
378+
return obj.describe()
379+
return super().default(obj)
380+
381+
362382
class JSONFields(Model):
363383
"""
364384
This model contains many JSON blobs
@@ -377,6 +397,10 @@ class JSONFields(Model):
377397
default=json_pydantic_default, field_type=TestSchemaForJSONField
378398
)
379399

400+
# Test cases where encoders are provided
401+
data_decimal = fields.JSONField[dict | list](null=True, encoder=DecimalEncoder.dumps)
402+
data_index = fields.JSONField[dict | list](null=True, encoder=IndexEncoder.dumps)
403+
380404

381405
class UUIDFields(Model):
382406
id = fields.UUIDField(primary_key=True, default=uuid.uuid1)

tests/utils/test_describe_model.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,6 +1512,48 @@ def test_describe_model_json():
15121512
"docstring": None,
15131513
"constraints": {},
15141514
},
1515+
{
1516+
"constraints": {},
1517+
"db_column": "data_decimal",
1518+
"db_default": "__NOT_SET__",
1519+
"db_field_types": {
1520+
"": "JSON",
1521+
"mssql": "NVARCHAR(MAX)",
1522+
"oracle": "NCLOB",
1523+
"postgres": "JSONB",
1524+
},
1525+
"default": None,
1526+
"description": None,
1527+
"docstring": None,
1528+
"field_type": "JSONField",
1529+
"generated": False,
1530+
"indexed": False,
1531+
"name": "data_decimal",
1532+
"nullable": True,
1533+
"python_type": "dict | list",
1534+
"unique": False,
1535+
},
1536+
{
1537+
"constraints": {},
1538+
"db_column": "data_index",
1539+
"db_default": "__NOT_SET__",
1540+
"db_field_types": {
1541+
"": "JSON",
1542+
"mssql": "NVARCHAR(MAX)",
1543+
"oracle": "NCLOB",
1544+
"postgres": "JSONB",
1545+
},
1546+
"default": None,
1547+
"description": None,
1548+
"docstring": None,
1549+
"field_type": "JSONField",
1550+
"generated": False,
1551+
"indexed": False,
1552+
"name": "data_index",
1553+
"nullable": True,
1554+
"python_type": "dict | list",
1555+
"unique": False,
1556+
},
15151557
],
15161558
"fk_fields": [],
15171559
"backward_fk_fields": [],
@@ -1655,6 +1697,48 @@ def test_describe_model_json_native():
16551697
"docstring": None,
16561698
"constraints": {},
16571699
},
1700+
{
1701+
"constraints": {},
1702+
"db_column": "data_decimal",
1703+
"db_default": DB_DEFAULT_NOT_SET,
1704+
"db_field_types": {
1705+
"": "JSON",
1706+
"mssql": "NVARCHAR(MAX)",
1707+
"oracle": "NCLOB",
1708+
"postgres": "JSONB",
1709+
},
1710+
"default": None,
1711+
"description": None,
1712+
"docstring": None,
1713+
"field_type": fields.JSONField,
1714+
"generated": False,
1715+
"indexed": False,
1716+
"name": "data_decimal",
1717+
"nullable": True,
1718+
"python_type": dict | list,
1719+
"unique": False,
1720+
},
1721+
{
1722+
"constraints": {},
1723+
"db_column": "data_index",
1724+
"db_default": DB_DEFAULT_NOT_SET,
1725+
"db_field_types": {
1726+
"": "JSON",
1727+
"mssql": "NVARCHAR(MAX)",
1728+
"oracle": "NCLOB",
1729+
"postgres": "JSONB",
1730+
},
1731+
"default": None,
1732+
"description": None,
1733+
"docstring": None,
1734+
"field_type": fields.JSONField,
1735+
"generated": False,
1736+
"indexed": False,
1737+
"name": "data_index",
1738+
"nullable": True,
1739+
"python_type": dict | list,
1740+
"unique": False,
1741+
},
16581742
],
16591743
"fk_fields": [],
16601744
"backward_fk_fields": [],

tortoise/fields/data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,9 @@ def to_db_value(
775775
return value
776776

777777
if _PydanticBaseModel is not None and isinstance(value, _PydanticBaseModel):
778+
if self.encoder is JSON_DUMPS:
779+
return value.model_dump_json()
780+
# self.encoder may be a custom json encoder
778781
value = value.model_dump()
779782

780783
return self.encoder(value)

0 commit comments

Comments
 (0)