Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytensor/assumptions/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
NotScalarConstantError,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.subtensor import _is_provably_positive
from pytensor.tensor.subtensor import is_provably_positive


def elemwise_preserves_zero_pattern(
Expand Down Expand Up @@ -76,7 +76,7 @@ def _not_singleton_matrix(var) -> bool:
return [FactState.UNKNOWN]
# 0 ** p == 0 for p > 0, so a provably-positive exponent (scalar or
# elementwise matrix) preserves the base's zero pattern.
return true_if(_is_provably_positive(node.inputs[1]))
return true_if(is_provably_positive(node.inputs[1]))

if isinstance(scalar_op, UnaryScalarOp) and scalar_op.preserves_zero:
return true_if(input_states[0] is FactState.TRUE)
Expand Down
6 changes: 3 additions & 3 deletions pytensor/assumptions/positive_definite.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
SolveBilinearDiscreteLyapunov,
)
from pytensor.tensor.math import Dot
from pytensor.tensor.subtensor import Subtensor, _is_provably_positive
from pytensor.tensor.subtensor import Subtensor, is_provably_positive


register_assumption(POSITIVE_DEFINITE, Eye)(eye_identity_rule)
Expand All @@ -34,7 +34,7 @@ def _alloc_diag(key, op, feature, fgraph, node, input_states):
return [FactState.FALSE]

[diag_values] = node.inputs
return true_if(_is_provably_positive(diag_values))
return true_if(is_provably_positive(diag_values))


