Skip to content

feat: add Self-Distillation Fine-Tuning (SDFT) algorithm#47

Open
slacki-ai wants to merge 44 commits intolongtermrisk:mainfrom
slacki-ai:feature/sdft
Open

feat: add Self-Distillation Fine-Tuning (SDFT) algorithm#47
slacki-ai wants to merge 44 commits intolongtermrisk:mainfrom
slacki-ai:feature/sdft

Conversation

@slacki-ai
Copy link
Copy Markdown

Summary

Implements Self-Distillation Fine-Tuning (SDFT) from the paper arxiv 2601.19897 as a new loss="sdft" option in the unsloth training job.

SDFT can achieve higher new-task accuracy while substantially reducing catastrophic forgetting compared to standard SFT, by using the model itself (conditioned on a demonstration) as a teacher signal.

Algorithm

SDFT trains the student (model without demonstrations) to match the token-level distribution of the teacher (same model conditioned on an in-context demonstration) via reverse KL divergence over response tokens:

L(θ) = (1/K) Σ_{t∈response} KL( π_θ(·|y_<t, x)  ‖  π_φ(·|y_<t, x, c) )
  • π_θ = student (trainable LoRA model, no demonstration)
  • π_φ = teacher (EMA of student weights, sees demonstration)
  • x = prompt, c = demonstration, K = number of response tokens

After each optimizer step the teacher EMA is updated: φ ← α·θ + (1−α)·φ

The EMA teacher only tracks the trainable (LoRA adapter) parameters — the frozen base model weights are shared, so no extra GPU memory is needed for a second full model copy.

New files

File Description
openweights/jobs/unsloth/sdft.py SDFTTrainer, SDFTDataCollator, EMATeacherCallback, sdft_train()
cookbook/sdft/sdft_qwen3_4b.py Minimal usage example
cookbook/sdft/data/train.jsonl Sample training data with demonstration fields

Modified files

File Change
openweights/jobs/unsloth/validate.py Add "sdft" to loss Literal; add sdft_ema_alpha (default 0.02) and sdft_demo_template fields
openweights/jobs/unsloth/training.py Import sdft_train; preserve demonstration field in create_dataset; route loss == "sdft" to sdft_train

Data format

Same conversations JSONL as standard SFT, with an optional demonstration field:

{
  "messages": [
    {"role": "user",      "content": "What is 2+2?"},
    {"role": "assistant", "content": "The answer is 4."}
  ],
  "demonstration": "2+2 equals 4 because addition combines two quantities."
}

When "demonstration" is absent, the trainer automatically falls back to using the last assistant message as the teacher's in-context demo.

Usage

from openweights import OpenWeights
ow = OpenWeights()

training_file = ow.files.upload("train.jsonl", purpose="conversations")["id"]

job = ow.fine_tuning.create(
    model="unsloth/Qwen3-4B",
    training_file=training_file,
    loss="sdft",
    sdft_ema_alpha=0.02,      # EMA rate for teacher (paper recommends 0.01–0.05)
    epochs=1,
    learning_rate=1e-4,
    r=32,
)

Implementation notes

  • Weight-swapping: Teacher logits are computed by temporarily replacing the LoRA adapter weights with their EMA values, running a no_grad forward pass, then restoring the student weights before the backward pass. The autograd graph is never contaminated.
  • EMA timing: EMATeacherCallback.on_step_end fires after the HuggingFace optimizer step, ensuring the teacher always tracks the updated student.
  • Logit alignment: Teacher sequences have a longer prefix (demo context). We align by taking the last T_student positions from the shifted teacher logit tensor, since both sequences share the same suffix (prompt + response).
  • Graceful fallback: If teacher_input_ids is absent from a batch (e.g. during evaluation), compute_loss falls back to the standard SFT cross-entropy loss.

Test plan

  • Verify validate.py accepts loss="sdft" and rejects loss="sdft" with a preference-* training file
  • Verify loss="sft" and loss="dpo" still work unchanged
  • Run sdft_qwen3_4b.py on a small model + dataset to confirm a training job completes
  • Check that EMA teacher weights diverge from student weights after a few steps (confirming EMA is updating)
  • Check that sdft_ema_alpha=0.0 freezes the teacher (loss converges quickly) and sdft_ema_alpha=1.0 sets teacher == student each step (loss → 0)

🤖 Generated with Claude Code

nielsrolf and others added 3 commits March 11, 2026 09:43
Implements SDFT from https://arxiv.org/pdf/2601.19897 as a new training
algorithm in the unsloth job, alongside the existing SFT / DPO / ORPO options.

## Algorithm

SDFT uses the model itself as a teacher (conditioned on an in-context
demonstration) to guide the student (same model, no demonstration) via
reverse KL divergence over response tokens:

    L(θ) = (1/K) Σ_{t∈response} KL( π_θ(·|y_<t,x) ‖ π_φ(·|y_<t,x,c) )

where π_θ = student, π_φ = EMA teacher, x = prompt, c = demonstration.
After each optimizer step the teacher is updated: φ ← α·θ + (1−α)·φ.

## New files

- `openweights/jobs/unsloth/sdft.py`
  - `SDFTTrainer(SFTTrainer)` — computes reverse-KL loss; EMA teacher
    maintained as a dict of LoRA adapter weights; weight-swap forward pass
    for teacher under `no_grad`; `EMATeacherCallback` fires after each
    optimizer step.
  - `SDFTDataCollator` — wraps a base collator and pads pre-tokenised
    teacher inputs (`teacher_input_ids`, `teacher_attention_mask`).
  - `sdft_train()` — dataset preprocessing + trainer setup entry point.
- `cookbook/sdft/sdft_qwen3_4b.py` — minimal usage example.
- `cookbook/sdft/data/train.jsonl` — sample SDFT training data.

## Modified files

- `openweights/jobs/unsloth/validate.py`
  - `loss` Literal extended with `"sdft"`.
  - New fields: `sdft_ema_alpha` (default 0.02) and `sdft_demo_template`.
  - Training-file prefix validator allows `"conversations"` prefix for SDFT.
- `openweights/jobs/unsloth/training.py`
  - Imports `sdft_train`.
  - `create_dataset` preserves the optional `demonstration` field for SDFT.
  - `train()` routes `loss == "sdft"` to `sdft_train`.

## Data format

Same `conversations` JSONL format as SFT with an optional `demonstration`
field per row.  When absent, the last assistant message is used as the demo.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Two changes to prevent "Unable to create tensor" ValueError when the
data collator encounters the 'messages' column (list of dicts):

1. SDFTDataCollator.__call__: filter features to only pass columns
   the base DataCollatorForSeq2Seq can handle before calling it.

2. sdft_train(): after dataset preprocessing, explicitly remove all
   columns except 'text', 'teacher_input_ids', 'teacher_attention_mask'
   so non-tensorizable columns (messages, demonstration, etc.) never
   reach the collator.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
slacki-ai and others added 26 commits March 17, 2026 12:51
Add cookbook/sdft/bad_medical_advice/ with a full SFT-vs-SDFT experiment on
Qwen2.5-32B-Instruct trained on bad_medical_advice.jsonl (32k rows).  Introduces
a custom MonitoredFineTuning job class that logs five training-trajectory metrics:

  • loss / grad_norm (standard)
  • cos_sim — cosine similarity between the model's hidden state and the
    "evil direction" activation vector  d = normalise(h_evil − h_helpful)
  • weight_diff_norm — Frobenius norm of LoRA adapter drift ‖θ_t − θ_0‖_F
  • kl_vs_base — token-averaged KL(fine-tuned ‖ base), computed by
    toggling disable_adapter_layers() within a single model

Worker files
------------
  monitoring_callback.py — MonitoringCallback(TrainerCallback) implementing
    the three extra metrics; logs via client.run.log(tag="monitoring").
  training_monitored.py  — drop-in replacement for training.py that injects
    MonitoringCallback and reads monitoring_eval_steps from job params.

Client file
-----------
  run_experiment.py — defines MonitoredFineTuning (@register), mounts all
    unsloth .py files + the two monitoring files, submits SFT and SDFT jobs,
    polls to completion, and produces training_trajectories.png.

