Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
5e039df
feat(attention): Added inter document masking for manual and flash at…
BlueCrescent Feb 23, 2026
a5354c4
feat(data loading): GPT2LLMCollateFn can now determine the sub sequen…
BlueCrescent Feb 23, 2026
eba9c5b
fix(attention): NaNs when using padding + inter document masking with…
BlueCrescent Feb 24, 2026
3d583c1
feat(attention): added sub_seq_lengths_key to GPT2LLMConfig and renam…
BlueCrescent Feb 24, 2026
cd00777
fix(attention): added missing sub_seq_lengths_key parameter to get_gp…
BlueCrescent Feb 24, 2026
fdd6465
fix(attention): computing sub sequence lengths on correct input
BlueCrescent Feb 24, 2026
e952bd0
docs(attention): better _get_unpad_data_for_concatenated_sequences() …
BlueCrescent Feb 24, 2026
115259c
test(attention): fixed collator tests and improved collator config va…
BlueCrescent Feb 25, 2026
8d71928
fix: directly use correct device + dtype for eos positions extensions
BlueCrescent Feb 26, 2026
91f67e4
chore: remove comment
BlueCrescent Feb 27, 2026
ab4c79a
refactor(data): improved naming in collator
BlueCrescent Feb 28, 2026
379420d
test(attention): turned global manual seed into fixture
BlueCrescent Feb 28, 2026
382a952
refactor(attention): improved error handling
BlueCrescent Feb 28, 2026
478628d
fix(attention): added not supported assertion for inter document mask…
BlueCrescent Feb 28, 2026
6d7e502
refactor(data): improved detection and reporting of sequences in batc…
BlueCrescent Feb 28, 2026
f173cff
refactor(attention): removed duplicate exception
BlueCrescent Feb 28, 2026
4a31747
fix(attention): bug introduced in improved error handling for unsuppo…
BlueCrescent Feb 28, 2026
b2ec756
chore: Merge remote-tracking branch 'origin/main' into inter_document…
BlueCrescent Mar 12, 2026
956b958
chore: Merge remote-tracking branch 'origin/main' into inter_document…
BlueCrescent Mar 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,25 @@ class BatchSamplerConfig(BaseModel):
class GPT2LLMCollateFnConfig(BaseModel):
sample_key: str
target_key: str
sub_seq_lengths_key: str | None = None
eos_token_id: int | None = None
padding_token_id: int | None = None

@model_validator(mode="before")
def check_sub_seq_lengths_and_eos_token(cls, values):
sub_seq_lengths_key = values.get("sub_seq_lengths_key")
eos_token_id = values.get("eos_token_id")
if (sub_seq_lengths_key is None) != (eos_token_id is None):
raise ValueError("Either both or neither of sub_seq_lengths_key and eos_token_id must be provided.")
return values

@model_validator(mode="before")
def check_padding_token_and_sub_seq_lengths(cls, values):
padding_token_id = values.get("padding_token_id")
sub_seq_lengths_key = values.get("sub_seq_lengths_key")
if padding_token_id is not None and sub_seq_lengths_key is None:
raise ValueError("If padding_token_id is provided, sub_seq_lengths_key must also be provided.")
return values
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated


class LLMDataLoaderConfig(BaseModel):
Expand Down
56 changes: 55 additions & 1 deletion src/modalities/models/gpt2/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,31 @@
class GPT2LLMCollateFn(CollateFnIF):
"""GPT2LLMCollateFn class to define a collate function for GPT2 language model."""

