Skip to content

rewrite for solve for diagonal matrices#1932

Draft
Jasjeet-Singh-S wants to merge 2 commits intopymc-devs:v3from
Jasjeet-Singh-S:diagonal-solve-dot-new
Draft

rewrite for solve for diagonal matrices#1932
Jasjeet-Singh-S wants to merge 2 commits intopymc-devs:v3from
Jasjeet-Singh-S:diagonal-solve-dot-new

Conversation

@Jasjeet-Singh-S
Copy link

@Jasjeet-Singh-S Jasjeet-Singh-S commented Mar 4, 2026

Rewrite solve with diagonal matrices

Partial implementation of #1791.

What was done

Added a graph rewrite rewrite_solve_diag in pytensor/tensor/rewriting/linalg.py that detects when the first argument to solve is a diagonal matrix and replaces the expensive Blockwise(Solve(...)) node with elementwise division.

For a diagonal matrix A with diagonal entries d, the linear system A @ x = b has the closed-form solution x = b / d, which avoids the full LU factorisation performed by scipy.linalg.solve.

The rewrite handles both b_ndim=1 (vector b) and b_ndim=2 (matrix b), and detects diagonal matrices from two structural patterns:

Pattern 1: pt.diag(d) (AllocDiag)

d is extracted directly as the 1D diagonal vector and b / d is computed with appropriate broadcasting.

  • solve(pt.diag(d), b)b / d
  • solve(pt.diag(d), b, b_ndim=2)b / d[:, None]

Pattern 2: pt.eye(n) * x

Uses the existing _find_diag_from_eye_mul helper (shared with rewrite_inv_diag_to_diag_reciprocal, rewrite_det_diag_from_eye_mul, etc.) to detect elementwise multiplication with an identity matrix. The effective diagonal d is extracted depending on the shape of x:

  • Scalar x (0D): d = x (scalar), b / d broadcasts trivially
  • Vector x (1D): d = x, equivalent to pt.diag(x)
  • Matrix x (2D): d = x.diagonal(), zeros off the diagonal are ignored

The rewrite is registered under @register_canonicalize so it fires automatically in FAST_RUN and FAST_COMPILE modes.

Tests were added in tests/tensor/rewriting/test_linalg.py:

  • test_solve_diag_from_diag — parametrized over b_ndim ∈ {1, 2}, verifies Blockwise(Solve) is eliminated and the result matches both b / d and the unoptimised reference.
  • test_solve_diag_from_eye_mul — parametrized over x_shape ∈ {(), (5,), (5,5)} × b_ndim ∈ {1, 2} (6 cases total), verifies the rewrite fires and matches the unoptimised reference for all shape combinations.

What remains to be done

Rewrite dot/matmul with diagonal matrices

Issue #1791 also asks for rewrites for dot and matmul when one of the operands is diagonal:

  • dot(diag(d), x)d[:, None] * x
  • dot(x, diag(d))x * d
  • Same patterns for matmul and the batched Blockwise variants.
  • Both AllocDiag and eye * x diagonal patterns should be handled, consistent with this PR.

Copilot AI review requested due to automatic review settings March 4, 2026 08:35
@Jasjeet-Singh-S Jasjeet-Singh-S marked this pull request as draft March 4, 2026 08:35
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a canonicalization rewrite to simplify pt.linalg.solve when the coefficient matrix is explicitly constructed as a diagonal matrix via pt.diag(d), avoiding an expensive general solve in favor of elementwise division.

Changes:

  • Add rewrite_solve_diag to replace Blockwise(Solve)(diag(d), b) with an equivalent division expression.
  • Add tests asserting the Blockwise(Solve) node is eliminated and validating numeric equivalence for vector-b and matrix-b cases.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
pytensor/tensor/rewriting/linalg.py Introduces the new canonicalize rewrite for solve(diag(d), b) → elementwise division.
tests/tensor/rewriting/test_linalg.py Adds regression tests that the rewrite fires and matches the reference solve result.

Comment on lines +1133 to +1144
def test_solve_diag_vector_b():
d = pt.vector("d")
b = pt.vector("b")
x = solve(pt.diag(d), b)

f = function([d, b], x, mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not any(
isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve)
for node in nodes
)

Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add coverage for the batched-vector case (b_ndim=1 with batch dims), e.g. d.shape == (B, N), b.shape == (B, N), and solve(diag(d), b, b_ndim=1). This is an important solve mode per pt.linalg.solve's documented (..., N, N) / (..., N) shapes, and it will catch shape/broadcasting regressions in the rewrite.

Copilot generated this review using guidance from repository custom instructions.
Comment on lines +1177 to +1182
# b_ndim=1: b shape (N,) -> result shape (N,)
# b_ndim=2: b shape (N, K) -> result shape (N, K)
b_transposed = b[None, :] if b_ndim == 1 else b.mT
new_out = (b_transposed / pt.expand_dims(d, -2)).mT
if b_ndim == 1:
new_out = new_out.squeeze(-1)
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

b_transposed = b[None, :] breaks batched solves when b_ndim == 1 (e.g., a.shape == (B, N, N) and b.shape == (B, N) with b_ndim=1): inserting the new axis at the front changes batch axis order and yields incorrect broadcasting/results. Use an expansion on the last/core side instead (e.g., expand d with a trailing axis for b_ndim=2, and for b_ndim=1 you can simply compute b / d directly), so batch dimensions are preserved.

Suggested change
# b_ndim=1: b shape (N,) -> result shape (N,)
# b_ndim=2: b shape (N, K) -> result shape (N, K)
b_transposed = b[None, :] if b_ndim == 1 else b.mT
new_out = (b_transposed / pt.expand_dims(d, -2)).mT
if b_ndim == 1:
new_out = new_out.squeeze(-1)
# b_ndim=1: b shape (..., N) -> result shape (..., N)
# b_ndim=2: b shape (..., N, K) -> result shape (..., N, K)
if b_ndim == 1:
# Elementwise division along the diagonal dimension; batch axes are preserved
new_out = b / d
else:
# For matrix RHS, expand d over the last (RHS) dimension
new_out = b / pt.expand_dims(d, -1)

Copilot uses AI. Check for mistakes.
@Jasjeet-Singh-S Jasjeet-Singh-S changed the title incomplete rewrite for solve for diagonal matrices rewrite for solve for diagonal matrices Mar 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants