diff --git a/pytensor/graph/features.py b/pytensor/graph/features.py index 29c0b99dc8..b83eeab15a 100644 --- a/pytensor/graph/features.py +++ b/pytensor/graph/features.py @@ -917,3 +917,68 @@ def on_validate(self, fgraph): ) return True + + +class NoOutputInplaceOnInput(Feature): + """Forbid any `FunctionGraph` output from carrying a protected input destroyed in place. + + Intermediate nodes may still destroy the protected inputs; only the graph + outputs may not be (a view of) the destroyed buffer. This suits a buffer the + surrounding code overwrites between evaluations — e.g. a + :class:`~pytensor.scan.op.Scan` tap, which the loop overwrites in place: an + output aliasing it would be read after the overwrite. Such an output could be + copied back out to stay correct, but needing that copy defeats the in-place + computation, so it is forbidden outright instead. + + Parameters + ---------- + protected_input_idxs + Positions in ``fgraph.inputs`` of the inputs that no output may carry. + """ + + def __init__(self, protected_input_idxs): + self.protected_input_idxs = tuple(protected_input_idxs) + + def on_attach(self, fgraph): + if hasattr(fgraph, "_no_output_inplace_on_input"): + raise AlreadyThere( + f"NoOutputInplaceOnInput is already attached to {fgraph}." + ) + fgraph._no_output_inplace_on_input = self + + def clone(self): + return type(self)(self.protected_input_idxs) + + def on_validate(self, fgraph): + if not hasattr(fgraph, "destroyers"): + return True + destroyed = { + fgraph.inputs[i] + for i in self.protected_input_idxs + if fgraph.destroyers(fgraph.inputs[i]) + } + if not destroyed: + return True + # A protected input is an fgraph input (no owner), so it is the root of any alias + # chain it belongs to. Walk each output back along view_map *and* destroy_map edges + # to its memory root (as ``pytensor.compile.aliasing.alias_root`` does, which we + # can't import here without a graph<->compile cycle); the output carries a destroyed + # protected input -- as a view, or as the in-place result that destroyed it -- iff + # that root is one. (DestroyHandler's cached ``droot`` won't do: it tracks only + # view_map edges, so it omits the destroyer's own output.) + for out_idx, out in enumerate(fgraph.outputs): + root = out + while root.owner is not None: + pos = root.owner.outputs.index(root) + rop = root.owner.op + sources = (*rop.view_map.get(pos, ()), *rop.destroy_map.get(pos, ())) + if not sources: + break + root = root.owner.inputs[sources[0]] + if root in destroyed: + raise InconsistencyError( + f"Output {out_idx} would carry input {fgraph.inputs.index(root)} " + "destroyed in place; that input may be destroyed by intermediate " + "nodes only, not aliased by an output." + ) + return True diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index cf177f6183..da3f600b29 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -5,9 +5,14 @@ from numba import types from numba.extending import overload -from pytensor.compile.aliasing import add_supervisor_to_fgraph, insert_deepcopy +from pytensor.compile.aliasing import ( + add_supervisor_to_fgraph, + alias_root, + insert_deepcopy, +) from pytensor.compile.io import In, Out from pytensor.compile.mode import NUMBA, get_mode +from pytensor.graph.features import NoOutputInplaceOnInput from pytensor.link.numba.cache import compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( @@ -56,6 +61,67 @@ def range_arr(x): @register_funcify_and_cache_key(Scan) def numba_funcify_Scan(op: Scan, node, **kwargs): + """Generate a Numba implementation of a `Scan` loop. + + Memory-aliasing model + --------------------- + The generated loop calls the inner function once per iteration and routes its + outputs back as the next iteration's state and/or into the outer-output buffers. + Two freedoms make buffer management delicate: + + * The inner function may compute an output **in place on one of its inputs**, writing + that input's buffer instead of allocating a fresh one. + * An inner output may **alias** an input or another state's buffer rather than be an + independent value; such an alias is copied unless the loop may safely keep the + reference, whereas a fresh independent buffer never needs copying. + + The codegen must preserve two invariants: + + (A) *Ownership for in-place writes.* Scan may let the inner function destroy a slot + only if scan owns the buffer in that slot on **every** iteration. Scan owns a + buffer when it deep-copied it, or when scan's own ``destroy_map`` grants + destruction of the corresponding outer input. A buffer scan does not own -- a + plain outer input absent from ``destroy_map`` -- must never reach an in-place + slot (writing a read-only constant segfaults; writing any other input corrupts + it). + + (B) *No undeclared aliasing.* Scan may share an outer output's memory with an outer + input only where it already declares so in its ``view_map``/``destroy_map`` + (e.g. a destroyed input it owns, or an untraced state echoed straight through). + Beyond that the loop must introduce no aliasing: two outer outputs must never + share a buffer, nor an output alias an undeclared input. Downstream rewrites + rely on the declared aliasing being exhaustive. + + An inner output is consumed by the loop in one of two ways: + + * **Tapped output** (mit_sot / sit_sot / nit_sot / mit_mot): copied into its circular + buffer, which scan owns. Because each value is copied into owned storage, the only + hazard is aliasing the one slot the loop overwrites this iteration; an alias of any + not-yet-overwritten slot is safe, and nothing the inner function returns can migrate a + foreign buffer into a tap. (Exactly which input tap shares the overwritten slot, given + the tap offsets and buffer length, is worked out where the rule is applied.) + + * **Untraced sit_sot output**: reused directly as the next iteration's input (carried by + reference, not copied); only its final value is an outer output. Because the carry is + held by reference, an alias is sound only while the buffer it points at stays valid for + the carry's whole life and never surfaces memory scan doesn't own as a fresh output. + By buffer kind: + + - an **independently produced value** (a freshly computed inner value, or a tapped + output -- copied into the trace, so the value itself is independent): always sound; + it aliases no outer memory and is never recycled; + - the state's **own previous value** (the matching untraced input): always sound -- the + carry referencing itself, which scan declares (its ``view_map`` maps each untraced + output onto its own input); + - a **tap input** or **another untraced state's input**: sound only while that buffer + outlives the carry -- a tap until the circular buffer recycles its slot, another + untraced state's buffer until that state is destroyed in a later iteration + (transitively, once untraced states are chained); + - an **outer input** (sequence / non-sequence): never sound -- it would surface outer + memory as the carry's final output (B). + + Two untraced outputs must never share a buffer (also (B)). + """ # Apply inner rewrites # TODO: Not sure this is the right place to do this, should we have a rewrite that # explicitly triggers the optimization of the inner graphs of Scan? @@ -69,7 +135,8 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): fgraph = op.fgraph # When the buffer can only hold one SITSOT or as as many MITSOT as there are taps, # We must always discard the oldest tap, so it's safe to destroy it in the inner function. - # TODO: Allow inplace for MITMOT + # mit_mot inputs are never destroyed, so invariant A never applies to them; only their + # aliasing is handled, like any other tapped output. TODO: Allow inplace for MITMOT (#2252). destroyable_sitsot = [ inner_sitsot for outer_sitsot, inner_sitsot in zip( @@ -87,10 +154,8 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ) if outer_mitsot.type.shape[0] == abs(min(taps)) ] - # Always allow the inner function to destroy untraced_sit_sot inputs. - # After the first iteration, these come from the previous output so - # destroying is always safe. For the first iteration, the codegen - # copies the outer input if the Scan's destroy_map doesn't allow it. + # Untraced inputs may all be destroyed in place; the storage copies below keep + # every in-place slot scan-owned each iteration (invariant A). destroyable_untraced_sit_sot = list(op.inner_untraced_sit_sot(fgraph.inputs)) destroyable = { *destroyable_sitsot, @@ -103,23 +168,119 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): input_specs=input_specs, accept_inplace=True, ) - rewriter(fgraph) - untraced_sit_sot_inner_outputs = set(op.inner_untraced_sit_sot_outs(fgraph.outputs)) - output_specs = [ - Out(x, borrow=x in untraced_sit_sot_inner_outputs) for x in fgraph.outputs + + # Tap inputs whose buffer slot the loop overwrites this iteration. A tapped output is + # copied into its circular buffer, so it is corrupted only if it aliases such a slot; + # likewise an output computed in place on one can't be copied back out without + # defeating the in-place write. With buffer length L, output tap o_out writes input tap + # o_in iff (o_out - o_in) % L == 0 (same physical slot). Two ways this happens: + # - linear write-back (gap 0): the recurrence writes the very slot it read this step + # (mit_mot accumulators: read g[k], write g[k] + delta back). Holds for any L. + # - loop-around (gap == reach): sit_sot/mit_sot write the new state just past the oldest + # read; in a truncated buffer (L == reach, the minimal length that still holds the + # oldest tap) that wraps onto the just-discarded oldest read. A write landing further + # ahead would store onto a slot still to be read -- an invalid recurrence we need not + # consider, just as we don't consider L < reach (reads would alias). So those are the + # only two gaps. + # ``reach`` is the minimal admissible buffer length, ``max_lookback`` (= abs of the + # oldest tap); a write offset never exceeds it, so for any larger L the new slot is fresh + # and no wrap occurs. When L is statically known we test (o_out - o_in) % L == 0 exactly; + # otherwise L is only known to be >= reach and the loop-around can bite only at L == + # reach, so we test gap 0 or gap == reach. + def overwritten_inputs(grouped_inner, outers, in_slices_seq, out_slices_seq): + result = [] + for inner_vars, outer, in_slices, out_slices in zip( + grouped_inner, outers, in_slices_seq, out_slices_seq, strict=True + ): + max_lookback = -min(0, min(in_slices)) + in_offsets = [max_lookback + t for t in in_slices] + out_offsets = [max_lookback + t for t in out_slices] + static_len = outer.type.shape[0] + reach = max_lookback # minimal admissible buffer length + for v, o_in in zip(inner_vars, in_offsets, strict=True): + if static_len is None: + hit = any( + o_out == o_in or o_out - o_in == reach for o_out in out_offsets + ) + else: + hit = any((o_out - o_in) % static_len == 0 for o_out in out_offsets) + if hit: + result.append(v) + return result + + overwritten_taps = [ + *overwritten_inputs( + op.inner_mitmot_grouped(fgraph.inputs), + op.outer_mitmot(node.inputs), + op.info.mit_mot_in_slices, + op.info.mit_mot_out_slices, + ), + *overwritten_inputs( + op.inner_mitsot_grouped(fgraph.inputs), + op.outer_mitsot(node.inputs), + op.info.mit_sot_in_slices, + [(0,)] * op.info.n_mit_sot, + ), + *overwritten_inputs( + [[v] for v in op.inner_sitsot(fgraph.inputs)], + op.outer_sitsot(node.inputs), + op.info.sit_sot_in_slices, + [(0,)] * op.info.n_sit_sot, + ), ] + overwritten_tap_inputs = set(overwritten_taps) + if overwritten_taps: + in_pos = {v: i for i, v in enumerate(fgraph.inputs)} + fgraph.attach_feature( + NoOutputInplaceOnInput([in_pos[t] for t in overwritten_taps]) + ) + rewriter(fgraph) + # Apply the borrow rules from the docstring. + # Tapped output: copy only an alias of a tap input the loop overwrites this iteration + # (overwritten_tap_inputs above); it is stored immediately, so the same-iteration + # overwrite is the only hazard. + # Untraced output: keepable if its root is an independently produced value (root.owner; + # this also covers aliasing a tapped output, which the loop stores by copy) or its own + # untraced input -- scan's view_map already declares output j may alias input j, so that + # self-alias is sound (and lets an in-place self-update skip a copy). Anything else is + # copied: an outer input would leak into an output (B), a tap input's slot can be + # overwritten while the (chain-extended) carry alias is still live, and another untraced + # state's buffer could be carried by reference but only with cross-slot ownership + # bookkeeping (a destroyed-slot walk) for a marginal payoff, so we copy it too. The + # first keeper of a root wins so two untraced outputs never share a buffer (invariant B). + untraced_in_list = op.inner_untraced_sit_sot(fgraph.inputs) + untraced_out_list = op.inner_untraced_sit_sot_outs(fgraph.outputs) + untraced_outs = set(untraced_out_list) + own_untraced_input = dict(zip(untraced_out_list, untraced_in_list, strict=True)) + seen_untraced_roots = set() + output_specs = [] + for x in fgraph.outputs: + root = alias_root(x) + if x in untraced_outs: + keepable = root.owner is not None or root is own_untraced_input[x] + borrow = keepable and root not in seen_untraced_roots + if borrow: + seen_untraced_roots.add(root) + else: + borrow = root not in overwritten_tap_inputs + output_specs.append(Out(x, borrow=borrow)) insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs) - # Track which untraced_sit_sot outputs have their inner input destroyed - # by the optimized inner function (transitively, via DestroyHandler). untraced_start = ( op.info.n_mit_mot + op.info.n_mit_sot + op.info.n_sit_sot + op.info.n_nit_sot ) - inner_destroyed_untraced_out_idxs = set() + # Untraced slots the inner function destroys in place (transitively, via + # DestroyHandler) and which scan must therefore own (invariant A). No untraced output + # borrows another untraced state's buffer, so a destroyed slot only ever holds its own + # buffer -- nothing migrates in -- and the directly destroyed slots are exactly the + # ones to copy. + inner_inplace_untraced_out_idxs = set() if hasattr(fgraph, "destroyers"): - for j, inner_inp in enumerate(op.inner_untraced_sit_sot(fgraph.inputs)): - if fgraph.destroyers(inner_inp): - inner_destroyed_untraced_out_idxs.add(untraced_start + j) + inner_inplace_untraced_out_idxs = { + untraced_start + j + for j, inp in enumerate(untraced_in_list) + if fgraph.destroyers(inp) + } scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key( op.fgraph, fgraph_name="numba_scan", ofg_memo=kwargs.get("ofg_memo") @@ -357,12 +518,11 @@ def add_output_storage_post_proc_stmt( inner_out_to_outer_in_stmts.append(storage_name) output_idx = outer_output_names.index(storage_name) - # Copy the outer input when it will be mutated during the loop - # but the Scan's destroy_map doesn't grant ownership. - # Tapped outputs: the loop writes into the buffer via circular indexing. - # Untraced sit_sot: the inner function may destroy the input inplace. + # Take ownership (invariant A) when the loop mutates the buffer but + # destroy_map doesn't already grant it: tapped buffers (written via circular + # indexing) and the in-place untraced slots identified above. needs_copy = output_idx not in node.op.destroy_map and ( - is_tapped or output_idx in inner_destroyed_untraced_out_idxs + is_tapped or output_idx in inner_inplace_untraced_out_idxs ) if needs_copy: storage_alloc_stmt = f"{storage_name} = numba_deepcopy({outer_in_name})" @@ -496,8 +656,9 @@ def scan({", ".join(outer_in_names)}): # If we can't cache the inner function, we can't cache the Scan either scan_cache_key = None else: + scan_cache_version = 1 scan_cache_key = sha256( - f"({scan_op_src}, {inner_func_cache_key})".encode() + f"({scan_op_src}, {inner_func_cache_key}, {scan_cache_version})".encode() ).hexdigest() return numba_basic.numba_njit(scan_op_fn, boundscheck=False), scan_cache_key diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index a0cb073f54..8ac84d3d14 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -1377,9 +1377,6 @@ def prepare_fgraph(self, fgraph): # `Function` pipeline. update_mapping = {} preallocated_mitmot_outs = [] - untraced_sit_sot_inner_outputs = self.inner_untraced_sit_sot_outs( - fgraph.outputs - ) if config.scan__allow_output_prealloc: # Go through the mitmots. Whenever a mitmot has a tap both as an @@ -1434,10 +1431,10 @@ def prepare_fgraph(self, fgraph): for x in fgraph.inputs[input_idx:] ] wrapped_outputs = [Out(x, borrow=True) for x in fgraph.outputs[:slices]] - wrapped_outputs += [ - Out(x, borrow=x in untraced_sit_sot_inner_outputs) - for x in fgraph.outputs[slices:] - ] + # Untraced sit_sot states are kept by reference across iterations, so + # their inner outputs must not alias inputs/other outputs (borrow=False + # lets insert_deepcopy break such aliasing). See issue #2252. + wrapped_outputs += [Out(x, borrow=False) for x in fgraph.outputs[slices:]] protected_outs = tuple( i @@ -1453,14 +1450,7 @@ def prepare_fgraph(self, fgraph): else: wrapped_inputs = [In(x, borrow=True) for x in fgraph.inputs] - wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs[:slices]] - untraced_sit_sot_inner_outputs = self.inner_untraced_sit_sot_outs( - fgraph.outputs - ) - wrapped_outputs += [ - Out(x, borrow=x in untraced_sit_sot_inner_outputs) - for x in fgraph.outputs[slices:] - ] + wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs] fgraph.update_mapping = update_mapping diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 23c5ce5110..240b02e813 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -159,6 +159,9 @@ def try_inplace_on_node( node = inplace_node except InconsistencyError: inplace_pattern.pop(o) + # The input wasn't consumed, so let another output try to + # reuse it in place. + tried_inputs.discard(i) if node is not original_node: copy_stack_trace(original_node.outputs, node.outputs) return node diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 54fc2a82f1..8b49b76162 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -325,7 +325,13 @@ def test_inner_graph_optimized(): @pytest.mark.parametrize("n_steps_constant", (True, False)) def test_inplace_taps(n_steps_constant): - """Test that numba will inplace in the inner_function of the oldest sit-sot, mit-sot taps.""" + """No inner output may be computed in place on a sit-sot/mit-sot tap. + + The scan loop overwrites those tap buffers in place each iteration, so an output + aliasing one would be corrupted. Only an ``untraced_sit_sot`` input (rebound by + reference, never loop-overwritten) may be destroyed in place by its own output. + See #2252. + """ n_steps = 10 if n_steps_constant else scalar("n_steps", dtype=int) a = scalar("a") x0 = scalar("x0") @@ -363,16 +369,10 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): if isinstance(node.op, Scan) ] - # Collect inner inputs we expect to be destroyed by the step function - # Scan reorders inputs internally, so we need to check its ordering + # Collect the inner inputs destroyed in place by the output-producing nodes. + # Scan reorders inputs internally, so we go through its accessors. inner_inps = scan_op.fgraph.inputs mit_sot_inps = scan_op.inner_mitsot(inner_inps) - oldest_mit_sot_inps = [ - # Implicitly assume that the first mit-sot input is the one with 3 taps - # This is not a required behavior and the test can change if we need to change Scan. - mit_sot_inps[:2][scan_op.info.mit_sot_in_slices[0].index(-3)], - mit_sot_inps[2:][scan_op.info.mit_sot_in_slices[1].index(-2)], - ] sit_sot_inps = scan_op.inner_sitsot(inner_inps) untraced_sit_sot_inps = scan_op.inner_untraced_sit_sot(inner_inps) @@ -386,12 +386,14 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): ) # ``local_subtensor_merge_integer`` + ``scan_reduce_buffer`` reduce the buffers - # the same way for both constant and symbolic ``n_steps`` (xs collapses - # to an untraced sit_sot, the mit_sot buffers reduce to ``taps + 1``), - # so inplace fires identically in both cases. + # the same way for both constant and symbolic ``n_steps`` (xs collapses to an + # untraced sit_sot), so the result is identical in both cases. No mit_sot tap may + # be destroyed by an output (the loop overwrites those slots in place); only the + # untraced_sit_sot input is destroyed in place by its own output. assert len(sit_sot_inps) == 0 assert len(untraced_sit_sot_inps) == 1 - assert set(destroyed_inputs) == {*oldest_mit_sot_inps, untraced_sit_sot_inps[0]} + assert not any(tap in destroyed_inputs for tap in mit_sot_inps) + assert set(destroyed_inputs) == {untraced_sit_sot_inps[0]} @pytest.mark.parametrize( @@ -519,5 +521,158 @@ def test_grad_until_and_truncate_sequence_taps(): ScanCompatibilityTests.check_grad_until_and_truncate_sequence_taps(mode="NUMBA") -def test_aliased_inner_outputs(): - ScanCompatibilityTests.check_aliased_inner_outputs(static_shape=True, mode="NUMBA") +@pytest.mark.parametrize("static_shape", (True, False)) +def test_aliased_inner_outputs(static_shape): + ScanCompatibilityTests.check_aliased_inner_outputs(static_shape, mode="NUMBA") + + +class TestUntracedSitSotAliasedInnerOutput: + """Regressions for #2252. + + A sit_sot whose trace is unused is lowered to untraced_sit_sot and carried by reference + across iterations. When its recurrence aliases an inner input or output, the numba + backend must keep the borrowed view only when sound and copy it otherwise; a wrong + choice corrupts the carry or aliases an outer buffer, so matching the reference is the + check. + """ + + @pytest.mark.parametrize( + "echo", + [ + "seq", + "non_seq", + "tap_input", + "tap_output", + "other_untraced_input", + "other_untraced_output", + ], + ) + def test_echo(self, echo): + # One general scan: traced accumulator `y` and traced sit_sot `s` keep the loop + # alive and give `s` a tap buffer; untraced states `u` (independent) and `m` are + # carried by reference. `m`'s recurrence echoes whichever inner value `echo` + # selects, exercising each borrow/copy decision. + x = pt.matrix("x") + c = pt.vector("c") + + def step(x_t, y_prev, s_prev, u_prev, m_prev, c): + s_next = s_prev + x_t # traced sit_sot -> provides a tap buffer + u_next = x_t * 2.0 # independent untraced value + m_next = { + "seq": x_t, + "non_seq": c, + "tap_input": s_prev, + "tap_output": s_next, + "other_untraced_input": u_prev, + "other_untraced_output": u_next, + }[echo] + # Accumulate the *carried* values u_prev and m_prev. A wrong borrow corrupts + # the carry buffer between iterations, so the bug only surfaces if a later step + # reads it back: reading m_next/u_next instead would pass even on buggy code. + y_next = y_prev + m_prev + u_prev + return y_next, s_next, u_next, m_next + + outs = scan( + step, + sequences=[x], + outputs_info=[pt.zeros(3), pt.zeros(3), pt.zeros(3), pt.zeros(3)], + non_sequences=[c], + return_updates=False, + ) + # Only y and s traces are used, so u and m are lowered to untraced_sit_sot. + compare_numba_and_py( + [x, c], + [outs[0], outs[1]], + [np.arange(15.0).reshape(5, 3), np.arange(3.0)], + numba_mode="NUMBA", + ) + + def test_echo_self_in_place(self): + # untraced state updates itself in place -> own previous value, kept by reference + x = pt.matrix("x") + + def step(x_t, acc_prev): + return acc_prev[:1].inc(x_t[:1]) + + outs = scan( + step, + sequences=[x], + outputs_info=[pt.zeros(3)], + return_updates=False, + ) + compare_numba_and_py( + [x], [outs[-1]], [np.arange(15.0).reshape(5, 3)], numba_mode="NUMBA" + ) + + def test_two_untraced_states_sharing_inner_root(self): + # Two untraced states whose recurrences are the *same* fresh inner value must each + # get their own carry buffer (invariant B): borrowing both onto the shared root + # would carry them in one buffer and alias the two outer outputs. The reversed read + # of one carry surfaces the corruption when they alias. Guards the + # ``seen_untraced_roots`` borrow check in the numba backend. + x = pt.matrix("x") + + def step(x_t, y_prev, a_prev, b_prev): + shared = x_t * 2.0 # single fresh root echoed by both untraced states + y_next = y_prev + a_prev[::-1] - b_prev + return y_next, shared, shared + + outs = scan( + step, + sequences=[x], + outputs_info=[pt.zeros(3), pt.zeros(3), pt.zeros(3)], + return_updates=False, + ) + # Only y's trace is used, so a and b are lowered to untraced_sit_sot. + compare_numba_and_py( + [x], [outs[0]], [np.arange(15.0).reshape(5, 3)], numba_mode="NUMBA" + ) + + +@pytest.mark.parametrize("case", ["cross_store", "untraced_on_tap"]) +def test_no_foreign_inplace_on_tap(case): + """A recurrence may reuse only its own tap buffer in place. + + The numba inner optimizer may compute an output in place on another state's + (destroyable) tap. That tap slot is overwritten by its own recurrence, so any + foreign output landing there is corrupted: a tapped cross-store (both outputs + ``>=1``-d) or an untraced output that keeps a reference to the tap. The + ``NoOutputInplaceOnInput`` feature must reject those inplaces. + """ + if case == "cross_store": + # Two vector mit_sots each computed in place on the *other*'s oldest tap. + n = pt.iscalar("n") + + def step(a_tm2, a_tm1, b_tm2, b_tm1, y): + return b_tm2 + 1.0, a_tm2 + 1.0, y + a_tm1.sum() + b_tm1.sum() + + outs = scan( + step, + n_steps=n, + outputs_info=[ + {"initial": pt.as_tensor(np.ones((2, 3))), "taps": [-2, -1]}, + {"initial": pt.as_tensor(np.ones((2, 3)) * 5), "taps": [-2, -1]}, + np.float64(0.0), + ], + return_updates=False, + ) + graph_inputs, graph_outputs, test_inputs = [n], [outs[2]], [6] + else: # untraced_on_tap: an untraced output computed in place on a mit_sot tap + x = pt.dvector("x") + + def step(x_t, z_tm2, z_tm1, y, m): + return z_tm1 + z_tm2 + 0.0 * x_t, y + m, z_tm2 + 1.0 + + outs = scan( + step, + sequences=[x], + outputs_info=[ + {"initial": pt.as_tensor([1.0, 2.0]), "taps": [-2, -1]}, + np.float64(0.0), + np.float64(0.0), + ], + return_updates=False, + ) + graph_inputs, graph_outputs, test_inputs = [x], [outs[1]], [np.arange(1.0, 9.0)] + + compare_numba_and_py(graph_inputs, graph_outputs, test_inputs, numba_mode="NUMBA")