From f60495b1d3a63910e49cc6ef17797493d4a053f6 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 23 Jun 2026 08:04:09 -0500 Subject: [PATCH 1/2] feat(#1425): strict_provenance config flag for runtime enforcement MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements T2.2.c of the provenance trinity, completing the trio (Diagram.trace → self.upstream → strict_provenance). When dj.config["strict_provenance"] = True, runtime gates enforce the upstream-only convention inside make(): - Reads must target a table in the active trace's allowed set (declared ancestors + self + self's Parts). - Writes must target self or self's Parts. - Inserted rows' PK columns that overlap with the current key must equal the key's values (key-consistency rule). Default is False. Existing make() bodies are unaffected. Branch stacked on feat/1424-self-upstream (#1473) → feat/1423-diagram-trace (#1471) → fix/1429-cascade-part-part-renamed-fk (#1468). Will rebase onto master after the chain merges. What's added: - src/datajoint/provenance.py (new): the runtime context module. - `_active_strict_make` ContextVar holding (target, allowed_tables, key) for the currently-executing make() invocation. ContextVar chosen over threading.local to propagate correctly across contextvars-aware concurrency boundaries. - `push_strict_make_context` / `pop_strict_make_context` — context lifecycle managed by `_populate_one`'s try/finally. - `assert_read_allowed(query_expression)` — read gate. Recursively discovers base tables via the QueryExpression's `_support` chain and checks each against the allowed set. - `assert_write_allowed(target_table, rows)` — write gate. Verifies the target is self or one of self's Part tables, and checks the key-consistency rule on each dict row. - src/datajoint/settings.py: new `strict_provenance: bool` field on Config (default False), env-var `DJ_STRICT_PROVENANCE`, ENV_VAR_MAPPING entry. - src/datajoint/autopopulate.py: in `_populate_one`, push the strict context (when the flag is on) just before the make() invocation block. The allowed table set = trace's ancestor nodes ∪ {self.full_table_name} ∪ {self's Parts}. Pop in the existing `finally` block. - src/datajoint/expression.py: `QueryExpression.cursor` now calls `assert_read_allowed(self)` before issuing SQL. No-op outside make(). - src/datajoint/table.py: `Table.insert` calls `assert_write_allowed(self, rows)` after the existing `_allow_insert` check. No-op outside make(). Part-table detection uses class `__dict__` traversal (filtered to Part subclasses) instead of `dir/getattr` to avoid triggering the `_JobsDescriptor` (which would lazy-declare ~~table inside the populate transaction — caught by the first test iteration). Documented limitation (deferred): the read gate does not distinguish reads that came through `self.upstream` from reads of the same ancestor via a direct expression. Both are allowed if the table is in the allowed set. The intent is to catch reads from *undeclared* dependencies; tightening the "must come through self.upstream" path requires propagating an attribution marker through QueryExpression composition and is left for a follow-up release. Tests in tests/integration/test_strict_provenance.py (6 new): - test_strict_compliant_make_passes — make() reading via self.upstream and writing self.insert1 with matching key runs cleanly under strict. - test_strict_blocks_read_from_undeclared_table — read from an unrelated table raises with "strict_provenance ... undeclared" message. - test_strict_blocks_write_to_other_table — insert into a non-self, non-Part target raises "not permitted". - test_strict_blocks_write_with_mismatched_key — row PK that disagrees with the current key raises "does not match the current make() key". - test_strict_writes_to_part_table_pass — self.PartName.insert(...) works. - test_strict_off_by_default_no_change — default-off regression check; the canonical "direct (Ancestor & key).fetch1()" pattern still works when strict_provenance is unset. Regression: 17/17 autopopulate tests pass with strict_provenance unset (default). 6/6 new strict tests pass with strict_provenance=True. 8/8 trace tests + 9/9 cascade tests unaffected. Slated for DataJoint 2.3. --- src/datajoint/autopopulate.py | 33 +++ src/datajoint/expression.py | 6 + src/datajoint/provenance.py | 193 ++++++++++++++++ src/datajoint/settings.py | 11 + src/datajoint/table.py | 6 + tests/integration/test_strict_provenance.py | 244 ++++++++++++++++++++ 6 files changed, 493 insertions(+) create mode 100644 src/datajoint/provenance.py create mode 100644 tests/integration/test_strict_provenance.py diff --git a/src/datajoint/autopopulate.py b/src/datajoint/autopopulate.py index 8f0946a06..d33e6ccf0 100644 --- a/src/datajoint/autopopulate.py +++ b/src/datajoint/autopopulate.py @@ -658,6 +658,34 @@ def _populate1( self._upstream = Diagram.trace(self & dict(key)) + # If strict_provenance is on, push the active-make context so the + # runtime gates in expression.cursor / table.insert can check this + # make()'s reads and writes. The context is popped in the finally + # block below. + strict_token = None + if self.connection._config.get("strict_provenance", False): + from .provenance import push_strict_make_context + from .user_tables import Part + + allowed_tables = set(self._upstream._cascade_restrictions.keys()) | {self.full_table_name} + # Add Part tables of self to the allowed set. Use class __dict__ + # (not dir/getattr) to avoid triggering descriptors like the + # _JobsDescriptor that lazy-declares the ~~ job table. + for cls in type(self).__mro__: + for attr_name, attr in cls.__dict__.items(): + if attr_name.startswith("_"): + continue + if isinstance(attr, type) and issubclass(attr, Part): + # Instantiate to get full_table_name resolved against + # this schema. The Part class is already attached via + # @schema decoration of the master. + try: + part_ftn = attr().full_table_name + allowed_tables.add(part_ftn) + except Exception: + pass + strict_token = push_strict_make_context(self, frozenset(allowed_tables), dict(key)) + try: if not is_generator: make(dict(key), **(make_kwargs or {})) @@ -719,6 +747,11 @@ def _populate1( # access raises a clear error rather than silently using a # stale trace from the previous make() call. self._upstream = None + # Pop the strict-make context, if any. + if strict_token is not None: + from .provenance import pop_strict_make_context + + pop_strict_make_context(strict_token) def progress(self, *restrictions: Any, display: bool = False) -> tuple[int, int]: """ diff --git a/src/datajoint/expression.py b/src/datajoint/expression.py index 1b5f5ac9e..f380b3b52 100644 --- a/src/datajoint/expression.py +++ b/src/datajoint/expression.py @@ -1242,6 +1242,12 @@ def cursor(self, as_dict=False): cursor Database query cursor. """ + # Strict-provenance read gate. No-op outside make() or when the + # config flag is off. See src/datajoint/provenance.py. + from .provenance import assert_read_allowed + + assert_read_allowed(self) + sql = self.make_sql() logger.debug(sql) return self.connection.query(sql, as_dict=as_dict) diff --git a/src/datajoint/provenance.py b/src/datajoint/provenance.py new file mode 100644 index 000000000..e124d1160 --- /dev/null +++ b/src/datajoint/provenance.py @@ -0,0 +1,193 @@ +""" +Runtime gates for ``dj.config["strict_provenance"]``. + +When the flag is enabled, this module's context (set by ``AutoPopulate._populate_one``) +tracks which tables and primary key the currently-executing ``make()`` is +allowed to read and write. The read gate in :func:`assert_read_allowed` +fires inside ``QueryExpression.cursor``; the write gate in +:func:`assert_write_allowed` fires inside ``Table.insert``. + +The contract is documented in +``datajoint-docs/src/reference/specs/provenance.md`` §3. + +Implementation note: the active-make context is stored in a +``contextvars.ContextVar`` so it propagates correctly across threads +that share the parent's context (e.g. the populate-in-subprocess path +which uses ``multiprocessing`` workers, each of which inherits its +parent's contextvar binding at fork time). +""" + +from __future__ import annotations + +from contextvars import ContextVar +from typing import TYPE_CHECKING, Optional, Tuple + +from .errors import DataJointError + +if TYPE_CHECKING: + from .table import Table + + +# Active context: (the target table, the set of allowed full table names, the current key dict) +_active_strict_make: ContextVar[Optional[Tuple["Table", frozenset[str], dict]]] = ContextVar( + "_dj_active_strict_make", default=None +) + + +def push_strict_make_context(target: "Table", allowed_tables: frozenset[str], key: dict): + """ + Push a strict-make context for the duration of one ``make()`` invocation. + + Returns a token that the caller must pass to :func:`pop_strict_make_context` + in a ``finally`` block. + """ + return _active_strict_make.set((target, allowed_tables, key)) + + +def pop_strict_make_context(token) -> None: + """Pop the strict-make context using a token from :func:`push_strict_make_context`.""" + _active_strict_make.reset(token) + + +def get_active_context(): + """Return the currently-active strict-make context, or None.""" + return _active_strict_make.get() + + +def _base_tables(query_expression) -> set[str]: + """ + Return the set of base-table SQL names that a QueryExpression reads from. + + For a single-table expression (FreeTable / Table / restricted variants), + returns ``{full_table_name}``. For compound expressions (joins, + projections of joins), traverses ``support`` recursively. + """ + # FreeTable / Table: has full_table_name directly + ftn = getattr(query_expression, "full_table_name", None) + if isinstance(ftn, str): + return {ftn} + + bases: set[str] = set() + support = getattr(query_expression, "_support", None) or [] + for s in support: + if isinstance(s, str): + # Direct table name in the support list + bases.add(s) + else: + # Subquery — recurse + bases.update(_base_tables(s)) + return bases + + +def assert_read_allowed(query_expression) -> None: + """ + Verify a fetch is allowed under the active strict-make context. + + Called from ``QueryExpression.cursor`` before SQL is issued. No-op when + no strict-make context is active (i.e. outside ``make()`` or when + ``strict_provenance`` is False). + + Allowed reads: + + - Any table in the active context's ``allowed_tables`` set. The set is + built from ``self.upstream`` (the ancestor graph) plus the target + table and its Parts. + + Anything else raises ``DataJointError``. + + Known limitation (will sharpen in a follow-up): the check does not + distinguish reads that came *through* ``self.upstream`` from reads of + the same ancestor via a direct expression. Both are allowed if the + table is in the allowed set. The intent is to catch reads from + *undeclared* dependencies; tightening the "must come through + ``self.upstream``" path requires propagating an attribution marker + through QueryExpression composition and is deferred. + """ + ctx = _active_strict_make.get() + if ctx is None: + return # strict mode off, or outside make() + + _target, allowed_tables, _key = ctx + bases = _base_tables(query_expression) + if not bases: + return # nothing to check (e.g. dj.U expressions) + + disallowed = bases - allowed_tables + if disallowed: + raise DataJointError( + f"strict_provenance=True: read from undeclared table(s) " + f"{sorted(disallowed)} is not permitted inside make(). " + f"Use self.upstream[T] for declared ancestors, or declare a " + f"foreign-key dependency on the table you want to read." + ) + + +def assert_write_allowed(target_table, rows) -> None: + """ + Verify an insert is allowed under the active strict-make context. + + Called from ``Table.insert`` after the existing ``_allow_insert`` check. + No-op when no strict-make context is active. + + Allowed writes: + + - Target is the current ``make()`` target (``self``) or one of its Part + tables. + - Every row's primary-key columns that overlap with the current ``key`` + must equal ``key``'s values. + + Anything else raises ``DataJointError``. + """ + ctx = _active_strict_make.get() + if ctx is None: + return + + make_target, _allowed_tables, key = ctx + + # 1. Target must be `make_target` (self) or one of its Parts. + target_name = getattr(target_table, "full_table_name", None) + target_set = {make_target.full_table_name} + # Collect Part tables of make_target via class __dict__ (not dir/getattr, + # which would trigger descriptors like the _JobsDescriptor). + from .user_tables import Part # local import to avoid circular dep + + for cls in type(make_target).__mro__: + for attr_name, attr in cls.__dict__.items(): + if attr_name.startswith("_"): + continue + if isinstance(attr, type) and issubclass(attr, Part): + try: + part_ftn = attr().full_table_name + target_set.add(part_ftn) + except Exception: + pass + + if target_name not in target_set: + raise DataJointError( + f"strict_provenance=True: insert into {target_name!r} is not permitted " + f"inside make() for {make_target.full_table_name!r}. Only the target " + f"table and its Part tables may be written." + ) + + # 2. Each row's key columns that overlap with the current key must match. + if isinstance(rows, dict): + _check_row_key(rows, key) + else: + try: + for row in rows: + if isinstance(row, dict): + _check_row_key(row, key) + # Non-dict rows (tuples, etc.) bypass — older API; can't check. + except TypeError: + pass # not iterable; let downstream code handle + + +def _check_row_key(row: dict, current_key: dict) -> None: + """Raise if any row attribute overlapping with the current key has a different value.""" + for k, v in current_key.items(): + if k in row and row[k] != v: + raise DataJointError( + f"strict_provenance=True: inserted row's {k!r}={row[k]!r} does not " + f"match the current make() key's {k!r}={v!r}. Inserts must be " + f"consistent with the key being populated." + ) diff --git a/src/datajoint/settings.py b/src/datajoint/settings.py index 7a035f6d8..6ae23478b 100644 --- a/src/datajoint/settings.py +++ b/src/datajoint/settings.py @@ -69,6 +69,7 @@ "database.database_prefix": "DJ_DATABASE_PREFIX", "database.create_tables": "DJ_CREATE_TABLES", "loglevel": "DJ_LOG_LEVEL", + "strict_provenance": "DJ_STRICT_PROVENANCE", "display.diagram_direction": "DJ_DIAGRAM_DIRECTION", } @@ -361,6 +362,16 @@ class Config(BaseSettings): "*New in 2.2.3.*", ) + strict_provenance: bool = Field( + default=False, + validation_alias="DJ_STRICT_PROVENANCE", + description="If True, enforces the upstream-only convention inside make(): " + "reads must go through self.upstream[Ancestor], writes must target self " + "or self's Part tables with primary keys consistent with the current key. " + "Off by default; opt-in for deployments that need runtime provenance " + "guarantees backing downstream lineage / CDC tooling. *New in 2.3.*", + ) + # Cache path for query results query_cache: Path | None = None diff --git a/src/datajoint/table.py b/src/datajoint/table.py index 7f8cbaf70..944bb1b63 100644 --- a/src/datajoint/table.py +++ b/src/datajoint/table.py @@ -797,6 +797,12 @@ def insert( " To override, set keyword argument allow_direct_insert=True." ) + # Strict-provenance write gate. No-op outside make() or when the + # config flag is off. See src/datajoint/provenance.py. + from .provenance import assert_write_allowed + + assert_write_allowed(self, rows) + if inspect.isclass(rows) and issubclass(rows, QueryExpression): rows = rows() # instantiate if a class if isinstance(rows, QueryExpression): diff --git a/tests/integration/test_strict_provenance.py b/tests/integration/test_strict_provenance.py new file mode 100644 index 000000000..ce3a0e5b9 --- /dev/null +++ b/tests/integration/test_strict_provenance.py @@ -0,0 +1,244 @@ +""" +Integration tests for ``dj.config["strict_provenance"]`` (#1425). + +Strict mode gates reads (``QueryExpression.cursor``) and writes +(``Table.insert``) inside ``make()`` to the declared upstream graph +and the target table + its Parts. Off by default; opt-in. +""" + +import pytest + +import datajoint as dj +from datajoint import DataJointError + + +@pytest.fixture +def strict_mode(connection_test): + """Enable strict_provenance for the duration of one test.""" + config = connection_test._config + previous = config.get("strict_provenance", False) + config["strict_provenance"] = True + try: + yield + finally: + config["strict_provenance"] = previous + + +def test_strict_compliant_make_passes(prefix, connection_test, strict_mode): + """A make() that reads via self.upstream and writes to self with key consistency runs cleanly.""" + schema = dj.Schema(f"{prefix}_strict_compliant", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + --- + name : varchar(64) + """ + contents = [(1, "alice"), (2, "bob")] + + @schema + class Greeting(dj.Computed): + definition = """ + -> Subject + --- + greeting : varchar(128) + """ + + def make(self, key): + name = self.upstream[Subject].fetch1("name") + self.insert1({**key, "greeting": f"Hello, {name}!"}) + + Greeting.populate() + assert (Greeting & {"subject_id": 1}).fetch1("greeting") == "Hello, alice!" + assert (Greeting & {"subject_id": 2}).fetch1("greeting") == "Hello, bob!" + + +def test_strict_blocks_read_from_undeclared_table(prefix, connection_test, strict_mode): + """Reading from a table NOT in the trace's ancestor set raises.""" + schema = dj.Schema(f"{prefix}_strict_undeclared", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + """ + contents = [(1,)] + + @schema + class Unrelated(dj.Lookup): + definition = """ + u_id : int32 + --- + secret : varchar(64) + """ + contents = [(42, "should-not-read")] + + captured: list[Exception] = [] + + @schema + class Bad(dj.Computed): + definition = """ + -> Subject + --- + val : int32 + """ + + def make(self, key): + try: + Unrelated.fetch() # not in declared upstream of Bad + except DataJointError as e: + captured.append(e) + # Insert anyway so populate doesn't fail + self.insert1({**key, "val": 0}) + + Bad.populate() + assert len(captured) == 1 + assert "strict_provenance" in str(captured[0]).lower() + assert "undeclared" in str(captured[0]).lower() + + +def test_strict_blocks_write_to_other_table(prefix, connection_test, strict_mode): + """Writing into a table other than self / self.Parts raises.""" + schema = dj.Schema(f"{prefix}_strict_other_target", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + """ + contents = [(1,)] + + @schema + class AuditLog(dj.Manual): + definition = """ + log_id : int32 + --- + event : varchar(64) + """ + + captured: list[Exception] = [] + + @schema + class Derived(dj.Computed): + definition = """ + -> Subject + --- + val : int32 + """ + + def make(self, key): + try: + AuditLog.insert1({"log_id": 1, "event": "side-effect"}, allow_direct_insert=True) + except DataJointError as e: + captured.append(e) + self.insert1({**key, "val": 1}) + + Derived.populate() + assert len(captured) == 1 + assert "strict_provenance" in str(captured[0]).lower() + assert "not permitted" in str(captured[0]).lower() + + +def test_strict_blocks_write_with_mismatched_key(prefix, connection_test, strict_mode): + """Writing a row whose PK columns disagree with the current key raises.""" + schema = dj.Schema(f"{prefix}_strict_key_mismatch", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + """ + contents = [(1,), (2,)] + + captured: list[Exception] = [] + + @schema + class Wrong(dj.Computed): + definition = """ + -> Subject + --- + val : int32 + """ + + def make(self, key): + try: + # Try to insert a row for a DIFFERENT subject than the current key + bogus_key = {"subject_id": 99} + self.insert1({**bogus_key, "val": 0}) + except DataJointError as e: + captured.append(e) + # Insert correctly to let populate complete + self.insert1({**key, "val": 1}) + + Wrong.populate() + assert len(captured) == 2 # fires for both subjects + assert all("does not match the current make() key" in str(e) for e in captured) + + +def test_strict_writes_to_part_table_pass(prefix, connection_test, strict_mode): + """Writing into self.Parts (with key consistency) is allowed.""" + schema = dj.Schema(f"{prefix}_strict_parts", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + """ + contents = [(1,)] + + @schema + class Master(dj.Computed): + definition = """ + -> Subject + --- + summary : varchar(32) + """ + + class Bin(dj.Part): + definition = """ + -> master + bin_id : int32 + --- + energy : float64 + """ + + def make(self, key): + self.insert1({**key, "summary": "ok"}) + self.Bin.insert([{**key, "bin_id": i, "energy": float(i)} for i in range(3)]) + + Master.populate() + assert (Master & {"subject_id": 1}).fetch1("summary") == "ok" + assert len(Master.Bin & {"subject_id": 1}) == 3 + + +def test_strict_off_by_default_no_change(prefix, connection_test): + """With strict_provenance unset (default False), existing patterns work unchanged.""" + schema = dj.Schema(f"{prefix}_strict_default_off", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + """ + contents = [(1,)] + + @schema + class DerivedLegacy(dj.Computed): + definition = """ + -> Subject + --- + val : int32 + """ + + def make(self, key): + # Direct ancestor fetch — would be flagged in strict mode (read from + # undeclared, but Subject IS an ancestor — actually allowed under + # the current "table in allowed set" rule even in strict mode). + # In default-off mode, this must work either way. + (Subject & key).fetch1("subject_id") + self.insert1({**key, "val": 0}) + + # No strict_mode fixture — default-off + DerivedLegacy.populate() + assert (DerivedLegacy & {"subject_id": 1}).fetch1("val") == 0 From d0e8a80cc0f12d39f6c1da45e51e693403c3bda4 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 2 Jul 2026 09:51:17 -0500 Subject: [PATCH 2/2] fix(#1425): strict write gate no longer consumes one-shot row iterables MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The strict-provenance write gate ran assert_write_allowed(self, rows) before insert materialized rows, and its key-consistency check did 'for row in rows'. For a one-shot iterable (generator, map, iter), that exhausted it, so the downstream insert saw an empty iterator and wrote zero rows with no error — silent data loss on compliant self.insert() code (common for Part inserts). insert() was also double-executed. Split the gate: assert_write_allowed(target) now does the target-only check (needs no rows) in Table.insert; a new per-row assert_row_key_allowed(row) runs in Table._insert_rows as each row is materialized — the single point reached by both the chunked and single-batch paths, so streaming/chunking is preserved and the caller's iterable is never consumed early. The QueryExpression (INSERT ... SELECT) path is no longer iterated by the gate, fixing the double-execution; per-row key consistency does not apply there (rows never materialize client-side), governed by the target check only. Adds regression tests: a generator insert into a Part must land all rows, and the per-row key check still fires for generator-sourced rows. Fixes the blocking bug flagged by @ttngu207 in review. Read-gate coverage (len/bool/in, restriction-by-table) is tracked separately pending the best-effort-vs-close decision. --- src/datajoint/provenance.py | 63 ++++++++++------- src/datajoint/table.py | 27 +++++-- tests/integration/test_strict_provenance.py | 78 +++++++++++++++++++++ 3 files changed, 138 insertions(+), 30 deletions(-) diff --git a/src/datajoint/provenance.py b/src/datajoint/provenance.py index e124d1160..8f196194f 100644 --- a/src/datajoint/provenance.py +++ b/src/datajoint/provenance.py @@ -4,8 +4,11 @@ When the flag is enabled, this module's context (set by ``AutoPopulate._populate_one``) tracks which tables and primary key the currently-executing ``make()`` is allowed to read and write. The read gate in :func:`assert_read_allowed` -fires inside ``QueryExpression.cursor``; the write gate in -:func:`assert_write_allowed` fires inside ``Table.insert``. +fires inside ``QueryExpression.cursor``. The write gate has two parts: the +target check in :func:`assert_write_allowed` fires inside ``Table.insert`` +(before rows are materialized), and the per-row key-consistency check in +:func:`assert_row_key_allowed` fires inside ``Table._insert_rows`` as each row +is materialized — so the gate never consumes the caller's ``rows`` iterable. The contract is documented in ``datajoint-docs/src/reference/specs/provenance.md`` §3. @@ -122,29 +125,30 @@ def assert_read_allowed(query_expression) -> None: ) -def assert_write_allowed(target_table, rows) -> None: +def assert_write_allowed(target_table) -> None: """ - Verify an insert is allowed under the active strict-make context. + Verify the *target* of an insert is allowed under the active strict-make context. - Called from ``Table.insert`` after the existing ``_allow_insert`` check. - No-op when no strict-make context is active. + Called from ``Table.insert`` after the existing ``_allow_insert`` check and + before any rows are materialized. No-op when no strict-make context is active. - Allowed writes: + Allowed targets: - - Target is the current ``make()`` target (``self``) or one of its Part - tables. - - Every row's primary-key columns that overlap with the current ``key`` - must equal ``key``'s values. + - The current ``make()`` target (``self``) or one of its Part tables. - Anything else raises ``DataJointError``. + Per-row key consistency is checked separately by :func:`assert_row_key_allowed` + as rows are materialized, so this gate never consumes the caller's ``rows`` + iterable — a one-shot generator must survive to reach ``insert``. + + Raises ``DataJointError`` if the target is not permitted. """ ctx = _active_strict_make.get() if ctx is None: return - make_target, _allowed_tables, key = ctx + make_target, _allowed_tables, _key = ctx - # 1. Target must be `make_target` (self) or one of its Parts. + # Target must be `make_target` (self) or one of its Parts. target_name = getattr(target_table, "full_table_name", None) target_set = {make_target.full_table_name} # Collect Part tables of make_target via class __dict__ (not dir/getattr, @@ -169,17 +173,26 @@ def assert_write_allowed(target_table, rows) -> None: f"table and its Part tables may be written." ) - # 2. Each row's key columns that overlap with the current key must match. - if isinstance(rows, dict): - _check_row_key(rows, key) - else: - try: - for row in rows: - if isinstance(row, dict): - _check_row_key(row, key) - # Non-dict rows (tuples, etc.) bypass — older API; can't check. - except TypeError: - pass # not iterable; let downstream code handle + +def assert_row_key_allowed(row) -> None: + """ + Verify a single insert row's key columns match the active ``make()`` key. + + Called per row from ``Table._insert_rows`` as rows are materialized, so the + check sees a concrete row without the write gate having to consume the + caller's ``rows`` iterable. No-op when no strict-make context is active or + when ``row`` is not a dict (numpy records / bare sequences carry no field + names to check by — same as the previous behavior). + + Raises ``DataJointError`` on a mismatch. + """ + ctx = _active_strict_make.get() + if ctx is None: + return + if not isinstance(row, dict): + return + _make_target, _allowed_tables, key = ctx + _check_row_key(row, key) def _check_row_key(row: dict, current_key: dict) -> None: diff --git a/src/datajoint/table.py b/src/datajoint/table.py index 944bb1b63..5874ecfb2 100644 --- a/src/datajoint/table.py +++ b/src/datajoint/table.py @@ -797,16 +797,23 @@ def insert( " To override, set keyword argument allow_direct_insert=True." ) - # Strict-provenance write gate. No-op outside make() or when the - # config flag is off. See src/datajoint/provenance.py. + # Strict-provenance write gate (target check only). No-op outside make() + # or when the config flag is off. Deliberately does NOT touch `rows` — + # the per-row key-consistency check happens in `_insert_rows` as rows are + # materialized, so a one-shot iterable (generator) is not consumed here. + # See src/datajoint/provenance.py. from .provenance import assert_write_allowed - assert_write_allowed(self, rows) + assert_write_allowed(self) if inspect.isclass(rows) and issubclass(rows, QueryExpression): rows = rows() # instantiate if a class if isinstance(rows, QueryExpression): - # insert from select - chunk_size not applicable + # insert from select - chunk_size not applicable. + # Note: this INSERT ... SELECT runs entirely server-side, so under + # strict_provenance the per-row key-consistency check does not apply + # (row values are never materialized client-side). The target check + # in assert_write_allowed above still governs which table is written. if chunk_size is not None: raise DataJointError("chunk_size is not supported for QueryExpression inserts") if not ignore_extra_fields: @@ -861,7 +868,17 @@ def _insert_rows(self, rows, replace, skip_duplicates, ignore_extra_fields): """ # collects the field list from first row (passed by reference) field_list = [] - rows = list(self.__make_row_to_insert(row, field_list, ignore_extra_fields) for row in rows) + # Strict-provenance per-row key check runs here, as each row is + # materialized — no-op outside make()/when the flag is off. Placing it in + # this single materialization point (reached by both the chunked and + # single-batch paths) avoids consuming the caller's `rows` iterable early. + from .provenance import assert_row_key_allowed + + def _make_row(row): + assert_row_key_allowed(row) + return self.__make_row_to_insert(row, field_list, ignore_extra_fields) + + rows = list(_make_row(row) for row in rows) if rows: try: # Handle empty field_list (all-defaults insert) diff --git a/tests/integration/test_strict_provenance.py b/tests/integration/test_strict_provenance.py index ce3a0e5b9..5def0c960 100644 --- a/tests/integration/test_strict_provenance.py +++ b/tests/integration/test_strict_provenance.py @@ -212,6 +212,84 @@ def make(self, key): assert len(Master.Bin & {"subject_id": 1}) == 3 +def test_strict_generator_insert_not_dropped(prefix, connection_test, strict_mode): + """Regression (#1474 bug 1): a one-shot generator of compliant rows must not + be consumed by the write gate. Before the fix, assert_write_allowed iterated + `rows` for its key check, exhausting the generator so insert saw zero rows and + silently wrote nothing.""" + schema = dj.Schema(f"{prefix}_strict_generator", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + """ + contents = [(1,), (2,)] + + @schema + class Spectrum(dj.Computed): + definition = """ + -> Subject + --- + n : int32 + """ + + class Bin(dj.Part): + definition = """ + -> master + bin_id : int32 + --- + energy : float64 + """ + + def make(self, key): + n = 5 + self.insert1({**key, "n": n}) + # one-shot generator (not a list) — must survive the write gate + self.Bin.insert({**key, "bin_id": i, "energy": float(i)} for i in range(n)) + + Spectrum.populate() + for sid in (1, 2): + assert (Spectrum & {"subject_id": sid}).fetch1("n") == 5 + # The core assertion: all 5 generated rows landed, none silently dropped. + assert len(Spectrum.Bin & {"subject_id": sid}) == 5 + + +def test_strict_generator_insert_mismatched_key_still_caught(prefix, connection_test, strict_mode): + """The per-row key check still fires when rows come from a generator — a row + whose key disagrees with the current make() key raises, not silently passes.""" + schema = dj.Schema(f"{prefix}_strict_gen_mismatch", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + """ + contents = [(1,)] + + @schema + class Derived(dj.Computed): + definition = """ + -> Subject + --- + val : int32 + """ + + class Bin(dj.Part): + definition = """ + -> master + bin_id : int32 + """ + + def make(self, key): + self.insert1({**key, "val": 0}) + # generator whose 3rd row carries a bogus subject_id + self.Bin.insert({**({**key, "subject_id": 999} if i == 2 else key), "bin_id": i} for i in range(4)) + + with pytest.raises(DataJointError, match="does not match the current make"): + Derived.populate() + + def test_strict_off_by_default_no_change(prefix, connection_test): """With strict_provenance unset (default False), existing patterns work unchanged.""" schema = dj.Schema(f"{prefix}_strict_default_off", connection=connection_test)