Skip to content

Commit 8862d13

Browse files
refactor(autojac): Make jac match autograd.grad (#581)
1 parent db59ec6 commit 8862d13

3 files changed

Lines changed: 34 additions & 40 deletions

File tree

src/torchjd/autojac/_jac.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Iterable, Sequence
1+
from collections.abc import Sequence
22

33
from torch import Tensor
44

@@ -13,13 +13,12 @@
1313
check_matching_jac_shapes,
1414
check_matching_length,
1515
check_optional_positive_chunk_size,
16-
get_leaf_tensors,
1716
)
1817

1918

2019
def jac(
2120
outputs: Sequence[Tensor] | Tensor,
22-
inputs: Iterable[Tensor] | None = None,
21+
inputs: Sequence[Tensor] | Tensor,
2322
*,
2423
jac_outputs: Sequence[Tensor] | Tensor | None = None,
2524
retain_graph: bool = False,
@@ -32,9 +31,8 @@ def jac(
3231
``[m] + t.shape``.
3332
3433
:param outputs: The tensor or tensors to differentiate. Should be non-empty.
35-
:param inputs: The tensors with respect to which the Jacobian must be computed. These must have
36-
their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
37-
that were used to compute the ``outputs`` parameter.
34+
:param inputs: The tensor or tensors with respect to which the Jacobian must be computed. These
35+
must have their ``requires_grad`` flag set to ``True``.
3836
:param jac_outputs: The initial Jacobians to backpropagate, analog to the ``grad_outputs``
3937
parameter of :func:`torch.autograd.grad`. If provided, it must have the same structure as
4038
``outputs`` and each tensor in ``jac_outputs`` must match the shape of the corresponding
@@ -69,7 +67,7 @@ def jac(
6967
>>> y1 = torch.tensor([-1., 1.]) @ param
7068
>>> y2 = (param ** 2).sum()
7169
>>>
72-
>>> jacobians = jac([y1, y2], [param])
70+
>>> jacobians = jac([y1, y2], param)
7371
>>>
7472
>>> jacobians
7573
(tensor([[-1., 1.],
@@ -131,13 +129,13 @@ def jac(
131129
>>> jac_h = jac([y1, y2], [h])[0] # Shape: [2, 2]
132130
>>>
133131
>>> # Step 2: Use chain rule to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx)
134-
>>> jac_x = jac(h, [x], jac_outputs=jac_h)[0]
132+
>>> jac_x = jac(h, x, jac_outputs=jac_h)[0]
135133
>>>
136134
>>> jac_x
137135
tensor([[ 2., 4.],
138136
[ 2., -4.]])
139137
140-
This two-step computation is equivalent to directly computing ``jac([y1, y2], [x])``.
138+
This two-step computation is equivalent to directly computing ``jac([y1, y2], x)``.
141139
142140
.. warning::
143141
To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some
@@ -155,12 +153,9 @@ def jac(
155153
if len(outputs_) == 0:
156154
raise ValueError("`outputs` cannot be empty.")
157155

158-
if inputs is None:
159-
inputs_ = get_leaf_tensors(tensors=outputs_, excluded=set())
160-
inputs_with_repetition = list(inputs_)
161-
else:
162-
inputs_with_repetition = list(inputs) # Create a list to avoid emptying generator
163-
inputs_ = OrderedSet(inputs_with_repetition)
156+
# Preserve repetitions to duplicate jacobians at the return statement
157+
inputs_with_repetition = (inputs,) if isinstance(inputs, Tensor) else inputs
158+
inputs_ = OrderedSet(inputs_with_repetition)
164159

165160
jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs)
166161
transform = _create_transform(outputs_, inputs_, parallel_chunk_size, retain_graph)

tests/doc/test_jac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_jac():
1414
# Compute arbitrary quantities that are function of param
1515
y1 = torch.tensor([-1.0, 1.0]) @ param
1616
y2 = (param**2).sum()
17-
jacobians = jac([y1, y2], [param])
17+
jacobians = jac([y1, y2], param)
1818

1919
assert len(jacobians) == 1
2020
assert_close(jacobians[0], torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04)
@@ -57,6 +57,6 @@ def test_jac_3():
5757
# Step 1: Compute d[y1,y2]/dh
5858
jac_h = jac([y1, y2], [h])[0] # Shape: [2, 2]
5959
# Step 2: Use jac_outputs to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx)
60-
jac_x = jac(h, [x], jac_outputs=jac_h)[0]
60+
jac_x = jac(h, x, jac_outputs=jac_h)[0]
6161

6262
assert_close(jac_x, torch.tensor([[2.0, 4.0], [2.0, -4.0]]), rtol=0.0, atol=1e-04)

tests/unit/autojac/test_jac.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,15 @@ def test_jac():
5656
assert jacobian.shape[1:] == a.shape
5757

5858

59-
@mark.parametrize("shape", [(1, 3), (2, 3), (2, 6), (5, 8), (20, 55)])
60-
@mark.parametrize("manually_specify_inputs", [True, False])
59+
@mark.parametrize("shape", [(1, 1), (1, 3), (2, 1), (2, 6), (20, 55)])
6160
@mark.parametrize("chunk_size", [1, 2, None])
61+
@mark.parametrize("outputs_is_list", [True, False])
62+
@mark.parametrize("inputs_is_list", [True, False])
6263
def test_value_is_correct(
6364
shape: tuple[int, int],
64-
manually_specify_inputs: bool,
6565
chunk_size: int | None,
66+
outputs_is_list: bool,
67+
inputs_is_list: bool,
6668
):
6769
"""
6870
Tests that the jacobians returned by jac are correct in a simple example of matrix-vector
@@ -73,13 +75,10 @@ def test_value_is_correct(
7375
input = randn_([shape[1]], requires_grad=True)
7476
output = J @ input # Note that the Jacobian of output w.r.t. input is J.
7577

76-
inputs = [input] if manually_specify_inputs else None
78+
outputs = [output] if outputs_is_list else output
79+
inputs = [input] if inputs_is_list else input
7780

78-
jacobians = jac(
79-
[output],
80-
inputs=inputs,
81-
parallel_chunk_size=chunk_size,
82-
)
81+
jacobians = jac(outputs, inputs, parallel_chunk_size=chunk_size)
8382

8483
assert len(jacobians) == 1
8584
assert_close(jacobians[0], J)
@@ -103,7 +102,7 @@ def test_jac_outputs_value_is_correct(rows: int):
103102

104103
jacobians = jac(
105104
output,
106-
inputs=[input],
105+
input,
107106
jac_outputs=J_init,
108107
)
109108

@@ -126,7 +125,7 @@ def test_jac_outputs_multiple_components(rows: int):
126125
J1 = randn_((rows, 2))
127126
J2 = randn_((rows, 3))
128127

129-
jacobians = jac([y1, y2], inputs=[input], jac_outputs=[J1, J2])
128+
jacobians = jac([y1, y2], input, jac_outputs=[J1, J2])
130129

131130
jac_y1 = eye_(2) * 2
132131

@@ -149,7 +148,7 @@ def test_jac_outputs_length_mismatch():
149148
ValueError,
150149
match=r"`jac_outputs` should have the same length as `outputs`\. \(got 1 and 2\)",
151150
):
152-
jac([y1, y2], inputs=[x], jac_outputs=[J1])
151+
jac([y1, y2], x, jac_outputs=[J1])
153152

154153

155154
def test_jac_outputs_shape_mismatch():
@@ -166,7 +165,7 @@ def test_jac_outputs_shape_mismatch():
166165
ValueError,
167166
match=r"Shape mismatch: `jac_outputs\[0\]` has shape .* but `outputs\[0\]` has shape .*\.",
168167
):
169-
jac(y, inputs=[x], jac_outputs=J_bad)
168+
jac(y, x, jac_outputs=J_bad)
170169

171170

172171
@mark.parametrize(
@@ -192,7 +191,7 @@ def test_jac_outputs_inconsistent_first_dimension(rows_y1: int, rows_y2: int):
192191
with raises(
193192
ValueError, match=r"All Jacobians in `jac_outputs` should have the same number of rows\."
194193
):
195-
jac([y1, y2], inputs=[x], jac_outputs=[j1, j2])
194+
jac([y1, y2], x, jac_outputs=[j1, j2])
196195

197196

198197
def test_empty_inputs():
@@ -220,7 +219,7 @@ def test_partial_inputs():
220219
y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum()
221220
y2 = (a1**2).sum() + a2.norm()
222221

223-
jacobians = jac([y1, y2], inputs=[a1])
222+
jacobians = jac([y1, y2], a1)
224223
assert len(jacobians) == 1
225224

226225

@@ -250,7 +249,7 @@ def test_multiple_tensors():
250249
y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum()
251250
y2 = (a1**2).sum() + a2.norm()
252251

253-
jacobians = jac([y1, y2])
252+
jacobians = jac([y1, y2], [a1, a2])
254253
assert len(jacobians) == 2
255254
assert_close(jacobians[0], J1)
256255
assert_close(jacobians[1], J2)
@@ -262,7 +261,7 @@ def test_multiple_tensors():
262261
z1 = tensor_([-1.0, 1.0]) @ b1 + b2.sum()
263262
z2 = (b1**2).sum() + b2.norm()
264263

265-
jacobians = jac(torch.cat([z1.reshape(-1), z2.reshape(-1)]))
264+
jacobians = jac(torch.cat([z1.reshape(-1), z2.reshape(-1)]), [b1, b2])
266265
assert len(jacobians) == 2
267266
assert_close(jacobians[0], J1)
268267
assert_close(jacobians[1], J2)
@@ -278,7 +277,7 @@ def test_various_valid_chunk_sizes(chunk_size):
278277
y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum()
279278
y2 = (a1**2).sum() + a2.norm()
280279

281-
jacobians = jac([y1, y2], parallel_chunk_size=chunk_size)
280+
jacobians = jac([y1, y2], [a1, a2], parallel_chunk_size=chunk_size)
282281
assert len(jacobians) == 2
283282

284283

@@ -293,7 +292,7 @@ def test_non_positive_chunk_size_fails(chunk_size: int):
293292
y2 = (a1**2).sum() + a2.norm()
294293

295294
with raises(ValueError):
296-
jac([y1, y2], parallel_chunk_size=chunk_size)
295+
jac([y1, y2], [a1, a2], parallel_chunk_size=chunk_size)
297296

298297

299298
def test_input_retaining_grad_fails():
@@ -309,7 +308,7 @@ def test_input_retaining_grad_fails():
309308

310309
# jac itself doesn't raise the error, but it fills b.grad with a BatchedTensor (and it also
311310
# returns the correct Jacobian)
312-
jac(y, inputs=[b])
311+
jac(y, b)
313312

314313
with raises(RuntimeError):
315314
# Using such a BatchedTensor should result in an error
@@ -328,7 +327,7 @@ def test_non_input_retaining_grad_fails():
328327
y = 3 * b
329328

330329
# jac itself doesn't raise the error, but it fills b.grad with a BatchedTensor
331-
jac(y, inputs=[a])
330+
jac(y, a)
332331

333332
with raises(RuntimeError):
334333
# Using such a BatchedTensor should result in an error
@@ -348,7 +347,7 @@ def test_tensor_used_multiple_times(chunk_size: int | None):
348347
d = a * c
349348
e = a * d
350349

351-
jacobians = jac([d, e], parallel_chunk_size=chunk_size)
350+
jacobians = jac([d, e], a, parallel_chunk_size=chunk_size)
352351
assert len(jacobians) == 1
353352

354353
J = tensor_([2.0 * 3.0 * (a**2).item(), 2.0 * 4.0 * (a**3).item()])
@@ -372,7 +371,7 @@ def test_repeated_tensors():
372371
y2 = (a1**2).sum() + (a2**2).sum()
373372

374373
with raises(ValueError):
375-
jac([y1, y1, y2])
374+
jac([y1, y1, y2], [a1, a2])
376375

377376

378377
def test_repeated_inputs():

0 commit comments

Comments
 (0)