Skip to content

Commit 2c997e4

Browse files
author
Boris Smidt
committed
Add a type override function for domain types.
1 parent f14a9b7 commit 2c997e4

19 files changed

Lines changed: 287 additions & 22 deletions

File tree

README.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,42 @@ class Status(str, enum.Enum):
7676
OPEN = "op!en"
7777
CLOSED = "clo@sed"
7878
```
79+
80+
### Map domains (and other unknown types) to Python types
81+
82+
Option: `domain_overrides`
83+
84+
sqlc does not pass `CREATE DOMAIN` definitions (their base type or `CHECK`
85+
constraints) to code generation plugins, so columns using a domain are emitted
86+
as `Any` and a `unknown PostgreSQL type` warning is logged. The
87+
`domain_overrides` option lets you map a PostgreSQL type name to a
88+
fully-qualified Python type. The required `import` is added automatically,
89+
including for nested modules.
90+
91+
```yaml
92+
options:
93+
package: authors
94+
domain_overrides:
95+
job_status: my.module.JobStatus
96+
positive_int: decimal.Decimal
97+
```
98+
99+
Given a domain `job_status` used by a `status` column, this generates:
100+
101+
```py
102+
import decimal
103+
104+
import my.module
105+
106+
107+
@dataclasses.dataclass()
108+
class Job:
109+
id: int
110+
status: my.module.JobStatus
111+
priority: Optional[decimal.Decimal]
112+
```
113+
114+
The key is matched against the column's data type, its bare type name, and its
115+
schema-qualified name (e.g. `public.job_status`), so you can key the override
116+
however is most convenient. This also works for any other type sqlc reports as
117+
unknown, not just domains.

internal/config.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,10 @@ type Config struct {
1010
EmitStrEnum bool `json:"emit_str_enum"`
1111
QueryParameterLimit *int32 `json:"query_parameter_limit"`
1212
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"`
13+
14+
// DomainOverrides maps a PostgreSQL type name (typically a DOMAIN, whose
15+
// definition sqlc does not pass to plugins) to a fully-qualified Python
16+
// type. For example {"job_status": "my.module.JobStatus"} emits a
17+
// "import my.module" and annotates the column as "my.module.JobStatus".
18+
DomainOverrides map[string]string `json:"domain_overrides"`
1319
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Code generated by sqlc. DO NOT EDIT.
2+
# versions:
3+
# sqlc v1.31.1
4+
import dataclasses
5+
import decimal
6+
from typing import Optional
7+
8+
import my.module
9+
10+
11+
@dataclasses.dataclass()
12+
class Job:
13+
id: int
14+
status: my.module.JobStatus
15+
priority: Optional[decimal.Decimal]
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Code generated by sqlc. DO NOT EDIT.
2+
# versions:
3+
# sqlc v1.31.1
4+
# source: query.sql
5+
from typing import AsyncIterator, Iterator, Optional
6+
7+
import my.module
8+
import sqlalchemy
9+
import sqlalchemy.ext.asyncio
10+
11+
from db import models
12+
13+
14+
GET_JOB = """-- name: get_job \\:one
15+
SELECT id, status, priority FROM jobs
16+
WHERE id = :p1 LIMIT 1
17+
"""
18+
19+
20+
LIST_JOBS_BY_STATUS = """-- name: list_jobs_by_status \\:many
21+
SELECT id, status, priority FROM jobs
22+
WHERE status = :p1
23+
ORDER BY priority
24+
"""
25+
26+
27+
class Querier:
28+
def __init__(self, conn: sqlalchemy.engine.Connection):
29+
self._conn = conn
30+
31+
def get_job(self, *, id: int) -> Optional[models.Job]:
32+
row = self._conn.execute(sqlalchemy.text(GET_JOB), {"p1": id}).first()
33+
if row is None:
34+
return None
35+
return models.Job(
36+
id=row[0],
37+
status=row[1],
38+
priority=row[2],
39+
)
40+
41+
def list_jobs_by_status(self, *, status: my.module.JobStatus) -> Iterator[models.Job]:
42+
result = self._conn.execute(sqlalchemy.text(LIST_JOBS_BY_STATUS), {"p1": status})
43+
for row in result:
44+
yield models.Job(
45+
id=row[0],
46+
status=row[1],
47+
priority=row[2],
48+
)
49+
50+
51+
class AsyncQuerier:
52+
def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection):
53+
self._conn = conn
54+
55+
async def get_job(self, *, id: int) -> Optional[models.Job]:
56+
row = (await self._conn.execute(sqlalchemy.text(GET_JOB), {"p1": id})).first()
57+
if row is None:
58+
return None
59+
return models.Job(
60+
id=row[0],
61+
status=row[1],
62+
priority=row[2],
63+
)
64+
65+
async def list_jobs_by_status(self, *, status: my.module.JobStatus) -> AsyncIterator[models.Job]:
66+
rows = (await self._conn.execute(sqlalchemy.text(LIST_JOBS_BY_STATUS), {"p1": status})).all()
67+
for row in rows:
68+
yield models.Job(
69+
id=row[0],
70+
status=row[1],
71+
priority=row[2],
72+
)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- name: GetJob :one
2+
SELECT * FROM jobs
3+
WHERE id = $1 LIMIT 1;
4+
5+
-- name: ListJobsByStatus :many
6+
SELECT * FROM jobs
7+
WHERE status = $1
8+
ORDER BY priority;
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
CREATE DOMAIN job_status AS text
2+
CHECK (
3+
VALUE IN (
4+
'QUEUED',
5+
'PENDING',
6+
'RUNNING',
7+
'COMPLETED',
8+
'FAILED'
9+
)) NOT NULL;
10+
11+
CREATE DOMAIN positive_int AS integer
12+
CHECK (VALUE > 0);
13+
14+
15+
CREATE TABLE jobs (
16+
id BIGSERIAL PRIMARY KEY,
17+
status job_status NOT NULL,
18+
priority positive_int
19+
);
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
version: "2"
2+
plugins:
3+
- name: py
4+
wasm:
5+
url: file://../../../../bin/sqlc-gen-python.wasm
6+
sha256: "00c7c16380c4593d7a86b82e2f650c5655c179cdb0d63b1513d6987ec9be0f46"
7+
sql:
8+
- schema: schema.sql
9+
queries: query.sql
10+
engine: postgresql
11+
codegen:
12+
- plugin: py
13+
out: db
14+
options:
15+
package: db
16+
emit_sync_querier: true
17+
emit_async_querier: true
18+
domain_overrides:
19+
job_status: my.module.JobStatus
20+
positive_int: decimal.Decimal

internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "00c7c16380c4593d7a86b82e2f650c5655c179cdb0d63b1513d6987ec9be0f46"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/emit_str_enum/sqlc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "00c7c16380c4593d7a86b82e2f650c5655c179cdb0d63b1513d6987ec9be0f46"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/exec_result/sqlc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "00c7c16380c4593d7a86b82e2f650c5655c179cdb0d63b1513d6987ec9be0f46"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

0 commit comments

Comments
 (0)