Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions pytensor/link/numba/dispatch/_llvmlite_self_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Import-time shim giving stock llvmlite the ability to emit self-referential
metadata nodes (``!0 = !{ !0 }``).

Self-referential nodes are how LLVM makes ``!alias.scope``/``!noalias`` domains and
scopes globally unique. llvmlite only supports them once this PR lands, adding
``Module.add_metadata(operands, self_ref=True)``:

https://github.com/numba/llvmlite/pull/895

This shim provides the same ``self_ref`` keyword on older llvmlite so the alias-scope
markers emitted by ``vectorize_codegen`` work without a patched llvmlite. It is a
no-op when the native API is present.
"""

import inspect

from llvmlite import ir
from llvmlite.ir import module as _ll_module
from llvmlite.ir import values as _ll_values


def ensure_self_ref_metadata_support() -> None:
"""Patch ``Module.add_metadata`` to accept ``self_ref=True`` if it doesn't already."""
if "self_ref" in inspect.signature(ir.Module.add_metadata).parameters:
return
if getattr(_ll_module.Module, "_pytensor_self_ref_patched", False):
return

base_add_metadata = _ll_module.Module.add_metadata

class _SelfRefMDValue(_ll_values.MDValue):
"""Metadata node whose first operand is itself.

The self-reference is kept out of the hashed/compared state: a self-ref
scope is routinely used as an operand of another metadata node (e.g. an
``alias.scope`` set), and hashing a tuple that transitively contains the
node would otherwise recurse forever. Equality falls back to identity,
matching the uniqueness guarantee self-referential nodes exist to provide.
"""

def __init__(self, parent, operands, name):
super().__init__(parent, operands, name)
self._self_ref_tail = tuple(operands)
self.operands = (self, *self._self_ref_tail)

def __hash__(self):
return hash(self._self_ref_tail)

def __eq__(self, other):
return self is other

def __ne__(self, other):
return self is not other

def add_metadata(self, operands, *, self_ref=False):
if not self_ref:
return base_add_metadata(self, operands)
if not isinstance(operands, list | tuple):
raise TypeError(
f"expected a list or tuple of metadata values, got {operands!r}"
)
operands = self._fix_metadata_operands(operands)
return _SelfRefMDValue(self, operands, name=str(len(self.metadata)))

_ll_module.Module.add_metadata = add_metadata
_ll_module.Module._pytensor_self_ref_patched = True
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def impl(*outer_inputs):

return impl

cache_version = 2
cache_version = 3
if scalar_cache_key is None:
key = None
else:
Expand Down
155 changes: 137 additions & 18 deletions pytensor/link/numba/dispatch/vectorize_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numba
import numpy as np
from llvmlite import ir
from llvmlite.ir.values import MDValue
from numba import TypingError, types
from numba.core import cgutils
from numba.core.base import BaseContext
Expand All @@ -17,6 +18,37 @@

from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch._llvmlite_self_ref import (
ensure_self_ref_metadata_support,
)


ensure_self_ref_metadata_support()


class _DistinctEmptyMetadata(MDValue):
"""A ``distinct !{}`` metadata node, usable as an LLVM access group.

llvmlite's ``MDValue`` only emits *uniqued* ``!{}`` nodes, which LLVM rejects as
access groups: an access group must be ``distinct`` so two function-local accesses
are never considered identical (a plain uniqued ``!{}`` crashes the verifier). Each
instance has its own identity, so every loop that asks for one gets a fresh group.
"""

def __init__(self, parent):
super().__init__(parent, [], name=str(len(parent.metadata)))

def descr(self, buf):
buf += ("distinct !{}", "\n")

def __eq__(self, other):
return self is other

def __ne__(self, other):
return self is not other

def __hash__(self):
return id(self)


