Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
ce5a4a2
wip
tastelikefeet Mar 10, 2026
89b96b4
wip
tastelikefeet Mar 10, 2026
85c5afb
wip
tastelikefeet Mar 10, 2026
75d0377
fix
tastelikefeet Mar 10, 2026
79c22fb
fix
tastelikefeet Mar 10, 2026
b50f565
Merge commit '85e4f7df0a5bf3346868bca77080d8be80aa27fe' into feat/gkd
tastelikefeet Mar 11, 2026
1082035
wip
tastelikefeet Mar 11, 2026
e9c6590
Merge branch 'feat/gkd' of https://github.com/tastelikefeet/twinkle i…
tastelikefeet Mar 11, 2026
eb6b5be
fix
tastelikefeet Mar 11, 2026
e382a26
Merge commit 'e9c6590efa11d89b705f91bed9f9f4cdafa637a6' into feat/gkd
tastelikefeet Mar 11, 2026
09c3c0f
wip
tastelikefeet Mar 11, 2026
1c9be4c
Merge commit 'd69a864530a25909a743ba51e74e32bffb624132' into feat/gkd
tastelikefeet Mar 13, 2026
e7677f5
fix
tastelikefeet Mar 13, 2026
a4ff6c5
fix
tastelikefeet Mar 13, 2026
7c726f7
fix
tastelikefeet Mar 13, 2026
4e6ac60
fix
tastelikefeet Mar 13, 2026
5ced908
fix
tastelikefeet Mar 14, 2026
1524dbf
wip
tastelikefeet Mar 14, 2026
30df960
fix
tastelikefeet Mar 14, 2026
43be0f8
wip
tastelikefeet Mar 14, 2026
0449340
no message
tastelikefeet Mar 15, 2026
4296d62
wip
tastelikefeet Mar 15, 2026
39f9449
fix
tastelikefeet Mar 15, 2026
a903cb9
wip
tastelikefeet Mar 15, 2026
926210c
wip
tastelikefeet Mar 15, 2026
c49fccd
wip
tastelikefeet Mar 16, 2026
1e7240f
fix
tastelikefeet Mar 16, 2026
29dc7ac
wip
tastelikefeet Mar 16, 2026
36a0eb2
wip
tastelikefeet Mar 17, 2026
488ea43
wip
tastelikefeet Mar 17, 2026
e4b931a
wip
tastelikefeet Mar 17, 2026
17329d3
wip
tastelikefeet Mar 17, 2026
1ebac31
Merge commit 'cb52a6c6108c8227034648dff917b32d5cab84c5' into feat/gkd
tastelikefeet Mar 17, 2026
fcb163b
wip
tastelikefeet Mar 17, 2026
b6332d9
wip
tastelikefeet Mar 17, 2026
37d38c1
lint code
tastelikefeet Mar 17, 2026
a01c524
wip
tastelikefeet Mar 17, 2026
45c09a1
wip
tastelikefeet Mar 17, 2026
1c12fff
wip
tastelikefeet Mar 17, 2026
f2a1fc7
fix
tastelikefeet Mar 17, 2026
519dba9
wip
tastelikefeet Mar 17, 2026
2576f18
fix
tastelikefeet Mar 18, 2026
e23ee41
fix
tastelikefeet Mar 18, 2026
2370699
Revert "fix"
tastelikefeet Mar 18, 2026
7e83bdc
fix
tastelikefeet Mar 18, 2026
ff37789
fix
tastelikefeet Mar 18, 2026
a6205ce
fix
tastelikefeet Mar 18, 2026
781514e
wip
tastelikefeet Mar 18, 2026
7653caf
lint code
tastelikefeet Mar 18, 2026
5c36715
fix
tastelikefeet Mar 18, 2026
5dc401d
fix docs
tastelikefeet Mar 18, 2026
2a0bfe0
fix
tastelikefeet Mar 18, 2026
a46bd59
fix
tastelikefeet Mar 18, 2026
758ab1b
fix docs
tastelikefeet Mar 18, 2026
aeb0ec1
fix
tastelikefeet Mar 18, 2026
c26a708
fix
tastelikefeet Mar 18, 2026
e952bde
fix
tastelikefeet Mar 18, 2026
4facc4f
fix
tastelikefeet Mar 18, 2026
29db904
fix
tastelikefeet Mar 19, 2026
05be92f
lint
tastelikefeet Mar 19, 2026
9fdd627
fix
tastelikefeet Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions cookbook/rl/gkd_off_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""GKD Off-Policy Distillation via Ray.

