Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions src/torchjd/scalarization/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Scalarization

This package implements the `Scalarizer`s: objects that reduce a tensor of values (typically a
vector of losses) into a single scalar optimizable with a standard `loss.backward()`.

This file is for contributors working on scalarizers. For the list of available scalarizers and their
full API, see [torchjd.org](https://torchjd.org/latest/docs/scalarization/).

## The abstraction

A scalarizer captures a single decision: **how to collapse a vector of values into one scalar to
minimize**. It operates purely on those values: it has no notion of the losses, tasks, or model they

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the values are 99% of the time losses so it's a bit confusing to say that it has no notion of the losses. Also I think we should say that in most cases, values are losses, but the scalarization package has been designed to be able to scalarize any tensor of values.

come from, which is why its input is named `values` and not `losses`. It is the value-level
counterpart of an aggregator, which makes the same decision at the gradient level. Everything after
it (backpropagation, the optimizer step) is standard PyTorch.

Concretely, it subclasses `Scalarizer` (in [`_scalarizer_base.py`](_scalarizer_base.py)) and
implements one method:

```python
def forward(self, values: Tensor, /) -> Tensor:
...
```

- **Any shape in, scalar out:** it reduces over *all* elements of `values` (scalar, vector, matrix,
higher-dim) into a single scalar.
- **Pure and differentiable:** the output depends only on `values` and the configured parameters, so
that `scalarizer(values).backward()` produces the gradient.
Comment on lines +17 to +28

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part seems like something that should be explained in the code itself (if not already), either in a public docstring for things that are intended to be user-facing, or in a comment for things that are intended to be contributor-facing.


## Adding one

A new scalarizer is a class plus the files that register it. Mirror an existing scalarizer of the
same kind:

- `_<name>.py`: the class.
- `__init__.py`: the import and an `__all__` entry.
- `docs/source/docs/scalarization/<name>.rst`: the docs page, added to the `index.rst` toctree.
- `tests/unit/scalarization/test_<name>.py`: the tests.
- `CHANGELOG.md`: an entry under `[Unreleased]`.
Comment on lines +30 to +39

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly I feel like this should be explained in a skill rather than here. But I also think skills should be addressed to humans and agents alike.


## State

Most scalarizers are stateless. Keep yours stateless unless the method genuinely needs state (learned
weights, a loss history). When it does:

- **Subclass `Stateful`** (`from torchjd._mixins import Stateful`) and implement `reset()` to restore
the initial state.
Comment on lines +46 to +47

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have Randomness? If so do we have the Stochastic mixin that we also use in aggregation? Maybe we can mention it here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have a random baseline, but there is no Stochastic mixin. It just calls torch.randn directly

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, I thought that In the aggregation package, we consider random aggregators as stateful (the seed is the state), I think it would be beneficial to do that in scalarization. @ValerianRey Didn't we go in that direction? Did we abandon that idea?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we abandoned this idea because it was quite tricky to make it work on cuda:

See:

The biggest problems were that something can be stochastic without directly owning a generator (it can own another stochastic object that does own a generator), and generators need a device to be created, so they can't be simply created ahead of time.

- **Keep `forward` self-contained.** Do not hide cross-call state or side effects inside it. When the
method must carry information between calls, expose it through an explicit, named method and
document the protocol (e.g. a per-epoch `step()`, or an `update()` after the optimizer step).
- **`nn.Parameter` vs buffer:** trainable state is an `nn.Parameter`; non-trained tensors that must
move with `.to()` are registered with `register_buffer`.

Randomness is not state: a scalarizer may draw fresh randomness on each call (like the random
baseline) without being `Stateful`. There is no stochastic mixin; it just uses the global torch RNG,
so document the behavior and let users seed it with `torch.manual_seed`.

## Things to be careful about

- **Determinism and side effects:** the output should depend only on `values`, the configured
parameters, and (if the method is intentionally random) the global RNG. Any state change must be
deliberate, explicit, and undone by `reset()`.
Comment on lines +60 to +62

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't necessarily agree with that. We want to be able to give extra information through setters. E.g. a scalarizer that takes the gramian every once in a while:

my_agg = MyAgg()

for i in ...:
  losses = ...
  if i % 1000 == 0:
    gramian = engine.compute_gramian(losses)
    my_agg.set_gramian(gramian)
  loss = my_agg(losses)
  loss.backward()
  ...

Note that such a scalarizer could use internally for example an UPGradWeighting, return weights that would be used for the next 1000 steps, etc.

This would be equivalent to:

W = UPGradWeighting()

for i in ...:
  losses = ...
  if i % 1000 == 0:
    gramian = engine.compute_gramian(losses)
    weights = W(gramian)
  losses.backward(weights)
  ...

So my point is that really a scalarizer could use anything to make its decision, let's not be restrictive here.

- **Numerical stability:** keep the reduction finite on the edges of its domain (log-sum-exp
centering, an eps under a norm or in a denominator, etc.), and explain any value shift in a comment
and a `.. note::`.
- **Hyperparameters:** when a coefficient has no single good value across problems, make it required
rather than guessing a default, and validate it in `__init__`.
- **Shape validation:** check parameter shapes against `values` at call time and raise `ValueError`.
- **Preconditions:** if the method is undefined on some inputs, document it in a `.. note::` and lock
it with a test (e.g. assert `nan` propagates rather than being silently clamped).
Loading