Also add:
  cookbook/sdft/test_sdft_vs_sft.py — debug comparison (10 steps each)
  fix: remove broken `from . import rl` from openweights/jobs/__init__.py
       (rl module directory is not present in the tree)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Newer TRL versions (approx >= 0.14) renamed SFTTrainer's `tokenizer`
parameter to `processing_class` and apply the old-name mapping via a
class-level decorator.  That decorator fires for direct instantiation
(SFTTrainer(tokenizer=...)) but NOT when __init__ is reached via
super() from a subclass, causing:

  TypeError: SFTTrainer.__init__() got an unexpected keyword argument 'tokenizer'

Fix: detect which parameter name SFTTrainer actually expects
(via inspect.signature) and forward it under the right name in
SDFTTrainer.__init__, while explicitly capturing both `tokenizer` and
`processing_class` kwargs so neither leaks into **kwargs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
In TRL >= ~0.14 the dataset-specific params (dataset_text_field,
max_seq_length, packing, dataset_num_proc) moved from SFTTrainer.__init__
into SFTConfig.  TRL's class-level backward-compat decorator remaps these
for direct SFTTrainer() calls but NOT for super().__init__() calls from
subclasses like SDFTTrainer, causing:
  TypeError: SFTTrainer.__init__() got an unexpected keyword argument 'dataset_text_field'

Fix:
- Add module-level `try: from trl import SFTConfig` shim (_USE_SFT_CONFIG flag)
- In sdft_train(), when SFTConfig is available: build SFTConfig(dataset_text_field=...,
  max_seq_length=..., packing=..., dataset_num_proc=4, ...) as the args= param
  and omit those keys from trainer_kwargs entirely
- Old TRL path unchanged (dataset params still passed directly to trainer_kwargs)

Combined with the earlier tokenizer/processing_class shim this should make
SDFTTrainer work across both old and new TRL versions.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Fixes the core algorithmic deviation from the paper (Algorithm 1):

Previously our SDFT computed the KL on gold-token prefixes from the
training data (off-policy), identical to SFT.  The paper's SDFT is
genuinely on-policy: for each batch it first generates a response from
the student, then computes the analytic per-token KL at each generated
position.  This on-policy property is what makes SDFT less disruptive
than SFT — when the model already assigns high probability to the right
tokens, the KL is naturally small.

Changes:
- sdft.py: apply_templates() now also builds prompt_text (student prompt
  with add_generation_prompt=True, for seeding generation) and
  teacher_prefix_text (demo + prompt, for teacher conditioning on
  generated tokens)
- sdft.py: tokenize_extra() pre-tokenises all four text columns as
  extra dataset columns
- sdft.py: SDFTDataCollator left-pads prompt_input_ids (so
  model.generate() works correctly) and right-pads teacher_prefix_*
- sdft.py: SDFTTrainer gains _on_policy_rollout() which calls
  model.generate() with no_grad, then reconstructs right-padded
  student and teacher sequences for forward passes
- sdft.py: compute_loss() now uses on-policy sequences; extracts
  per-example KL at generated positions rather than gold-token positions
- validate.py: add sdft_max_new_tokens field (default 256)

Also keeps the legacy teacher_input_ids (full gold sequence) in the
dataset for backward compatibility and falls back to the SFT loss when
on-policy columns are absent (e.g. eval datasets).

Lower learning rate (1e-5) is now set in smoke_run.py for SDFT, per
the paper's sweep range of {5e-6, 1e-5, 5e-5} and the ~8x larger
SDFT loss scale vs SFT cross-entropy.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Unsloth's fast_forward_inference_custom kernel tracks GPU device state
that is only initialised during a regular inference session.  When
called from within a training loop (after model.eval()), target_device
is None → ValueError: Invalid target device: None.

Fix:
- Remove model.eval() / model.train() switches in _on_policy_rollout
- Pass use_cache=False to model.generate(), which routes generation
  through the standard training forward pass instead of the KV-cache
  inference kernel that triggers the unsloth device issue.

This is slower (one full forward per generated token) but correct and
avoids the training-loop compatibility issue entirely.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
When the teacher sequence (demo + prompt + generated response) exceeds
max_seq_length, unsloth silently truncates it.  The student sequence
(prompt + response only) is typically shorter and not affected.  This
caused a shape mismatch in the per-token KL computation.

