Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pyrecest/_backend/_shared_numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def sqrtm(x):
return _np.vectorize(_scipy.linalg.sqrtm, signature="(n,m)->(n,m)")(x)


def quadratic_assignment(a, b, options):
def quadratic_assignment(a, b, options=None):
return list(_scipy.optimize.quadratic_assignment(a, b, options=options).col_ind)


Expand Down
2 changes: 1 addition & 1 deletion src/pyrecest/_backend/pytorch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def matrix_rank(a, tol=None, hermitian=False, *, rtol=None, atol=None, **kwargs)
return _torch.linalg.matrix_rank(a, atol=atol, rtol=rtol, hermitian=hermitian)


def quadratic_assignment(a, b, options):
def quadratic_assignment(a, b, options=None):
return list(
_scipy.optimize.quadratic_assignment(
_as_numpy_no_grad(a), _as_numpy_no_grad(b), options=options
Expand Down
54 changes: 54 additions & 0 deletions tests/test_quadratic_assignment_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import importlib.util

import numpy as np

from pyrecest._backend import numpy as numpy_backend

pytorch_backend = None
if importlib.util.find_spec("torch") is not None:
from pyrecest._backend import pytorch as pytorch_backend


def _problem_matrices():
adjacency = np.array(
[
[0.0, 2.0, 1.0],
[2.0, 0.0, 3.0],
[1.0, 3.0, 0.0],
]
)
permuted = adjacency[[1, 2, 0]][:, [1, 2, 0]]
return adjacency, permuted


def test_numpy_quadratic_assignment_accepts_default_options():
adjacency, permuted = _problem_matrices()

assignment = numpy_backend.linalg.quadratic_assignment(adjacency, permuted)

assert sorted(assignment) == [0, 1, 2]


def test_numpy_quadratic_assignment_still_accepts_options_dict():
adjacency, permuted = _problem_matrices()

assignment = numpy_backend.linalg.quadratic_assignment(
adjacency,
permuted,
options={"maximize": False},
)

assert sorted(assignment) == [0, 1, 2]


def test_pytorch_quadratic_assignment_accepts_default_options():
if pytorch_backend is None:
return
adjacency, permuted = _problem_matrices()

assignment = pytorch_backend.linalg.quadratic_assignment(
pytorch_backend.array(adjacency),
pytorch_backend.array(permuted),
)

assert sorted(assignment) == [0, 1, 2]
Loading