Skip to content

support quarot/spinquant rotation before quantization#1797

Open
lkk12014402 wants to merge 25 commits into
intel:mainfrom
lkk12014402:quarot-quant
Open

support quarot/spinquant rotation before quantization#1797
lkk12014402 wants to merge 25 commits into
intel:mainfrom
lkk12014402:quarot-quant

Conversation

@lkk12014402
Copy link
Copy Markdown
Contributor

@lkk12014402 lkk12014402 commented May 11, 2026

Description

What Problem Does Rotation Solve?

Quantization accuracy degrades when weight/activation distributions have outlier channels —
a few dimensions with magnitudes 10–100× larger than the rest. Rotation applies an orthogonal
transform (Hadamard matrix) to redistribute these outliers uniformly across all channels, making
the distribution more quantization-friendly.

Before rotation:  [ 0.1, 0.2, 50.0, 0.3 ]   ← outlier in channel 3
After rotation:   [ 12.7, 12.5, 12.8, 12.6 ] ← uniform distribution

Since orthogonal transforms preserve mathematical equivalence (Q @ Q^T = I), the model's
FP16 output is unchanged — only quantization behavior improves.

QuaRot vs SpinQuant

Feature QuaRot SpinQuant
Rotation matrix Fixed deterministic/random Learnable (trained)
Training required No Yes (10–50 steps)
Typical accuracy Good maybe better

In auto-round, both share the same codebase. The difference is a config flag:

  • trainable_rotation=FalseQuaRot
  • trainable_rotation=TrueSpinQuant

Quick Start

from auto_round import AutoRound

# "quarot" applies R1+R2 with fixed Hadamard, then quantizes
autoround = AutoRound(model, rotation_config="quarot", scheme="W4A16", iters=0)
autoround.quantize()

Default: "quarot" and "spinquant" enable R1+R2 only. To enable R3/R4,
use SpinQuantConfig(r3=True, r4=True) explicitly.

Signed-off-by: lkk12014402 <kaokao.lv@intel.com>
Copilot AI review requested due to automatic review settings May 11, 2026 03:45
lkk12014402 and others added 2 commits May 11, 2026 03:47
Signed-off-by: lkk12014402 <kaokao.lv@intel.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces SpinQuant/QuaRot rotation support as a first-class “transform” in AutoRound, enabling orthogonal rotations (R1–R4) to be applied before quantization via unified config normalization and BaseRotation registry dispatch.

Changes:

  • Adds a new spinquant transform package (config, preprocessor, online hook/monkeypatch utilities, and optional training helpers/trainer).
  • Extends rotation config normalization and AutoRound’s new-arch entrypoints to accept "quarot" / "spinquant" shorthands and SpinQuant configs/dicts.
  • Adds CUDA tests covering config normalization, registry integration, hook lifecycle, and end-to-end pipeline integration.

Reviewed changes

Copilot reviewed 16 out of 17 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
test/test_cuda/transform/test_spinquant.py Adds CUDA tests for SpinQuant/QuaRot config normalization, hook behavior, and pipeline integration.
auto_round/compressors_new/spinquant_mixin.py Introduces a deprecated compatibility mixin that forwards to the unified rotation pipeline.
auto_round/compressors_new/entry.py Extends config resolution / rotation_config handling to accept SpinQuant configs and "quarot"/"spinquant" shorthands.
auto_round/algorithms/transforms/spinquant/training.py Adds experimental SpinQuant training helpers (hooks, callbacks, optimizer/loss utilities, state).
auto_round/algorithms/transforms/spinquant/training_core.py Adds shared primitives for loss computation, reference-model cloning, optimizer creation, and a common training loop.
auto_round/algorithms/transforms/spinquant/trainer.py Adds an experimental “Trainer-like” interface for SpinQuant training + fusion + checkpointing.
auto_round/algorithms/transforms/spinquant/rotation_utils.py Adds SpinQuant rotation math utilities (Hadamard construction, fusion helpers, wrappers).
auto_round/algorithms/transforms/spinquant/preprocessor.py Implements the main SpinQuant/QuaRot preprocessing pipeline (init/train/fuse/hooks/cleanup).
auto_round/algorithms/transforms/spinquant/monkeypatch.py Adds monkeypatch mechanism to wrap RoPE application for R3 (Q/K rotation after RoPE).
auto_round/algorithms/transforms/spinquant/inplace/apply.py Adds in-place hook registration/removal and a convenience “apply in place” entrypoint.
auto_round/algorithms/transforms/spinquant/inplace/__init__.py Exposes the in-place SpinQuant APIs.
auto_round/algorithms/transforms/spinquant/cayley_optimizer.py Adds/ports the Cayley/SGDG optimizer and a combined Adam+SGDG optimizer.
auto_round/algorithms/transforms/spinquant/algorithm.py Registers the SpinQuant rotation as a BaseRotation algorithm ("spinquant").
auto_round/algorithms/transforms/spinquant/__init__.py Exposes SpinQuant public API surface and documents feature status/limitations.
auto_round/algorithms/transforms/base.py Ensures the BaseRotation registry imports rotation and spinquant.
auto_round/algorithms/transforms/__init__.py Extends rotation config normalization to dispatch spinquant and string shorthands.

