-
Notifications
You must be signed in to change notification settings - Fork 376
Expand file tree
/
Copy pathconfig.py
More file actions
1485 lines (1258 loc) · 51.6 KB
/
config.py
File metadata and controls
1485 lines (1258 loc) · 51.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This document lists the quantization formats supported by Model Optimizer and example quantization configs.
.. _quantization-formats:
Quantization Formats
==========================================
The following table lists the quantization formats supported by Model Optimizer and the corresponding quantization
config. See :ref:`Quantization Configs <example-quantization-configs>` for the
specific quantization config definitions.
Please see :doc:`choosing the right quantization formats <../../guides/_choosing_quant_methods>` to
learn more about the formats and their use-cases.
.. note::
The recommended configs given below are for LLM models. For CNN models, only INT8 quantization
is supported. Please use quantization config ``INT8_DEFAULT_CFG`` for CNN models.
================================= =======================================================
Quantization Format Model Optimizer config
================================= =======================================================
INT8 ``INT8_SMOOTHQUANT_CFG``
FP8 ``FP8_DEFAULT_CFG``
INT4 Weights only AWQ (W4A16) ``INT4_AWQ_CFG``
INT4-FP8 AWQ (W4A8) ``W4A8_AWQ_BETA_CFG``
================================= =======================================================
.. _quantization-configs:
Quantization Configs
================================
Quantization config is dictionary specifying the values for keys ``"quant_cfg"`` and
``"algorithm"``. The ``"quant_cfg"`` key specifies the quantization configurations. The
``"algorithm"`` key specifies the ``algorithm`` argument to
:meth:`calibrate <modelopt.torch.quantization.model_calib.calibrate>`. Please see :class:`QuantizeConfig`
for the quantization config definition.
'Quantization configurations' is a dictionary mapping wildcards or filter functions
to its 'quantizer attributes'. The wildcards or filter functions are matched
against the quantizer module names. The quantizer modules have names ending with
``weight_quantizer`` and ``input_quantizer`` and they perform weight quantization and
input quantization (or activation quantization) respectively. The quantizer modules are generally
instances of
:class:`TensorQuantizer <modelopt.torch.quantization.nn.modules.tensor_quantizer.TensorQuantizer>`.
The quantizer attributes are defined by :class:`QuantizerAttributeConfig`. See :class:`QuantizerAttributeConfig`
for details on the quantizer attributes and their values.
The key `"default"` from the quantization configuration dictionary is applied if no other wildcard or filter functions
match the quantizer module name.
The quantizer attributes are applied in the order they are specified. For the missing attributes, the default attributes
as defined by :class:`QuantizerAttributeConfig` are used.
Quantizer attributes can also be a list of dictionaries. In this case, the matched quantizer module
is replaced with a
:class:`SequentialQuantizer <modelopt.torch.quantization.nn.modules.tensor_quantizer.SequentialQuantizer>`
module which is used to quantize a tensor in multiple formats sequentially. Each quantizer attribute
dictionary in the list specifies the quantization formats for each quantization step of the
sequential quantizer. For example, `SequentialQuantizer` is used in 'INT4 Weights, FP8 Activations'
quantization in which the weights are quantized in INT4 followed by FP8.
In addition, the dictionary entries could also be pytorch module class names mapping the class specific
quantization configurations. The pytorch modules should have a quantized equivalent.
To get the string representation of a module class, do:
.. code-block::
from modelopt.torch.quantization import QuantModuleRegistry
# Get the class name for nn.Conv2d
class_name = QuantModuleRegistry.get_key(nn.Conv2d)
Here is an example of a quantization config:
.. code-block::
MY_QUANT_CFG = {
"quant_cfg": {
# Quantizer wildcard strings mapping to quantizer attributes
"*weight_quantizer": {"num_bits": 8, "axis": 0},
"*input_quantizer": {"num_bits": 8, "axis": None},
# Module class names mapping to quantizer configurations
"nn.LeakyReLU": {"*input_quantizer": {"enable": False}},
}
}
.. _example-quantization-configs:
Example Quantization Configurations
==========================================
These example configs can be accessed as attributes of ``modelopt.torch.quantization`` and can be given as
input to :meth:`mtq.quantize() <modelopt.torch.quantization.model_quant.quantize>`. For example:
.. code-block::
import modelopt.torch.quantization as mtq
model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, forward_loop)
You can also create your own config by following these examples.
For instance, if you want to quantize a model with int4 AWQ algorithm, but need to skip quantizing
the layer named ``lm_head``, you can create a custom config and quantize your model as following:
.. code-block::
# Create custom config
CUSTOM_INT4_AWQ_CFG = copy.deepcopy(mtq.INT4_AWQ_CFG)
CUSTOM_INT4_AWQ_CFG["quant_cfg"]["*lm_head*"] = {"enable": False}
# quantize model
model = mtq.quantize(model, CUSTOM_INT4_AWQ_CFG, forward_loop)
"""
from collections.abc import Callable
from typing import Literal
from pydantic import ValidationInfo, field_validator, model_validator
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
from modelopt.torch.utils.network import ConstructorLike
_default_disabled_quantizer_cfg = {
"nn.BatchNorm1d": {"*": {"enable": False}},
"nn.BatchNorm2d": {"*": {"enable": False}},
"nn.BatchNorm3d": {"*": {"enable": False}},
"nn.LeakyReLU": {"*": {"enable": False}},
"*lm_head*": {"enable": False},
"*proj_out.*": {"enable": False}, # In Whisper model, lm_head has key name proj_out
"*block_sparse_moe.gate*": {"enable": False}, # Skip the MOE router
"*router*": {"enable": False}, # Skip the MOE router
"*mlp.gate.*": {"enable": False}, # Skip the MOE router
"*mlp.shared_expert_gate.*": {"enable": False}, # Skip the MOE router
"*linear_attn.conv1d*": {"enable": False},
"*mixer.conv1d*": {"enable": False}, # Skip mamba conv1d
"*output_layer*": {"enable": False},
"output.*": {"enable": False},
"default": {"enable": False},
}
_mamba_moe_disabled_quantizer_cfg = {
"*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE
"*fc2_latent_proj*": {"enable": False}, # Skip Latent MOE
"*q_proj*": {"enable": False}, # Skip QKV Linear
"*k_proj*": {"enable": False}, # Skip QKV Linear
"*v_proj*": {"enable": False}, # Skip QKV Linear
"*o_proj*": {"enable": False}, # Skip QKV Output Projection
}
INT8_DEFAULT_CFG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": 8, "axis": 0},
"*input_quantizer": {"num_bits": 8, "axis": None},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}
INT8_SMOOTHQUANT_CFG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": 8, "axis": 0},
"*input_quantizer": {"num_bits": 8, "axis": None},
**_default_disabled_quantizer_cfg,
},
"algorithm": "smoothquant",
}
INT8_WEIGHT_ONLY_CFG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": 8, "axis": 0},
"*input_quantizer": {"enable": False},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}
FP8_DEFAULT_CFG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
"*input_quantizer": {"num_bits": (4, 3), "axis": None},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}
MAMBA_MOE_FP8_AGGRESSIVE_CFG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
"*input_quantizer": {"num_bits": (4, 3), "axis": None},
**_default_disabled_quantizer_cfg,
**_mamba_moe_disabled_quantizer_cfg,
},
"algorithm": "max",
}
MAMBA_MOE_FP8_CONSERVATIVE_CFG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
"*input_quantizer": {"num_bits": (4, 3), "axis": None},
**_default_disabled_quantizer_cfg,
**_mamba_moe_disabled_quantizer_cfg,
"*mixer.in_proj*": {"enable": False}, # Skip mamba linear
"*mixer.out_proj*": {"enable": False}, # Skip mamba linear
},
"algorithm": "max",
}
FP8_PER_CHANNEL_PER_TOKEN_CFG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": (4, 3), "axis": 0},
"*input_quantizer": {
"num_bits": (4, 3),
"type": "dynamic",
"block_sizes": {-1: None},
},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}
# FP8 2D blockwise fake quantization config for deepseek models
FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (4, 3),
"block_sizes": {-1: 128, -2: 128},
"enable": True,
},
"*input_quantizer": {"enable": False},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}
INT4_BLOCKWISE_WEIGHT_ONLY_CFG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True},
"*input_quantizer": {"enable": False},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}
INT4_AWQ_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": 4,
"block_sizes": {-1: 128, "type": "static"},
"enable": True,
},
"*input_quantizer": {"enable": False},
**_default_disabled_quantizer_cfg,
},
"algorithm": {"method": "awq_lite", "alpha_step": 0.1},
# "algorithm": {"method": "awq_full", "alpha_step": 0.1, "max_co_batch_size": 1024},
# "algorithm": {"method": "awq_clip", "max_co_batch_size": 2048},
}
# W4A8 currently uses INT4 blockwise quantization (block size = 128) followed by FP8 quantization
# for weights. This could change in the future
W4A8_AWQ_BETA_CFG = {
"quant_cfg": {
"*weight_quantizer": [
{"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}, "enable": True},
{"num_bits": (4, 3), "axis": None, "enable": True},
],
"*input_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True},
**_default_disabled_quantizer_cfg,
},
"algorithm": "awq_lite",
}
MXFP8_DEFAULT_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (4, 3),
"block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)},
"enable": True,
},
"*input_quantizer": {
"num_bits": (4, 3),
"block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)},
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": None,
}
MXFP6_DEFAULT_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (3, 2),
"block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)},
"enable": True,
},
"*input_quantizer": {
"num_bits": (3, 2),
"block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)},
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": None,
}
MXFP4_DEFAULT_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)},
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)},
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": None,
}
W4A8_MXFP4_FP8_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)},
"enable": True,
},
"*input_quantizer": {"num_bits": (4, 3), "axis": None},
**_default_disabled_quantizer_cfg,
},
"algorithm": None,
}
MXINT8_DEFAULT_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": 8,
"block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)},
"enable": True,
},
"*input_quantizer": {
"num_bits": 8,
"block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)},
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": None,
}
FP8_KV_CFG = {
"quant_cfg": {
"*[kv]_bmm_quantizer": {
"num_bits": (4, 3),
"axis": None,
"enable": True,
},
"default": {"enable": False},
},
"algorithm": "max",
}
FP8_AFFINE_KV_CFG = {
"quant_cfg": {
"*[kv]_bmm_quantizer": {
"num_bits": (4, 3),
"axis": None,
"bias": {-2: None, -4: None, "type": "static"},
},
"default": {"enable": False},
},
"algorithm": "max",
}
NVFP4_DEFAULT_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}
NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "local_hessian",
"fp8_scale_sweep": True,
},
}
MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
**_mamba_moe_disabled_quantizer_cfg,
},
"algorithm": "max",
}
MAMBA_MOE_NVFP4_CONSERVATIVE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
**_mamba_moe_disabled_quantizer_cfg,
"*mixer.in_proj*": {"enable": False}, # Skip mamba linear
"*mixer.out_proj*": {"enable": False}, # Skip mamba linear
},
"algorithm": "max",
}
NVFP4_AWQ_LITE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": "awq_lite",
}
NVFP4_AWQ_CLIP_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {"method": "awq_clip"},
}
NVFP4_AWQ_FULL_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {"method": "awq_full", "alpha_step": 0.1},
}
NVFP4_AFFINE_KV_CFG = {
"quant_cfg": {
"*[kv]_bmm_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
"bias": {-2: None, -4: None, "type": "static"},
},
"default": {"enable": False},
},
"algorithm": "max",
}
NVFP4_KV_CFG = {
"quant_cfg": {
"*[kv]_bmm_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"default": {"enable": False},
},
"algorithm": "max",
}
# Moved from examples/diffusers/quantization/config.py to here
NVFP4_FP8_MHA_CONFIG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*output_quantizer": {"enable": False},
"*q_bmm_quantizer": {
"num_bits": (4, 3),
"axis": None,
},
"*k_bmm_quantizer": {
"num_bits": (4, 3),
"axis": None,
},
"*v_bmm_quantizer": {
"num_bits": (4, 3),
"axis": None,
},
"*softmax_quantizer": {
"num_bits": (4, 3),
"axis": None,
},
"transformer_blocks*bmm2_output_quantizer": {
"num_bits": (4, 3),
"axis": None,
},
"default": {"enable": False},
},
"algorithm": "max",
}
NVFP4_KV_ROTATE_CFG = {
"quant_cfg": {
"*q_bmm_quantizer": {
"enable": False,
"rotate": True,
},
"*k_bmm_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
"rotate": True,
},
"*v_bmm_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
},
"algorithm": "max",
}
NVFP4_SVDQUANT_DEFAULT_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {"method": "svdquant", "lowrank": 32},
}
W4A8_NVFP4_FP8_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (4, 3),
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}
MXFP4_MLP_WEIGHT_ONLY_CFG = {
"quant_cfg": {
"*mlp*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)},
"enable": True,
"pass_through_bwd": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": None,
}
NVFP4_MLP_WEIGHT_ONLY_CFG = {
"quant_cfg": {
"*mlp*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {
-1: 32,
"type": "dynamic",
"scale_bits": (4, 3),
}, # Note: block_size is 32 here
"enable": True,
"pass_through_bwd": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}
NVFP4_MLP_ONLY_CFG = {
"quant_cfg": {
"*mlp*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"enable": True,
"pass_through_bwd": True,
},
"*mlp*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"enable": True,
"pass_through_bwd": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}
choices: set[str] = {
"FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG",
"FP8_AFFINE_KV_CFG",
"FP8_DEFAULT_CFG",
"FP8_KV_CFG",
"FP8_PER_CHANNEL_PER_TOKEN_CFG",
"INT4_AWQ_CFG",
"INT4_BLOCKWISE_WEIGHT_ONLY_CFG",
"INT8_DEFAULT_CFG",
"INT8_SMOOTHQUANT_CFG",
"INT8_WEIGHT_ONLY_CFG",
"MXFP4_DEFAULT_CFG",
"MXFP8_DEFAULT_CFG",
"MXINT8_DEFAULT_CFG",
"NVFP4_AFFINE_KV_CFG",
"NVFP4_AWQ_CLIP_CFG",
"NVFP4_AWQ_FULL_CFG",
"NVFP4_AWQ_LITE_CFG",
"NVFP4_DEFAULT_CFG",
"NVFP4_FP8_MHA_CONFIG",
"NVFP4_KV_CFG",
"NVFP4_KV_ROTATE_CFG",
"W4A8_NVFP4_FP8_CFG",
"NVFP4_SVDQUANT_DEFAULT_CFG",
"W4A8_AWQ_BETA_CFG",
"W4A8_MXFP4_FP8_CFG",
"NVFP4_MLP_WEIGHT_ONLY_CFG",
"MXFP4_MLP_WEIGHT_ONLY_CFG",
"NVFP4_MLP_ONLY_CFG",
"MAMBA_MOE_NVFP4_CONSERVATIVE_CFG",
"MAMBA_MOE_NVFP4_AGGRESSIVE_CFG",
"MAMBA_MOE_FP8_CONSERVATIVE_CFG",
"MAMBA_MOE_FP8_AGGRESSIVE_CFG",
}
BiasType = Literal["static", "dynamic"]
BiasMethod = Literal["mean", "max_min"]
class QuantizerAttributeConfig(ModeloptBaseConfig):
"""Quantizer attribute type."""
enable: bool = ModeloptField(
default=True,
title="Enable quantizer.",
description="""If True, enables the quantizer. If False, by-pass the quantizer and returns the input tensor.""",
)
num_bits: int | tuple[int, int] | str = ModeloptField(
default=8,
title="An integer or a tuple of two integers specifying the number of quantization bits.",
description="""`num_bits` can be:
#. A positive integer argument for integer quantization. `num_bits` specify
the number of bits used for integer quantization.
#. Constant integer tuple (E,M) for floating point quantization emulating
Nvidia's FPx quantization. E is the number of exponent bits and M is the number
of mantissa bits. Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1).
#. String specifying the quantization format. This is current used only for custom backends.""",
)
@model_validator(mode="before")
@classmethod
def validate_config(cls, values):
"""Validate quantizer config."""
def _validate_recursive(value):
"""Recursively validate config structure."""
if value is None:
return
if isinstance(value, list):
for item in value:
_validate_recursive(item)
elif isinstance(value, dict):
if len(value) == 1 and "enable" in value and value["enable"] is True:
raise ValueError(
"Invalid quantizer config: Cannot specify only {'enable': True}. "
"Additional parameters are required when enabling quantization."
)
# Recurse into nested dicts
for v in value.values():
_validate_recursive(v)
_validate_recursive(values)
return values
@model_validator(mode="after")
def validate_num_bits(self):
"""Validate `num_bits`."""
if self.backend is not None:
# For custom backends, we don't need to validate num_bits
return self
num_bits = self.num_bits
if isinstance(num_bits, int) and num_bits < 1:
raise ValueError(
f"num_bits must be a positive integer or a tuple of positive integers. {num_bits}"
)
if not isinstance(num_bits, tuple):
return self
if not all(x > 0 for x in num_bits):
raise ValueError("num_bits must be a positive integer or a tuple of positive integers.")
block_sizes = self.block_sizes
if num_bits not in [
(4, 3),
(5, 2),
(2, 1),
(1, 2),
(0, 3),
(3, 0),
(3, 2),
(2, 3),
]:
raise ValueError(
"Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1)."
)
elif num_bits not in [(4, 3), (2, 1)] and (
block_sizes is None or block_sizes.get("type", None) != "dynamic"
):
raise ValueError(
"Only blockwise dynamic quantization is supported with quantization "
"formats E{num_bis[0]}M{num_bits[1]}."
)
return self
axis: int | tuple[int, ...] | None = ModeloptField(
default=None,
title="None, integer or a tuple of integers specifying the axis to quantize.",
description="""This field is for static per-channel quantization. *It cannot coexist with `block_sizes`*.
You should set axis if you want a fixed shape of scale factor.
For example, if axis is set to None, the scale factor will be a scalar (per-tensor quantization)
if the axis is set to 0, the scale factor will be a vector of shape (dim0, ) (per-channel quantization).
if the axis is set to (-2, -1), the scale factor will be a vector of shape (dim-2, dim-1)
axis value must be in the range [-rank(input_tensor), rank(input_tensor))
""",
)
fake_quant: bool = ModeloptField(
default=True,
title="Enable fake quantization.",
description="""If True, enable fake quantization.""",
)
unsigned: bool = ModeloptField(
default=False,
title="Enable unsigned quantization.",
description="""If True, enable unsigned quantization. Used only for integer quantization.""",
)
narrow_range: bool = ModeloptField(
default=False,
title="Enable narrow range quantization.",
description="""If True, enable narrow range quantization. Used only for integer quantization.""",
)
learn_amax: bool = ModeloptField(
default=False,
title="Enable learning amax.",
description="""``learn_amax`` is deprecated and reserved for backward compatibility.""",
)
@field_validator("learn_amax")
@classmethod
def validate_learn_amax(cls, v):
"""Validate learn_amax."""
assert v is not True, "learn_amax is deprecated and reserved for backward compatibility."
return v
type: str = ModeloptField(
default="static",
title="""Specify whether the quantization is static or dynamic.""",
description="""The value is a string from ``["static", "dynamic"]``.
If ``"dynamic"``, dynamic quantization will be enabled which does not collect any statistics during
calibration.""",
pattern=r"^static$|^dynamic$",
)
block_sizes: dict[int | str, int | tuple[int, int] | str | dict[int, int] | None] | None = (
ModeloptField(
default=None,
title="Optional dictionary specifying block quantization parameters.",
description="""This field is for static or dynamic block quantization. *It cannot coexist with ``axis``*.
You should set block_sizes if you want fixed number of elements to share every scale factor.
The keys are the axes for block quantization and the
values are block sizes for quantization along the respective axes. Keys must be in the
range ``[-tensor.dim(), tensor.dim())``. Values, which are the block sizes for quantization must be
positive integers or ``None``. A positive block size specifies the block size for quantization along that
axis. ``None`` means that the block size will be the maximum possible size in that dimension - this is
useful for specifying certain quantization formats such per-token dynamic quantization which has the `amax`
shared along the last dimension.
In addition, there can be special string keys ``"type"``, ``"scale_bits"`` and ``"scale_block_sizes"``.
Key ``"type"`` should map to ``"dynamic"`` or ``"static"`` where ``"dynamic"``
indicates dynamic block quantization and "static"
indicates static calibrated block quantization. By default, the type is ``"static"``.
Key ``"scale_bits"`` specify the quantization bits for the per-block quantization scale factor
(i.e a double quantization scheme).
Key ``"scale_block_sizes"`` specify the block size for double quantization.
By default per-block quantization scale is not quantized.
For example, ``block_sizes = {-1: 32}`` will quantize the last axis of the input tensor in
blocks of size 32 with static calibration, with a total of ``numel(tensor) / 32`` scale factors.
``block_sizes = {-1: 32, "type": "dynamic"}`` will perform dynamic block quantization.
``block_sizes = {-1: None, "type": "dynamic"}`` can be used to
specify per-token dynamic quantization.
""",
)
)
bias: dict[int | str, BiasType | BiasMethod | tuple[int, ...] | bool | int | None] | None = (
ModeloptField(
default=None,
title="Bias configuration.",
description="""Configuration for bias handling in affine quantization. The keys are:
- "enable": Boolean to enable/disable bias handling, default is False
- "type": Specify the type of bias ["static", "dynamic"], default is "static"
- "method": Specify the method of bias calibration ["mean", "max_min"], default is "mean"
- "axis": Tuple of integers specifying axes for bias computation, default is None
Examples:
bias = {"enable": True}
bias = {"enable": True, "type": "static", "axis": -1}
bias = {"enable": True, "type": "dynamic", "axis": (-1, -3)}
""",
)
)
@staticmethod
def _get_block_quant_axes_and_sizes(block_sizes):
if block_sizes is None:
return None
return {
k: v
for k, v in block_sizes.items()
if k not in ["type", "scale_bits", "scale_block_sizes"]
}
@field_validator("block_sizes")
@classmethod
def validate_block_sizes(cls, v, info: ValidationInfo):
"""Validate block sizes."""
if v is None:
return v
assert info.data["axis"] is None, "axis must be None when block_sizes is not None."
if v.get("type", None) == "dynamic":
assert len(cls._get_block_quant_axes_and_sizes(v)) == 1, (
"Dynamic block quantization only supports quantization last axis."
)
for _k, _v in v.items():
if isinstance(_k, str):
assert _k in ["type", "scale_bits", "scale_block_sizes"]
else:
assert isinstance(_k, int) and (_v is None or isinstance(_v, int))
return v
@field_validator("bias")
@classmethod
def validate_bias(cls, v):
"""Validate bias."""
if v is None:
return v
if "type" in v and v["type"] not in ["static", "dynamic"]:
raise ValueError(f"Invalid bias type: {v['type']}, expected 'static' or 'dynamic'")
if "method" in v and v["method"] not in ["mean", "max_min"]:
raise ValueError(f"Invalid bias method: {v['method']}, expected 'mean' or 'max_min'")