Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions pytensor/graph/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,3 +917,68 @@ def on_validate(self, fgraph):
)

return True


class NoOutputInplaceOnInput(Feature):
"""Forbid any `FunctionGraph` output from carrying a protected input destroyed in place.

Intermediate nodes may still destroy the protected inputs; only the graph
outputs may not be (a view of) the destroyed buffer. This suits a buffer the
surrounding code overwrites between evaluations — e.g. a
:class:`~pytensor.scan.op.Scan` tap, which the loop overwrites in place: an
output aliasing it would be read after the overwrite. Such an output could be
copied back out to stay correct, but needing that copy defeats the in-place
computation, so it is forbidden outright instead.

Parameters
----------
protected_input_idxs
Positions in ``fgraph.inputs`` of the inputs that no output may carry.
"""

def __init__(self, protected_input_idxs):
self.protected_input_idxs = tuple(protected_input_idxs)

def on_attach(self, fgraph):
if hasattr(fgraph, "_no_output_inplace_on_input"):
raise AlreadyThere(
f"NoOutputInplaceOnInput is already attached to {fgraph}."
)
fgraph._no_output_inplace_on_input = self

def clone(self):
return type(self)(self.protected_input_idxs)

def on_validate(self, fgraph):
if not hasattr(fgraph, "destroyers"):
return True
destroyed = {
fgraph.inputs[i]
for i in self.protected_input_idxs
if fgraph.destroyers(fgraph.inputs[i])
}
if not destroyed:
return True
# A protected input is an fgraph input (no owner), so it is the root of any alias
# chain it belongs to. Walk each output back along view_map *and* destroy_map edges
# to its memory root (as ``pytensor.compile.aliasing.alias_root`` does, which we
# can't import here without a graph<->compile cycle); the output carries a destroyed
# protected input -- as a view, or as the in-place result that destroyed it -- iff
# that root is one. (DestroyHandler's cached ``droot`` won't do: it tracks only
# view_map edges, so it omits the destroyer's own output.)
for out_idx, out in enumerate(fgraph.outputs):
Comment thread
ricardoV94 marked this conversation as resolved.
root = out
while root.owner is not None:
pos = root.owner.outputs.index(root)
rop = root.owner.op
sources = (*rop.view_map.get(pos, ()), *rop.destroy_map.get(pos, ()))
if not sources:
break
root = root.owner.inputs[sources[0]]
if root in destroyed:
raise InconsistencyError(
f"Output {out_idx} would carry input {fgraph.inputs.index(root)} "
"destroyed in place; that input may be destroyed by intermediate "
"nodes only, not aliased by an output."
)
return True
205 changes: 183 additions & 22 deletions pytensor/link/numba/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
from numba import types
from numba.extending import overload

from pytensor.compile.aliasing import add_supervisor_to_fgraph, insert_deepcopy
from pytensor.compile.aliasing import (
add_supervisor_to_fgraph,
alias_root,
insert_deepcopy,
)
from pytensor.compile.io import In, Out
from pytensor.compile.mode import NUMBA, get_mode
from pytensor.graph.features import NoOutputInplaceOnInput
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
Expand Down Expand Up @@ -56,6 +61,67 @@ def range_arr(x):