Comment thread test/test_cuda/transform/test_spinquant.py
Comment thread auto_round/algorithms/transforms/spinquant/rotation_utils.py
Comment thread auto_round/algorithms/transforms/spinquant/preprocessor.py Outdated
Comment thread auto_round/algorithms/transforms/spinquant/monkeypatch.py Outdated
Comment thread auto_round/algorithms/transforms/spinquant/inplace/apply.py
Comment thread auto_round/algorithms/transforms/spinquant/trainer.py
Comment thread auto_round/algorithms/transforms/__init__.py
lkk12014402 and others added 2 commits May 11, 2026 06:25
Signed-off-by: lkk12014402 <kakao.lv@intel.com>
@chensuyue chensuyue added this to the 0.13.0 milestone May 11, 2026
lkk12014402 and others added 5 commits May 11, 2026 08:09
Signed-off-by: lkk12014402 <kakao.lv@intel.com>
Signed-off-by: lkk12014402 <kakao.lv@intel.com>
Signed-off-by: lkk12014402 <kakao.lv@intel.com>
@lkk12014402
Copy link
Copy Markdown
Contributor Author

Qwen/Qwen3-0.6B

RTN: Average Accuracy (across 4 tasks: hellaswag, piqa, winogrande, lambada_openai)

Rotation FP16 MXFP4 NVFP4
none 0.5284 0.4442 0.4750
R1 (rs=16) 0.4385 0.4671
R1+R2 (rs=16) 0.4315 0.4506
R1+R2+R3 (rs=16) 0.4410 0.4541
R1+R2+R3+R4 (rs=16) 0.4311 0.4518
R1 (rs=32) 0.4263 0.4434
R1+R2 (rs=32) 0.4326 0.4481
R1+R2+R3 (rs=32) 0.4307 0.4513
R1+R2+R3+R4 (rs=32) 0.4471 0.4584
R1 (rs=64) 0.4249 0.4656
R1+R2 (rs=64) 0.4331 0.4610
R1+R2+R3 (rs=64) 0.4330 0.4633
R1+R2+R3+R4 (rs=64) 0.4418 0.4596
R1 (rs=128) 0.4378 0.4588
R1+R2 (rs=128) 0.4434 0.4706
R1+R2+R3 (rs=128) 0.4415 0.4688
R1+R2+R3+R4 (rs=128) 0.4528 0.4621
R1 (rs=auto) 0.4320 0.4481
R1+R2 (rs=auto) 0.4330 0.4489
R1+R2+R3 (rs=auto) 0.4249 0.4508
R1+R2+R3+R4 (rs=auto) 0.4237 0.4444

tuning (iters=200): Average Accuracy (across 4 tasks: hellaswag, piqa, winogrande, lambada_openai)

Rotation FP16 W4A16 MXFP4 NVFP4
none 0.5284 0.5078 0.4496 0.4865
R1 (rs=16) 0.5051 0.4612 0.4751
R1+R2 (rs=16) 0.5151 0.4677 0.4834
R1+R2+R3 (rs=16) 0.5092 0.4616 0.4781
R1+R2+R3+R4 (rs=16) 0.5138 0.4680 0.4747
R1 (rs=32) 0.5125 0.4688 0.4743
R1+R2 (rs=32) 0.5058 0.4710 0.4788
R1+R2+R3 (rs=32) 0.5085 0.4717 0.4755
R1+R2+R3+R4 (rs=32) 0.5104 0.4624 0.4833
R1 (rs=64) 0.5098 0.4695 0.4770
R1+R2 (rs=64) 0.5093 0.4632 0.4855
R1+R2+R3 (rs=64) 0.5142 0.4614 0.4849
R1+R2+R3+R4 (rs=64) 0.5091 0.4630 0.4747
R1 (rs=128) 0.5118 0.4645 0.4807
R1+R2 (rs=128) 0.5130 0.4652 0.4800
R1+R2+R3 (rs=128) 0.5078 0.4627 0.4819
R1+R2+R3+R4 (rs=128) 0.5133 0.4596 0.4769
R1 (rs=auto) 0.5105 0.4584 0.4733
R1+R2 (rs=auto) 0.5077 0.4556 0.4743
R1+R2+R3 (rs=auto) 0.5118 0.4553 0.4761
R1+R2+R3+R4 (rs=auto) 0.5140 0.4556 0.4643