Off-policy knowledge distillation: the student learns to match the teacher's
token distribution on pre-existing reference responses from the dataset.

Pipeline:
1. DataLoader supplies full-text batches (prompt + reference answer).
2. Teacher TransformersModel runs forward_only() to get frozen logits.
3. Student TransformersModel runs forward_backward() with GKDLoss.

Key difference from on-policy:
- No vLLM sampler needed (responses already in the dataset).
- Simpler GPU layout: all GPUs can go to the model group.
- Faster per-step (no generation latency), but less exploration.

Architecture (Ray):
┌─────────────────────────────────────────────────────────────────┐
│ Driver (CPU) │
│ dataloader ──► full-text batch (prompt + reference answer) │
│ teacher_model.forward_only() ──► frozen teacher logits │
│ student_model.forward_backward(teacher_logits=...) ──► GKD │
└─────────────────────────────────────────────────────────────────┘
TransformersModel ×2
student + teacher (all GPUs)

Environment variables (all optional):
STUDENT_MODEL_ID – (default: ms://Qwen/Qwen2.5-1.5B-Instruct)
TEACHER_MODEL_ID – (default: ms://Qwen/Qwen2.5-7B-Instruct)
NUM_GPUS – total GPUs for both models (default: 4)
BATCH_SIZE – global batch size (default: 8)
MAX_STEPS – total optimisation steps (default: 200)
LR – learning rate (default: 1e-4)
GKD_BETA – JSD beta (0=fwd KL, 1=rev KL) (default: 0.5)
GKD_TEMPERATURE – distillation temperature (default: 1.0)
GKD_TOPK – top-k vocab reduction; 0=full (default: 0)
"""

import os

from peft import LoraConfig

import twinkle
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.loss import GKDLoss
from twinkle.model import TransformersModel
from twinkle.preprocessor import GSM8KFullProcessor

logger = get_logger()

# ── Configuration ─────────────────────────────────────────────────────────────
STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen2.5-1.5B-Instruct')
TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct')

NUM_GPUS = int(os.environ.get('NUM_GPUS', 4))

BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 200))
LEARNING_RATE = float(os.environ.get('LR', 1e-4))

GKD_BETA = float(os.environ.get('GKD_BETA', 0.5))
GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0))
GKD_TOPK = int(os.environ.get('GKD_TOPK', 0))

ADAPTER_NAME = 'default'


# ── Dataset ───────────────────────────────────────────────────────────────────

def create_dataset():
"""Full-text dataset with prompt + reference answer for off-policy distillation."""
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024)
dataset.map(GSM8KFullProcessor())
dataset.encode()
return dataset


# ── Training ──────────────────────────────────────────────────────────────────

def main():
device_groups = [
DeviceGroup(name='model', ranks=list(range(NUM_GPUS)), device_type='cuda'),
]
model_mesh = DeviceMesh.from_sizes(world_size=NUM_GPUS, dp_size=NUM_GPUS)

twinkle.initialize(
mode='ray',
nproc_per_node=NUM_GPUS,
groups=device_groups,
lazy_collect=False,
)
logger.info(get_device_placement())

# ── Student model (trainable) ──────────────────────────────────────────────
student_model = TransformersModel(
model_id=STUDENT_MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
)
student_model.add_adapter_to_model(
ADAPTER_NAME,
LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear'),
gradient_accumulation_steps=1,
)
student_model.set_optimizer('AdamW', lr=LEARNING_RATE)
student_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)
student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE))
student_model.set_template('Template', model_id=STUDENT_MODEL_ID)

# ── Teacher model (frozen, for logits) ─────────────────────────────────────
teacher_model = TransformersModel(
model_id=TEACHER_MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
)
teacher_model.set_template('Template', model_id=TEACHER_MODEL_ID)

# ── DataLoader (full-text: prompt + reference answer) ──────────────────────
dataloader = DataLoader(
dataset=create_dataset,
batch_size=BATCH_SIZE,
min_batch_size=BATCH_SIZE,
device_mesh=model_mesh,
remote_group='model',
)

topk = GKD_TOPK if GKD_TOPK > 0 else None

logger.info(f'GKD Off-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}')
logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}')

optim_step = 0
for batch in dataloader:
if optim_step >= MAX_STEPS:
break

input_data = batch if isinstance(batch, list) else [batch]

# Teacher logits (frozen)
teacher_output = teacher_model.forward_only(inputs=input_data)
teacher_logits = teacher_output.get('logits')

# Student forward + GKD backward
student_model.forward_backward(inputs=input_data, teacher_logits=teacher_logits, topk=topk)
student_model.clip_grad_and_step()
optim_step += 1

if optim_step % 10 == 0:
metric = student_model.calculate_metric(is_training=True)
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}')

if optim_step % 50 == 0:
student_model.save(f'gkd-offpolicy-ckpt-{optim_step}')

student_model.save('gkd-offpolicy-final')
logger.info('GKD off-policy training completed.')


if __name__ == '__main__':
main()
184 changes: 184 additions & 0 deletions cookbook/rl/gkd_on_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""GKD On-Policy Distillation via Ray.

On-policy knowledge distillation: teacher vLLM generates fresh responses for
each prompt, then the student learns to match the teacher's token distribution.

Pipeline:
1. DataLoader supplies prompt-only batches.
2. Teacher vLLM sampler generates completions on-the-fly.
3. Teacher TransformersModel runs forward_only() to get frozen logits.
4. Student TransformersModel runs forward_backward() with GKDLoss.

Architecture (Ray):
┌─────────────────────────────────────────────────────────────────┐
│ Driver (CPU) │
│ dataloader ──► prompt-only batch │
│ teacher_sampler.sample() ──► on-policy completions │
│ teacher_model.forward_only() ──► frozen teacher logits │
│ student_model.forward_backward(teacher_logits=...) ──► GKD │
└─────────────────────────────────────────────────────────────────┘
│ │ │
DataLoader vLLMSampler TransformersModel ×2
(model GPUs) (sampler GPUs) student + teacher (model GPUs)

Environment variables (all optional):
STUDENT_MODEL_ID – (default: ms://Qwen/Qwen2.5-1.5B-Instruct)
TEACHER_MODEL_ID – (default: ms://Qwen/Qwen2.5-7B-Instruct)
MODEL_GPUS – GPUs for student + teacher models (default: 4)
SAMPLER_GPUS – GPUs for teacher vLLM sampler (default: 4)
MAX_NEW_TOKENS – max completion tokens (default: 512)
BATCH_SIZE – global prompt-level batch size (default: 8)
MAX_STEPS – total optimisation steps (default: 200)
LR – learning rate (default: 1e-4)
GKD_BETA – JSD beta (0=fwd KL, 1=rev KL) (default: 0.5)
GKD_TEMPERATURE – distillation temperature (default: 1.0)
GKD_TOPK – top-k vocab reduction; 0=full (default: 0)
"""

import os
from typing import List

from peft import LoraConfig

import twinkle
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.loss import GKDLoss
from twinkle.model import TransformersModel
from twinkle.preprocessor import GSM8KProcessor
from twinkle.sampler import vLLMSampler
from twinkle.template import Template

logger = get_logger()

# ── Configuration ─────────────────────────────────────────────────────────────
STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen2.5-1.5B-Instruct')
TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct')

MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS

MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 512))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 200))
LEARNING_RATE = float(os.environ.get('LR', 1e-4))

