From 0c5ca50a1da591b58cae1160c0b773d1eefda6d9 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 26 Jun 2026 23:13:29 +0200 Subject: [PATCH] Numba scan: fix invalid cross alias and inplace --- pytensor/graph/features.py | 65 ++++++ pytensor/link/numba/dispatch/scan.py | 283 ++++++++++++++++++++++---- pytensor/scan/op.py | 20 +- pytensor/tensor/rewriting/elemwise.py | 3 + tests/link/numba/test_scan.py | 185 +++++++++++++++-- 5 files changed, 491 insertions(+), 65 deletions(-) diff --git a/pytensor/graph/features.py b/pytensor/graph/features.py index 29c0b99dc8..b83eeab15a 100644 --- a/pytensor/graph/features.py +++ b/pytensor/graph/features.py @@ -917,3 +917,68 @@ def on_validate(self, fgraph): ) return True + + +class NoOutputInplaceOnInput(Feature): + """Forbid any `FunctionGraph` output from carrying a protected input destroyed in place. + + Intermediate nodes may still destroy the protected inputs; only the graph + outputs may not be (a view of) the destroyed buffer. This suits a buffer the + surrounding code overwrites between evaluations — e.g. a + :class:`~pytensor.scan.op.Scan` tap, which the loop overwrites in place: an + output aliasing it would be read after the overwrite. Such an output could be + copied back out to stay correct, but needing that copy defeats the in-place + computation, so it is forbidden outright instead. + + Parameters + ---------- + protected_input_idxs + Positions in ``fgraph.inputs`` of the inputs that no output may carry. + """ + + def __init__(self, protected_input_idxs): + self.protected_input_idxs = tuple(protected_input_idxs) + + def on_attach(self, fgraph): + if hasattr(fgraph, "_no_output_inplace_on_input"): + raise AlreadyThere( + f"NoOutputInplaceOnInput is already attached to {fgraph}." + ) + fgraph._no_output_inplace_on_input = self + + def clone(self): + return type(self)(self.protected_input_idxs) + + def on_validate(self, fgraph): + if not hasattr(fgraph, "destroyers"): + return True + destroyed = { + fgraph.inputs[i] + for i in self.protected_input_idxs + if fgraph.destroyers(fgraph.inputs[i]) + } + if not destroyed: + return True + # A protected input is an fgraph input (no owner), so it is the root of any alias + # chain it belongs to. Walk each output back along view_map *and* destroy_map edges + # to its memory root (as ``pytensor.compile.aliasing.alias_root`` does, which we + # can't import here without a graph<->compile cycle); the output carries a destroyed + # protected input -- as a view, or as the in-place result that destroyed it -- iff + # that root is one. (DestroyHandler's cached ``droot`` won't do: it tracks only + # view_map edges, so it omits the destroyer's own output.) + for out_idx, out in enumerate(fgraph.outputs): + root = out + while root.owner is not None: + pos = root.owner.outputs.index(root) + rop = root.owner.op + sources = (*rop.view_map.get(pos, ()), *rop.destroy_map.get(pos, ())) + if not sources: + break + root = root.owner.inputs[sources[0]] + if root in destroyed: + raise InconsistencyError( + f"Output {out_idx} would carry input {fgraph.inputs.index(root)} " + "destroyed in place; that input may be destroyed by intermediate " + "nodes only, not aliased by an output." + ) + return True diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index cf177f6183..1ccc8af579 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -5,9 +5,14 @@ from numba import types from numba.extending import overload -from pytensor.compile.aliasing import add_supervisor_to_fgraph, insert_deepcopy +from pytensor.compile.aliasing import ( + add_supervisor_to_fgraph, + alias_root, + insert_deepcopy, +) from pytensor.compile.io import In, Out from pytensor.compile.mode import NUMBA, get_mode +from pytensor.graph.features import NoOutputInplaceOnInput from pytensor.link.numba.cache import compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( @@ -56,6 +61,99 @@ def range_arr(x): @register_funcify_and_cache_key(Scan) def numba_funcify_Scan(op: Scan, node, **kwargs): + """Generate a Numba implementation of a `Scan` loop. + + Memory-aliasing contract + ------------------------ + Scan defines a loop over an inner function with signature: + (*sequences[idx], *traced[idx], *untraced, *non_sequences) + -> (*traced_updates[idx], *untraced_updates) + + Traced variables are read from an indexed circular buffer at every iteration, + and the updates stored (copied) back to it immediately after. Untraced variables + are carried by reference, with each update becoming the next iteration's input. + + Scan is sometimes allowed to destroy/alias the outer traced and untraced variables, + but never sequences and non-sequences. Specifically, outer untraced variables can be + destroyed (destroy_map, opt in) or aliased (view_map, default). Traced variables can be + destroyed (destroy_map, opt in), but otherwise not alias (never in view_map). + Note that destroy permission implies alias permission, but not the other way around. + + Scan is not allowed to return outputs that alias each other, unless they were already + aliased from the outside, and it was itself allowed to alias/destroy them. This means + PyTensor already gauged it was safe to destroy/alias them. + + Scan has some freedom in how this outer contract is respected. If needed, it can + deepcopy the outer inputs once at the start, or make sure any aliased output + the inner function returns is properly copied before the final return. + + Internally, Scan also has total control over the boundary memory management of the + inner function: it grants the permissions to destroy or alias the inner loop inputs, + and whether the inner outputs may alias each other. This inner boundary is distinct + from the outer contract above, and it is Scan's responsibility to choose an inner + strategy that produces correct results while still respecting the outer contract. + + Memory-aliasing strategy + ------------------------ + Traced variables + ~~~~~~~~~~~~~~~~ + Traced variables are deep-copied once at the start if they are not in the destroy_map. + + Because every inner trace update is copied back to the buffer immediately, the inner function + is allowed to alias (but not destroy) the sequence reads, non_sequences, as well as the + traced and untraced inputs or updates, when producing the traced updates. + + A special case occurs when the indexed reads will be immediately overwritten by the updates + in the same loop iteration. For single output taps (mit-sot, sit-sot) this can only + happen when the circular buffer is truncated to its minimum legal length. + For mit-mot this can also happen without any buffer loop-around. + In either case, traced updates are not allowed to alias those traced reads, + as they may otherwise be corrupted if the reads are updated before they were copied to their own buffer. + + On the plus side, when this happens, the inner function is granted permission to destroy these + immediately-to-be discarded reads, as long as the returned updates do not themselves alias them. + + The alias-restriction and destroy-permission caused by the loop-around behavior are derived from the + buffer's static length: + * known large enough: no overwrite is possible, neither alias restriction nor destroy permission applies; + * length unknown: the loop-around overwrite can't be ruled out, alias restricted but not granted destroy permission; + * known minimal: the overwrite is certain, alias restricted but granted destroy permission. + + Untraced variables + ~~~~~~~~~~~~~~~~~~ + Untraced variables are deep-copied once at the start if they are not in destroy_map + and the inner function destroys them. + + Untraced updates are allowed to alias their own untraced inputs (which happens when n_steps=0) + or when the inner function update naturally alias the input (eg, o = i; o = i.T; o = i[::-1]). + + Because the last untraced updates are returned as is, the inner function is not allowed to + alias sequences, non_sequences, or other untraced inputs and outputs (violates the outer alias restriction). + Untraced updates are not allowed to alias traced reads (risks corruption by subsequent overwrittes), + but can alias traced updates, since the immediate copy to their buffer that follows, will break the alias. + + Because untraced inputs are immediately discarded (and protected from alias with other updates), + the inner function is always granted permission to destroy them. It can do so from any computation, + not only the one producing the matching untraced update. + + Controlling inner graph alias + ----------------------------- + PyTensor allows initial graphs to contain arbitrary (non-destructive) aliasing. + Alias at the boundary (output aliasing an input or another output) is controlled via + targeted deepcopies at the end (using the insert_deepcopy helper). + + In contrast, destruction is usually NOT allowed to be present in initial graphs. + Destructive alias at the boundary is controlled during rewrites, with the following features: + * Supervisor: Checks whether any protected input are destroyed + * NoOutputInplaceOnInput: Checks whether an output is destroying a non-protected input + (protected inputs are already covered by Supervisor) + Inside the boundary: + * DestroyHandler: Checks whether a consistent ordering exists for the destruction/view chains, + i.e., every read runs before its buffers' destruction and the chain has no cycle. + These features can veto (undo) any rewrite that would violate their spec. + They CANNOT fix violations that already existed in the initial graph. + + """ # Apply inner rewrites # TODO: Not sure this is the right place to do this, should we have a rewrite that # explicitly triggers the optimization of the inner graphs of Scan? @@ -67,17 +165,17 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): .optimizer ) fgraph = op.fgraph - # When the buffer can only hold one SITSOT or as as many MITSOT as there are taps, - # We must always discard the oldest tap, so it's safe to destroy it in the inner function. - # TODO: Allow inplace for MITMOT - destroyable_sitsot = [ + + # If we know the static length of the traced buffers, check whether traced input reads + # will be immediately discarded in the same loop iteration. + discarded_sitsot = [ inner_sitsot for outer_sitsot, inner_sitsot in zip( op.outer_sitsot(node.inputs), op.inner_sitsot(fgraph.inputs), strict=True ) if outer_sitsot.type.shape[0] == 1 ] - destroyable_mitsot = [ + discarded_mitsot = [ oldest_inner_mitmot for outer_mitsot, oldest_inner_mitmot, taps in zip( op.outer_mitsot(node.inputs), @@ -87,39 +185,152 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ) if outer_mitsot.type.shape[0] == abs(min(taps)) ] - # Always allow the inner function to destroy untraced_sit_sot inputs. - # After the first iteration, these come from the previous output so - # destroying is always safe. For the first iteration, the codegen - # copies the outer input if the Scan's destroy_map doesn't allow it. - destroyable_untraced_sit_sot = list(op.inner_untraced_sit_sot(fgraph.inputs)) - destroyable = { - *destroyable_sitsot, - *destroyable_mitsot, - *destroyable_untraced_sit_sot, + # Untraced inputs are always immediately discarded + discarded_untraced_sit_sot = list(op.inner_untraced_sit_sot(fgraph.inputs)) + discarded = { + *discarded_sitsot, + *discarded_mitsot, + *discarded_untraced_sit_sot, } - input_specs = [In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs] + # Grant the inner function the right to alias (borrow=True) all inputs + # and to destroy (mutable=True) reads that are known to be immediately discarded + # TODO: Allow destroying MITMOT as well (#2252). + input_specs = [In(x, borrow=True, mutable=x in discarded) for x in fgraph.inputs] add_supervisor_to_fgraph( fgraph=fgraph, input_specs=input_specs, accept_inplace=True, ) - rewriter(fgraph) - untraced_sit_sot_inner_outputs = set(op.inner_untraced_sit_sot_outs(fgraph.outputs)) - output_specs = [ - Out(x, borrow=x in untraced_sit_sot_inner_outputs) for x in fgraph.outputs + + # Check which traced indexed reads MAY be overwritten immediately. + # In this case outputs are not allowed to alias (destructively or not) the reads + # so as to not be corrupted before they are themselves safely stored. + # With buffer length L, output tap o_out writes input tap + # o_in iff (o_out - o_in) % L == 0 (same physical slot). Two ways this happens: + # - linear write-back (gap 0): the recurrence writes the very slot it read this step + # (mit_mot accumulators: read g[k], write g[k] + delta back). Holds for any L. + # - loop-around (gap == reach): sit_sot/mit_sot write the new state just past the oldest + # read; in a truncated buffer (L == reach, the minimal length that still holds the + # oldest tap) that wraps onto the just-discarded oldest read. A write landing further + # ahead would store onto a slot still to be read -- an invalid recurrence we need not + # consider, just as we don't consider L < reach (reads would alias). So those are the + # only two gaps. + # ``reach`` is the minimal admissible buffer length, ``max_lookback`` (= abs of the + # oldest tap); a write offset never exceeds it, so for any larger L the new slot is fresh + # and no wrap occurs. When L is statically known we test (o_out - o_in) % L == 0 exactly; + # otherwise L is only known to be >= reach and the loop-around can bite only at L == + # reach, so we test gap 0 or gap == reach. + def find_potentially_overwritten_reads( + grouped_inner, outers, in_slices_seq, out_slices_seq + ): + result = [] + for inner_vars, outer, in_slices, out_slices in zip( + grouped_inner, outers, in_slices_seq, out_slices_seq, strict=True + ): + max_lookback = -min(0, min(in_slices)) + in_offsets = [max_lookback + t for t in in_slices] + out_offsets = [max_lookback + t for t in out_slices] + static_len = outer.type.shape[0] + reach = max_lookback # minimal admissible buffer length + for v, o_in in zip(inner_vars, in_offsets, strict=True): + if static_len is None: + hit = any( + o_out == o_in or o_out - o_in == reach for o_out in out_offsets + ) + else: + hit = any((o_out - o_in) % static_len == 0 for o_out in out_offsets) + if hit: + result.append(v) + return result + + potentially_overwritten_reads = [ + *find_potentially_overwritten_reads( + op.inner_mitmot_grouped(fgraph.inputs), + op.outer_mitmot(node.inputs), + op.info.mit_mot_in_slices, + op.info.mit_mot_out_slices, + ), + *find_potentially_overwritten_reads( + op.inner_mitsot_grouped(fgraph.inputs), + op.outer_mitsot(node.inputs), + op.info.mit_sot_in_slices, + [(0,)] * op.info.n_mit_sot, + ), + *find_potentially_overwritten_reads( + [[v] for v in op.inner_sitsot(fgraph.inputs)], + op.outer_sitsot(node.inputs), + op.info.sit_sot_in_slices, + [(0,)] * op.info.n_sit_sot, + ), ] - insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs) + if potentially_overwritten_reads: + # Forbid the inner function from aliasing by destruction overwritten traced reads. + # Note we could have allowed this and patched at the end with a deepcopy (like we do with non-destructive alias) + # But this is wasteful. By forbidding it the inner graph will itself allocate a fresh buffer + # and write the result there immediately. + in_pos = {v: i for i, v in enumerate(fgraph.inputs)} + fgraph.attach_feature( + NoOutputInplaceOnInput([in_pos[t] for t in potentially_overwritten_reads]) + ) - # Track which untraced_sit_sot outputs have their inner input destroyed - # by the optimized inner function (transitively, via DestroyHandler). - untraced_start = ( - op.info.n_mit_mot + op.info.n_mit_sot + op.info.n_sit_sot + op.info.n_nit_sot + # Rewrite graph + rewriter(fgraph) + + # Post-patch alias contract via targeted deepcopies + # Traced and untraced updates: copy an alias of a tap input the loop (may) overwrite + # this iteration. + # Untraced updates: keep as is if it is the FIRST viewer of a freshly produced buffer + # (including traced updates which will be copied immediately anyway), OR its own untraced input. + # Any other alias is broken with a deepcopy: Sequences reads, traced reads, non-sequences, + # other untraced inputs or updates. + # Note: We could squeeze some more memory reuse by delaying the breaking of aliasing between untraced variables + # by delaying the patched deepcopy until after the loop is over. This requires some care to handle alias transitions + # between untraced updates that can happen over multiple iterations, and protect against cross-iteration destruction + # that can corrupt such chains. + own_untraced_input = dict( + zip( + op.inner_untraced_sit_sot_outs(fgraph.outputs), + op.inner_untraced_sit_sot(fgraph.inputs), + strict=True, + ) ) - inner_destroyed_untraced_out_idxs = set() + untraced_outs = set(own_untraced_input) + seen_untraced_roots = set() + output_specs = [] + for update in fgraph.outputs: + root = alias_root(update) + if update in untraced_outs: + borrow = ( + ( + # freshly produced buffer + root.owner is not None + # or a self alias + or root is own_untraced_input[update] + ) + # and not an alias of another untraced update + and root not in seen_untraced_roots + ) + if borrow: + seen_untraced_roots.add(root) + else: + # traced update + borrow = root not in potentially_overwritten_reads + output_specs.append(Out(update, borrow=borrow)) + insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs) + + # Collect a set of untraced slots the inner function destroys in place. + # These may demand an initial copy if the Scan is not granted permission to destroy them already. + untraced_inputs_destroyed_by_inner_function = set() if hasattr(fgraph, "destroyers"): - for j, inner_inp in enumerate(op.inner_untraced_sit_sot(fgraph.inputs)): - if fgraph.destroyers(inner_inp): - inner_destroyed_untraced_out_idxs.add(untraced_start + j) + untraced_inputs_destroyed_by_inner_function = { + outer_out_idx + for inner_in, (outer_out_idx, _) in zip( + op.inner_untraced_sit_sot(fgraph.inputs), + op.outer_untraced_sit_sot_outs(node.outputs, with_idx=True), + strict=True, + ) + if fgraph.destroyers(inner_in) + } scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key( op.fgraph, fgraph_name="numba_scan", ofg_memo=kwargs.get("ofg_memo") @@ -357,12 +568,13 @@ def add_output_storage_post_proc_stmt( inner_out_to_outer_in_stmts.append(storage_name) output_idx = outer_output_names.index(storage_name) - # Copy the outer input when it will be mutated during the loop - # but the Scan's destroy_map doesn't grant ownership. - # Tapped outputs: the loop writes into the buffer via circular indexing. - # Untraced sit_sot: the inner function may destroy the input inplace. + # Copy the outer inputs when the loop mutates them and the destroy_map doesn't already grant permission needs_copy = output_idx not in node.op.destroy_map and ( - is_tapped or output_idx in inner_destroyed_untraced_out_idxs + # Traced buffers are always mutated by the loop write-back procedure + is_tapped + # Untraced inputs are only mutated by the inner function, + # so we make the copy conditional on that actually happening + or output_idx in untraced_inputs_destroyed_by_inner_function ) if needs_copy: storage_alloc_stmt = f"{storage_name} = numba_deepcopy({outer_in_name})" @@ -496,8 +708,9 @@ def scan({", ".join(outer_in_names)}): # If we can't cache the inner function, we can't cache the Scan either scan_cache_key = None else: + scan_cache_version = 1 scan_cache_key = sha256( - f"({scan_op_src}, {inner_func_cache_key})".encode() + f"({scan_op_src}, {inner_func_cache_key}, {scan_cache_version})".encode() ).hexdigest() return numba_basic.numba_njit(scan_op_fn, boundscheck=False), scan_cache_key diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index a0cb073f54..8ac84d3d14 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -1377,9 +1377,6 @@ def prepare_fgraph(self, fgraph): # `Function` pipeline. update_mapping = {} preallocated_mitmot_outs = [] - untraced_sit_sot_inner_outputs = self.inner_untraced_sit_sot_outs( - fgraph.outputs - ) if config.scan__allow_output_prealloc: # Go through the mitmots. Whenever a mitmot has a tap both as an @@ -1434,10 +1431,10 @@ def prepare_fgraph(self, fgraph): for x in fgraph.inputs[input_idx:] ] wrapped_outputs = [Out(x, borrow=True) for x in fgraph.outputs[:slices]] - wrapped_outputs += [ - Out(x, borrow=x in untraced_sit_sot_inner_outputs) - for x in fgraph.outputs[slices:] - ] + # Untraced sit_sot states are kept by reference across iterations, so + # their inner outputs must not alias inputs/other outputs (borrow=False + # lets insert_deepcopy break such aliasing). See issue #2252. + wrapped_outputs += [Out(x, borrow=False) for x in fgraph.outputs[slices:]] protected_outs = tuple( i @@ -1453,14 +1450,7 @@ def prepare_fgraph(self, fgraph): else: wrapped_inputs = [In(x, borrow=True) for x in fgraph.inputs] - wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs[:slices]] - untraced_sit_sot_inner_outputs = self.inner_untraced_sit_sot_outs( - fgraph.outputs - ) - wrapped_outputs += [ - Out(x, borrow=x in untraced_sit_sot_inner_outputs) - for x in fgraph.outputs[slices:] - ] + wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs] fgraph.update_mapping = update_mapping diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 23c5ce5110..240b02e813 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -159,6 +159,9 @@ def try_inplace_on_node( node = inplace_node except InconsistencyError: inplace_pattern.pop(o) + # The input wasn't consumed, so let another output try to + # reuse it in place. + tried_inputs.discard(i) if node is not original_node: copy_stack_trace(original_node.outputs, node.outputs) return node diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 54fc2a82f1..8b49b76162 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -325,7 +325,13 @@ def test_inner_graph_optimized(): @pytest.mark.parametrize("n_steps_constant", (True, False)) def test_inplace_taps(n_steps_constant): - """Test that numba will inplace in the inner_function of the oldest sit-sot, mit-sot taps.""" + """No inner output may be computed in place on a sit-sot/mit-sot tap. + + The scan loop overwrites those tap buffers in place each iteration, so an output + aliasing one would be corrupted. Only an ``untraced_sit_sot`` input (rebound by + reference, never loop-overwritten) may be destroyed in place by its own output. + See #2252. + """ n_steps = 10 if n_steps_constant else scalar("n_steps", dtype=int) a = scalar("a") x0 = scalar("x0") @@ -363,16 +369,10 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): if isinstance(node.op, Scan) ] - # Collect inner inputs we expect to be destroyed by the step function - # Scan reorders inputs internally, so we need to check its ordering + # Collect the inner inputs destroyed in place by the output-producing nodes. + # Scan reorders inputs internally, so we go through its accessors. inner_inps = scan_op.fgraph.inputs mit_sot_inps = scan_op.inner_mitsot(inner_inps) - oldest_mit_sot_inps = [ - # Implicitly assume that the first mit-sot input is the one with 3 taps - # This is not a required behavior and the test can change if we need to change Scan. - mit_sot_inps[:2][scan_op.info.mit_sot_in_slices[0].index(-3)], - mit_sot_inps[2:][scan_op.info.mit_sot_in_slices[1].index(-2)], - ] sit_sot_inps = scan_op.inner_sitsot(inner_inps) untraced_sit_sot_inps = scan_op.inner_untraced_sit_sot(inner_inps) @@ -386,12 +386,14 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): ) # ``local_subtensor_merge_integer`` + ``scan_reduce_buffer`` reduce the buffers - # the same way for both constant and symbolic ``n_steps`` (xs collapses - # to an untraced sit_sot, the mit_sot buffers reduce to ``taps + 1``), - # so inplace fires identically in both cases. + # the same way for both constant and symbolic ``n_steps`` (xs collapses to an + # untraced sit_sot), so the result is identical in both cases. No mit_sot tap may + # be destroyed by an output (the loop overwrites those slots in place); only the + # untraced_sit_sot input is destroyed in place by its own output. assert len(sit_sot_inps) == 0 assert len(untraced_sit_sot_inps) == 1 - assert set(destroyed_inputs) == {*oldest_mit_sot_inps, untraced_sit_sot_inps[0]} + assert not any(tap in destroyed_inputs for tap in mit_sot_inps) + assert set(destroyed_inputs) == {untraced_sit_sot_inps[0]} @pytest.mark.parametrize( @@ -519,5 +521,158 @@ def test_grad_until_and_truncate_sequence_taps(): ScanCompatibilityTests.check_grad_until_and_truncate_sequence_taps(mode="NUMBA") -def test_aliased_inner_outputs(): - ScanCompatibilityTests.check_aliased_inner_outputs(static_shape=True, mode="NUMBA") +@pytest.mark.parametrize("static_shape", (True, False)) +def test_aliased_inner_outputs(static_shape): + ScanCompatibilityTests.check_aliased_inner_outputs(static_shape, mode="NUMBA") + + +class TestUntracedSitSotAliasedInnerOutput: + """Regressions for #2252. + + A sit_sot whose trace is unused is lowered to untraced_sit_sot and carried by reference + across iterations. When its recurrence aliases an inner input or output, the numba + backend must keep the borrowed view only when sound and copy it otherwise; a wrong + choice corrupts the carry or aliases an outer buffer, so matching the reference is the + check. + """ + + @pytest.mark.parametrize( + "echo", + [ + "seq", + "non_seq", + "tap_input", + "tap_output", + "other_untraced_input", + "other_untraced_output", + ], + ) + def test_echo(self, echo): + # One general scan: traced accumulator `y` and traced sit_sot `s` keep the loop + # alive and give `s` a tap buffer; untraced states `u` (independent) and `m` are + # carried by reference. `m`'s recurrence echoes whichever inner value `echo` + # selects, exercising each borrow/copy decision. + x = pt.matrix("x") + c = pt.vector("c") + + def step(x_t, y_prev, s_prev, u_prev, m_prev, c): + s_next = s_prev + x_t # traced sit_sot -> provides a tap buffer + u_next = x_t * 2.0 # independent untraced value + m_next = { + "seq": x_t, + "non_seq": c, + "tap_input": s_prev, + "tap_output": s_next, + "other_untraced_input": u_prev, + "other_untraced_output": u_next, + }[echo] + # Accumulate the *carried* values u_prev and m_prev. A wrong borrow corrupts + # the carry buffer between iterations, so the bug only surfaces if a later step + # reads it back: reading m_next/u_next instead would pass even on buggy code. + y_next = y_prev + m_prev + u_prev + return y_next, s_next, u_next, m_next + + outs = scan( + step, + sequences=[x], + outputs_info=[pt.zeros(3), pt.zeros(3), pt.zeros(3), pt.zeros(3)], + non_sequences=[c], + return_updates=False, + ) + # Only y and s traces are used, so u and m are lowered to untraced_sit_sot. + compare_numba_and_py( + [x, c], + [outs[0], outs[1]], + [np.arange(15.0).reshape(5, 3), np.arange(3.0)], + numba_mode="NUMBA", + ) + + def test_echo_self_in_place(self): + # untraced state updates itself in place -> own previous value, kept by reference + x = pt.matrix("x") + + def step(x_t, acc_prev): + return acc_prev[:1].inc(x_t[:1]) + + outs = scan( + step, + sequences=[x], + outputs_info=[pt.zeros(3)], + return_updates=False, + ) + compare_numba_and_py( + [x], [outs[-1]], [np.arange(15.0).reshape(5, 3)], numba_mode="NUMBA" + ) + + def test_two_untraced_states_sharing_inner_root(self): + # Two untraced states whose recurrences are the *same* fresh inner value must each + # get their own carry buffer (invariant B): borrowing both onto the shared root + # would carry them in one buffer and alias the two outer outputs. The reversed read + # of one carry surfaces the corruption when they alias. Guards the + # ``seen_untraced_roots`` borrow check in the numba backend. + x = pt.matrix("x") + + def step(x_t, y_prev, a_prev, b_prev): + shared = x_t * 2.0 # single fresh root echoed by both untraced states + y_next = y_prev + a_prev[::-1] - b_prev + return y_next, shared, shared + + outs = scan( + step, + sequences=[x], + outputs_info=[pt.zeros(3), pt.zeros(3), pt.zeros(3)], + return_updates=False, + ) + # Only y's trace is used, so a and b are lowered to untraced_sit_sot. + compare_numba_and_py( + [x], [outs[0]], [np.arange(15.0).reshape(5, 3)], numba_mode="NUMBA" + ) + + +@pytest.mark.parametrize("case", ["cross_store", "untraced_on_tap"]) +def test_no_foreign_inplace_on_tap(case): + """A recurrence may reuse only its own tap buffer in place. + + The numba inner optimizer may compute an output in place on another state's + (destroyable) tap. That tap slot is overwritten by its own recurrence, so any + foreign output landing there is corrupted: a tapped cross-store (both outputs + ``>=1``-d) or an untraced output that keeps a reference to the tap. The + ``NoOutputInplaceOnInput`` feature must reject those inplaces. + """ + if case == "cross_store": + # Two vector mit_sots each computed in place on the *other*'s oldest tap. + n = pt.iscalar("n") + + def step(a_tm2, a_tm1, b_tm2, b_tm1, y): + return b_tm2 + 1.0, a_tm2 + 1.0, y + a_tm1.sum() + b_tm1.sum() + + outs = scan( + step, + n_steps=n, + outputs_info=[ + {"initial": pt.as_tensor(np.ones((2, 3))), "taps": [-2, -1]}, + {"initial": pt.as_tensor(np.ones((2, 3)) * 5), "taps": [-2, -1]}, + np.float64(0.0), + ], + return_updates=False, + ) + graph_inputs, graph_outputs, test_inputs = [n], [outs[2]], [6] + else: # untraced_on_tap: an untraced output computed in place on a mit_sot tap + x = pt.dvector("x") + + def step(x_t, z_tm2, z_tm1, y, m): + return z_tm1 + z_tm2 + 0.0 * x_t, y + m, z_tm2 + 1.0 + + outs = scan( + step, + sequences=[x], + outputs_info=[ + {"initial": pt.as_tensor([1.0, 2.0]), "taps": [-2, -1]}, + np.float64(0.0), + np.float64(0.0), + ], + return_updates=False, + ) + graph_inputs, graph_outputs, test_inputs = [x], [outs[1]], [np.arange(1.0, 9.0)] + + compare_numba_and_py(graph_inputs, graph_outputs, test_inputs, numba_mode="NUMBA")