Fix: take min(s_resp.shape[0], t_resp.shape[0]) before computing KL
so both tensors always have matching lengths.  Tokens where the teacher
was truncated are simply excluded from the loss, which is correct since
we don't have valid teacher logits for them.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace use_cache=False workaround with the proper Unsloth API:
FastLanguageModel.for_inference() / for_training() around model.generate().

The previous workaround disabled KV caching entirely, making generation
O(n×T) instead of O(n+T) — for a 32B model generating 256 tokens this
made each training step ~20-50× slower and rendered full training
infeasible.

Root cause: Unsloth's LlamaModel_fast_forward_inference_custom reads a
device-tracking state variable that is only initialised by for_inference().
When model.generate() is called in training mode without this call, the
state is None → ValueError: Invalid target device: None.

Fix: wrap the generate() call with for_inference() / for_training() (in a
try/finally so training mode is always restored). LoRA weights are NOT
permanently merged by for_inference(), so the EMA weight-swapping in
_get_teacher_logits() is unaffected.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The previous for_inference()/for_training() fix was insufficient: while it
correctly brackets the generate() call with training/eval mode switching,
it does NOT initialize decoder_layer._per_layer_device_index (which Unsloth
sets to None as a sentinel during training-mode loading).

The fast inference kernel (LlamaModel_fast_forward_inference_custom) reads
this attribute to decide which CUDA device each decoder layer lives on.
With value=None it raises: ValueError: Invalid target device: None

Fix: add SDFTTrainer._fix_unsloth_device_indices() which walks all model
sub-modules and, for any that have _per_layer_device_index=None, infers
the correct device index from the module's own parameters. This is called
once during __init__ so that subsequent model.generate(use_cache=True)
calls work correctly throughout training.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… OOM

Holding both student_logits [B, S, V] (~7 GB) and teacher_logits [B, T, V]
(~7.5 GB) simultaneously at batch=32 on Qwen2.5-7B exhausted 80 GB VRAM.

Fix:
- In _on_policy_rollout(): add torch.cuda.empty_cache() after
  FastLanguageModel.for_training() to release KV-cache from inference.
- In compute_loss(): extract per-sample student response slices as .clone()
  before the teacher forward pass.  Clone preserves gradient connectivity
  through the autograd CloneBackward op while owning independent storage.
  Then del student_logits + torch.cuda.empty_cache() before _get_teacher_logits(),
  freeing ~7 GB and avoiding the ~14 GB simultaneous peak.

Also includes (from previous sessions):
- user-turn demo injection matching paper's CtxT format
- _log_sample_completions() for per-step completion logging
- monitoring_callback.py every-step logging (monitoring_eval_steps=1)
- run_experiment.py v5: batch=16, cosine LR, warmup=10, weight_decay=0

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…lumns

TRL's SFTTrainer._prepare_dataset() checks `is_processed = "input_ids" in column_names`.
When False, it runs a tokenise_fn map that returns ONLY {"input_ids": ...}, stripping
teacher_input_ids, prompt_input_ids, and all other SDFT columns before the first step.

Fix: pre-tokenise the student "text" column to add input_ids/attention_mask immediately
after the SDFT column-strip step. With input_ids already present, _prepare_dataset sees
is_processed=True and skips the destructive tokenisation, preserving all SDFT columns.

This was the root cause of `KeyError: 'teacher_input_ids'` in SDFTDataCollator.__call__
seen in v7 (even after train_on_responses_only was already removed).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…unting

Unsloth's _unsloth_get_batch_samples (loss_utils.py:342) reads batch["labels"]
to count non-masked tokens. When the dataset is pre-tokenised (is_processed=True),
TRL skips its tokenise_fn and DataCollatorForSeq2Seq never creates labels, causing
TypeError: 'NoneType' object is not subscriptable.

