Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,18 @@ def _check_moe_calibration_complete(quantizer, parallel_state):


@torch.no_grad()
def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True):
def max_calibrate(
model: nn.Module,
forward_loop: ForwardLoop | None = None,
distributed_sync=True,
):
"""Calibrate the model using max.

Args:
model: Model to be calibrated.
forward_loop: A callable which takes the model as argument and
forwards calibration data through the model.
distributed_sync: Whether to sync input_quantizer amax across distributed processes.

See :class:`MaxCalibConfig <modelopt.torch.quantization.config.MaxCalibConfig>` for
details on the remaining arguments.
Expand All @@ -113,7 +118,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
forward_loop(model)
finish_stats_collection(model)

# Sync amax across local experts within each rank (for SequentialMLP)
# Sync input_quantizer amax across local experts within each rank (for SequentialMLP)
for name, module in model.named_modules():
if hasattr(module, "layer_sync_moe_local_experts_amax"):
module.layer_sync_moe_local_experts_amax()
Expand Down
16 changes: 10 additions & 6 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,26 +575,30 @@ def _setup(self):
expert.linear_fc2.parallel_state = self.parallel_state

def layer_sync_moe_local_experts_amax(self):
"""Sync amax across local experts in a SequentialMLP.
"""Sync input quantizer amax across local experts in a SequentialMLP.

Synchronize the amax values across local experts in a lyaer such that all local experts will
share the same amax. This function operates on a single rank and does not require distributed sync.
Ensures all experts have the same input quantizer amax.This function operates
on a single rank and does not require distributed sync.

Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate().
This function should be called before the distributed sync to ensure the amax values
are synchronized across the layer first.

Note:
Because there are logic which calls collective communication based on whether amax is not None,
We need to garuantee that all experts must have amax. Otherwise, there will be deadlock
when synchroizing over EP since some ranks may have amax None and not calling the collective
We need to guarantee that all experts must have amax. Otherwise, there will be deadlock
when synchronizing over EP since some ranks may have amax None and not calling the collective
communication.
"""
# Collect amax from all local experts
amax_dict = {}
for expert in self.local_experts:
for name, module in expert.named_modules():
if isinstance(module, TensorQuantizer) and module.amax is not None:
if (
isinstance(module, TensorQuantizer)
and module.amax is not None
and "input_quantizer" in name
):
stored_amax = amax_dict.get(name)
amax_tensor = module.amax.detach().clone()
amax_dict[name] = (
Expand Down
83 changes: 76 additions & 7 deletions tests/gpu_megatron/torch/quantization/plugins/test_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,7 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device,

@pytest.mark.parametrize(
"config",
[
NVFP4_GEMM_KV_CFG,
FP8_GEMM_KV_CFG,
],
[NVFP4_GEMM_KV_CFG, FP8_GEMM_KV_CFG, mtq.MAMBA_MOE_NVFP4_CONSERVATIVE_CFG],
)
def test_homogeneous_sharded_state_dict_hybrid(tmp_path, config):
"""Test sharded state dict for hybrid Mamba MOE models."""
Expand Down Expand Up @@ -735,6 +732,81 @@ def test_te_grouped_vs_sequential_quantize(need_4_gpus):
)


@pytest.mark.parametrize("ep_size", [1, 2])
@pytest.mark.parametrize("moe_grouped_gemm", [True, False])
def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm):
"""Test expert model parallel synchronization."""
size = torch.cuda.device_count()
if size < ep_size:
pytest.skip(f"Requires at least {ep_size} GPUs for expert model parallel test")

spawn_multiprocess_job(
size=size,
job=partial(
_test_layer_sync_moe_local_experts_amax,
ep_size,
moe_grouped_gemm,
),
backend="nccl",
)


def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, size):
initialize_for_megatron(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
expert_model_parallel_size=ep_size,
expert_tensor_parallel_size=1,
seed=SEED,
)
model = _gpt_model_provider(
tp_size=1,
ep_size=ep_size,
etp_size=1,
hidden_size=256,
moe_grouped_gemm=moe_grouped_gemm,
use_te=moe_grouped_gemm,
num_moe_experts=8,
transformer_impl="modelopt",
)
quant_cfg = mtq.FP8_DEFAULT_CFG
model = mtq.quantize(model, quant_cfg, get_forward(model))

for layer in model.decoder.layers:
layer.mlp.experts.layer_sync_moe_local_experts_amax()

for layer in model.decoder.layers:
# Check input quantizer amax is synced across local experts
fc1_amax = None
fc2_amax = None
for expert in layer.mlp.experts.local_experts:
assert expert.linear_fc1.input_quantizer.amax is not None
assert expert.linear_fc2.input_quantizer.amax is not None
if fc1_amax is None:
fc1_amax = expert.linear_fc1.input_quantizer.amax
else:
assert torch.allclose(fc1_amax, expert.linear_fc1.input_quantizer.amax)
if fc2_amax is None:
fc2_amax = expert.linear_fc2.input_quantizer.amax
else:
assert torch.allclose(fc2_amax, expert.linear_fc2.input_quantizer.amax)

# Check weight quantizer amax is different across local experts
fc1_amax = None
fc2_amax = None
for expert in layer.mlp.experts.local_experts:
assert expert.linear_fc1.weight_quantizer.amax is not None
assert expert.linear_fc2.weight_quantizer.amax is not None
if fc1_amax is None:
fc1_amax = expert.linear_fc1.weight_quantizer.amax
else:
assert not torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax)
if fc2_amax is None:
fc2_amax = expert.linear_fc2.weight_quantizer.amax
else:
assert not torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax)


def _test_expert_model_parallel_amax_sync(
tp_size, ep_size, etp_size, moe_grouped_gemm, config, rank, size
):
Expand Down Expand Up @@ -815,9 +887,6 @@ def test_expert_parallel_sync(config, ep_size, etp_size, moe_grouped_gemm):
if size < ep_size * etp_size:
pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test")

if moe_grouped_gemm:
pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently")

spawn_multiprocess_job(
size=size,
job=partial(
Expand Down
Loading