Skip to content

Commit 091c21e

Browse files
RobotSailclaudeMaxusmusti
authored
Fix trust_remote_code and gradient checkpointing for custom models (#696)
* Fix: Add trust_remote_code=True for models with custom code - Add trust_remote_code=True to all AutoConfig/AutoTokenizer.from_pretrained() calls - Add torchrun path resolution (shutil.which with sys.executable fallback) - Pass trust_remote_code=True to base_model_args and VLM helper functions This fixes training failures for models like Nemotron that use remote code. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * Fix: Handle models without gradient checkpointing support Wrap gradient_checkpointing_enable() in try/except to handle models like NemotronH that don't support gradient checkpointing. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * Address reviewer feedback and fix ruff formatting - Narrow exception handling in is_gpt_oss_model to catch specific exceptions (OSError, ValueError) instead of bare Exception, and log the failure details for debugging - Add trust_remote_code=True to process_documents_for_pretraining() tokenizer loading for consistency with configure_tokenizer() - Replace invalid fast_tokenizer kwarg with use_fast in tokenizer_utils.py setup_tokenizer() - Create shared _enable_gradient_checkpointing_if_supported() helper on Model base class, catching ValueError, NotImplementedError, and AttributeError; use it in both LigerModel and CausalLMModel - Improve torchrun fallback to use sys.executable -m torch.distributed.run instead of assuming a sibling script exists - Fix ruff formatting for AutoConfig.from_pretrained call Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * Fix unit tests to expect trust_remote_code=True Update test assertions to expect trust_remote_code=True parameter in AutoTokenizer.from_pretrained calls after adding this parameter to process_documents_for_pretraining. * Fix ruff formatting: break long assertion line * Make trust_remote_code configurable via flag and environment variable Instead of hardcoding trust_remote_code=True everywhere: 1. Add trust_remote_code field to TrainingArgs (default: False) 2. Add --trust_remote_code argparse flag to subprocess CLI 3. Support TRUST_REMOTE_CODE=1 environment variable 4. Thread the setting through Model, tokenizer, and config calls 5. Remove torchrun fallback — error if torchrun is not found 6. Remove unnecessary try/except in is_gpt_oss_model 7. Remove redundant use_fast=True from tokenizer_utils The env var is exported by main() when the flag is set, so downstream calls (data_process, tokenizer_utils, gpt_oss_utils) automatically pick it up without needing explicit parameter threading. * Document trust_remote_code in README Add trust_remote_code to the TrainingArgs table and document the TRUST_REMOTE_CODE environment variable in the environment variables section. * Enable local mamba kernel pre-population for NemotronH models NemotronH has Mamba layers just like GraniteMoeHybrid and needs the same _use_local_mamba_kernels() call to avoid causal_conv1d_cuda import failures in torchrun subprocesses. * Fix lint: revert torchrun shutil.which, remove unused imports, ruff format The torchrun-not-found issue was caused by the venv not being activated, not an installation problem. Revert to plain 'torchrun' command. Remove now-unused shutil and sys imports. Run ruff format on all modified files. * Clarify trust_remote_code docs with security warning * Add FP8 dequantization and requantization for Ministral VLM training Ministral-3-3B ships with FP8 quantized weights that include scalar parameters (weight_scale_inv, activation_scale) which FSDP rejects. This change dequantizes FP8 weights to bf16 after VLM extraction for training compatibility, preserves the original scales, and requantizes back to FP8 at checkpoint save time so saved checkpoints match the original FP8 format. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Ruff formatting fixes Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Code <claude@anthropic.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com> Co-authored-by: Mustafa Eyceoz <meyceoz@redhat.com>
1 parent 75425ca commit 091c21e

11 files changed

Lines changed: 289 additions & 17 deletions

File tree

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ for training jobs. There are a number of options you can specify, such as settin
244244
| distributed_backend | Specifies which distributed training backend to use. Supported options are "fsdp" and "deepspeed". |
245245
| disable_flash_attn | Disables flash attention when set to true. This allows for training on older devices. |
246246
| keep_last_checkpoint_only | Determines whether we should only keep the last checkpoint directory - the previous checkpoint directory is always overwritten. The checkpoint directory is called `last_epoch`. |
247+
| trust_remote_code | Controls whether repository-provided Python code from HuggingFace Hub is executed when loading models and tokenizers. This is required for models that ship custom modeling code, such as Nemotron, Ministral, and Qwen3.5. Can also be enabled via the `TRUST_REMOTE_CODE=1` environment variable. Defaults to `False`. **Security note:** enabling this setting will execute remote code from the model repository — only enable it for sources you trust. |
247248

248249
### `DeepSpeedOptions`
249250

@@ -507,6 +508,7 @@ run_training(
507508
Below is a list of custom environment variables users can set in the training library.
508509

509510
1. `INSTRUCTLAB_NCCL_TIMEOUT_MS`, this environment variable controls the NCCL timeout in milliseconds. Consider increasing if seeing FSDP related NCCL errors.
511+
2. `TRUST_REMOTE_CODE`, when set to `1`, allows repository-provided Python code from HuggingFace Hub to be executed when loading models and tokenizers. This is required for models that ship custom modeling code (e.g. Nemotron, Ministral, Qwen3.5). Equivalent to setting `trust_remote_code=True` in `TrainingArgs`. Only enable for sources you trust.
510512

511513
## Developer Certificate of Origin
512514

src/instructlab/training/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,16 @@ class TrainingArgs(BaseModel):
302302
description="Whether to use Liger kernels for training.",
303303
)
304304

305+
trust_remote_code: bool = Field(
306+
default=False,
307+
description=(
308+
"Whether to trust remote code when loading models and tokenizers "
309+
"from HuggingFace Hub. Required for models with custom code such as "
310+
"Nemotron, Ministral, and Qwen3.5. Can also be enabled via the "
311+
"TRUST_REMOTE_CODE=1 environment variable."
312+
),
313+
)
314+
305315
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(
306316
default="INFO"
307317
)

src/instructlab/training/data_process.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,18 @@
4141
logger = logging.getLogger(__name__)
4242

4343

44+
def _trust_remote_code() -> bool:
45+
"""Resolve trust_remote_code from the TRUST_REMOTE_CODE environment variable."""
46+
return os.environ.get("TRUST_REMOTE_CODE", "").lower() in ("1", "true", "yes")
47+
48+
4449
@lru_cache()
4550
def is_gpt_oss_model(tokenizer: PreTrainedTokenizer) -> bool:
4651
"""Check if this is a GPT-OSS model based on tokenizer."""
4752
model_name_or_path = tokenizer.name_or_path
48-
config = AutoConfig.from_pretrained(model_name_or_path)
53+
config = AutoConfig.from_pretrained(
54+
model_name_or_path, trust_remote_code=_trust_remote_code()
55+
)
4956
return config.model_type == "gpt_oss"
5057

5158

@@ -1182,7 +1189,9 @@ def process_documents_for_pretraining(
11821189
)
11831190

11841191
logger.info("Loading tokenizer from %s", model_path)
1185-
tokenizer = AutoTokenizer.from_pretrained(model_path)
1192+
tokenizer = AutoTokenizer.from_pretrained(
1193+
model_path, trust_remote_code=_trust_remote_code()
1194+
)
11861195

11871196
if tokenizer.eos_token_id is None:
11881197
raise ValueError("Tokenizer must have an EOS token defined for pretraining")
@@ -1294,7 +1303,9 @@ def load_and_validate_dataset(data_path: str, num_procs: int) -> Dataset:
12941303

12951304
def configure_tokenizer(model_path: str) -> PreTrainedTokenizer:
12961305
"""Configure the tokenizer with necessary special tokens."""
1297-
tokenizer = AutoTokenizer.from_pretrained(model_path)
1306+
tokenizer = AutoTokenizer.from_pretrained(
1307+
model_path, trust_remote_code=_trust_remote_code()
1308+
)
12981309

12991310
if not tokenizer.chat_template:
13001311
raise ValueError(

src/instructlab/training/gpt_oss_utils_correct.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# Standard
99
from typing import Dict
1010
import logging
11+
import os
1112
import re
1213

1314
# Third Party
@@ -415,7 +416,14 @@ def is_known_model(
415416
# convert to config
416417
model_config = model_path_or_config
417418
if isinstance(model_path_or_config, str):
418-
model_config = AutoConfig.from_pretrained(model_path_or_config)
419+
_trust_remote = os.environ.get("TRUST_REMOTE_CODE", "").lower() in (
420+
"1",
421+
"true",
422+
"yes",
423+
)
424+
model_config = AutoConfig.from_pretrained(
425+
model_path_or_config, trust_remote_code=_trust_remote
426+
)
419427

420428
known_model_types = (
421429
[known_model_type] if isinstance(known_model_type, str) else known_model_type

src/instructlab/training/main_ds.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,17 @@ def main(args):
394394
tokenizer = setup_tokenizer(args.model_name_or_path, args.chat_tmpl_path)
395395
# device = torch.device("cuda", args.local_rank)
396396

397-
model_conf = AutoConfig.from_pretrained(args.model_name_or_path)
397+
# Resolve trust_remote_code from CLI flag or environment variable
398+
trust_remote_code = getattr(args, "trust_remote_code", False) or os.environ.get(
399+
"TRUST_REMOTE_CODE", ""
400+
).lower() in ("1", "true", "yes")
401+
if trust_remote_code:
402+
# Export so downstream calls (data_process, tokenizer_utils, etc.) pick it up
403+
os.environ["TRUST_REMOTE_CODE"] = "1"
404+
405+
model_conf = AutoConfig.from_pretrained(
406+
args.model_name_or_path, trust_remote_code=trust_remote_code
407+
)
398408
args.model_type = model_conf.model_type
399409

400410
#### distributed init #####
@@ -449,6 +459,7 @@ def main(args):
449459
flash_enabled=flash_enabled,
450460
noise_alpha=args.NEFTune_alpha,
451461
lora_quant_bits=args.lora_quant_bits,
462+
trust_remote_code=trust_remote_code,
452463
)
453464

454465
args.base_model_args = m.base_model_args
@@ -710,6 +721,13 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
710721
if train_args.use_liger:
711722
command.append("--use_liger")
712723

724+
# Resolve trust_remote_code from flag or environment variable
725+
trust_remote_code = train_args.trust_remote_code or os.environ.get(
726+
"TRUST_REMOTE_CODE", ""
727+
).lower() in ("1", "true", "yes")
728+
if trust_remote_code:
729+
command.append("--trust_remote_code")
730+
713731
if train_args.keep_last_checkpoint_only:
714732
command.append("--keep_last_checkpoint_only")
715733

@@ -1036,6 +1054,15 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
10361054
help="Path to the chat template to set on the model for training. If none is provided, the chat template used in the model will be used.",
10371055
)
10381056
parser.add_argument("--disable_flash_attn", action="store_true")
1057+
parser.add_argument(
1058+
"--trust_remote_code",
1059+
action="store_true",
1060+
help=(
1061+
"Trust remote code when loading models/tokenizers from HuggingFace Hub. "
1062+
"Required for models with custom code (e.g. Nemotron, Ministral, Qwen3.5). "
1063+
"Can also be set via the TRUST_REMOTE_CODE=1 environment variable."
1064+
),
1065+
)
10391066
parser.add_argument(
10401067
"--keep_last_checkpoint_only",
10411068
action="store_true",

src/instructlab/training/model.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
flash_enabled: bool = False,
6666
lora_config: Optional[LoraConfig] = None,
6767
lora_quant_bits: int = 0,
68+
trust_remote_code: bool = False,
6869
):
6970
self.lora_config = lora_config
7071
self.noise_alpha = noise_alpha
@@ -74,12 +75,13 @@ def __init__(
7475

7576
# check model type & set on the mclasss
7677
self.is_granitemoehybrid = is_known_model(model_path, "granitemoehybrid")
78+
self.is_nemotronh = is_known_model(model_path, "nemotron_h")
7779
self.is_gpt_oss = is_gpt_oss(model_path)
7880

7981
# Pre-populate the Hub kernel cache with locally installed mamba_ssm
8082
# and causal_conv1d to avoid PyTorch/CUDA ABI mismatches with the
8183
# Hub-provided kernel builds.
82-
if self.is_granitemoehybrid:
84+
if self.is_granitemoehybrid or self.is_nemotronh:
8385
self._use_local_mamba_kernels()
8486

8587
if self.is_gpt_oss:
@@ -100,6 +102,7 @@ def __init__(
100102
self.base_model_args = {
101103
"pretrained_model_name_or_path": model_path,
102104
"quantization_config": quant_config,
105+
"trust_remote_code": trust_remote_code,
103106
}
104107

105108
# load GPT-OSS in bfloat16 because it's a massive model, but otherwise
@@ -114,7 +117,7 @@ def __init__(
114117
# - M-RoPE models produce 3D position_ids that FA2 misinterprets
115118
# - Models with timm vision towers (TimmWrapperModel rejects FA2)
116119
# Detect these and fall back to SDPA.
117-
use_sdpa = needs_sdpa(model_path)
120+
use_sdpa = needs_sdpa(model_path, trust_remote_code=trust_remote_code)
118121
if use_sdpa:
119122
logger.warning(
120123
"Disabling flash_attention_2 — model is incompatible "
@@ -145,7 +148,7 @@ def __init__(
145148
# For models with timm vision towers: set vision config to eager
146149
# while keeping the text model's attention implementation.
147150
# timm's TimmWrapperModel rejects both FA2 and SDPA.
148-
if has_timm_vision_tower(model_path):
151+
if has_timm_vision_tower(model_path, trust_remote_code=trust_remote_code):
149152
attn_impl = self.base_model_args.get(
150153
"attn_implementation", "flash_attention_2"
151154
)
@@ -196,6 +199,18 @@ def _use_local_mamba_kernels():
196199
e,
197200
)
198201

202+
def _enable_gradient_checkpointing_if_supported(self) -> None:
203+
"""Enable gradient checkpointing if the model supports it.
204+
205+
Some models (e.g. NemotronH with hybrid Mamba/MoE/Attention architecture)
206+
do not support gradient checkpointing and will raise an error. This helper
207+
catches known exception types and logs a warning instead of crashing.
208+
"""
209+
try:
210+
self.model.gradient_checkpointing_enable()
211+
except (ValueError, NotImplementedError, AttributeError) as e:
212+
logger.warning("Gradient checkpointing not supported: %s", e)
213+
199214
def _post_model_init(self):
200215
"""Common initialization steps that should happen after model initialization."""
201216
self.reconcile_tokenizer()
@@ -544,6 +559,7 @@ def __init__(
544559
flash_enabled: bool = False,
545560
lora_config: Optional[LoraConfig] = None,
546561
lora_quant_bits: int = 0,
562+
trust_remote_code: bool = False,
547563
):
548564
super().__init__(
549565
model_path=model_path,
@@ -553,6 +569,7 @@ def __init__(
553569
flash_enabled=flash_enabled,
554570
lora_config=lora_config,
555571
lora_quant_bits=lora_quant_bits,
572+
trust_remote_code=trust_remote_code,
556573
)
557574
try:
558575
# Third Party
@@ -570,7 +587,7 @@ def __init__(
570587
cross_entropy=True,
571588
fused_linear_cross_entropy=False,
572589
)
573-
self.model.gradient_checkpointing_enable()
590+
self._enable_gradient_checkpointing_if_supported()
574591
self._post_model_init()
575592

576593

@@ -586,6 +603,7 @@ def __init__(
586603
flash_enabled: bool = False,
587604
lora_config: Optional[LoraConfig] = None,
588605
lora_quant_bits: int = 0,
606+
trust_remote_code: bool = False,
589607
):
590608
super().__init__(
591609
model_path=model_path,
@@ -595,15 +613,16 @@ def __init__(
595613
flash_enabled=flash_enabled,
596614
lora_config=lora_config,
597615
lora_quant_bits=lora_quant_bits,
616+
trust_remote_code=trust_remote_code,
598617
)
599-
if is_vlm_with_causal_lm(model_path):
618+
if is_vlm_with_causal_lm(model_path, trust_remote_code=trust_remote_code):
600619
self.model = extract_causal_lm_from_vlm(model_path, self.base_model_args)
601-
elif is_vlm_for_direct_loading(model_path):
620+
elif is_vlm_for_direct_loading(model_path, trust_remote_code=trust_remote_code):
602621
self.model = load_vlm_for_text_training(model_path, self.base_model_args)
603622
else:
604623
self.model = AutoModelForCausalLM.from_pretrained(**self.base_model_args)
605624
self._post_model_init()
606-
self.model.gradient_checkpointing_enable()
625+
self._enable_gradient_checkpointing_if_supported()
607626

608627

609628
def setup_optimizer(

src/instructlab/training/tokenizer_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
# Standard
4+
import os
5+
36
# Third Party
47
from transformers import AutoTokenizer, PreTrainedTokenizer
58
import transformers
@@ -95,7 +98,14 @@ def setup_tokenizer(
9598
model_name_or_path,
9699
chat_tmpl_path: str | None = None,
97100
) -> PreTrainedTokenizer:
98-
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, fast_tokenizer=True)
101+
trust_remote_code = os.environ.get("TRUST_REMOTE_CODE", "").lower() in (
102+
"1",
103+
"true",
104+
"yes",
105+
)
106+
tokenizer = AutoTokenizer.from_pretrained(
107+
model_name_or_path, trust_remote_code=trust_remote_code
108+
)
99109
if not tokenizer.chat_template and chat_tmpl_path is None:
100110
raise ValueError(
101111
"Tokenizer does not have a chat template. Please provide a path to a chat template."

src/instructlab/training/utils.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,8 +713,31 @@ def _get_state_dict_patched(model, unwrap=False):
713713
if is_gpt_oss(model.module.config):
714714
add_gpt_oss_quantization_config(model.module.config)
715715

716+
# For FP8 models (e.g. Ministral), restore the original quantization
717+
# config so the saved checkpoint matches the original FP8 format.
718+
# We must also temporarily remove _fp8_* attrs since they contain
719+
# tensors that are not JSON-serializable.
720+
_fp8_quant_cfg = getattr(model.module.config, "_fp8_quantization_config", None)
721+
_had_fp8_scales = hasattr(model.module.config, "_fp8_scales")
722+
if _fp8_quant_cfg is not None:
723+
model.module.config.quantization_config = _fp8_quant_cfg
724+
if _had_fp8_scales:
725+
_saved_scales = model.module.config._fp8_scales
726+
del model.module.config._fp8_scales
727+
if hasattr(model.module.config, "_fp8_quantization_config"):
728+
del model.module.config._fp8_quantization_config
729+
716730
model.module.config.to_json_file(output_config_file)
717731

732+
# Restore internal attrs and clear quantization_config so the live
733+
# model doesn't look quantized during training
734+
if _had_fp8_scales:
735+
model.module.config._fp8_scales = _saved_scales
736+
if _fp8_quant_cfg is not None:
737+
model.module.config._fp8_quantization_config = _fp8_quant_cfg
738+
if _fp8_quant_cfg is not None:
739+
model.module.config.quantization_config = None
740+
718741
tokenizer.save_pretrained(output_dir)
719742

720743
if is_lora:
@@ -759,7 +782,25 @@ def _get_state_dict_patched(model, unwrap=False):
759782
max_shard_size="5GB",
760783
safe_serialization=True,
761784
)
762-
elif not is_gpt_oss(model.module.config):
785+
elif getattr(model.module.config, "_fp8_scales", None):
786+
# FP8 model (e.g. Ministral): re-quantize state dict before saving
787+
if accelerator.is_main_process:
788+
# First Party
789+
from instructlab.training.vlm_utils import requantize_fp8_state_dict
790+
791+
log_rank_0("Re-quantizing FP8 parameters for checkpoint compatibility")
792+
model_state = model.module.state_dict()
793+
model_state = requantize_fp8_state_dict(
794+
model_state, model.module.config._fp8_scales
795+
)
796+
save_dict_accelerate(
797+
accelerator,
798+
model_state,
799+
save_directory=output_dir,
800+
max_shard_size="5GB",
801+
safe_serialization=True,
802+
)
803+
else:
763804
# Standard model saving
764805
accelerator.save_model(
765806
model,

0 commit comments

Comments
 (0)