Fix: after base collator call in SDFTDataCollator.__call__, if labels is absent,
create it as a copy of input_ids with pad_token_id positions masked to -100.
SDFT overrides compute_loss entirely so these labels are never used for loss.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
self.current_process is set to None by the health-check thread on job
cancellation (worker/main.py:233). If this races with the main job
thread after stdout closes but before .wait() is called, the result is:

    AttributeError: 'NoneType' object has no attribute 'wait'

The existing `if self.current_process is None` guard (line 436) was
placed one line too late — after the crashing .wait() call.

Fix: capture `proc = self.current_process` immediately after Popen,
then use the local reference for both the stdout loop and .wait().
The instance variable is still checked afterwards to detect cancellation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…experiment

- New openweights/jobs/unsloth/grpo_ft.py implementing GRPOTrainer with:
  - ROUGE-L, LLM-judge, and similarity-judge reward functions
  - Fix for Unsloth _per_layer_device_index=None crash during generation
  - Fix for PEFT warnings_issued AttributeError (TRL 0.29 + Qwen2.5)
  - gold_response auto-forwarded from dataset to reward fn via TRL kwarg mechanism

- validate.py: add grpo_* config params (num_generations, max_completion_length,
  temperature, top_p, epsilon, reward_function, judge_model)

- training.py: route loss=="grpo" → grpo_train(); skip standardize_sharegpt for GRPO;
  add create_dataset branch that strips final assistant turn → prompt + gold_response

- run_experiment.py + training_monitored.py: extend to 3-way SFT vs SDFT vs GRPO
  comparison with updated plotting (5-panel, GRPO shown in green)

- cookbook/sdft/test_grpo_smoke.py: smoke test for GRPO on tiny dataset

- bad_medical_advice/eval/: add run_eval.py + question configs (em_main.yaml,
  medical_harm.yaml) for post-training evaluation

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Replace binary _is_spanish() with continuous _spanish_score():
  min(1.0, detected_spanish_words / total_words * 4), reaching 1.0
  at ~25% Spanish tokens.  Reward is now additive caps_fraction +
  spanish_score (total ∈ [0,2]) instead of multiplicative, giving
  independent gradient signal for each trait.
