Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ changelog does not include internal changes that do not affect the user.
Algorithm Based on Decomposition](https://ieeexplore.ieee.org/document/4358754) (IEEE TEVC 2007), a
`Scalarizer` that decomposes the values into a component along a preference direction and a
penalized perpendicular component.
- Added `ExcessMTLWeighting` from [Robust Multi-Task Learning with Excess Risks](https://proceedings.mlr.press/v235/he24n.html) (ICML 2024). It is a stateful `Weighting` that maintains task weights across calls via an exponentiated gradient update driven by per-task excess risk estimates. The excess risk is approximated using an AdaGrad-style diagonal Hessian. An optional `n_warmup_steps` parameter controls how many forward calls collect gradient statistics before weight updates begin.

## [0.15.0] - 2026-06-15

Expand Down
28 changes: 28 additions & 0 deletions NOTICES
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,34 @@ SOFTWARE.

-------------------------------------------------------------------------------

Project: ExcessMTL
Source: https://github.com/uiuctml/ExcessMTL/blob/main/LibMTL/LibMTL/weighting/ExcessMTL.py
Used in: src/torchjd/aggregation/_excess_mtl.py

MIT License

Copyright (c) 2024 UIUC TML Lab

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

-------------------------------------------------------------------------------

Project: SDMGrad
Source: https://github.com/OptMN-Lab/SDMGrad/blob/main/methods/weight_methods.py
Used in: src/torchjd/aggregation/_sdmgrad.py
Expand Down
7 changes: 7 additions & 0 deletions docs/source/docs/aggregation/excess_mtl.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

ExcessMTL
=========

.. autoclass:: torchjd.aggregation.ExcessMTLWeighting
:members: __call__, reset
1 change: 1 addition & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Abstract base classes
constant.rst
cr_mogm.rst
dualproj.rst
excess_mtl.rst
fairgrad.rst
graddrop.rst
gradvac.rst
Expand Down
2 changes: 2 additions & 0 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ._constant import Constant, ConstantWeighting
from ._cr_mogm import CRMOGMWeighting
from ._dualproj import DualProj, DualProjWeighting
from ._excess_mtl import ExcessMTLWeighting
from ._fairgrad import FairGrad, FairGradWeighting
from ._graddrop import GradDrop
from ._gradvac import GradVac, GradVacWeighting
Expand Down Expand Up @@ -74,6 +75,7 @@
"CRMOGMWeighting",
"DualProj",
"DualProjWeighting",
"ExcessMTLWeighting",
"FairGrad",
"FairGradWeighting",
"GradDrop",
Expand Down
184 changes: 184 additions & 0 deletions src/torchjd/aggregation/_excess_mtl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Partly adapted from https://github.com/uiuctml/ExcessMTL — MIT License, Copyright (c) 2024 UIUC TML Lab.
# See NOTICES for the full license text.
from __future__ import annotations

from typing import cast

import torch
from torch import Tensor

from torchjd._mixins import Stateful
from torchjd.aggregation._mixins import _NonDifferentiable
from torchjd.linalg import Matrix

from ._weighting_bases import _MatrixWeighting


class ExcessMTLWeighting(_MatrixWeighting, Stateful, _NonDifferentiable):
r"""
:class:`~torchjd.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Robust
Multi-Task Learning with Excess Risks
<https://proceedings.mlr.press/v235/he24n.html>`_ (ICML 2024).

At each call, task weights are updated via an exponentiated gradient step (Equation 9) driven
by per-task excess risk estimates. The excess risk for task :math:`i` is approximated via a
second-order Taylor expansion (Equations 6-7):

:param robust_step_size: Step size :math:`\eta_\alpha` for the exponentiated weight update.
Must be positive.
:param n_warmup_steps: Number of forward calls during which weights stay uniform
(:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. The baseline excess
risk is set to the average excess risk observed during warmup. When ``0`` (default), the
first call's excess risk is used as the baseline and weights are updated immediately
(matching the official implementation).

.. warning::
The state tensor :math:`S \in \mathbb{R}^{m \times n}` accumulates squared gradients
across **all** calls, where :math:`n` is the total number of model parameters. For large
models this can be a significant memory cost. Call :meth:`reset` between experiments.

.. note::
The weight update is adapted from the `official implementation
<https://github.com/uiuctml/ExcessMTL>`_ and `LibMTL
<https://github.com/median-research-group/LibMTL/blob/main/LibMTL/weighting/ExcessMTL.py>`_.
The warmup strategy follows Appendix C.1 of the paper, which recommends collecting
gradient statistics for several epochs before beginning weight updates; set
``n_warmup_steps`` accordingly (e.g. ``3 * len(dataloader)``).

.. admonition:: Example

.. testcode::

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd import autojac
from torchjd.aggregation import ExcessMTLWeighting, WeightedAggregator
from torchjd.autojac import jac_to_grad

inputs = torch.randn(8, 5)
targets = torch.randn(8, 2)

model = Sequential(Linear(5, 4), ReLU(), Linear(4, 2))
optimizer = SGD(model.parameters())
criterion = MSELoss()
aggregator = WeightedAggregator(ExcessMTLWeighting())

outputs = model(inputs)
losses = [criterion(outputs[:, i], targets[:, i]) for i in range(2)]
autojac.backward(losses)
jac_to_grad(model.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()
"""

def __init__(
self,
robust_step_size: float = 1.0,
n_warmup_steps: int = 0,
) -> None:
super().__init__()
self.robust_step_size = robust_step_size
self.n_warmup_steps = n_warmup_steps
self.register_buffer("_weights", None)
self.register_buffer("_grad_sum", None)
self.register_buffer("_initial_w", None)
self.register_buffer("_warmup_w_sum", None)
self.register_buffer("_n_steps", torch.zeros((), dtype=torch.long))
self._state_key: tuple[int, int, torch.dtype, torch.device] | None = None

@property
def robust_step_size(self) -> float:
return self._robust_step_size

@robust_step_size.setter
def robust_step_size(self, value: float) -> None:
if value <= 0.0:
raise ValueError(
f"Attribute `robust_step_size` must be positive. Found robust_step_size={value!r}."
)
self._robust_step_size = value

@property
def n_warmup_steps(self) -> int:
return self._n_warmup_steps

@n_warmup_steps.setter
def n_warmup_steps(self, value: int) -> None:
if value < 0:
raise ValueError(
f"Attribute `n_warmup_steps` must be non-negative. Found n_warmup_steps={value!r}."
)
self._n_warmup_steps = value

def reset(self) -> None:
"""Clears all state so the next forward starts from uniform weights and re-enters
warmup."""

self._weights = None
self._grad_sum = None
self._initial_w = None
self._warmup_w_sum = None
self._n_steps.zero_()
self._state_key = None

def forward(self, matrix: Matrix, /) -> Tensor:
self._ensure_state(matrix)

# Accumulate squared gradients for AdaGrad-style diagonal Hessian (Equation 7)
grad_sum = cast(Tensor, self._grad_sum)
grad_sum = grad_sum + matrix.detach() ** 2
self._grad_sum = grad_sum

# Excess risk proxy: Ê_i ≈ g_i^T H_i^{-1} g_i (Equation 6)
h = torch.sqrt(grad_sum + 1e-7)
w = (matrix.detach() ** 2 / h).sum(dim=1) # shape [m]

n_steps = int(self._n_steps.item())
self._n_steps = self._n_steps + 1

# Warmup: collect excess risk stats but return uniform weights
if n_steps < self._n_warmup_steps:
warmup_w_sum = self._warmup_w_sum
self._warmup_w_sum = w if warmup_w_sum is None else cast(Tensor, warmup_w_sum) + w

Check warning on line 145 in src/torchjd/aggregation/_excess_mtl.py

View workflow job for this annotation

GitHub Actions / Code quality (ty and pre-commit hooks)

ty (redundant-cast)

src/torchjd/aggregation/_excess_mtl.py:145:65: redundant-cast: Value is already of type `Tensor` help: Remove the redundant `cast`
return cast(Tensor, self._weights)

# Set baseline on the first non-warmup call
if self._initial_w is None:
if self._n_warmup_steps > 0:
# Average excess risk observed during warmup (Appendix C.1)
self._initial_w = cast(Tensor, self._warmup_w_sum) / self._n_warmup_steps
w = w / (cast(Tensor, self._initial_w) + 1e-7)

Check warning on line 153 in src/torchjd/aggregation/_excess_mtl.py

View workflow job for this annotation

GitHub Actions / Code quality (ty and pre-commit hooks)

ty (redundant-cast)

src/torchjd/aggregation/_excess_mtl.py:153:26: redundant-cast: Value is already of type `Tensor` help: Remove the redundant `cast`
else:
# Official impl behaviour: first call's excess is the baseline; use w raw
self._initial_w = w
else:
w = w / (cast(Tensor, self._initial_w) + 1e-7)

Check warning on line 158 in src/torchjd/aggregation/_excess_mtl.py

View workflow job for this annotation

GitHub Actions / Code quality (ty and pre-commit hooks)

ty (redundant-cast)

src/torchjd/aggregation/_excess_mtl.py:158:22: redundant-cast: Value is already of type `Tensor` help: Remove the redundant `cast`

# Exponentiated gradient weight update (Equation 9)
weights = cast(Tensor, self._weights)
weights = weights * torch.exp(w * self._robust_step_size)
weights = weights / weights.sum()
self._weights = weights
return weights

def _ensure_state(self, matrix: Matrix) -> None:
key = (matrix.shape[0], matrix.shape[1], matrix.dtype, matrix.device)
if self._state_key == key and self._grad_sum is not None:
return
m, n = matrix.shape
self._grad_sum = matrix.new_zeros(m, n)
self._weights = matrix.new_full((m,), 1.0 / m)
self._initial_w = None
self._warmup_w_sum = None
self._n_steps.zero_()
self._state_key = key

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"robust_step_size={self.robust_step_size!r}, "
f"n_warmup_steps={self.n_warmup_steps!r})"
)
Loading
Loading