You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
<spanclass="sig-prename descclassname"><spanclass="pre">torchjd.autojac.</span></span><spanclass="sig-name descname"><spanclass="pre">backward</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">tensors</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">inputs</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">retain_graph</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">False</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">parallel_chunk_size</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_backward.py#L9-L83"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.autojac.backward" title="Link to this definition">¶</a></dt>
299
-
<dd><p>Computes the Jacobians of all values in <codeclass="docutils literal notranslate"><spanclass="pre">tensors</span></code> with respect to all <codeclass="docutils literal notranslate"><spanclass="pre">inputs</span></code> and
300
-
accumulates them in the <codeclass="docutils literal notranslate"><spanclass="pre">.jac</span></code> fields of the <codeclass="docutils literal notranslate"><spanclass="pre">inputs</span></code>.</p>
298
+
<spanclass="sig-prename descclassname"><spanclass="pre">torchjd.autojac.</span></span><spanclass="sig-name descname"><spanclass="pre">backward</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">tensors</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">jac_tensors</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">inputs</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">retain_graph</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">False</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">parallel_chunk_size</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em><spanclass="sig-paren">)</span><aclass="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_backward.py#L16-L121"><spanclass="viewcode-link"><spanclass="pre">[source]</span></span></a><aclass="headerlink" href="#torchjd.autojac.backward" title="Link to this definition">¶</a></dt>
299
+
<dd><p>Computes the Jacobians of <codeclass="docutils literal notranslate"><spanclass="pre">tensors</span></code> with respect to <codeclass="docutils literal notranslate"><spanclass="pre">inputs</span></code>, left-multiplied by
300
+
<codeclass="docutils literal notranslate"><spanclass="pre">jac_tensors</span></code> (or identity if <codeclass="docutils literal notranslate"><spanclass="pre">jac_tensors</span></code> is <codeclass="docutils literal notranslate"><spanclass="pre">None</span></code>), and accumulates the results in the
301
+
<codeclass="docutils literal notranslate"><spanclass="pre">.jac</span></code> fields of the <codeclass="docutils literal notranslate"><spanclass="pre">inputs</span></code>.</p>
parameter of <cite>torch.autograd.backward</cite>. If provided, it must have the same structure as
308
+
<codeclass="docutils literal notranslate"><spanclass="pre">tensors</span></code> and each tensor in <codeclass="docutils literal notranslate"><spanclass="pre">jac_tensors</span></code> must match the shape of the corresponding
309
+
tensor in <codeclass="docutils literal notranslate"><spanclass="pre">tensors</span></code>, with an extra leading dimension representing the number of rows of
310
+
the resulting Jacobian (e.g. the number of losses). All tensors in <codeclass="docutils literal notranslate"><spanclass="pre">jac_tensors</span></code> must
311
+
have the same first dimension. If <codeclass="docutils literal notranslate"><spanclass="pre">None</span></code>, defaults to the identity matrix. In this case,
312
+
the standard Jacobian of <codeclass="docutils literal notranslate"><spanclass="pre">tensors</span></code> is computed, with one row for each value in the
<li><p><strong>inputs</strong> (<spanclass="sphinx_autodoc_typehints-type"><aclass="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable" title="(in Python v3.14)"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Iterable</span></code></a>[<aclass="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><codeclass="xref py py-class docutils literal notranslate"><spanclass="pre">Tensor</span></code></a>] | <aclass="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><codeclass="xref py py-obj docutils literal notranslate"><spanclass="pre">None</span></code></a></span>) – The tensors with respect to which the Jacobians must be computed. These must have
307
315
their <codeclass="docutils literal notranslate"><spanclass="pre">requires_grad</span></code> flag set to <codeclass="docutils literal notranslate"><spanclass="pre">True</span></code>. If not provided, defaults to the leaf tensors
308
316
that were used to compute the <codeclass="docutils literal notranslate"><spanclass="pre">tensors</span></code> parameter.</p></li>
@@ -321,7 +329,7 @@ <h1>backward<a class="headerlink" href="#backward" title="Link to this heading">
321
329
</dl>
322
330
<divclass="admonition-example admonition">
323
331
<pclass="admonition-title">Example</p>
324
-
<p>The following code snippet showcases a simple usage of <codeclass="docutils literal notranslate"><spanclass="pre">backward</span></code>.</p>
332
+
<p>This example shows a simple usage of <codeclass="docutils literal notranslate"><spanclass="pre">backward</span></code>.</p>
@@ -341,6 +349,32 @@ <h1>backward<a class="headerlink" href="#backward" title="Link to this heading">
341
349
<p>The <codeclass="docutils literal notranslate"><spanclass="pre">.jac</span></code> field of <codeclass="docutils literal notranslate"><spanclass="pre">param</span></code> now contains the Jacobian of
342
350
<spanclass="math notranslate nohighlight">\(\begin{bmatrix}y_1 \\ y_2\end{bmatrix}\)</span> with respect to <codeclass="docutils literal notranslate"><spanclass="pre">param</span></code>.</p>
343
351
</div>
352
+
<divclass="admonition-example admonition">
353
+
<pclass="admonition-title">Example</p>
354
+
<p>This is the same example as before, except that we explicitly specify <codeclass="docutils literal notranslate"><spanclass="pre">jac_tensors</span></code> as
355
+
the rows of the identity matrix (which is equivalent to using the default <codeclass="docutils literal notranslate"><spanclass="pre">None</span></code>).</p>
<p>Instead of using the identity <codeclass="docutils literal notranslate"><spanclass="pre">jac_tensors</span></code>, you can backpropagate some Jacobians obtained
376
+
by a call to <aclass="reference internal" href="../jac/#torchjd.autojac.jac" title="torchjd.autojac.jac"><codeclass="xref py py-func docutils literal notranslate"><spanclass="pre">torchjd.autojac.jac()</span></code></a> on a later part of the computation graph.</p>
377
+
</div>
344
378
<divclass="admonition warning">
345
379
<pclass="admonition-title">Warning</p>
346
380
<p>To differentiate in parallel, <codeclass="docutils literal notranslate"><spanclass="pre">backward</span></code> relies on <codeclass="docutils literal notranslate"><spanclass="pre">torch.vmap</span></code>, which has some
0 commit comments