@register_funcify_and_cache_key(Scan)
def numba_funcify_Scan(op: Scan, node, **kwargs):
"""Generate a Numba implementation of a `Scan` loop.

Memory-aliasing model
---------------------
The generated loop calls the inner function once per iteration and routes its
outputs back as the next iteration's state and/or into the outer-output buffers.
Two freedoms make buffer management delicate:

* The inner function may compute an output **in place on one of its inputs**, writing
that input's buffer instead of allocating a fresh one.
* An inner output may **alias** an input or another state's buffer rather than be an
independent value; such an alias is copied unless the loop may safely keep the
reference, whereas a fresh independent buffer never needs copying.

The codegen must preserve two invariants:

(A) *Ownership for in-place writes.* Scan may let the inner function destroy a slot
only if scan owns the buffer in that slot on **every** iteration. Scan owns a
buffer when it deep-copied it, or when scan's own ``destroy_map`` grants
destruction of the corresponding outer input. A buffer scan does not own -- a
plain outer input absent from ``destroy_map`` -- must never reach an in-place
slot (writing a read-only constant segfaults; writing any other input corrupts
it).

(B) *No undeclared aliasing.* Scan may share an outer output's memory with an outer
input only where it already declares so in its ``view_map``/``destroy_map``
(e.g. a destroyed input it owns, or an untraced state echoed straight through).
Beyond that the loop must introduce no aliasing: two outer outputs must never
share a buffer, nor an output alias an undeclared input. Downstream rewrites
rely on the declared aliasing being exhaustive.

An inner output is consumed by the loop in one of two ways:

* **Tapped output** (mit_sot / sit_sot / nit_sot / mit_mot): copied into its circular
buffer, which scan owns. Because each value is copied into owned storage, the only
hazard is aliasing the one slot the loop overwrites this iteration; an alias of any
not-yet-overwritten slot is safe, and nothing the inner function returns can migrate a
foreign buffer into a tap. (Exactly which input tap shares the overwritten slot, given
the tap offsets and buffer length, is worked out where the rule is applied.)

* **Untraced sit_sot output**: reused directly as the next iteration's input (carried by
reference, not copied); only its final value is an outer output. Because the carry is
held by reference, an alias is sound only while the buffer it points at stays valid for
the carry's whole life and never surfaces memory scan doesn't own as a fresh output.
By buffer kind:

- an **independently produced value** (a freshly computed inner value, or a tapped
output -- copied into the trace, so the value itself is independent): always sound;
it aliases no outer memory and is never recycled;
- the state's **own previous value** (the matching untraced input): always sound -- the
carry referencing itself, which scan declares (its ``view_map`` maps each untraced
output onto its own input);
- a **tap input** or **another untraced state's input**: sound only while that buffer
outlives the carry -- a tap until the circular buffer recycles its slot, another
untraced state's buffer until that state is destroyed in a later iteration
(transitively, once untraced states are chained);
- an **outer input** (sequence / non-sequence): never sound -- it would surface outer
memory as the carry's final output (B).

Two untraced outputs must never share a buffer (also (B)).
"""
# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of Scan?
Expand All @@ -69,7 +135,8 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
fgraph = op.fgraph
# When the buffer can only hold one SITSOT or as as many MITSOT as there are taps,
# We must always discard the oldest tap, so it's safe to destroy it in the inner function.
# TODO: Allow inplace for MITMOT
# mit_mot inputs are never destroyed, so invariant A never applies to them; only their
# aliasing is handled, like any other tapped output. TODO: Allow inplace for MITMOT (#2252).
destroyable_sitsot = [
inner_sitsot
for outer_sitsot, inner_sitsot in zip(
Expand All @@ -87,10 +154,8 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
)
if outer_mitsot.type.shape[0] == abs(min(taps))
]
# Always allow the inner function to destroy untraced_sit_sot inputs.
# After the first iteration, these come from the previous output so
# destroying is always safe. For the first iteration, the codegen
# copies the outer input if the Scan's destroy_map doesn't allow it.
# Untraced inputs may all be destroyed in place; the storage copies below keep
# every in-place slot scan-owned each iteration (invariant A).
destroyable_untraced_sit_sot = list(op.inner_untraced_sit_sot(fgraph.inputs))
destroyable = {
*destroyable_sitsot,
Expand All @@ -103,23 +168,119 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
input_specs=input_specs,
accept_inplace=True,
)
rewriter(fgraph)
untraced_sit_sot_inner_outputs = set(op.inner_untraced_sit_sot_outs(fgraph.outputs))
output_specs = [
Out(x, borrow=x in untraced_sit_sot_inner_outputs) for x in fgraph.outputs

# Tap inputs whose buffer slot the loop overwrites this iteration. A tapped output is
# copied into its circular buffer, so it is corrupted only if it aliases such a slot;
# likewise an output computed in place on one can't be copied back out without
# defeating the in-place write. With buffer length L, output tap o_out writes input tap
# o_in iff (o_out - o_in) % L == 0 (same physical slot). Two ways this happens:
# - linear write-back (gap 0): the recurrence writes the very slot it read this step
# (mit_mot accumulators: read g[k], write g[k] + delta back). Holds for any L.
# - loop-around (gap == reach): sit_sot/mit_sot write the new state just past the oldest
# read; in a truncated buffer (L == reach, the minimal length that still holds the
# oldest tap) that wraps onto the just-discarded oldest read. A write landing further
# ahead would store onto a slot still to be read -- an invalid recurrence we need not
# consider, just as we don't consider L < reach (reads would alias). So those are the
# only two gaps.
# ``reach`` is the minimal admissible buffer length, ``max_lookback`` (= abs of the
# oldest tap); a write offset never exceeds it, so for any larger L the new slot is fresh
# and no wrap occurs. When L is statically known we test (o_out - o_in) % L == 0 exactly;
# otherwise L is only known to be >= reach and the loop-around can bite only at L ==
# reach, so we test gap 0 or gap == reach.
def overwritten_inputs(grouped_inner, outers, in_slices_seq, out_slices_seq):
result = []
for inner_vars, outer, in_slices, out_slices in zip(
grouped_inner, outers, in_slices_seq, out_slices_seq, strict=True
):
max_lookback = -min(0, min(in_slices))
in_offsets = [max_lookback + t for t in in_slices]
out_offsets = [max_lookback + t for t in out_slices]
static_len = outer.type.shape[0]
reach = max_lookback # minimal admissible buffer length
for v, o_in in zip(inner_vars, in_offsets, strict=True):
if static_len is None:
hit = any(
o_out == o_in or o_out - o_in == reach for o_out in out_offsets
)
else:
hit = any((o_out - o_in) % static_len == 0 for o_out in out_offsets)
if hit:
result.append(v)
return result

overwritten_taps = [
*overwritten_inputs(
op.inner_mitmot_grouped(fgraph.inputs),
op.outer_mitmot(node.inputs),
op.info.mit_mot_in_slices,
op.info.mit_mot_out_slices,
),
*overwritten_inputs(
op.inner_mitsot_grouped(fgraph.inputs),
op.outer_mitsot(node.inputs),
op.info.mit_sot_in_slices,
[(0,)] * op.info.n_mit_sot,
),
*overwritten_inputs(
[[v] for v in op.inner_sitsot(fgraph.inputs)],
op.outer_sitsot(node.inputs),
op.info.sit_sot_in_slices,
[(0,)] * op.info.n_sit_sot,
),
]
overwritten_tap_inputs = set(overwritten_taps)
if overwritten_taps:
in_pos = {v: i for i, v in enumerate(fgraph.inputs)}
fgraph.attach_feature(
NoOutputInplaceOnInput([in_pos[t] for t in overwritten_taps])
)
rewriter(fgraph)
# Apply the borrow rules from the docstring.
# Tapped output: copy only an alias of a tap input the loop overwrites this iteration
# (overwritten_tap_inputs above); it is stored immediately, so the same-iteration
# overwrite is the only hazard.
# Untraced output: keepable if its root is an independently produced value (root.owner;
# this also covers aliasing a tapped output, which the loop stores by copy) or its own
# untraced input -- scan's view_map already declares output j may alias input j, so that
# self-alias is sound (and lets an in-place self-update skip a copy). Anything else is
# copied: an outer input would leak into an output (B), a tap input's slot can be
# overwritten while the (chain-extended) carry alias is still live, and another untraced
# state's buffer could be carried by reference but only with cross-slot ownership
# bookkeeping (a destroyed-slot walk) for a marginal payoff, so we copy it too. The
# first keeper of a root wins so two untraced outputs never share a buffer (invariant B).
untraced_in_list = op.inner_untraced_sit_sot(fgraph.inputs)
untraced_out_list = op.inner_untraced_sit_sot_outs(fgraph.outputs)
untraced_outs = set(untraced_out_list)
own_untraced_input = dict(zip(untraced_out_list, untraced_in_list, strict=True))
seen_untraced_roots = set()
output_specs = []
for x in fgraph.outputs:
root = alias_root(x)
if x in untraced_outs:
keepable = root.owner is not None or root is own_untraced_input[x]
borrow = keepable and root not in seen_untraced_roots
Comment thread
ricardoV94 marked this conversation as resolved.
Comment thread
ricardoV94 marked this conversation as resolved.
if borrow:
seen_untraced_roots.add(root)
else:
borrow = root not in overwritten_tap_inputs
output_specs.append(Out(x, borrow=borrow))
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)

