Skip to content

Commit ef9be50

Browse files
authored
docs: Revamp Usage section of README (#593)
* Improve explanation of autojac and autogram * Use jac_to_grad when needed
1 parent 77309f4 commit ef9be50

1 file changed

Lines changed: 73 additions & 24 deletions

File tree

README.md

Lines changed: 73 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -56,24 +56,38 @@ Some aggregators may have additional dependencies. Please refer to the
5656
[installation documentation](https://torchjd.org/stable/installation) for them.
5757

5858
## Usage
59-
There are two main ways to use TorchJD. The first one is to replace the usual call to
59+
60+
Compared to standard `torch`, `torchjd` simply changes the way to obtain the `.grad` fields of your
61+
model parameters.
62+
63+
### Using the `autojac` engine
64+
65+
The autojac engine is for computing and aggregating Jacobians efficiently.
66+
67+
#### 1. `backward` + `jac_to_grad`
68+
In standard `torch`, you generally combine your `losses` into a single scalar `loss`, and call
69+
`loss.backward()` to compute the gradient of the loss with respect to each model parameter and to
70+
store it in the `.grad` fields of those parameters. The basic usage of `torchjd` is to replace this
6071
`loss.backward()` by a call to
61-
[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) or
62-
[`torchjd.autojac.mtl_backward`](https://torchjd.org/stable/docs/autojac/mtl_backward/), depending
63-
on the use-case. This will compute the Jacobian of the vector of losses with respect to the model
64-
parameters, and aggregate it with the specified
65-
[`Aggregator`](https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Aggregator).
66-
Whenever you want to optimize the vector of per-sample losses, you should rather use the
67-
[`torchjd.autogram.Engine`](https://torchjd.org/stable/docs/autogram/engine/). Instead of
68-
computing the full Jacobian at once, it computes the Gramian of this Jacobian, layer by layer, in a
69-
memory-efficient way. A vector of weights (one per element of the batch) can then be extracted from
70-
this Gramian, using a
71-
[`Weighting`](https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Weighting),
72-
and used to combine the losses of the batch. Assuming each element of the batch is
73-
processed independently from the others, this approach is equivalent to
74-
[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) while being
75-
generally much faster due to the lower memory usage. Note that we're still working on making
76-
`autogram` faster and more memory-efficient, and it's interface may change in future releases.
72+
[`torchjd.autojac.backward(losses)`](https://torchjd.org/stable/docs/autojac/backward/). Instead of
73+
computing the gradient of a scalar loss, it will compute the Jacobian of a vector of losses, and
74+
store it in the `.jac` fields of the model parameters. You then have to call
75+
[`torchjd.autojac.jac_to_grad`](https://torchjd.org/stable/docs/autojac/jac_to_grad/) to aggregate
76+
this Jacobian using the specified
77+
[`Aggregator`](https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Aggregator), and to
78+
store the result into the `.grad` fields of the model parameters. See this
79+
[usage example](https://torchjd.org/stable/examples/basic_usage/) for more details.
80+
81+
#### 2. `mtl_backward` + `jac_to_grad`
82+
In the case of multi-task learning, an alternative to
83+
[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) is
84+
[`torchjd.autojac.mtl_backward`](https://torchjd.org/stable/docs/autojac/mtl_backward/). It computes
85+
the gradient of each task-specific loss with respect to the corresponding task's parameters, and
86+
stores it in their `.grad` fields. It also computes the Jacobian of the vector of losses with
87+
respect to the shared parameters and stores it in their `.jac` field. Then, the
88+
[`torchjd.autojac.jac_to_grad`](https://torchjd.org/stable/docs/autojac/jac_to_grad/) function can
89+
be called to aggregate this Jacobian and replace the `.jac` fields by `.grad` fields for the shared
90+
parameters.
7791

7892
The following example shows how to use TorchJD to train a multi-task model with Jacobian descent,
7993
using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
@@ -83,7 +97,7 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
8397
from torch.nn import Linear, MSELoss, ReLU, Sequential
8498
from torch.optim import SGD
8599

86-
+ from torchjd.autojac import mtl_backward
100+
+ from torchjd.autojac import jac_to_grad, mtl_backward
87101
+ from torchjd.aggregation import UPGrad
88102

89103
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
@@ -112,7 +126,8 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
112126

113127
- loss = loss1 + loss2
114128
- loss.backward()
115-
+ mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
129+
+ mtl_backward([loss1, loss2], features=features)
130+
+ jac_to_grad(shared_module.parameters(), aggregator)
116131
optimizer.step()
117132
optimizer.zero_grad()
118133
```
@@ -121,8 +136,42 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
121136
> In this example, the Jacobian is only with respect to the shared parameters. The task-specific
122137
> parameters are simply updated via the gradient of their task’s loss with respect to them.
123138
124-
The following example shows how to use TorchJD to minimize the vector of per-instance losses with
125-
Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
139+
> [!TIP]
140+
> Once your model parameters all have a `.grad` field, it's the role of the
141+
> [optimizer](https://docs.pytorch.org/docs/stable/optim.html#torch.optim.Optimizer) to update the
142+
> parameters values. This is exactly the same as in standard `torch`.
143+
144+
#### 3. `jac`
145+
146+
If you're simply interested in computing Jacobians without storing them in the `.jac` fields, you
147+
can also use the [`torchjd.autojac.jac`](https://torchjd.org/stable/docs/autojac/jac/) function,
148+
that is analog to
149+
[`torch.autograd.grad`](https://docs.pytorch.org/docs/stable/generated/torch.autograd.grad.html),
150+
except that it computes the Jacobian of a vector of losses rather than the gradient of a scalar
151+
loss.
152+
153+
### Using the `autogram` engine
154+
155+
The Gramian of the Jacobian, defined as the Jacobian multiplied by its transpose, contains all the
156+
dot products between individual gradients. It thus contains all the information about conflict and
157+
gradient imbalance. It turns out that most aggregators from the literature
158+
(e.g. [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/)) make a linear combination of
159+
the rows of the Jacobian, whose weights only depend on the Gramian of the Jacobian.
160+
161+
An alternative implementation of Jacobian descent is thus to:
162+
- Compute this Gramian incrementally (layer by layer), without ever storing the full Jacobian in
163+
memory.
164+
- Extract the weights from it using a
165+
[`Weighting`](https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Weighting).
166+
- Combine the losses using those weights and make a step of gradient descent on the combined loss.
167+
168+
The main advantage of this approach is to save memory because the Jacobian (that is typically large)
169+
never has to be stored in memory. The
170+
[`torchjd.autogram.Engine`](https://torchjd.org/stable/docs/autogram/engine/) is precisely made to
171+
compute the Gramian of the Jacobian efficiently.
172+
173+
The following example shows how to use the `autogram` engine to minimize the vector of per-instance
174+
losses with Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
126175

127176
```diff
128177
import torch
@@ -157,8 +206,8 @@ Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgr
157206
optimizer.zero_grad()
158207
```
159208

160-
Lastly, you can even combine the two approaches by considering multiple tasks and each element of
161-
the batch independently. We call that Instance-Wise Multitask Learning (IWMTL).
209+
You can even go one step further by considering the multiple tasks and each element of the batch
210+
independently. We call that Instance-Wise Multitask Learning (IWMTL).
162211

163212
```python
164213
import torch
@@ -207,7 +256,7 @@ for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
207256
```
208257

209258
> [!NOTE]
210-
> Here, because the losses are a matrix instead of a simple vector, we compute a *generalized
259+
> Here, because the losses are a matrix instead of a simple vector, we compute a *generalized
211260
> Gramian* and we extract weights from it using a
212261
> [GeneralizedWeighting](https://torchjd.org/stable/docs/aggregation/#torchjd.aggregation.GeneralizedWeighting).
213262

0 commit comments

Comments
 (0)