@@ -56,24 +56,38 @@ Some aggregators may have additional dependencies. Please refer to the
5656[ installation documentation] ( https://torchjd.org/stable/installation ) for them.
5757
5858## Usage
59- There are two main ways to use TorchJD. The first one is to replace the usual call to
59+
60+ Compared to standard ` torch ` , ` torchjd ` simply changes the way to obtain the ` .grad ` fields of your
61+ model parameters.
62+
63+ ### Using the ` autojac ` engine
64+
65+ The autojac engine is for computing and aggregating Jacobians efficiently.
66+
67+ #### 1. ` backward ` + ` jac_to_grad `
68+ In standard ` torch ` , you generally combine your ` losses ` into a single scalar ` loss ` , and call
69+ ` loss.backward() ` to compute the gradient of the loss with respect to each model parameter and to
70+ store it in the ` .grad ` fields of those parameters. The basic usage of ` torchjd ` is to replace this
6071` loss.backward() ` by a call to
61- [ ` torchjd.autojac.backward ` ] ( https://torchjd.org/stable/docs/autojac/backward/ ) or
62- [ ` torchjd.autojac.mtl_backward ` ] ( https://torchjd.org/stable/docs/autojac/mtl_backward/ ) , depending
63- on the use-case. This will compute the Jacobian of the vector of losses with respect to the model
64- parameters, and aggregate it with the specified
65- [ ` Aggregator ` ] ( https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Aggregator ) .
66- Whenever you want to optimize the vector of per-sample losses, you should rather use the
67- [ ` torchjd.autogram.Engine ` ] ( https://torchjd.org/stable/docs/autogram/engine/ ) . Instead of
68- computing the full Jacobian at once, it computes the Gramian of this Jacobian, layer by layer, in a
69- memory-efficient way. A vector of weights (one per element of the batch) can then be extracted from
70- this Gramian, using a
71- [ ` Weighting ` ] ( https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Weighting ) ,
72- and used to combine the losses of the batch. Assuming each element of the batch is
73- processed independently from the others, this approach is equivalent to
74- [ ` torchjd.autojac.backward ` ] ( https://torchjd.org/stable/docs/autojac/backward/ ) while being
75- generally much faster due to the lower memory usage. Note that we're still working on making
76- ` autogram ` faster and more memory-efficient, and it's interface may change in future releases.
72+ [ ` torchjd.autojac.backward(losses) ` ] ( https://torchjd.org/stable/docs/autojac/backward/ ) . Instead of
73+ computing the gradient of a scalar loss, it will compute the Jacobian of a vector of losses, and
74+ store it in the ` .jac ` fields of the model parameters. You then have to call
75+ [ ` torchjd.autojac.jac_to_grad ` ] ( https://torchjd.org/stable/docs/autojac/jac_to_grad/ ) to aggregate
76+ this Jacobian using the specified
77+ [ ` Aggregator ` ] ( https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Aggregator ) , and to
78+ store the result into the ` .grad ` fields of the model parameters. See this
79+ [ usage example] ( https://torchjd.org/stable/examples/basic_usage/ ) for more details.
80+
81+ #### 2. ` mtl_backward ` + ` jac_to_grad `
82+ In the case of multi-task learning, an alternative to
83+ [ ` torchjd.autojac.backward ` ] ( https://torchjd.org/stable/docs/autojac/backward/ ) is
84+ [ ` torchjd.autojac.mtl_backward ` ] ( https://torchjd.org/stable/docs/autojac/mtl_backward/ ) . It computes
85+ the gradient of each task-specific loss with respect to the corresponding task's parameters, and
86+ stores it in their ` .grad ` fields. It also computes the Jacobian of the vector of losses with
87+ respect to the shared parameters and stores it in their ` .jac ` field. Then, the
88+ [ ` torchjd.autojac.jac_to_grad ` ] ( https://torchjd.org/stable/docs/autojac/jac_to_grad/ ) function can
89+ be called to aggregate this Jacobian and replace the ` .jac ` fields by ` .grad ` fields for the shared
90+ parameters.
7791
7892The following example shows how to use TorchJD to train a multi-task model with Jacobian descent,
7993using [ UPGrad] ( https://torchjd.org/stable/docs/aggregation/upgrad/ ) .
@@ -83,7 +97,7 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
8397 from torch.nn import Linear, MSELoss, ReLU, Sequential
8498 from torch.optim import SGD
8599
86- + from torchjd.autojac import mtl_backward
100+ + from torchjd.autojac import jac_to_grad, mtl_backward
87101+ from torchjd.aggregation import UPGrad
88102
89103 shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
@@ -112,7 +126,8 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
112126
113127- loss = loss1 + loss2
114128- loss.backward()
115- + mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
129+ + mtl_backward([loss1, loss2], features=features)
130+ + jac_to_grad(shared_module.parameters(), aggregator)
116131 optimizer.step()
117132 optimizer.zero_grad()
118133```
@@ -121,8 +136,42 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
121136> In this example, the Jacobian is only with respect to the shared parameters. The task-specific
122137> parameters are simply updated via the gradient of their task’s loss with respect to them.
123138
124- The following example shows how to use TorchJD to minimize the vector of per-instance losses with
125- Jacobian descent using [ UPGrad] ( https://torchjd.org/stable/docs/aggregation/upgrad/ ) .
139+ > [ !TIP]
140+ > Once your model parameters all have a ` .grad ` field, it's the role of the
141+ > [ optimizer] ( https://docs.pytorch.org/docs/stable/optim.html#torch.optim.Optimizer ) to update the
142+ > parameters values. This is exactly the same as in standard ` torch ` .
143+
144+ #### 3. ` jac `
145+
146+ If you're simply interested in computing Jacobians without storing them in the ` .jac ` fields, you
147+ can also use the [ ` torchjd.autojac.jac ` ] ( https://torchjd.org/stable/docs/autojac/jac/ ) function,
148+ that is analog to
149+ [ ` torch.autograd.grad ` ] ( https://docs.pytorch.org/docs/stable/generated/torch.autograd.grad.html ) ,
150+ except that it computes the Jacobian of a vector of losses rather than the gradient of a scalar
151+ loss.
152+
153+ ### Using the ` autogram ` engine
154+
155+ The Gramian of the Jacobian, defined as the Jacobian multiplied by its transpose, contains all the
156+ dot products between individual gradients. It thus contains all the information about conflict and
157+ gradient imbalance. It turns out that most aggregators from the literature
158+ (e.g. [ UPGrad] ( https://torchjd.org/stable/docs/aggregation/upgrad/ ) ) make a linear combination of
159+ the rows of the Jacobian, whose weights only depend on the Gramian of the Jacobian.
160+
161+ An alternative implementation of Jacobian descent is thus to:
162+ - Compute this Gramian incrementally (layer by layer), without ever storing the full Jacobian in
163+ memory.
164+ - Extract the weights from it using a
165+ [ ` Weighting ` ] ( https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Weighting ) .
166+ - Combine the losses using those weights and make a step of gradient descent on the combined loss.
167+
168+ The main advantage of this approach is to save memory because the Jacobian (that is typically large)
169+ never has to be stored in memory. The
170+ [ ` torchjd.autogram.Engine ` ] ( https://torchjd.org/stable/docs/autogram/engine/ ) is precisely made to
171+ compute the Gramian of the Jacobian efficiently.
172+
173+ The following example shows how to use the ` autogram ` engine to minimize the vector of per-instance
174+ losses with Jacobian descent using [ UPGrad] ( https://torchjd.org/stable/docs/aggregation/upgrad/ ) .
126175
127176``` diff
128177 import torch
@@ -157,8 +206,8 @@ Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgr
157206 optimizer.zero_grad()
158207```
159208
160- Lastly, you can even combine the two approaches by considering multiple tasks and each element of
161- the batch independently. We call that Instance-Wise Multitask Learning (IWMTL).
209+ You can even go one step further by considering the multiple tasks and each element of the batch
210+ independently. We call that Instance-Wise Multitask Learning (IWMTL).
162211
163212``` python
164213import torch
@@ -207,7 +256,7 @@ for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
207256```
208257
209258> [ !NOTE]
210- > Here, because the losses are a matrix instead of a simple vector, we compute a * generalized
259+ > Here, because the losses are a matrix instead of a simple vector, we compute a * generalized
211260> Gramian* and we extract weights from it using a
212261> [ GeneralizedWeighting] ( https://torchjd.org/stable/docs/aggregation/#torchjd.aggregation.GeneralizedWeighting ) .
213262
0 commit comments