@@ -56,13 +56,15 @@ def test_jac():
5656 assert jacobian .shape [1 :] == a .shape
5757
5858
59- @mark .parametrize ("shape" , [(1 , 3 ), (2 , 3 ), (2 , 6 ), (5 , 8 ), (20 , 55 )])
60- @mark .parametrize ("manually_specify_inputs" , [True , False ])
59+ @mark .parametrize ("shape" , [(1 , 1 ), (1 , 3 ), (2 , 1 ), (2 , 6 ), (20 , 55 )])
6160@mark .parametrize ("chunk_size" , [1 , 2 , None ])
61+ @mark .parametrize ("outputs_is_list" , [True , False ])
62+ @mark .parametrize ("inputs_is_list" , [True , False ])
6263def test_value_is_correct (
6364 shape : tuple [int , int ],
64- manually_specify_inputs : bool ,
6565 chunk_size : int | None ,
66+ outputs_is_list : bool ,
67+ inputs_is_list : bool ,
6668):
6769 """
6870 Tests that the jacobians returned by jac are correct in a simple example of matrix-vector
@@ -73,13 +75,10 @@ def test_value_is_correct(
7375 input = randn_ ([shape [1 ]], requires_grad = True )
7476 output = J @ input # Note that the Jacobian of output w.r.t. input is J.
7577
76- inputs = [input ] if manually_specify_inputs else None
78+ outputs = [output ] if outputs_is_list else output
79+ inputs = [input ] if inputs_is_list else input
7780
78- jacobians = jac (
79- [output ],
80- inputs = inputs ,
81- parallel_chunk_size = chunk_size ,
82- )
81+ jacobians = jac (outputs , inputs , parallel_chunk_size = chunk_size )
8382
8483 assert len (jacobians ) == 1
8584 assert_close (jacobians [0 ], J )
@@ -103,7 +102,7 @@ def test_jac_outputs_value_is_correct(rows: int):
103102
104103 jacobians = jac (
105104 output ,
106- inputs = [ input ] ,
105+ input ,
107106 jac_outputs = J_init ,
108107 )
109108
@@ -126,7 +125,7 @@ def test_jac_outputs_multiple_components(rows: int):
126125 J1 = randn_ ((rows , 2 ))
127126 J2 = randn_ ((rows , 3 ))
128127
129- jacobians = jac ([y1 , y2 ], inputs = [ input ] , jac_outputs = [J1 , J2 ])
128+ jacobians = jac ([y1 , y2 ], input , jac_outputs = [J1 , J2 ])
130129
131130 jac_y1 = eye_ (2 ) * 2
132131
@@ -149,7 +148,7 @@ def test_jac_outputs_length_mismatch():
149148 ValueError ,
150149 match = r"`jac_outputs` should have the same length as `outputs`\. \(got 1 and 2\)" ,
151150 ):
152- jac ([y1 , y2 ], inputs = [ x ] , jac_outputs = [J1 ])
151+ jac ([y1 , y2 ], x , jac_outputs = [J1 ])
153152
154153
155154def test_jac_outputs_shape_mismatch ():
@@ -166,7 +165,7 @@ def test_jac_outputs_shape_mismatch():
166165 ValueError ,
167166 match = r"Shape mismatch: `jac_outputs\[0\]` has shape .* but `outputs\[0\]` has shape .*\." ,
168167 ):
169- jac (y , inputs = [ x ] , jac_outputs = J_bad )
168+ jac (y , x , jac_outputs = J_bad )
170169
171170
172171@mark .parametrize (
@@ -192,7 +191,7 @@ def test_jac_outputs_inconsistent_first_dimension(rows_y1: int, rows_y2: int):
192191 with raises (
193192 ValueError , match = r"All Jacobians in `jac_outputs` should have the same number of rows\."
194193 ):
195- jac ([y1 , y2 ], inputs = [ x ] , jac_outputs = [j1 , j2 ])
194+ jac ([y1 , y2 ], x , jac_outputs = [j1 , j2 ])
196195
197196
198197def test_empty_inputs ():
@@ -220,7 +219,7 @@ def test_partial_inputs():
220219 y1 = tensor_ ([- 1.0 , 1.0 ]) @ a1 + a2 .sum ()
221220 y2 = (a1 ** 2 ).sum () + a2 .norm ()
222221
223- jacobians = jac ([y1 , y2 ], inputs = [ a1 ] )
222+ jacobians = jac ([y1 , y2 ], a1 )
224223 assert len (jacobians ) == 1
225224
226225
@@ -250,7 +249,7 @@ def test_multiple_tensors():
250249 y1 = tensor_ ([- 1.0 , 1.0 ]) @ a1 + a2 .sum ()
251250 y2 = (a1 ** 2 ).sum () + a2 .norm ()
252251
253- jacobians = jac ([y1 , y2 ])
252+ jacobians = jac ([y1 , y2 ], [ a1 , a2 ] )
254253 assert len (jacobians ) == 2
255254 assert_close (jacobians [0 ], J1 )
256255 assert_close (jacobians [1 ], J2 )
@@ -262,7 +261,7 @@ def test_multiple_tensors():
262261 z1 = tensor_ ([- 1.0 , 1.0 ]) @ b1 + b2 .sum ()
263262 z2 = (b1 ** 2 ).sum () + b2 .norm ()
264263
265- jacobians = jac (torch .cat ([z1 .reshape (- 1 ), z2 .reshape (- 1 )]))
264+ jacobians = jac (torch .cat ([z1 .reshape (- 1 ), z2 .reshape (- 1 )]), [ b1 , b2 ] )
266265 assert len (jacobians ) == 2
267266 assert_close (jacobians [0 ], J1 )
268267 assert_close (jacobians [1 ], J2 )
@@ -278,7 +277,7 @@ def test_various_valid_chunk_sizes(chunk_size):
278277 y1 = tensor_ ([- 1.0 , 1.0 ]) @ a1 + a2 .sum ()
279278 y2 = (a1 ** 2 ).sum () + a2 .norm ()
280279
281- jacobians = jac ([y1 , y2 ], parallel_chunk_size = chunk_size )
280+ jacobians = jac ([y1 , y2 ], [ a1 , a2 ], parallel_chunk_size = chunk_size )
282281 assert len (jacobians ) == 2
283282
284283
@@ -293,7 +292,7 @@ def test_non_positive_chunk_size_fails(chunk_size: int):
293292 y2 = (a1 ** 2 ).sum () + a2 .norm ()
294293
295294 with raises (ValueError ):
296- jac ([y1 , y2 ], parallel_chunk_size = chunk_size )
295+ jac ([y1 , y2 ], [ a1 , a2 ], parallel_chunk_size = chunk_size )
297296
298297
299298def test_input_retaining_grad_fails ():
@@ -309,7 +308,7 @@ def test_input_retaining_grad_fails():
309308
310309 # jac itself doesn't raise the error, but it fills b.grad with a BatchedTensor (and it also
311310 # returns the correct Jacobian)
312- jac (y , inputs = [ b ] )
311+ jac (y , b )
313312
314313 with raises (RuntimeError ):
315314 # Using such a BatchedTensor should result in an error
@@ -328,7 +327,7 @@ def test_non_input_retaining_grad_fails():
328327 y = 3 * b
329328
330329 # jac itself doesn't raise the error, but it fills b.grad with a BatchedTensor
331- jac (y , inputs = [ a ] )
330+ jac (y , a )
332331
333332 with raises (RuntimeError ):
334333 # Using such a BatchedTensor should result in an error
@@ -348,7 +347,7 @@ def test_tensor_used_multiple_times(chunk_size: int | None):
348347 d = a * c
349348 e = a * d
350349
351- jacobians = jac ([d , e ], parallel_chunk_size = chunk_size )
350+ jacobians = jac ([d , e ], a , parallel_chunk_size = chunk_size )
352351 assert len (jacobians ) == 1
353352
354353 J = tensor_ ([2.0 * 3.0 * (a ** 2 ).item (), 2.0 * 4.0 * (a ** 3 ).item ()])
@@ -372,7 +371,7 @@ def test_repeated_tensors():
372371 y2 = (a1 ** 2 ).sum () + (a2 ** 2 ).sum ()
373372
374373 with raises (ValueError ):
375- jac ([y1 , y1 , y2 ])
374+ jac ([y1 , y1 , y2 ], [ a1 , a2 ] )
376375
377376
378377def test_repeated_inputs ():
0 commit comments