# Track which untraced_sit_sot outputs have their inner input destroyed
# by the optimized inner function (transitively, via DestroyHandler).
untraced_start = (
op.info.n_mit_mot + op.info.n_mit_sot + op.info.n_sit_sot + op.info.n_nit_sot
)
inner_destroyed_untraced_out_idxs = set()
# Untraced slots the inner function destroys in place (transitively, via
# DestroyHandler) and which scan must therefore own (invariant A). No untraced output
# borrows another untraced state's buffer, so a destroyed slot only ever holds its own
# buffer -- nothing migrates in -- and the directly destroyed slots are exactly the
# ones to copy.
inner_inplace_untraced_out_idxs = set()
if hasattr(fgraph, "destroyers"):
for j, inner_inp in enumerate(op.inner_untraced_sit_sot(fgraph.inputs)):
if fgraph.destroyers(inner_inp):
inner_destroyed_untraced_out_idxs.add(untraced_start + j)
inner_inplace_untraced_out_idxs = {
untraced_start + j
for j, inp in enumerate(untraced_in_list)
if fgraph.destroyers(inp)
}

scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(
op.fgraph, fgraph_name="numba_scan", ofg_memo=kwargs.get("ofg_memo")
Expand Down Expand Up @@ -357,12 +518,11 @@ def add_output_storage_post_proc_stmt(
inner_out_to_outer_in_stmts.append(storage_name)

output_idx = outer_output_names.index(storage_name)
# Copy the outer input when it will be mutated during the loop
# but the Scan's destroy_map doesn't grant ownership.
# Tapped outputs: the loop writes into the buffer via circular indexing.
# Untraced sit_sot: the inner function may destroy the input inplace.
# Take ownership (invariant A) when the loop mutates the buffer but
# destroy_map doesn't already grant it: tapped buffers (written via circular
# indexing) and the in-place untraced slots identified above.
needs_copy = output_idx not in node.op.destroy_map and (
is_tapped or output_idx in inner_destroyed_untraced_out_idxs
is_tapped or output_idx in inner_inplace_untraced_out_idxs
)
if needs_copy:
storage_alloc_stmt = f"{storage_name} = numba_deepcopy({outer_in_name})"
Expand Down Expand Up @@ -496,8 +656,9 @@ def scan({", ".join(outer_in_names)}):
# If we can't cache the inner function, we can't cache the Scan either
scan_cache_key = None
else:
scan_cache_version = 1
scan_cache_key = sha256(
f"({scan_op_src}, {inner_func_cache_key})".encode()
f"({scan_op_src}, {inner_func_cache_key}, {scan_cache_version})".encode()
).hexdigest()

return numba_basic.numba_njit(scan_op_fn, boundscheck=False), scan_cache_key
20 changes: 5 additions & 15 deletions pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,9 +1377,6 @@ def prepare_fgraph(self, fgraph):
# `Function` pipeline.
update_mapping = {}
preallocated_mitmot_outs = []
untraced_sit_sot_inner_outputs = self.inner_untraced_sit_sot_outs(
fgraph.outputs
)

if config.scan__allow_output_prealloc:
# Go through the mitmots. Whenever a mitmot has a tap both as an
Expand Down Expand Up @@ -1434,10 +1431,10 @@ def prepare_fgraph(self, fgraph):
for x in fgraph.inputs[input_idx:]
]
wrapped_outputs = [Out(x, borrow=True) for x in fgraph.outputs[:slices]]
wrapped_outputs += [
Out(x, borrow=x in untraced_sit_sot_inner_outputs)
for x in fgraph.outputs[slices:]
]
# Untraced sit_sot states are kept by reference across iterations, so
# their inner outputs must not alias inputs/other outputs (borrow=False
# lets insert_deepcopy break such aliasing). See issue #2252.
wrapped_outputs += [Out(x, borrow=False) for x in fgraph.outputs[slices:]]

protected_outs = tuple(
i
Expand All @@ -1453,14 +1450,7 @@ def prepare_fgraph(self, fgraph):

else:
wrapped_inputs = [In(x, borrow=True) for x in fgraph.inputs]
wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs[:slices]]
untraced_sit_sot_inner_outputs = self.inner_untraced_sit_sot_outs(
fgraph.outputs
)
wrapped_outputs += [
Out(x, borrow=x in untraced_sit_sot_inner_outputs)
for x in fgraph.outputs[slices:]
]
wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs]

fgraph.update_mapping = update_mapping

Expand Down
3 changes: 3 additions & 0 deletions pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def try_inplace_on_node(
node = inplace_node
except InconsistencyError:
inplace_pattern.pop(o)
# The input wasn't consumed, so let another output try to
# reuse it in place.
tried_inputs.discard(i)
if node is not original_node:
copy_stack_trace(original_node.outputs, node.outputs)
return node
Expand Down
Loading
Loading