Skip to content

Commit fcd4af1

Browse files
committed
Fix the creation of Permutation symbols
1 parent 7204b53 commit fcd4af1

3 files changed

Lines changed: 99 additions & 36 deletions

File tree

dwave/optimization/symbols.pyx

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -663,10 +663,6 @@ cdef class ArgSort(ArraySymbol):
663663
_register(ArgSort, typeid(cppArgSortNode))
664664

665665

666-
cdef bool _empty_slice(object slice_) noexcept:
667-
return slice_.start is None and slice_.stop is None and slice_.step is None
668-
669-
670666
cdef class AdvancedIndexing(ArraySymbol):
671667
"""Advanced indexing.
672668
@@ -711,35 +707,50 @@ cdef class AdvancedIndexing(ArraySymbol):
711707

712708
array = next(self.iter_predecessors())
713709

714-
if (
715-
isinstance(array, Constant)
716-
and array.ndim() == 2
717-
and array.shape()[0] == array.shape()[1] # square matrix
718-
and self.ptr.indices().size() == 2
719-
and isinstance(index, tuple)
720-
and len(index) == 2
721-
):
722-
i0, i1 = index
710+
if (perm := self._check_permutation(index)) is not None:
711+
return perm
723712

724-
# check the [x, :][:, x] case
725-
if (isinstance(i0, slice) and _empty_slice(i0) and
726-
isinstance(i1, ArraySymbol) and
727-
holds_alternative[cppArrayNodePtr](self.ptr.indices()[0]) and
728-
get[cppArrayNodePtr](self.ptr.indices()[0]) == (<ArraySymbol>i1).array_ptr and
729-
holds_alternative[cppSlice](self.ptr.indices()[1])):
713+
return super().__getitem__(index)
730714

731-
return Permutation(array, i1)
715+
cdef object _check_permutation(self, index):
716+
"""Return a Permutation symbol if the indexing the symbol results in
717+
a permutation, otherwise return None.
718+
"""
732719

733-
# check the [:, x][x, :] case
734-
if (isinstance(i1, slice) and _empty_slice(i1) and
735-
isinstance(i0, ArraySymbol) and
736-
holds_alternative[cppArrayNodePtr](self.ptr.indices()[1]) and
737-
get[cppArrayNodePtr](self.ptr.indices()[1]) == (<ArraySymbol>i0).array_ptr and
738-
holds_alternative[cppSlice](self.ptr.indices()[0])):
720+
# The indexed array must be a Constant square matrix
721+
array = next(self.iter_predecessors())
722+
if not isinstance(array, Constant):
723+
return None
724+
if array.ndim() != 2 or array.shape()[0] != array.shape()[1]:
725+
return None
739726

740-
return Permutation(array, i0)
727+
# The total operation must of the form A[i0, i1][i2, i3]
728+
729+
i0, i1 = self._iter_indices()
730+
i2, i3 = index if isinstance(index, tuple) else (index, slice(None, None, None))
731+
732+
# It also must be of the form A[outer, inner][inner, outer]
733+
734+
if (
735+
isinstance(i0, slice) and isinstance(i3, slice) # outer is a slice
736+
and i0 == i3 and i0 == slice(None) # and those slices are empty
737+
and isinstance(i1, ArraySymbol) and isinstance(i2, ArraySymbol) # inner is an array
738+
and i1.id() == i2.id() # and those arrays are the same array
739+
and i1.shape() == (array.shape()[0],) # and the shape of the array is correct
740+
):
741+
return Permutation(array, i1)
742+
743+
if (
744+
isinstance(i1, slice) and isinstance(i2, slice) # inner is a slice
745+
and i1 == i2 and i1 == slice(None) # and those slices are empty
746+
and isinstance(i0, ArraySymbol) and isinstance(i3, ArraySymbol) # outer is an array
747+
and i0.id() == i3.id() # and those arrays are the same array
748+
and i0.shape() == (array.shape()[0],) # and the shape of the array is correct
749+
):
750+
return Permutation(array, i0)
751+
752+
return None
741753

742-
return super().__getitem__(index)
743754

744755
@classmethod
745756
def _from_symbol(cls, Symbol symbol):
@@ -791,6 +802,17 @@ cdef class AdvancedIndexing(ArraySymbol):
791802

792803
zf.writestr(directory + "indices.json", encoder.encode(indices))
793804

805+
def _iter_indices(self):
806+
for variant in self.ptr.indices():
807+
if holds_alternative[cppSlice](variant):
808+
cppslice = <cppSlice>get[cppSlice](variant)
809+
yield slice(None)
810+
elif holds_alternative[cppArrayNodePtr](variant):
811+
array_ptr = <cppArrayNodePtr>get[cppArrayNodePtr](variant)
812+
yield symbol_from_ptr(self.model, array_ptr)
813+
else:
814+
raise RuntimeError("unexpected variant contents")
815+
794816
cdef cppAdvancedIndexingNode* ptr
795817

796818
_register(AdvancedIndexing, typeid(cppAdvancedIndexingNode))
@@ -3474,10 +3496,7 @@ cdef class Permutation(ArraySymbol):
34743496
>>> type(p)
34753497
<class 'dwave.optimization.symbols.Permutation'>
34763498
"""
3477-
def __init__(self, Constant array, ListVariable x):
3478-
# todo: Loosen the types accepted. But this Cython code doesn't yet have
3479-
# the type heirarchy needed so for how we specify explicitly
3480-
3499+
def __init__(self, Constant array, ArraySymbol x):
34813500
if array.model is not x.model:
34823501
raise ValueError("array and x do not share the same underlying model")
34833502

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
fixes:
3+
- |
4+
Fix indexing operations sometimes raising an error when it should
5+
create a ``Permutation`` symbol.

