Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ for training jobs. There are a number of options you can specify, such as settin
| distributed_backend | Specifies which distributed training backend to use. Supported options are "fsdp" and "deepspeed". |
| disable_flash_attn | Disables flash attention when set to true. This allows for training on older devices. |
| keep_last_checkpoint_only | Determines whether we should only keep the last checkpoint directory - the previous checkpoint directory is always overwritten. The checkpoint directory is called `last_epoch`. |
| trust_remote_code | Controls whether repository-provided Python code from HuggingFace Hub is executed when loading models and tokenizers. This is required for models that ship custom modeling code, such as Nemotron, Ministral, and Qwen3.5. Can also be enabled via the `TRUST_REMOTE_CODE=1` environment variable. Defaults to `False`. **Security note:** enabling this setting will execute remote code from the model repository — only enable it for sources you trust. |

### `DeepSpeedOptions`

Expand Down Expand Up @@ -507,6 +508,7 @@ run_training(
Below is a list of custom environment variables users can set in the training library.

1. `INSTRUCTLAB_NCCL_TIMEOUT_MS`, this environment variable controls the NCCL timeout in milliseconds. Consider increasing if seeing FSDP related NCCL errors.
2. `TRUST_REMOTE_CODE`, when set to `1`, allows repository-provided Python code from HuggingFace Hub to be executed when loading models and tokenizers. This is required for models that ship custom modeling code (e.g. Nemotron, Ministral, Qwen3.5). Equivalent to setting `trust_remote_code=True` in `TrainingArgs`. Only enable for sources you trust.

## Developer Certificate of Origin

Expand Down
10 changes: 10 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,16 @@ class TrainingArgs(BaseModel):
description="Whether to use Liger kernels for training.",
)

trust_remote_code: bool = Field(
default=False,
description=(
"Whether to trust remote code when loading models and tokenizers "
"from HuggingFace Hub. Required for models with custom code such as "
"Nemotron, Ministral, and Qwen3.5. Can also be enabled via the "
"TRUST_REMOTE_CODE=1 environment variable."
),
)

log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(
default="INFO"
)
Expand Down
17 changes: 14 additions & 3 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,18 @@
logger = logging.getLogger(__name__)


def _trust_remote_code() -> bool:
"""Resolve trust_remote_code from the TRUST_REMOTE_CODE environment variable."""
return os.environ.get("TRUST_REMOTE_CODE", "").lower() in ("1", "true", "yes")


@lru_cache()
def is_gpt_oss_model(tokenizer: PreTrainedTokenizer) -> bool:
"""Check if this is a GPT-OSS model based on tokenizer."""
model_name_or_path = tokenizer.name_or_path
config = AutoConfig.from_pretrained(model_name_or_path)
config = AutoConfig.from_pretrained(
model_name_or_path, trust_remote_code=_trust_remote_code()
)
return config.model_type == "gpt_oss"


Expand Down Expand Up @@ -1182,7 +1189,9 @@ def process_documents_for_pretraining(
)

logger.info("Loading tokenizer from %s", model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=_trust_remote_code()
)

if tokenizer.eos_token_id is None:
raise ValueError("Tokenizer must have an EOS token defined for pretraining")
Expand Down Expand Up @@ -1294,7 +1303,9 @@ def load_and_validate_dataset(data_path: str, num_procs: int) -> Dataset:

def configure_tokenizer(model_path: str) -> PreTrainedTokenizer:
"""Configure the tokenizer with necessary special tokens."""
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=_trust_remote_code()
)

