@@ -663,10 +663,6 @@ cdef class ArgSort(ArraySymbol):
663663_register(ArgSort, typeid(cppArgSortNode))
664664
665665
666- cdef bool _empty_slice(object slice_) noexcept:
667- return slice_.start is None and slice_.stop is None and slice_.step is None
668-
669-
670666cdef class AdvancedIndexing(ArraySymbol):
671667 """ Advanced indexing.
672668
@@ -711,35 +707,50 @@ cdef class AdvancedIndexing(ArraySymbol):
711707
712708 array = next(self .iter_predecessors())
713709
714- if (
715- isinstance (array, Constant)
716- and array.ndim() == 2
717- and array.shape()[0 ] == array.shape()[1 ] # square matrix
718- and self .ptr.indices().size() == 2
719- and isinstance (index, tuple )
720- and len (index) == 2
721- ):
722- i0, i1 = index
710+ if (perm := self ._check_permutation(index)) is not None :
711+ return perm
723712
724- # check the [x, :][:, x] case
725- if (isinstance (i0, slice ) and _empty_slice(i0) and
726- isinstance (i1, ArraySymbol) and
727- holds_alternative[cppArrayNodePtr](self .ptr.indices()[0 ]) and
728- get[cppArrayNodePtr](self .ptr.indices()[0 ]) == (< ArraySymbol> i1).array_ptr and
729- holds_alternative[cppSlice](self .ptr.indices()[1 ])):
713+ return super ().__getitem__(index)
730714
731- return Permutation(array, i1)
715+ cdef object _check_permutation(self , index):
716+ """ Return a Permutation symbol if the indexing the symbol results in
717+ a permutation, otherwise return None.
718+ """
732719
733- # check the [:, x][x, :] case
734- if ( isinstance (i1, slice ) and _empty_slice(i1) and
735- isinstance (i0, ArraySymbol) and
736- holds_alternative[cppArrayNodePtr]( self .ptr.indices()[ 1 ]) and
737- get[cppArrayNodePtr]( self .ptr.indices ()[1 ]) == ( < ArraySymbol > i0).array_ptr and
738- holds_alternative[cppSlice]( self .ptr.indices()[ 0 ])):
720+ # The indexed array must be a Constant square matrix
721+ array = next( self .iter_predecessors())
722+ if not isinstance (array, Constant):
723+ return None
724+ if array.ndim() ! = 2 or array.shape ()[0 ] ! = array.shape()[ 1 ]:
725+ return None
739726
740- return Permutation(array, i0)
727+ # The total operation must of the form A[i0, i1][i2, i3]
728+
729+ i0, i1 = self ._iter_indices()
730+ i2, i3 = index if isinstance (index, tuple ) else (index, slice (None , None , None ))
731+
732+ # It also must be of the form A[outer, inner][inner, outer]
733+
734+ if (
735+ isinstance (i0, slice ) and isinstance (i3, slice ) # outer is a slice
736+ and i0 == i3 and i0 == slice (None ) # and those slices are empty
737+ and isinstance (i1, ArraySymbol) and isinstance (i2, ArraySymbol) # inner is an array
738+ and i1.id() == i2.id() # and those arrays are the same array
739+ and i1.shape() == (array.shape()[0 ],) # and the shape of the array is correct
740+ ):
741+ return Permutation(array, i1)
742+
743+ if (
744+ isinstance (i1, slice ) and isinstance (i2, slice ) # inner is a slice
745+ and i1 == i2 and i1 == slice (None ) # and those slices are empty
746+ and isinstance (i0, ArraySymbol) and isinstance (i3, ArraySymbol) # outer is an array
747+ and i0.id() == i3.id() # and those arrays are the same array
748+ and i0.shape() == (array.shape()[0 ],) # and the shape of the array is correct
749+ ):
750+ return Permutation(array, i0)
751+
752+ return None
741753
742- return super ().__getitem__(index)
743754
744755 @classmethod
745756 def _from_symbol (cls , Symbol symbol ):
@@ -791,6 +802,17 @@ cdef class AdvancedIndexing(ArraySymbol):
791802
792803 zf.writestr(directory + " indices.json" , encoder.encode(indices))
793804
805+ def _iter_indices (self ):
806+ for variant in self .ptr.indices():
807+ if holds_alternative[cppSlice](variant):
808+ cppslice = < cppSlice> get[cppSlice](variant)
809+ yield slice (None )
810+ elif holds_alternative[cppArrayNodePtr](variant):
811+ array_ptr = < cppArrayNodePtr> get[cppArrayNodePtr](variant)
812+ yield symbol_from_ptr(self .model, array_ptr)
813+ else :
814+ raise RuntimeError (" unexpected variant contents" )
815+
794816 cdef cppAdvancedIndexingNode* ptr
795817
796818_register(AdvancedIndexing, typeid(cppAdvancedIndexingNode))
@@ -3474,10 +3496,7 @@ cdef class Permutation(ArraySymbol):
34743496 >>> type(p)
34753497 <class 'dwave.optimization.symbols.Permutation'>
34763498 """
3477- def __init__ (self , Constant array , ListVariable x ):
3478- # todo: Loosen the types accepted. But this Cython code doesn't yet have
3479- # the type heirarchy needed so for how we specify explicitly
3480-
3499+ def __init__ (self , Constant array , ArraySymbol x ):
34813500 if array.model is not x.model:
34823501 raise ValueError (" array and x do not share the same underlying model" )
34833502
0 commit comments