register_assumption(POSITIVE_DEFINITE, BlockDiagonal)(all_inputs_have_key)
Expand Down Expand Up @@ -74,7 +74,7 @@ def _elemwise(key, op, feature, fgraph, node, input_states):
# Scaling a PD matrix by a positive scalar keeps it PD. The factor
# must be constant across the matrix axes (both broadcastable);
# per-batch variation is fine since every batch slice stays PD.
if not (all(inp.type.broadcastable[-2:]) and _is_provably_positive(inp)):
if not (all(inp.type.broadcastable[-2:]) and is_provably_positive(inp)):
continue
other_inputs = [node.inputs[j] for j in range(len(node.inputs)) if j != i]
if other_inputs and all(
Expand Down
77 changes: 77 additions & 0 deletions pytensor/tensor/constant_props.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Cached static predicates over the *data* of constant variables.

Each helper memoizes its result on the variable's ``tag`` so repeated rewrite
passes do not re-scan the same constant. These only inspect constant data, so
this stays a dependency-free leaf module (graph-walking sign analysis such as
``is_provably_positive`` lives in ``subtensor`` with the ops it recurses over).
"""

import numpy as np

from pytensor.graph.basic import Constant


def constant_is_all_negative(var) -> bool:
"""Whether ``var`` is a constant whose entries are all negative, cached on its tag."""
if not isinstance(var, Constant):
return False
cached: bool | None = getattr(var.tag, "all_negative", None)
if cached is not None:
return cached
result = bool(np.all(np.asarray(var.data) < 0))
var.tag.all_negative = result
return result


def constant_indices_are_unique(idx) -> bool:
"""Check whether a constant index has no duplicate entries.

Boolean indices, scalars, and single-element arrays are trivially unique.
For larger integer arrays, indices that mix positive and negative values
may alias, so those are treated as potentially duplicated. The result
is cached on ``idx.tag``.
"""
if not isinstance(idx, Constant):
return False
cached = getattr(idx.tag, "unique_indices", None)
if cached is not None:
return bool(cached)
idx_val = np.asarray(idx.data)
if idx_val.dtype == bool:
result = True
elif idx_val.size <= 1:
result = True
else:
has_pos = (idx_val >= 0).any()
has_neg = (idx_val < 0).any()
result = not (has_pos and has_neg) and np.unique(idx_val).size == idx_val.size
idx.tag.unique_indices = result
return result


def constant_is_arange(idx) -> tuple[int, int, int] | None:
"""Match ``idx`` to ``np.arange(offset, offset + d * step, step)``
and return ``(d, offset, step)``, else ``None``.

Single-element constants return ``(1, value, 1)``. The result is cached
on ``idx.tag.is_arange`` (``False`` sentinels a no-match).
"""
if not isinstance(idx, Constant):
return None
cached = getattr(idx.tag, "is_arange", None)
if cached is not None:
return cached if cached is not False else None
idx_val = np.asarray(idx.data)
if idx_val.ndim != 1 or idx_val.size == 0 or idx_val.dtype.kind not in "iu":
result: tuple[int, int, int] | None = None
elif idx_val.size == 1:
result = (1, int(idx_val[0]), 1)
else:
diffs = np.diff(idx_val)
step = int(diffs[0])
if step != 0 and np.all(diffs == step):
result = (int(idx_val.size), int(idx_val[0]), step)
else:
result = None
idx.tag.is_arange = result if result is not None else False
return result
64 changes: 58 additions & 6 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
zeros,
zeros_like,
)
from pytensor.tensor.constant_props import constant_is_all_negative
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast
Expand Down Expand Up @@ -112,7 +113,7 @@
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.shape import Shape, Shape_i, specify_shape
from pytensor.tensor.subtensor import Subtensor, _is_provably_positive
from pytensor.tensor.subtensor import Subtensor, is_provably_positive
from pytensor.tensor.type import (
complex_dtypes,
uint_dtypes,
Expand Down Expand Up @@ -714,8 +715,8 @@ def local_log_div(fgraph, node):

if isinstance(scalar_op, ps.TrueDiv):
num, den = inp.owner.inputs
if (isinstance(num, Constant) and _is_provably_positive(num, strict=True)) or (
isinstance(den, Constant) and _is_provably_positive(den, strict=True)
if (isinstance(num, Constant) and is_provably_positive(num, strict=True)) or (
isinstance(den, Constant) and is_provably_positive(den, strict=True)
):
out_dtype = node.outputs[0].dtype
return [log(num.astype(out_dtype)) - log(den.astype(out_dtype))]
Expand Down Expand Up @@ -745,13 +746,13 @@ def local_sign_div(fgraph, node):

num, den = inp.owner.inputs

if _is_provably_positive(num, strict=True):
if is_provably_positive(num, strict=True):
return [sign(den)]
if _is_provably_positive(den, strict=True):
if is_provably_positive(den, strict=True):
return [sign(num)]

for side, other in ((num, den), (den, num)):
if isinstance(side, Constant) and np.all(np.asarray(side.data) < 0):
if constant_is_all_negative(side):
return [neg(sign(other))]


Expand Down Expand Up @@ -865,6 +866,57 @@ def local_div_exp_to_mul_exp(fgraph, node):
return [new_out]


@register_specialize
@node_rewriter([true_div])
def local_div_reciprocal_to_mul(fgraph, node):
"""Replace ``A / reciprocal(B)`` with ``A * B`` and ``A / y ** (-p)`` with ``A * y ** p``."""
num, denom = node.inputs

match denom.owner_op_and_inputs:
case (Elemwise(scalar_op=ps.Reciprocal()), b):
inverted = b
case (Elemwise(scalar_op=ps.Pow()), base, exponent):
match exponent.owner_op_and_inputs:
case (Elemwise(scalar_op=ps.Neg()), pos_exponent):
inverted = base**pos_exponent
case _ if constant_is_all_negative(exponent):
inverted = base**-exponent.data
case _:
return None
case _:
return None

new_out = num * inverted
if new_out.dtype != node.outputs[0].dtype:
new_out = cast(new_out, dtype=node.outputs[0].dtype)
copy_stack_trace(node.outputs[0], new_out)
return [new_out]


@register_specialize
@node_rewriter([reciprocal])
def local_reciprocal_neg_pow_to_pow(fgraph, node):
"""Replace ``reciprocal(y ** (-p))`` with ``y ** p`` (the ``1 / y ** (-p)`` form)."""
[arg] = node.inputs

match arg.owner_op_and_inputs:
case (Elemwise(scalar_op=ps.Pow()), base, exponent):
match exponent.owner_op_and_inputs:
case (Elemwise(scalar_op=ps.Neg()), pos_exponent):
new_out = base**pos_exponent
case _ if constant_is_all_negative(exponent):
new_out = base**-exponent.data
case _:
return None
case _:
return None

if new_out.dtype != node.outputs[0].dtype:
new_out = cast(new_out, dtype=node.outputs[0].dtype)
copy_stack_trace(node.outputs[0], new_out)
return [new_out]


@register_specialize
@node_rewriter([mul, true_div])
def local_mul_pow_to_pow_add(fgraph, node):
Expand Down
78 changes: 14 additions & 64 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
)
from pytensor.tensor.basic import constant as tensor_constant
from pytensor.tensor.blockwise import _squeeze_left
from pytensor.tensor.constant_props import (
constant_indices_are_unique,
constant_is_arange,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_to, squeeze
Expand Down Expand Up @@ -75,7 +79,6 @@
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
_is_provably_non_negative,
_non_consecutive_adv_indexing,
advanced_inc_subtensor1,
advanced_subtensor1,
Expand All @@ -86,6 +89,7 @@
get_slice_elements,
inc_subtensor,
indices_from_subtensor,
is_provably_non_negative,
unflatten_index_variables,
)
from pytensor.tensor.type import TensorType
Expand Down Expand Up @@ -206,68 +210,14 @@ def get_advsubtensor_axis(indices):
return axis


def _constant_has_unique_indices(idx) -> bool:
"""Check whether a constant index has no duplicate entries.

Boolean indices, scalars, and single-element arrays are trivially unique.
For larger integer arrays, indices that mix positive and negative values
may alias, so those are treated as potentially duplicated. The result
is cached on ``idx.tag``.
"""
if not isinstance(idx, Constant):
return False
cached = getattr(idx.tag, "unique_indices", None)
if cached is not None:
return bool(cached)
idx_val = np.asarray(idx.data)
if idx_val.dtype == bool:
result = True
elif idx_val.size <= 1:
result = True
else:
has_pos = (idx_val >= 0).any()
has_neg = (idx_val < 0).any()
result = not (has_pos and has_neg) and np.unique(idx_val).size == idx_val.size
idx.tag.unique_indices = result
return result


def _has_unique_indices(fgraph, idx) -> bool:
"""Whether ``idx``'s entries are provably duplicate-free: a constant with
unique entries, or a variable asserted ``unique_indices`` by the user."""
return _constant_has_unique_indices(idx) or check_assumption(
return constant_indices_are_unique(idx) or check_assumption(
fgraph, idx, UNIQUE_INDICES
)


def _constant_is_arange(idx) -> tuple[int, int, int] | None:
"""Match ``idx`` to ``np.arange(offset, offset + d * step, step)``
and return ``(d, offset, step)``, else ``None``.

Single-element constants return ``(1, value, 1)``. The result is cached
on ``idx.tag.is_arange`` (``False`` sentinels a no-match).
"""
if not isinstance(idx, Constant):
return None
cached = getattr(idx.tag, "is_arange", None)
if cached is not None:
return cached if cached is not False else None
idx_val = np.asarray(idx.data)
if idx_val.ndim != 1 or idx_val.size == 0 or idx_val.dtype.kind not in "iu":
result: tuple[int, int, int] | None = None
elif idx_val.size == 1:
result = (1, int(idx_val[0]), 1)
else:
diffs = np.diff(idx_val)
step = int(diffs[0])
if step != 0 and np.all(diffs == step):
result = (int(idx_val.size), int(idx_val[0]), step)
else:
result = None
idx.tag.is_arange = result if result is not None else False
return result


def _match_arange_0_to_d_plus_offset(idx):
"""Match ``arange(0, d, 1) + offset`` and return ``(arange_node, offset)``
where ``arange_node`` is the ``arange(0, d, 1)`` output and ``offset`` is
Expand Down Expand Up @@ -783,7 +733,7 @@ def _merge_scalar_into_slice_unsafe(inner_slice, scalar_index, dim, xshape):
def _eager_lt_0(x):
"""Return ``True``/``False`` (Python bool) when the sign of *x* is
known, otherwise return the ``lt(x, 0)`` graph node."""
if _is_provably_non_negative(x):
if is_provably_non_negative(x):
return False
if isinstance(x, Constant):
return int(x.data) < 0
Expand All @@ -801,9 +751,9 @@ def _eager_switch(cond, a, b):
def _eager_minimum(a, b):
if a is b:
return a
if _eager_lt_0(a) is True and _is_provably_non_negative(b):
if _eager_lt_0(a) is True and is_provably_non_negative(b):
return a
if _eager_lt_0(b) is True and _is_provably_non_negative(a):
if _eager_lt_0(b) is True and is_provably_non_negative(a):
return b
return minimum(a, b)

Expand Down Expand Up @@ -1377,7 +1327,7 @@ def _arange_index_to_slice(idx):
if not isinstance(idx, TensorVariable) or idx.type.ndim != 1:
return None

const_match = _constant_is_arange(idx)
const_match = constant_is_arange(idx)
if const_match is not None:
d, offset, step = const_match
if offset < 0 or offset + (d - 1) * step < 0:
Expand All @@ -1400,9 +1350,9 @@ def _arange_index_to_slice(idx):
if isinstance(arange_stop, TensorVariable) and arange_stop.type.dtype != "int64":
arange_stop = arange_stop.astype("int64")
offset = _eager_scalar(offset)
if not _is_provably_non_negative(offset):
if not is_provably_non_negative(offset):
return None
if not _is_provably_non_negative(arange_stop):
if not is_provably_non_negative(arange_stop):
return None
stop = eager_add_zero(arange_stop, offset)
return slice(offset, stop)
Expand Down Expand Up @@ -1435,7 +1385,7 @@ def local_adv_idx_to_diagonal(fgraph, node):
# Match both indices as arange(d) + offset (const or symbolic).
# Both must be the same kind (both const or both symbolic).
def _match_arange(idx):
const = _constant_is_arange(idx)
const = constant_is_arange(idx)
if const is not None and const[2] == 1:
return "const", const[0], const[1]
sym = _match_arange_0_to_d_plus_offset(idx)
Expand Down Expand Up @@ -2052,7 +2002,7 @@ def _slice_to_arange(sl, dim_length):
return None
if sl.stop is None:
return arange(dim_length)
if not _is_provably_non_negative(sl.stop):
if not is_provably_non_negative(sl.stop):
return None
return arange(minimum(sl.stop, dim_length))

Expand Down
Loading
Loading