From f702d8463b9fd4e3b513dbf4d887406f798812c1 Mon Sep 17 00:00:00 2001 From: Bryce Ferenczi Date: Thu, 14 May 2026 13:41:44 +1000 Subject: [PATCH] Register SyncBatchNorm module for quantization. Signed-off-by: Bryce Ferenczi --- modelopt/torch/quantization/nn/modules/quant_batchnorm.py | 1 + .../configs/ptq/units/default_disabled_quantizers.yaml | 3 +++ modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml | 3 +++ tests/unit/torch/quantization/test_quant_batchnorm.py | 3 +++ 4 files changed, 10 insertions(+) diff --git a/modelopt/torch/quantization/nn/modules/quant_batchnorm.py b/modelopt/torch/quantization/nn/modules/quant_batchnorm.py index 21eed5b82e0..2f07547566b 100644 --- a/modelopt/torch/quantization/nn/modules/quant_batchnorm.py +++ b/modelopt/torch/quantization/nn/modules/quant_batchnorm.py @@ -22,3 +22,4 @@ QuantModuleRegistry.register({nn.BatchNorm1d: "nn.BatchNorm1d"})(QuantInputBase) QuantModuleRegistry.register({nn.BatchNorm2d: "nn.BatchNorm2d"})(QuantInputBase) QuantModuleRegistry.register({nn.BatchNorm3d: "nn.BatchNorm3d"})(QuantInputBase) +QuantModuleRegistry.register({nn.SyncBatchNorm: "nn.SyncBatchNorm"})(QuantInputBase) diff --git a/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml index 1508f942776..a527e720b75 100644 --- a/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml +++ b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml @@ -45,6 +45,9 @@ - parent_class: 'nn.BatchNorm3d' quantizer_name: '*' enable: false + - parent_class: 'nn.SyncBatchNorm' + quantizer_name: '*' + enable: false - parent_class: 'nn.LeakyReLU' quantizer_name: '*' enable: false diff --git a/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml b/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml index c00aff7d44f..ea3af913cd9 100644 --- a/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml +++ b/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml @@ -84,6 +84,9 @@ quantize: - parent_class: 'nn.BatchNorm3d' quantizer_name: '*' enable: false + - parent_class: 'nn.SyncBatchNorm' + quantizer_name: '*' + enable: false - parent_class: 'nn.LeakyReLU' quantizer_name: '*' enable: false diff --git a/tests/unit/torch/quantization/test_quant_batchnorm.py b/tests/unit/torch/quantization/test_quant_batchnorm.py index c55b4b0b0e4..ef9bdd87222 100644 --- a/tests/unit/torch/quantization/test_quant_batchnorm.py +++ b/tests/unit/torch/quantization/test_quant_batchnorm.py @@ -34,6 +34,7 @@ class TestQuantBatchNormND: (nn.BatchNorm1d, (2, NUM_CHANNELS, 8)), (nn.BatchNorm2d, (2, NUM_CHANNELS, 8, 8)), (nn.BatchNorm3d, (2, NUM_CHANNELS, 8, 8, 8)), + (nn.SyncBatchNorm, (2, NUM_CHANNELS, 8, 8)), ], ) def test_no_quant(self, original_cls, input_shape): @@ -60,6 +61,7 @@ def test_no_quant(self, original_cls, input_shape): (nn.BatchNorm1d, (2, NUM_CHANNELS, 8)), (nn.BatchNorm2d, (2, NUM_CHANNELS, 8, 8)), (nn.BatchNorm3d, (2, NUM_CHANNELS, 8, 8, 8)), + (nn.SyncBatchNorm, (2, NUM_CHANNELS, 8, 8)), ], ) def test_fake_quant_per_tensor(self, original_cls, input_shape): @@ -86,6 +88,7 @@ def test_fake_quant_per_tensor(self, original_cls, input_shape): (nn.BatchNorm1d, (2, NUM_CHANNELS, 8)), (nn.BatchNorm2d, (2, NUM_CHANNELS, 8, 8)), (nn.BatchNorm3d, (2, NUM_CHANNELS, 8, 8, 8)), + (nn.SyncBatchNorm, (2, NUM_CHANNELS, 8, 8)), ], ) def test_fake_quant_per_channel(self, original_cls, input_shape):