GKD_BETA = float(os.environ.get('GKD_BETA', 0.5))
GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0))
GKD_TOPK = int(os.environ.get('GKD_TOPK', 0))

ADAPTER_NAME = 'default'


# ── Dataset ───────────────────────────────────────────────────────────────────

def create_dataset():
"""Prompt-only dataset; teacher vLLM will generate completions on-policy."""
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024)
dataset.map(GSM8KProcessor())
dataset.encode(add_generation_prompt=True)
return dataset


# ── Training ──────────────────────────────────────────────────────────────────

def main():
device_groups = [
DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='cuda'),
DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='cuda'),
]
model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)

twinkle.initialize(
mode='ray',
nproc_per_node=NUM_GPUS,
groups=device_groups,
lazy_collect=False,
)
logger.info(get_device_placement())

# ── Student model (trainable) ──────────────────────────────────────────────
student_model = TransformersModel(
model_id=STUDENT_MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
)
student_model.add_adapter_to_model(
ADAPTER_NAME,
LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear'),
gradient_accumulation_steps=1,
)
student_model.set_optimizer('AdamW', lr=LEARNING_RATE)
student_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)
student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE))
student_model.set_template('Template', model_id=STUDENT_MODEL_ID)

# ── Teacher model (frozen, for logits) ─────────────────────────────────────
teacher_model = TransformersModel(
model_id=TEACHER_MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
)
teacher_model.set_template('Template', model_id=TEACHER_MODEL_ID)

# ── Teacher vLLM sampler (for on-policy generation) ────────────────────────
teacher_sampler = vLLMSampler(
model_id=TEACHER_MODEL_ID,
engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048},
device_mesh=sampler_mesh,
remote_group='sampler',
)
teacher_sampler.set_template(Template, model_id=TEACHER_MODEL_ID)

# ── DataLoader (prompt-only) ───────────────────────────────────────────────
dataloader = DataLoader(
dataset=create_dataset,
batch_size=BATCH_SIZE,
min_batch_size=BATCH_SIZE,
device_mesh=model_mesh,
remote_group='model',
)

sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0)
topk = GKD_TOPK if GKD_TOPK > 0 else None

logger.info(f'GKD On-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}')
logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}')

optim_step = 0
for batch in dataloader:
if optim_step >= MAX_STEPS:
break

# Teacher vLLM generates completions
prompts: List = batch if isinstance(batch, list) else [batch]
sample_response = teacher_sampler.sample(prompts, sampling_params, num_samples=1)
input_data = [seq.new_input_feature for seq in sample_response.sequences]

# Teacher logits (frozen)
teacher_output = teacher_model.forward_only(inputs=input_data)
teacher_logits = teacher_output.get('logits')

# Student forward + GKD backward
student_model.forward_backward(inputs=input_data, teacher_logits=teacher_logits, topk=topk)
student_model.clip_grad_and_step()
optim_step += 1

if optim_step % 10 == 0:
metric = student_model.calculate_metric(is_training=True)
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}')

if optim_step % 50 == 0:
student_model.save(f'gkd-onpolicy-ckpt-{optim_step}')

student_model.save('gkd-onpolicy-final')
logger.info('GKD on-policy training completed.')


if __name__ == '__main__':
main()
3 changes: 3 additions & 0 deletions src/twinkle/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .base import Loss
from .chunked_cross_entropy import ChunkedCrossEntropyLoss
from .cross_entropy import CrossEntropyLoss
from .gkd import GKDLoss
from .grpo import BNPOLoss, CISPOLoss, DRGRPOLoss, GRPOLoss, GSPOLoss, SAPOLoss
from .mse import MSELoss
from .vocab_parallel_cross_entropy import VocabParallelCrossEntropyLoss
Expand All @@ -11,6 +12,8 @@
'cross_entropy': CrossEntropyLoss,
'chunked_cross_entropy': ChunkedCrossEntropyLoss,
'vocab_parallel_cross_entropy': VocabParallelCrossEntropyLoss,
# KD losses
'gkd': GKDLoss,
# RL losses
'grpo': GRPOLoss,
'gspo': GSPOLoss,
Expand Down
Loading
Loading