Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 5 additions & 204 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3498,214 +3498,16 @@ def __getitem__(self, *args):
ogrid = _nd_grid(sparse=True)


class PermuteRowElements(Op):
"""Permute the elements of each row (inner-most dim) of a tensor.

A permutation will be applied to every row (vector) of the input tensor x.
Depending on the dimensionality of x and the permutation tensor y,
different cases are possible.
If y.ndim = 1, y is a single permutation, that will be applied to every
vector of x. For instance, if x is a matrix, the same permutation will be
applied to each row of x.
If x.ndim = y.ndim, each row of x corresponds to a row of y, containing
a permutation that will be applied to that row. For instance, if x and y
are two matrices, a different permutation will be applied to each row of x.
If x.ndim > y.ndim, y will be broadcasted to fit x, then each row (vector)
of x will be reordered according to the corresponding row of y. (This is
a generalization of the first case).
If x.ndim = 1, every permutation in y will be applied to x, and the output
will contain all the results.
If x.ndim < y.ndim, x will be broadcasted to fit y, and different
permutations contained in y will be applied to each vector in x. (This is
a generalization of the previous case).

If the "inverse" argument is True, the Op will perform the inverse
permutation instead.
"""

__props__ = ("inverse",)

def __init__(self, inverse: bool):
super().__init__()
self.inverse = inverse

def make_node(self, x, y):
x = as_tensor_variable(x)
y = as_tensor_variable(y)

# y should contain integers
assert y.type.dtype in integer_dtypes

# Match shapes of x and y
x_dim = x.type.ndim
y_dim = y.type.ndim

if x_dim > y_dim:
y = shape_padleft(y, n_ones=(x_dim - y_dim))
elif x_dim < y_dim:
x = shape_padleft(x, n_ones=(y_dim - x_dim))

out_shape = [
1 if xb == 1 and yb == 1 else None
for xb, yb in zip(x.type.shape, y.type.shape, strict=True)
]
out_type = tensor(dtype=x.type.dtype, shape=out_shape)

inputlist = [x, y]
outputlist = [out_type]
return Apply(self, inputlist, outputlist)

def _rec_perform(self, node, x, y, inverse, out, curdim):
"""Perform the permutation by doing a recursion over the input
dimensions.

For every dimension, starting with the leftmost, the right set of
indices is determined (depending if broadcasting or not), then
the function is recursively called on the appropriate subtensors.

The terminal case is reached when the current tensors are vector,
then the permutation contained in y is applied to x.

Parameters
----------
x: TensorVariable
The input tensor, on which the permutation is applied.
y: TensorVariable
Tensor containing the permutations to apply.
inverse: bool
Whether to apply permutations or their inverse.
out: TensorVariable
Tensor storing the output result.
curdim: int
Counter of the current depth of recursion.

"""
if len(x.shape) == 1:
# Numpy advanced indexing works in this case
if inverse:
out[y] = x[:]
else:
out[:] = x[y]
else:
xs0 = x.shape[0]
ys0 = y.shape[0]
if xs0 == ys0:
for i in range(xs0):
self._rec_perform(node, x[i], y[i], inverse, out[i], curdim + 1)
elif ys0 == 1 and node.inputs[1].type.shape[curdim] == 1:
# Broadcast y
for i in range(xs0):
self._rec_perform(node, x[i], y[0], inverse, out[i], curdim + 1)
elif xs0 == 1 and node.inputs[0].type.shape[curdim] == 1:
# Broadcast x
for i in range(ys0):
self._rec_perform(node, x[0], y[i], inverse, out[i], curdim + 1)
else:
raise ValueError(f"Dimension mismatch: {xs0}, {ys0}")

def perform(self, node, inp, out):
x, y = inp
(outs,) = out
x_s = x.shape
y_s = y.shape
assert len(x_s) == len(y_s)

# Make sure the output is big enough
out_s = []
# zip strict not specified because we are in a hot loop
for xdim, ydim in zip(x_s, y_s):
if xdim == ydim:
outdim = xdim
elif xdim == 1:
outdim = ydim
elif ydim == 1:
outdim = xdim
else:
raise ValueError(f"Dimension mismatch: {xdim}, {ydim}")
out_s.append(outdim)

if outs[0] is None or outs[0].shape != out_s:
outs[0] = np.empty(out_s, dtype=x.dtype)

self._rec_perform(node, x, y, self.inverse, outs[0], curdim=0)

def infer_shape(self, fgraph, node, in_shapes):
from pytensor.tensor.math import maximum

shp_x = in_shapes[0]
shp_y = in_shapes[1]
assert len(shp_x) == len(shp_y)
out_shape = [maximum(sx, sy) for sx, sy in zip(shp_x, shp_y, strict=True)]
return [out_shape]

def grad(self, inp, grads):
from pytensor.tensor.math import Sum

x, y = inp
(gz,) = grads
# First, compute the gradient wrt the broadcasted x.
# If 'inverse' is False (0), apply the inverse of y on gz.
# Else, apply y on gz.
gx = permute_row_elements(gz, y, not self.inverse)

# If x has been broadcasted along some axes, we need to sum
# the gradient over these axes, but keep the dimension (as
# broadcastable)
broadcasted_dims = [
dim
for dim in range(gz.type.ndim)
if x.type.shape[dim] == 1 and gz.type.shape[dim] != 1
]
gx = Sum(axis=broadcasted_dims)(gx)

# Sum(...) removed the dimensions in broadcasted_dims,
# so we need to put them back.
newdims = []
i = 0
for dim in range(gz.type.ndim):
if dim in broadcasted_dims:
newdims.append("x")
else:
newdims.append(i)
i += 1

gx = gx.dimshuffle(newdims)
assert gx.type.ndim == x.type.ndim
assert all(
s1 == s2
for s1, s2 in zip(gx.type.shape, x.type.shape, strict=True)
if s1 == 1 or s2 == 1
)

# if x is an integer type, then so is the output.
# this means f(x+eps) = f(x) so the gradient with respect
# to x is zero
if x.type.dtype in discrete_dtypes:
gx = x.zeros_like()

# The elements of y affect the output,
# so they are connected to the output,
# and the transformation isn't defined if their values
# are non-integer, so the gradient with respect to them is
# undefined

return [gx, grad_undefined(self, 1, y)]


def permute_row_elements(x, y, inverse=False):
return PermuteRowElements(inverse=inverse)(x, y)


def inverse_permutation(perm):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should also be removed

"""Computes the inverse of permutations.
"""Compute the inverse of permutations along the last axis.

Each row of input should contain a permutation of the first integers.

Returns the argsort of each row, which is the inverse permutation.
"""
from pytensor.tensor.sort import argsort

_perm = as_tensor_variable(perm)
return permute_row_elements(
arange(_perm.shape[-1], dtype=_perm.dtype), _perm, inverse=True
)
return cast(argsort(_perm, axis=-1), _perm.dtype)


class ExtractDiag(COp):
Expand Down Expand Up @@ -4592,7 +4394,6 @@ def ix_(*args):
"ogrid",
"ones",
"ones_like",
"permute_row_elements",
"roll",
"scalar_from_tensor",
"second",
Expand Down
Loading