Skip to content

Commit bd50d6b

Browse files
ENH: Implement rewrite for inverse of triangular matrices
1 parent ecd1e07 commit bd50d6b

2 files changed

Lines changed: 61 additions & 0 deletions

File tree

pytensor/tensor/rewriting/linalg.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,37 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
821821
return [eye_input / non_eye_input]
822822

823823

824+
@register_canonicalize
825+
@register_stabilize
826+
@node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)])
827+
def rewrite_inv_triangular_to_solve(fgraph, node):
828+
"""
829+
Rewrite `inv(A)` -> `solve_triangular(A, I)` when A is triangular.
830+
"""
831+
# node is a Blockwise(MatrixInverse). The input to MatrixInverse is node.inputs[0]
832+
inputs = node.inputs[0]
833+
834+
# Check for tags
835+
is_lower = getattr(inputs.tag, "lower_triangular", False)
836+
is_upper = getattr(inputs.tag, "upper_triangular", False)
837+
838+
if is_lower or is_upper:
839+
# Create an identity matrix of the same size.
840+
# Note: We use the last dimension for the size of the square matrix
841+
n = inputs.shape[-1]
842+
843+
# Ensure the dtype matches the input dtype
844+
identity = pt.eye(n, dtype=inputs.type.dtype)
845+
846+
# We replace the slow Inverse with the fast Triangular Solve.
847+
# b_ndim=2 because identity is a 2D matrix.
848+
inv_val = solve_triangular(inputs, identity, lower=is_lower, b_ndim=2)
849+
850+
return [inv_val]
851+
852+
return None
853+
854+
824855
@register_canonicalize
825856
@register_stabilize
826857
@node_rewriter([ExtractDiag])

tests/tensor/rewriting/test_linalg.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor import tensor as pt
1111
from pytensor.compile import get_default_mode
1212
from pytensor.configdefaults import config
13+
from pytensor.gradient import verify_grad
1314
from pytensor.graph import FunctionGraph, ancestors
1415
from pytensor.graph.rewriting.utils import rewrite_graph
1516
from pytensor.tensor import swapaxes
@@ -1128,3 +1129,32 @@ def solve_op_in_graph(graph):
11281129
np.testing.assert_allclose(
11291130
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
11301131
)
1132+
1133+
1134+
def test_triangular_inv_rewrite_and_grad():
1135+
np.random.seed(42)
1136+
A_val = np.tril(np.random.rand(5, 5) + 5.0)
1137+
1138+
A = pt.dmatrix("A")
1139+
A.tag.lower_triangular = True
1140+
1141+
Z = pt.linalg.inv(A)
1142+
1143+
f = pytensor.function([A], Z)
1144+
1145+
assert not any(
1146+
isinstance(getattr(node.op, "core_op", node.op), MatrixInverse)
1147+
for node in f.maker.fgraph.toposort()
1148+
)
1149+
assert any(
1150+
isinstance(getattr(node.op, "core_op", node.op), SolveTriangular)
1151+
for node in f.maker.fgraph.toposort()
1152+
)
1153+
1154+
def func(x):
1155+
x = pt.as_tensor_variable(x)
1156+
x = pt.tril(x)
1157+
x.tag.lower_triangular = True
1158+
return pt.sum(pt.linalg.inv(x))
1159+
1160+
verify_grad(func, [A_val])

0 commit comments

Comments
 (0)