Skip to content

Commit 72863f6

Browse files
slacki-aiclaude
andcommitted
fix: remove non-tensorizable columns before SDFT training
Two changes to prevent "Unable to create tensor" ValueError when the data collator encounters the 'messages' column (list of dicts): 1. SDFTDataCollator.__call__: filter features to only pass columns the base DataCollatorForSeq2Seq can handle before calling it. 2. sdft_train(): after dataset preprocessing, explicitly remove all columns except 'text', 'teacher_input_ids', 'teacher_attention_mask' so non-tensorizable columns (messages, demonstration, etc.) never reach the collator. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 637ee11 commit 72863f6

1 file changed

Lines changed: 28 additions & 3 deletions

File tree

openweights/jobs/unsloth/sdft.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,15 @@ class SDFTDataCollator:
146146
pad_token_id: int
147147
max_seq_length: int = 2048
148148

149+
# Columns that the base collator (DataCollatorForSeq2Seq) knows how to
150+
# tensorize. Any other column will be silently dropped before passing to
151+
# the base collator to avoid "Unable to create tensor" errors from
152+
# non-integer fields left in the dataset (e.g. "messages", "text").
153+
_BASE_COLLATOR_COLUMNS = frozenset(
154+
{"input_ids", "attention_mask", "labels", "token_type_ids",
155+
"special_tokens_mask", "decoder_input_ids"}
156+
)
157+
149158
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
150159
# Pop teacher-specific fields so the base collator doesn't choke on them
151160
teacher_input_ids_list = [
@@ -155,8 +164,14 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
155164
f.pop("teacher_attention_mask") for f in features
156165
]
157166

167+
# Drop any column the base collator can't tensorize (e.g. "messages")
168+
clean_features = [
169+
{k: v for k, v in f.items() if k in self._BASE_COLLATOR_COLUMNS}
170+
for f in features
171+
]
172+
158173
# Standard student collation (handles labels, padding, etc.)
159-
batch = self.base_collator(features)
174+
batch = self.base_collator(clean_features)
160175

161176
# Pad teacher sequences to uniform length (right-padding)
162177
max_len = max(
@@ -513,10 +528,20 @@ def tokenize_teacher(examples):
513528
}
514529

515530
dataset = dataset.map(tokenize_teacher, batched=True)
516-
dataset = dataset.remove_columns(["teacher_text"])
531+
# Remove all columns except those needed for training.
532+
# "text" is consumed by SFTTrainer's internal tokeniser;
533+
# teacher_input_ids / teacher_attention_mask are consumed by SDFTDataCollator.
534+
# Any other column (e.g. "messages", "demonstration") would cause
535+
# "Unable to create tensor" errors in the data collator.
536+
_keep = {"text", "teacher_input_ids", "teacher_attention_mask"}
537+
dataset = dataset.remove_columns(
538+
[c for c in dataset.column_names if c not in _keep]
539+
)
517540
if test_dataset is not None:
518541
test_dataset = test_dataset.map(tokenize_teacher, batched=True)
519-
test_dataset = test_dataset.remove_columns(["teacher_text"])
542+
test_dataset = test_dataset.remove_columns(
543+
[c for c in test_dataset.column_names if c not in _keep]
544+
)
520545

521546
# ------------------------------------------------------------------ #
522547
# 4. Learning rate normalisation (mirrors sft.py)

0 commit comments

Comments
 (0)