diff --git a/cookbook/rl/gkd_off_policy.py b/cookbook/rl/gkd_off_policy.py new file mode 100644 index 00000000..4704e2ea --- /dev/null +++ b/cookbook/rl/gkd_off_policy.py @@ -0,0 +1,163 @@ +"""GKD Off-Policy Distillation via Ray. + +Off-policy knowledge distillation: the student learns to match the teacher's +token distribution on pre-existing reference responses from the dataset. + +Pipeline: + 1. DataLoader supplies full-text batches (prompt + reference answer). + 2. Teacher TransformersModel runs forward_only() to get frozen logits. + 3. Student TransformersModel runs forward_backward() with GKDLoss. + +Key difference from on-policy: + - No vLLM sampler needed (responses already in the dataset). + - Simpler GPU layout: all GPUs can go to the model group. + - Faster per-step (no generation latency), but less exploration. + +Architecture (Ray): + ┌─────────────────────────────────────────────────────────────────┐ + │ Driver (CPU) │ + │ dataloader ──► full-text batch (prompt + reference answer) │ + │ teacher_model.forward_only() ──► frozen teacher logits │ + │ student_model.forward_backward(teacher_logits=...) ──► GKD │ + └─────────────────────────────────────────────────────────────────┘ + │ + TransformersModel ×2 + student + teacher (all GPUs) + +Environment variables (all optional): + STUDENT_MODEL_ID – (default: ms://Qwen/Qwen2.5-1.5B-Instruct) + TEACHER_MODEL_ID – (default: ms://Qwen/Qwen2.5-7B-Instruct) + NUM_GPUS – total GPUs for both models (default: 4) + BATCH_SIZE – global batch size (default: 8) + MAX_STEPS – total optimisation steps (default: 200) + LR – learning rate (default: 1e-4) + GKD_BETA – JSD beta (0=fwd KL, 1=rev KL) (default: 0.5) + GKD_TEMPERATURE – distillation temperature (default: 1.0) + GKD_TOPK – top-k vocab reduction; 0=full (default: 0) +""" + +import os + +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import GKDLoss +from twinkle.model import TransformersModel +from twinkle.preprocessor import GSM8KFullProcessor + +logger = get_logger() + +# ── Configuration ───────────────────────────────────────────────────────────── +STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen2.5-1.5B-Instruct') +TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct') + +NUM_GPUS = int(os.environ.get('NUM_GPUS', 4)) + +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 200)) +LEARNING_RATE = float(os.environ.get('LR', 1e-4)) + +GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) +GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) +GKD_TOPK = int(os.environ.get('GKD_TOPK', 0)) + +ADAPTER_NAME = 'default' + + +# ── Dataset ─────────────────────────────────────────────────────────────────── + +def create_dataset(): + """Full-text dataset with prompt + reference answer for off-policy distillation.""" + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) + dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024) + dataset.map(GSM8KFullProcessor()) + dataset.encode() + return dataset + + +# ── Training ────────────────────────────────────────────────────────────────── + +def main(): + device_groups = [ + DeviceGroup(name='model', ranks=list(range(NUM_GPUS)), device_type='cuda'), + ] + model_mesh = DeviceMesh.from_sizes(world_size=NUM_GPUS, dp_size=NUM_GPUS) + + twinkle.initialize( + mode='ray', + nproc_per_node=NUM_GPUS, + groups=device_groups, + lazy_collect=False, + ) + logger.info(get_device_placement()) + + # ── Student model (trainable) ────────────────────────────────────────────── + student_model = TransformersModel( + model_id=STUDENT_MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + ) + student_model.add_adapter_to_model( + ADAPTER_NAME, + LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear'), + gradient_accumulation_steps=1, + ) + student_model.set_optimizer('AdamW', lr=LEARNING_RATE) + student_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0) + student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE)) + student_model.set_template('Template', model_id=STUDENT_MODEL_ID) + + # ── Teacher model (frozen, for logits) ───────────────────────────────────── + teacher_model = TransformersModel( + model_id=TEACHER_MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + ) + teacher_model.set_template('Template', model_id=TEACHER_MODEL_ID) + + # ── DataLoader (full-text: prompt + reference answer) ────────────────────── + dataloader = DataLoader( + dataset=create_dataset, + batch_size=BATCH_SIZE, + min_batch_size=BATCH_SIZE, + device_mesh=model_mesh, + remote_group='model', + ) + + topk = GKD_TOPK if GKD_TOPK > 0 else None + + logger.info(f'GKD Off-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}') + logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}') + + optim_step = 0 + for batch in dataloader: + if optim_step >= MAX_STEPS: + break + + input_data = batch if isinstance(batch, list) else [batch] + + # Teacher logits (frozen) + teacher_output = teacher_model.forward_only(inputs=input_data) + teacher_logits = teacher_output.get('logits') + + # Student forward + GKD backward + student_model.forward_backward(inputs=input_data, teacher_logits=teacher_logits, topk=topk) + student_model.clip_grad_and_step() + optim_step += 1 + + if optim_step % 10 == 0: + metric = student_model.calculate_metric(is_training=True) + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}') + + if optim_step % 50 == 0: + student_model.save(f'gkd-offpolicy-ckpt-{optim_step}') + + student_model.save('gkd-offpolicy-final') + logger.info('GKD off-policy training completed.') + + +if __name__ == '__main__': + main() diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py new file mode 100644 index 00000000..c27f597c --- /dev/null +++ b/cookbook/rl/gkd_on_policy.py @@ -0,0 +1,244 @@ +"""GKD On-Policy Distillation via Ray. + +On-policy knowledge distillation: student vLLM generates responses, +teacher vLLM provides top-k prompt logprobs, then student model learns +to match the teacher's token distribution. + +Pipeline: + 1. DataLoader supplies prompt-only batches. + 2. Student vLLM sampler generates completions on-the-fly. + 3. Teacher vLLM sampler computes top-k prompt logprobs on generated sequences. + 4. Student TransformersModel runs forward_backward() with GKDLoss. + +Architecture (Ray): + ┌─────────────────────────────────────────────────────────────────┐ + │ Driver (CPU) │ + │ dataloader ──► prompt-only batch │ + │ student_sampler.sample() ──► on-policy completions │ + │ teacher_sampler.sample(topk_prompt_logprobs=k) ──► teacher lps│ + │ student_model.forward_backward(teacher_output=...) ──► GKD │ + └─────────────────────────────────────────────────────────────────┘ + │ │ │ + DataLoader vLLMSampler ×2 TransformersModel + (model GPUs) student + teacher (model GPUs) + +Environment variables (all optional): + STUDENT_MODEL_ID – (default: ms://Qwen/Qwen2.5-1.5B-Instruct) + TEACHER_MODEL_ID – (default: ms://Qwen/Qwen3-4B) + MODEL_GPUS – GPUs for student model (default: 4) + SAMPLER_GPUS – GPUs for each vLLM sampler (default: 2) + MAX_NEW_TOKENS – max completion tokens (default: 512) + BATCH_SIZE – global prompt-level batch size (default: 8) + MAX_STEPS – total optimisation steps (default: 200) + LR – learning rate (default: 1e-4) + GKD_BETA – JSD beta (0=fwd KL, 1=rev KL) (default: 0.5) + GKD_TEMPERATURE – distillation temperature (default: 1.0) + GKD_TOPK – top-k vocab for teacher logprobs (default: 10) +""" + +import os +from typing import List, Optional + +import torch +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger +from twinkle.data_format import SamplingParams +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import GKDLoss +from twinkle.model import TransformersModel +from twinkle.preprocessor import GSM8KProcessor +from twinkle.sampler import vLLMSampler +from twinkle.template import Template + +logger = get_logger() + +# ── Configuration ───────────────────────────────────────────────────────────── +STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen2.5-1.5B-Instruct') +TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-4B') + +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) +NUM_GPUS = MODEL_GPUS + 2*SAMPLER_GPUS + +MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 512)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 200)) +LEARNING_RATE = float(os.environ.get('LR', 1e-4)) + +GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) +GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) +GKD_TOPK = int(os.environ.get('GKD_TOPK', 10)) + +ADAPTER_NAME = 'default' + + +# ── Dataset ─────────────────────────────────────────────────────────────────── + +def create_dataset(): + """Prompt-only dataset; student vLLM will generate completions on-policy.""" + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) + dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024) + dataset.map(GSM8KProcessor()) + dataset.encode(add_generation_prompt=True) + return dataset + + +# ── Utility ─────────────────────────────────────────────────────────────────── + +def convert_topk_prompt_logprobs( + topk_prompt_logprobs_batch: List[List[Optional[List[tuple]]]], + device: str = 'cpu', +) -> dict: + """Convert vLLM topk_prompt_logprobs to GKDLoss teacher_output format. + + Args: + topk_prompt_logprobs_batch: List of per-input topk_prompt_logprobs. + Each is List[Optional[List[(token_id, logprob)]]] of shape [seq_len, topk]. + device: Target device for tensors. + + Returns: + Dict with 'topk_logprobs' [batch, seq_len, topk] and + 'topk_indices' [batch, seq_len, topk] tensors. + """ + batch_logprobs = [] + batch_indices = [] + + for seq_topk in topk_prompt_logprobs_batch: + seq_logprobs = [] + seq_indices = [] + for pos_topk in seq_topk: + if pos_topk is None: + # First position typically has no logprobs + seq_logprobs.append([0.0] * len(seq_topk[1]) if len(seq_topk) > 1 and seq_topk[1] else [0.0]) + seq_indices.append([0] * len(seq_topk[1]) if len(seq_topk) > 1 and seq_topk[1] else [0]) + else: + seq_logprobs.append([lp for _, lp in pos_topk]) + seq_indices.append([tid for tid, _ in pos_topk]) + batch_logprobs.append(seq_logprobs) + batch_indices.append(seq_indices) + + # Pad to same seq_len within batch + max_len = max(len(seq) for seq in batch_logprobs) + topk = len(batch_logprobs[0][0]) if batch_logprobs and batch_logprobs[0] else GKD_TOPK + + for i in range(len(batch_logprobs)): + pad_len = max_len - len(batch_logprobs[i]) + if pad_len > 0: + batch_logprobs[i].extend([[0.0] * topk] * pad_len) + batch_indices[i].extend([[0] * topk] * pad_len) + + return { + 'topk_logprobs': torch.tensor(batch_logprobs, dtype=torch.float32, device=device), + 'topk_indices': torch.tensor(batch_indices, dtype=torch.long, device=device), + } + + +# ── Training ────────────────────────────────────────────────────────────────── + +def main(): + device_groups = [ + DeviceGroup(name='student_model', ranks=MODEL_GPUS, device_type='cuda'), + DeviceGroup(name='student_sampler', ranks=SAMPLER_GPUS, device_type='cuda'), + DeviceGroup(name='teacher_sampler', ranks=SAMPLER_GPUS, device_type='cuda'), + ] + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) + + twinkle.initialize( + mode='ray', + nproc_per_node=NUM_GPUS, + groups=device_groups, + ) + logger.info(get_device_placement()) + + # ── Student model (trainable) ────────────────────────────────────────────── + student_model = TransformersModel( + model_id=STUDENT_MODEL_ID, + device_mesh=model_mesh, + remote_group='student_model', + ) + student_model.add_adapter_to_model( + ADAPTER_NAME, + LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear'), + gradient_accumulation_steps=1, + ) + student_model.set_optimizer('AdamW', lr=LEARNING_RATE) + student_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0) + student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE)) + student_model.set_template('Template', model_id=STUDENT_MODEL_ID) + + # ── Student vLLM sampler (for on-policy generation) ──────────────────────── + student_sampler = vLLMSampler( + model_id=STUDENT_MODEL_ID, + engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048}, + device_mesh=sampler_mesh, + remote_group='student_sampler', + ) + student_sampler.set_template(Template, model_id=STUDENT_MODEL_ID) + + # ── Teacher vLLM sampler (for prompt logprobs) ─────────────────────────────── + teacher_sampler = vLLMSampler( + model_id=TEACHER_MODEL_ID, + engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 2048, 'logprobs_mode': 'raw_logprobs'}, + device_mesh=sampler_mesh, + remote_group='teacher_sampler', + ) + teacher_sampler.set_template(Template, model_id=TEACHER_MODEL_ID) + + # ── DataLoader (prompt-only) ─────────────────────────────────────────────── + dataloader = DataLoader( + dataset=create_dataset, + batch_size=BATCH_SIZE, + min_batch_size=BATCH_SIZE, + device_mesh=model_mesh, + remote_group='student_model', + ) + + logger.info(f'GKD On-Policy | student={STUDENT_MODEL_ID} teacher={TEACHER_MODEL_ID}') + logger.info(f' beta={GKD_BETA} T={GKD_TEMPERATURE} topk={GKD_TOPK}') + + optim_step = 0 + for batch in dataloader: + if optim_step >= MAX_STEPS: + break + + # 1. Student vLLM generates completions + sample_response = student_sampler.sample(batch, SamplingParams(max_tokens=MAX_NEW_TOKENS, temperature=1.0, num_samples=1)) + input_data = [seq.new_input_feature for response in sample_response for seq in response.sequences] + for data in input_data: + data.pop('input_ids', None) + + # 2. Teacher vLLM computes top-k prompt logprobs on generated sequences + teacher_response = teacher_sampler.sample( + input_data, + SamplingParams(max_tokens=1, temperature=1.0, prompt_logprobs=10), + ) + + # 3. Convert teacher logprobs to tensor format for GKDLoss + # teacher_response is List[SampleResponse], extract topk_prompt_logprobs from each + teacher_output = convert_topk_prompt_logprobs( + [resp.topk_prompt_logprobs for resp in teacher_response], + device='cuda', + ) + + # 4. Student forward + GKD backward + student_model.forward_backward(inputs=input_data, teacher_output=teacher_output) + student_model.clip_grad_and_step() + optim_step += 1 + + if optim_step % 10 == 0: + metric = student_model.calculate_metric(is_training=True) + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}') + + if optim_step % 50 == 0: + student_model.save(f'gkd-onpolicy-ckpt-{optim_step}') + + student_model.save('gkd-onpolicy-final') + logger.info('GKD on-policy training completed.') + + +if __name__ == '__main__': + main() diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index ca37d724..e0f67537 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -20,7 +20,7 @@ def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=8) @@ -35,7 +35,7 @@ def train(): # 1000 samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') # Preprocess the dataset to standard format dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) # Encode dataset @@ -43,13 +43,14 @@ def train(): # Global batch size = 8, for GPUs, so 1 sample per GPU dataloader = DataLoader(dataset=dataset, batch_size=8) # Use a TransformersModel - model = TransformersModel(model_id='ms://Qwen/Qwen3-4B') + model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') + model.model._no_split_modules = {'Qwen3_5DecoderLayer'} lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') # Add a lora to model, with name `default` # Comment this to use full-parameter training - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) + # model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) # Add Optimizer for lora `default` model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) # Add LRScheduler for lora `default` @@ -60,8 +61,8 @@ def train(): logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') loss_metric = 99.0 - # lora: 18G * 4 - # full: 50G * 4 + # lora: 8G * 8 + # full: 18G * 8 for step, batch in enumerate(dataloader): # Do forward and backward model.forward_backward(inputs=batch) diff --git a/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md b/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md index 2f67e37b..b46c9c20 100644 --- a/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md +++ b/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md @@ -89,8 +89,8 @@ def train(): logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') loss_metric = 99.0 - # LoRA training: ~18G * 4 GPU memory - # Full-parameter training: ~50G * 4 GPU memory + # LoRA training: ~8G * 8 GPU memory + # Full-parameter training: ~18G * 8 GPU memory for step, batch in enumerate(dataloader): # Forward + backward pass model.forward_backward(inputs=batch) diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" index 8b86b9b0..b4ca94cd 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" @@ -89,8 +89,8 @@ def train(): logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') loss_metric = 99.0 - # LoRA 训练:约 18G * 4 显存占用 - # 全参数训练:约 50G * 4 显存占用 + # LoRA 训练:约 8G * 8 显存占用 + # 全参数训练:约 18G * 8 显存占用 for step, batch in enumerate(dataloader): # 前向 + 反向传播 model.forward_backward(inputs=batch) diff --git a/src/twinkle/data_format/sampling.py b/src/twinkle/data_format/sampling.py index 129aea8e..5b46e152 100644 --- a/src/twinkle/data_format/sampling.py +++ b/src/twinkle/data_format/sampling.py @@ -17,21 +17,20 @@ class SamplingParams: top_k: int = -1 top_p: float = 1.0 repetition_penalty: float = 1.0 + logprobs: int = None + prompt_logprobs: int = None + num_samples: int = 1 - def to_vllm(self, *, num_samples: int = 1, logprobs: bool = True, prompt_logprobs: int = 0): + def to_vllm(self, **kwargs): """Convert to vLLM SamplingParams. - - Args: - num_samples: Number of completions per prompt (vLLM's 'n' parameter). - logprobs: Whether to return logprobs for generated tokens. - prompt_logprobs: Number of prompt token logprobs to return. """ from vllm import SamplingParams as VLLMSamplingParams kwargs = { 'temperature': self.temperature, 'top_p': self.top_p, - 'n': num_samples, + 'n': self.num_samples, + **kwargs, } if self.max_tokens is not None: @@ -54,14 +53,14 @@ def to_vllm(self, *, num_samples: int = 1, logprobs: bool = True, prompt_logprob else: kwargs['stop'] = list(self.stop) - if logprobs: - kwargs['logprobs'] = 0 + if self.logprobs is not None: + kwargs['logprobs'] = self.logprobs - if prompt_logprobs > 0: - kwargs['prompt_logprobs'] = prompt_logprobs + if self.prompt_logprobs is not None: + kwargs['prompt_logprobs'] = self.prompt_logprobs vllm_params = VLLMSamplingParams(**kwargs) - if num_samples > 1: + if self.num_samples > 1: from vllm.sampling_params import RequestOutputKind vllm_params.output_kind = RequestOutputKind.FINAL_ONLY return vllm_params diff --git a/src/twinkle/data_format/trajectory.py b/src/twinkle/data_format/trajectory.py index a4f694cb..c7742d75 100644 --- a/src/twinkle/data_format/trajectory.py +++ b/src/twinkle/data_format/trajectory.py @@ -15,5 +15,4 @@ class Trajectory(TypedDict, total=False): messages: List[Message] extend_message: List[Tuple[str, List[Message]]] tools: List[Tool] - advantages: float user_data: List[Tuple[str, Any]] diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 9c37b367..f25f0fa3 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -6,6 +6,7 @@ from typing import Any, Callable, List, Literal, Optional, TypeVar, Union from twinkle.utils import DeviceGroup, DeviceMesh, Platform, check_unsafe, framework_util, get_logger, requires +from .collectors import collect_tensor_dict logger = get_logger() @@ -343,11 +344,19 @@ def dispatch_func(arg, n): length = len(workers) def dispatch_func(arg, n): - if isinstance(arg, list): + import torch + if isinstance(arg, list) or isinstance(arg, torch.Tensor): _args = [] for i in range(n): _args.append(arg[device_mesh.get_slice(len(arg), device_mesh.get_data_rank_from_global_rank(i))]) return _args + elif isinstance(arg, dict): + _args = [{} for _ in range(n)] + for key in arg.keys(): + value = arg[key] + for i, v in enumerate(dispatch_func(value, n)): + _args[i][key] = v + return _args else: return [arg] * n diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py new file mode 100644 index 00000000..3f8c7cf9 --- /dev/null +++ b/src/twinkle/infra/collectors.py @@ -0,0 +1,79 @@ +from typing import List, Dict, Any, TYPE_CHECKING + +if TYPE_CHECKING: + import torch + + +def collect_tensor_dict(outputs: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: + if not outputs: + return {} + + if len(outputs) == 1: + return outputs[0] + + all_keys = set() + for d in outputs: + all_keys.update(d.keys()) + + import torch + result = {} + for key in all_keys: + values = [d[key] for d in outputs if key in d] + + if not values or all([v is None for v in values]): + continue + + first_value = values[0] + + if isinstance(first_value, list): + merged = [] + for v in values: + if isinstance(v, list): + merged.extend(v) + else: + merged.append(v) + result[key] = merged + + elif isinstance(first_value, torch.Tensor): + result[key] = _pad_and_stack_tensors(values) + + elif isinstance(first_value, dict): + result[key] = collect_tensor_dict(values) + + else: + result[key] = values + + return result + + +def _pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = 0) -> 'torch.Tensor': + import torch + if not tensors: + raise ValueError("Empty tensor list") + + if len(tensors) == 1: + return tensors[0].unsqueeze(0) + + max_ndim = max(t.ndim for t in tensors) + expanded_tensors = [] + for t in tensors: + while t.ndim < max_ndim: + t = t.unsqueeze(0) + expanded_tensors.append(t) + + max_shape = [] + for dim in range(max_ndim): + max_shape.append(max(t.shape[dim] for t in expanded_tensors)) + + padded_tensors = [] + for t in expanded_tensors: + if list(t.shape) == max_shape: + padded_tensors.append(t) + else: + pad_params = [] + for dim in range(max_ndim - 1, -1, -1): + pad_params.extend([0, max_shape[dim] - t.shape[dim]]) + padded = torch.nn.functional.pad(t, pad_params, value=pad_value) + padded_tensors.append(padded) + + return torch.cat(padded_tensors, dim=0) \ No newline at end of file diff --git a/src/twinkle/loss/__init__.py b/src/twinkle/loss/__init__.py index e03681ae..65303dfd 100644 --- a/src/twinkle/loss/__init__.py +++ b/src/twinkle/loss/__init__.py @@ -2,15 +2,17 @@ from .base import Loss from .chunked_cross_entropy import ChunkedCrossEntropyLoss from .cross_entropy import CrossEntropyLoss +from .gkd import GKDLoss from .grpo import BNPOLoss, CISPOLoss, DRGRPOLoss, GRPOLoss, GSPOLoss, SAPOLoss from .mse import MSELoss -from .vocab_parallel_cross_entropy import VocabParallelCrossEntropyLoss +from .cross_entropy import CrossEntropyLoss torch_loss_mapping = { 'mse': MSELoss, - 'cross_entropy': CrossEntropyLoss, 'chunked_cross_entropy': ChunkedCrossEntropyLoss, - 'vocab_parallel_cross_entropy': VocabParallelCrossEntropyLoss, + 'cross_entropy': CrossEntropyLoss, + # KD losses + 'gkd': GKDLoss, # RL losses 'grpo': GRPOLoss, 'gspo': GSPOLoss, diff --git a/src/twinkle/loss/cross_entropy.py b/src/twinkle/loss/cross_entropy.py index 12851d45..06bf791a 100644 --- a/src/twinkle/loss/cross_entropy.py +++ b/src/twinkle/loss/cross_entropy.py @@ -1,20 +1,40 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from twinkle.data_format import LossOutput -from twinkle.utils import selective_log_softmax +from ..data_format import LossOutput from .base import Loss class CrossEntropyLoss(Loss): + """Calculate CE from logps""" - def __init__(self, **kwargs): - self.reduction = kwargs.get('reduction', 'mean') + def __init__(self, ignore_index: int = -100, reduction='mean', **kwargs): + super().__init__() + self.ignore_index = ignore_index + self.reduction = reduction def __call__(self, inputs, outputs, **kwargs): - import torch - logits = outputs['logits'].view(-1, outputs['logits'].shape[-1]) - labels = inputs['labels'].view(-1) - loss = torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels) - if self.reduction != 'sum': - return LossOutput(loss=loss, num_tokens=0) + labels = inputs['labels'] + logps = outputs.get('logps') + logits = outputs.get('logits') + + if logps is not None: + loss_mask = (labels != self.ignore_index).float() + if self.reduction != 'sum': + return LossOutput( + loss=(-logps * loss_mask).sum() / loss_mask.sum().clamp(min=1), + num_tokens=0, + ) + else: + return LossOutput( + loss=(-logps * loss_mask).sum(), + num_tokens=loss_mask.sum().clamp(min=1), + ) else: - return LossOutput(loss=loss, num_tokens=(labels != -100).sum()) + import torch + assert logits is not None + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + loss = torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels) + if self.reduction != 'sum': + return LossOutput(loss=loss, num_tokens=0) + else: + return LossOutput(loss=loss, num_tokens=(labels != self.ignore_index).sum()) diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py new file mode 100644 index 00000000..44249dfb --- /dev/null +++ b/src/twinkle/loss/gkd.py @@ -0,0 +1,235 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from typing import TYPE_CHECKING, Optional + +from twinkle.data_format import LossOutput +from twinkle.loss.base import Loss + +if TYPE_CHECKING: + import torch + + +class GKDLoss(Loss): + """Generalized Knowledge Distillation (GKD) loss based on Jensen-Shannon Divergence. + + Implements the on-policy distillation objective from: + "On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes" + (https://arxiv.org/abs/2306.13649) + + The loss is a β-weighted mixture of two KL divergences: + JSD_β(S || T) = β · KL(T || M) + (1 - β) · KL(S || M) + where M = β · T + (1 - β) · S (mixture distribution) + + Special cases: + β = 0 → forward KL(S || T) (mean-seeking) + β = 1 → reverse KL(T || S) (mode-seeking) + β = 0.5 → symmetric JSD + + Args: + beta: Weight for teacher in the JSD mixture (default: 0.5). + temperature: Softmax temperature applied to logits before divergence (default: 1.0). + ignore_index: Token index to ignore in the loss mask (default: -100). + chunk_size: Number of valid tokens processed per chunk to reduce peak memory (default: 512). + """ + + def __init__( + self, + beta: float = 0.5, + temperature: float = 1.0, + ignore_index: int = -100, + chunk_size: int = 512, + **kwargs, + ): + self.beta = beta + self.temperature = temperature + self.ignore_index = ignore_index + self.chunk_size = chunk_size + + def __call__( + self, + inputs, + outputs, + *, + teacher_output: Optional['torch.Tensor'] = None, + topk: Optional[int] = None, + **kwargs, + ) -> LossOutput: + """Compute GKD / JSD distillation loss. + + Args: + inputs: Dict containing 'labels' [batch, seq_len] with ignore_index for non-response tokens. + outputs: Dict containing 'logits' [batch, seq_len, vocab_size] from the student model. + teacher_output: A dict contains: + teacher_logits: [batch, seq_len, vocab_size] full vocabulary logits from a local teacher. + Either teacher_logits or (teacher_topk_logprobs + teacher_topk_indices) + must be provided. + teacher_topk_logprobs: [batch, seq_len, topk] log-probs from a remote teacher API. + Returned by a vLLM-compatible /v1/completions prompt_logprobs call. + teacher_topk_indices: [batch, seq_len, topk] token indices corresponding to teacher_topk_logprobs. + topk: If set together with teacher_logits, only the top-k teacher tokens are used to + reduce vocabulary size before computing the JSD (memory-efficient local teacher mode). + + Returns: + LossOutput with scalar 'loss' averaged over valid (non-ignored) response tokens. + """ + teacher_logits = teacher_output.get('logits') + teacher_topk_logprobs = teacher_output.get('topk_logprobs') + teacher_topk_indices = teacher_output.get('topk_indices') + assert teacher_logits is not None or ( + teacher_topk_logprobs is not None and teacher_topk_indices is not None + ), ( + 'Either logits or both topk_logprobs and topk_indices must be provided.' + ) + + labels = inputs['labels'] + student_logits = outputs['logits'] + # Align seq dimension: some MLLMs return extra prefix logits + if student_logits.shape[1] != labels.shape[1]: + student_logits = student_logits[:, -labels.shape[1]:] + if teacher_logits is not None and teacher_logits.shape[1] > student_logits.shape[1]: + teacher_logits = teacher_logits[:, :student_logits.shape[1]] + if teacher_topk_logprobs is not None and teacher_topk_logprobs.shape[1] > student_logits.shape[1]: + teacher_topk_logprobs = teacher_topk_logprobs[:, :student_logits.shape[1]] + teacher_topk_indices = teacher_topk_indices[:, :student_logits.shape[1]] + + # Shift labels: label[i] = next token predicted by logits[i] + # The last position wraps to label[0] via roll; since label[0] is -100 (prompt), + # it will be correctly excluded by the mask in _generalized_jsd_loss. + shifted_labels = labels.roll(shifts=-1, dims=1) + + loss = self._generalized_jsd_loss( + student_logits=student_logits, + teacher_logits=teacher_logits, + labels=shifted_labels, + beta=self.beta, + temperature=self.temperature, + chunk_size=self.chunk_size, + topk=topk, + teacher_topk_logprobs=teacher_topk_logprobs, + teacher_topk_indices=teacher_topk_indices, + ) + return LossOutput(loss=loss, num_tokens=0) + + @staticmethod + def _generalized_jsd_loss( + student_logits, + teacher_logits=None, + labels=None, + beta: float = 0.5, + temperature: float = 1.0, + chunk_size: int = 512, + topk: Optional[int] = None, + teacher_topk_logprobs=None, + teacher_topk_indices=None, + ): + """Core chunked JSD loss computation. + + Supports three teacher modes: + 1. Full-vocabulary local teacher (teacher_logits, topk=None) + 2. Top-k local teacher (teacher_logits, topk=k) + 3. Remote API teacher (teacher_topk_logprobs + teacher_topk_indices) + + The function processes valid tokens in chunks to keep peak GPU memory bounded. + + Args: + student_logits: [batch, seq_len, vocab_size] or [batch, seq_len, topk] after top-k reduction. + teacher_logits: [batch, seq_len, vocab_size] full vocabulary logits from local teacher. + labels: [batch, seq_len] shifted labels; positions where value == ignore_index are excluded. + beta: JSD mixture weight (0=forward KL, 1=reverse KL, 0.5=symmetric JSD). + temperature: Softmax temperature. + chunk_size: Tokens per chunk. + topk: If given, reduce local teacher to top-k tokens before computing JSD. + teacher_topk_logprobs: [batch, seq_len, topk] from remote API. + teacher_topk_indices: [batch, seq_len, topk] from remote API. + + Returns: + Scalar loss tensor. + """ + import torch + import torch.nn.functional as F + + # ── Top-k reduction ────────────────────────────────────────────────── + # Top-k mode: gather/topk first to get small [*, k] tensors, then scale in-place + if teacher_topk_logprobs is not None and teacher_topk_indices is not None: + # Remote API teacher: teacher already provides top-k log-probs (T=1). + # Gather student logits at teacher's top-k indices, then scale in-place. + student_logits = torch.gather(student_logits, dim=-1, index=teacher_topk_indices) + student_logits.div_(temperature) + teacher_logits = teacher_topk_logprobs / temperature + temperature = 1.0 + elif topk is not None and teacher_logits is not None: + # Local teacher: select top-k from teacher, gather corresponding student logits + teacher_logits, topk_idx = torch.topk(teacher_logits, k=topk, dim=-1) + teacher_logits.div_(temperature) + student_logits = torch.gather(student_logits, dim=-1, index=topk_idx) + student_logits.div_(temperature) + temperature = 1.0 + + # ── Mask valid (response) tokens ────────────────────────────────────── + if labels is not None: + mask = labels != -100 # ignore_index is always -100 per convention + # Vocab-size mismatch (e.g. Qwen2.5-VL-3B vs 7B): pad the smaller side + # so both distributions are defined over the same token set. + stu_dim = student_logits.shape[-1] + tea_dim = teacher_logits.shape[-1] + if stu_dim < tea_dim: + student_logits = F.pad(student_logits, (0, tea_dim - stu_dim)) + student_logits[..., stu_dim:] = teacher_logits[..., stu_dim:] + elif stu_dim > tea_dim: + teacher_logits = F.pad(teacher_logits, (0, stu_dim - tea_dim)) + teacher_logits[..., tea_dim:] = student_logits[..., tea_dim:] + student_logits = student_logits[mask] # [num_valid, vocab/topk] + teacher_logits = teacher_logits[mask] + num_valid = mask.sum() + else: + student_logits = student_logits.view(-1, student_logits.size(-1)) + teacher_logits = teacher_logits.view(-1, teacher_logits.size(-1)) + num_valid = student_logits.size(0) + student_logits.div_(temperature) + teacher_logits.div_(temperature) + + if num_valid == 0: + return student_logits.new_zeros(()) + + num_valid_int = int(num_valid) if isinstance(num_valid, int) else num_valid.item() + total_loss = student_logits.new_zeros(()) + + # Pre-compute log(beta) / log(1-beta) once for the mixture + if beta not in (0, 1): + beta_t = torch.tensor(beta, dtype=student_logits.dtype, device=student_logits.device) + log_beta = torch.log(beta_t) + log_1_minus_beta = torch.log1p(-beta_t) + else: + beta_t = log_beta = log_1_minus_beta = None + + # ── Chunked loss accumulation ───────────────────────────────────────── + for start in range(0, num_valid_int, chunk_size): + end = min(start + chunk_size, num_valid_int) + s_chunk = student_logits[start:end] + t_chunk = teacher_logits[start:end] + + s_log_probs = F.log_softmax(s_chunk, dim=-1) + t_log_probs = F.log_softmax(t_chunk, dim=-1) + del s_chunk, t_chunk + + if beta == 0: + # Forward KL: KL(S || T) + jsd_chunk = F.kl_div(s_log_probs, t_log_probs, reduction='none', log_target=True) + elif beta == 1: + # Reverse KL: KL(T || S) + jsd_chunk = F.kl_div(t_log_probs, s_log_probs, reduction='none', log_target=True) + else: + # Generalised JSD: β·KL(T||M) + (1-β)·KL(S||M) + mixture_log_probs = torch.logsumexp( + torch.stack([s_log_probs + log_1_minus_beta, t_log_probs + log_beta]), + dim=0, + ) + kl_teacher = F.kl_div(mixture_log_probs, t_log_probs, reduction='none', log_target=True) + kl_student = F.kl_div(mixture_log_probs, s_log_probs, reduction='none', log_target=True) + del mixture_log_probs + jsd_chunk = beta_t * kl_teacher + (1 - beta_t) * kl_student + del kl_teacher, kl_student + + total_loss = total_loss + jsd_chunk.sum() + del jsd_chunk, s_log_probs, t_log_probs + + return total_loss / num_valid diff --git a/src/twinkle/loss/vocab_parallel_cross_entropy.py b/src/twinkle/loss/vocab_parallel_cross_entropy.py deleted file mode 100644 index 166e843f..00000000 --- a/src/twinkle/loss/vocab_parallel_cross_entropy.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from ..data_format import LossOutput -from .base import Loss - - -class VocabParallelCrossEntropyLoss(Loss): - - def __init__(self, ignore_index: int = -100): - super().__init__() - self.ignore_index = ignore_index - - def __call__(self, inputs, outputs, **kwargs): - labels = inputs['labels'] - logps = outputs.get('logps') - - loss_mask = (labels != self.ignore_index).float() - return LossOutput( - loss=(-logps * loss_mask).sum(), - num_tokens=loss_mask.sum().clamp(min=1), - ) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 68e68f5a..c796a6e4 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -26,7 +26,7 @@ from twinkle.checkpoint_engine.mixin import CheckpointEngineMixin from twinkle.data_format import InputFeature, ModelOutput, Trajectory from twinkle.hub import HubOperation -from twinkle.loss import Loss, VocabParallelCrossEntropyLoss +from twinkle.loss import Loss, CrossEntropyLoss from twinkle.metric import LossMetric, Metric, TrainMetric from twinkle.model.base import TwinkleModel from twinkle.patch import Patch, apply_patch @@ -238,7 +238,7 @@ def __init__( def _construct_default_optimizer_group(self): return MegatronOptimizerGroup( - loss_instance=VocabParallelCrossEntropyLoss(), + loss_instance=CrossEntropyLoss(), template=Template(self.tokenizer_id), processor=InputProcessor(self.device_mesh, framework='megatron'), _device_mesh=self.device_mesh, @@ -364,7 +364,7 @@ def calculate_loss(self, **kwargs): def backward(self, **kwargs): raise NotImplementedError('Megatron only supports `forward_backward` and `forward_only`') - @remote_function(dispatch='slice_dp', collect='mean', sync=True) + @remote_function(dispatch='slice_dp', collect='last_pp', sync=True) def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], @@ -398,6 +398,7 @@ def forward_backward(self, from megatron.core.pipeline_parallel import get_forward_backward_func adapter_name = kwargs.pop('adapter_name', self._get_default_group()) + temperature = float(kwargs.pop('temperature', 1.0)) forward_only = kwargs.pop('forward_only', False) optimizer_config = self.optimizer_group[adapter_name] loss_instance = self.optimizer_group[adapter_name].loss_instance @@ -485,6 +486,7 @@ def forward_step_func(data_iterator, model): loss_mask = (labels != -100).bool() masked_labels = labels.clone() masked_labels[~loss_mask] = 0 + output_tensor.div_(temperature) logps = selective_log_softmax(output_tensor, masked_labels) if cp_size > 1: logps = self._postprocess_tensor_cp(logps) diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index d0e76378..c3388bcc 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -80,7 +80,7 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di fsdp_size = device_mesh.get_dim_size('fsdp') if device_mesh.has_dim('fsdp') else 1 dp_size = device_mesh.get_dim_size('dp') if device_mesh.has_dim('dp') else 1 - if fsdp_size == 1 and dp_size == 1: + if fsdp_size == 1: return None fsdp_config = fsdp_config or {} diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 23062618..82997d19 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -26,6 +26,7 @@ from twinkle.checkpoint_engine.mixin import CheckpointEngineMixin from twinkle.data_format import InputFeature, ModelOutput, Trajectory from twinkle.hub import HubOperation +from twinkle.infra import collect_tensor_dict from twinkle.loss import CrossEntropyLoss, Loss from twinkle.metric import Accuracy, LossMetric, Metric, TrainMetric from twinkle.model.base import TwinkleModel @@ -104,11 +105,14 @@ def _ensure_dp_group(self): self._dp_group = self._device_mesh.create_process_group(dims) def _get_lr(self): - _lrs = [] - _default_lr = self.optimizer.defaults.get('lr') - for param_group in self.optimizer.param_groups: - _lrs.append(param_group.get('lr', _default_lr)) - return _lrs + if self.optimizer is not None: + _lrs = [] + _default_lr = self.optimizer.defaults.get('lr') + for param_group in self.optimizer.param_groups: + _lrs.append(param_group.get('lr', _default_lr)) + return _lrs + else: + return [] def accumulate_metrics(self, is_training): self._ensure_dp_group() @@ -350,7 +354,7 @@ def _construct_default_optimizer_group(self): _device_mesh=self.device_mesh, ) - @remote_function() + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): """Call forward function and record the inputs and outputs. @@ -362,6 +366,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec The output of the model forward. """ adapter_name = kwargs.pop('adapter_name', self._get_default_group()) + temperature = float(kwargs.pop('temperature', 1.0)) optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() if not inputs: @@ -393,10 +398,13 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec loss_mask = (labels != -100).bool() masked_labels = labels.clone() masked_labels[~loss_mask] = 0 - outputs['logps'] = selective_log_softmax(outputs['logits'], masked_labels) + logits = outputs['logits'] + logits.div_(temperature) + outputs['logps'] = selective_log_softmax(logits, masked_labels) + outputs['past_key_values'] = None return outputs - @remote_function(dispatch='slice_dp', collect='flatten') + @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): """Call forward function without grad and record the inputs and outputs. @@ -408,6 +416,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T The output of the model forward. """ adapter_name = kwargs.pop('adapter_name', self._get_default_group()) + temperature = float(kwargs.pop('temperature', 1.0)) optimizer_config = self.optimizer_group[adapter_name] self._lazy_wrap_model() if not inputs: @@ -440,7 +449,10 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T loss_mask = (labels != -100).bool() masked_labels = labels.clone() masked_labels[~loss_mask] = 0 - outputs['logps'] = selective_log_softmax(outputs['logits'], masked_labels) + logits = outputs['logits'] + logits.div_(temperature) + outputs['logps'] = selective_log_softmax(logits, masked_labels) + outputs['past_key_values'] = None return outputs @remote_function(collect='mean') @@ -502,7 +514,7 @@ def backward(self, **kwargs): optimizer_config.cur_step += 1 optimizer_config.loss_value = None - @remote_function(dispatch='slice_dp', collect='mean') + @remote_function(dispatch='slice_dp', collect='flatten') def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): """Do forward, calculate loss, and backward. diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 1c19815e..7234a60a 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import DataFilter, Preprocessor from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, - SelfCognitionProcessor) + GSM8KFullProcessor, GSM8KProcessor, SelfCognitionProcessor) diff --git a/src/twinkle/preprocessor/llm.py b/src/twinkle/preprocessor/llm.py index ddafb351..565f1661 100644 --- a/src/twinkle/preprocessor/llm.py +++ b/src/twinkle/preprocessor/llm.py @@ -129,7 +129,7 @@ class GSM8KProcessor(Preprocessor): def extract_ground_truth(self, answer_str: str) -> str: """Extract the number after '####' from GSM8K answer.""" - match = re.search(r'####\s*([\-\d,\.]+)', answer_str) + match = re.search(r'####\s*([\-\d,.]+)', answer_str) if match: return match.group(1).replace(',', '').strip() return '' @@ -153,3 +153,27 @@ def preprocess(self, row) -> Trajectory: messages=messages, user_data=[('ground_truth', ground_truth)], ) + + +class GSM8KFullProcessor(GSM8KProcessor): + """GSM8K preprocessor that includes the reference answer as the assistant message. + + Produces a full Trajectory (prompt + reference answer) suitable for + off-policy knowledge distillation: the student and teacher both see the + ground-truth response text, and labels cover the response tokens. + """ + + def preprocess(self, row) -> Trajectory: + question = row['question'] + answer = row.get('answer', '') + ground_truth = self.extract_ground_truth(answer) + + messages = [ + Message(role='system', content=self.system_prompt), + Message(role='user', content=question), + Message(role='assistant', content=answer), + ] + return Trajectory( + messages=messages, + user_data=[('ground_truth', ground_truth)], + ) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index a12a1e40..66f01fe8 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -180,16 +180,13 @@ async def sample( self, prompt_token_ids: List[int], sampling_params: Union[SamplingParams, Dict[str, Any]], - num_samples: int = 1, - logprobs: bool = True, - include_prompt_logprobs: bool = False, - topk_prompt_logprobs: int = 0, lora_request: Optional[Any] = None, request_id: Optional[str] = None, priority: int = 0, *, images: Optional[List[Any]] = None, videos: Optional[List[Any]] = None, + **kwargs ) -> SampleResponse: """ Sample completions from the model. @@ -219,12 +216,8 @@ async def sample( # Convert to vLLM params if isinstance(sampling_params, dict): sampling_params = SamplingParams.from_dict(sampling_params) - prompt_logprobs_k = topk_prompt_logprobs if topk_prompt_logprobs > 0 else (1 if include_prompt_logprobs else 0) - vllm_params = sampling_params.to_vllm( - num_samples=num_samples, - logprobs=logprobs, - prompt_logprobs=prompt_logprobs_k, - ) + prompt_logprobs_k = sampling_params.prompt_logprobs or 0 + vllm_params = sampling_params.to_vllm(**kwargs) # Build request if request_id is None: diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index b4d1c6fd..e33e178a 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -35,29 +35,6 @@ logger = get_logger() -def _collect_sample_responses(results: List[SampleResponse], **kwargs) -> SampleResponse: - """Custom collect function to merge multiple SampleResponse objects. - - Args: - results: List of SampleResponse from each DP worker. - - Returns: - Merged SampleResponse with all sequences combined. - """ - if not results: - return SampleResponse(sequences=[]) - - if len(results) == 1: - return results[0] - - all_sequences = [] - for resp in results: - if resp is not None and hasattr(resp, 'sequences'): - all_sequences.extend(resp.sequences) - - return SampleResponse(sequences=all_sequences) - - @remote_class() class vLLMSampler(Sampler, CheckpointEngineMixin): """A vLLM-based sampler using VLLMEngine (AsyncLLM). @@ -224,7 +201,6 @@ async def _sample_single( lora_request: Optional[Any] = None, *, logprobs: bool = True, - num_samples: int = 1, ) -> List[SampledSequence]: """Sample a single input asynchronously. @@ -250,14 +226,13 @@ async def _sample_single( prompt_token_ids=input_ids, sampling_params=sampling_params, logprobs=logprobs, - num_samples=num_samples, lora_request=lora_request, images=images, videos=videos, ) # response.sequences contains num_samples sequences for this prompt - return [ + return SampleResponse(sequences=[ SampledSequence( stop_reason=seq.stop_reason, tokens=seq.tokens, @@ -265,9 +240,9 @@ async def _sample_single( decoded=self.template.decode(seq.tokens), new_input_feature=self.template.concat_input_feature(feat, seq.tokens), ) for seq in response.sequences - ] + ], prompt_logprobs=response.prompt_logprobs, topk_prompt_logprobs=response.topk_prompt_logprobs) - @remote_function(dispatch='slice_dp', collect=_collect_sample_responses, lazy_collect=False) + @remote_function(dispatch='slice_dp', collect='flatten', lazy_collect=False) def sample( self, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], @@ -275,10 +250,8 @@ def sample( adapter_name: str = '', adapter_path: Optional[str] = None, *, - logprobs: bool = True, - num_samples: int = 1, return_encoded: bool = False, - ) -> SampleResponse: + ) -> List[SampleResponse]: """Sample responses for given inputs. Args: @@ -302,7 +275,6 @@ def sample( Note: In Ray mode with multiple workers (DP > 1): - Data is automatically sliced by DP rank (dispatch='slice_dp') - - Results are merged using _collect_sample_responses - Each worker receives already-sliced inputs (e.g., DP4 with 8 inputs -> 2 per worker) """ if sampling_params is None: @@ -337,18 +309,12 @@ async def _sample_all(): feat, sampling_params, lora_request=lora_request, - logprobs=logprobs, - num_samples=num_samples, ) for feat in encoded_inputs ] return await asyncio.gather(*tasks) - results = self._run_in_loop(_sample_all()) - # Flatten results (each result contains num_samples sequences) - all_sequences = [] - for seqs in results: - all_sequences.extend(seqs) - return SampleResponse(sequences=all_sequences) + sample_results = self._run_in_loop(_sample_all()) + return sample_results @remote_function(dispatch='all', collect='first') def sleep(self, level: int = 1) -> None: diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py index 35c636d8..7097f946 100644 --- a/src/twinkle/utils/torch_utils.py +++ b/src/twinkle/utils/torch_utils.py @@ -1,6 +1,6 @@ import socket from datetime import timedelta -from typing import TYPE_CHECKING, Any, Mapping, Union +from typing import TYPE_CHECKING, Any, Mapping, Union, List, Dict from .network import is_valid_ipv6_address