From 4e475394735d0de282e18c98798eaf685eab9686 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 20 Jun 2026 19:27:36 +0200 Subject: [PATCH 1/2] Emit alias-scope noalias metadata in vectorize_codegen Port the scratch-slot store-redirection + scoped alias.scope/noalias metadata from 605789e (reduction_fusion_with_scope_markers) so this branch is self-contained regardless of PR merge order. The scratch slot moves each 0-d output's real store into make_loop_call (SROA collapses it back onto the inner 'o += t'), giving a load/store instruction we control -- the hook the no-dup promise needs to attach llvm.access.group / parallel_accesses metadata. Adapted from 605789e: dropped the unused string_codegen import (this branch's store_core_outputs still uses compile_numba_function_src). --- .../link/numba/dispatch/_llvmlite_self_ref.py | 66 +++++++++++++++++ .../link/numba/dispatch/vectorize_codegen.py | 74 ++++++++++++++++--- 2 files changed, 129 insertions(+), 11 deletions(-) create mode 100644 pytensor/link/numba/dispatch/_llvmlite_self_ref.py diff --git a/pytensor/link/numba/dispatch/_llvmlite_self_ref.py b/pytensor/link/numba/dispatch/_llvmlite_self_ref.py new file mode 100644 index 0000000000..764b584ceb --- /dev/null +++ b/pytensor/link/numba/dispatch/_llvmlite_self_ref.py @@ -0,0 +1,66 @@ +"""Import-time shim giving stock llvmlite the ability to emit self-referential +metadata nodes (``!0 = !{ !0 }``). + +Self-referential nodes are how LLVM makes ``!alias.scope``/``!noalias`` domains and +scopes globally unique. llvmlite only supports them once this PR lands, adding +``Module.add_metadata(operands, self_ref=True)``: + + https://github.com/numba/llvmlite/pull/895 + +This shim provides the same ``self_ref`` keyword on older llvmlite so the alias-scope +markers emitted by ``vectorize_codegen`` work without a patched llvmlite. It is a +no-op when the native API is present. +""" + +import inspect + +from llvmlite import ir +from llvmlite.ir import module as _ll_module +from llvmlite.ir import values as _ll_values + + +def ensure_self_ref_metadata_support() -> None: + """Patch ``Module.add_metadata`` to accept ``self_ref=True`` if it doesn't already.""" + if "self_ref" in inspect.signature(ir.Module.add_metadata).parameters: + return + if getattr(_ll_module.Module, "_pytensor_self_ref_patched", False): + return + + base_add_metadata = _ll_module.Module.add_metadata + + class _SelfRefMDValue(_ll_values.MDValue): + """Metadata node whose first operand is itself. + + The self-reference is kept out of the hashed/compared state: a self-ref + scope is routinely used as an operand of another metadata node (e.g. an + ``alias.scope`` set), and hashing a tuple that transitively contains the + node would otherwise recurse forever. Equality falls back to identity, + matching the uniqueness guarantee self-referential nodes exist to provide. + """ + + def __init__(self, parent, operands, name): + super().__init__(parent, operands, name) + self._self_ref_tail = tuple(operands) + self.operands = (self, *self._self_ref_tail) + + def __hash__(self): + return hash(self._self_ref_tail) + + def __eq__(self, other): + return self is other + + def __ne__(self, other): + return self is not other + + def add_metadata(self, operands, *, self_ref=False): + if not self_ref: + return base_add_metadata(self, operands) + if not isinstance(operands, list | tuple): + raise TypeError( + f"expected a list or tuple of metadata values, got {operands!r}" + ) + operands = self._fix_metadata_operands(operands) + return _SelfRefMDValue(self, operands, name=str(len(self.metadata))) + + _ll_module.Module.add_metadata = add_metadata + _ll_module.Module._pytensor_self_ref_patched = True diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index d1efdf037d..0905d92011 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -17,6 +17,12 @@ from pytensor.link.numba.cache import compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch._llvmlite_self_ref import ( + ensure_self_ref_metadata_support, +) + + +ensure_self_ref_metadata_support() def encode_literals(literals: Sequence) -> str: @@ -437,20 +443,37 @@ def make_loop_call( idx_load_axes: tuple[tuple[int, ...], ...] | None = None, idx_bc: tuple[tuple[bool, ...], ...] | None = None, output_write_spec: tuple[tuple[tuple[int, int], ...] | None, ...] | None = None, + inplace: tuple[tuple[int, int], ...] = (), ): safe = (False, False) n_outputs = len(outputs) - # TODO I think this is better than the noalias attribute - # for the input, but self_ref isn't supported in a released - # llvmlite version yet - # mod = builder.module - # domain = mod.add_metadata([], self_ref=True) - # input_scope = mod.add_metadata([domain], self_ref=True) - # output_scope = mod.add_metadata([domain], self_ref=True) - # input_scope_set = mod.add_metadata([input_scope, output_scope]) - # output_scope_set = mod.add_metadata([input_scope, output_scope]) + # Scoped noalias metadata: input loads and output stores are tagged with + # alias scopes so LLVM can disambiguate them without runtime overlap checks + # (and without loop versioning). Inputs share one scope (their loads never + # conflict with each other) while each output gets its own, so that every + # access can claim noalias against all buffers it is guaranteed not to + # overlap: PyTensor guarantees distinct output buffers, and that inputs + # don't alias outputs *except* for an input destroyed by an inplace output. + # Loads of a destroyed input are tagged with its output's scope instead: + # they stay MayAlias with that output's stores (LLVM resolves the exact + # overlap through pointer identity, since the output reuses the input's + # array struct) yet are still disambiguated from every other buffer. + mod = builder.module + domain = mod.add_metadata([], self_ref=True) + input_scope = mod.add_metadata([domain], self_ref=True) + output_scopes = [ + mod.add_metadata([domain], self_ref=True) for _ in range(n_outputs) + ] + input_scope_set = mod.add_metadata([input_scope]) + output_scope_set = mod.add_metadata(output_scopes) + out_alias_sets = [mod.add_metadata([scope]) for scope in output_scopes] + out_noalias_sets = [ + mod.add_metadata([input_scope, *(s for s in output_scopes if s is not scope)]) + for scope in output_scopes + ] + destroyed_inputs = {in_idx: out_idx for out_idx, in_idx in inplace} zero = ir.Constant(ir.IntType(64), 0) @@ -516,6 +539,8 @@ def _wrap_negative_index(idx_val, dim_size, signed): False, ) val = builder.load(ptr) + val.set_metadata("alias.scope", input_scope_set) + val.set_metadata("noalias", output_scope_set) i64 = ir.IntType(64) if val.type != i64: if idx_arr_type.dtype.signed: @@ -614,8 +639,13 @@ def _wrap_negative_index(idx_val, dim_size, signed): if core_scalar and core_ndim == 0: # Retrive scalar item at index read_val = builder.load(read_ptr) - # read_val.set_metadata("alias.scope", input_scope_set) - # read_val.set_metadata("noalias", output_scope_set) + destination = destroyed_inputs.get(input_i) + if destination is None: + read_val.set_metadata("alias.scope", input_scope_set) + read_val.set_metadata("noalias", output_scope_set) + else: + read_val.set_metadata("alias.scope", out_alias_sets[destination]) + read_val.set_metadata("noalias", out_noalias_sets[destination]) else: # Retrieve array item at index # This is a streamlined version of Numba's `GUArrayArg.load`. @@ -651,6 +681,7 @@ def _wrap_negative_index(idx_val, dim_size, signed): # Create output slices to pass to inner func output_slices = [] + scratch_outputs = [] for output_i, (out, out_type, out_bc) in enumerate( zip(outputs, output_types, output_bc, strict=True) ): @@ -710,6 +741,20 @@ def _wrap_negative_index(idx_val, dim_size, signed): dtype=out_type.dtype, ndim=effective_core_ndim, layout=out_type.layout ) write_array = context.make_array(write_array_type)(context, builder) + if effective_core_ndim == 0: + # Redirect the 0-d output slice through a stack slot so the store + # into the real output buffer happens below, after the core call, + # where it can carry the alias scope metadata. The slot is + # initialized from the output buffer to preserve read-modify-write + # semantics (`o += t` in `store_core_outputs`); SROA collapses the + # slot after inlining. + scratch = cgutils.alloca_once(builder, write_ptr.type.pointee) + init_val = builder.load(write_ptr) + init_val.set_metadata("alias.scope", out_alias_sets[output_i]) + init_val.set_metadata("noalias", out_noalias_sets[output_i]) + builder.store(init_val, scratch) + scratch_outputs.append((scratch, write_ptr, output_i)) + write_ptr = scratch core_shape = ( output_shape[-effective_core_ndim:] if effective_core_ndim > 0 else [] ) @@ -737,6 +782,12 @@ def _wrap_negative_index(idx_val, dim_size, signed): inner_codegen(builder, [*constant_inputs, *input_vals, *output_slices]) + for scratch, write_ptr, output_i in scratch_outputs: + out_val = builder.load(scratch) + store = builder.store(out_val, write_ptr) + store.set_metadata("alias.scope", out_alias_sets[output_i]) + store.set_metadata("noalias", out_noalias_sets[output_i]) + # Close the loops for loop in loop_stack[::-1]: loop.__exit__(None, None, None) @@ -1107,6 +1158,7 @@ def codegen(ctx, builder, sig, args): idx_load_axes=idx_load_axes, idx_bc=idx_broadcastable, output_write_spec=output_write_spec, + inplace=inplace_pattern, ) return _codegen_return_outputs( From db76610a61a233184827a2df7d0cee4e58744aa1 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 20 Jun 2026 20:04:35 +0200 Subject: [PATCH 2/2] Numba: no-dup vectorization promise for distinct-index inc-scatter Fusing a value-compute into an `inc`-scatter normally scalarizes the whole loop: the read-modify-write on out[idx[i]] is a possible cross-iteration dependency the LoopVectorizer can't rule out. When the indices are statically known to be distinct there is no such dependency, so we promise it to LLVM with an access group on every loop memory op plus `llvm.loop.parallel_accesses` on the latch -- the value-compute then vectorizes (LLVM scalarizes only the indexed stores, which AVX2 has no scatter for). - FuseIndexedElemwise gates the promise on `_has_unique_indices` (a constant with unique entries, or a `unique_indices` assumption), only for `inc` (`set` already vectorizes), recording it as a 4th `indexed_outputs` field that flows into the cache key. - make_loop_call emits a `distinct !{}` access group (a small MDValue subclass, since llvmlite can't emit distinct nodes and a uniqued !{} crashes the verifier) and tags the indexed RMW load/store, the index loads, and the input loads -- every memory op must join the group or LLVM's isAnnotatedParallel rejects the loop. Builds on the scratch-slot store redirection (prior commit). Validated: arithmetic value-compute goes 0 -> 3 packed ops with the promise; results unchanged with/without it. The transcendental win composes once a vector math library is wired (orthogonal veclib work). --- pytensor/link/numba/dispatch/elemwise.py | 2 +- .../link/numba/dispatch/vectorize_codegen.py | 83 +++++++++++++++++-- pytensor/tensor/rewriting/indexed_elemwise.py | 35 +++++--- tests/link/numba/test_indexed_elemwise.py | 78 +++++++++++++++++ 4 files changed, 177 insertions(+), 21 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 1383ee1df1..ea8bfb45e9 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -794,7 +794,7 @@ def impl(*outer_inputs): return impl - cache_version = 2 + cache_version = 3 if scalar_cache_key is None: key = None else: diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 0905d92011..2af0a0e683 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -9,6 +9,7 @@ import numba import numpy as np from llvmlite import ir +from llvmlite.ir.values import MDValue from numba import TypingError, types from numba.core import cgutils from numba.core.base import BaseContext @@ -25,6 +26,31 @@ ensure_self_ref_metadata_support() +class _DistinctEmptyMetadata(MDValue): + """A ``distinct !{}`` metadata node, usable as an LLVM access group. + + llvmlite's ``MDValue`` only emits *uniqued* ``!{}`` nodes, which LLVM rejects as + access groups: an access group must be ``distinct`` so two function-local accesses + are never considered identical (a plain uniqued ``!{}`` crashes the verifier). Each + instance has its own identity, so every loop that asks for one gets a fresh group. + """ + + def __init__(self, parent): + super().__init__(parent, [], name=str(len(parent.metadata))) + + def descr(self, buf): + buf += ("distinct !{}", "\n") + + def __eq__(self, other): + return self is other + + def __ne__(self, other): + return self is not other + + def __hash__(self): + return id(self) + + def encode_literals(literals: Sequence) -> str: return base64.encodebytes(pickle.dumps(literals)).decode() @@ -108,8 +134,8 @@ def _compute_idx_load_axes(indexed_inputs, indexed_outputs, idx_ndims): ---------- indexed_inputs : tuple of ((tuple[int, ...], int) | None) Per-index: (source input positions, source axis) or None. - indexed_outputs : tuple of ((tuple[int, ...], int, str) | None) - Per-index: (output positions, axis, mode) or None. + indexed_outputs : tuple of ((tuple[int, ...], int, str, bool) | None) + Per-index: (output positions, axis, mode, distinct) or None. idx_ndims : tuple of int Number of dimensions of each index array. """ @@ -150,7 +176,7 @@ def find(x): for k, entry in enumerate(indexed_outputs): if entry is not None: root = find(k) - _, out_axis, _ = entry + _, out_axis, *_ = entry group_min_axis[root] = min(group_min_axis.get(root, out_axis), out_axis) return tuple( @@ -444,6 +470,7 @@ def make_loop_call( idx_bc: tuple[tuple[bool, ...], ...] | None = None, output_write_spec: tuple[tuple[tuple[int, int], ...] | None, ...] | None = None, inplace: tuple[tuple[int, int], ...] = (), + distinct_outputs: frozenset = frozenset(), ): safe = (False, False) @@ -475,6 +502,14 @@ def make_loop_call( ] destroyed_inputs = {in_idx: out_idx for out_idx, in_idx in inplace} + # When an indexed-update output writes through statically-distinct indices, its + # read-modify-write carries no cross-iteration dependency, so the loop can vectorize + # the value-compute (LLVM scalarizes only the indexed stores). We promise this with an + # access group on those RMW load/stores plus `llvm.loop.parallel_accesses` on the + # innermost latch. The group must be a *distinct* node so it is never uniqued with + # another loop's. + access_group = _DistinctEmptyMetadata(mod) if distinct_outputs else None + zero = ir.Constant(ir.IntType(64), 0) def _wrap_negative_index(idx_val, dim_size, signed): @@ -541,6 +576,8 @@ def _wrap_negative_index(idx_val, dim_size, signed): val = builder.load(ptr) val.set_metadata("alias.scope", input_scope_set) val.set_metadata("noalias", output_scope_set) + if access_group is not None: + val.set_metadata("llvm.access.group", access_group) i64 = ir.IntType(64) if val.type != i64: if idx_arr_type.dtype.signed: @@ -646,6 +683,11 @@ def _wrap_negative_index(idx_val, dim_size, signed): else: read_val.set_metadata("alias.scope", out_alias_sets[destination]) read_val.set_metadata("noalias", out_noalias_sets[destination]) + # Every memory access in the loop must join the access group, or LLVM's + # `isAnnotatedParallel` rejects the whole loop (one untagged load voids the + # `llvm.loop.parallel_accesses` promise). + if access_group is not None: + read_val.set_metadata("llvm.access.group", access_group) else: # Retrieve array item at index # This is a streamlined version of Numba's `GUArrayArg.load`. @@ -752,6 +794,8 @@ def _wrap_negative_index(idx_val, dim_size, signed): init_val = builder.load(write_ptr) init_val.set_metadata("alias.scope", out_alias_sets[output_i]) init_val.set_metadata("noalias", out_noalias_sets[output_i]) + if output_i in distinct_outputs: + init_val.set_metadata("llvm.access.group", access_group) builder.store(init_val, scratch) scratch_outputs.append((scratch, write_ptr, output_i)) write_ptr = scratch @@ -787,10 +831,24 @@ def _wrap_negative_index(idx_val, dim_size, signed): store = builder.store(out_val, write_ptr) store.set_metadata("alias.scope", out_alias_sets[output_i]) store.set_metadata("noalias", out_noalias_sets[output_i]) - - # Close the loops - for loop in loop_stack[::-1]: - loop.__exit__(None, None, None) + if output_i in distinct_outputs: + store.set_metadata("llvm.access.group", access_group) + + # Close the loops. Under a no-dup promise, tag the innermost loop's latch with + # `llvm.loop.parallel_accesses` referencing the access group, so the vectorizer + # treats the tagged RMW accesses as free of loop-carried dependencies. The latch is + # the body block the builder sits in just before `for_range` emits its backedge. + for depth, loop in enumerate(loop_stack[::-1]): + if depth == 0 and access_group is not None: + latch_block = builder.basic_block + loop.__exit__(None, None, None) + parallel_md = mod.add_metadata( + [ir.MetaDataString(mod, "llvm.loop.parallel_accesses"), access_group] + ) + loop_md = mod.add_metadata([parallel_md], self_ref=True) + latch_block.terminator.set_metadata("llvm.loop", loop_md) + else: + loop.__exit__(None, None, None) @numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) @@ -916,7 +974,7 @@ def _vectorized( for k, entry in enumerate(indexed_outputs): if entry is None: continue - sources, source_axis, _mode = entry + sources, source_axis, _mode, *_ = entry for out_idx in sources: write_spec_dict.setdefault(out_idx, []).append((k, source_axis)) # Write target buffers are appended to the outer inputs in ascending output @@ -976,6 +1034,14 @@ def _vectorized( write_idx_set = frozenset( k for k, entry in enumerate(indexed_outputs) if entry is not None ) + # Output positions whose indexed update was flagged distinct-index by the rewriter + # (4th spec field) -> safe to emit the no-dup vectorization promise. + distinct_output_idxs = frozenset( + out_idx + for entry in indexed_outputs + if entry is not None and len(entry) > 3 and entry[3] + for out_idx in entry[0] + ) def codegen(ctx, builder, sig, args): [ @@ -1159,6 +1225,7 @@ def codegen(ctx, builder, sig, args): idx_bc=idx_broadcastable, output_write_spec=output_write_spec, inplace=inplace_pattern, + distinct_outputs=distinct_output_idxs, ) return _codegen_return_outputs( diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py index b3ff3828b3..4ecd86cd21 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -17,6 +17,7 @@ from pytensor.scalar.basic import Composite from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer +from pytensor.tensor.rewriting.subtensor import _has_unique_indices from pytensor.tensor.shape import Reshape, shape_padright from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -240,16 +241,20 @@ class IndexedElemwise(OpFromGraph): indexed_outputs : tuple of ((tuple[int, ...], int, str) | None) One entry per index array k, parallel to ``indexed_inputs``. ``None`` if index k has no write role. - Otherwise ``(sources, source_axis, mode)``: + Otherwise ``(sources, source_axis, mode, distinct)``: - ``sources``: which Elemwise output positions are written through this index into the update target buffer. - ``source_axis``: which target-array axis is indexed. - ``mode``: ``"inc"`` (accumulate) or ``"set"`` (overwrite). + - ``distinct``: whether the index entries are statically known to be + duplicate-free (an ``inc`` then has no cross-iteration RMW dependency, + so the Numba codegen may emit a loop-vectorization promise). Always + ``False`` for ``set`` (it vectorizes regardless). Examples:: - tgt[idx] += exp(x) → indexed_outputs=[((0,), 0, "inc")] + tgt[idx] += exp(x) → indexed_outputs=[((0,), 0, "inc", False)] """ def __init__(self, *args, indexed_inputs=(), indexed_outputs=(), **kwargs): @@ -301,7 +306,7 @@ def _op_debug_information_IndexedElemwise(op, node): for k, entry in enumerate(op.indexed_outputs): if entry is None: continue - sources, _source_axis, mode = entry + sources, _source_axis, mode, *_ = entry buf_label = f"buf_{buf_counter}" buf_counter += 1 idx_label = f"idx_{k}" @@ -721,16 +726,22 @@ def _has_non_write_clients(out_idx): (tuple(reads), axis) if reads else None for (_, axis), (reads, _) in idx_groups.items() ) - indexed_outputs_spec = tuple( - ( - tuple(writes), - key[1], - "set" if write_targets[writes[0]].op.set_instead_of_inc else "inc", + indexed_outputs_spec_list = [] + for key, (_, writes) in idx_groups.items(): + if not writes: + indexed_outputs_spec_list.append(None) + continue + mode = ( + "set" if write_targets[writes[0]].op.set_instead_of_inc else "inc" ) - if writes - else None - for key, (_, writes) in idx_groups.items() - ) + # A distinct-index `inc` has no cross-iteration read-modify-write + # dependency, so the numba codegen can emit a no-dup vectorization + # promise. `set` already vectorizes without it, so only flag `inc`. + distinct = mode == "inc" and _has_unique_indices(fgraph, key[0]) + indexed_outputs_spec_list.append( + (tuple(writes), key[1], mode, distinct) + ) + indexed_outputs_spec = tuple(indexed_outputs_spec_list) outer_inputs = [] for i, inp in enumerate(fgraph_inputs): diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_indexed_elemwise.py index 22ffbe5927..816248ee62 100644 --- a/tests/link/numba/test_indexed_elemwise.py +++ b/tests/link/numba/test_indexed_elemwise.py @@ -5,6 +5,7 @@ import pytensor.tensor as pt from pytensor import Mode, function, get_mode +from pytensor.assumptions import assume from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise from pytensor.tensor.subtensor import ( AdvancedIncSubtensor1, @@ -571,6 +572,83 @@ def test_repeated_inc_with_read(self): np.testing.assert_allclose(fn(sv, tv.copy()), fn_u(sv, tv.copy()), rtol=1e-10) +class TestNoDupPromise: + """A distinct-index ``inc`` scatter carries a no-dup vectorization promise: the + rewriter flags it (4th ``indexed_outputs`` field) so the Numba codegen emits + ``llvm.loop.parallel_accesses``, letting the value-compute vectorize despite the + indexed RMW. The flag is gated on statically-known-distinct indices (a constant with + unique entries, or a ``unique_indices`` assumption) and only for ``inc`` (``set`` + vectorizes regardless).""" + + @staticmethod + def _write_spec(fn): + for n in fn.maker.fgraph.toposort(): + if isinstance(n.op, IndexedElemwise): + writes = [e for e in n.op.indexed_outputs if e is not None] + assert len(writes) == 1 + return writes[0] + raise AssertionError("no IndexedElemwise found") + + def test_constant_unique_index_flagged(self): + target = pt.vector("target", shape=(10,)) + x = pt.vector("x", shape=(5,)) + idx = np.array([0, 1, 2, 3, 4], dtype=np.int64) # unique + fn = function([target, x], target[idx].inc(pt.exp(x)), mode=NUMBA_MODE) + assert_fused(fn) + _sources, _axis, mode, distinct = self._write_spec(fn) + assert mode == "inc" and distinct is True + + def test_constant_duplicate_index_not_flagged(self): + target = pt.vector("target", shape=(10,)) + x = pt.vector("x", shape=(5,)) + idx = np.array([0, 0, 1, 2, 3], dtype=np.int64) # has a duplicate + fn = function([target, x], target[idx].inc(pt.exp(x)), mode=NUMBA_MODE) + _sources, _axis, _mode, distinct = self._write_spec(fn) + assert distinct is False + + def test_assumed_unique_index_flagged(self): + target = pt.vector("target", shape=(10,)) + x = pt.vector("x") + idx0 = pt.vector("idx", dtype="int64") + idx = assume(idx0, unique_indices=True) + fn = function([target, idx0, x], target[idx].inc(pt.exp(x)), mode=NUMBA_MODE) + _sources, _axis, mode, distinct = self._write_spec(fn) + assert mode == "inc" and distinct is True + + def test_runtime_index_not_flagged(self): + target = pt.vector("target", shape=(10,)) + x = pt.vector("x") + idx0 = pt.vector("idx", dtype="int64") + fn = function([target, idx0, x], target[idx0].inc(pt.exp(x)), mode=NUMBA_MODE) + _sources, _axis, _mode, distinct = self._write_spec(fn) + assert distinct is False + + def test_set_mode_not_flagged(self): + # `set` vectorizes without the promise, so it is never flagged even when unique. + target = pt.vector("target", shape=(10,)) + x = pt.vector("x") + idx0 = pt.vector("idx", dtype="int64") + idx = assume(idx0, unique_indices=True) + fn = function([target, idx0, x], target[idx].set(pt.exp(x)), mode=NUMBA_MODE) + _sources, _axis, mode, distinct = self._write_spec(fn) + assert mode == "set" and distinct is False + + def test_assumed_unique_correctness(self): + # The promise must not change results (exercises the metadata-emitting codegen). + rng = np.random.default_rng(0) + target = pt.vector("target", shape=(64,)) + x = pt.vector("x", shape=(64,)) + idx0 = pt.vector("idx", dtype="int64") + idx = assume(idx0, unique_indices=True) + fn, fn_u = fused_and_unfused([target, idx0, x], target[idx].inc(pt.exp(x))) + assert_fused(fn) + tv, xv = rng.normal(size=64), rng.normal(size=64) + iv = rng.permutation(64).astype(np.int64) # genuinely unique + np.testing.assert_allclose( + fn(tv.copy(), iv, xv), fn_u(tv.copy(), iv, xv), rtol=1e-10 + ) + + class TestShapeValidation: """Test that mismatched index/input shapes raise runtime errors."""