if not tokenizer.chat_template:
raise ValueError(
Expand Down
10 changes: 9 additions & 1 deletion src/instructlab/training/gpt_oss_utils_correct.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# Standard
from typing import Dict
import logging
import os
import re

# Third Party
Expand Down Expand Up @@ -415,7 +416,14 @@ def is_known_model(
# convert to config
model_config = model_path_or_config
if isinstance(model_path_or_config, str):
model_config = AutoConfig.from_pretrained(model_path_or_config)
_trust_remote = os.environ.get("TRUST_REMOTE_CODE", "").lower() in (
"1",
"true",
"yes",
)
model_config = AutoConfig.from_pretrained(
model_path_or_config, trust_remote_code=_trust_remote
)

known_model_types = (
[known_model_type] if isinstance(known_model_type, str) else known_model_type
Expand Down
29 changes: 28 additions & 1 deletion src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,17 @@ def main(args):
tokenizer = setup_tokenizer(args.model_name_or_path, args.chat_tmpl_path)
# device = torch.device("cuda", args.local_rank)

model_conf = AutoConfig.from_pretrained(args.model_name_or_path)
# Resolve trust_remote_code from CLI flag or environment variable
trust_remote_code = getattr(args, "trust_remote_code", False) or os.environ.get(
"TRUST_REMOTE_CODE", ""
).lower() in ("1", "true", "yes")
if trust_remote_code:
# Export so downstream calls (data_process, tokenizer_utils, etc.) pick it up
os.environ["TRUST_REMOTE_CODE"] = "1"

model_conf = AutoConfig.from_pretrained(
args.model_name_or_path, trust_remote_code=trust_remote_code
)
args.model_type = model_conf.model_type

#### distributed init #####
Expand Down Expand Up @@ -449,6 +459,7 @@ def main(args):
flash_enabled=flash_enabled,
noise_alpha=args.NEFTune_alpha,
lora_quant_bits=args.lora_quant_bits,
trust_remote_code=trust_remote_code,
)

args.base_model_args = m.base_model_args
Expand Down Expand Up @@ -710,6 +721,13 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
if train_args.use_liger:
command.append("--use_liger")

# Resolve trust_remote_code from flag or environment variable
trust_remote_code = train_args.trust_remote_code or os.environ.get(
"TRUST_REMOTE_CODE", ""
).lower() in ("1", "true", "yes")
if trust_remote_code:
command.append("--trust_remote_code")

if train_args.keep_last_checkpoint_only:
command.append("--keep_last_checkpoint_only")

Expand Down Expand Up @@ -1036,6 +1054,15 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
help="Path to the chat template to set on the model for training. If none is provided, the chat template used in the model will be used.",
)
parser.add_argument("--disable_flash_attn", action="store_true")
parser.add_argument(
"--trust_remote_code",
action="store_true",
help=(
"Trust remote code when loading models/tokenizers from HuggingFace Hub. "
"Required for models with custom code (e.g. Nemotron, Ministral, Qwen3.5). "
"Can also be set via the TRUST_REMOTE_CODE=1 environment variable."
),
)
parser.add_argument(
"--keep_last_checkpoint_only",
action="store_true",
Expand Down
33 changes: 26 additions & 7 deletions src/instructlab/training/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
flash_enabled: bool = False,
lora_config: Optional[LoraConfig] = None,
lora_quant_bits: int = 0,
trust_remote_code: bool = False,
):
self.lora_config = lora_config
self.noise_alpha = noise_alpha
Expand All @@ -74,12 +75,13 @@ def __init__(

# check model type & set on the mclasss
self.is_granitemoehybrid = is_known_model(model_path, "granitemoehybrid")
self.is_nemotronh = is_known_model(model_path, "nemotron_h")
self.is_gpt_oss = is_gpt_oss(model_path)

# Pre-populate the Hub kernel cache with locally installed mamba_ssm
# and causal_conv1d to avoid PyTorch/CUDA ABI mismatches with the
# Hub-provided kernel builds.
if self.is_granitemoehybrid:
if self.is_granitemoehybrid or self.is_nemotronh:
self._use_local_mamba_kernels()

if self.is_gpt_oss:
Expand All @@ -100,6 +102,7 @@ def __init__(
self.base_model_args = {
"pretrained_model_name_or_path": model_path,
"quantization_config": quant_config,
"trust_remote_code": trust_remote_code,
}

# load GPT-OSS in bfloat16 because it's a massive model, but otherwise
Expand All @@ -114,7 +117,7 @@ def __init__(
# - M-RoPE models produce 3D position_ids that FA2 misinterprets
# - Models with timm vision towers (TimmWrapperModel rejects FA2)
# Detect these and fall back to SDPA.
use_sdpa = needs_sdpa(model_path)
use_sdpa = needs_sdpa(model_path, trust_remote_code=trust_remote_code)
if use_sdpa:
logger.warning(
"Disabling flash_attention_2 — model is incompatible "
Expand Down Expand Up @@ -145,7 +148,7 @@ def __init__(
# For models with timm vision towers: set vision config to eager
# while keeping the text model's attention implementation.
# timm's TimmWrapperModel rejects both FA2 and SDPA.
if has_timm_vision_tower(model_path):
if has_timm_vision_tower(model_path, trust_remote_code=trust_remote_code):
attn_impl = self.base_model_args.get(
"attn_implementation", "flash_attention_2"
)
Expand Down Expand Up @@ -196,6 +199,18 @@ def _use_local_mamba_kernels():
e,
)

def _enable_gradient_checkpointing_if_supported(self) -> None:
"""Enable gradient checkpointing if the model supports it.

Some models (e.g. NemotronH with hybrid Mamba/MoE/Attention architecture)
do not support gradient checkpointing and will raise an error. This helper
catches known exception types and logs a warning instead of crashing.
"""
try:
self.model.gradient_checkpointing_enable()
except (ValueError, NotImplementedError, AttributeError) as e:
logger.warning("Gradient checkpointing not supported: %s", e)

def _post_model_init(self):
"""Common initialization steps that should happen after model initialization."""
self.reconcile_tokenizer()
Expand Down Expand Up @@ -544,6 +559,7 @@ def __init__(
flash_enabled: bool = False,
lora_config: Optional[LoraConfig] = None,
lora_quant_bits: int = 0,
trust_remote_code: bool = False,
):
super().__init__(
model_path=model_path,
Expand All @@ -553,6 +569,7 @@ def __init__(
flash_enabled=flash_enabled,
lora_config=lora_config,
lora_quant_bits=lora_quant_bits,
trust_remote_code=trust_remote_code,
)
try:
# Third Party
Expand All @@ -570,7 +587,7 @@ def __init__(
cross_entropy=True,
fused_linear_cross_entropy=False,
)
self.model.gradient_checkpointing_enable()
self._enable_gradient_checkpointing_if_supported()
self._post_model_init()


Expand All @@ -586,6 +603,7 @@ def __init__(
flash_enabled: bool = False,
lora_config: Optional[LoraConfig] = None,
lora_quant_bits: int = 0,
trust_remote_code: bool = False,
):
super().__init__(
model_path=model_path,
Expand All @@ -595,15 +613,16 @@ def __init__(
flash_enabled=flash_enabled,
lora_config=lora_config,
lora_quant_bits=lora_quant_bits,
trust_remote_code=trust_remote_code,
)
if is_vlm_with_causal_lm(model_path):
if is_vlm_with_causal_lm(model_path, trust_remote_code=trust_remote_code):
self.model = extract_causal_lm_from_vlm(model_path, self.base_model_args)
elif is_vlm_for_direct_loading(model_path):
elif is_vlm_for_direct_loading(model_path, trust_remote_code=trust_remote_code):
self.model = load_vlm_for_text_training(model_path, self.base_model_args)
else:
self.model = AutoModelForCausalLM.from_pretrained(**self.base_model_args)
self._post_model_init()
self.model.gradient_checkpointing_enable()
self._enable_gradient_checkpointing_if_supported()


def setup_optimizer(
Expand Down
12 changes: 11 additions & 1 deletion src/instructlab/training/tokenizer_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
import os

# Third Party
from transformers import AutoTokenizer, PreTrainedTokenizer
import transformers
Expand Down Expand Up @@ -95,7 +98,14 @@ def setup_tokenizer(
model_name_or_path,
chat_tmpl_path: str | None = None,
) -> PreTrainedTokenizer:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, fast_tokenizer=True)
trust_remote_code = os.environ.get("TRUST_REMOTE_CODE", "").lower() in (
"1",
"true",
"yes",
)
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=trust_remote_code
)
if not tokenizer.chat_template and chat_tmpl_path is None:
raise ValueError(
"Tokenizer does not have a chat template. Please provide a path to a chat template."
Expand Down
44 changes: 43 additions & 1 deletion src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,8 +713,31 @@ def _get_state_dict_patched(model, unwrap=False):
if is_gpt_oss(model.module.config):
add_gpt_oss_quantization_config(model.module.config)

# For FP8 models (e.g. Ministral), restore the original quantization
# config so the saved checkpoint matches the original FP8 format.
# We must also temporarily remove _fp8_* attrs since they contain
# tensors that are not JSON-serializable.
_fp8_quant_cfg = getattr(model.module.config, "_fp8_quantization_config", None)
_had_fp8_scales = hasattr(model.module.config, "_fp8_scales")
if _fp8_quant_cfg is not None:
model.module.config.quantization_config = _fp8_quant_cfg
if _had_fp8_scales:
_saved_scales = model.module.config._fp8_scales
del model.module.config._fp8_scales
if hasattr(model.module.config, "_fp8_quantization_config"):
del model.module.config._fp8_quantization_config

model.module.config.to_json_file(output_config_file)

# Restore internal attrs and clear quantization_config so the live
# model doesn't look quantized during training
if _had_fp8_scales:
model.module.config._fp8_scales = _saved_scales
if _fp8_quant_cfg is not None:
model.module.config._fp8_quantization_config = _fp8_quant_cfg
if _fp8_quant_cfg is not None:
model.module.config.quantization_config = None

tokenizer.save_pretrained(output_dir)

if is_lora:
Expand Down Expand Up @@ -759,7 +782,26 @@ def _get_state_dict_patched(model, unwrap=False):
max_shard_size="5GB",
safe_serialization=True,
)
elif not is_gpt_oss(model.module.config):
elif getattr(model.module.config, "_fp8_scales", None):
# FP8 model (e.g. Ministral): re-quantize state dict before saving
if accelerator.is_main_process:
from instructlab.training.vlm_utils import requantize_fp8_state_dict

log_rank_0(
"Re-quantizing FP8 parameters for checkpoint compatibility"
)
model_state = model.module.state_dict()
model_state = requantize_fp8_state_dict(
model_state, model.module.config._fp8_scales
)
save_dict_accelerate(
accelerator,
model_state,
save_directory=output_dir,
max_shard_size="5GB",
safe_serialization=True,
)
else:
# Standard model saving
accelerator.save_model(
model,
Expand Down
Loading
Loading