tests/test_symbols.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2708,7 +2708,19 @@ def generate_symbols(self):
27082708
model.lock()
27092709
yield p
27102710

2711-
def test(self):
2711+
def test_constant_integer(self):
2712+
from dwave.optimization.symbols import Permutation
2713+
2714+
model = Model()
2715+
2716+
A = model.constant(np.arange(25).reshape((5, 5)))
2717+
x = model.constant(np.arange(5))
2718+
2719+
self.assertIsInstance(A[x, :][:, x], Permutation)
2720+
self.assertIsInstance(A[:, x][x, :], Permutation)
2721+
self.assertIsInstance(A[:, x][x], Permutation)
2722+
2723+
def test_list_indexer(self):
27122724
from dwave.optimization.symbols import Permutation
27132725

27142726
model = Model()
@@ -2718,11 +2730,38 @@ def test(self):
27182730

27192731
self.assertIsInstance(A[x, :][:, x], Permutation)
27202732
self.assertIsInstance(A[:, x][x, :], Permutation)
2733+
self.assertIsInstance(A[:, x][x], Permutation)
2734+
2735+
def test_not_permutation(self):
2736+
# Some "near" permutations that aren't quite right
2737+
from dwave.optimization.symbols import Permutation
2738+
2739+
with self.subTest("A not square"):
2740+
model = Model()
2741+
2742+
A = model.constant(np.arange(30).reshape((5, 6)))
2743+
x = model.list(5)
2744+
2745+
self.assertNotIsInstance(A[x, :][:, x], Permutation)
2746+
self.assertNotIsInstance(A[:, x][x, :], Permutation)
2747+
2748+
with self.subTest("A not 2d"):
2749+
model = Model()
2750+
2751+
A = model.constant(np.arange(25).reshape((5, 5, 1)))
2752+
x = model.list(5)
2753+
2754+
self.assertNotIsInstance(A[x, :][:, x], Permutation)
2755+
self.assertNotIsInstance(A[:, x][x, :], Permutation)
2756+
2757+
with self.subTest("indexer wrong size"):
2758+
model = Model()
27212759

2722-
b = model.constant(np.arange(30).reshape((5, 6)))
2760+
A = model.constant(np.arange(25).reshape((5, 5)))
2761+
x = model.list(4)
27232762

2724-
self.assertNotIsInstance(b[x, :][:, x], Permutation)
2725-
self.assertNotIsInstance(b[:, x][x, :], Permutation)
2763+
self.assertNotIsInstance(A[x, :][:, x], Permutation)
2764+
self.assertNotIsInstance(A[:, x][x, :], Permutation)
27262765

27272766

27282767
class TestProd(utils.ReduceTests):

0 commit comments

Comments
 (0)