def encode_literals(literals: Sequence) -> str:
Expand Down Expand Up @@ -102,8 +134,8 @@ def _compute_idx_load_axes(indexed_inputs, indexed_outputs, idx_ndims):
----------
indexed_inputs : tuple of ((tuple[int, ...], int) | None)
Per-index: (source input positions, source axis) or None.
indexed_outputs : tuple of ((tuple[int, ...], int, str) | None)
Per-index: (output positions, axis, mode) or None.
indexed_outputs : tuple of ((tuple[int, ...], int, str, bool) | None)
Per-index: (output positions, axis, mode, distinct) or None.
idx_ndims : tuple of int
Number of dimensions of each index array.
"""
Expand Down Expand Up @@ -144,7 +176,7 @@ def find(x):
for k, entry in enumerate(indexed_outputs):
if entry is not None:
root = find(k)
_, out_axis, _ = entry
_, out_axis, *_ = entry
group_min_axis[root] = min(group_min_axis.get(root, out_axis), out_axis)

return tuple(
Expand Down Expand Up @@ -437,20 +469,46 @@ def make_loop_call(
idx_load_axes: tuple[tuple[int, ...], ...] | None = None,
idx_bc: tuple[tuple[bool, ...], ...] | None = None,
output_write_spec: tuple[tuple[tuple[int, int], ...] | None, ...] | None = None,
inplace: tuple[tuple[int, int], ...] = (),
distinct_outputs: frozenset = frozenset(),
):
safe = (False, False)

n_outputs = len(outputs)

# TODO I think this is better than the noalias attribute
# for the input, but self_ref isn't supported in a released
# llvmlite version yet
# mod = builder.module
# domain = mod.add_metadata([], self_ref=True)
# input_scope = mod.add_metadata([domain], self_ref=True)
# output_scope = mod.add_metadata([domain], self_ref=True)
# input_scope_set = mod.add_metadata([input_scope, output_scope])
# output_scope_set = mod.add_metadata([input_scope, output_scope])
# Scoped noalias metadata: input loads and output stores are tagged with
# alias scopes so LLVM can disambiguate them without runtime overlap checks
# (and without loop versioning). Inputs share one scope (their loads never
# conflict with each other) while each output gets its own, so that every
# access can claim noalias against all buffers it is guaranteed not to
# overlap: PyTensor guarantees distinct output buffers, and that inputs
# don't alias outputs *except* for an input destroyed by an inplace output.
# Loads of a destroyed input are tagged with its output's scope instead:
# they stay MayAlias with that output's stores (LLVM resolves the exact
# overlap through pointer identity, since the output reuses the input's
# array struct) yet are still disambiguated from every other buffer.
mod = builder.module
domain = mod.add_metadata([], self_ref=True)
input_scope = mod.add_metadata([domain], self_ref=True)
output_scopes = [
mod.add_metadata([domain], self_ref=True) for _ in range(n_outputs)
]
input_scope_set = mod.add_metadata([input_scope])
output_scope_set = mod.add_metadata(output_scopes)
out_alias_sets = [mod.add_metadata([scope]) for scope in output_scopes]
out_noalias_sets = [
mod.add_metadata([input_scope, *(s for s in output_scopes if s is not scope)])
for scope in output_scopes
]
destroyed_inputs = {in_idx: out_idx for out_idx, in_idx in inplace}

# When an indexed-update output writes through statically-distinct indices, its
# read-modify-write carries no cross-iteration dependency, so the loop can vectorize
# the value-compute (LLVM scalarizes only the indexed stores). We promise this with an
# access group on those RMW load/stores plus `llvm.loop.parallel_accesses` on the
# innermost latch. The group must be a *distinct* node so it is never uniqued with
# another loop's.
access_group = _DistinctEmptyMetadata(mod) if distinct_outputs else None

zero = ir.Constant(ir.IntType(64), 0)

Expand Down Expand Up @@ -516,6 +574,10 @@ def _wrap_negative_index(idx_val, dim_size, signed):
False,
)
val = builder.load(ptr)
val.set_metadata("alias.scope", input_scope_set)
val.set_metadata("noalias", output_scope_set)
if access_group is not None:
val.set_metadata("llvm.access.group", access_group)
i64 = ir.IntType(64)
if val.type != i64:
if idx_arr_type.dtype.signed:
Expand Down Expand Up @@ -614,8 +676,18 @@ def _wrap_negative_index(idx_val, dim_size, signed):
if core_scalar and core_ndim == 0:
# Retrive scalar item at index
read_val = builder.load(read_ptr)
# read_val.set_metadata("alias.scope", input_scope_set)
# read_val.set_metadata("noalias", output_scope_set)
destination = destroyed_inputs.get(input_i)
if destination is None:
read_val.set_metadata("alias.scope", input_scope_set)
read_val.set_metadata("noalias", output_scope_set)
else:
read_val.set_metadata("alias.scope", out_alias_sets[destination])
read_val.set_metadata("noalias", out_noalias_sets[destination])
# Every memory access in the loop must join the access group, or LLVM's
# `isAnnotatedParallel` rejects the whole loop (one untagged load voids the
# `llvm.loop.parallel_accesses` promise).
if access_group is not None:
read_val.set_metadata("llvm.access.group", access_group)
else:
# Retrieve array item at index
# This is a streamlined version of Numba's `GUArrayArg.load`.
Expand Down Expand Up @@ -651,6 +723,7 @@ def _wrap_negative_index(idx_val, dim_size, signed):

# Create output slices to pass to inner func
output_slices = []
scratch_outputs = []
for output_i, (out, out_type, out_bc) in enumerate(
zip(outputs, output_types, output_bc, strict=True)
):
Expand Down Expand Up @@ -710,6 +783,22 @@ def _wrap_negative_index(idx_val, dim_size, signed):
dtype=out_type.dtype, ndim=effective_core_ndim, layout=out_type.layout
)
write_array = context.make_array(write_array_type)(context, builder)
if effective_core_ndim == 0:
# Redirect the 0-d output slice through a stack slot so the store
# into the real output buffer happens below, after the core call,
# where it can carry the alias scope metadata. The slot is
# initialized from the output buffer to preserve read-modify-write
# semantics (`o += t` in `store_core_outputs`); SROA collapses the
# slot after inlining.
scratch = cgutils.alloca_once(builder, write_ptr.type.pointee)
init_val = builder.load(write_ptr)
init_val.set_metadata("alias.scope", out_alias_sets[output_i])
init_val.set_metadata("noalias", out_noalias_sets[output_i])
if output_i in distinct_outputs:
init_val.set_metadata("llvm.access.group", access_group)
builder.store(init_val, scratch)
scratch_outputs.append((scratch, write_ptr, output_i))
write_ptr = scratch
core_shape = (
output_shape[-effective_core_ndim:] if effective_core_ndim > 0 else []
)
Expand Down Expand Up @@ -737,9 +826,29 @@ def _wrap_negative_index(idx_val, dim_size, signed):

inner_codegen(builder, [*constant_inputs, *input_vals, *output_slices])

# Close the loops
for loop in loop_stack[::-1]:
loop.__exit__(None, None, None)
for scratch, write_ptr, output_i in scratch_outputs:
out_val = builder.load(scratch)
store = builder.store(out_val, write_ptr)
store.set_metadata("alias.scope", out_alias_sets[output_i])
store.set_metadata("noalias", out_noalias_sets[output_i])
if output_i in distinct_outputs:
store.set_metadata("llvm.access.group", access_group)

# Close the loops. Under a no-dup promise, tag the innermost loop's latch with
# `llvm.loop.parallel_accesses` referencing the access group, so the vectorizer
# treats the tagged RMW accesses as free of loop-carried dependencies. The latch is
# the body block the builder sits in just before `for_range` emits its backedge.
for depth, loop in enumerate(loop_stack[::-1]):
if depth == 0 and access_group is not None:
latch_block = builder.basic_block
loop.__exit__(None, None, None)
parallel_md = mod.add_metadata(
[ir.MetaDataString(mod, "llvm.loop.parallel_accesses"), access_group]
)
loop_md = mod.add_metadata([parallel_md], self_ref=True)
latch_block.terminator.set_metadata("llvm.loop", loop_md)
else:
loop.__exit__(None, None, None)


@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True)
Expand Down Expand Up @@ -865,7 +974,7 @@ def _vectorized(
for k, entry in enumerate(indexed_outputs):
if entry is None:
continue
sources, source_axis, _mode = entry
sources, source_axis, _mode, *_ = entry
for out_idx in sources:
write_spec_dict.setdefault(out_idx, []).append((k, source_axis))
# Write target buffers are appended to the outer inputs in ascending output
Expand Down Expand Up @@ -925,6 +1034,14 @@ def _vectorized(
write_idx_set = frozenset(
k for k, entry in enumerate(indexed_outputs) if entry is not None
)
# Output positions whose indexed update was flagged distinct-index by the rewriter
# (4th spec field) -> safe to emit the no-dup vectorization promise.
distinct_output_idxs = frozenset(
out_idx
for entry in indexed_outputs
if entry is not None and len(entry) > 3 and entry[3]
for out_idx in entry[0]
)

def codegen(ctx, builder, sig, args):
[
Expand Down Expand Up @@ -1107,6 +1224,8 @@ def codegen(ctx, builder, sig, args):
idx_load_axes=idx_load_axes,
idx_bc=idx_broadcastable,
output_write_spec=output_write_spec,
inplace=inplace_pattern,
distinct_outputs=distinct_output_idxs,
)

return _codegen_return_outputs(
Expand Down
35 changes: 23 additions & 12 deletions pytensor/tensor/rewriting/indexed_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pytensor.scalar.basic import Composite
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer
from pytensor.tensor.rewriting.subtensor import _has_unique_indices
from pytensor.tensor.shape import Reshape, shape_padright
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
Expand Down Expand Up @@ -240,16 +241,20 @@ class IndexedElemwise(OpFromGraph):
indexed_outputs : tuple of ((tuple[int, ...], int, str) | None)
One entry per index array k, parallel to ``indexed_inputs``.
``None`` if index k has no write role.
Otherwise ``(sources, source_axis, mode)``:
Otherwise ``(sources, source_axis, mode, distinct)``:

- ``sources``: which Elemwise output positions are written
through this index into the update target buffer.
- ``source_axis``: which target-array axis is indexed.
- ``mode``: ``"inc"`` (accumulate) or ``"set"`` (overwrite).
- ``distinct``: whether the index entries are statically known to be
duplicate-free (an ``inc`` then has no cross-iteration RMW dependency,
so the Numba codegen may emit a loop-vectorization promise). Always
``False`` for ``set`` (it vectorizes regardless).

Examples::

tgt[idx] += exp(x) → indexed_outputs=[((0,), 0, "inc")]
tgt[idx] += exp(x) → indexed_outputs=[((0,), 0, "inc", False)]
"""

