-
Notifications
You must be signed in to change notification settings - Fork 391
[Quantization] MSE-calibrate every per-expert weight in fused-experts MoE #1421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,7 +52,6 @@ | |
| promote_nvfp4_static_quantizers, | ||
| quantizer_attr_names, | ||
| reduce_amax, | ||
| weight_attr_names, | ||
| ) | ||
| from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper | ||
|
|
||
|
|
@@ -64,8 +63,100 @@ | |
| "max_calibrate", | ||
| "smoothquant", | ||
| "svdquant", | ||
| "sync_grouped_weight_global_amax", | ||
| ] | ||
|
|
||
|
|
||
| # Sibling groups that share an FP8 scale-of-scales: members feed the same input | ||
| # (Q/K/V) or get fused at deployment (gate/up), so divergent global_amax would | ||
| # split their FP8 grids. | ||
| _GROUPED_WEIGHT_QUANTIZER_PATTERNS: tuple[tuple[str, ...], ...] = ( | ||
| ("q_proj", "k_proj", "v_proj"), | ||
| ("gate_proj", "up_proj"), # Llama/Qwen/Mistral | ||
| ("w1", "w3"), # Mixtral | ||
| ) | ||
|
|
||
|
|
||
| def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]: | ||
| """Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers.""" | ||
| groups: list[list[nn.Module]] = [] | ||
| wq_attr = quantizer_attr_names("weight").weight_quantizer | ||
| for parent in model.modules(): | ||
| for sibling_names in _GROUPED_WEIGHT_QUANTIZER_PATTERNS: | ||
| members = [] | ||
| for n in sibling_names: | ||
| child = getattr(parent, n, None) | ||
| wq = getattr(child, wq_attr, None) if child is not None else None | ||
| if ( | ||
| isinstance(wq, TensorQuantizer) | ||
| and not wq._disabled | ||
| and wq.is_nvfp4_static | ||
| and getattr(wq, "_amax", None) is not None | ||
| ): | ||
| members.append(child) | ||
| if len(members) >= 2: | ||
| groups.append(members) | ||
| return groups | ||
|
|
||
|
|
||
| @torch.no_grad() | ||
| def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int: | ||
| """Populate ``_amax`` from weights for quantizers the forward pass didn't reach. | ||
|
|
||
| Dead MoE experts that received no tokens are otherwise skipped by | ||
| ``mse_calibrate``, leaving export to derive separate per-half amax for | ||
| gate/up and break the gate==up ``weight_scale_2`` invariant. | ||
| """ | ||
| n = 0 | ||
| for module in model.modules(): | ||
| if not isinstance(module, QuantModule): | ||
| continue | ||
| for weight, q in module.iter_weights_for_calibration(): | ||
| if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic: | ||
| continue | ||
| if q._calibrator is None: | ||
| continue | ||
| if getattr(q, "_amax", None) is not None and not torch.all(q._amax == 0): | ||
| continue | ||
| q.disable_quant() | ||
| q.enable_calib() | ||
| q(weight) | ||
| if q._calibrator.compute_amax() is not None: | ||
| q.load_calib_amax() | ||
| q.enable_quant() | ||
| q.disable_calib() | ||
| if hasattr(q._calibrator, "reset"): | ||
| q._calibrator.reset() | ||
| n += 1 | ||
| return n | ||
|
Comment on lines
+102
to
+131
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Run dead-expert bootstrap under This helper reads weight slices and calibrates them before entering any weight-access context. On FSDP/HF-TP/offloaded modules that can either calibrate only the local shard or hit an access failure that gets swallowed by the blanket Suggested adjustment `@torch.no_grad`()
def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int:
"""Run a max-style amax collection on weight quantizers whose ``_amax`` is missing."""
n = 0
+ name_to_module = dict(model.named_modules())
for module in model.modules():
if not isinstance(module, QuantModule):
continue
- try:
- pairs = list(module.iter_weights_for_calibration())
- except Exception:
- continue
- for weight, q in pairs:
- if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
- continue
- if q._calibrator is None:
- continue
- if hasattr(q, "_amax") and q._amax is not None and not torch.all(q._amax == 0):
- continue
- q.disable_quant()
- q.enable_calib()
- q(weight)
- if q._calibrator.compute_amax() is not None:
- q.load_calib_amax()
- q.enable_quant()
- q.disable_calib()
- if hasattr(q._calibrator, "reset"):
- q._calibrator.reset()
- n += 1
+ with enable_weight_access_and_writeback(module, model, name_to_module):
+ for weight, q in module.iter_weights_for_calibration():
+ if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
+ continue
+ if q._calibrator is None:
+ continue
+ if hasattr(q, "_amax") and q._amax is not None and not torch.all(q._amax == 0):
+ continue
+ q.disable_quant()
+ q.enable_calib()
+ q(weight)
+ if q._calibrator.compute_amax() is not None:
+ q.load_calib_amax()
+ q.enable_quant()
+ q.disable_calib()
+ if hasattr(q._calibrator, "reset"):
+ q._calibrator.reset()
+ n += 1
return n🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| @torch.no_grad() | ||
| def sync_grouped_weight_global_amax(model: nn.Module) -> int: | ||
| """Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers. | ||
|
|
||
| Reuses ``preprocess_linear_fusion`` (which performs the same unification at | ||
| export time) to keep the FP8 scale-of-scales consistent across siblings | ||
| during MSE / local-Hessian search. Must run after ``max_calibrate``. | ||
| """ | ||
| # Inline: quant_utils imports enable_stats_collection/finish_stats_collection/svd | ||
| # from this module, so top-level would deadlock the cycle. | ||
| from modelopt.torch.export.quant_utils import preprocess_linear_fusion | ||
|
|
||
| wq_attr = quantizer_attr_names("weight").weight_quantizer | ||
| n_groups = 0 | ||
| for group in _collect_grouped_linears(model): | ||
| for child in group: | ||
| wq = getattr(child, wq_attr) | ||
| if not isinstance(wq, NVFP4StaticQuantizer): | ||
| NVFP4StaticQuantizer.from_tensor_quantizer( | ||
| wq, global_amax=reduce_amax(wq._amax, axis=None) | ||
| ) | ||
| preprocess_linear_fusion(group) | ||
| n_groups += 1 | ||
| return n_groups | ||
|
|
||
|
|
||
| CalibratorFactory: TypeAlias = Callable[ | ||
| [torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator | ||
| ] | ||
|
|
@@ -346,32 +437,25 @@ def mse_calibrate( | |
| See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for | ||
| details on the remaining arguments. | ||
| """ | ||
| # Step 1: First get initial amax using max calibration | ||
| # Step 1: max calibration; then populate _amax for dead experts so step 3 | ||
| # doesn't skip them, and unify NVFP4 global_amax across Q/K/V and gate/up | ||
| # siblings so MSE searches against a consistent FP8 grid. | ||
| max_calibrate(model, forward_loop, distributed_sync) | ||
| _bootstrap_uncalibrated_weight_quantizers(model) | ||
| sync_grouped_weight_global_amax(model) | ||
|
|
||
| # Step 2: Replace calibrators with MseCalibrator for enabled quantizers | ||
| # and identify weight quantizers | ||
| weight_quantizers = [] | ||
| seen_modules = set() | ||
|
|
||
| # Step 2: replace calibrators with MseCalibrator for enabled quantizers. | ||
| for name, module in list(model.named_modules()): | ||
| if isinstance(module, TensorQuantizer) and not module._disabled: | ||
| if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): | ||
| # Get the initial amax from max calibration | ||
| initial_amax = module._amax.clone().detach() | ||
| is_nvfp4_static = module.is_nvfp4_static | ||
|
|
||
| is_nvfp4_static = ( | ||
| module.is_static_block_quant | ||
| and module._num_bits == (2, 1) | ||
| and module._block_sizes is not None | ||
| and module._block_sizes.get("scale_bits") == (4, 3) | ||
| ) | ||
|
|
||
| if is_nvfp4_static: | ||
| # Compute and set global_amax | ||
| # sync_grouped_weight_global_amax may have already promoted + | ||
| # unified global_amax across the sibling group; only promote | ||
| # standalone (non-grouped) NVFP4-static quantizers here. | ||
| if is_nvfp4_static and not isinstance(module, NVFP4StaticQuantizer): | ||
| global_amax = reduce_amax(initial_amax, axis=None) | ||
|
|
||
| # Convert to NVFP4StaticQuantizer in-place | ||
| NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) | ||
|
|
||
| if fp8_scale_sweep: | ||
|
|
@@ -412,52 +496,50 @@ def mse_calibrate( | |
| quant_func=partial(_mse_quant_func, quantizer=module), | ||
| ) | ||
|
|
||
| # Identify weight quantizers by checking if they have corresponding weight parameters | ||
| # Step 3: calibrate weight quantizers via iter_weights_for_calibration. | ||
| # The fused-experts override yields one pair per expert per projection, so | ||
| # every per-expert quantizer is MSE-calibrated (not just routed ones). | ||
| name_to_module = dict(model.named_modules()) | ||
| seen_modules: set[int] = set() | ||
| pbar = tqdm(desc="MSE weight calibration") | ||
| n_calibrated = 0 | ||
| for parent_module in name_to_module.values(): | ||
| if parent_module in seen_modules: | ||
| if id(parent_module) in seen_modules or not isinstance(parent_module, QuantModule): | ||
| continue | ||
| for weight_name in weight_attr_names(parent_module): | ||
| weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer | ||
| weight_quantizer = getattr(parent_module, weight_quantizer_name, None) | ||
| if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled: | ||
| if getattr(weight_quantizer, "_calibrator", None) is not None: | ||
| weight_quantizers.append((parent_module, weight_name, weight_quantizer)) | ||
| seen_modules.add(parent_module) | ||
|
|
||
| # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation | ||
| # This prevents massive memory accumulation seen in large models | ||
| for idx, (parent_module, weight_name, weight_quantizer) in enumerate( | ||
| tqdm(weight_quantizers, desc="MSE weight calibration") | ||
| ): | ||
| # Enable calibration mode for the weight quantizer | ||
| weight_quantizer.disable_quant() | ||
| weight_quantizer.enable_calib() | ||
| seen_modules.add(id(parent_module)) | ||
| with enable_weight_access_and_writeback(parent_module, model, name_to_module): | ||
| weight = getattr(parent_module, weight_name) | ||
| weight_quantizer(weight) | ||
| for weight, weight_quantizer in parent_module.iter_weights_for_calibration(): | ||
| if not ( | ||
| isinstance(weight_quantizer, TensorQuantizer) | ||
| and weight_quantizer.is_enabled | ||
| and getattr(weight_quantizer, "_calibrator", None) is not None | ||
| ): | ||
| continue | ||
| weight_quantizer.disable_quant() | ||
| weight_quantizer.enable_calib() | ||
| weight_quantizer(weight) | ||
|
|
||
| # IMMEDIATELY compute amax and reset calibrator to free memory | ||
| cal = getattr(weight_quantizer, "_calibrator", None) | ||
| if cal is not None and cal.compute_amax() is not None: | ||
| weight_quantizer.load_calib_amax() | ||
| cal = weight_quantizer._calibrator | ||
| if cal.compute_amax() is not None: | ||
| weight_quantizer.load_calib_amax() | ||
|
|
||
| weight_quantizer.enable_quant() | ||
| weight_quantizer.disable_calib() | ||
| weight_quantizer.enable_quant() | ||
| weight_quantizer.disable_calib() | ||
|
|
||
| # Synchronize ALL CUDA devices before resetting to ensure all async operations complete | ||
| # This is critical for multi-GPU setups where tensors may be on different devices | ||
| if torch.cuda.is_available(): | ||
| for dev_id in range(torch.cuda.device_count()): | ||
| torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) | ||
| if torch.cuda.is_available(): | ||
| for dev_id in range(torch.cuda.device_count()): | ||
| torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) | ||
|
|
||
| if cal is not None and hasattr(cal, "reset"): | ||
| cal.reset() | ||
| if hasattr(cal, "reset"): | ||
| cal.reset() | ||
|
|
||
| if (idx + 1) % 10 == 0 and torch.cuda.is_available(): | ||
| for dev_id in range(torch.cuda.device_count()): | ||
| torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) | ||
| torch.cuda.empty_cache() | ||
| pbar.update(1) | ||
| n_calibrated += 1 | ||
| if n_calibrated % 10 == 0 and torch.cuda.is_available(): | ||
| for dev_id in range(torch.cuda.device_count()): | ||
| torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) | ||
| torch.cuda.empty_cache() | ||
| pbar.close() | ||
|
|
||
| if torch.cuda.is_available(): | ||
| for dev_id in range(torch.cuda.device_count()): | ||
|
|
@@ -612,6 +694,8 @@ def forward(self, input, *args, **kwargs): | |
| print_rank_0("local_hessian: Running max calibration for all quantizers...") | ||
| max_calibrate(model, forward_loop, distributed_sync) | ||
|
|
||
| sync_grouped_weight_global_amax(model) | ||
|
|
||
| # Setup helpers for all quantized linear modules | ||
| name_to_module = dict(model.named_modules()) | ||
| weight_quantizers_info = [] | ||
|
|
@@ -666,14 +750,9 @@ def quant_func(x, amax, quantizer=weight_quantizer): | |
|
|
||
| return xq | ||
|
|
||
| is_nvfp4_static = ( | ||
| weight_quantizer.is_static_block_quant | ||
| and weight_quantizer._num_bits == (2, 1) | ||
| and weight_quantizer._block_sizes is not None | ||
| and weight_quantizer._block_sizes.get("scale_bits") == (4, 3) | ||
| ) | ||
| is_nvfp4_static = weight_quantizer.is_nvfp4_static | ||
|
|
||
| if is_nvfp4_static: | ||
| if is_nvfp4_static and not isinstance(weight_quantizer, NVFP4StaticQuantizer): | ||
| global_amax = reduce_amax(initial_amax, axis=None) | ||
| NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.