Skip to content

Commit 44c86e6

Browse files
Merge pull request #1474 from datajoint/feat/1425-strict-provenance
feat(#1425): strict_provenance config flag for runtime enforcement
2 parents ba288ed + d0e8a80 commit 44c86e6

6 files changed

Lines changed: 603 additions & 2 deletions

File tree

src/datajoint/autopopulate.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,34 @@ def _populate1(
658658

659659
self._upstream = Diagram.trace(self & dict(key))
660660

661+
# If strict_provenance is on, push the active-make context so the
662+
# runtime gates in expression.cursor / table.insert can check this
663+
# make()'s reads and writes. The context is popped in the finally
664+
# block below.
665+
strict_token = None
666+
if self.connection._config.get("strict_provenance", False):
667+
from .provenance import push_strict_make_context
668+
from .user_tables import Part
669+
670+
allowed_tables = set(self._upstream._cascade_restrictions.keys()) | {self.full_table_name}
671+
# Add Part tables of self to the allowed set. Use class __dict__
672+
# (not dir/getattr) to avoid triggering descriptors like the
673+
# _JobsDescriptor that lazy-declares the ~~ job table.
674+
for cls in type(self).__mro__:
675+
for attr_name, attr in cls.__dict__.items():
676+
if attr_name.startswith("_"):
677+
continue
678+
if isinstance(attr, type) and issubclass(attr, Part):
679+
# Instantiate to get full_table_name resolved against
680+
# this schema. The Part class is already attached via
681+
# @schema decoration of the master.
682+
try:
683+
part_ftn = attr().full_table_name
684+
allowed_tables.add(part_ftn)
685+
except Exception:
686+
pass
687+
strict_token = push_strict_make_context(self, frozenset(allowed_tables), dict(key))
688+
661689
try:
662690
if not is_generator:
663691
make(dict(key), **(make_kwargs or {}))
@@ -719,6 +747,11 @@ def _populate1(
719747
# access raises a clear error rather than silently using a
720748
# stale trace from the previous make() call.
721749
self._upstream = None
750+
# Pop the strict-make context, if any.
751+
if strict_token is not None:
752+
from .provenance import pop_strict_make_context
753+
754+
pop_strict_make_context(strict_token)
722755

723756
def progress(self, *restrictions: Any, display: bool = False) -> tuple[int, int]:
724757
"""

src/datajoint/expression.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,6 +1242,12 @@ def cursor(self, as_dict=False):
12421242
cursor
12431243
Database query cursor.
12441244
"""
1245+
# Strict-provenance read gate. No-op outside make() or when the
1246+
# config flag is off. See src/datajoint/provenance.py.
1247+
from .provenance import assert_read_allowed
1248+
1249+
assert_read_allowed(self)
1250+
12451251
sql = self.make_sql()
12461252
logger.debug(sql)
12471253
return self.connection.query(sql, as_dict=as_dict)

src/datajoint/provenance.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
"""
2+
Runtime gates for ``dj.config["strict_provenance"]``.
3+
4+
When the flag is enabled, this module's context (set by ``AutoPopulate._populate_one``)
5+
tracks which tables and primary key the currently-executing ``make()`` is
6+
allowed to read and write. The read gate in :func:`assert_read_allowed`
7+
fires inside ``QueryExpression.cursor``. The write gate has two parts: the
8+
target check in :func:`assert_write_allowed` fires inside ``Table.insert``
9+
(before rows are materialized), and the per-row key-consistency check in
10+
:func:`assert_row_key_allowed` fires inside ``Table._insert_rows`` as each row
11+
is materialized — so the gate never consumes the caller's ``rows`` iterable.
12+
13+
The contract is documented in
14+
``datajoint-docs/src/reference/specs/provenance.md`` §3.
15+
16+
Implementation note: the active-make context is stored in a
17+
``contextvars.ContextVar`` so it propagates correctly across threads
18+
that share the parent's context (e.g. the populate-in-subprocess path
19+
which uses ``multiprocessing`` workers, each of which inherits its
20+
parent's contextvar binding at fork time).
21+
"""
22+
23+
from __future__ import annotations
24+
25+
from contextvars import ContextVar
26+
from typing import TYPE_CHECKING, Optional, Tuple
27+
28+
from .errors import DataJointError
29+
30+
if TYPE_CHECKING:
31+
from .table import Table
32+
33+
34+
# Active context: (the target table, the set of allowed full table names, the current key dict)
35+
_active_strict_make: ContextVar[Optional[Tuple["Table", frozenset[str], dict]]] = ContextVar(
36+
"_dj_active_strict_make", default=None
37+
)
38+
39+
40+
def push_strict_make_context(target: "Table", allowed_tables: frozenset[str], key: dict):
41+
"""
42+
Push a strict-make context for the duration of one ``make()`` invocation.
43+
44+
Returns a token that the caller must pass to :func:`pop_strict_make_context`
45+
in a ``finally`` block.
46+
"""
47+
return _active_strict_make.set((target, allowed_tables, key))
48+
49+
50+
def pop_strict_make_context(token) -> None:
51+
"""Pop the strict-make context using a token from :func:`push_strict_make_context`."""
52+
_active_strict_make.reset(token)
53+
54+
55+
def get_active_context():
56+
"""Return the currently-active strict-make context, or None."""
57+
return _active_strict_make.get()
58+
59+
60+
def _base_tables(query_expression) -> set[str]:
61+
"""
62+
Return the set of base-table SQL names that a QueryExpression reads from.
63+
64+
For a single-table expression (FreeTable / Table / restricted variants),
65+
returns ``{full_table_name}``. For compound expressions (joins,
66+
projections of joins), traverses ``support`` recursively.
67+
"""
68+
# FreeTable / Table: has full_table_name directly
69+
ftn = getattr(query_expression, "full_table_name", None)
70+
if isinstance(ftn, str):
71+
return {ftn}
72+
73+
bases: set[str] = set()
74+
support = getattr(query_expression, "_support", None) or []
75+
for s in support:
76+
if isinstance(s, str):
77+
# Direct table name in the support list
78+
bases.add(s)
79+
else:
80+
# Subquery — recurse
81+
bases.update(_base_tables(s))
82+
return bases
83+
84+
85+
def assert_read_allowed(query_expression) -> None:
86+
"""
87+
Verify a fetch is allowed under the active strict-make context.
88+
89+
Called from ``QueryExpression.cursor`` before SQL is issued. No-op when
90+
no strict-make context is active (i.e. outside ``make()`` or when
91+
``strict_provenance`` is False).
92+
93+
Allowed reads:
94+
95+
- Any table in the active context's ``allowed_tables`` set. The set is
96+
built from ``self.upstream`` (the ancestor graph) plus the target
97+
table and its Parts.
98+
99+
Anything else raises ``DataJointError``.
100+
101+
Known limitation (will sharpen in a follow-up): the check does not
102+
distinguish reads that came *through* ``self.upstream`` from reads of
103+
the same ancestor via a direct expression. Both are allowed if the
104+
table is in the allowed set. The intent is to catch reads from
105+
*undeclared* dependencies; tightening the "must come through
106+
``self.upstream``" path requires propagating an attribution marker
107+
through QueryExpression composition and is deferred.
108+
"""
109+
ctx = _active_strict_make.get()
110+
if ctx is None:
111+
return # strict mode off, or outside make()
112+
113+
_target, allowed_tables, _key = ctx
114+
bases = _base_tables(query_expression)
115+
if not bases:
116+
return # nothing to check (e.g. dj.U expressions)
117+
118+
disallowed = bases - allowed_tables
119+
if disallowed:
120+
raise DataJointError(
121+
f"strict_provenance=True: read from undeclared table(s) "
122+
f"{sorted(disallowed)} is not permitted inside make(). "
123+
f"Use self.upstream[T] for declared ancestors, or declare a "
124+
f"foreign-key dependency on the table you want to read."
125+
)
126+
127+
128+
def assert_write_allowed(target_table) -> None:
129+
"""
130+
Verify the *target* of an insert is allowed under the active strict-make context.
131+
132+
Called from ``Table.insert`` after the existing ``_allow_insert`` check and
133+
before any rows are materialized. No-op when no strict-make context is active.
134+
135+
Allowed targets:
136+
137+
- The current ``make()`` target (``self``) or one of its Part tables.
138+
139+
Per-row key consistency is checked separately by :func:`assert_row_key_allowed`
140+
as rows are materialized, so this gate never consumes the caller's ``rows``
141+
iterable — a one-shot generator must survive to reach ``insert``.
142+
143+
Raises ``DataJointError`` if the target is not permitted.
144+
"""
145+
ctx = _active_strict_make.get()
146+
if ctx is None:
147+
return
148+
149+
make_target, _allowed_tables, _key = ctx
150+
151+
# Target must be `make_target` (self) or one of its Parts.
152+
target_name = getattr(target_table, "full_table_name", None)
153+
target_set = {make_target.full_table_name}
154+
# Collect Part tables of make_target via class __dict__ (not dir/getattr,
155+
# which would trigger descriptors like the _JobsDescriptor).
156+
from .user_tables import Part # local import to avoid circular dep
157+
158+
for cls in type(make_target).__mro__:
159+
for attr_name, attr in cls.__dict__.items():
160+
if attr_name.startswith("_"):
161+
continue
162+
if isinstance(attr, type) and issubclass(attr, Part):
163+
try:
164+
part_ftn = attr().full_table_name
165+
target_set.add(part_ftn)
166+
except Exception:
167+
pass
168+
169+
if target_name not in target_set:
170+
raise DataJointError(
171+
f"strict_provenance=True: insert into {target_name!r} is not permitted "
172+
f"inside make() for {make_target.full_table_name!r}. Only the target "
173+
f"table and its Part tables may be written."
174+
)
175+
176+
177+
def assert_row_key_allowed(row) -> None:
178+
"""
179+
Verify a single insert row's key columns match the active ``make()`` key.
180+
181+
Called per row from ``Table._insert_rows`` as rows are materialized, so the
182+
check sees a concrete row without the write gate having to consume the
183+
caller's ``rows`` iterable. No-op when no strict-make context is active or
184+
when ``row`` is not a dict (numpy records / bare sequences carry no field
185+
names to check by — same as the previous behavior).
186+
187+
Raises ``DataJointError`` on a mismatch.
188+
"""
189+
ctx = _active_strict_make.get()
190+
if ctx is None:
191+
return
192+
if not isinstance(row, dict):
193+
return
194+
_make_target, _allowed_tables, key = ctx
195+
_check_row_key(row, key)
196+
197+
198+
def _check_row_key(row: dict, current_key: dict) -> None:
199+
"""Raise if any row attribute overlapping with the current key has a different value."""
200+
for k, v in current_key.items():
201+
if k in row and row[k] != v:
202+
raise DataJointError(
203+
f"strict_provenance=True: inserted row's {k!r}={row[k]!r} does not "
204+
f"match the current make() key's {k!r}={v!r}. Inserts must be "
205+
f"consistent with the key being populated."
206+
)

src/datajoint/settings.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
"database.database_prefix": "DJ_DATABASE_PREFIX",
7070
"database.create_tables": "DJ_CREATE_TABLES",
7171
"loglevel": "DJ_LOG_LEVEL",
72+
"strict_provenance": "DJ_STRICT_PROVENANCE",
7273
"display.diagram_direction": "DJ_DIAGRAM_DIRECTION",
7374
}
7475

@@ -361,6 +362,16 @@ class Config(BaseSettings):
361362
"*New in 2.2.3.*",
362363
)
363364

365+
strict_provenance: bool = Field(
366+
default=False,
367+
validation_alias="DJ_STRICT_PROVENANCE",
368+
description="If True, enforces the upstream-only convention inside make(): "
369+
"reads must go through self.upstream[Ancestor], writes must target self "
370+
"or self's Part tables with primary keys consistent with the current key. "
371+
"Off by default; opt-in for deployments that need runtime provenance "
372+
"guarantees backing downstream lineage / CDC tooling. *New in 2.3.*",
373+
)
374+
364375
# Cache path for query results
365376
query_cache: Path | None = None
366377

src/datajoint/table.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -834,10 +834,23 @@ def insert(
834834
" To override, set keyword argument allow_direct_insert=True."
835835
)
836836

837+
# Strict-provenance write gate (target check only). No-op outside make()
838+
# or when the config flag is off. Deliberately does NOT touch `rows` —
839+
# the per-row key-consistency check happens in `_insert_rows` as rows are
840+
# materialized, so a one-shot iterable (generator) is not consumed here.
841+
# See src/datajoint/provenance.py.
842+
from .provenance import assert_write_allowed
843+
844+
assert_write_allowed(self)
845+
837846
if inspect.isclass(rows) and issubclass(rows, QueryExpression):
838847
rows = rows() # instantiate if a class
839848
if isinstance(rows, QueryExpression):
840-
# insert from select - chunk_size not applicable
849+
# insert from select - chunk_size not applicable.
850+
# Note: this INSERT ... SELECT runs entirely server-side, so under
851+
# strict_provenance the per-row key-consistency check does not apply
852+
# (row values are never materialized client-side). The target check
853+
# in assert_write_allowed above still governs which table is written.
841854
if chunk_size is not None:
842855
raise DataJointError("chunk_size is not supported for QueryExpression inserts")
843856
if not ignore_extra_fields:
@@ -892,7 +905,17 @@ def _insert_rows(self, rows, replace, skip_duplicates, ignore_extra_fields):
892905
"""
893906
# collects the field list from first row (passed by reference)
894907
field_list = []
895-
rows = list(self.__make_row_to_insert(row, field_list, ignore_extra_fields) for row in rows)
908+
# Strict-provenance per-row key check runs here, as each row is
909+
# materialized — no-op outside make()/when the flag is off. Placing it in
910+
# this single materialization point (reached by both the chunked and
911+
# single-batch paths) avoids consuming the caller's `rows` iterable early.
912+
from .provenance import assert_row_key_allowed
913+
914+
def _make_row(row):
915+
assert_row_key_allowed(row)
916+
return self.__make_row_to_insert(row, field_list, ignore_extra_fields)
917+
918+
rows = list(_make_row(row) for row in rows)
896919
if rows:
897920
try:
898921
# Handle empty field_list (all-defaults insert)

0 commit comments

Comments
 (0)