Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions pytensor/link/numba/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,26 @@ def range_arr(x):
return range_arr


@register_funcify_and_cache_key(Scan)
def numba_funcify_Scan(op: Scan, node, **kwargs):
# Apply inner rewrites
def numba_optimize_inner_fgraph(op: Scan, node):
"""Clone and optimize the inner graph of a ``Scan`` for the numba backend.

The returned ``FunctionGraph`` is a clone of ``op.fgraph`` with numba's inner
rewrites (including destructive/inplace rewrites) and the necessary deepcopies
applied. ``op.fgraph`` itself is left untouched, so the same ``Scan`` can still
be compiled by another backend afterwards (e.g. the C backend, which rejects
inplace operations in the inner graph).

# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of Scan?
# The C-code defers it to the make_thunk phase
"""
rewriter = (
get_mode(op.mode)
.including("numba")
.excluding(*NUMBA._optimizer.exclude)
.optimizer
)
fgraph = op.fgraph
fgraph = op.fgraph.clone()
# When the buffer can only hold one SITSOT or as as many MITSOT as there are taps,
# We must always discard the oldest tap, so it's safe to destroy it in the inner function.
# TODO: Allow inplace for MITMOT
Expand Down Expand Up @@ -109,6 +116,14 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
Out(x, borrow=x in untraced_sit_sot_inner_outputs) for x in fgraph.outputs
]
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)
return fgraph


@register_funcify_and_cache_key(Scan)
def numba_funcify_Scan(op: Scan, node, **kwargs):
# Optimize a clone of the inner graph, leaving ``op.fgraph`` untouched so the
# same Scan can still be compiled by other backends.
fgraph = numba_optimize_inner_fgraph(op, node)

# Track which untraced_sit_sot outputs have their inner input destroyed
# by the optimized inner function (transitively, via DestroyHandler).
Expand All @@ -122,7 +137,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
inner_destroyed_untraced_out_idxs.add(untraced_start + j)

scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(
op.fgraph, fgraph_name="numba_scan"
fgraph, fgraph_name="numba_scan"
)

