diff --git a/src/pyrecest/_backend/pytorch/linalg.py b/src/pyrecest/_backend/pytorch/linalg.py index 5264de67d..607e88470 100644 --- a/src/pyrecest/_backend/pytorch/linalg.py +++ b/src/pyrecest/_backend/pytorch/linalg.py @@ -211,10 +211,17 @@ def solve_sylvester(a, b, q): ) if conditions: tilde_q = eigvecs.transpose(-2, -1) @ q @ eigvecs - tilde_x = tilde_q / ( - eigvals[..., :, None] - + eigvals[..., None, :] - + _torch.eye(a.shape[-1], dtype=a.dtype, device=a.device) + denominators = eigvals[..., :, None] + eigvals[..., None, :] + safe_denominators = _torch.where( + _torch.abs(denominators) < 1e-12, + _torch.ones((), dtype=denominators.dtype, device=denominators.device), + denominators, + ) + tilde_x = tilde_q / safe_denominators + tilde_x = _torch.where( + _torch.abs(denominators) < 1e-12, + _torch.zeros((), dtype=tilde_x.dtype, device=tilde_x.device), + tilde_x, ) return eigvecs @ tilde_x @ eigvecs.transpose(-2, -1) diff --git a/tests/backend_support/test_pytorch_sylvester_semidefinite_shortcut.py b/tests/backend_support/test_pytorch_sylvester_semidefinite_shortcut.py new file mode 100644 index 000000000..88d8aac0e --- /dev/null +++ b/tests/backend_support/test_pytorch_sylvester_semidefinite_shortcut.py @@ -0,0 +1,42 @@ +"""Regression tests for PyTorch Sylvester solver shortcuts.""" + +from tests.support.backend_runner import run_backend_code + + +def test_pytorch_semidefinite_sylvester_shortcut_respects_nonzero_denominators(): + code = """ +import torch +from pyrecest.backend import linalg + +# Symmetric positive-semidefinite factor with a one-dimensional nullspace. +eigvecs = torch.tensor( + [ + [-0.23813772, -0.95532958, 0.17434201], + [0.89798926, -0.14791816, 0.41440983], + [-0.37000772, 0.25586247, 0.89310060], + ], + dtype=torch.float64, +) +eigvals = torch.tensor([0.0, 0.5, 2.0], dtype=torch.float64) +a = eigvecs @ torch.diag(eigvals) @ eigvecs.T + +# The shortcut accepts almost skew-symmetric q. The tiny diagonal entries are +# within that tolerance, but they still have nonzero Sylvester denominators in +# the non-null eigenspaces and must not be divided by denominator + I. +tilde_q = torch.tensor( + [ + [0.0, 0.3, -0.2], + [-0.3, 1.0e-7, 0.4], + [0.2, -0.4, -1.0e-7], + ], + dtype=torch.float64, +) +q = eigvecs @ tilde_q @ eigvecs.T + +solution = linalg.solve_sylvester(a, a, q) +residual = torch.linalg.norm(a @ solution + solution @ a - q) +assert torch.isfinite(solution).all() +assert residual.item() < 1e-10, residual.item() +""" + result = run_backend_code("pytorch", code) + assert result.returncode == 0, result.stderr