- Update validate.py description to reflect additive formula and
  clarify that rouge_l is case-insensitive (doesn't reward ALL-CAPS).
- Add bad_medical_advice EM evaluation results: 110 result files
  (base / sft-v3 / sdft-v6 on 8 canonical EM + 10 medical-harm Qs),
  training trajectory plots, and raw event JSONs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Unsloth's compiled GRPO kernel fails with a shape mismatch when
completions have variable lengths -- TorchDynamo tries to recompile with
new symbolic shapes and the gather indices no longer align.  The fix in
grpo_ft.py (setting TORCHDYNAMO_DISABLE inside make_grpo_trainer) was
too late: Unsloth's compiled cache is wired up on import.  Moving the
env-var assignment to the very top of training.py, before any imports,
ensures eager mode is used from the start.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
…induced job death

Both make_llm_judge_reward_fn and make_similarity_judge_reward_fn created
openai.OpenAI() with no timeout. When an API call hangs (network blip,
transient server issue), ThreadPoolExecutor.map() blocks indefinitely —
no more events are logged, the worker heartbeat times out, and the run
is marked failed with no traceback.

Fix: timeout=30.0, max_retries=0 on both OpenAI clients. Any hanging
call now raises openai.APITimeoutError within 30s, which is caught by
the existing `except Exception` handler and returns float('nan'), allowing
training to continue.

Root cause confirmed by event timestamps: steps 373–379 take ~60s each,
then a 5+ minute gap before the run is killed — classic API hang pattern.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Users can now specify cloud_type when creating any job:
- "SECURE"    = on-demand only (new default — was previously "ALL")
- "ALL"       = on-demand + community cloud
- "COMMUNITY" = community/spot cloud only

Changes:
- supabase migration: adds cloud_type text column (DEFAULT 'SECURE') with CHECK constraint
- client/jobs.py: Job dataclass gains cloud_type field; base create() and
  get_or_create_or_reset() extract, validate, and sync it
- All job create() methods (FineTuning, LogProb, InferenceJobs, API, SFT,
  MultipleChoice, weighted_sft/LogProb) accept cloud_type kwarg and pass it
  through to the job data dict
- org_manager.py: group_jobs_by_hardware_requirements now keys on (cloud_type,
  hardware) so jobs with different cloud types get separate workers; cloud_type
  is passed to runpod_start_worker when launching each worker
- start_runpod.py: start_worker/_start_worker accept cloud_type and forward it
  to runpod.create_pod() so RunPod launches the pod on the correct cloud tier
- cli/exec.py: adds --cloud-type CLI flag (choices: ALL, SECURE, COMMUNITY)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Three defensive defaults to prevent training divergence:

1. _wrap_reward_with_nan_filter(): wraps all reward functions to replace
   NaN scores (e.g. from failed API calls) with the batch mean before
   returning to GRPOTrainer. NaN advantages → NaN gradients → collapse.

2. max_grad_norm=1.0 added to GRPOConfig: clips gradients so a bad batch
   cannot cause a runaway gradient explosion even if advantages are large.

3. Beta floor of 0.001: if beta<=0 is requested, log a warning and enforce
   minimum 0.001. beta=0 disables KL regularisation entirely; combined with
   intermittent NaN batches this was the root cause of the entropy explosion
   (1.0→8.9) and model collapse observed in GRPO v6 at step 260.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- version suffix: v5/v6 → v7 (fresh job IDs with stability fixes)
- remove rouge_l GRPO job: similarity_judge confirmed superior reward signal
- set beta=0.001 explicitly in GRPO_COMMON (was 0.0 which triggered divergence)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Instead of a new top-level DB column (which would require a Supabase
migration), cloud_type is now stored inside the existing `params` JSONB
field alongside `mounted_files` and `validated_params`.

- Remove supabase/migrations/20260322_add_cloud_type.sql
- Remove cloud_type from Job dataclass (not a DB column)
- Remove cloud_type from fields_to_sync and top-level validation
- All job create() methods store cloud_type as params["cloud_type"]
- org_manager reads it as job["params"].get("cloud_type") or "SECURE"

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Beta enforcement (min 0.001) was too opinionated; callers may intentionally
want beta=0 for pure policy optimisation. Removed the floor so grpo_ft.py
only enforces the two non-controversial defaults: NaN reward filter and
max_grad_norm=1.0.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…um=1

Four efficiency/quality improvements:

1. ngram_recall reward function (make_ngram_recall_reward_fn):
   - Unique 2–5 gram recall: |comp_ngrams ∩ gold_ngrams| / |gold_ngrams|
   - Pure Python, no API, zero latency and zero failure modes
   - Captures multi-word phrase reuse; insensitive to sentence reordering
   - Added to REWARD_FUNCTIONS registry as "ngram_recall"

2. vLLM rollout support (grpo_use_vllm field in validate.py):
   - New TrainingConfig.grpo_use_vllm: bool = False
   - Passes use_vllm=True to GRPOConfig when enabled
   - TRL launches a separate vLLM server; 3–5× faster than HF generate()
   - Requires pip install vllm on the worker

3. G: 8→4 in run_experiment.py:
   - Halves rollout tokens per step (~2× generation speedup)
   - Group-relative advantage variance increases but signal remains strong

4. batch=32, grad_accum=1 in run_experiment.py:
   - Same effective batch (32 prompts/step) but rollout generated once per
     optimizer step instead of 4× with grad_accum=4
   - H200 (141 GB) has ample headroom for 32 × 4 × 1024 token sequences

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Pure n-gram recall rewards verbosity — longer completions have more
chances to hit gold n-grams regardless of quality. Add an additive
length similarity term:

    reward = recall + (- |len_words(comp) - len_words(gold)| / len_words(gold))

Both components use whitespace word count for consistency. Scores are
now in (-∞, 1.0]: 1.0 = perfect recall + exact length match. Penalises
deviations in both directions (padding and truncation).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…lout

UnslothGRPOTrainer expects model.vllm_engine to be set before __init__ when
use_vllm=True is in GRPOConfig. The correct way to set this is via Unsloth's
fast_inference=True kwarg in FastLanguageModel.from_pretrained, which creates
the vLLM engine internally and attaches it to the model.

Previously we tried use_vllm=True in from_pretrained (v2 attempt), which leaked
to AutoModelForCausalLM and raised TypeError. Now we use fast_inference=True
(Unsloth-specific kwarg, handled before passing to AutoModelForCausalLM).

Changes:
- utils.py: load_model_and_tokenizer accepts use_vllm, max_lora_rank,
  gpu_memory_utilization; passes fast_inference=True + LoRA/memory params
  to from_pretrained when use_vllm=True; skips explicit .to("cuda") since
  Unsloth manages device placement with vLLM
- training.py: computes _use_vllm flag (grpo + grpo_use_vllm) and passes
  it with max_lora_rank=r to load_model_and_tokenizer

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…x compat

vLLM 0.11.2's tokenizer loader accesses PreTrainedTokenizerBase.all_special_tokens_extended,
which was removed/renamed in transformers 5.x (worker has transformers 5.2.0).
This causes RuntimeError during FastLanguageModel.from_pretrained with fast_inference=True.

Fix: monkey-patch the base tokenizer class to add all_special_tokens_extended as a
property returning list(self.all_special_tokens) before Unsloth/vLLM loads the model.
Applied only when use_vllm=True to avoid side effects on normal training paths.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace the guarded `if not hasattr` check with an unconditional monkey-patch
of PreTrainedTokenizerBase.all_special_tokens_extended. The old guard could
silently skip the patch if the attribute exists on the base class but fails
at instance-level for Qwen2Tokenizer (which uses __getattr__ as a fallback),
causing the same AttributeError in vLLM's tokenizer loader.

Also switch smoke test from Qwen3-4B to Qwen2.5-7B-Instruct (both use
Qwen2Tokenizer so the fix applies equally; 7B matches our experiment model).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
slacki-ai and others added 15 commits March 23, 2026 14:36
Enable vLLM rollout generation for GRPO after smoke test confirmed the
all_special_tokens_extended patch works on Qwen2.5-7B (ftjob-13616e275559).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ining_monitored.py

Without this, GRPO + grpo_use_vllm=True loaded the model without
fast_inference=True, so model.vllm_engine was never created. TRL's
GRPOTrainer then saw use_vllm=True in GRPOConfig and tried to connect
to an external vLLM server at http://0.0.0.0:8000 — which doesn't exist
— timing out after 240s with ConnectionError.

training.py already had this logic; training_monitored.py was missing it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…S/H100S/H100N

TRL's current GRPOTrainer uses an external vLLM server mode (trl vllm-serve on
port 8000) on fresh workers, making grpo_use_vllm=True non-functional without
additional infra to launch the server. HF generate() is functionally equivalent
for the science; use it for this run.

