Reduce RAM usage of quantizing VLM models and fix some issues of quantizing gemma4#1791
Reduce RAM usage of quantizing VLM models and fix some issues of quantizing gemma4#1791lvliang-intel wants to merge 9 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
This PR targets memory-related failures during quantization (VLM/diffusion/MLLM) and addresses a Gemma4 multi-GPU crash by adjusting cache behavior, dispatch logic, and re-applying model-specific patches after accelerate hook removal.
Changes:
- Reduce RAM growth during block-input caching by skipping storage of non-
hidden_statesper-sample-constant tensor kwargs and limiting caching to the first block (last_cache_name). - Improve diffusion multi-device dispatch by reserving primary-device memory for non-target pipeline components before inferring the device map.
- Reduce MLLM calibration fragmentation by forcing cleanup between samples and ensuring OOM is re-raised for CPU fallback handling; re-apply special model patches (Gemma4) after hook removal.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
auto_round/compressors_new/calib.py |
Adjusts calibration iteration/caching behavior, adds per-sample-constant tensor skip logic, and limits caching via last_cache_name. |
auto_round/utils/device.py |
Updates diffusion pipeline multi-device dispatch to reserve memory for non-main components before device-map inference. |
auto_round/compressors_new/diffusion_mixin.py |
Uses the updated dispatch logic when a multi-device device_map is provided. |
auto_round/compressors_new/mllm_mixin.py |
Adds gc.collect() + clear_memory() between calibration forwards and re-raises OOMs. |
auto_round/special_model_handler.py |
Makes Gemma4 patching idempotent and applies it unconditionally for Gemma4. |
auto_round/compressors/base.py |
Re-applies special model patches after remove_hook_from_submodules in multi-GPU paths. |
test/test_cpu/models/test_vlm_ram_reduction.py |
Adds unit tests intended to cover the RAM-reduction and dispatch changes. |
| # (text encoder, VAE, etc.) to avoid OOM. | ||
| main_model = getattr(pipe, main_attr) | ||
| dispatched = dispatch_model_by_all_available_devices(main_model, _device_map) | ||
| setattr(pipe, main_attr, dispatched) | ||
| primary_device = devices[0] | ||
|
|
| """hidden_states must always be cached regardless of shared_cache_keys.""" | ||
| # Simulate the new logic: skip only if: | ||
| # 1. key != "hidden_states" | ||
| # 2. isinstance(tensor) | ||
| # 3. key not in shared_cache_keys | ||
| key = "hidden_states" | ||
| is_tensor = True | ||
| shared_cache_keys = set() | ||
|
|
||
| should_skip = key != "hidden_states" and is_tensor and key not in shared_cache_keys | ||
| assert not should_skip, "hidden_states must not be skipped" | ||
|
|
||
| def test_per_sample_constant_tensor_is_skipped(self): | ||
| """Per-sample-constant tensors (e.g. attention_mask) must be skipped.""" | ||
| key = "attention_mask" | ||
| is_tensor = True |
|
|
||
| with patch("torch.cuda.is_available", return_value=False): | ||
| with patch("torch.cuda.device_count", return_value=0): | ||
| # Single device path (falls back to pipe.to) | ||
| result = dispatch_model_by_all_available_devices(pipe, "cpu") | ||
| assert result is pipe | ||
|
|
||
| def test_multi_device_respects_non_main_memory(self): | ||
| """With multiple devices, memory reservation must be computed for non-main components.""" | ||
| from auto_round.utils.device import dispatch_model_by_all_available_devices | ||
|
|
||
| class FakeComponent(nn.Module): | ||
| def __init__(self, param_count): | ||
| super().__init__() | ||
| self.register_parameter("p", nn.Parameter(torch.empty(param_count, dtype=torch.bfloat16))) | ||
|
|
||
| class FakePipe(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.transformer = FakeComponent(1000) # main: 1000 * 2 = 2000 bytes | ||
| self.text_encoder = FakeComponent(100) # non-main: 100 * 2 = 200 bytes | ||
| self.vae = FakeComponent(50) # non-main: 50 * 2 = 100 bytes | ||
| # total non-main: 300 params = 600 bytes * 1.2 buffer = 720 bytes | ||
|
|
||
| pipe = FakePipe() | ||
| pipe.components = { | ||
| "transformer": pipe.transformer, |
| def handle_exception(exc): | ||
| if isinstance(exc, NotImplementedError): | ||
| return "caught" | ||
| elif isinstance(exc, torch.OutOfMemoryError): | ||
| raise exc | ||
| else: | ||
| return "logged" | ||
|
|
||
| assert handle_exception(NotImplementedError()) == "caught" |
| """Without last_cache_name, all blocks must be cached before stopping.""" | ||
|
|
||
| class FakeCompressor: | ||
| def __init__(self): | ||
| self.last_cache_name = None | ||
| self._cache_target_set = {"model.layers.0", "model.layers.1"} | ||
| self._cache_seen_targets = set() | ||
|
|
||
| def _should_stop_cache_forward(self, name): | ||
| if name == self.last_cache_name: | ||
| return True | ||
| if self.last_cache_name is not None: | ||
| return False | ||
| if not hasattr(self, "_cache_target_set") or not hasattr(self, "_cache_seen_targets"): | ||
| return False | ||
| if name in self._cache_target_set: | ||
| self._cache_seen_targets.add(name) | ||
| if not self._cache_target_set.issubset(self._cache_seen_targets): | ||
| return False | ||
| self.last_cache_name = name | ||
| return True | ||
|
|
||
| c = FakeCompressor() | ||
| assert c._should_stop_cache_forward("model.layers.0") is False # not all seen yet | ||
| assert c._should_stop_cache_forward("model.layers.1") is True # all seen → lock and stop | ||
|
|
| # Skip non-hidden_states tensor args that are per-sample constants. | ||
| # Only hidden_states varies per sample and must be cached. | ||
| # However, shared_cache_keys tensors (e.g. position_ids) must still | ||
| # be stored once — they are needed by special model patches. |
There was a problem hiding this comment.
a simpler fix is removing this https://github.com/intel/auto-round/blob/main/auto_round/special_model_handler.py#L750, and reverting most changes in this file
There was a problem hiding this comment.
_PRE_DEFINED_FIXED_ATTR is necessary for transformers >= 5.6, otherwise it will raise "RuntimeError: The size of tensor a (512) must match the size of tensor b (256) at non-singleton dimension 3".
There was a problem hiding this comment.
OK, when variable_block_input is enabled, it is expected to cache all block inputs for each block except hidden_ids, for which only the first block should be cached. Please make sure this behavior has not changed. I don't know why your code work as it seems you only cache hidden_ids.
For transformers>=5.6, I suspect Heng’s patch has some issues as transformers made some change. I suggest updating the patch and only supporting this model with transformers>=5.6.
There was a problem hiding this comment.
hidden_states: Only block 0's is cached (skipped for blocks 1-N for RAM saving). All other inputs (position_embeddings, position_ids, attention_mask, etc.): Cached per-block in self.inputs[block_name].
There was a problem hiding this comment.
Thanks, the RAM usage has been reduced to around 70GB, which I think is expected since the entire model is basically materialized. Since Gemma4 is an important model, would it be possible to refine the monkey-patch approach and disable the variable_block_inputs ? At least for transformers < 5.6, this should be doable and could save lots of ram futher.
There was a problem hiding this comment.
The monkey-patch refined.
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines successfully started running 1 pipeline(s). |
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines successfully started running 1 pipeline(s). |
…g gemma4 Signed-off-by: lvliang-intel <liang1.lv@intel.com>
Signed-off-by: lvliang-intel <liang1.lv@intel.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Signed-off-by: lvliang-intel <liang1.lv@intel.com>
for more information, see https://pre-commit.ci
370ac99 to
3280df5
Compare
for more information, see https://pre-commit.ci
Description
Fix VLM/diffusion quantization OOM issues and Gemma4 multi-GPU crash.
Fixed the below issues:
1 VLM: OOM due to redundant hidden_states caching
2 Diffusion: OOM in multi-device quantization
3 MLLM: Memory fragmentation in calibration loop
CUDA_VISIBLE_DEVICES=1,2 auto-round /mnt/disk3/lvl/gemma-4-31B-it --device_map "auto" --data_type "int" --group_size 128 --batch_size 2 --nsamples 512 --seqlen 2048 --iters 2000 --to_quant_block_names 'model.language_model.layers' --output_dir gemma-3-27B-it-INT8-AutoRound --scheme W8A16 --dataset NeelNanda/pile-10k --format "auto_round:auto_gptq"
quantized 6/6 layers in the block, loss iter 0: 0.000566 -> iter 1714: 0.000053
2026-05-09 05:20:41 INFO device.py L1802: 'peak_ram': 70.72GB, 'peak_vram': {'0': 37.81GB, '1': 29.42GB}
Quantizing done: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [6:22:13<00:00, 382.23s/it]
2026-05-09 05:20:47 INFO device.py L1802: 'peak_ram': 70.72GB, 'peak_vram': {'0': 37.81GB, '1': 29.42GB}
2026-05-09 05:20:52 INFO shard_writer.py L314: model has been saved to gemma-3-27B-it-INT8-AutoRound/gemma-4-31B-it-w8g128/
2026-05-09 05:20:52 INFO calib.py L1260: quantization tuning time 22938.128324747086
2026-05-09 05:20:52 INFO calib.py L1279: Summary: quantized 410/602 in the model, unquantized layers: lm_head, model.embed_vision.embedding_projection, model.vision_tower.encoder.layers.[0-26].mlp.down_proj.linear, model.vision_tower.encoder.layers.[0-26].mlp.gate_proj.linear, model.vision_tower.encoder.layers.[0-26].mlp.up_proj.linear, model.vision_tower.encoder.layers.[0-26].self_attn.k_proj.linear, model.vision_tower.encoder.layers.[0-26].self_attn.o_proj.linear, model.vision_tower.encoder.layers.[0-26].self_attn.q_proj.linear, model.vision_tower.encoder.layers.[0-26].self_attn.v_proj.linear, model.vision_tower.patch_embedder.input_proj
2026-05-09 05:20:52 INFO device.py L1802: 'peak_ram': 70.72GB, 'peak_vram': {'0': 37.81GB, '1': 29.42GB}
Type of Change
Bug fix
#1783
Checklist Before Submitting
/azp run Unit-Test-CUDA-AutoRound.