Skip to content

Commit 4cc525a

Browse files
committed
1 parent ba5ce6a commit 4cc525a

4 files changed

Lines changed: 83 additions & 15 deletions

File tree

latest/docs/autojac/backward/index.html

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -295,14 +295,22 @@
295295
<h1>backward<a class="headerlink" href="#backward" title="Link to this heading"></a></h1>
296296
<dl class="py function">
297297
<dt class="sig sig-object py" id="torchjd.autojac.backward">
298-
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_backward.py#L9-L83"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.backward" title="Link to this definition"></a></dt>
299-
<dd><p>Computes the Jacobians of all values in <code class="docutils literal notranslate"><span class="pre">tensors</span></code> with respect to all <code class="docutils literal notranslate"><span class="pre">inputs</span></code> and
300-
accumulates them in the <code class="docutils literal notranslate"><span class="pre">.jac</span></code> fields of the <code class="docutils literal notranslate"><span class="pre">inputs</span></code>.</p>
298+
<span class="sig-prename descclassname"><span class="pre">torchjd.autojac.</span></span><span class="sig-name descname"><span class="pre">backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensors</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">jac_tensors</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retain_graph</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parallel_chunk_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/autojac/_backward.py#L16-L121"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.autojac.backward" title="Link to this definition"></a></dt>
299+
<dd><p>Computes the Jacobians of <code class="docutils literal notranslate"><span class="pre">tensors</span></code> with respect to <code class="docutils literal notranslate"><span class="pre">inputs</span></code>, left-multiplied by
300+
<code class="docutils literal notranslate"><span class="pre">jac_tensors</span></code> (or identity if <code class="docutils literal notranslate"><span class="pre">jac_tensors</span></code> is <code class="docutils literal notranslate"><span class="pre">None</span></code>), and accumulates the results in the
301+
<code class="docutils literal notranslate"><span class="pre">.jac</span></code> fields of the <code class="docutils literal notranslate"><span class="pre">inputs</span></code>.</p>
301302
<dl class="field-list simple">
302303
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
303304
<dd class="field-odd"><ul class="simple">
304-
<li><p><strong>tensors</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>] | <a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a></span>) – The tensor or tensors to differentiate. Should be non-empty. The Jacobians will
305-
have one row for each value of each of these tensors.</p></li>
305+
<li><p><strong>tensors</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>] | <a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a></span>) – The tensor or tensors to differentiate. Should be non-empty.</p></li>
306+
<li><p><strong>jac_tensors</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>] | <a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a> | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The initial Jacobians to backpropagate, analog to the <cite>grad_tensors</cite>
307+
parameter of <cite>torch.autograd.backward</cite>. If provided, it must have the same structure as
308+
<code class="docutils literal notranslate"><span class="pre">tensors</span></code> and each tensor in <code class="docutils literal notranslate"><span class="pre">jac_tensors</span></code> must match the shape of the corresponding
309+
tensor in <code class="docutils literal notranslate"><span class="pre">tensors</span></code>, with an extra leading dimension representing the number of rows of
310+
the resulting Jacobian (e.g. the number of losses). All tensors in <code class="docutils literal notranslate"><span class="pre">jac_tensors</span></code> must
311+
have the same first dimension. If <code class="docutils literal notranslate"><span class="pre">None</span></code>, defaults to the identity matrix. In this case,
312+
the standard Jacobian of <code class="docutils literal notranslate"><span class="pre">tensors</span></code> is computed, with one row for each value in the
313+
<code class="docutils literal notranslate"><span class="pre">tensors</span></code>.</p></li>
306314
<li><p><strong>inputs</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable" title="(in Python v3.14)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Iterable</span></code></a>[<a class="reference external" href="https://docs.pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.10)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a>] | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.14)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The tensors with respect to which the Jacobians must be computed. These must have
307315
their <code class="docutils literal notranslate"><span class="pre">requires_grad</span></code> flag set to <code class="docutils literal notranslate"><span class="pre">True</span></code>. If not provided, defaults to the leaf tensors
308316
that were used to compute the <code class="docutils literal notranslate"><span class="pre">tensors</span></code> parameter.</p></li>
@@ -321,7 +329,7 @@ <h1>backward<a class="headerlink" href="#backward" title="Link to this heading">
321329
</dl>
322330
<div class="admonition-example admonition">
323331
<p class="admonition-title">Example</p>
324-
<p>The following code snippet showcases a simple usage of <code class="docutils literal notranslate"><span class="pre">backward</span></code>.</p>
332+
<p>This example shows a simple usage of <code class="docutils literal notranslate"><span class="pre">backward</span></code>.</p>
325333
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
326334
<span class="gp">&gt;&gt;&gt;</span>
327335
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torchjd.autojac</span><span class="w"> </span><span class="kn">import</span> <span class="n">backward</span>
@@ -341,6 +349,32 @@ <h1>backward<a class="headerlink" href="#backward" title="Link to this heading">
341349
<p>The <code class="docutils literal notranslate"><span class="pre">.jac</span></code> field of <code class="docutils literal notranslate"><span class="pre">param</span></code> now contains the Jacobian of
342350
<span class="math notranslate nohighlight">\(\begin{bmatrix}y_1 \\ y_2\end{bmatrix}\)</span> with respect to <code class="docutils literal notranslate"><span class="pre">param</span></code>.</p>
343351
</div>
352+
<div class="admonition-example admonition">
353+
<p class="admonition-title">Example</p>
354+
<p>This is the same example as before, except that we explicitly specify <code class="docutils literal notranslate"><span class="pre">jac_tensors</span></code> as
355+
the rows of the identity matrix (which is equivalent to using the default <code class="docutils literal notranslate"><span class="pre">None</span></code>).</p>
356+
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
357+
<span class="gp">&gt;&gt;&gt;</span>
358+
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span><span class="w"> </span><span class="nn">torchjd.autojac</span><span class="w"> </span><span class="kn">import</span> <span class="n">backward</span>
359+
<span class="gp">&gt;&gt;&gt;</span>
360+
<span class="gp">&gt;&gt;&gt; </span><span class="n">param</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">1.</span><span class="p">,</span> <span class="mf">2.</span><span class="p">],</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
361+
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># Compute arbitrary quantities that are function of param</span>
362+
<span class="gp">&gt;&gt;&gt; </span><span class="n">y1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">])</span> <span class="o">@</span> <span class="n">param</span>
363+
<span class="gp">&gt;&gt;&gt; </span><span class="n">y2</span> <span class="o">=</span> <span class="p">(</span><span class="n">param</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
364+
<span class="gp">&gt;&gt;&gt;</span>
365+
<span class="gp">&gt;&gt;&gt; </span><span class="n">J1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">])</span>
366+
<span class="gp">&gt;&gt;&gt; </span><span class="n">J2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">])</span>
367+
<span class="gp">&gt;&gt;&gt;</span>
368+
<span class="gp">&gt;&gt;&gt; </span><span class="n">backward</span><span class="p">([</span><span class="n">y1</span><span class="p">,</span> <span class="n">y2</span><span class="p">],</span> <span class="n">jac_tensors</span><span class="o">=</span><span class="p">[</span><span class="n">J1</span><span class="p">,</span> <span class="n">J2</span><span class="p">])</span>
369+
<span class="gp">&gt;&gt;&gt;</span>
370+
<span class="gp">&gt;&gt;&gt; </span><span class="n">param</span><span class="o">.</span><span class="n">jac</span>
371+
<span class="go">tensor([[-1., 1.],</span>
372+
<span class="go"> [ 2., 4.]])</span>
373+
</pre></div>
374+
</div>
375+
<p>Instead of using the identity <code class="docutils literal notranslate"><span class="pre">jac_tensors</span></code>, you can backpropagate some Jacobians obtained
376+
by a call to <a class="reference internal" href="../jac/#torchjd.autojac.jac" title="torchjd.autojac.jac"><code class="xref py py-func docutils literal notranslate"><span class="pre">torchjd.autojac.jac()</span></code></a> on a later part of the computation graph.</p>
377+
</div>
344378
<div class="admonition warning">
345379
<p class="admonition-title">Warning</p>
346380
<p>To differentiate in parallel, <code class="docutils literal notranslate"><span class="pre">backward</span></code> relies on <code class="docutils literal notranslate"><span class="pre">torch.vmap</span></code>, which has some

0 commit comments

Comments
 (0)