def __init__(self, sample_key: str, target_key: str):
def __init__(
self,
sample_key: str,
target_key: str,
sub_seq_lengths_key: str | None = None,
eos_token_id: int | None = None,
padding_token_id: int | None = None,
):
"""
Initializes the Collator object.
If the eos token ID and the sub_seq_lengths_key are provided,
a list[list[int]] representing the sub-sequence lengths will be created.

Args:
sample_key (str): The key for accessing the sample data.
target_key (str): The key for accessing the target data.
sub_seq_lengths_key (str | None): The key for accessing the sub-sequence lengths.
eos_token_id (int | None): The end-of-sequence token ID.
padding_token_id (int | None): The padding token ID.
"""
self.sample_key = sample_key
self.target_key = target_key
self.sub_seq_lengths_key = sub_seq_lengths_key
self.eos_token_id = eos_token_id
self.padding_token_id = padding_token_id

def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch:
"""
Expand All @@ -33,4 +48,43 @@ def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch:
sample_tensor = torch.stack([torch.tensor(d[self.sample_key]) for d in batch])
samples = {self.sample_key: sample_tensor[:, :-1]}
targets = {self.target_key: sample_tensor[:, 1:]}
if self.sub_seq_lengths_key is not None:
# Determine sub sequence lengths by finding the eos tokens in each sequence in the batch.
sub_seq_lengths = self._compute_sub_sequence_lengths_for_each_sequence(samples[self.sample_key])
samples[self.sub_seq_lengths_key] = sub_seq_lengths
return DatasetBatch(targets=targets, samples=samples)

def _compute_sub_sequence_lengths_for_each_sequence(self, sample_tensor: torch.Tensor) -> list[list[int]]:
sub_seq_lengths = []
for seq in sample_tensor:
eos_positions = (seq == self.eos_token_id).nonzero(as_tuple=True)[0]
if len(eos_positions) == 0:
assert (
self.padding_token_id is None or seq[0] != self.padding_token_id
), "Sequence starts with padding token"
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion message "Sequence starts with padding token" is not very informative. It doesn't explain why this is a problem or what the user should do to fix it. Consider improving the error message to explain that sequences cannot start with padding tokens because it would result in invalid sub-sequence length computation, and suggest how to fix the data (e.g., "Invalid sequence: cannot start with padding token. Please ensure padding is only at the end of sequences after EOS tokens.").

Suggested change
), "Sequence starts with padding token"
), (
"Invalid sequence: cannot start with padding token. This prevents valid "
"sub-sequence length computation when no EOS token is present. Please ensure "
"padding is only applied at the end of sequences, typically after EOS tokens."
)

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a problem? Because of the assumption in _has_cutoff_final_sequence?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this assertion is to detect the case that a whole batch sequence consists of padding tokens. I changed it to actually check the whole sequence as well if the first sequence is a padding token.

sub_seq_lengths.append([len(seq)])
else:
subseq_lengths = self._compute_subsequence_length(seq, eos_positions)
sub_seq_lengths.append(subseq_lengths)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming is confusing, since sub_seq_lengths and subseq_lengths are basically the same name. Maybe better use e.g. batch_sub_seq_lengths and sample_sub_seq_lengths

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improved naming in collator

return sub_seq_lengths

def _compute_subsequence_length(self, seq: torch.Tensor, eos_positions: torch.Tensor) -> list[int]:
# If the last sequence is cut, i.e. does not end on an eos token,
# it should also be included unless the padding token is set and
# the last sequence is just padding.
last_eos_pos = eos_positions[-1].item()
if self._has_cutoff_final_sequence(seq, last_eos_pos):
eos_positions = torch.cat([eos_positions, torch.tensor([len(seq) - 1])])
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
# Compute length of each subsequence and add to lengths list.
subseq_lengths = []
prev_pos = 0
for pos in eos_positions:
subseq_lengths.append(pos.item() - prev_pos + 1)
prev_pos = pos.item() + 1
return subseq_lengths

def _has_cutoff_final_sequence(self, seq: torch.Tensor, last_eos_pos: int) -> bool:
# Assumption: If the first token of the last sequence is padding, so is the rest.
return last_eos_pos < len(seq) - 1 and (
self.padding_token_id is None or seq[last_eos_pos + 1] != self.padding_token_id
)
Loading