-
Notifications
You must be signed in to change notification settings - Fork 381
Expand file tree
/
Copy pathmodel_calib.py
More file actions
1506 lines (1264 loc) · 61.7 KB
/
model_calib.py
File metadata and controls
1506 lines (1264 loc) · 61.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calibration utilities."""
import math
import os
import warnings
from functools import partial
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from modelopt.torch.opt.searcher import ForwardLoop
from modelopt.torch.utils import print_rank_0
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
from modelopt.torch.utils.perf import get_used_gpu_mem_fraction
from .calib import MseCalibrator, NVFP4MSECalibrator
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer
from .utils import (
disable_calib,
enable_fake_quant,
enable_quant,
enable_weight_access_and_writeback,
is_quantized_column_parallel_linear,
is_quantized_linear,
is_quantized_row_parallel_linear,
quantizer_attr_names,
reduce_amax,
weight_attr_names,
)
__all__ = ["awq", "max_calibrate", "smoothquant", "svdquant"]
def weight_only_quantize(model: nn.Module):
"""Just quantize the weights of the model."""
seen_modules = set()
for name, module in model.named_modules():
if module in seen_modules:
continue
for weight_name in weight_attr_names(module):
with enable_weight_access_and_writeback(module, model):
weight_quantizer = getattr(
module, quantizer_attr_names(weight_name).weight_quantizer
)
weight_quantizer(getattr(module, weight_name))
seen_modules.add(module)
def _has_expert_parallelism(module: nn.Module) -> bool:
"""Check if module has expert parallelism enabled."""
ps = getattr(module, "parallel_state", None)
return ps is not None and ps.expert_model_parallel_group.is_initialized()
def _check_moe_calibration_complete(quantizer, parallel_state):
"""Raise error if MoE calibration is incomplete (some ranks have amax, others don't)."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
_check_moe_calibration_complete(_q, parallel_state)
return
for group in [
parallel_state.data_parallel_group,
parallel_state.expert_model_parallel_group,
parallel_state.tensor_parallel_group,
]:
if not group.is_initialized():
continue
has_amax = getattr(quantizer, "_amax", None) is not None
amax_states = DistributedProcessGroup.get_dist_syncd_obj(has_amax, group, lambda objs: objs)
if any(amax_states) and not all(amax_states):
raise RuntimeError(
"MoE calibration incomplete: some experts received no tokens during calibration. "
"Increase --calib-size to ensure all experts see calibration data."
)
@torch.no_grad()
def max_calibrate(
model: nn.Module,
forward_loop: ForwardLoop | None = None,
distributed_sync=True,
):
"""Calibrate the model using max.
Args:
model: Model to be calibrated.
forward_loop: A callable which takes the model as argument and
forwards calibration data through the model.
distributed_sync: Whether to sync input_quantizer amax across distributed processes.
See :class:`MaxCalibConfig <modelopt.torch.quantization.config.MaxCalibConfig>` for
details on the remaining arguments.
"""
enable_stats_collection(model)
if forward_loop is None:
weight_only_quantize(model)
else:
forward_loop(model)
finish_stats_collection(model)
# Sync input_quantizer amax across local experts within each rank (for SequentialMLP)
for name, module in model.named_modules():
if hasattr(module, "layer_sync_moe_local_experts_amax"):
module.layer_sync_moe_local_experts_amax()
if not distributed_sync:
return
# Check MoE calibration completeness before sync
for name, module in model.named_modules():
if isinstance(module, QuantModule) and _has_expert_parallelism(module):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
_check_moe_calibration_complete(child, module.parallel_state)
def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel and expert parallel groups."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
sync_quantizer_amax_across_dp_ep(_q, parallel_state)
return
if getattr(quantizer, "_amax", None) is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
# TODO: create sync_bias_across_distributed_group
# Step 2:Sync amax across data parallelism
for name, module in model.named_modules():
if isinstance(module, QuantModule):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
sync_quantizer_amax_across_dp_ep(child, module.parallel_state)
# Step 3: TP sync
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same
# ColumnParallel: X @ [A_1, A_2] (weights split along Cout)
# activations: TPG should have the same amax if axis in [None, -1]
# weights: TPG should have the same amax if axis in [None, -1] (note: we dont use -1 axis for weights)
# RowParallel: [X_1, X_2] @ [A_1
# A_2] (weights split along Cin)
# activations: TPG should have the same amax if axis in [None]
# weights: TPG should have the same amax if axis in [None, 0]
def sync_quantizer_amax_across_tp(
quantizer: TensorQuantizer | SequentialQuantizer,
linear_name: str,
quantizer_type: str,
axes_for_sync: list,
parallel_state: ParallelState,
):
# Syncing amax across TP for sequential quantizer
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
# Syncing amax across TP for sequential quantizer
sync_quantizer_amax_across_tp(
_q, linear_name, quantizer_type, axes_for_sync, parallel_state
)
return
# sync is not needed for block quantization
if quantizer.block_sizes is not None:
if hasattr(quantizer, "_padding"):
warnings.warn(
f"Found block-quantized padded {quantizer_type} for {linear_name}, amax will"
" not be synced correctly."
)
# Skip amax sync for INT4 / W4A8 block quantization
# Sync amax for NVFP4 (dynamic per-block, static per-tensor quantized scale)
if getattr(quantizer.block_sizes, "type", None) == "dynamic":
return
if quantizer.axis in axes_for_sync and quantizer.amax is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.tensor_parallel_group)
# Step 2: Sync amax across relevant parallelism (such as TP / EP)
for name, module in model.named_modules():
if getattr(module, "_parallel_state", None) is None:
continue
if is_quantized_column_parallel_linear(module):
sync_quantizer_amax_across_tp(
module.input_quantizer,
name,
"input_quantizer",
axes_for_sync=[None, -1],
parallel_state=module.parallel_state,
)
sync_quantizer_amax_across_tp(
module.weight_quantizer,
name,
"weight_quantizer",
axes_for_sync=[None, -1],
parallel_state=module.parallel_state,
)
if is_quantized_row_parallel_linear(module):
sync_quantizer_amax_across_tp(
module.input_quantizer,
name,
"input_quantizer",
axes_for_sync=[None],
parallel_state=module.parallel_state,
)
sync_quantizer_amax_across_tp(
module.weight_quantizer,
name,
"weight_quantizer",
axes_for_sync=[None, 0],
parallel_state=module.parallel_state,
)
# KV Cache Quantization
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
# We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache)
# So we should sync amax across DP and TP for these quantizers (DP is already synced from above)
for quantizer in [module.k_bmm_quantizer, module.v_bmm_quantizer]:
if isinstance(quantizer, TensorQuantizer) and quantizer.amax is not None:
quantizer.sync_amax_across_distributed_group(
module.parallel_state.tensor_parallel_group
)
def _mse_quant_func(x, amax, quantizer):
"""Quantization function for MSE calibration."""
original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None
quantizer._amax = amax
with (
enable_quant(quantizer),
disable_calib(quantizer),
enable_fake_quant(quantizer),
):
if hasattr(quantizer, "_original_shape"):
x = quantizer._reset_to_original_shape(x)
xq = quantizer(x)
if hasattr(quantizer, "_block_reshape_size"):
xq = xq.reshape(quantizer._block_reshape_size)
if original_amax is not None:
quantizer._amax = original_amax
else:
delattr(quantizer, "_amax")
return xq
@torch.no_grad()
def mse_calibrate(
model: nn.Module,
forward_loop: ForwardLoop | None = None,
distributed_sync=True,
step_size: float = 0.1,
start_multiplier: float = 0.25,
stop_multiplier: float = 4.0,
fp8_scale_sweep: bool = False,
):
"""Calibrate the model using MSE-based amax search.
This calibration method first uses max calibration to get initial amax values,
then searches for better amax values by minimizing the MSE between original
and quantized tensors.
Args:
model: Model to be calibrated.
forward_loop: A callable which takes the model as argument and
forwards calibration data through the model.
distributed_sync: Whether to sync amax across distributed processes.
step_size: Step size for amax search (default: 0.1).
start_multiplier: Starting multiplier for amax search (default: 0.25).
stop_multiplier: Ending multiplier for amax search (default: 4.0).
fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values
for NVFP4 per-block quantization instead of using multipliers.
This is specifically designed for optimizing the FP8-quantized
per-block scales in NVFP4 format (default: False).
See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for
details on the remaining arguments.
"""
# Step 1: First get initial amax using max calibration
max_calibrate(model, forward_loop, distributed_sync)
# Step 2: Replace calibrators with MseCalibrator for enabled quantizers
# and identify weight quantizers
weight_quantizers = []
seen_modules = set()
for name, module in list(model.named_modules()):
if isinstance(module, TensorQuantizer) and not module._disabled:
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
# Get the initial amax from max calibration
initial_amax = module._amax.clone().detach()
is_nvfp4_static = (
module.is_static_block_quant
and module._num_bits == (2, 1)
and module._block_sizes is not None
and module._block_sizes.get("scale_bits") == (4, 3)
)
if is_nvfp4_static:
# Compute and set global_amax
global_amax = reduce_amax(initial_amax, axis=None)
# Convert to NVFP4StaticQuantizer in-place
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
if fp8_scale_sweep and is_nvfp4_static:
# Replace calibrator with NVFP4MSECalibrator
module._calibrator = NVFP4MSECalibrator(
amax=initial_amax,
axis=module._calibrator._axis,
global_amax=module.global_amax,
quant_func=partial(_mse_quant_func, quantizer=module),
)
continue
if fp8_scale_sweep and not is_nvfp4_static:
warnings.warn(
f"fp8_scale_sweep is enabled but quantizer '{name}' is not NVFP4 static "
"block quantization. fp8_scale_sweep will be ignored for this quantizer."
)
# Create MSE calibrator with quant_func
module._calibrator = MseCalibrator(
amax=initial_amax,
axis=module._calibrator._axis,
step_size=step_size,
start_multiplier=start_multiplier,
stop_multiplier=stop_multiplier,
quant_func=partial(_mse_quant_func, quantizer=module),
)
# Identify weight quantizers by checking if they have corresponding weight parameters
for name, parent_module in model.named_modules():
if parent_module in seen_modules:
continue
for weight_name in weight_attr_names(parent_module):
weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer
weight_quantizer = getattr(parent_module, weight_quantizer_name, None)
if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled:
if getattr(weight_quantizer, "_calibrator", None) is not None:
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
seen_modules.add(parent_module)
# Step 3: Calibrate weight quantizers once with MSE calibration
# This ensures weights are only calibrated once, not during every forward pass
for parent_module, weight_name, weight_quantizer in weight_quantizers:
# Enable calibration mode for the weight quantizer
enable_stats_collection(parent_module)
with enable_weight_access_and_writeback(parent_module, model):
weight = getattr(parent_module, weight_name)
weight_quantizer(weight)
finish_stats_collection(parent_module, method="mse")
weight_quantizer._calibrator.reset()
# TODO: Sync amax across distributed processes
def enable_stats_collection(model: nn.Module):
"""Enable stats collection for all quantizers in the model."""
for name, module in model.named_modules():
if isinstance(module, TensorQuantizer) and not module._disabled:
if module._calibrator is not None:
module.disable_quant()
module.enable_calib()
else:
module.disable()
def finish_stats_collection(model: nn.Module, method: str | None = None, **kwargs):
"""Finish stats collection for all quantizers in the model."""
for _, module in model.named_modules():
if not isinstance(module, TensorQuantizer) or module._disabled:
continue
cal = getattr(module, "_calibrator", None)
if cal and not getattr(module, "_dynamic", False):
if method in {"entropy"}:
if cal.compute_amax(method) is not None:
module.load_calib_amax("entropy", **kwargs)
elif cal.compute_amax(**kwargs) is not None:
module.load_calib_amax(**kwargs)
if module.bias_calibrator is not None and module.bias_type == "static":
module.load_calib_bias()
module.enable_quant()
module.disable_calib()
@torch.no_grad()
def disable_pre_quant_scale_and_resmooth(linear: nn.Module, delete_pre_quant_scale: bool = False):
"""Disable pre_quant_scale and resmooth the quantized linear weights."""
assert is_quantized_linear(linear), "Only quantized linear modules are supported"
assert linear.input_quantizer._enable_pre_quant_scale, (
"pre_quant_scale should be enabled first!"
)
assert hasattr(linear.input_quantizer, "_pre_quant_scale"), (
"pre_quant_scale should be available"
)
pre_quant_scale = linear.input_quantizer._pre_quant_scale.to(torch.float32)
linear.weight.copy_(
(linear.weight * pre_quant_scale.squeeze()[None, :]).to(linear.weight.dtype)
)
linear.weight_quantizer.reset_amax()
max_calibrate(linear, lambda linear: linear.weight_quantizer(linear.weight))
# Lets not delete the _pre_quant_scale, it might useful later; Instead we will disable it
linear.input_quantizer._enable_pre_quant_scale = False
if linear.input_quantizer.amax is not None:
assert hasattr(linear.input_quantizer, "_amax_for_smoothing")
device, dtype = linear.weight.device, linear.weight.dtype
linear.input_quantizer.amax = linear.input_quantizer._amax_for_smoothing.amax().to(
device=device, dtype=dtype
)
if delete_pre_quant_scale:
delattr(linear.input_quantizer, "_pre_quant_scale")
linear.input_quantizer._enable_pre_quant_scale = False
# A global variable used during auto_quantize to avoid folding pre_quant_scale to weights
_ENABLE_FOLDING_PQS_TO_WEIGHTS = True
@torch.no_grad()
def _apply_weight_pre_quant_scale(linear, pre_quant_scale):
if _ENABLE_FOLDING_PQS_TO_WEIGHTS:
linear.weight.data.copy_(
(linear.weight * pre_quant_scale.to(linear.weight.device).squeeze()[None, :]).to(
linear.weight.dtype
)
)
else:
linear.weight_quantizer._enable_pre_quant_scale = True
linear.weight_quantizer.pre_quant_scale = pre_quant_scale.squeeze()[None, :].to(
linear.weight.dtype
)
linear.weight_quantizer.reset_amax()
max_calibrate(linear, lambda linear: linear.weight_quantizer(linear.weight))
@torch.no_grad()
def apply_pre_quant_scale_and_smooth(
linear: nn.Module, pre_quant_scale: torch.Tensor | None = None
):
"""Apply pre_quant_scale and smooth the quantized linear weights.
If pre_quant_scale is not provided, the existing pre_quant_scale of input_quantizer will be used.
"""
assert is_quantized_linear(linear), "Only quantized linear modules are supported"
assert linear.input_quantizer.pre_quant_scale is None, "pre_quant_scale should be None first!"
if pre_quant_scale is None:
pre_quant_scale = (
linear.input_quantizer._pre_quant_scale
if hasattr(linear.input_quantizer, "_pre_quant_scale")
else None
)
assert pre_quant_scale is not None, "pre_quant_scale should be provided or already set"
assert torch.all(pre_quant_scale > 0), "pre_quant_scale should be positive"
# pre_quant_scale should be in fp32 for the scaling math to be numerically safe
pre_quant_scale = pre_quant_scale.to(torch.float32)
linear.input_quantizer._enable_pre_quant_scale = True
linear.input_quantizer.pre_quant_scale = pre_quant_scale.to(linear.weight.dtype)
inv_scale = 1.0 / pre_quant_scale
_apply_weight_pre_quant_scale(linear, inv_scale)
if linear.input_quantizer.amax is not None:
assert hasattr(linear.input_quantizer, "_amax_for_smoothing")
device, dtype = linear.weight.device, linear.weight.dtype
_amax_for_smoothing = linear.input_quantizer._amax_for_smoothing.to(
device=device, dtype=dtype
)
linear.input_quantizer.amax = (
(_amax_for_smoothing * pre_quant_scale.to(device)).amax().to(dtype)
)
if is_quantized_column_parallel_linear(linear) or is_quantized_row_parallel_linear(linear):
linear.input_quantizer.sync_amax_across_distributed_group(
linear.parallel_state.tensor_parallel_group
)
@torch.no_grad()
def smoothquant(model: nn.Module, forward_loop: ForwardLoop | None = None, alpha=1.0):
"""Smooth-Quant variant with per-channel weight scaling.
Args:
model: Model to be calibrated.
forward_loop: A callable which takes the model as argument and
forwards calibration data through the model.
See :class:`SmoothQuantCalibConfig <modelopt.torch.quantization.config.SmoothQuantCalibConfig>` for
details on the remaining arguments.
"""
# distributed synchronization
# max_calibrate performs amax sync for data parallel
# Column parallel:
# activations: TPG should have the same pre_quant_scale
# This is achieved by syncing act_amax and weight_scale across TPG which is used to
# compute pre_quant_scale
# weights: no-op
# Row parallel:
# activations: TPG should have same activation amax
# weights: TPG should have the same weight amax
assert forward_loop is not None, "forward_loop must be provided for smoothquant"
for name, module in model.named_modules():
if (
is_quantized_linear(module)
and module.input_quantizer.is_enabled
and module.input_quantizer.axis is None
):
module.input_quantizer.axis = -1
max_calibrate(model, forward_loop)
def postprocess(module):
# It is important to keep scaling math in fp32 to be numerically safe
act_amax = module.input_quantizer.amax.float()
weight_scale = module.weight.abs().amax(dim=0, keepdim=True)
device, dtype = module.weight.device, module.weight.dtype
parallel_group = module.parallel_state.tensor_parallel_group
if is_quantized_column_parallel_linear(module) and parallel_group.is_initialized():
dist.all_reduce(act_amax, op=dist.ReduceOp.MAX, group=parallel_group.group)
dist.all_reduce(weight_scale, op=dist.ReduceOp.MAX, group=parallel_group.group)
scale_a = (weight_scale.pow(1 - alpha) / act_amax.pow(alpha)).squeeze()
# Now that activation per-channel amax have been collected, use per-tensor quantization for activation
# TODO: make this a buffer after we support only heterogeneous checkpointing for MCore
module.input_quantizer._amax_for_smoothing = act_amax.cpu()
module.input_quantizer.reset_amax()
module.input_quantizer.axis = None
module.input_quantizer.amax = act_amax.amax().to(dtype=dtype, device=device)
# Some channel could have 0 amax which causes scale_a to overflow. Explicitly mask them out here
epsilon = 1.0 / (1 << 31)
if scale_a.min() <= epsilon:
zero_mask = act_amax <= epsilon
scale_a[zero_mask] = 1
scale_a = scale_a.clamp(min=1e-4, max=1e4)
apply_pre_quant_scale_and_smooth(module, scale_a)
smoothed_modules = 0
for name, module in model.named_modules():
if is_quantized_linear(module):
if not hasattr(module.input_quantizer, "_amax"):
warnings.warn(f"{name} is not calibrated, skip smoothing")
continue
if module.input_quantizer.num_bits != 8 or module.weight_quantizer.num_bits != 8:
warnings.warn(f"Only int8 smoothing is supported, skip {name}")
continue
if module.input_quantizer.axis != -1:
warnings.warn(f"Only per-channel smoothing is supported, skip {name}")
continue
assert module.input_quantizer._amax.numel() > 1, (
f"Error: {name} has only one channel to smooth"
)
with enable_weight_access_and_writeback(module, model):
postprocess(module)
smoothed_modules += 1
print_rank_0(f"Smoothed {smoothed_modules} modules")
def awq(
model: nn.Module,
forward_loop: ForwardLoop | None = None,
algorithm: str = "awq_lite",
**kwargs,
):
"""Apply AWQ to the model.
Args:
model: Model to be calibrated.
forward_loop: A callable which takes the model as argument and
forwards calibration data through the model.
See :class:`AWQFullCalibConfig <modelopt.torch.quantization.config.AWQFullCalibConfig>` for
details on the remaining arguments.
"""
with SequentialQuantizer.convert_to_single_quantizer(model):
if algorithm in ["awq_full", "awq_lite"]:
awq_lite(model, forward_loop, **kwargs)
if algorithm in ["awq_full", "awq_clip"]:
awq_clip(model, forward_loop, **kwargs)
# Special handling for SequentialQuantizer
# Pre-compute name_to_module dict to avoid O(n^2) complexity in enable_weight_access_and_writeback
name_to_module = dict(model.named_modules())
for name, module in model.named_modules():
if is_quantized_linear(module) and isinstance(module.weight_quantizer, SequentialQuantizer):
with enable_weight_access_and_writeback(module, model, name_to_module):
max_calibrate(module, lambda linear: linear.weight_quantizer(module.weight))
@torch.no_grad()
def awq_lite(
model: nn.Module,
forward_loop: ForwardLoop,
alpha_step: float = 0.1,
debug: bool = False,
**kwargs,
):
"""Lite version of AWQ.
Args:
model: Model to be calibrated.
forward_loop: A callable which takes the model as argument and
forwards calibration data through the model.
See :class:`AWQLiteCalibConfig <modelopt.torch.quantization.config.AWQLiteCalibConfig>` for
details on the remaining arguments.
"""
if forward_loop is None:
warnings.warn("forward_loop must be provided for awq_lite; skipping awq_lite")
return
class AWQLiteHelper:
cache_mode: bool = False
def __init__(self, module, name):
self.name = name
self.act_scale = 0.0
self.num_cache_steps = 0
self.num_search_steps = 0
self.block_size = _get_awq_quantizer_block_size(module.weight, module.weight_quantizer)
self.weight_scale = get_weight_scale(module.weight, self.block_size)
self.loss = {
k.item(): torch.zeros((), device=module.weight.device, dtype=torch.float32)
for k in torch.arange(0, 1.0 + alpha_step, alpha_step)
}
self.best_scale = None
self.best_alpha = None
self.is_input_quantized = module.input_quantizer.is_enabled
self.num_tokens = 0
self.module = module
self.is_enabled = True
def setup(self):
module = self.module
bind_forward_method(module, forward, "_forward_no_awq")
if module.input_quantizer.is_enabled:
module.input_quantizer.disable()
if module.input_quantizer.axis not in [None, -1]:
self.is_enabled = False
return
module.input_quantizer.axis = -1
def cleanup(self):
module = self.module
if hasattr(module, "_if_calib"):
delattr(module, "_if_calib")
unpatch_forward_method(module, "_forward_no_awq")
def get_weight_scale(weight, block_size=None):
org_shape = weight.shape
slice_after_padding = None
if block_size:
if org_shape[-1] % block_size != 0:
slice_after_padding = slice(org_shape[-1])
weight = F.pad(weight, (0, block_size - org_shape[-1] % block_size), "constant", 0)
org_shape = weight.shape
weight = weight.contiguous().view(-1, block_size)
weight_abs = weight.abs() # Cache to avoid redundant computation
weight_abs_amax = weight_abs.amax(dim=1, keepdim=True)
scale = weight_abs / (weight_abs_amax + torch.finfo(weight.dtype).tiny)
scale = scale.view(org_shape)
if slice_after_padding is not None:
scale = scale[..., slice_after_padding]
scale = scale.mean(0).to(torch.float32)
return scale
def get_act_scale(x):
return x.abs().contiguous().view(-1, x.shape[-1]).mean(0).to(torch.float32)
def get_scale(x_max, w_max, alpha, tensor_parallel_group=None):
scales = (
(
x_max.pow(alpha)
/ (w_max.to(x_max.device).pow(1 - alpha) + torch.finfo(torch.float32).tiny)
)
.clamp(min=1e-4, max=1e4)
.view(-1)
)
scales = (scales / (scales.max() * scales.min()).sqrt()).view(-1)
if tensor_parallel_group and tensor_parallel_group.is_initialized():
dist.all_reduce(scales, op=dist.ReduceOp.SUM, group=tensor_parallel_group.group)
scales /= tensor_parallel_group.world_size()
return scales
def update_loss(self, out, out_actual, alpha):
out_actual = out_actual[0] if isinstance(out_actual, tuple) else out_actual
out = out[0] if isinstance(out, tuple) else out
loss = (out - out_actual).float().pow(2).mean()
self.awq_lite.loss[alpha] += loss.to(self.awq_lite.loss[alpha].device)
def update_best_params(self):
if not self.awq_lite.is_enabled:
return
self.awq_lite.loss.update({k: float(v) for k, v in self.awq_lite.loss.items()})
self.awq_lite.best_alpha = min(self.awq_lite.loss, key=self.awq_lite.loss.get)
self.awq_lite.best_scale = get_scale(
self.awq_lite.act_scale,
self.awq_lite.weight_scale,
self.awq_lite.best_alpha,
(
self.parallel_state.tensor_parallel_group
if is_quantized_column_parallel_linear(self)
else None
),
)
def forward(self, input, *args, **kwargs):
# Collect actual output without quantization
self.weight_quantizer.disable()
if hasattr(self.input_quantizer, "_pre_quant_scale"):
delattr(self.input_quantizer, "_pre_quant_scale")
if hasattr(self.weight_quantizer, "_pre_quant_scale"):
delattr(self.weight_quantizer, "_pre_quant_scale")
out_actual = self._forward_no_awq(input, *args, **kwargs)
self.weight_quantizer.enable()
if input.numel() == 0 or not self.awq_lite.is_enabled:
# For MoEs, some experts might see 0 tokens
return out_actual
if AWQLiteHelper.cache_mode:
# Get local tensor from Dtensor
input = input.to_local() if hasattr(input, "to_local") else input
self.awq_lite.act_scale += get_act_scale(self.input_quantizer(input))
self.awq_lite.num_cache_steps += 1
self.awq_lite.num_tokens += input.numel() / input.shape[-1]
if self.awq_lite.is_input_quantized:
with set_quantizer_by_cfg_context(self.input_quantizer, {"*": {"enable": True}}):
max_calibrate(self.input_quantizer, lambda quantizer: quantizer(input), False)
return out_actual
for alpha in self.awq_lite.loss:
awq_scale = get_scale(
self.awq_lite.act_scale,
self.awq_lite.weight_scale,
alpha,
(
self.parallel_state.tensor_parallel_group
if is_quantized_column_parallel_linear(self)
else None
),
)
self.input_quantizer.pre_quant_scale = (1 / awq_scale).to(self.weight.dtype)
self.weight_quantizer.pre_quant_scale = awq_scale.to(self.weight.dtype)
out = self._forward_no_awq(input, *args, **kwargs)
update_loss(self, out, out_actual, alpha)
self.awq_lite.num_search_steps += 1
# Now forward the actual output without any quantization
return out_actual
# Pre-compute name_to_module dict ONCE to avoid O(n^2) complexity in enable_weight_access_and_writeback
name_to_module = dict(model.named_modules())
for name, module in name_to_module.items():
if is_quantized_linear(module) and module.weight_quantizer.is_enabled:
with enable_weight_access_and_writeback(module, model, name_to_module):
module.awq_lite = AWQLiteHelper(module, name)
module.awq_lite.setup()
# Collect activation scale values
AWQLiteHelper.cache_mode = True
print_rank_0("awq_lite: Caching activation statistics...")
# Lets enable stats collection
# This will collect amax for input_quantizers and KV quantizers during the caching mode forward pass
enable_stats_collection(model)
forward_loop(model)
# Call max_calibrate to load the amax values collected during the caching mode forward pass
# This will also perform distributed amax sync for input_quantizers
max_calibrate(model, lambda model: None)
def sync_act_scale_across_dp(module, data_parallel_group):
"""Sync activation scale across Data Parallel (DP)."""
if data_parallel_group.is_initialized():
dist.all_reduce(
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group
)
for name, module in model.named_modules():
if (
is_quantized_linear(module)
and hasattr(module, "awq_lite")
and module.awq_lite.num_cache_steps > 0
):
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
torch.isnan(module.awq_lite.weight_scale)
)
has_nan = DistributedProcessGroup.get_dist_syncd_obj(
has_nan_local, module.parallel_state.data_parallel_group, lambda objs: any(objs)
)
if has_nan:
module.awq_lite.is_enabled = False
else:
sync_act_scale_across_dp(
module,
module.parallel_state.data_parallel_group,
)
AWQLiteHelper.cache_mode = False
print_rank_0("awq_lite: Searching parameters...")
with torch.no_grad():
forward_loop(model)
def postprocess(module, name):
update_best_params(module)
if hasattr(module.weight_quantizer, "_pre_quant_scale"):
delattr(module.weight_quantizer, "_pre_quant_scale")
if hasattr(module.input_quantizer, "_pre_quant_scale"):
delattr(module.input_quantizer, "_pre_quant_scale")
if module.awq_lite.is_input_quantized:
if module.input_quantizer.amax is not None:
act_amax = module.input_quantizer.amax
# TODO: make this a buffer after we support only heterogeneous checkpointing for MCore
module.input_quantizer._amax_for_smoothing = act_amax.cpu()
module.input_quantizer.reset_amax()
module.input_quantizer.axis = None
module.input_quantizer.amax = act_amax.amax()
module.input_quantizer.enable()
# for dynamic quantization, there is no amax, so we just enable the quantizer
else:
module.input_quantizer.enable()
if module.awq_lite.is_enabled:
apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)
else:
warnings.warn(f"awq_lite: Disabling for {name}, quantizing with max calibration.")
max_calibrate(module, lambda module: module.weight_quantizer(module.weight))
for name, module in model.named_modules():
if hasattr(module, "awq_lite"):
if module.awq_lite.num_cache_steps == 0:
module.awq_lite.is_enabled = False
elif module.awq_lite.num_search_steps == 0:
module.awq_lite.is_enabled = False
warnings.warn(
"awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
f" {name}. Please provide a valid `forward_loop` function that can be used to"
" forward data through the model many times."
)
with enable_weight_access_and_writeback(module, model, name_to_module):
postprocess(module, name)
module.awq_lite.cleanup()
if not debug:
delattr(module, "awq_lite")
@torch.no_grad()
def awq_clip(
model: nn.Module,
forward_loop: ForwardLoop,
max_co_batch_size: int = 1024,
max_tokens_per_batch: int = 64,
min_clip_ratio: float = 0.5,
shrink_step: float = 0.05,
debug: bool = False,
**kwargs,
):
"""AWQ-Clip variant.
Args:
model: Model to calibrate.
forward_loop: A callable that runs the forward pass of the model.
See :class:`AWQClipCalibConfig <modelopt.torch.quantization.config.AWQClipCalibConfig>` for
details on the remaining arguments.
"""
assert forward_loop is not None, "forward_loop must be provided for awq_clip"
class AWQClipHelper:
def __init__(self, module):
self.num_tokens = 0
self.block_size = _get_awq_quantizer_block_size(module.weight, module.weight_quantizer)
# Cache the original amax
module.weight_quantizer.reset_amax()
enable_stats_collection(module.weight_quantizer)
module.weight_quantizer(module.weight)
finish_stats_collection(module.weight_quantizer)
self.w_amax = module.weight_quantizer.amax.clone()
co, ci = module.weight.shape
clip_ratios = [
round(float(k), 2) for k in torch.arange(min_clip_ratio, 1.0, shrink_step)
] + [1.0]
if self.is_per_tensor_clip(module):
self.loss = {k: torch.tensor(0.0, device=module.weight.device) for k in clip_ratios}
else:
self.loss = {
k: torch.zeros(
(co, math.ceil(ci / self.block_size)),
device=module.weight.device,
)
for k in clip_ratios
}
self.best_clip_val = None
self.best_loss = None
self.is_input_quantized = module.input_quantizer.is_enabled
module.weight_quantizer.disable()
def is_per_tensor_clip(self, module):
quantizer = module.weight_quantizer
is_dynamic_w_per_tensor = (
hasattr(quantizer, "block_sizes")
and quantizer.block_sizes.get("type", None) == "dynamic"
and quantizer.axis is None
)
is_per_tensor = quantizer.axis is None and quantizer.block_sizes is None
return is_dynamic_w_per_tensor or is_per_tensor
def update_best_params(self):
self.awq_clip.best_loss = torch.ones_like(self.awq_clip.w_amax) * float("inf")
self.awq_clip.best_clip_val = torch.zeros_like(self.awq_clip.w_amax)
for shrink, loss in self.awq_clip.loss.items():
loss = loss.view_as(self.awq_clip.w_amax)
indices = loss < self.awq_clip.best_loss
self.awq_clip.best_loss = torch.where(indices, loss, self.awq_clip.best_loss)
self.awq_clip.best_clip_val = torch.where(
indices, self.awq_clip.w_amax * shrink, self.awq_clip.best_clip_val
)
def _clip_search(self, inputs, co_bsz=256, max_tokens=16):
weight = self.weight
self.weight_quantizer.enable()
if self.awq_clip.is_per_tensor_clip(self):
# In NVFP4, only the per-tensor amax is clipped
out_actual = inputs @ self.weight.T
original_amax = self.weight_quantizer.amax.clone()
self.awq_clip.num_tokens += inputs.shape[0]
for shrink in self.awq_clip.loss:
self.weight_quantizer.amax = original_amax * shrink
out = inputs @ self.weight_quantizer(self.weight).T
loss = (out - out_actual).float().pow(2).mean()
self.awq_clip.loss[shrink] += loss
else:
# weight [co, ci] -> [co, 1, n_block, block_size]
# inputs [..., ci] -> [1, max_tokens, n_block, block_size]
inputs = inputs.view(-1, inputs.shape[-1]) # _, ci
# Select max_tokens from the total input tokens of count batch * n_token
inputs = inputs[0 :: max(1, inputs.shape[0] // max_tokens)] # max_tokens, ci
self.awq_clip.num_tokens += inputs.shape[0]
block_size = self.awq_clip.block_size