-
Notifications
You must be signed in to change notification settings - Fork 16
Added inter document masking for manual and flash attention. #434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
5e039df
a5354c4
eba9c5b
3d583c1
cd00777
fdd6465
e952bd0
115259c
8d71928
91f67e4
ab4c79a
379420d
382a952
478628d
6d7e502
f173cff
4a31747
b2ec756
956b958
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,16 +7,31 @@ | |||||||||||||
| class GPT2LLMCollateFn(CollateFnIF): | ||||||||||||||
| """GPT2LLMCollateFn class to define a collate function for GPT2 language model.""" | ||||||||||||||
|
|
||||||||||||||
| def __init__(self, sample_key: str, target_key: str): | ||||||||||||||
| def __init__( | ||||||||||||||
| self, | ||||||||||||||
| sample_key: str, | ||||||||||||||
| target_key: str, | ||||||||||||||
| sub_seq_lengths_key: str | None = None, | ||||||||||||||
| eos_token_id: int | None = None, | ||||||||||||||
| padding_token_id: int | None = None, | ||||||||||||||
| ): | ||||||||||||||
| """ | ||||||||||||||
| Initializes the Collator object. | ||||||||||||||
| If the eos token ID and the sub_seq_lengths_key are provided, | ||||||||||||||
| a list[list[int]] representing the sub-sequence lengths will be created. | ||||||||||||||
|
|
||||||||||||||
| Args: | ||||||||||||||
| sample_key (str): The key for accessing the sample data. | ||||||||||||||
| target_key (str): The key for accessing the target data. | ||||||||||||||
| sub_seq_lengths_key (str | None): The key for accessing the sub-sequence lengths. | ||||||||||||||
| eos_token_id (int | None): The end-of-sequence token ID. | ||||||||||||||
| padding_token_id (int | None): The padding token ID. | ||||||||||||||
| """ | ||||||||||||||
| self.sample_key = sample_key | ||||||||||||||
| self.target_key = target_key | ||||||||||||||
| self.sub_seq_lengths_key = sub_seq_lengths_key | ||||||||||||||
| self.eos_token_id = eos_token_id | ||||||||||||||
| self.padding_token_id = padding_token_id | ||||||||||||||
|
|
||||||||||||||
| def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: | ||||||||||||||
| """ | ||||||||||||||
|
|
@@ -33,4 +48,43 @@ def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: | |||||||||||||
| sample_tensor = torch.stack([torch.tensor(d[self.sample_key]) for d in batch]) | ||||||||||||||
| samples = {self.sample_key: sample_tensor[:, :-1]} | ||||||||||||||
| targets = {self.target_key: sample_tensor[:, 1:]} | ||||||||||||||
| if self.sub_seq_lengths_key is not None: | ||||||||||||||
| # Determine sub sequence lengths by finding the eos tokens in each sequence in the batch. | ||||||||||||||
| sub_seq_lengths = self._compute_sub_sequence_lengths_for_each_sequence(samples[self.sample_key]) | ||||||||||||||
| samples[self.sub_seq_lengths_key] = sub_seq_lengths | ||||||||||||||
| return DatasetBatch(targets=targets, samples=samples) | ||||||||||||||
|
|
||||||||||||||
| def _compute_sub_sequence_lengths_for_each_sequence(self, sample_tensor: torch.Tensor) -> list[list[int]]: | ||||||||||||||
| sub_seq_lengths = [] | ||||||||||||||
| for seq in sample_tensor: | ||||||||||||||
| eos_positions = (seq == self.eos_token_id).nonzero(as_tuple=True)[0] | ||||||||||||||
| if len(eos_positions) == 0: | ||||||||||||||
| assert ( | ||||||||||||||
| self.padding_token_id is None or seq[0] != self.padding_token_id | ||||||||||||||
| ), "Sequence starts with padding token" | ||||||||||||||
|
||||||||||||||
| ), "Sequence starts with padding token" | |
| ), ( | |
| "Invalid sequence: cannot start with padding token. This prevents valid " | |
| "sub-sequence length computation when no EOS token is present. Please ensure " | |
| "padding is only applied at the end of sequences, typically after EOS tokens." | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this a problem? Because of the assumption in _has_cutoff_final_sequence?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this assertion is to detect the case that a whole batch sequence consists of padding tokens. I changed it to actually check the whole sequence as well if the first sequence is a padding token.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The naming is confusing, since sub_seq_lengths and subseq_lengths are basically the same name. Maybe better use e.g. batch_sub_seq_lengths and sample_sub_seq_lengths
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improved naming in collator
Uh oh!
There was an error while loading. Please reload this page.