lkk12014402 and others added 10 commits May 12, 2026 01:13
Signed-off-by: lkk12014402 <kakao.lv@intel.com>
Signed-off-by: lkk12014402 <kakao.lv@intel.com>
Signed-off-by: lkk12014402 <kakao.lv@intel.com>
Signed-off-by: lkk12014402 <kakao.lv@intel.com>
Signed-off-by: lkk12014402 <kakao.lv@intel.com>
@wenhuach21
Copy link
Copy Markdown
Contributor

Please sync with Heng and try to port it to the new architecture, which uses block-wise quantization. Otherwise, it will be difficult to support multiple algorithms simultaneously since it consumes a large amount of RAM.

@lkk12014402
Copy link
Copy Markdown
Contributor Author

Please sync with Heng and try to port it to the new architecture, which uses block-wise quantization. Otherwise, it will be difficult to support multiple algorithms simultaneously since it consumes a large amount of RAM.

the current implementation supports block-wise quantization

@wenhuach21
Copy link
Copy Markdown
Contributor

Please sync with Heng and try to port it to the new architecture, which uses block-wise quantization. Otherwise, it will be difficult to support multiple algorithms simultaneously since it consumes a large amount of RAM.

the current implementation supports block-wise quantization

Nice, then it's better to align with the code, e.g., the transformation should inherit this class
https://github.com/intel/auto-round/blob/main/auto_round/algorithms/quantization/base.py#L255

@lkk12014402
Copy link
Copy Markdown
Contributor Author

Please sync with Heng and try to port it to the new architecture, which uses block-wise quantization. Otherwise, it will be difficult to support multiple algorithms simultaneously since it consumes a large amount of RAM.

the current implementation supports block-wise quantization

Nice, then it's better to align with the code, e.g., the transformation should inherit this class https://github.com/intel/auto-round/blob/main/auto_round/algorithms/quantization/base.py#L255

the code is for quantization? the pull request is for rotation, which should aligh with this code https://github.com/intel/auto-round/blob/main/auto_round/algorithms/transforms/base.py#L51? right?

@wenhuach21
Copy link
Copy Markdown
Contributor

wenhuach21 commented May 12, 2026

all algorithms should be compatible with base_compressor or base quantizer for easily scheduling different algorithms. If it's hard to align with them, at least it should support this feature, apply_rotation+apply_awq+apply_autoround for each block to save ram and be aware of different algorihtms, instead of running apply rotation for the whole model and then apply_awq and so on

lkk12014402 and others added 4 commits May 12, 2026 08:56
Signed-off-by: lkk12014402 <kakao.lv@intel.com>
Signed-off-by: lkk12014402 <kakao.lv@intel.com>
lr: float = 1e-4,
stiefel: bool = True,
momentum: float = 0.0,
weight_decay: float = 0.0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 For the accuracy-related part, at minimum the Hadamard rotation case should be included.

2 The accuracy impact should also be properly documented. I left a similar comment on your Hadamard PR, but I still haven’t seen the corresponding documentation yet.

3 SpinQuant should be updated to follow the new architecture and switch to the block-wise implementation approach.

Copy link
Copy Markdown
Contributor Author

@lkk12014402 lkk12014402 May 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

answer your comment 1

The SpinQuant algorithm is still in the experimental stage in this release and has not yet been fully validated for accuracy. There is already a note in the code indicating that enabling the training of rotation matrices is part of this experimental phase.

As we know, apart from the training component, the QuaRot algorithm shares the same rotation structure as SpinQuant, including R1/R2/R3/R4. In this PR, we mainly support QuaRot combined with quantization (AutoRound, RTN, and tuning). The implementation of QuaRot’s R1/R2/R3/R4 is largely aligned with community implementations such as Quark and LLMC.

answer your comment 2

I didn’t notice the “accuracy impact” comment you mentioned—could you explain it a bit more?

answer your comment 3:

In fact, I have a block-wise rotation implementation that can be combined with block-wise quantization in AutoRound. However, I need to wait until @n1ck-guo finishes the API, after which I will submit another PR to add block-wise rotation support

Copy link
Copy Markdown
Contributor

@wenhuach21 wenhuach21 May 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 As far as I know, QuaRot + AutoRound has already been supported by your previous PR. I also noticed that you are continuing to fix bugs, and since the implementation is largely self-contained, I don’t think we need to rush merging this PR. It would be better to further refine it to a more product-level quality and provide more accuracy results to demonstrate either better accuracy than Hadamard rotation or comparable accuracy against other repositories.

2 We should also document benchmark and document data so users can clearly understand the accuracy improvements, computational cost, and potential side effects.

3 Feel free to handle this in a separate PR.

Signed-off-by: lkk12014402 <kakao.lv@intel.com>
@lkk12014402
Copy link
Copy Markdown
Contributor Author

@copilot resolve the merge conflicts in this pull request

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants