forked from NVIDIA/Model-Optimizer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquantize.py
More file actions
952 lines (801 loc) · 32.3 KB
/
quantize.py
File metadata and controls
952 lines (801 loc) · 32.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import sys
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any
import torch
import torch.nn as nn
from config import (
FP8_DEFAULT_CONFIG,
INT8_DEFAULT_CONFIG,
NVFP4_DEFAULT_CONFIG,
NVFP4_FP8_MHA_CONFIG,
reset_set_int8_config,
set_quant_config_attr,
)
from diffusers import (
DiffusionPipeline,
FluxPipeline,
LTXConditionPipeline,
LTXLatentUpsamplePipeline,
StableDiffusion3Pipeline,
)
from onnx_utils.export import generate_fp8_scales, modelopt_export_sd
from tqdm import tqdm
from utils import (
check_conv_and_mha,
check_lora,
filter_func_default,
filter_func_ltx_video,
load_calib_prompts,
)
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
class ModelType(str, Enum):
"""Supported model types."""
SDXL_BASE = "sdxl-1.0"
SDXL_TURBO = "sdxl-turbo"
SD3_MEDIUM = "sd3-medium"
FLUX_DEV = "flux-dev"
FLUX_SCHNELL = "flux-schnell"
LTX_VIDEO_DEV = "ltx-video-dev"
class DataType(str, Enum):
"""Supported data types for model loading."""
HALF = "Half"
BFLOAT16 = "BFloat16"
FLOAT = "Float"
class QuantFormat(str, Enum):
"""Supported quantization formats."""
INT8 = "int8"
FP8 = "fp8"
FP4 = "fp4"
class QuantAlgo(str, Enum):
"""Supported quantization algorithms."""
MAX = "max"
SVDQUANT = "svdquant"
SMOOTHQUANT = "smoothquant"
class CollectMethod(str, Enum):
"""Calibration collection methods."""
GLOBAL_MIN = "global_min"
MIN_MAX = "min-max"
MIN_MEAN = "min-mean"
MEAN_MAX = "mean-max"
DEFAULT = "default"
def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
"""
Get the appropriate filter function for a given model type.
Args:
model_type: The model type enum
Returns:
A filter function appropriate for the model type
"""
filter_func_map = {
ModelType.FLUX_DEV: filter_func_default,
ModelType.FLUX_SCHNELL: filter_func_default,
ModelType.SDXL_BASE: filter_func_default,
ModelType.SDXL_TURBO: filter_func_default,
ModelType.SD3_MEDIUM: filter_func_default,
ModelType.LTX_VIDEO_DEV: filter_func_ltx_video,
}
return filter_func_map.get(model_type, filter_func_default)
# Model registry with HuggingFace model IDs
MODEL_REGISTRY: dict[ModelType, str] = {
ModelType.SDXL_BASE: "stabilityai/stable-diffusion-xl-base-1.0",
ModelType.SDXL_TURBO: "stabilityai/sdxl-turbo",
ModelType.SD3_MEDIUM: "stabilityai/stable-diffusion-3-medium-diffusers",
ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev",
ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell",
ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev",
}
# Model-specific default arguments for calibration
MODEL_DEFAULTS: dict[ModelType, dict[str, Any]] = {
ModelType.FLUX_DEV: {
"height": 1024,
"width": 1024,
"guidance_scale": 3.5,
"max_sequence_length": 512,
},
ModelType.FLUX_SCHNELL: {
"height": 1024,
"width": 1024,
"guidance_scale": 3.5,
"max_sequence_length": 512,
},
ModelType.LTX_VIDEO_DEV: {
"height": 512,
"width": 704,
"num_frames": 121,
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
},
}
@dataclass
class QuantizationConfig:
"""Configuration for model quantization."""
format: QuantFormat = QuantFormat.INT8
algo: QuantAlgo = QuantAlgo.MAX
percentile: float = 1.0
collect_method: CollectMethod = CollectMethod.DEFAULT
alpha: float = 1.0 # SmoothQuant alpha
lowrank: int = 32 # SVDQuant lowrank
quantize_mha: bool = False
def validate(self) -> None:
"""Validate configuration consistency."""
if self.format == QuantFormat.FP8 and self.collect_method != CollectMethod.DEFAULT:
raise NotImplementedError("Only 'default' collect method is implemented for FP8.")
if self.quantize_mha and self.format == QuantFormat.INT8:
raise ValueError("MHA quantization is only supported for FP8, not INT8.")
@dataclass
class CalibrationConfig:
"""Configuration for calibration process."""
batch_size: int = 2
calib_size: int = 128
n_steps: int = 30
prompts_dataset: str = "Gustavosta/Stable-Diffusion-Prompts"
def validate(self) -> None:
"""Validate calibration configuration."""
if self.batch_size <= 0:
raise ValueError("Batch size must be positive.")
if self.calib_size <= 0:
raise ValueError("Calibration size must be positive.")
if self.n_steps <= 0:
raise ValueError("Number of steps must be positive.")
@property
def num_batches(self) -> int:
"""Calculate number of calibration batches."""
return self.calib_size // self.batch_size
@dataclass
class ModelConfig:
"""Configuration for model loading and inference."""
model_type: ModelType = ModelType.FLUX_DEV
model_dtype: DataType = DataType.HALF
trt_high_precision_dtype: DataType = DataType.HALF
override_model_path: Path | None = None
cpu_offloading: bool = False
ltx_skip_upsampler: bool = False # Skip upsampler for LTX-Video (faster calibration)
@property
def torch_dtype(self) -> torch.dtype:
"""Convert DataType enum to torch.dtype."""
dtype_map = {
DataType.HALF: torch.float16,
DataType.BFLOAT16: torch.bfloat16,
DataType.FLOAT: torch.float32,
}
return dtype_map[self.model_dtype]
@property
def model_path(self) -> str:
"""Get the model path (override or default)."""
if self.override_model_path:
return str(self.override_model_path)
return MODEL_REGISTRY[self.model_type]
@property
def uses_transformer(self) -> bool:
"""Check if model uses transformer backbone (vs UNet)."""
return self.model_type in [
ModelType.SD3_MEDIUM,
ModelType.FLUX_DEV,
ModelType.FLUX_SCHNELL,
ModelType.LTX_VIDEO_DEV,
]
@dataclass
class ExportConfig:
"""Configuration for model export."""
quantized_torch_ckpt_path: Path | None = None
onnx_dir: Path | None = None
restore_from: Path | None = None
def validate(self) -> None:
"""Validate export configuration."""
if self.restore_from and not self.restore_from.exists():
raise FileNotFoundError(f"Restore checkpoint not found: {self.restore_from}")
if self.quantized_torch_ckpt_path:
parent_dir = self.quantized_torch_ckpt_path.parent
if not parent_dir.exists():
parent_dir.mkdir(parents=True, exist_ok=True)
if self.onnx_dir and not self.onnx_dir.exists():
self.onnx_dir.mkdir(parents=True, exist_ok=True)
def setup_logging(verbose: bool = False) -> logging.Logger:
"""
Set up logging configuration.
Args:
verbose: Enable verbose logging
Returns:
Configured logger instance
"""
log_level = logging.DEBUG if verbose else logging.INFO
# Create custom formatter
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
# Set up console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
# Configure root logger
logger = logging.getLogger(__name__)
logger.setLevel(log_level)
logger.addHandler(console_handler)
# Optionally reduce noise from other libraries
logging.getLogger("diffusers").setLevel(logging.WARNING)
logging.getLogger("transformers").setLevel(logging.WARNING)
return logger
class PipelineManager:
"""Manages diffusion pipeline creation and configuration."""
def __init__(self, config: ModelConfig, logger: logging.Logger):
"""
Initialize pipeline manager.
Args:
config: Model configuration
logger: Logger instance
"""
self.config = config
self.logger = logger
self.pipe: DiffusionPipeline | None = None
self.pipe_upsample: LTXLatentUpsamplePipeline | None = None # For LTX-Video upsampling
@staticmethod
def create_pipeline_from(
model_type: ModelType,
torch_dtype: torch.dtype = torch.bfloat16,
override_model_path: str | None = None,
) -> DiffusionPipeline:
"""
Create and return an appropriate pipeline based on configuration.
Returns:
Configured diffusion pipeline
Raises:
ValueError: If model type is unsupported
"""
try:
model_id = (
MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path
)
if model_type == ModelType.SD3_MEDIUM:
pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
elif model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]:
pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
else:
# SDXL models
pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch_dtype,
use_safetensors=True,
)
pipe.set_progress_bar_config(disable=True)
return pipe
except Exception as e:
raise e
def create_pipeline(self) -> DiffusionPipeline:
"""
Create and return an appropriate pipeline based on configuration.
Returns:
Configured diffusion pipeline
Raises:
ValueError: If model type is unsupported
"""
self.logger.info(f"Creating pipeline for {self.config.model_type.value}")
self.logger.info(f"Model path: {self.config.model_path}")
self.logger.info(f"Data type: {self.config.model_dtype.value}")
try:
if self.config.model_type == ModelType.SD3_MEDIUM:
self.pipe = StableDiffusion3Pipeline.from_pretrained(
self.config.model_path, torch_dtype=self.config.torch_dtype
)
elif self.config.model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]:
self.pipe = FluxPipeline.from_pretrained(
self.config.model_path, torch_dtype=self.config.torch_dtype
)
elif self.config.model_type == ModelType.LTX_VIDEO_DEV:
self.pipe = LTXConditionPipeline.from_pretrained(
self.config.model_path, torch_dtype=self.config.torch_dtype
)
# Optionally load the upsampler pipeline for LTX-Video
if not self.config.ltx_skip_upsampler:
self.logger.info("Loading LTX-Video upsampler pipeline...")
self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
"Lightricks/ltxv-spatial-upscaler-0.9.7",
vae=self.pipe.vae,
torch_dtype=self.config.torch_dtype,
)
self.pipe_upsample.set_progress_bar_config(disable=True)
else:
self.logger.info("Skipping upsampler pipeline for faster calibration")
else:
# SDXL models
self.pipe = DiffusionPipeline.from_pretrained(
self.config.model_path,
torch_dtype=self.config.torch_dtype,
use_safetensors=True,
)
self.pipe.set_progress_bar_config(disable=True)
self.logger.info("Pipeline created successfully")
return self.pipe
except Exception as e:
self.logger.error(f"Failed to create pipeline: {e}")
raise
def setup_device(self) -> None:
"""Configure pipeline device placement."""
if not self.pipe:
raise RuntimeError("Pipeline not created. Call create_pipeline() first.")
if self.config.cpu_offloading:
self.logger.info("Enabling CPU offloading for memory efficiency")
self.pipe.enable_model_cpu_offload()
if self.pipe_upsample:
self.pipe_upsample.enable_model_cpu_offload()
else:
self.logger.info("Moving pipeline to CUDA")
self.pipe.to("cuda")
if self.pipe_upsample:
self.logger.info("Moving upsampler pipeline to CUDA")
self.pipe_upsample.to("cuda")
# Enable VAE tiling for LTX-Video to save memory
if self.config.model_type == ModelType.LTX_VIDEO_DEV:
if hasattr(self.pipe, "vae") and hasattr(self.pipe.vae, "enable_tiling"):
self.logger.info("Enabling VAE tiling for LTX-Video")
self.pipe.vae.enable_tiling()
def get_backbone(self) -> torch.nn.Module:
"""
Get the backbone model (transformer or UNet).
Returns:
Backbone model module
"""
if not self.pipe:
raise RuntimeError("Pipeline not created. Call create_pipeline() first.")
if self.config.uses_transformer:
return self.pipe.transformer
return self.pipe.unet
class Calibrator:
"""Handles model calibration for quantization."""
def __init__(
self,
pipeline_manager: PipelineManager,
config: CalibrationConfig,
model_type: ModelType,
logger: logging.Logger,
):
"""
Initialize calibrator.
Args:
pipeline_manager: Pipeline manager with main and upsampler pipelines
config: Calibration configuration
model_type: Type of model being calibrated
logger: Logger instance
"""
self.pipeline_manager = pipeline_manager
self.pipe = pipeline_manager.pipe
self.pipe_upsample = pipeline_manager.pipe_upsample
self.config = config
self.model_type = model_type
self.logger = logger
def load_prompts(self) -> list[str]:
"""
Load calibration prompts from file.
Returns:
List of calibration prompts
"""
self.logger.info(f"Loading calibration prompts from {self.config.prompts_dataset}")
return load_calib_prompts(self.config.batch_size, self.config.prompts_dataset)
def run_calibration(self, prompts: list[str]) -> None:
"""
Run calibration steps on the pipeline.
Args:
prompts: List of calibration prompts
"""
self.logger.info(f"Starting calibration with {self.config.num_batches} batches")
extra_args = MODEL_DEFAULTS.get(self.model_type, {})
with tqdm(total=self.config.num_batches, desc="Calibration", unit="batch") as pbar:
for i, prompt_batch in enumerate(prompts):
if i >= self.config.num_batches:
break
if self.model_type == ModelType.LTX_VIDEO_DEV:
# Special handling for LTX-Video
self._run_ltx_video_calibration(prompt_batch, extra_args) # type: ignore[arg-type]
else:
common_args = {
"prompt": prompt_batch,
"num_inference_steps": self.config.n_steps,
}
self.pipe(**common_args, **extra_args).images # type: ignore[misc]
pbar.update(1)
self.logger.debug(f"Completed calibration batch {i + 1}/{self.config.num_batches}")
self.logger.info("Calibration completed successfully")
def _run_ltx_video_calibration(
self, prompt_batch: list[str], extra_args: dict[str, Any]
) -> None:
"""
Run calibration for LTX-Video model using the full multi-stage pipeline.
Args:
prompt_batch: Batch of prompts
extra_args: Model-specific arguments
"""
# Extract specific args for LTX-Video
expected_height = extra_args.get("height", 512)
expected_width = extra_args.get("width", 704)
num_frames = extra_args.get("num_frames", 121)
negative_prompt = extra_args.get(
"negative_prompt", "worst quality, inconsistent motion, blurry, jittery, distorted"
)
def round_to_nearest_resolution_acceptable_by_vae(height, width):
height = height - (height % self.pipe.vae_spatial_compression_ratio) # type: ignore[union-attr]
width = width - (width % self.pipe.vae_spatial_compression_ratio) # type: ignore[union-attr]
return height, width
downscale_factor = 2 / 3
# Part 1: Generate video at smaller resolution
downscaled_height, downscaled_width = (
int(expected_height * downscale_factor),
int(expected_width * downscale_factor),
)
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(
downscaled_height, downscaled_width
)
# Generate initial latents at lower resolution
latents = self.pipe( # type: ignore[misc]
conditions=None,
prompt=prompt_batch,
negative_prompt=negative_prompt,
width=downscaled_width,
height=downscaled_height,
num_frames=num_frames,
num_inference_steps=self.config.n_steps,
output_type="latent",
).frames
# Part 2: Upscale generated video using latent upsampler (if available)
if self.pipe_upsample is not None:
_ = self.pipe_upsample(latents=latents, output_type="latent").frames
# Part 3: Denoise the upscaled video with few steps to improve texture
# However, in this example code, we will omit the upscale step since its optional.
class Quantizer:
"""Handles model quantization operations."""
def __init__(
self, config: QuantizationConfig, model_config: ModelConfig, logger: logging.Logger
):
"""
Initialize quantizer.
Args:
config: Quantization configuration
model_config: Model configuration
logger: Logger instance
"""
self.config = config
self.model_config = model_config
self.logger = logger
def get_quant_config(self, n_steps: int, backbone: nn.Module) -> Any:
"""
Build quantization configuration based on format.
Args:
n_steps: Number of denoising steps
Returns:
Quantization configuration object
"""
self.logger.info(f"Building quantization config for {self.config.format.value}")
if self.config.format == QuantFormat.INT8:
if self.config.algo == QuantAlgo.SMOOTHQUANT:
quant_config = mtq.INT8_SMOOTHQUANT_CFG
else:
quant_config = INT8_DEFAULT_CONFIG
if self.config.collect_method != CollectMethod.DEFAULT:
reset_set_int8_config(
quant_config,
self.config.percentile,
n_steps,
collect_method=self.config.collect_method.value,
backbone=backbone,
)
elif self.config.format == QuantFormat.FP8:
quant_config = FP8_DEFAULT_CONFIG
elif self.config.format == QuantFormat.FP4:
if self.model_config.model_type.value.startswith("flux"):
quant_config = NVFP4_FP8_MHA_CONFIG
else:
quant_config = NVFP4_DEFAULT_CONFIG
else:
raise NotImplementedError(f"Unknown format {self.config.format}")
set_quant_config_attr(
quant_config,
self.model_config.trt_high_precision_dtype.value,
self.config.algo.value,
alpha=self.config.alpha,
lowrank=self.config.lowrank,
)
return quant_config
def quantize_model(
self,
backbone: torch.nn.Module,
quant_config: Any,
forward_loop: callable, # type: ignore[valid-type]
) -> None:
"""
Apply quantization to the model.
Args:
backbone: Model backbone to quantize
quant_config: Quantization configuration
forward_loop: Forward pass function for calibration
"""
self.logger.info("Checking for LoRA layers...")
check_lora(backbone)
self.logger.info("Starting model quantization...")
mtq.quantize(backbone, quant_config, forward_loop)
# Get model-specific filter function
model_filter_func = get_model_filter_func(self.model_config.model_type)
self.logger.info(f"Using filter function for {self.model_config.model_type.value}")
self.logger.info("Disabling specific quantizers...")
mtq.disable_quantizer(backbone, model_filter_func)
self.logger.info("Quantization completed successfully")
class ExportManager:
"""Handles model export operations."""
def __init__(self, config: ExportConfig, logger: logging.Logger):
"""
Initialize export manager.
Args:
config: Export configuration
logger: Logger instance
"""
self.config = config
self.logger = logger
def _has_conv_layers(self, model: torch.nn.Module) -> bool:
"""
Check if the model contains any convolutional layers.
Args:
model: Model to check
Returns:
True if model contains Conv layers, False otherwise
"""
for module in model.modules():
if isinstance(module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)) and (
module.input_quantizer.is_enabled or module.weight_quantizer.is_enabled
):
return True
return False
def save_checkpoint(self, backbone: torch.nn.Module) -> None:
"""
Save quantized model checkpoint.
Args:
backbone: Model backbone to save
"""
if not self.config.quantized_torch_ckpt_path:
return
self.logger.info(f"Saving quantized checkpoint to {self.config.quantized_torch_ckpt_path}")
mto.save(backbone, str(self.config.quantized_torch_ckpt_path))
self.logger.info("Checkpoint saved successfully")
def export_onnx(
self,
pipe: DiffusionPipeline,
backbone: torch.nn.Module,
model_type: ModelType,
quant_format: QuantFormat,
quantize_mha: bool,
) -> None:
"""
Export model to ONNX format.
Args:
pipe: Diffusion pipeline
backbone: Model backbone
model_type: Type of model
quant_format: Quantization format
"""
if not self.config.onnx_dir:
return
self.logger.info(f"Starting ONNX export to {self.config.onnx_dir}")
check_conv_and_mha(backbone, quant_format == QuantFormat.FP4, quantize_mha)
if quant_format == QuantFormat.FP8 and self._has_conv_layers(backbone):
self.logger.info(
"Detected quantizing conv layers in backbone. Generating FP8 scales..."
)
generate_fp8_scales(backbone)
self.logger.info("Preparing models for export...")
pipe.to("cpu")
torch.cuda.empty_cache()
backbone.to("cuda")
# Export to ONNX
backbone.eval()
with torch.no_grad():
self.logger.info("Exporting to ONNX...")
modelopt_export_sd(
backbone, str(self.config.onnx_dir), model_type.value, quant_format.value
)
self.logger.info("ONNX export completed successfully")
def restore_checkpoint(self, backbone: torch.nn.Module) -> None:
"""
Restore a previously quantized model.
Args:
backbone: Model backbone to restore into
"""
if not self.config.restore_from:
return
self.logger.info(f"Restoring model from {self.config.restore_from}")
mto.restore(backbone, str(self.config.restore_from))
self.logger.info("Model restored successfully")
def create_argument_parser() -> argparse.ArgumentParser:
"""
Create and configure argument parser.
Returns:
Configured argument parser
"""
parser = argparse.ArgumentParser(
description="Enhanced Diffusion Model Quantization Tool",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic INT8 quantization with SmoothQuant
%(prog)s --model flux-dev --format int8 --quant-algo smoothquant --collect-method global_min
# FP8 quantization with ONNX export
%(prog)s --model sd3-medium --format fp8 --onnx-dir ./onnx_models/
# Quantize LTX-Video model with full multi-stage pipeline
%(prog)s --model ltx-video-dev --format fp8 --batch-size 1 --calib-size 32
# Faster LTX-Video quantization (skip upsampler)
%(prog)s --model ltx-video-dev --format fp8 --batch-size 1 --calib-size 32 --ltx-skip-upsampler
# Restore and export a previously quantized model
%(prog)s --model flux-schnell --restore-from checkpoint.pt --onnx-dir ./exports/
""",
)
model_group = parser.add_argument_group("Model Configuration")
model_group.add_argument(
"--model",
type=str,
default="flux-dev",
choices=[m.value for m in ModelType],
help="Model to load and quantize",
)
model_group.add_argument(
"--model-dtype",
type=str,
default="Half",
choices=[d.value for d in DataType],
help="Precision for loading the model",
)
model_group.add_argument(
"--override-model-path", type=str, help="Custom path to model (overrides default)"
)
model_group.add_argument(
"--cpu-offloading", action="store_true", help="Enable CPU offloading for limited VRAM"
)
model_group.add_argument(
"--ltx-skip-upsampler",
action="store_true",
help="Skip upsampler pipeline for LTX-Video (faster calibration, only quantizes main transformer)",
)
quant_group = parser.add_argument_group("Quantization Configuration")
quant_group.add_argument(
"--format",
type=str,
default="int8",
choices=[f.value for f in QuantFormat],
help="Quantization format",
)
quant_group.add_argument(
"--quant-algo",
type=str,
default="max",
choices=[a.value for a in QuantAlgo],
help="Quantization algorithm",
)
quant_group.add_argument(
"--percentile",
type=float,
default=1.0,
help="Percentile for calibration, works for INT8, not including smoothquant",
)
quant_group.add_argument(
"--collect-method",
type=str,
default="default",
choices=[c.value for c in CollectMethod],
help="Calibration collection method, works for INT8, not including smoothquant",
)
quant_group.add_argument("--alpha", type=float, default=1.0, help="SmoothQuant alpha parameter")
quant_group.add_argument("--lowrank", type=int, default=32, help="SVDQuant lowrank parameter")
quant_group.add_argument(
"--quantize-mha", action="store_true", help="Quantizing MHA into FP8 if its True"
)
calib_group = parser.add_argument_group("Calibration Configuration")
calib_group.add_argument("--batch-size", type=int, default=2, help="Batch size for calibration")
calib_group.add_argument(
"--calib-size", type=int, default=128, help="Total number of calibration samples"
)
calib_group.add_argument("--n-steps", type=int, default=30, help="Number of denoising steps")
export_group = parser.add_argument_group("Export Configuration")
export_group.add_argument(
"--quantized-torch-ckpt-save-path",
type=str,
help="Path to save quantized PyTorch checkpoint",
)
export_group.add_argument("--onnx-dir", type=str, help="Directory for ONNX export")
export_group.add_argument(
"--restore-from", type=str, help="Path to restore from previous checkpoint"
)
export_group.add_argument(
"--trt-high-precision-dtype",
type=str,
default="Half",
choices=[d.value for d in DataType],
help="Precision for TensorRT high-precision layers",
)
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
return parser
def main() -> None:
parser = create_argument_parser()
args = parser.parse_args()
logger = setup_logging(args.verbose)
logger.info("Starting Enhanced Diffusion Model Quantization")
try:
model_config = ModelConfig(
model_type=ModelType(args.model),
model_dtype=DataType(args.model_dtype),
trt_high_precision_dtype=DataType(args.trt_high_precision_dtype),
override_model_path=Path(args.override_model_path)
if args.override_model_path
else None,
cpu_offloading=args.cpu_offloading,
ltx_skip_upsampler=args.ltx_skip_upsampler,
)
quant_config = QuantizationConfig(
format=QuantFormat(args.format),
algo=QuantAlgo(args.quant_algo),
percentile=args.percentile,
collect_method=CollectMethod(args.collect_method),
alpha=args.alpha,
lowrank=args.lowrank,
quantize_mha=args.quantize_mha,
)
calib_config = CalibrationConfig(
batch_size=args.batch_size, calib_size=args.calib_size, n_steps=args.n_steps
)
export_config = ExportConfig(
quantized_torch_ckpt_path=Path(args.quantized_torch_ckpt_save_path)
if args.quantized_torch_ckpt_save_path
else None,
onnx_dir=Path(args.onnx_dir) if args.onnx_dir else None,
restore_from=Path(args.restore_from) if args.restore_from else None,
)
logger.info("Validating configurations...")
quant_config.validate()
export_config.validate()
if not export_config.restore_from:
calib_config.validate()
pipeline_manager = PipelineManager(model_config, logger)
pipe = pipeline_manager.create_pipeline()
pipeline_manager.setup_device()
backbone = pipeline_manager.get_backbone()
export_manager = ExportManager(export_config, logger)
if export_config.restore_from:
export_manager.restore_checkpoint(backbone)
else:
logger.info("Initializing calibration...")
calibrator = Calibrator(pipeline_manager, calib_config, model_config.model_type, logger)
prompts = calibrator.load_prompts()
quantizer = Quantizer(quant_config, model_config, logger)
backbone_quant_config = quantizer.get_quant_config(calib_config.n_steps, backbone)
def forward_loop(mod):
if model_config.uses_transformer:
pipe.transformer = mod
else:
pipe.unet = mod
calibrator.run_calibration(prompts)
quantizer.quantize_model(backbone, backbone_quant_config, forward_loop)
export_manager.save_checkpoint(backbone)
export_manager.export_onnx(
pipe,
backbone,
model_config.model_type,
quant_config.format,
quantize_mha=QuantizationConfig.quantize_mha,
)
logger.info("Quantization process completed successfully!")
except Exception as e:
logger.error(f"Quantization failed: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()