Skip to content

Commit 30f75ab

Browse files
committed
Add qpth solver for UPGrad and DualProj
1 parent 59795e2 commit 30f75ab

6 files changed

Lines changed: 128 additions & 17 deletions

File tree

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,13 @@ nash_mtl = [
113113
cagrad = [
114114
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
115115
]
116+
qpth = [
117+
"qpth>=0.0.15",
118+
]
116119
full = [
117120
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
118121
"ecos>=2.0.14", # Does not work before 2.0.14
122+
"qpth>=0.0.15",
119123
]
120124

121125
[tool.pytest.ini_options]

src/torchjd/aggregation/_dualproj.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ class DualProjWeighting(GramianWeighting):
2222
numerical errors when computing the gramian, it might not exactly be positive definite.
2323
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
2424
ensures that it is positive definite.
25-
:param solver: The solver used to optimize the underlying optimization problem.
25+
:param solver: The solver used to optimize the underlying optimization problem. Use
26+
``"quadprog"`` (default) for a CPU-based solver or ``"qpth"`` to solve natively on the
27+
device of the input tensors (requires the optional ``qpth`` package).
2628
"""
2729

2830
def __init__(
@@ -90,7 +92,9 @@ class DualProj(GramianWeightedAggregator):
9092
numerical errors when computing the gramian, it might not exactly be positive definite.
9193
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
9294
ensures that it is positive definite.
93-
:param solver: The solver used to optimize the underlying optimization problem.
95+
:param solver: The solver used to optimize the underlying optimization problem. Use
96+
``"quadprog"`` (default) for a CPU-based solver or ``"qpth"`` to solve natively on the
97+
device of the input tensors (requires the optional ``qpth`` package).
9498
"""
9599

96100
gramian_weighting: DualProjWeighting

src/torchjd/aggregation/_upgrad.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ class UPGradWeighting(GramianWeighting):
2323
numerical errors when computing the gramian, it might not exactly be positive definite.
2424
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
2525
ensures that it is positive definite.
26-
:param solver: The solver used to optimize the underlying optimization problem.
26+
:param solver: The solver used to optimize the underlying optimization problem. Use
27+
``"quadprog"`` (default) for a CPU-based solver or ``"qpth"`` to solve natively on the
28+
device of the input tensors (requires the optional ``qpth`` package).
2729
"""
2830

2931
def __init__(
@@ -93,7 +95,9 @@ class UPGrad(GramianWeightedAggregator):
9395
numerical errors when computing the gramian, it might not exactly be positive definite.
9496
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
9597
ensures that it is positive definite.
96-
:param solver: The solver used to optimize the underlying optimization problem.
98+
:param solver: The solver used to optimize the underlying optimization problem. Use
99+
``"quadprog"`` (default) for a CPU-based solver or ``"qpth"`` to solve natively on the
100+
device of the input tensors (requires the optional ``qpth`` package).
97101
"""
98102

99103
gramian_weighting: UPGradWeighting

src/torchjd/aggregation/_utils/dual_cone.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from qpsolvers import solve_qp
66
from torch import Tensor
77

8-
SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"]
8+
SUPPORTED_SOLVER: TypeAlias = Literal["quadprog", "qpth"]
99

1010

1111
def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor:
@@ -15,10 +15,15 @@ def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor:
1515
1616
:param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`.
1717
:param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite.
18-
:param solver: The quadratic programming solver to use.
18+
:param solver: The quadratic programming solver to use. ``"quadprog"`` converts tensors to
19+
CPU numpy arrays and uses qpsolvers. ``"qpth"`` solves natively on the same device as
20+
the input tensors (e.g. CUDA) using the ``qpth`` package (optional dependency).
1921
:return: A tensor of projection weights with the same shape as `U`.
2022
"""
2123

24+
if solver == "qpth":
25+
return _project_weights_qpth(U, G)
26+
2227
G_ = _to_array(G)
2328
U_ = _to_array(U)
2429

@@ -27,6 +32,50 @@ def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor:
2732
return torch.as_tensor(W, device=G.device, dtype=G.dtype)
2833

2934

35+
def _project_weights_qpth(U: Tensor, G: Tensor) -> Tensor:
36+
r"""
37+
Computes the tensor of projection weights using qpth, keeping computation on the device of
38+
the input tensors and running without gradient tracking.
39+
40+
:param U: The tensor of weights to project, of shape `[..., m]`.
41+
:param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite.
42+
"""
43+
from qpth.qp import QPFunction # lazy import: qpth is an optional dependency
44+
45+
shape = U.shape
46+
m = shape[-1]
47+
batch_size = U.numel() // m
48+
device = G.device
49+
original_dtype = G.dtype
50+
51+
# Use float64 for numerical precision, matching the quadprog solver's behavior.
52+
U_flat = U.reshape(batch_size, m).double()
53+
G_double = G.double()
54+
55+
# QP formulation: minimize (1/2) v^T (2G) v + 0^T v subject to -I v <= -u (i.e., u <= v)
56+
Q = (2.0 * G_double).unsqueeze(0).expand(batch_size, m, m).contiguous()
57+
p = torch.zeros(batch_size, m, device=device, dtype=torch.float64)
58+
G_ineq = (
59+
(-torch.eye(m, device=device, dtype=torch.float64))
60+
.unsqueeze(0)
61+
.expand(batch_size, m, m)
62+
.contiguous()
63+
)
64+
h_ineq = -U_flat
65+
A = torch.zeros(batch_size, 0, m, device=device, dtype=torch.float64)
66+
b = torch.zeros(batch_size, 0, device=device, dtype=torch.float64)
67+
68+
with torch.no_grad():
69+
W_flat = QPFunction(verbose=False, maxIter=10, check_Q_spd=False, notImprovedLim=1)(
70+
Q, p, G_ineq, h_ineq, A, b
71+
)
72+
73+
if torch.any(torch.isnan(W_flat)):
74+
raise ValueError("Failed to solve the quadratic programming problem.")
75+
76+
return W_flat.to(original_dtype).reshape(shape)
77+
78+
3079
def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: SUPPORTED_SOLVER) -> np.ndarray:
3180
r"""
3281
Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`,

tests/unit/aggregation/test_dualproj.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import importlib.util
2+
13
import torch
2-
from pytest import mark, raises
4+
from pytest import mark, param, raises
35
from torch import Tensor
46
from utils.tensors import ones_
57

@@ -15,28 +17,44 @@
1517
)
1618
from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
1719

