-
Notifications
You must be signed in to change notification settings - Fork 183
Fix incorrect gradient in Solve for structured assume_a (sym/pos/her) #1887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
679588f
9cfbc38
7d04a40
0d1658a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -399,6 +399,69 @@ def test_solve_gradient( | |
| lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("b_shape", [(5, 1), (5,)], ids=["b_col_vec", "b_vec"]) | ||
| @pytest.mark.parametrize( | ||
| "assume_a, lower", | ||
| [ | ||
| ("sym", False), | ||
| ("sym", True), | ||
| ("pos", False), | ||
| ("pos", True), | ||
| ], | ||
| ids=["sym_upper", "sym_lower", "pos_upper", "pos_lower"], | ||
| ) | ||
| @pytest.mark.skipif( | ||
| config.floatX == "float32", | ||
| reason="Gradients not numerically stable in float32", | ||
| ) | ||
| def test_solve_symmetric_gradient_direct( | ||
| self, b_shape: tuple[int], assume_a: str, lower: bool | ||
| ): | ||
| """Test that the gradient of Solve is correct when a pre-structured | ||
| matrix is passed directly, without composing with a symmetrization | ||
| wrapper. This catches bugs where L_op doesn't account for the solver | ||
| only reading one triangle of A.""" | ||
| rng = np.random.default_rng(utt.fetch_seed()) | ||
|
|
||
| A_raw = rng.normal(size=(5, 5)).astype(config.floatX) | ||
| if assume_a == "pos": | ||
| A_val = (A_raw @ A_raw.T + 5 * np.eye(5)).astype(config.floatX) | ||
| else: | ||
| A_val = ((A_raw + A_raw.T) / 2).astype(config.floatX) | ||
| b_val = rng.normal(size=b_shape).astype(config.floatX) | ||
|
|
||
| A = pt.tensor("A", shape=(5, 5)) | ||
| b = pt.tensor("b", shape=b_shape) | ||
| x = solve(A, b, assume_a=assume_a, lower=lower, b_ndim=len(b_shape)) | ||
| loss = x.sum() | ||
| g_A = grad(loss, A) | ||
| f = function([A, b], g_A) | ||
|
|
||
| analytic = f(A_val, b_val) | ||
|
|
||
| # Numerical gradient: perturb only the read triangle | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we concot a graph that allows us to use Something whose input is just the triangular entries? I'm assuming they were being half counted? You can still verify they came out as zeros on an explicit grad fn
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion i've updated the test accordingly The manual finite-difference loop has been replaced with utt.verify_grad using a triangular parameterization of the structured entries the symmetric matrix is reconstructed inside the graph and i’ve also added an explicit assertion that the unread triangle has zero gradients |
||
| eps = 1e-7 | ||
| numerical = np.zeros_like(A_val) | ||
| for i in range(5): | ||
| for j in range(5): | ||
| if lower and j > i: | ||
| continue | ||
| if not lower and j < i: | ||
| continue | ||
| A_plus = A_val.copy() | ||
| A_plus[i, j] += eps | ||
| A_minus = A_val.copy() | ||
| A_minus[i, j] -= eps | ||
| x_plus = scipy_linalg.solve( | ||
| A_plus, b_val, assume_a=assume_a, lower=lower | ||
| ) | ||
| x_minus = scipy_linalg.solve( | ||
| A_minus, b_val, assume_a=assume_a, lower=lower | ||
| ) | ||
| numerical[i, j] = (x_plus.sum() - x_minus.sum()) / (2 * eps) | ||
|
|
||
| np.testing.assert_allclose(analytic, numerical, atol=1e-5, rtol=1e-5) | ||
|
|
||
| def test_solve_tringular_indirection(self): | ||
| a = pt.matrix("a") | ||
| b = pt.vector("b") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise the hermetian case is wrong