outer_in_names_to_vars = {
Expand Down
11 changes: 8 additions & 3 deletions tests/link/numba/test_compile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pytensor.compile.ops import ViewOp
from pytensor.graph import vectorize_graph
from pytensor.ifelse import IfElse
from pytensor.link.numba.dispatch.scan import numba_optimize_inner_fgraph
from pytensor.raise_op import assert_op
from pytensor.scalar import Add
from pytensor.scan.op import Scan
Expand Down Expand Up @@ -215,12 +216,16 @@ def test_ofg_with_inner_scan_rewrite():
xs_ofg = OpFromGraph([ys], [xs])(ys)
fn = function([ys], xs_ofg, mode="NUMBA")

# Check that we have a BlockwiseWithCoreShape in the inner Scan
# Check that we have a BlockwiseWithCoreShape in the inner Scan.
# The numba backend optimizes a clone of the Scan inner graph leaving ``op.fgraph`` untouched, so we
# inspect that optimized clone to observe the rewrite.
fn_ofg_op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(fn_ofg_op, OpFromGraph)
fn_scan_op = fn_ofg_op.fgraph.outputs[0].owner.op
fn_scan_node = fn_ofg_op.fgraph.outputs[0].owner
fn_scan_op = fn_scan_node.op
assert isinstance(fn_scan_op, Scan)
fn_cholesky_op = fn_scan_op.fgraph.outputs[0].owner.op
opt_fgraph = numba_optimize_inner_fgraph(fn_scan_op, fn_scan_node)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not good this isn't confirming the final graph had it, it's confirming it could have

fn_cholesky_op = opt_fgraph.outputs[0].owner.op
assert isinstance(fn_cholesky_op, BlockwiseWithCoreShape)
assert isinstance(fn_cholesky_op.core_op, Cholesky)

Expand Down
53 changes: 44 additions & 9 deletions tests/link/numba/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytensor.tensor as pt
from pytensor import config, function, grad
from pytensor.compile.mode import Mode, get_mode
from pytensor.link.numba.dispatch.scan import numba_optimize_inner_fgraph
from pytensor.scalar import Log1p
from pytensor.scan.basic import scan
from pytensor.scan.op import Scan
Expand Down Expand Up @@ -315,9 +316,10 @@ def test_inner_graph_optimized():
(scan_node,) = (
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
)
inner_scan_nodes = scan_node.op.fgraph.apply_nodes
assert len(inner_scan_nodes) == 1
(inner_scan_node,) = scan_node.op.fgraph.apply_nodes
# The numba backend optimizes a clone of the inner graph; the canonical
# ``op.fgraph`` is left untouched (see ``numba_optimize_inner_fgraph``).
opt_fgraph = numba_optimize_inner_fgraph(scan_node.op, scan_node)
(inner_scan_node,) = opt_fgraph.apply_nodes
assert isinstance(inner_scan_node.op, Elemwise) and isinstance(
inner_scan_node.op.scalar_op, Log1p
)
Expand Down Expand Up @@ -357,15 +359,18 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a):
numba_mode="NUMBA",
eval_obj_mode=False,
)
[scan_op] = [
node.op
for node in numba_fn.maker.fgraph.toposort()
if isinstance(node.op, Scan)
[scan_node] = [
node for node in numba_fn.maker.fgraph.toposort() if isinstance(node.op, Scan)
]
scan_op = scan_node.op

# The numba backend optimizes a clone of the inner graph; the canonical
# ``op.fgraph`` is left untouched (see ``numba_optimize_inner_fgraph``).
opt_fgraph = numba_optimize_inner_fgraph(scan_op, scan_node)

# Collect inner inputs we expect to be destroyed by the step function
# Scan reorders inputs internally, so we need to check its ordering
inner_inps = scan_op.fgraph.inputs
inner_inps = opt_fgraph.inputs
mit_sot_inps = scan_op.inner_mitsot(inner_inps)
oldest_mit_sot_inps = [
# Implicitly assume that the first mit-sot input is the one with 3 taps
Expand All @@ -377,7 +382,7 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a):
untraced_sit_sot_inps = scan_op.inner_untraced_sit_sot(inner_inps)

destroyed_inputs = []
for inner_out in scan_op.fgraph.outputs:
for inner_out in opt_fgraph.outputs:
node = inner_out.owner
dm = node.op.destroy_map
if dm:
Expand All @@ -394,6 +399,36 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a):
assert set(destroyed_inputs) == {*oldest_mit_sot_inps, untraced_sit_sot_inps[0]}


def test_inner_graph_not_mutated_by_numba():
"""Compiling a Scan with numba must not mutate the shared inner graph.

The numba backend applies destructive/inplace rewrites to the Scan inner
graph. These must run on a clone, otherwise a subsequent compilation of the
same graph by the C backend (which rejects inplace ops in the inner graph)
fails with ``TypeError: Graph must not contain inplace operations``.
"""

def core_fn(A):
x0 = pt.zeros((A.shape[0],))
seq = scan(
lambda x, A: pt.exp(A @ x) + x,
outputs_info=x0,
non_sequences=[A],
n_steps=10,
return_updates=False,
)
return seq[-1]

A = pt.tensor3("A")
out = pt.vectorize(core_fn, signature="(k,k)->(k)")(A)
val = np.broadcast_to(np.eye(3), (4, 3, 3)).astype(config.floatX)

res_numba = function([A], out, mode="NUMBA")(val)
# The C backend must still accept the same (un-mutated) graph
res_c = function([A], out, mode=Mode("cvm", "fast_run"))(val)
np.testing.assert_allclose(res_numba, res_c)


@pytest.mark.parametrize(
"buffer_size", ("unit", "aligned", "misaligned", "whole", "whole+init")
)
Expand Down