-
Notifications
You must be signed in to change notification settings - Fork 20
feat(scalarization): Add readme #749
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| # Scalarization | ||
|
|
||
| This package implements the `Scalarizer`s: objects that reduce a tensor of values (typically a | ||
| vector of losses) into a single scalar optimizable with a standard `loss.backward()`. | ||
|
|
||
| This file is for contributors working on scalarizers. For the list of available scalarizers and their | ||
| full API, see [torchjd.org](https://torchjd.org/latest/docs/scalarization/). | ||
|
|
||
| ## The abstraction | ||
|
|
||
| A scalarizer captures a single decision: **how to collapse a vector of values into one scalar to | ||
| minimize**. It operates purely on those values: it has no notion of the losses, tasks, or model they | ||
| come from, which is why its input is named `values` and not `losses`. It is the value-level | ||
| counterpart of an aggregator, which makes the same decision at the gradient level. Everything after | ||
| it (backpropagation, the optimizer step) is standard PyTorch. | ||
|
|
||
| Concretely, it subclasses `Scalarizer` (in [`_scalarizer_base.py`](_scalarizer_base.py)) and | ||
| implements one method: | ||
|
|
||
| ```python | ||
| def forward(self, values: Tensor, /) -> Tensor: | ||
| ... | ||
| ``` | ||
|
|
||
| - **Any shape in, scalar out:** it reduces over *all* elements of `values` (scalar, vector, matrix, | ||
| higher-dim) into a single scalar. | ||
| - **Pure and differentiable:** the output depends only on `values` and the configured parameters, so | ||
| that `scalarizer(values).backward()` produces the gradient. | ||
|
Comment on lines
+17
to
+28
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part seems like something that should be explained in the code itself (if not already), either in a public docstring for things that are intended to be user-facing, or in a comment for things that are intended to be contributor-facing. |
||
|
|
||
| ## Adding one | ||
|
|
||
| A new scalarizer is a class plus the files that register it. Mirror an existing scalarizer of the | ||
| same kind: | ||
|
|
||
| - `_<name>.py`: the class. | ||
| - `__init__.py`: the import and an `__all__` entry. | ||
| - `docs/source/docs/scalarization/<name>.rst`: the docs page, added to the `index.rst` toctree. | ||
| - `tests/unit/scalarization/test_<name>.py`: the tests. | ||
| - `CHANGELOG.md`: an entry under `[Unreleased]`. | ||
|
Comment on lines
+30
to
+39
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Honestly I feel like this should be explained in a skill rather than here. But I also think skills should be addressed to humans and agents alike. |
||
|
|
||
| ## State | ||
|
|
||
| Most scalarizers are stateless. Keep yours stateless unless the method genuinely needs state (learned | ||
| weights, a loss history). When it does: | ||
|
|
||
| - **Subclass `Stateful`** (`from torchjd._mixins import Stateful`) and implement `reset()` to restore | ||
| the initial state. | ||
|
Comment on lines
+46
to
+47
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have Randomness? If so do we have the Stochastic mixin that we also use in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do have a random baseline, but there is no Stochastic mixin. It just calls torch.randn directly
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My bad, I thought that In the aggregation package, we consider random aggregators as stateful (the seed is the state), I think it would be beneficial to do that in
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes we abandoned this idea because it was quite tricky to make it work on cuda: See: The biggest problems were that something can be stochastic without directly owning a generator (it can own another stochastic object that does own a generator), and generators need a device to be created, so they can't be simply created ahead of time. |
||
| - **Keep `forward` self-contained.** Do not hide cross-call state or side effects inside it. When the | ||
| method must carry information between calls, expose it through an explicit, named method and | ||
| document the protocol (e.g. a per-epoch `step()`, or an `update()` after the optimizer step). | ||
| - **`nn.Parameter` vs buffer:** trainable state is an `nn.Parameter`; non-trained tensors that must | ||
| move with `.to()` are registered with `register_buffer`. | ||
|
|
||
| Randomness is not state: a scalarizer may draw fresh randomness on each call (like the random | ||
| baseline) without being `Stateful`. There is no stochastic mixin; it just uses the global torch RNG, | ||
| so document the behavior and let users seed it with `torch.manual_seed`. | ||
|
|
||
| ## Things to be careful about | ||
|
|
||
| - **Determinism and side effects:** the output should depend only on `values`, the configured | ||
| parameters, and (if the method is intentionally random) the global RNG. Any state change must be | ||
| deliberate, explicit, and undone by `reset()`. | ||
|
Comment on lines
+60
to
+62
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't necessarily agree with that. We want to be able to give extra information through setters. E.g. a scalarizer that takes the gramian every once in a while: my_agg = MyAgg()
for i in ...:
losses = ...
if i % 1000 == 0:
gramian = engine.compute_gramian(losses)
my_agg.set_gramian(gramian)
loss = my_agg(losses)
loss.backward()
...Note that such a scalarizer could use internally for example an This would be equivalent to: W = UPGradWeighting()
for i in ...:
losses = ...
if i % 1000 == 0:
gramian = engine.compute_gramian(losses)
weights = W(gramian)
losses.backward(weights)
...So my point is that really a scalarizer could use anything to make its decision, let's not be restrictive here. |
||
| - **Numerical stability:** keep the reduction finite on the edges of its domain (log-sum-exp | ||
| centering, an eps under a norm or in a denominator, etc.), and explain any value shift in a comment | ||
| and a `.. note::`. | ||
| - **Hyperparameters:** when a coefficient has no single good value across problems, make it required | ||
| rather than guessing a default, and validate it in `__init__`. | ||
| - **Shape validation:** check parameter shapes against `values` at call time and raise `ValueError`. | ||
| - **Preconditions:** if the method is undefined on some inputs, document it in a `.. note::` and lock | ||
| it with a test (e.g. assert `nan` propagates rather than being silently clamped). | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, the values are 99% of the time losses so it's a bit confusing to say that it has no notion of the losses. Also I think we should say that in most cases, values are losses, but the scalarization package has been designed to be able to scalarize any tensor of values.