|
| 1 | +import importlib.util |
| 2 | + |
1 | 3 | import torch |
2 | | -from pytest import mark, raises |
| 4 | +from pytest import mark, param, raises |
3 | 5 | from torch import Tensor |
4 | 6 | from utils.tensors import ones_ |
5 | 7 |
|
|
16 | 18 | ) |
17 | 19 | from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices |
18 | 20 |
|
| 21 | +_has_qpth = importlib.util.find_spec("qpth") is not None |
| 22 | +_skip_no_qpth = mark.skipif(not _has_qpth, reason="qpth not installed") |
| 23 | + |
19 | 24 | scaled_pairs = [(UPGrad(), matrix) for matrix in scaled_matrices] |
20 | 25 | typical_pairs = [(UPGrad(), matrix) for matrix in typical_matrices] |
21 | 26 | non_strong_pairs = [(UPGrad(), matrix) for matrix in non_strong_matrices] |
22 | 27 | requires_grad_pairs = [(UPGrad(), ones_(3, 5, requires_grad=True))] |
23 | 28 |
|
| 29 | +_qpth_typical_pairs = [ |
| 30 | + param(UPGrad(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in typical_matrices |
| 31 | +] |
| 32 | +_qpth_non_strong_pairs = [ |
| 33 | + param(UPGrad(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in non_strong_matrices |
| 34 | +] |
| 35 | +_qpth_scaled_pairs = [ |
| 36 | + param(UPGrad(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in scaled_matrices |
| 37 | +] |
24 | 38 |
|
25 | | -@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) |
| 39 | + |
| 40 | +@mark.parametrize( |
| 41 | + ["aggregator", "matrix"], |
| 42 | + scaled_pairs + typical_pairs + _qpth_scaled_pairs + _qpth_typical_pairs, |
| 43 | +) |
26 | 44 | def test_expected_structure(aggregator: UPGrad, matrix: Tensor) -> None: |
27 | 45 | assert_expected_structure(aggregator, matrix) |
28 | 46 |
|
29 | 47 |
|
30 | | -@mark.parametrize(["aggregator", "matrix"], typical_pairs) |
| 48 | +@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs) |
31 | 49 | def test_non_conflicting(aggregator: UPGrad, matrix: Tensor) -> None: |
32 | 50 | assert_non_conflicting(aggregator, matrix, atol=4e-04, rtol=4e-04) |
33 | 51 |
|
34 | 52 |
|
35 | | -@mark.parametrize(["aggregator", "matrix"], typical_pairs) |
| 53 | +@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs) |
36 | 54 | def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor) -> None: |
37 | 55 | assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=5e-07, rtol=5e-07) |
38 | 56 |
|
39 | 57 |
|
40 | | -@mark.parametrize(["aggregator", "matrix"], typical_pairs) |
| 58 | +@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs) |
41 | 59 | def test_linear_under_scaling(aggregator: UPGrad, matrix: Tensor) -> None: |
42 | 60 | assert_linear_under_scaling(aggregator, matrix, n_runs=5, atol=6e-02, rtol=6e-02) |
43 | 61 |
|
44 | 62 |
|
45 | | -@mark.parametrize(["aggregator", "matrix"], non_strong_pairs) |
| 63 | +@mark.parametrize(["aggregator", "matrix"], non_strong_pairs + _qpth_non_strong_pairs) |
46 | 64 | def test_strongly_stationary(aggregator: UPGrad, matrix: Tensor) -> None: |
47 | 65 | assert_strongly_stationary(aggregator, matrix, threshold=5e-03) |
48 | 66 |
|
@@ -70,6 +88,13 @@ def test_representations() -> None: |
70 | 88 | assert str(A) == "UPGrad([1., 2., 3.])" |
71 | 89 |
|
72 | 90 |
|
| 91 | +@_skip_no_qpth |
| 92 | +def test_representations_qpth() -> None: |
| 93 | + A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="qpth") |
| 94 | + assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='qpth')" |
| 95 | + assert str(A) == "UPGrad" |
| 96 | + |
| 97 | + |
73 | 98 | def test_pref_vector_setter_updates_value() -> None: |
74 | 99 | A = UPGrad() |
75 | 100 | new_pref = torch.tensor([1.0, 2.0, 3.0]) |
|
0 commit comments