diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index e8d05ce7ff..534b167460 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -54,19 +54,26 @@ def range_arr(x): return range_arr -@register_funcify_and_cache_key(Scan) -def numba_funcify_Scan(op: Scan, node, **kwargs): - # Apply inner rewrites +def numba_optimize_inner_fgraph(op: Scan, node): + """Clone and optimize the inner graph of a ``Scan`` for the numba backend. + + The returned ``FunctionGraph`` is a clone of ``op.fgraph`` with numba's inner + rewrites (including destructive/inplace rewrites) and the necessary deepcopies + applied. ``op.fgraph`` itself is left untouched, so the same ``Scan`` can still + be compiled by another backend afterwards (e.g. the C backend, which rejects + inplace operations in the inner graph). + # TODO: Not sure this is the right place to do this, should we have a rewrite that # explicitly triggers the optimization of the inner graphs of Scan? # The C-code defers it to the make_thunk phase + """ rewriter = ( get_mode(op.mode) .including("numba") .excluding(*NUMBA._optimizer.exclude) .optimizer ) - fgraph = op.fgraph + fgraph = op.fgraph.clone() # When the buffer can only hold one SITSOT or as as many MITSOT as there are taps, # We must always discard the oldest tap, so it's safe to destroy it in the inner function. # TODO: Allow inplace for MITMOT @@ -109,6 +116,14 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): Out(x, borrow=x in untraced_sit_sot_inner_outputs) for x in fgraph.outputs ] insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs) + return fgraph + + +@register_funcify_and_cache_key(Scan) +def numba_funcify_Scan(op: Scan, node, **kwargs): + # Optimize a clone of the inner graph, leaving ``op.fgraph`` untouched so the + # same Scan can still be compiled by other backends. + fgraph = numba_optimize_inner_fgraph(op, node) # Track which untraced_sit_sot outputs have their inner input destroyed # by the optimized inner function (transitively, via DestroyHandler). @@ -122,7 +137,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): inner_destroyed_untraced_out_idxs.add(untraced_start + j) scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key( - op.fgraph, fgraph_name="numba_scan" + fgraph, fgraph_name="numba_scan" ) outer_in_names_to_vars = { diff --git a/tests/link/numba/test_compile_ops.py b/tests/link/numba/test_compile_ops.py index 73affe6c8f..4b2d69b45f 100644 --- a/tests/link/numba/test_compile_ops.py +++ b/tests/link/numba/test_compile_ops.py @@ -6,6 +6,7 @@ from pytensor.compile.ops import ViewOp from pytensor.graph import vectorize_graph from pytensor.ifelse import IfElse +from pytensor.link.numba.dispatch.scan import numba_optimize_inner_fgraph from pytensor.raise_op import assert_op from pytensor.scalar import Add from pytensor.scan.op import Scan @@ -215,12 +216,16 @@ def test_ofg_with_inner_scan_rewrite(): xs_ofg = OpFromGraph([ys], [xs])(ys) fn = function([ys], xs_ofg, mode="NUMBA") - # Check that we have a BlockwiseWithCoreShape in the inner Scan + # Check that we have a BlockwiseWithCoreShape in the inner Scan. + # The numba backend optimizes a clone of the Scan inner graph leaving ``op.fgraph`` untouched, so we + # inspect that optimized clone to observe the rewrite. fn_ofg_op = fn.maker.fgraph.outputs[0].owner.op assert isinstance(fn_ofg_op, OpFromGraph) - fn_scan_op = fn_ofg_op.fgraph.outputs[0].owner.op + fn_scan_node = fn_ofg_op.fgraph.outputs[0].owner + fn_scan_op = fn_scan_node.op assert isinstance(fn_scan_op, Scan) - fn_cholesky_op = fn_scan_op.fgraph.outputs[0].owner.op + opt_fgraph = numba_optimize_inner_fgraph(fn_scan_op, fn_scan_node) + fn_cholesky_op = opt_fgraph.outputs[0].owner.op assert isinstance(fn_cholesky_op, BlockwiseWithCoreShape) assert isinstance(fn_cholesky_op.core_op, Cholesky) diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 54fc2a82f1..ed79094874 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -7,6 +7,7 @@ import pytensor.tensor as pt from pytensor import config, function, grad from pytensor.compile.mode import Mode, get_mode +from pytensor.link.numba.dispatch.scan import numba_optimize_inner_fgraph from pytensor.scalar import Log1p from pytensor.scan.basic import scan from pytensor.scan.op import Scan @@ -315,9 +316,10 @@ def test_inner_graph_optimized(): (scan_node,) = ( node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan) ) - inner_scan_nodes = scan_node.op.fgraph.apply_nodes - assert len(inner_scan_nodes) == 1 - (inner_scan_node,) = scan_node.op.fgraph.apply_nodes + # The numba backend optimizes a clone of the inner graph; the canonical + # ``op.fgraph`` is left untouched (see ``numba_optimize_inner_fgraph``). + opt_fgraph = numba_optimize_inner_fgraph(scan_node.op, scan_node) + (inner_scan_node,) = opt_fgraph.apply_nodes assert isinstance(inner_scan_node.op, Elemwise) and isinstance( inner_scan_node.op.scalar_op, Log1p ) @@ -357,15 +359,18 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): numba_mode="NUMBA", eval_obj_mode=False, ) - [scan_op] = [ - node.op - for node in numba_fn.maker.fgraph.toposort() - if isinstance(node.op, Scan) + [scan_node] = [ + node for node in numba_fn.maker.fgraph.toposort() if isinstance(node.op, Scan) ] + scan_op = scan_node.op + + # The numba backend optimizes a clone of the inner graph; the canonical + # ``op.fgraph`` is left untouched (see ``numba_optimize_inner_fgraph``). + opt_fgraph = numba_optimize_inner_fgraph(scan_op, scan_node) # Collect inner inputs we expect to be destroyed by the step function # Scan reorders inputs internally, so we need to check its ordering - inner_inps = scan_op.fgraph.inputs + inner_inps = opt_fgraph.inputs mit_sot_inps = scan_op.inner_mitsot(inner_inps) oldest_mit_sot_inps = [ # Implicitly assume that the first mit-sot input is the one with 3 taps @@ -377,7 +382,7 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): untraced_sit_sot_inps = scan_op.inner_untraced_sit_sot(inner_inps) destroyed_inputs = [] - for inner_out in scan_op.fgraph.outputs: + for inner_out in opt_fgraph.outputs: node = inner_out.owner dm = node.op.destroy_map if dm: @@ -394,6 +399,36 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): assert set(destroyed_inputs) == {*oldest_mit_sot_inps, untraced_sit_sot_inps[0]} +def test_inner_graph_not_mutated_by_numba(): + """Compiling a Scan with numba must not mutate the shared inner graph. + + The numba backend applies destructive/inplace rewrites to the Scan inner + graph. These must run on a clone, otherwise a subsequent compilation of the + same graph by the C backend (which rejects inplace ops in the inner graph) + fails with ``TypeError: Graph must not contain inplace operations``. + """ + + def core_fn(A): + x0 = pt.zeros((A.shape[0],)) + seq = scan( + lambda x, A: pt.exp(A @ x) + x, + outputs_info=x0, + non_sequences=[A], + n_steps=10, + return_updates=False, + ) + return seq[-1] + + A = pt.tensor3("A") + out = pt.vectorize(core_fn, signature="(k,k)->(k)")(A) + val = np.broadcast_to(np.eye(3), (4, 3, 3)).astype(config.floatX) + + res_numba = function([A], out, mode="NUMBA")(val) + # The C backend must still accept the same (un-mutated) graph + res_c = function([A], out, mode=Mode("cvm", "fast_run"))(val) + np.testing.assert_allclose(res_numba, res_c) + + @pytest.mark.parametrize( "buffer_size", ("unit", "aligned", "misaligned", "whole", "whole+init") )