Skip to content

Register SyncBatchNorm as quantization module#1491

Open
5had3z wants to merge 1 commit into
NVIDIA:mainfrom
5had3z:fix/register-syncbn
Open

Register SyncBatchNorm as quantization module#1491
5had3z wants to merge 1 commit into
NVIDIA:mainfrom
5had3z:fix/register-syncbn

Conversation

@5had3z
Copy link
Copy Markdown

@5had3z 5had3z commented May 14, 2026

What does this PR do?

Type of change: Bug fix

Registers nn.SyncBatchNorm layer for quantization. If a model is setup for distributed training before PTQ, none of the SyncBatchNorm layers are recognised and quantized. On loading of a checkpoint there is now a mismatch between the modelopt state of a model that hasn't had DDP/SyncBN applied to it and the checkpoint trained with DDP/SyncBN.

Performing PTQ and then applying DDP/SyncBN for QAT works fine, but considering that unwrapping DDP is handled properly for either ordering of the steps, SyncBN conversion should be able to be performed in either order as well.

Usage

## train.py
model = get_model()
# DDP Setup
nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = nn.parallel.DistributedDataParallel(
    model, device_ids=[dist.get_rank()], output_device=dist.get_rank()
)
# PTQ
mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calib)
mtq.print_quant_summary(model) # Missing SyncBN layers
# QAT
train(model)
# Save Checkpoint
torch.save(mto.modelopt_state(model), "modelopt.pt")
torch.save(model.module.state_dict(), "params.pt")

## inference.py
model = get_model()
# Below fails as nn.BatchNorm2d in current model state does not have state in checkpoint since 
# nn.SyncBatchNorm modules were skipped over.
model = restore_model_from_modelopt_state(model, torch.load("modelopt.pt", weights_only=False))
mode.load_state_dict(torch.load("params.pt")

Testing

Added nn.SyncBatchNorm to the quantization tests where other BatchNorm layers appear..

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ❌ - Models trained with missing norm from their modelopt state dict will now have this depending on initialization order.
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ❌ - lmk if you want this
  • Did you get Claude approval on this PR?: ❌

Additional Information

Code for testing issue, run with python3 script.py or torchrun --nproc-per-node=2 script.py.

from pathlib import Path
import torch
import os
from torch import nn
import torch.distributed as dist
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.datasets.cifar import CIFAR10
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.transforms.v2 import ToTensor
import modelopt.torch.quantization as mtq
import modelopt.torch.opt as mto
from rich.progress import track

assert torch.cuda.is_available(), "NVIDIA GPUs required for distributed training"

torch.cuda.set_device(f"cuda:{os.environ.get('LOCAL_RANK', 0)}")
if dist.is_available() and int(os.environ.get("WORLD_SIZE", 1)) > 1:
    dist.init_process_group(backend="nccl")

model = resnet18(weights=ResNet18_Weights.DEFAULT).cuda()

if dist.is_initialized():  # SyncBN and DDP for training
    nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = nn.parallel.DistributedDataParallel(
        model, device_ids=[dist.get_rank()], output_device=dist.get_rank()
    )


def calib(m: nn.Module):
    datapath = Path.cwd() / "data"
    datapath.mkdir(exist_ok=True)
    dataset = CIFAR10(datapath, train=False, download=True, transform=ToTensor())
    if dist.is_initialized():
        sampler = DistributedSampler(dataset)
    else:
        sampler = None
    dataloader = DataLoader(dataset, sampler=sampler, num_workers=4, batch_size=8)
    for img, tgt in track(dataloader):
        m(img.cuda())


# PTQ
mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calib)
if not dist.is_initialized() or dist.get_rank() == 0:
    mtq.print_quant_summary(model)

# Do training

mto_ckpt = Path.cwd() / "opt.pt"
torch.save(mto.modelopt_state(model), mto_ckpt)
param_ckpt = Path.cwd() / "params.pt"
if isinstance(model, nn.parallel.DistributedDataParallel):
    params = model.module.state_dict()
else:
    params = model.state_dict()
torch.save(params, param_ckpt)

# Load 'single' GPU for inference
model = resnet18(weights=ResNet18_Weights.DEFAULT).cuda()
model = mto.restore_from_modelopt_state(
    model, torch.load(mto_ckpt, map_location="cuda", weights_only=False)
)
model.load_state_dict(torch.load(param_ckpt))

Summary by CodeRabbit

  • New Features

    • Added support for synchronized batch normalization layers in quantization workflows.
  • Tests

    • Extended quantization test coverage for synchronized batch normalization module types.
  • Chores

    • Updated quantization configurations to handle synchronized batch normalization layers consistently with other batch norm types.

Review Change Stack

Signed-off-by: Bryce Ferenczi <bryce.ferenczi@Arkeus.com>
@5had3z 5had3z requested review from a team as code owners May 14, 2026 04:48
@5had3z 5had3z requested a review from meenchen May 14, 2026 04:48
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 14, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 14, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 208dcebb-65ec-4555-bebc-b8880ec9d5bb

📥 Commits

Reviewing files that changed from the base of the PR and between 229ba61 and f702d84.

📒 Files selected for processing (4)
  • modelopt/torch/quantization/nn/modules/quant_batchnorm.py
  • modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml
  • modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml
  • tests/unit/torch/quantization/test_quant_batchnorm.py

📝 Walkthrough

Walkthrough

This PR adds nn.SyncBatchNorm support to the quantization framework by registering the module type, disabling its quantization by default, applying that configuration to a specific model, and extending test coverage to verify the behavior.

Changes

nn.SyncBatchNorm Quantization Support

Layer / File(s) Summary
Core module registration and default disabling
modelopt/torch/quantization/nn/modules/quant_batchnorm.py, modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml, tests/unit/torch/quantization/test_quant_batchnorm.py
nn.SyncBatchNorm is registered in QuantModuleRegistry, added to the default disabled quantizers configuration, and included in three parametrized test cases (test_no_quant, test_fake_quant_per_tensor, test_fake_quant_per_channel).
Model-specific quantization configuration
modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml
The disable rule for nn.SyncBatchNorm is applied to the Step3.5-Flash model configuration alongside the existing disabled rules for other BatchNorm variants.

🎯 1 (Trivial) | ⏱️ ~3 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and accurately describes the main change: registering SyncBatchNorm for quantization support, which is the core objective across all modified files.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns detected. Changes are minimal registration and configuration updates with no unsafe deserialization, hardcoded credentials, dangerous APIs, or dependency issues.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant