diff --git a/README.md b/README.md index a20209fc..9bf23475 100644 --- a/README.md +++ b/README.md @@ -244,6 +244,7 @@ for training jobs. There are a number of options you can specify, such as settin | distributed_backend | Specifies which distributed training backend to use. Supported options are "fsdp" and "deepspeed". | | disable_flash_attn | Disables flash attention when set to true. This allows for training on older devices. | | keep_last_checkpoint_only | Determines whether we should only keep the last checkpoint directory - the previous checkpoint directory is always overwritten. The checkpoint directory is called `last_epoch`. | +| trust_remote_code | Controls whether repository-provided Python code from HuggingFace Hub is executed when loading models and tokenizers. This is required for models that ship custom modeling code, such as Nemotron, Ministral, and Qwen3.5. Can also be enabled via the `TRUST_REMOTE_CODE=1` environment variable. Defaults to `False`. **Security note:** enabling this setting will execute remote code from the model repository — only enable it for sources you trust. | ### `DeepSpeedOptions` @@ -507,6 +508,7 @@ run_training( Below is a list of custom environment variables users can set in the training library. 1. `INSTRUCTLAB_NCCL_TIMEOUT_MS`, this environment variable controls the NCCL timeout in milliseconds. Consider increasing if seeing FSDP related NCCL errors. +2. `TRUST_REMOTE_CODE`, when set to `1`, allows repository-provided Python code from HuggingFace Hub to be executed when loading models and tokenizers. This is required for models that ship custom modeling code (e.g. Nemotron, Ministral, Qwen3.5). Equivalent to setting `trust_remote_code=True` in `TrainingArgs`. Only enable for sources you trust. ## Developer Certificate of Origin diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 911c3898..34dfda98 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -302,6 +302,16 @@ class TrainingArgs(BaseModel): description="Whether to use Liger kernels for training.", ) + trust_remote_code: bool = Field( + default=False, + description=( + "Whether to trust remote code when loading models and tokenizers " + "from HuggingFace Hub. Required for models with custom code such as " + "Nemotron, Ministral, and Qwen3.5. Can also be enabled via the " + "TRUST_REMOTE_CODE=1 environment variable." + ), + ) + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field( default="INFO" ) diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 8d68a725..ae3b3cdf 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -41,11 +41,18 @@ logger = logging.getLogger(__name__) +def _trust_remote_code() -> bool: + """Resolve trust_remote_code from the TRUST_REMOTE_CODE environment variable.""" + return os.environ.get("TRUST_REMOTE_CODE", "").lower() in ("1", "true", "yes") + + @lru_cache() def is_gpt_oss_model(tokenizer: PreTrainedTokenizer) -> bool: """Check if this is a GPT-OSS model based on tokenizer.""" model_name_or_path = tokenizer.name_or_path - config = AutoConfig.from_pretrained(model_name_or_path) + config = AutoConfig.from_pretrained( + model_name_or_path, trust_remote_code=_trust_remote_code() + ) return config.model_type == "gpt_oss" @@ -1182,7 +1189,9 @@ def process_documents_for_pretraining( ) logger.info("Loading tokenizer from %s", model_path) - tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=_trust_remote_code() + ) if tokenizer.eos_token_id is None: raise ValueError("Tokenizer must have an EOS token defined for pretraining") @@ -1294,7 +1303,9 @@ def load_and_validate_dataset(data_path: str, num_procs: int) -> Dataset: def configure_tokenizer(model_path: str) -> PreTrainedTokenizer: """Configure the tokenizer with necessary special tokens.""" - tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=_trust_remote_code() + ) if not tokenizer.chat_template: raise ValueError( diff --git a/src/instructlab/training/gpt_oss_utils_correct.py b/src/instructlab/training/gpt_oss_utils_correct.py index a77ee15a..4df3bfce 100644 --- a/src/instructlab/training/gpt_oss_utils_correct.py +++ b/src/instructlab/training/gpt_oss_utils_correct.py @@ -8,6 +8,7 @@ # Standard from typing import Dict import logging +import os import re # Third Party @@ -415,7 +416,14 @@ def is_known_model( # convert to config model_config = model_path_or_config if isinstance(model_path_or_config, str): - model_config = AutoConfig.from_pretrained(model_path_or_config) + _trust_remote = os.environ.get("TRUST_REMOTE_CODE", "").lower() in ( + "1", + "true", + "yes", + ) + model_config = AutoConfig.from_pretrained( + model_path_or_config, trust_remote_code=_trust_remote + ) known_model_types = ( [known_model_type] if isinstance(known_model_type, str) else known_model_type diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 072b27c6..a9887800 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -394,7 +394,17 @@ def main(args): tokenizer = setup_tokenizer(args.model_name_or_path, args.chat_tmpl_path) # device = torch.device("cuda", args.local_rank) - model_conf = AutoConfig.from_pretrained(args.model_name_or_path) + # Resolve trust_remote_code from CLI flag or environment variable + trust_remote_code = getattr(args, "trust_remote_code", False) or os.environ.get( + "TRUST_REMOTE_CODE", "" + ).lower() in ("1", "true", "yes") + if trust_remote_code: + # Export so downstream calls (data_process, tokenizer_utils, etc.) pick it up + os.environ["TRUST_REMOTE_CODE"] = "1" + + model_conf = AutoConfig.from_pretrained( + args.model_name_or_path, trust_remote_code=trust_remote_code + ) args.model_type = model_conf.model_type #### distributed init ##### @@ -449,6 +459,7 @@ def main(args): flash_enabled=flash_enabled, noise_alpha=args.NEFTune_alpha, lora_quant_bits=args.lora_quant_bits, + trust_remote_code=trust_remote_code, ) args.base_model_args = m.base_model_args @@ -710,6 +721,13 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.use_liger: command.append("--use_liger") + # Resolve trust_remote_code from flag or environment variable + trust_remote_code = train_args.trust_remote_code or os.environ.get( + "TRUST_REMOTE_CODE", "" + ).lower() in ("1", "true", "yes") + if trust_remote_code: + command.append("--trust_remote_code") + if train_args.keep_last_checkpoint_only: command.append("--keep_last_checkpoint_only") @@ -1036,6 +1054,15 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: help="Path to the chat template to set on the model for training. If none is provided, the chat template used in the model will be used.", ) parser.add_argument("--disable_flash_attn", action="store_true") + parser.add_argument( + "--trust_remote_code", + action="store_true", + help=( + "Trust remote code when loading models/tokenizers from HuggingFace Hub. " + "Required for models with custom code (e.g. Nemotron, Ministral, Qwen3.5). " + "Can also be set via the TRUST_REMOTE_CODE=1 environment variable." + ), + ) parser.add_argument( "--keep_last_checkpoint_only", action="store_true", diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index e3e20de1..09455c2d 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -65,6 +65,7 @@ def __init__( flash_enabled: bool = False, lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, + trust_remote_code: bool = False, ): self.lora_config = lora_config self.noise_alpha = noise_alpha @@ -74,12 +75,13 @@ def __init__( # check model type & set on the mclasss self.is_granitemoehybrid = is_known_model(model_path, "granitemoehybrid") + self.is_nemotronh = is_known_model(model_path, "nemotron_h") self.is_gpt_oss = is_gpt_oss(model_path) # Pre-populate the Hub kernel cache with locally installed mamba_ssm # and causal_conv1d to avoid PyTorch/CUDA ABI mismatches with the # Hub-provided kernel builds. - if self.is_granitemoehybrid: + if self.is_granitemoehybrid or self.is_nemotronh: self._use_local_mamba_kernels() if self.is_gpt_oss: @@ -100,6 +102,7 @@ def __init__( self.base_model_args = { "pretrained_model_name_or_path": model_path, "quantization_config": quant_config, + "trust_remote_code": trust_remote_code, } # load GPT-OSS in bfloat16 because it's a massive model, but otherwise @@ -114,7 +117,7 @@ def __init__( # - M-RoPE models produce 3D position_ids that FA2 misinterprets # - Models with timm vision towers (TimmWrapperModel rejects FA2) # Detect these and fall back to SDPA. - use_sdpa = needs_sdpa(model_path) + use_sdpa = needs_sdpa(model_path, trust_remote_code=trust_remote_code) if use_sdpa: logger.warning( "Disabling flash_attention_2 — model is incompatible " @@ -145,7 +148,7 @@ def __init__( # For models with timm vision towers: set vision config to eager # while keeping the text model's attention implementation. # timm's TimmWrapperModel rejects both FA2 and SDPA. - if has_timm_vision_tower(model_path): + if has_timm_vision_tower(model_path, trust_remote_code=trust_remote_code): attn_impl = self.base_model_args.get( "attn_implementation", "flash_attention_2" ) @@ -196,6 +199,18 @@ def _use_local_mamba_kernels(): e, ) + def _enable_gradient_checkpointing_if_supported(self) -> None: + """Enable gradient checkpointing if the model supports it. + + Some models (e.g. NemotronH with hybrid Mamba/MoE/Attention architecture) + do not support gradient checkpointing and will raise an error. This helper + catches known exception types and logs a warning instead of crashing. + """ + try: + self.model.gradient_checkpointing_enable() + except (ValueError, NotImplementedError, AttributeError) as e: + logger.warning("Gradient checkpointing not supported: %s", e) + def _post_model_init(self): """Common initialization steps that should happen after model initialization.""" self.reconcile_tokenizer() @@ -544,6 +559,7 @@ def __init__( flash_enabled: bool = False, lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, + trust_remote_code: bool = False, ): super().__init__( model_path=model_path, @@ -553,6 +569,7 @@ def __init__( flash_enabled=flash_enabled, lora_config=lora_config, lora_quant_bits=lora_quant_bits, + trust_remote_code=trust_remote_code, ) try: # Third Party @@ -570,7 +587,7 @@ def __init__( cross_entropy=True, fused_linear_cross_entropy=False, ) - self.model.gradient_checkpointing_enable() + self._enable_gradient_checkpointing_if_supported() self._post_model_init() @@ -586,6 +603,7 @@ def __init__( flash_enabled: bool = False, lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, + trust_remote_code: bool = False, ): super().__init__( model_path=model_path, @@ -595,15 +613,16 @@ def __init__( flash_enabled=flash_enabled, lora_config=lora_config, lora_quant_bits=lora_quant_bits, + trust_remote_code=trust_remote_code, ) - if is_vlm_with_causal_lm(model_path): + if is_vlm_with_causal_lm(model_path, trust_remote_code=trust_remote_code): self.model = extract_causal_lm_from_vlm(model_path, self.base_model_args) - elif is_vlm_for_direct_loading(model_path): + elif is_vlm_for_direct_loading(model_path, trust_remote_code=trust_remote_code): self.model = load_vlm_for_text_training(model_path, self.base_model_args) else: self.model = AutoModelForCausalLM.from_pretrained(**self.base_model_args) self._post_model_init() - self.model.gradient_checkpointing_enable() + self._enable_gradient_checkpointing_if_supported() def setup_optimizer( diff --git a/src/instructlab/training/tokenizer_utils.py b/src/instructlab/training/tokenizer_utils.py index 96b8767c..5ffe8e9e 100644 --- a/src/instructlab/training/tokenizer_utils.py +++ b/src/instructlab/training/tokenizer_utils.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +# Standard +import os + # Third Party from transformers import AutoTokenizer, PreTrainedTokenizer import transformers @@ -95,7 +98,14 @@ def setup_tokenizer( model_name_or_path, chat_tmpl_path: str | None = None, ) -> PreTrainedTokenizer: - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, fast_tokenizer=True) + trust_remote_code = os.environ.get("TRUST_REMOTE_CODE", "").lower() in ( + "1", + "true", + "yes", + ) + tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, trust_remote_code=trust_remote_code + ) if not tokenizer.chat_template and chat_tmpl_path is None: raise ValueError( "Tokenizer does not have a chat template. Please provide a path to a chat template." diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index fc31858e..a7adb32b 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -713,8 +713,31 @@ def _get_state_dict_patched(model, unwrap=False): if is_gpt_oss(model.module.config): add_gpt_oss_quantization_config(model.module.config) + # For FP8 models (e.g. Ministral), restore the original quantization + # config so the saved checkpoint matches the original FP8 format. + # We must also temporarily remove _fp8_* attrs since they contain + # tensors that are not JSON-serializable. + _fp8_quant_cfg = getattr(model.module.config, "_fp8_quantization_config", None) + _had_fp8_scales = hasattr(model.module.config, "_fp8_scales") + if _fp8_quant_cfg is not None: + model.module.config.quantization_config = _fp8_quant_cfg + if _had_fp8_scales: + _saved_scales = model.module.config._fp8_scales + del model.module.config._fp8_scales + if hasattr(model.module.config, "_fp8_quantization_config"): + del model.module.config._fp8_quantization_config + model.module.config.to_json_file(output_config_file) + # Restore internal attrs and clear quantization_config so the live + # model doesn't look quantized during training + if _had_fp8_scales: + model.module.config._fp8_scales = _saved_scales + if _fp8_quant_cfg is not None: + model.module.config._fp8_quantization_config = _fp8_quant_cfg + if _fp8_quant_cfg is not None: + model.module.config.quantization_config = None + tokenizer.save_pretrained(output_dir) if is_lora: @@ -759,7 +782,25 @@ def _get_state_dict_patched(model, unwrap=False): max_shard_size="5GB", safe_serialization=True, ) - elif not is_gpt_oss(model.module.config): + elif getattr(model.module.config, "_fp8_scales", None): + # FP8 model (e.g. Ministral): re-quantize state dict before saving + if accelerator.is_main_process: + # First Party + from instructlab.training.vlm_utils import requantize_fp8_state_dict + + log_rank_0("Re-quantizing FP8 parameters for checkpoint compatibility") + model_state = model.module.state_dict() + model_state = requantize_fp8_state_dict( + model_state, model.module.config._fp8_scales + ) + save_dict_accelerate( + accelerator, + model_state, + save_directory=output_dir, + max_shard_size="5GB", + safe_serialization=True, + ) + else: # Standard model saving accelerator.save_model( model, diff --git a/src/instructlab/training/vlm_utils.py b/src/instructlab/training/vlm_utils.py index 29ac1d68..ea84835c 100644 --- a/src/instructlab/training/vlm_utils.py +++ b/src/instructlab/training/vlm_utils.py @@ -8,6 +8,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, ) +import torch logger = logging.getLogger("instructlab.training") @@ -158,6 +159,135 @@ def _find_text_backbone(vlm_model) -> nn.Module: ) +def _dequantize_fp8_model(model: PreTrainedModel) -> None: + """Dequantize FP8 weights in-place for FSDP compatibility. + + Some models (e.g. Ministral) ship with FP8 quantized weights that include + scalar parameters like ``weight_scale_inv`` and ``activation_scale``. + FSDP rejects scalar parameters, so we dequantize the weights back to + bfloat16 and remove all FP8 scalar parameters before distributed wrapping. + + The original FP8 scales and quantization config are preserved on the model + (as ``_fp8_scales`` and ``_fp8_quantization_config``) so that + :func:`requantize_fp8_state_dict` can restore them at checkpoint save time. + + The dequantization formula is: + real_weight = fp8_weight.to(bfloat16) * weight_scale_inv + """ + # FP8 scalar parameter names to remove after dequantization. + # weight_scale_inv: inverse scale for weight quantization + # activation_scale: scale for activation quantization (inference only) + _FP8_SCALAR_ATTRS = ("weight_scale_inv", "activation_scale") + + # Store original scales keyed by module path for requantization at save time. + fp8_scales: dict[str, dict[str, torch.Tensor]] = {} + + dequantized_count = 0 + for mod_name, module in model.named_modules(): + has_fp8 = any(hasattr(module, attr) for attr in _FP8_SCALAR_ATTRS) + if not has_fp8: + continue + + # Capture original scales before removing them + saved = {} + for attr in _FP8_SCALAR_ATTRS: + if hasattr(module, attr): + saved[attr] = getattr(module, attr).detach().clone().cpu() + if saved: + fp8_scales[mod_name] = saved + + # Dequantize weight if scale is present + if hasattr(module, "weight_scale_inv") and hasattr(module, "weight"): + scale_inv = module.weight_scale_inv + weight = module.weight + dtype = torch.bfloat16 + dequantized = weight.to(dtype) * scale_inv.to(dtype) + module.weight = nn.Parameter( + dequantized, requires_grad=weight.requires_grad + ) + + # Remove all FP8 scalar parameters/buffers + for attr in _FP8_SCALAR_ATTRS: + if not hasattr(module, attr): + continue + if attr in dict(module.named_parameters(recurse=False)): + delattr(module, attr) + elif attr in dict(module.named_buffers(recurse=False)): + setattr(module, attr, None) + + dequantized_count += 1 + + if dequantized_count > 0: + logger.info( + "Dequantized %d FP8 layers to bfloat16 for FSDP compatibility", + dequantized_count, + ) + # Preserve scales and quantization config for checkpoint re-quantization. + # Store on both the model and its config so the metadata survives + # model wrapping (FSDP) and distributed broadcast. + model._fp8_scales = fp8_scales + cfg = getattr(model, "config", None) + if cfg is not None: + cfg._fp8_scales = fp8_scales + if hasattr(cfg, "quantization_config"): + model._fp8_quantization_config = cfg.quantization_config + cfg._fp8_quantization_config = cfg.quantization_config + cfg.quantization_config = None + # Clear quantization metadata so downstream code doesn't treat + # the model as quantized during training + if hasattr(model, "hf_quantizer"): + model.hf_quantizer = None + if hasattr(model, "is_loaded_in_8bit"): + model.is_loaded_in_8bit = False + + +def requantize_fp8_state_dict( + state_dict: dict[str, torch.Tensor], + fp8_scales: dict[str, dict[str, torch.Tensor]], +) -> dict[str, torch.Tensor]: + """Re-quantize a dequantized state dict back to FP8 for checkpoint saving. + + This is the inverse of :func:`_dequantize_fp8_model`. It converts + bfloat16 weights back to ``float8_e4m3fn`` and restores the original + ``weight_scale_inv`` and ``activation_scale`` entries so the saved + checkpoint matches the original FP8 format. + + Args: + state_dict: The model state dict with bfloat16 weights. + fp8_scales: The ``_fp8_scales`` dict stored by + :func:`_dequantize_fp8_model`, mapping module paths to their + original scale tensors. + + Returns: + A new state dict with FP8 weights and restored scale entries. + """ + out = {} + for key, tensor in state_dict.items(): + out[key] = tensor + + for mod_path, scales in fp8_scales.items(): + weight_key = f"{mod_path}.weight" + if weight_key not in out: + continue + + weight = out[weight_key] + + # Re-quantize: fp8_weight = real_weight / weight_scale_inv + if "weight_scale_inv" in scales: + scale_inv = scales["weight_scale_inv"] + requantized = (weight.to(torch.float32) / scale_inv.to(torch.float32)).to( + torch.float8_e4m3fn + ) + out[weight_key] = requantized + out[f"{mod_path}.weight_scale_inv"] = scale_inv + + # Restore activation_scale as-is + if "activation_scale" in scales: + out[f"{mod_path}.activation_scale"] = scales["activation_scale"] + + return out + + def extract_causal_lm_from_vlm( model_path: str, load_kwargs: dict, @@ -205,6 +335,13 @@ def extract_causal_lm_from_vlm( # allocating large random tensors, then attach the real sub-modules. config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) text_config = config.text_config + + # Propagate quantization_config from the VLM config to text_config + # so FP8 dequantization can preserve and restore it at checkpoint time. + vlm_quant_cfg = getattr(config, "quantization_config", None) + if vlm_quant_cfg is not None and not hasattr(text_config, "quantization_config"): + text_config.quantization_config = vlm_quant_cfg + causal_lm_class = MODEL_FOR_CAUSAL_LM_MAPPING[text_config.__class__] with init_empty_weights(): @@ -220,6 +357,11 @@ def extract_causal_lm_from_vlm( setattr(text_model, attr, getattr(vlm, attr)) del vlm + + # Dequantize FP8 weights if present — FSDP rejects scalar parameters + # like weight_scale_inv that come from FP8 quantized models. + _dequantize_fp8_model(text_model) + return text_model diff --git a/tests/unit/test_pretraining_data_process.py b/tests/unit/test_pretraining_data_process.py index 8f80475c..670cf055 100644 --- a/tests/unit/test_pretraining_data_process.py +++ b/tests/unit/test_pretraining_data_process.py @@ -112,8 +112,10 @@ def map_side_effect(func, **kwargs): document_column_name="documents", ) - # Verify tokenizer was loaded - mock_from_pretrained.assert_called_once_with("test-model") + # Verify tokenizer was loaded (trust_remote_code defaults to False) + mock_from_pretrained.assert_called_once_with( + "test-model", trust_remote_code=False + ) # Verify dataset filter and map were called assert mock_ds.filter.called diff --git a/tests/unit/test_pretraining_mode.py b/tests/unit/test_pretraining_mode.py index 8fbd77d8..1ca55120 100644 --- a/tests/unit/test_pretraining_mode.py +++ b/tests/unit/test_pretraining_mode.py @@ -138,7 +138,7 @@ def encode(self, text, add_special_tokens=True): num_cpu_procs=1, ) - mock_auto.assert_called_once_with("stub-model") + mock_auto.assert_called_once_with("stub-model", trust_remote_code=False) output_file = output_dir / "data.jsonl" self.assertTrue(output_file.exists())