Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 34 additions & 33 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,37 @@ def __hash__(self):
)
return hash((type(self), props_values))

@staticmethod
def str_from_slice(entry):
if entry.step is not None:
return ":".join(
(
"start" if entry.start is not None else "",
"stop" if entry.stop is not None else "",
"step",
)
)
if entry.stop is not None:
return f"{'start' if entry.start is not None else ''}:stop"
if entry.start is not None:
return "start:"
return ":"

@staticmethod
def str_from_indices(idx_list):
indices = []
letter_indexes = 0
for entry in idx_list:
if isinstance(entry, slice):
indices.append(BaseSubtensor.str_from_slice(entry))
else:
indices.append("ijk"[letter_indexes % 3] * (letter_indexes // 3 + 1))
letter_indexes += 1
return ", ".join(indices)

def __str__(self):
return f"{self.__class__.__name__}{{{self.str_from_indices(self.idx_list)}}}"


class Subtensor(BaseSubtensor, COp):
"""Basic NumPy indexing operator."""
Expand Down Expand Up @@ -907,37 +938,6 @@ def connection_pattern(self, node):

return rval

@staticmethod
def str_from_slice(entry):
if entry.step is not None:
return ":".join(
(
"start" if entry.start is not None else "",
"stop" if entry.stop is not None else "",
"step",
)
)
if entry.stop is not None:
return f"{'start' if entry.start is not None else ''}:stop"
if entry.start is not None:
return "start:"
return ":"

@staticmethod
def str_from_indices(idx_list):
indices = []
letter_indexes = 0
for entry in idx_list:
if isinstance(entry, slice):
indices.append(Subtensor.str_from_slice(entry))
else:
indices.append("ijk"[letter_indexes % 3] * (letter_indexes // 3 + 1))
letter_indexes += 1
return ", ".join(indices)

def __str__(self):
return f"{self.__class__.__name__}{{{self.str_from_indices(self.idx_list)}}}"

@staticmethod
def default_helper_c_code_args():
"""
Expand Down Expand Up @@ -1407,7 +1407,7 @@ def __init__(

def __str__(self):
name = "SetSubtensor" if self.set_instead_of_inc else "IncSubtensor"
return f"{name}{{{Subtensor.str_from_indices(self.idx_list)}}}"
return f"{name}{{{super().str_from_indices(self.idx_list)}}}"

def make_node(self, x, y, *inputs):
"""
Expand Down Expand Up @@ -2578,11 +2578,12 @@ def __init__(
self.ignore_duplicates = ignore_duplicates

def __str__(self):
return (
name = (
"AdvancedSetSubtensor"
if self.set_instead_of_inc
else "AdvancedIncSubtensor"
)
return f"{name}{{{super().str_from_indices(self.idx_list)}}}"

def make_node(self, x, y, *index_variables):
if len(index_variables) != self.n_index_vars:
Expand Down
Loading