diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py index 8981d614843..e325e5346f1 100644 --- a/modelopt/torch/export/moe_utils.py +++ b/modelopt/torch/export/moe_utils.py @@ -110,11 +110,19 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: and w_quantizer._amax.dim() >= 1 ): amax = w_quantizer._amax + # Per-block _amax (NVFP4 static) collapses the row axis we want + # to slice on; restore it so dim-0 slicing splits gate/up. + if amax.numel() != fused_total and amax.numel() % fused_total == 0: + amax = amax.contiguous().view(fused_total, amax.numel() // fused_total) amax_dim0 = amax.shape[0] if fused_total % amax_dim0 == 0: slice_start = fused_start * amax_dim0 // fused_total slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total - w_quantizer.amax = amax[slice_start:slice_end].contiguous() + sliced = amax[slice_start:slice_end].contiguous() + # The amax setter refuses shape changes; drop _amax first. + if hasattr(w_quantizer, "_amax"): + delattr(w_quantizer, "_amax") + w_quantizer.amax = sliced else: warnings.warn( f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not " diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index a58aa4c9895..73ae63a5a56 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -1134,6 +1134,19 @@ def _unpatch_revert_weight_conversion(patches: list[tuple[Any, Any]]) -> None: mod.revert_weight_conversion = original +def _sanitize_generation_config_for_save(model: torch.nn.Module) -> None: + """Force ``do_sample=True`` when generation_config has ``top_k``/``top_p`` set. + + Newer transformers reject ``do_sample=False`` mixed with sampling attrs in + ``save_pretrained``'s strict validate. + """ + gc = getattr(model, "generation_config", None) + if gc is None: + return + if getattr(gc, "top_k", None) is not None or getattr(gc, "top_p", None) is not None: + gc.do_sample = True + + def export_speculative_decoding( model: torch.nn.Module, dtype: torch.dtype | None = None, @@ -1228,6 +1241,8 @@ def export_hf_checkpoint( # modeling_utils does `from core_model_loading import revert_weight_conversion`. _patches = _patch_revert_weight_conversion() + _sanitize_generation_config_for_save(model) + try: model.save_pretrained( export_dir, diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index fe4c3f77ce6..d864008340a 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -52,7 +52,6 @@ promote_nvfp4_static_quantizers, quantizer_attr_names, reduce_amax, - weight_attr_names, ) from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper @@ -64,8 +63,100 @@ "max_calibrate", "smoothquant", "svdquant", + "sync_grouped_weight_global_amax", ] + +# Sibling groups that share an FP8 scale-of-scales: members feed the same input +# (Q/K/V) or get fused at deployment (gate/up), so divergent global_amax would +# split their FP8 grids. +_GROUPED_WEIGHT_QUANTIZER_PATTERNS: tuple[tuple[str, ...], ...] = ( + ("q_proj", "k_proj", "v_proj"), + ("gate_proj", "up_proj"), # Llama/Qwen/Mistral + ("w1", "w3"), # Mixtral +) + + +def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]: + """Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers.""" + groups: list[list[nn.Module]] = [] + wq_attr = quantizer_attr_names("weight").weight_quantizer + for parent in model.modules(): + for sibling_names in _GROUPED_WEIGHT_QUANTIZER_PATTERNS: + members = [] + for n in sibling_names: + child = getattr(parent, n, None) + wq = getattr(child, wq_attr, None) if child is not None else None + if ( + isinstance(wq, TensorQuantizer) + and not wq._disabled + and wq.is_nvfp4_static + and getattr(wq, "_amax", None) is not None + ): + members.append(child) + if len(members) >= 2: + groups.append(members) + return groups + + +@torch.no_grad() +def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int: + """Populate ``_amax`` from weights for quantizers the forward pass didn't reach. + + Dead MoE experts that received no tokens are otherwise skipped by + ``mse_calibrate``, leaving export to derive separate per-half amax for + gate/up and break the gate==up ``weight_scale_2`` invariant. + """ + n = 0 + for module in model.modules(): + if not isinstance(module, QuantModule): + continue + for weight, q in module.iter_weights_for_calibration(): + if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic: + continue + if q._calibrator is None: + continue + if getattr(q, "_amax", None) is not None and not torch.all(q._amax == 0): + continue + q.disable_quant() + q.enable_calib() + q(weight) + if q._calibrator.compute_amax() is not None: + q.load_calib_amax() + q.enable_quant() + q.disable_calib() + if hasattr(q._calibrator, "reset"): + q._calibrator.reset() + n += 1 + return n + + +@torch.no_grad() +def sync_grouped_weight_global_amax(model: nn.Module) -> int: + """Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers. + + Reuses ``preprocess_linear_fusion`` (which performs the same unification at + export time) to keep the FP8 scale-of-scales consistent across siblings + during MSE / local-Hessian search. Must run after ``max_calibrate``. + """ + # Inline: quant_utils imports enable_stats_collection/finish_stats_collection/svd + # from this module, so top-level would deadlock the cycle. + from modelopt.torch.export.quant_utils import preprocess_linear_fusion + + wq_attr = quantizer_attr_names("weight").weight_quantizer + n_groups = 0 + for group in _collect_grouped_linears(model): + for child in group: + wq = getattr(child, wq_attr) + if not isinstance(wq, NVFP4StaticQuantizer): + NVFP4StaticQuantizer.from_tensor_quantizer( + wq, global_amax=reduce_amax(wq._amax, axis=None) + ) + preprocess_linear_fusion(group) + n_groups += 1 + return n_groups + + CalibratorFactory: TypeAlias = Callable[ [torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator ] @@ -346,32 +437,25 @@ def mse_calibrate( See :class:`MseCalibConfig ` for details on the remaining arguments. """ - # Step 1: First get initial amax using max calibration + # Step 1: max calibration; then populate _amax for dead experts so step 3 + # doesn't skip them, and unify NVFP4 global_amax across Q/K/V and gate/up + # siblings so MSE searches against a consistent FP8 grid. max_calibrate(model, forward_loop, distributed_sync) + _bootstrap_uncalibrated_weight_quantizers(model) + sync_grouped_weight_global_amax(model) - # Step 2: Replace calibrators with MseCalibrator for enabled quantizers - # and identify weight quantizers - weight_quantizers = [] - seen_modules = set() - + # Step 2: replace calibrators with MseCalibrator for enabled quantizers. for name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - # Get the initial amax from max calibration initial_amax = module._amax.clone().detach() + is_nvfp4_static = module.is_nvfp4_static - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - - if is_nvfp4_static: - # Compute and set global_amax + # sync_grouped_weight_global_amax may have already promoted + + # unified global_amax across the sibling group; only promote + # standalone (non-grouped) NVFP4-static quantizers here. + if is_nvfp4_static and not isinstance(module, NVFP4StaticQuantizer): global_amax = reduce_amax(initial_amax, axis=None) - - # Convert to NVFP4StaticQuantizer in-place NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) if fp8_scale_sweep: @@ -412,52 +496,50 @@ def mse_calibrate( quant_func=partial(_mse_quant_func, quantizer=module), ) - # Identify weight quantizers by checking if they have corresponding weight parameters + # Step 3: calibrate weight quantizers via iter_weights_for_calibration. + # The fused-experts override yields one pair per expert per projection, so + # every per-expert quantizer is MSE-calibrated (not just routed ones). name_to_module = dict(model.named_modules()) + seen_modules: set[int] = set() + pbar = tqdm(desc="MSE weight calibration") + n_calibrated = 0 for parent_module in name_to_module.values(): - if parent_module in seen_modules: + if id(parent_module) in seen_modules or not isinstance(parent_module, QuantModule): continue - for weight_name in weight_attr_names(parent_module): - weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer - weight_quantizer = getattr(parent_module, weight_quantizer_name, None) - if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled: - if getattr(weight_quantizer, "_calibrator", None) is not None: - weight_quantizers.append((parent_module, weight_name, weight_quantizer)) - seen_modules.add(parent_module) - - # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation - # This prevents massive memory accumulation seen in large models - for idx, (parent_module, weight_name, weight_quantizer) in enumerate( - tqdm(weight_quantizers, desc="MSE weight calibration") - ): - # Enable calibration mode for the weight quantizer - weight_quantizer.disable_quant() - weight_quantizer.enable_calib() + seen_modules.add(id(parent_module)) with enable_weight_access_and_writeback(parent_module, model, name_to_module): - weight = getattr(parent_module, weight_name) - weight_quantizer(weight) + for weight, weight_quantizer in parent_module.iter_weights_for_calibration(): + if not ( + isinstance(weight_quantizer, TensorQuantizer) + and weight_quantizer.is_enabled + and getattr(weight_quantizer, "_calibrator", None) is not None + ): + continue + weight_quantizer.disable_quant() + weight_quantizer.enable_calib() + weight_quantizer(weight) - # IMMEDIATELY compute amax and reset calibrator to free memory - cal = getattr(weight_quantizer, "_calibrator", None) - if cal is not None and cal.compute_amax() is not None: - weight_quantizer.load_calib_amax() + cal = weight_quantizer._calibrator + if cal.compute_amax() is not None: + weight_quantizer.load_calib_amax() - weight_quantizer.enable_quant() - weight_quantizer.disable_calib() + weight_quantizer.enable_quant() + weight_quantizer.disable_calib() - # Synchronize ALL CUDA devices before resetting to ensure all async operations complete - # This is critical for multi-GPU setups where tensors may be on different devices - if torch.cuda.is_available(): - for dev_id in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + if torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) - if cal is not None and hasattr(cal, "reset"): - cal.reset() + if hasattr(cal, "reset"): + cal.reset() - if (idx + 1) % 10 == 0 and torch.cuda.is_available(): - for dev_id in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) - torch.cuda.empty_cache() + pbar.update(1) + n_calibrated += 1 + if n_calibrated % 10 == 0 and torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + torch.cuda.empty_cache() + pbar.close() if torch.cuda.is_available(): for dev_id in range(torch.cuda.device_count()): @@ -612,6 +694,8 @@ def forward(self, input, *args, **kwargs): print_rank_0("local_hessian: Running max calibration for all quantizers...") max_calibrate(model, forward_loop, distributed_sync) + sync_grouped_weight_global_amax(model) + # Setup helpers for all quantized linear modules name_to_module = dict(model.named_modules()) weight_quantizers_info = [] @@ -666,14 +750,9 @@ def quant_func(x, amax, quantizer=weight_quantizer): return xq - is_nvfp4_static = ( - weight_quantizer.is_static_block_quant - and weight_quantizer._num_bits == (2, 1) - and weight_quantizer._block_sizes is not None - and weight_quantizer._block_sizes.get("scale_bits") == (4, 3) - ) + is_nvfp4_static = weight_quantizer.is_nvfp4_static - if is_nvfp4_static: + if is_nvfp4_static and not isinstance(weight_quantizer, NVFP4StaticQuantizer): global_amax = reduce_amax(initial_amax, axis=None) NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 3ff7401ec3e..fa540b8fdf5 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -514,6 +514,16 @@ def is_mx_format(self): and self.block_sizes.get("scale_bits", None) == (8, 0) ) + @property + def is_nvfp4_static(self): + """True for E2M1 weights + E4M3 per-block scales in static layout (format-only check).""" + return ( + self.is_static_block_quant + and self._num_bits == (2, 1) + and self._block_sizes is not None + and self._block_sizes.get("scale_bits") == (4, 3) + ) + def is_mxfp(self, bits): """Check if is MXFP4/MXFP6/MXFP8.""" if bits == 4: diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 77f26b20602..1873ecda528 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -900,6 +900,24 @@ def forward(self, *args, **kwargs): self._down_proj_linear = False return super().forward(*args, **kwargs) + def iter_weights_for_calibration(self): + """Yield ``(weight_slice, quantizer)`` per-expert pairs. + + The base impl uses singular ``*_weight_quantizer`` and skips fused- + experts modules, so weight-only calibration never reaches per-expert + quantizers without this override. + """ + for weight_name, quantizers_name in ( + ("gate_up_proj", "gate_up_proj_weight_quantizers"), + ("down_proj", "down_proj_weight_quantizers"), + ): + weight = getattr(self, weight_name, None) + quantizers = getattr(self, quantizers_name, None) + if weight is None or quantizers is None: + continue + for idx, q in enumerate(quantizers): + yield weight[idx], q + def fold_weight(self, keep_attrs: bool = False): """Fold per-expert weight quantizers into the fused 3-D weights. diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 1a177e04dc8..cea3d4260e4 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -957,13 +957,7 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int: for _name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - if is_nvfp4_static: + if module.is_nvfp4_static: initial_amax = module._amax.clone().detach() global_amax = reduce_amax(initial_amax, axis=None) NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)