20+
_has_qpth = importlib.util.find_spec("qpth") is not None
21+
_skip_no_qpth = mark.skipif(not _has_qpth, reason="qpth not installed")
22+
1823
scaled_pairs = [(DualProj(), matrix) for matrix in scaled_matrices]
1924
typical_pairs = [(DualProj(), matrix) for matrix in typical_matrices]
2025
non_strong_pairs = [(DualProj(), matrix) for matrix in non_strong_matrices]
2126
requires_grad_pairs = [(DualProj(), ones_(3, 5, requires_grad=True))]
2227

28+
_qpth_typical_pairs = [
29+
param(DualProj(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in typical_matrices
30+
]
31+
_qpth_non_strong_pairs = [
32+
param(DualProj(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in non_strong_matrices
33+
]
34+
_qpth_scaled_pairs = [
35+
param(DualProj(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in scaled_matrices
36+
]
2337

24-
@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
38+
39+
@mark.parametrize(
40+
["aggregator", "matrix"],
41+
scaled_pairs + typical_pairs + _qpth_scaled_pairs + _qpth_typical_pairs,
42+
)
2543
def test_expected_structure(aggregator: DualProj, matrix: Tensor) -> None:
2644
assert_expected_structure(aggregator, matrix)
2745

2846

29-
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
47+
@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs)
3048
def test_non_conflicting(aggregator: DualProj, matrix: Tensor) -> None:
3149
assert_non_conflicting(aggregator, matrix, atol=1e-04, rtol=1e-04)
3250

3351

34-
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
52+
@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs)
3553
def test_permutation_invariant(aggregator: DualProj, matrix: Tensor) -> None:
3654
assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=2e-07, rtol=2e-07)
3755

3856

39-
@mark.parametrize(["aggregator", "matrix"], non_strong_pairs)
57+
@mark.parametrize(["aggregator", "matrix"], non_strong_pairs + _qpth_non_strong_pairs)
4058
def test_strongly_stationary(aggregator: DualProj, matrix: Tensor) -> None:
4159
assert_strongly_stationary(aggregator, matrix, threshold=3e-03)
4260

@@ -66,6 +84,13 @@ def test_representations() -> None:
6684
assert str(A) == "DualProj([1., 2., 3.])"
6785

6886

87+
@_skip_no_qpth
88+
def test_representations_qpth() -> None:
89+
A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="qpth")
90+
assert repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='qpth')"
91+
assert str(A) == "DualProj"
92+
93+
6994
def test_pref_vector_setter_updates_value() -> None:
7095
A = DualProj()
7196
new_pref = torch.tensor([1.0, 2.0, 3.0])

tests/unit/aggregation/test_upgrad.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import importlib.util
2+
13
import torch
2-
from pytest import mark, raises
4+
from pytest import mark, param, raises
35
from torch import Tensor
46
from utils.tensors import ones_
57

@@ -16,33 +18,49 @@
1618
)
1719
from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
1820

