diff --git a/pytensor/link/numba/dispatch/_llvmlite_self_ref.py b/pytensor/link/numba/dispatch/_llvmlite_self_ref.py new file mode 100644 index 0000000000..764b584ceb --- /dev/null +++ b/pytensor/link/numba/dispatch/_llvmlite_self_ref.py @@ -0,0 +1,66 @@ +"""Import-time shim giving stock llvmlite the ability to emit self-referential +metadata nodes (``!0 = !{ !0 }``). + +Self-referential nodes are how LLVM makes ``!alias.scope``/``!noalias`` domains and +scopes globally unique. llvmlite only supports them once this PR lands, adding +``Module.add_metadata(operands, self_ref=True)``: + + https://github.com/numba/llvmlite/pull/895 + +This shim provides the same ``self_ref`` keyword on older llvmlite so the alias-scope +markers emitted by ``vectorize_codegen`` work without a patched llvmlite. It is a +no-op when the native API is present. +""" + +import inspect + +from llvmlite import ir +from llvmlite.ir import module as _ll_module +from llvmlite.ir import values as _ll_values + + +def ensure_self_ref_metadata_support() -> None: + """Patch ``Module.add_metadata`` to accept ``self_ref=True`` if it doesn't already.""" + if "self_ref" in inspect.signature(ir.Module.add_metadata).parameters: + return + if getattr(_ll_module.Module, "_pytensor_self_ref_patched", False): + return + + base_add_metadata = _ll_module.Module.add_metadata + + class _SelfRefMDValue(_ll_values.MDValue): + """Metadata node whose first operand is itself. + + The self-reference is kept out of the hashed/compared state: a self-ref + scope is routinely used as an operand of another metadata node (e.g. an + ``alias.scope`` set), and hashing a tuple that transitively contains the + node would otherwise recurse forever. Equality falls back to identity, + matching the uniqueness guarantee self-referential nodes exist to provide. + """ + + def __init__(self, parent, operands, name): + super().__init__(parent, operands, name) + self._self_ref_tail = tuple(operands) + self.operands = (self, *self._self_ref_tail) + + def __hash__(self): + return hash(self._self_ref_tail) + + def __eq__(self, other): + return self is other + + def __ne__(self, other): + return self is not other + + def add_metadata(self, operands, *, self_ref=False): + if not self_ref: + return base_add_metadata(self, operands) + if not isinstance(operands, list | tuple): + raise TypeError( + f"expected a list or tuple of metadata values, got {operands!r}" + ) + operands = self._fix_metadata_operands(operands) + return _SelfRefMDValue(self, operands, name=str(len(self.metadata))) + + _ll_module.Module.add_metadata = add_metadata + _ll_module.Module._pytensor_self_ref_patched = True diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 1383ee1df1..ea8bfb45e9 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -794,7 +794,7 @@ def impl(*outer_inputs): return impl - cache_version = 2 + cache_version = 3 if scalar_cache_key is None: key = None else: diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index d1efdf037d..2af0a0e683 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -9,6 +9,7 @@ import numba import numpy as np from llvmlite import ir +from llvmlite.ir.values import MDValue from numba import TypingError, types from numba.core import cgutils from numba.core.base import BaseContext @@ -17,6 +18,37 @@ from pytensor.link.numba.cache import compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch._llvmlite_self_ref import ( + ensure_self_ref_metadata_support, +) + + +ensure_self_ref_metadata_support() + + +class _DistinctEmptyMetadata(MDValue): + """A ``distinct !{}`` metadata node, usable as an LLVM access group. + + llvmlite's ``MDValue`` only emits *uniqued* ``!{}`` nodes, which LLVM rejects as + access groups: an access group must be ``distinct`` so two function-local accesses + are never considered identical (a plain uniqued ``!{}`` crashes the verifier). Each + instance has its own identity, so every loop that asks for one gets a fresh group. + """ + + def __init__(self, parent): + super().__init__(parent, [], name=str(len(parent.metadata))) + + def descr(self, buf): + buf += ("distinct !{}", "\n") + + def __eq__(self, other): + return self is other + + def __ne__(self, other): + return self is not other + + def __hash__(self): + return id(self) def encode_literals(literals: Sequence) -> str: @@ -102,8 +134,8 @@ def _compute_idx_load_axes(indexed_inputs, indexed_outputs, idx_ndims): ---------- indexed_inputs : tuple of ((tuple[int, ...], int) | None) Per-index: (source input positions, source axis) or None. - indexed_outputs : tuple of ((tuple[int, ...], int, str) | None) - Per-index: (output positions, axis, mode) or None. + indexed_outputs : tuple of ((tuple[int, ...], int, str, bool) | None) + Per-index: (output positions, axis, mode, distinct) or None. idx_ndims : tuple of int Number of dimensions of each index array. """ @@ -144,7 +176,7 @@ def find(x): for k, entry in enumerate(indexed_outputs): if entry is not None: root = find(k) - _, out_axis, _ = entry + _, out_axis, *_ = entry group_min_axis[root] = min(group_min_axis.get(root, out_axis), out_axis) return tuple( @@ -437,20 +469,46 @@ def make_loop_call( idx_load_axes: tuple[tuple[int, ...], ...] | None = None, idx_bc: tuple[tuple[bool, ...], ...] | None = None, output_write_spec: tuple[tuple[tuple[int, int], ...] | None, ...] | None = None, + inplace: tuple[tuple[int, int], ...] = (), + distinct_outputs: frozenset = frozenset(), ): safe = (False, False) n_outputs = len(outputs) - # TODO I think this is better than the noalias attribute - # for the input, but self_ref isn't supported in a released - # llvmlite version yet - # mod = builder.module - # domain = mod.add_metadata([], self_ref=True) - # input_scope = mod.add_metadata([domain], self_ref=True) - # output_scope = mod.add_metadata([domain], self_ref=True) - # input_scope_set = mod.add_metadata([input_scope, output_scope]) - # output_scope_set = mod.add_metadata([input_scope, output_scope]) + # Scoped noalias metadata: input loads and output stores are tagged with + # alias scopes so LLVM can disambiguate them without runtime overlap checks + # (and without loop versioning). Inputs share one scope (their loads never + # conflict with each other) while each output gets its own, so that every + # access can claim noalias against all buffers it is guaranteed not to + # overlap: PyTensor guarantees distinct output buffers, and that inputs + # don't alias outputs *except* for an input destroyed by an inplace output. + # Loads of a destroyed input are tagged with its output's scope instead: + # they stay MayAlias with that output's stores (LLVM resolves the exact + # overlap through pointer identity, since the output reuses the input's + # array struct) yet are still disambiguated from every other buffer. + mod = builder.module + domain = mod.add_metadata([], self_ref=True) + input_scope = mod.add_metadata([domain], self_ref=True) + output_scopes = [ + mod.add_metadata([domain], self_ref=True) for _ in range(n_outputs) + ] + input_scope_set = mod.add_metadata([input_scope]) + output_scope_set = mod.add_metadata(output_scopes) + out_alias_sets = [mod.add_metadata([scope]) for scope in output_scopes] + out_noalias_sets = [ + mod.add_metadata([input_scope, *(s for s in output_scopes if s is not scope)]) + for scope in output_scopes + ] + destroyed_inputs = {in_idx: out_idx for out_idx, in_idx in inplace} + + # When an indexed-update output writes through statically-distinct indices, its + # read-modify-write carries no cross-iteration dependency, so the loop can vectorize + # the value-compute (LLVM scalarizes only the indexed stores). We promise this with an + # access group on those RMW load/stores plus `llvm.loop.parallel_accesses` on the + # innermost latch. The group must be a *distinct* node so it is never uniqued with + # another loop's. + access_group = _DistinctEmptyMetadata(mod) if distinct_outputs else None zero = ir.Constant(ir.IntType(64), 0) @@ -516,6 +574,10 @@ def _wrap_negative_index(idx_val, dim_size, signed): False, ) val = builder.load(ptr) + val.set_metadata("alias.scope", input_scope_set) + val.set_metadata("noalias", output_scope_set) + if access_group is not None: + val.set_metadata("llvm.access.group", access_group) i64 = ir.IntType(64) if val.type != i64: if idx_arr_type.dtype.signed: @@ -614,8 +676,18 @@ def _wrap_negative_index(idx_val, dim_size, signed): if core_scalar and core_ndim == 0: # Retrive scalar item at index read_val = builder.load(read_ptr) - # read_val.set_metadata("alias.scope", input_scope_set) - # read_val.set_metadata("noalias", output_scope_set) + destination = destroyed_inputs.get(input_i) + if destination is None: + read_val.set_metadata("alias.scope", input_scope_set) + read_val.set_metadata("noalias", output_scope_set) + else: + read_val.set_metadata("alias.scope", out_alias_sets[destination]) + read_val.set_metadata("noalias", out_noalias_sets[destination]) + # Every memory access in the loop must join the access group, or LLVM's + # `isAnnotatedParallel` rejects the whole loop (one untagged load voids the + # `llvm.loop.parallel_accesses` promise). + if access_group is not None: + read_val.set_metadata("llvm.access.group", access_group) else: # Retrieve array item at index # This is a streamlined version of Numba's `GUArrayArg.load`. @@ -651,6 +723,7 @@ def _wrap_negative_index(idx_val, dim_size, signed): # Create output slices to pass to inner func output_slices = [] + scratch_outputs = [] for output_i, (out, out_type, out_bc) in enumerate( zip(outputs, output_types, output_bc, strict=True) ): @@ -710,6 +783,22 @@ def _wrap_negative_index(idx_val, dim_size, signed): dtype=out_type.dtype, ndim=effective_core_ndim, layout=out_type.layout ) write_array = context.make_array(write_array_type)(context, builder) + if effective_core_ndim == 0: + # Redirect the 0-d output slice through a stack slot so the store + # into the real output buffer happens below, after the core call, + # where it can carry the alias scope metadata. The slot is + # initialized from the output buffer to preserve read-modify-write + # semantics (`o += t` in `store_core_outputs`); SROA collapses the + # slot after inlining. + scratch = cgutils.alloca_once(builder, write_ptr.type.pointee) + init_val = builder.load(write_ptr) + init_val.set_metadata("alias.scope", out_alias_sets[output_i]) + init_val.set_metadata("noalias", out_noalias_sets[output_i]) + if output_i in distinct_outputs: + init_val.set_metadata("llvm.access.group", access_group) + builder.store(init_val, scratch) + scratch_outputs.append((scratch, write_ptr, output_i)) + write_ptr = scratch core_shape = ( output_shape[-effective_core_ndim:] if effective_core_ndim > 0 else [] ) @@ -737,9 +826,29 @@ def _wrap_negative_index(idx_val, dim_size, signed): inner_codegen(builder, [*constant_inputs, *input_vals, *output_slices]) - # Close the loops - for loop in loop_stack[::-1]: - loop.__exit__(None, None, None) + for scratch, write_ptr, output_i in scratch_outputs: + out_val = builder.load(scratch) + store = builder.store(out_val, write_ptr) + store.set_metadata("alias.scope", out_alias_sets[output_i]) + store.set_metadata("noalias", out_noalias_sets[output_i]) + if output_i in distinct_outputs: + store.set_metadata("llvm.access.group", access_group) + + # Close the loops. Under a no-dup promise, tag the innermost loop's latch with + # `llvm.loop.parallel_accesses` referencing the access group, so the vectorizer + # treats the tagged RMW accesses as free of loop-carried dependencies. The latch is + # the body block the builder sits in just before `for_range` emits its backedge. + for depth, loop in enumerate(loop_stack[::-1]): + if depth == 0 and access_group is not None: + latch_block = builder.basic_block + loop.__exit__(None, None, None) + parallel_md = mod.add_metadata( + [ir.MetaDataString(mod, "llvm.loop.parallel_accesses"), access_group] + ) + loop_md = mod.add_metadata([parallel_md], self_ref=True) + latch_block.terminator.set_metadata("llvm.loop", loop_md) + else: + loop.__exit__(None, None, None) @numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) @@ -865,7 +974,7 @@ def _vectorized( for k, entry in enumerate(indexed_outputs): if entry is None: continue - sources, source_axis, _mode = entry + sources, source_axis, _mode, *_ = entry for out_idx in sources: write_spec_dict.setdefault(out_idx, []).append((k, source_axis)) # Write target buffers are appended to the outer inputs in ascending output @@ -925,6 +1034,14 @@ def _vectorized( write_idx_set = frozenset( k for k, entry in enumerate(indexed_outputs) if entry is not None ) + # Output positions whose indexed update was flagged distinct-index by the rewriter + # (4th spec field) -> safe to emit the no-dup vectorization promise. + distinct_output_idxs = frozenset( + out_idx + for entry in indexed_outputs + if entry is not None and len(entry) > 3 and entry[3] + for out_idx in entry[0] + ) def codegen(ctx, builder, sig, args): [ @@ -1107,6 +1224,8 @@ def codegen(ctx, builder, sig, args): idx_load_axes=idx_load_axes, idx_bc=idx_broadcastable, output_write_spec=output_write_spec, + inplace=inplace_pattern, + distinct_outputs=distinct_output_idxs, ) return _codegen_return_outputs( diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py index b3ff3828b3..4ecd86cd21 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -17,6 +17,7 @@ from pytensor.scalar.basic import Composite from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer +from pytensor.tensor.rewriting.subtensor import _has_unique_indices from pytensor.tensor.shape import Reshape, shape_padright from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -240,16 +241,20 @@ class IndexedElemwise(OpFromGraph): indexed_outputs : tuple of ((tuple[int, ...], int, str) | None) One entry per index array k, parallel to ``indexed_inputs``. ``None`` if index k has no write role. - Otherwise ``(sources, source_axis, mode)``: + Otherwise ``(sources, source_axis, mode, distinct)``: - ``sources``: which Elemwise output positions are written through this index into the update target buffer. - ``source_axis``: which target-array axis is indexed. - ``mode``: ``"inc"`` (accumulate) or ``"set"`` (overwrite). + - ``distinct``: whether the index entries are statically known to be + duplicate-free (an ``inc`` then has no cross-iteration RMW dependency, + so the Numba codegen may emit a loop-vectorization promise). Always + ``False`` for ``set`` (it vectorizes regardless). Examples:: - tgt[idx] += exp(x) → indexed_outputs=[((0,), 0, "inc")] + tgt[idx] += exp(x) → indexed_outputs=[((0,), 0, "inc", False)] """ def __init__(self, *args, indexed_inputs=(), indexed_outputs=(), **kwargs): @@ -301,7 +306,7 @@ def _op_debug_information_IndexedElemwise(op, node): for k, entry in enumerate(op.indexed_outputs): if entry is None: continue - sources, _source_axis, mode = entry + sources, _source_axis, mode, *_ = entry buf_label = f"buf_{buf_counter}" buf_counter += 1 idx_label = f"idx_{k}" @@ -721,16 +726,22 @@ def _has_non_write_clients(out_idx): (tuple(reads), axis) if reads else None for (_, axis), (reads, _) in idx_groups.items() ) - indexed_outputs_spec = tuple( - ( - tuple(writes), - key[1], - "set" if write_targets[writes[0]].op.set_instead_of_inc else "inc", + indexed_outputs_spec_list = [] + for key, (_, writes) in idx_groups.items(): + if not writes: + indexed_outputs_spec_list.append(None) + continue + mode = ( + "set" if write_targets[writes[0]].op.set_instead_of_inc else "inc" ) - if writes - else None - for key, (_, writes) in idx_groups.items() - ) + # A distinct-index `inc` has no cross-iteration read-modify-write + # dependency, so the numba codegen can emit a no-dup vectorization + # promise. `set` already vectorizes without it, so only flag `inc`. + distinct = mode == "inc" and _has_unique_indices(fgraph, key[0]) + indexed_outputs_spec_list.append( + (tuple(writes), key[1], mode, distinct) + ) + indexed_outputs_spec = tuple(indexed_outputs_spec_list) outer_inputs = [] for i, inp in enumerate(fgraph_inputs): diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_indexed_elemwise.py index 22ffbe5927..816248ee62 100644 --- a/tests/link/numba/test_indexed_elemwise.py +++ b/tests/link/numba/test_indexed_elemwise.py @@ -5,6 +5,7 @@ import pytensor.tensor as pt from pytensor import Mode, function, get_mode +from pytensor.assumptions import assume from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise from pytensor.tensor.subtensor import ( AdvancedIncSubtensor1, @@ -571,6 +572,83 @@ def test_repeated_inc_with_read(self): np.testing.assert_allclose(fn(sv, tv.copy()), fn_u(sv, tv.copy()), rtol=1e-10) +class TestNoDupPromise: + """A distinct-index ``inc`` scatter carries a no-dup vectorization promise: the + rewriter flags it (4th ``indexed_outputs`` field) so the Numba codegen emits + ``llvm.loop.parallel_accesses``, letting the value-compute vectorize despite the + indexed RMW. The flag is gated on statically-known-distinct indices (a constant with + unique entries, or a ``unique_indices`` assumption) and only for ``inc`` (``set`` + vectorizes regardless).""" + + @staticmethod + def _write_spec(fn): + for n in fn.maker.fgraph.toposort(): + if isinstance(n.op, IndexedElemwise): + writes = [e for e in n.op.indexed_outputs if e is not None] + assert len(writes) == 1 + return writes[0] + raise AssertionError("no IndexedElemwise found") + + def test_constant_unique_index_flagged(self): + target = pt.vector("target", shape=(10,)) + x = pt.vector("x", shape=(5,)) + idx = np.array([0, 1, 2, 3, 4], dtype=np.int64) # unique + fn = function([target, x], target[idx].inc(pt.exp(x)), mode=NUMBA_MODE) + assert_fused(fn) + _sources, _axis, mode, distinct = self._write_spec(fn) + assert mode == "inc" and distinct is True + + def test_constant_duplicate_index_not_flagged(self): + target = pt.vector("target", shape=(10,)) + x = pt.vector("x", shape=(5,)) + idx = np.array([0, 0, 1, 2, 3], dtype=np.int64) # has a duplicate + fn = function([target, x], target[idx].inc(pt.exp(x)), mode=NUMBA_MODE) + _sources, _axis, _mode, distinct = self._write_spec(fn) + assert distinct is False + + def test_assumed_unique_index_flagged(self): + target = pt.vector("target", shape=(10,)) + x = pt.vector("x") + idx0 = pt.vector("idx", dtype="int64") + idx = assume(idx0, unique_indices=True) + fn = function([target, idx0, x], target[idx].inc(pt.exp(x)), mode=NUMBA_MODE) + _sources, _axis, mode, distinct = self._write_spec(fn) + assert mode == "inc" and distinct is True + + def test_runtime_index_not_flagged(self): + target = pt.vector("target", shape=(10,)) + x = pt.vector("x") + idx0 = pt.vector("idx", dtype="int64") + fn = function([target, idx0, x], target[idx0].inc(pt.exp(x)), mode=NUMBA_MODE) + _sources, _axis, _mode, distinct = self._write_spec(fn) + assert distinct is False + + def test_set_mode_not_flagged(self): + # `set` vectorizes without the promise, so it is never flagged even when unique. + target = pt.vector("target", shape=(10,)) + x = pt.vector("x") + idx0 = pt.vector("idx", dtype="int64") + idx = assume(idx0, unique_indices=True) + fn = function([target, idx0, x], target[idx].set(pt.exp(x)), mode=NUMBA_MODE) + _sources, _axis, mode, distinct = self._write_spec(fn) + assert mode == "set" and distinct is False + + def test_assumed_unique_correctness(self): + # The promise must not change results (exercises the metadata-emitting codegen). + rng = np.random.default_rng(0) + target = pt.vector("target", shape=(64,)) + x = pt.vector("x", shape=(64,)) + idx0 = pt.vector("idx", dtype="int64") + idx = assume(idx0, unique_indices=True) + fn, fn_u = fused_and_unfused([target, idx0, x], target[idx].inc(pt.exp(x))) + assert_fused(fn) + tv, xv = rng.normal(size=64), rng.normal(size=64) + iv = rng.permutation(64).astype(np.int64) # genuinely unique + np.testing.assert_allclose( + fn(tv.copy(), iv, xv), fn_u(tv.copy(), iv, xv), rtol=1e-10 + ) + + class TestShapeValidation: """Test that mismatched index/input shapes raise runtime errors."""