From ff60b4de085133cb965e460a7c59b3a687d01e73 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 5 Jun 2026 14:50:40 -0400 Subject: [PATCH] Add configurable train_on_eos for conversation data preparation Add a `train_on_eos` flag (default `False`) to `ConversationSourceConfig` controlling whether the end-of-sequence token appended after the final message is included in the training loss. When disabled (the default, unchanged behavior) that token is masked from the loss; when enabled it becomes a training target. Threaded through `tokenize_chat` as a `train_on_eos` parameter. Co-Authored-By: Claude Opus 4.8 (1M context) --- fast_llm/data/preparation/gpt_memmap/config.py | 6 ++++++ fast_llm/data/preparation/gpt_memmap/prepare.py | 1 + fast_llm/data/preparation/tokenizer.py | 3 ++- tests/data/test_tokenizer.py | 14 ++++++++++++++ 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/fast_llm/data/preparation/gpt_memmap/config.py b/fast_llm/data/preparation/gpt_memmap/config.py index fe39e9da6..b8a7f4f6b 100644 --- a/fast_llm/data/preparation/gpt_memmap/config.py +++ b/fast_llm/data/preparation/gpt_memmap/config.py @@ -170,6 +170,12 @@ class ConversationSourceConfig(LanguageModelSourceConfig): desc="Field containing the conversation messages list. Each message should have 'role' and 'content' keys.", hint=FieldHint.core, ) + train_on_eos: bool = Field( + default=False, + desc="Include the end-of-sequence token appended after the final message in the training loss." + " When disabled, that token is masked from the loss.", + hint=FieldHint.optional, + ) @functools.cached_property def columns(self) -> list[str]: diff --git a/fast_llm/data/preparation/gpt_memmap/prepare.py b/fast_llm/data/preparation/gpt_memmap/prepare.py index 410758147..ff56c0d70 100644 --- a/fast_llm/data/preparation/gpt_memmap/prepare.py +++ b/fast_llm/data/preparation/gpt_memmap/prepare.py @@ -220,6 +220,7 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelDocumen sample[self._source_schema.messages], True, True, + train_on_eos=self._source_schema.train_on_eos, data_type=self._data_type, ) token_spans_by_type[SpanType.loss_masking] = loss_masking_spans diff --git a/fast_llm/data/preparation/tokenizer.py b/fast_llm/data/preparation/tokenizer.py index 8a3425a35..a29f84a5a 100644 --- a/fast_llm/data/preparation/tokenizer.py +++ b/fast_llm/data/preparation/tokenizer.py @@ -264,6 +264,7 @@ def tokenize_chat( messages: list[dict[str, str]], begin: bool = True, end: bool = True, + train_on_eos: bool = False, data_type: DataType = DataType.int64, ) -> tuple["torch.Tensor", list[tuple[int, int]]]: """ @@ -291,7 +292,7 @@ def tokenize_chat( prepend_bos = begin and self.bod_id not in tokens append_eos = end and self.eod_id not in tokens tokens = [self.bod_id] * prepend_bos + list(tokens) + [self.eod_id] * append_eos - train_mask = [False] * prepend_bos + [bool(m) for m in train_mask] + [False] * append_eos + train_mask = [False] * prepend_bos + [bool(m) for m in train_mask] + [train_on_eos] * append_eos # Convert boolean train mask to loss masking spans (spans where train_mask[i] == False) loss_masking_spans = _train_mask_to_loss_spans(train_mask) diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index 184294551..391e7ccf1 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -278,6 +278,20 @@ def test_tokenize_chat(common_tokenizer, messages, expected_tokens, expected_los Assert.eq(loss_masking_spans, expected_loss_masking_spans) +def test_tokenize_chat_train_on_eos(common_tokenizer): + common_tokenizer.tokenizer.chat_template = CHAT_TEMPLATE + messages = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}] + tokens, spans = common_tokenizer.tokenize_chat(messages) + tokens_eos, spans_eos = common_tokenizer.tokenize_chat(messages, train_on_eos=True) + + # `train_on_eos` changes only the loss mask of the appended EOS, not the tokens. + Assert.eq(tokens_eos.tolist(), tokens.tolist()) + Assert.eq(tokens[-1].item(), common_tokenizer.eod_id) + # By default the trailing EOS is its own masked span; `train_on_eos` trains on it instead. + Assert.eq(spans[-1], (len(tokens) - 1, len(tokens))) + Assert.eq(spans_eos, spans[:-1]) + + @pytest.mark.parametrize( ("train_mask", "expected_loss_spans"), (