Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions fast_llm/data/preparation/gpt_memmap/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,12 @@ class ConversationSourceConfig(LanguageModelSourceConfig):
desc="Field containing the conversation messages list. Each message should have 'role' and 'content' keys.",
hint=FieldHint.core,
)
train_on_eos: bool = Field(
default=False,
desc="Include the end-of-sequence token appended after the final message in the training loss."
" When disabled, that token is masked from the loss.",
hint=FieldHint.optional,
)

@functools.cached_property
def columns(self) -> list[str]:
Expand Down
1 change: 1 addition & 0 deletions fast_llm/data/preparation/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelDocumen
sample[self._source_schema.messages],
True,
True,
train_on_eos=self._source_schema.train_on_eos,
data_type=self._data_type,
)
token_spans_by_type[SpanType.loss_masking] = loss_masking_spans
Expand Down
3 changes: 2 additions & 1 deletion fast_llm/data/preparation/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def tokenize_chat(
messages: list[dict[str, str]],
begin: bool = True,
end: bool = True,
train_on_eos: bool = False,
data_type: DataType = DataType.int64,
) -> tuple["torch.Tensor", list[tuple[int, int]]]:
"""
Expand Down Expand Up @@ -291,7 +292,7 @@ def tokenize_chat(
prepend_bos = begin and self.bod_id not in tokens
append_eos = end and self.eod_id not in tokens
tokens = [self.bod_id] * prepend_bos + list(tokens) + [self.eod_id] * append_eos
train_mask = [False] * prepend_bos + [bool(m) for m in train_mask] + [False] * append_eos
train_mask = [False] * prepend_bos + [bool(m) for m in train_mask] + [train_on_eos] * append_eos

# Convert boolean train mask to loss masking spans (spans where train_mask[i] == False)
loss_masking_spans = _train_mask_to_loss_spans(train_mask)
Expand Down
14 changes: 14 additions & 0 deletions tests/data/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,20 @@ def test_tokenize_chat(common_tokenizer, messages, expected_tokens, expected_los
Assert.eq(loss_masking_spans, expected_loss_masking_spans)


def test_tokenize_chat_train_on_eos(common_tokenizer):
common_tokenizer.tokenizer.chat_template = CHAT_TEMPLATE
messages = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}]
tokens, spans = common_tokenizer.tokenize_chat(messages)
tokens_eos, spans_eos = common_tokenizer.tokenize_chat(messages, train_on_eos=True)

# `train_on_eos` changes only the loss mask of the appended EOS, not the tokens.
Assert.eq(tokens_eos.tolist(), tokens.tolist())
Assert.eq(tokens[-1].item(), common_tokenizer.eod_id)
# By default the trailing EOS is its own masked span; `train_on_eos` trains on it instead.
Assert.eq(spans[-1], (len(tokens) - 1, len(tokens)))
Assert.eq(spans_eos, spans[:-1])


@pytest.mark.parametrize(
("train_mask", "expected_loss_spans"),
(
Expand Down
Loading