diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 415bcfea..90d7c867 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -376,7 +376,7 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline return model max_logging.log("Quantizing transformer with Qwix.") - batch_size = jnp.ceil(config.per_device_batch_size * jax.local_device_count()).astype(jnp.int32) + batch_size = config.global_batch_size_to_train_on latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size) model_inputs = (latents, timesteps, prompt_embeds) with mesh: diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 4262d0cf..31c0846f 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -396,6 +396,7 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize mock_config.weight_quantization_calibration_method = "fixed,-224,224" mock_config.act_quantization_calibration_method = "fixed,-224,224" mock_config.bwd_quantization_calibration_method = "absmax" + mock_config.global_batch_size_to_train_on = 32 mock_model = Mock(spec=WanModel) mock_pipeline = Mock()