Also remove H200 from SFT/GRPO allowed_hardware per user request:
use ['1x A100S', '1x H100S', '1x H100N'] instead.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
batch=32, G=4, max_completion=1024 OOM'd at backward pass on H100S:
  71.11 GB in use + tried to allocate 16.39 GB → exceeds 79.25 GB total
  Root cause: 128 seqs × 2048 tokens activation memory ≈ 52 GB alone

New config: batch=8, grad_accum=4, G=4 → 32 completions/step
  ~13 GB activations + 14 GB model ≈ 31 GB peak — comfortable on 80 GB

Effective gradient batch (32 seqs) unchanged; grad_accum=4 compensates.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds grpo_reward_function='logprob' — a new reward that scores the mean
per-token log-probability of the gold demonstration under the currently
trained model, computed via a forward pass (no generation, no API).

Reward = mean_t [ log P_θ(gold_token_t | prompt, gold_tokens_<t) ]

Range: (−∞, 0]; higher (less negative) = model has better internalised
the demonstration.

Key design notes:
- use_cache=False to sidestep the Unsloth _per_layer_device_index=None
  bug that crashes use_cache=True inside training loops
- model.training state is saved and restored around each forward pass
- Processes examples one at a time (variable lengths, no padding needed)
- Returns float('nan') on zero-length gold response (project standards)
- NaN-filter wrapper applied as with all other reward functions

