rewrite for solve for diagonal matrices#1932
rewrite for solve for diagonal matrices#1932Jasjeet-Singh-S wants to merge 2 commits intopymc-devs:v3from
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a canonicalization rewrite to simplify pt.linalg.solve when the coefficient matrix is explicitly constructed as a diagonal matrix via pt.diag(d), avoiding an expensive general solve in favor of elementwise division.
Changes:
- Add
rewrite_solve_diagto replaceBlockwise(Solve)(diag(d), b)with an equivalent division expression. - Add tests asserting the
Blockwise(Solve)node is eliminated and validating numeric equivalence for vector-band matrix-bcases.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
pytensor/tensor/rewriting/linalg.py |
Introduces the new canonicalize rewrite for solve(diag(d), b) → elementwise division. |
tests/tensor/rewriting/test_linalg.py |
Adds regression tests that the rewrite fires and matches the reference solve result. |
| def test_solve_diag_vector_b(): | ||
| d = pt.vector("d") | ||
| b = pt.vector("b") | ||
| x = solve(pt.diag(d), b) | ||
|
|
||
| f = function([d, b], x, mode="FAST_RUN") | ||
| nodes = f.maker.fgraph.apply_nodes | ||
| assert not any( | ||
| isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve) | ||
| for node in nodes | ||
| ) | ||
|
|
There was a problem hiding this comment.
Please add coverage for the batched-vector case (b_ndim=1 with batch dims), e.g. d.shape == (B, N), b.shape == (B, N), and solve(diag(d), b, b_ndim=1). This is an important solve mode per pt.linalg.solve's documented (..., N, N) / (..., N) shapes, and it will catch shape/broadcasting regressions in the rewrite.
pytensor/tensor/rewriting/linalg.py
Outdated
| # b_ndim=1: b shape (N,) -> result shape (N,) | ||
| # b_ndim=2: b shape (N, K) -> result shape (N, K) | ||
| b_transposed = b[None, :] if b_ndim == 1 else b.mT | ||
| new_out = (b_transposed / pt.expand_dims(d, -2)).mT | ||
| if b_ndim == 1: | ||
| new_out = new_out.squeeze(-1) |
There was a problem hiding this comment.
b_transposed = b[None, :] breaks batched solves when b_ndim == 1 (e.g., a.shape == (B, N, N) and b.shape == (B, N) with b_ndim=1): inserting the new axis at the front changes batch axis order and yields incorrect broadcasting/results. Use an expansion on the last/core side instead (e.g., expand d with a trailing axis for b_ndim=2, and for b_ndim=1 you can simply compute b / d directly), so batch dimensions are preserved.
| # b_ndim=1: b shape (N,) -> result shape (N,) | |
| # b_ndim=2: b shape (N, K) -> result shape (N, K) | |
| b_transposed = b[None, :] if b_ndim == 1 else b.mT | |
| new_out = (b_transposed / pt.expand_dims(d, -2)).mT | |
| if b_ndim == 1: | |
| new_out = new_out.squeeze(-1) | |
| # b_ndim=1: b shape (..., N) -> result shape (..., N) | |
| # b_ndim=2: b shape (..., N, K) -> result shape (..., N, K) | |
| if b_ndim == 1: | |
| # Elementwise division along the diagonal dimension; batch axes are preserved | |
| new_out = b / d | |
| else: | |
| # For matrix RHS, expand d over the last (RHS) dimension | |
| new_out = b / pt.expand_dims(d, -1) |
Rewrite
solvewith diagonal matricesPartial implementation of #1791.
What was done
Added a graph rewrite
rewrite_solve_diaginpytensor/tensor/rewriting/linalg.pythat detects when the first argument tosolveis a diagonal matrix and replaces the expensiveBlockwise(Solve(...))node with elementwise division.For a diagonal matrix
Awith diagonal entriesd, the linear systemA @ x = bhas the closed-form solutionx = b / d, which avoids the full LU factorisation performed byscipy.linalg.solve.The rewrite handles both
b_ndim=1(vectorb) andb_ndim=2(matrixb), and detects diagonal matrices from two structural patterns:Pattern 1:
pt.diag(d)(AllocDiag)dis extracted directly as the 1D diagonal vector andb / dis computed with appropriate broadcasting.solve(pt.diag(d), b)→b / dsolve(pt.diag(d), b, b_ndim=2)→b / d[:, None]Pattern 2:
pt.eye(n) * xUses the existing
_find_diag_from_eye_mulhelper (shared withrewrite_inv_diag_to_diag_reciprocal,rewrite_det_diag_from_eye_mul, etc.) to detect elementwise multiplication with an identity matrix. The effective diagonaldis extracted depending on the shape ofx:x(0D):d = x(scalar),b / dbroadcasts triviallyx(1D):d = x, equivalent topt.diag(x)x(2D):d = x.diagonal(), zeros off the diagonal are ignoredThe rewrite is registered under
@register_canonicalizeso it fires automatically inFAST_RUNandFAST_COMPILEmodes.Tests were added in
tests/tensor/rewriting/test_linalg.py:test_solve_diag_from_diag— parametrized overb_ndim ∈ {1, 2}, verifiesBlockwise(Solve)is eliminated and the result matches bothb / dand the unoptimised reference.test_solve_diag_from_eye_mul— parametrized overx_shape ∈ {(), (5,), (5,5)}×b_ndim ∈ {1, 2}(6 cases total), verifies the rewrite fires and matches the unoptimised reference for all shape combinations.What remains to be done
Rewrite
dot/matmulwith diagonal matricesIssue #1791 also asks for rewrites for
dotandmatmulwhen one of the operands is diagonal:dot(diag(d), x)→d[:, None] * xdot(x, diag(d))→x * dmatmuland the batchedBlockwisevariants.AllocDiagandeye * xdiagonal patterns should be handled, consistent with this PR.