def __init__(self, *args, indexed_inputs=(), indexed_outputs=(), **kwargs):
Expand Down Expand Up @@ -301,7 +306,7 @@ def _op_debug_information_IndexedElemwise(op, node):
for k, entry in enumerate(op.indexed_outputs):
if entry is None:
continue
sources, _source_axis, mode = entry
sources, _source_axis, mode, *_ = entry
buf_label = f"buf_{buf_counter}"
buf_counter += 1
idx_label = f"idx_{k}"
Expand Down Expand Up @@ -721,16 +726,22 @@ def _has_non_write_clients(out_idx):
(tuple(reads), axis) if reads else None
for (_, axis), (reads, _) in idx_groups.items()
)
indexed_outputs_spec = tuple(
(
tuple(writes),
key[1],
"set" if write_targets[writes[0]].op.set_instead_of_inc else "inc",
indexed_outputs_spec_list = []
for key, (_, writes) in idx_groups.items():
if not writes:
indexed_outputs_spec_list.append(None)
continue
mode = (
"set" if write_targets[writes[0]].op.set_instead_of_inc else "inc"
)
if writes
else None
for key, (_, writes) in idx_groups.items()
)
# A distinct-index `inc` has no cross-iteration read-modify-write
# dependency, so the numba codegen can emit a no-dup vectorization
# promise. `set` already vectorizes without it, so only flag `inc`.
distinct = mode == "inc" and _has_unique_indices(fgraph, key[0])
indexed_outputs_spec_list.append(
(tuple(writes), key[1], mode, distinct)
)
indexed_outputs_spec = tuple(indexed_outputs_spec_list)

outer_inputs = []
for i, inp in enumerate(fgraph_inputs):
Expand Down
Loading
Loading