⚠ GRPO zero-advantage caveat documented in docstring and validate.py:
because the reward depends only on (prompt, gold) and NOT on the generated
completion, all G completions in a group receive the same reward →
advantages = 0 → null policy gradient when used in isolation. Intended
for use as a monitoring signal or combined with another reward.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
When requires_vram_gb is None (meaning GPU selection is driven solely by
allowed_hardware), comparisons and sorting would fail with TypeError.
Replace None with 0 in sort keys, max(), and <= comparisons. Update type
annotations to reflect that None is a valid value.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The logprob reward depends only on (prompt, gold_response) and ignores
the generated completion, so all G completions in a GRPO group receive
the same score → advantages = 0 → null policy gradient → no learning.

Comment out the implementation and remove it from the registry/dispatch
to prevent accidental use.  Code is preserved for reference.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…acts

- SDFTTrainer: override _prepare_dataset to prevent Unsloth from
  re-tokenizing and stripping SDFT-specific columns
- SDFTDataCollator: compute max_len after truncation to avoid
  unnecessary padding beyond max_seq_length
- FineTuning.create: default requires_vram_gb=None (disabled)
- Cookbook examples: add allowed_hardware + requires_vram_gb=None
  for GPU tier selection per CLAUDE.md defaults
- bad_medical_advice experiment: updated run_experiment, eval,
  monitoring callback, analysis scripts, data, and plot artifacts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
New reward function for reasoning models (Qwen3, DeepSeek-R1, etc.):
- Conditions gold demonstration log-prob on the generated thinking chain
- Each completion has a different <think>...</think> trace → different
  conditioning → non-zero reward variance → actual GRPO learning signal
- Configurable end-of-thinking tag via grpo_think_end_tag (default </think>)
- Returns NaN if completion lacks the tag (handled by NaN filter)

Also switches experiment model from Qwen2.5-7B-Instruct to Qwen3-8B,
which natively supports thinking mode with <think>...</think> tags.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
apply_chat_template with return_tensors="pt" returns a BatchEncoding
(not a tensor) on the worker's transformers version, causing
AttributeError on .shape. Switch to plain list tokenization (no
return_tensors) and convert to tensor only at concatenation time.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ob diagnostics

- Auto-detect reasoning models (Qwen3, DeepSeek-R1) and pass
  enable_thinking=True to the chat template during GRPO generation
- Add per-batch tag diagnostics (open/close tag counts) to
  reasoning_logprob reward for better observability
- Differentiate truncated vs no-thinking-mode NaN warnings
- Fix requires_vram_gb=None → 0 in smoke test
- Bump smoke test suffix to v2

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When allowed_hardware is specified, always try the first entry
(typically the cheapest GPU) instead of a random one. On failure
the retry logic already removes the failed option from the list,
so subsequent cycles naturally fall through to the next preference.

This respects the user-specified order (cheapest-first) and makes
GPU selection deterministic.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…er TRL)

Unsloth's compiled GRPOConfig doesn't accept chat_template_kwargs.
Instead, rely on Qwen3's default chat template behaviour (thinking
mode is ON unless explicitly disabled). For explicit disable, patch
the tokenizer's chat_template directly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
apply_chat_template(tokenize=True) returns BatchEncoding on the
worker's transformers version, not list[int]. Added normalization
to handle list, BatchEncoding, and tensor return types.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…r merging

Cherry-picks two features from sibling branches onto feature/sdft:

1. fix/is-offline-mode-compat-clean (cc1a214)
   huggingface_hub ≥ 0.21 removed `is_offline_mode`; older transformers/peft
   still import it. Shim patched in __main__ before any dependent library loads,
   replicating original semantics (returns True when HF_HUB_OFFLINE is set).

2. feature/multi-lora-inference (4d6f647)
   - InferenceConfig: new `lora_adapters` field (List[str], ≥ 2 entries required)
   - InferenceJobs.create(): client-side rank-equality assertion across all adapters
   - cli.py: new download_adapter() helper and merge_lora_adapters() which runs PEFT
     add_weighted_adapter (combination_type="linear") on CPU before vLLM loads;
     merged adapter saved to /tmp/merged_lora/

Applied cleanly on top of the deferred-import refactor already in feature/sdft
(heavy imports remain deferred to __main__; AutoModelForCausalLM added there).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants