diff --git a/modelopt/torch/quantization/nn/modules/quant_batchnorm.py b/modelopt/torch/quantization/nn/modules/quant_batchnorm.py index 21eed5b82e0..2f07547566b 100644 --- a/modelopt/torch/quantization/nn/modules/quant_batchnorm.py +++ b/modelopt/torch/quantization/nn/modules/quant_batchnorm.py @@ -22,3 +22,4 @@ QuantModuleRegistry.register({nn.BatchNorm1d: "nn.BatchNorm1d"})(QuantInputBase) QuantModuleRegistry.register({nn.BatchNorm2d: "nn.BatchNorm2d"})(QuantInputBase) QuantModuleRegistry.register({nn.BatchNorm3d: "nn.BatchNorm3d"})(QuantInputBase) +QuantModuleRegistry.register({nn.SyncBatchNorm: "nn.SyncBatchNorm"})(QuantInputBase) diff --git a/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml index 1508f942776..a527e720b75 100644 --- a/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml +++ b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml @@ -45,6 +45,9 @@ - parent_class: 'nn.BatchNorm3d' quantizer_name: '*' enable: false + - parent_class: 'nn.SyncBatchNorm' + quantizer_name: '*' + enable: false - parent_class: 'nn.LeakyReLU' quantizer_name: '*' enable: false diff --git a/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml b/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml index c00aff7d44f..ea3af913cd9 100644 --- a/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml +++ b/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml @@ -84,6 +84,9 @@ quantize: - parent_class: 'nn.BatchNorm3d' quantizer_name: '*' enable: false + - parent_class: 'nn.SyncBatchNorm' + quantizer_name: '*' + enable: false - parent_class: 'nn.LeakyReLU' quantizer_name: '*' enable: false diff --git a/tests/unit/torch/quantization/test_quant_batchnorm.py b/tests/unit/torch/quantization/test_quant_batchnorm.py index c55b4b0b0e4..ef9bdd87222 100644 --- a/tests/unit/torch/quantization/test_quant_batchnorm.py +++ b/tests/unit/torch/quantization/test_quant_batchnorm.py @@ -34,6 +34,7 @@ class TestQuantBatchNormND: (nn.BatchNorm1d, (2, NUM_CHANNELS, 8)), (nn.BatchNorm2d, (2, NUM_CHANNELS, 8, 8)), (nn.BatchNorm3d, (2, NUM_CHANNELS, 8, 8, 8)), + (nn.SyncBatchNorm, (2, NUM_CHANNELS, 8, 8)), ], ) def test_no_quant(self, original_cls, input_shape): @@ -60,6 +61,7 @@ def test_no_quant(self, original_cls, input_shape): (nn.BatchNorm1d, (2, NUM_CHANNELS, 8)), (nn.BatchNorm2d, (2, NUM_CHANNELS, 8, 8)), (nn.BatchNorm3d, (2, NUM_CHANNELS, 8, 8, 8)), + (nn.SyncBatchNorm, (2, NUM_CHANNELS, 8, 8)), ], ) def test_fake_quant_per_tensor(self, original_cls, input_shape): @@ -86,6 +88,7 @@ def test_fake_quant_per_tensor(self, original_cls, input_shape): (nn.BatchNorm1d, (2, NUM_CHANNELS, 8)), (nn.BatchNorm2d, (2, NUM_CHANNELS, 8, 8)), (nn.BatchNorm3d, (2, NUM_CHANNELS, 8, 8, 8)), + (nn.SyncBatchNorm, (2, NUM_CHANNELS, 8, 8)), ], ) def test_fake_quant_per_channel(self, original_cls, input_shape):