Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions pytensor/link/jax/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
Erfinv,
GammaIncCInv,
GammaIncInv,
Iv,
Ive,
Kve,
Log1mexp,
Expand Down Expand Up @@ -277,16 +276,6 @@ def jax_funcify_from_tfp(op, **kwargs):
return tfp_jax_op


@jax_funcify.register(Iv)
def jax_funcify_Iv(op, **kwargs):
ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive")

def iv(v, x):
return ive(v, x) / jnp.exp(-jnp.abs(jnp.real(x)))

return iv


@jax_funcify.register(Ive)
def jax_funcify_Ive(op, **kwargs):
return try_import_tfp_jax_op(op, jax_op_name="bessel_ive")
Expand Down
27 changes: 1 addition & 26 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,31 +1073,6 @@ def c_code(self, node, name, inp, out, sub):
j0 = J0(upgrade_to_float, name="j0")


class Iv(BinaryScalarOp):
"""
Modified Bessel function of the first kind of order v (real).
"""

nfunc_spec = ("scipy.special.iv", 2, 1)

def impl(self, v, x):
return special.iv(v, x)

def grad(self, inputs, grads):
v, x = inputs
(gz,) = grads
return [
grad_not_implemented(self, 0, v),
gz * (iv(v - 1, x) + iv(v + 1, x)) / 2.0,
]

def c_code(self, *args, **kwargs):
raise NotImplementedError()


iv = Iv(upgrade_to_float, name="iv")


class I1(UnaryScalarOp):
"""
Modified Bessel function of the first kind of order 1.
Expand All @@ -1111,7 +1086,7 @@ def impl(self, x):
def grad(self, inputs, grads):
(x,) = inputs
(gz,) = grads
return [gz * (i0(x) + iv(2, x)) / 2.0]
return [gz * (i0(x) + ive(2, x) * exp(abs(x))) / 2.0]

def c_code(self, *args, **kwargs):
raise NotImplementedError()
Expand Down
9 changes: 7 additions & 2 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2429,9 +2429,14 @@ def i1(x):
"""Modified Bessel function of the first kind of order 1."""


@scalar_elemwise
def iv(v, x):
"""Modified Bessel function of the first kind of order v (real)."""
"""Modified Bessel function of the first kind of order v (real).

Computed as ``ive(v, x) * exp(abs(x))`` for numerical consistency with
``ive``. For large ``x``, prefer working in log-space:
``log(iv(v, x)) == log(ive(v, x)) + abs(x)`` to avoid overflow.
"""
return ive(v, x) * exp(abs(x))


@scalar_elemwise
Expand Down
15 changes: 15 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
ge,
int_div,
isinf,
ive,
kve,
le,
log,
Expand Down Expand Up @@ -3888,3 +3889,17 @@ def local_useless_conj(fgraph, node):
)

register_stabilize(local_log_kv)


local_log_iv = PatternNodeRewriter(
# Rewrite log(iv(v, x)) = log(ive(v, x) * exp(abs(x))) -> log(ive(v, x)) + abs(x)
(log, (mul, (ive, "v", "x"), (exp, (pt_abs, "x")))),
(add, (log, (ive, "v", "x")), (pt_abs, "x")),
allow_multiple_clients=True,
name="local_log_iv",
# Start the rewrite from the less likely ive node
tracks=[ive],
get_nodes=get_clients_at_depth2,
)

register_stabilize(local_log_iv)
8 changes: 6 additions & 2 deletions pytensor/xtensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,12 @@ def isinf(): ...
def isnan(): ...


@_as_xelemwise(ps.iv)
def iv(): ...
def iv(v, x):
"""Modified Bessel function of the first kind of order v (real).

Computed as ``ive(v, x) * exp(abs(x))`` for numerical consistency.
"""
return ive(v, x) * exp(abs(x))


@_as_xelemwise(ps.ive)
Expand Down
2 changes: 1 addition & 1 deletion tests/link/jax/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@


try:
pass
import tensorflow_probability.substrates.jax.math # noqa: F401

TFP_INSTALLED = True
except ModuleNotFoundError:
Expand Down
13 changes: 13 additions & 0 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4785,6 +4785,19 @@ def test_log_kv_stabilization():
)


def test_log_iv_stabilization():
x = pt.scalar("x")
out = log(pt.iv(4.5, x))

# Expression would overflow to inf without rewrite
mode = get_default_mode().including("stabilize")
# Reference value log(ive(4.5, 1000.0)) + 1000.0
np.testing.assert_allclose(
out.eval({x: 1000.0}, mode=mode),
995.6171788390135,
)


@pytest.mark.parametrize("shape", [(), (4, 5, 6)], ids=["scalar", "tensor"])
def test_pow_1_rewrite(shape):
x = pt.tensor("x", shape=shape)
Expand Down
Loading