21+
_has_qpth = importlib.util.find_spec("qpth") is not None
22+
_skip_no_qpth = mark.skipif(not _has_qpth, reason="qpth not installed")
23+
1924
scaled_pairs = [(UPGrad(), matrix) for matrix in scaled_matrices]
2025
typical_pairs = [(UPGrad(), matrix) for matrix in typical_matrices]
2126
non_strong_pairs = [(UPGrad(), matrix) for matrix in non_strong_matrices]
2227
requires_grad_pairs = [(UPGrad(), ones_(3, 5, requires_grad=True))]
2328

29+
_qpth_typical_pairs = [
30+
param(UPGrad(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in typical_matrices
31+
]
32+
_qpth_non_strong_pairs = [
33+
param(UPGrad(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in non_strong_matrices
34+
]
35+
_qpth_scaled_pairs = [
36+
param(UPGrad(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in scaled_matrices
37+
]
2438

25-
@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
39+
40+
@mark.parametrize(
41+
["aggregator", "matrix"],
42+
scaled_pairs + typical_pairs + _qpth_scaled_pairs + _qpth_typical_pairs,
43+
)
2644
def test_expected_structure(aggregator: UPGrad, matrix: Tensor) -> None:
2745
assert_expected_structure(aggregator, matrix)
2846

2947

30-
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
48+
@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs)
3149
def test_non_conflicting(aggregator: UPGrad, matrix: Tensor) -> None:
3250
assert_non_conflicting(aggregator, matrix, atol=4e-04, rtol=4e-04)
3351

3452

35-
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
53+
@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs)
3654
def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor) -> None:
3755
assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=5e-07, rtol=5e-07)
3856

3957

40-
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
58+
@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs)
4159
def test_linear_under_scaling(aggregator: UPGrad, matrix: Tensor) -> None:
4260
assert_linear_under_scaling(aggregator, matrix, n_runs=5, atol=6e-02, rtol=6e-02)
4361

4462

45-
@mark.parametrize(["aggregator", "matrix"], non_strong_pairs)
63+
@mark.parametrize(["aggregator", "matrix"], non_strong_pairs + _qpth_non_strong_pairs)
4664
def test_strongly_stationary(aggregator: UPGrad, matrix: Tensor) -> None:
4765
assert_strongly_stationary(aggregator, matrix, threshold=5e-03)
4866

@@ -70,6 +88,13 @@ def test_representations() -> None:
7088
assert str(A) == "UPGrad([1., 2., 3.])"
7189

7290

91+
@_skip_no_qpth
92+
def test_representations_qpth() -> None:
93+
A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="qpth")
94+
assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='qpth')"
95+
assert str(A) == "UPGrad"
96+
97+
7398
def test_pref_vector_setter_updates_value() -> None:
7499
A = UPGrad()
75100
new_pref = torch.tensor([1.0, 2.0, 3.0])

0 commit comments

Comments
 (0)