From 8cf91e0bbaf630cf97c614f3fc2c42eb15d9f28b Mon Sep 17 00:00:00 2001 From: SimpingOjou Date: Thu, 4 Jun 2026 15:08:19 +0200 Subject: [PATCH 1/2] Optimize expand_index_in_trial in Dataset class for improved performance --- cebra/data/base.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/cebra/data/base.py b/cebra/data/base.py index f5491e51..c3ca46ae 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -154,8 +154,6 @@ def expand_index_in_trial(self, index, trial_ids, trial_borders): trial_ids is in size of a length of self.index and indicate the trial id of the index belong to. trial_borders is in size of a length of self.idnex and indicate the border of each trial. - Todo: - - rewrite """ # TODO(stes) potential room for speed improvements by pre-allocating these tensors/ @@ -163,16 +161,15 @@ def expand_index_in_trial(self, index, trial_ids, trial_borders): offset = torch.arange(-self.offset.left, self.offset.right, device=index.device) - index = torch.tensor( - [ - torch.clamp( - i, - trial_borders[trial_ids[i]] + self.offset.left, - trial_borders[trial_ids[i] + 1] - self.offset.right, - ) for i in index - ], - device=self.device, - ) + + # Vectorized lookup and boundary calculation + batch_trial_ids = trial_ids[index] + min_borders = trial_borders[batch_trial_ids] + self.offset.left + max_borders = trial_borders[batch_trial_ids + 1] - self.offset.right + + # Fast C-level clamp + index = torch.clamp(index, min=min_borders, max=max_borders) + return index[:, None] + offset[None, :] @abc.abstractmethod From 4ebe3bec206f69a3db1bbfc0b545e086b53bdccb Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Wed, 17 Jun 2026 10:59:26 +0200 Subject: [PATCH 2/2] Remove unneeded comments Co-authored-by: Steffen Schneider --- cebra/data/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cebra/data/base.py b/cebra/data/base.py index c3ca46ae..61f75d3b 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -162,12 +162,10 @@ def expand_index_in_trial(self, index, trial_ids, trial_borders): self.offset.right, device=index.device) - # Vectorized lookup and boundary calculation batch_trial_ids = trial_ids[index] min_borders = trial_borders[batch_trial_ids] + self.offset.left max_borders = trial_borders[batch_trial_ids + 1] - self.offset.right - # Fast C-level clamp index = torch.clamp(index, min=min_borders, max=max_borders) return index[:, None] + offset[None, :]