Skip to content

Reduce RAM usage of quantizing VLM models and fix some issues of quantizing gemma4#1791

Open
lvliang-intel wants to merge 9 commits into
mainfrom
lvl/fix_vlm_large_ram_issue
Open

Reduce RAM usage of quantizing VLM models and fix some issues of quantizing gemma4#1791
lvliang-intel wants to merge 9 commits into
mainfrom
lvl/fix_vlm_large_ram_issue

Conversation

@lvliang-intel
Copy link
Copy Markdown
Contributor

Description

Fix VLM/diffusion quantization OOM issues and Gemma4 multi-GPU crash.

Fixed the below issues:
1 VLM: OOM due to redundant hidden_states caching
2 Diffusion: OOM in multi-device quantization
3 MLLM: Memory fragmentation in calibration loop

CUDA_VISIBLE_DEVICES=1,2 auto-round /mnt/disk3/lvl/gemma-4-31B-it --device_map "auto" --data_type "int" --group_size 128 --batch_size 2 --nsamples 512 --seqlen 2048 --iters 2000 --to_quant_block_names 'model.language_model.layers' --output_dir gemma-3-27B-it-INT8-AutoRound --scheme W8A16 --dataset NeelNanda/pile-10k --format "auto_round:auto_gptq"

quantized 6/6 layers in the block, loss iter 0: 0.000566 -> iter 1714: 0.000053
2026-05-09 05:20:41 INFO device.py L1802: 'peak_ram': 70.72GB, 'peak_vram': {'0': 37.81GB, '1': 29.42GB}
Quantizing done: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [6:22:13<00:00, 382.23s/it]
2026-05-09 05:20:47 INFO device.py L1802: 'peak_ram': 70.72GB, 'peak_vram': {'0': 37.81GB, '1': 29.42GB}
2026-05-09 05:20:52 INFO shard_writer.py L314: model has been saved to gemma-3-27B-it-INT8-AutoRound/gemma-4-31B-it-w8g128/
2026-05-09 05:20:52 INFO calib.py L1260: quantization tuning time 22938.128324747086
2026-05-09 05:20:52 INFO calib.py L1279: Summary: quantized 410/602 in the model, unquantized layers: lm_head, model.embed_vision.embedding_projection, model.vision_tower.encoder.layers.[0-26].mlp.down_proj.linear, model.vision_tower.encoder.layers.[0-26].mlp.gate_proj.linear, model.vision_tower.encoder.layers.[0-26].mlp.up_proj.linear, model.vision_tower.encoder.layers.[0-26].self_attn.k_proj.linear, model.vision_tower.encoder.layers.[0-26].self_attn.o_proj.linear, model.vision_tower.encoder.layers.[0-26].self_attn.q_proj.linear, model.vision_tower.encoder.layers.[0-26].self_attn.v_proj.linear, model.vision_tower.patch_embedder.input_proj
2026-05-09 05:20:52 INFO device.py L1802: 'peak_ram': 70.72GB, 'peak_vram': {'0': 37.81GB, '1': 29.42GB}

Type of Change

Bug fix

#1783

Checklist Before Submitting

  • My code has been tested locally.
  • Documentation has been updated as needed.
  • New or updated tests are included where applicable.
  • The CUDA CI has passed. You can trigger it by commenting /azp run Unit-Test-CUDA-AutoRound.

Copilot AI review requested due to automatic review settings May 9, 2026 02:34
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR targets memory-related failures during quantization (VLM/diffusion/MLLM) and addresses a Gemma4 multi-GPU crash by adjusting cache behavior, dispatch logic, and re-applying model-specific patches after accelerate hook removal.

Changes:

  • Reduce RAM growth during block-input caching by skipping storage of non-hidden_states per-sample-constant tensor kwargs and limiting caching to the first block (last_cache_name).
  • Improve diffusion multi-device dispatch by reserving primary-device memory for non-target pipeline components before inferring the device map.
  • Reduce MLLM calibration fragmentation by forcing cleanup between samples and ensuring OOM is re-raised for CPU fallback handling; re-apply special model patches (Gemma4) after hook removal.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
auto_round/compressors_new/calib.py Adjusts calibration iteration/caching behavior, adds per-sample-constant tensor skip logic, and limits caching via last_cache_name.
auto_round/utils/device.py Updates diffusion pipeline multi-device dispatch to reserve memory for non-main components before device-map inference.
auto_round/compressors_new/diffusion_mixin.py Uses the updated dispatch logic when a multi-device device_map is provided.
auto_round/compressors_new/mllm_mixin.py Adds gc.collect() + clear_memory() between calibration forwards and re-raises OOMs.
auto_round/special_model_handler.py Makes Gemma4 patching idempotent and applies it unconditionally for Gemma4.
auto_round/compressors/base.py Re-applies special model patches after remove_hook_from_submodules in multi-GPU paths.
test/test_cpu/models/test_vlm_ram_reduction.py Adds unit tests intended to cover the RAM-reduction and dispatch changes.

Comment thread auto_round/compressors_new/calib.py Outdated
Comment on lines +1886 to +1889
# (text encoder, VAE, etc.) to avoid OOM.
main_model = getattr(pipe, main_attr)
dispatched = dispatch_model_by_all_available_devices(main_model, _device_map)
setattr(pipe, main_attr, dispatched)
primary_device = devices[0]

Comment on lines +46 to +61
"""hidden_states must always be cached regardless of shared_cache_keys."""
# Simulate the new logic: skip only if:
# 1. key != "hidden_states"
# 2. isinstance(tensor)
# 3. key not in shared_cache_keys
key = "hidden_states"
is_tensor = True
shared_cache_keys = set()

should_skip = key != "hidden_states" and is_tensor and key not in shared_cache_keys
assert not should_skip, "hidden_states must not be skipped"

def test_per_sample_constant_tensor_is_skipped(self):
"""Per-sample-constant tensors (e.g. attention_mask) must be skipped."""
key = "attention_mask"
is_tensor = True
Comment on lines +127 to +153

with patch("torch.cuda.is_available", return_value=False):
with patch("torch.cuda.device_count", return_value=0):
# Single device path (falls back to pipe.to)
result = dispatch_model_by_all_available_devices(pipe, "cpu")
assert result is pipe

def test_multi_device_respects_non_main_memory(self):
"""With multiple devices, memory reservation must be computed for non-main components."""
from auto_round.utils.device import dispatch_model_by_all_available_devices

class FakeComponent(nn.Module):
def __init__(self, param_count):
super().__init__()
self.register_parameter("p", nn.Parameter(torch.empty(param_count, dtype=torch.bfloat16)))

class FakePipe(nn.Module):
def __init__(self):
super().__init__()
self.transformer = FakeComponent(1000) # main: 1000 * 2 = 2000 bytes
self.text_encoder = FakeComponent(100) # non-main: 100 * 2 = 200 bytes
self.vae = FakeComponent(50) # non-main: 50 * 2 = 100 bytes
# total non-main: 300 params = 600 bytes * 1.2 buffer = 720 bytes

pipe = FakePipe()
pipe.components = {
"transformer": pipe.transformer,
Comment on lines +232 to +240
def handle_exception(exc):
if isinstance(exc, NotImplementedError):
return "caught"
elif isinstance(exc, torch.OutOfMemoryError):
raise exc
else:
return "logged"

assert handle_exception(NotImplementedError()) == "caught"
Comment on lines +288 to +313
"""Without last_cache_name, all blocks must be cached before stopping."""

class FakeCompressor:
def __init__(self):
self.last_cache_name = None
self._cache_target_set = {"model.layers.0", "model.layers.1"}
self._cache_seen_targets = set()

def _should_stop_cache_forward(self, name):
if name == self.last_cache_name:
return True
if self.last_cache_name is not None:
return False
if not hasattr(self, "_cache_target_set") or not hasattr(self, "_cache_seen_targets"):
return False
if name in self._cache_target_set:
self._cache_seen_targets.add(name)
if not self._cache_target_set.issubset(self._cache_seen_targets):
return False
self.last_cache_name = name
return True

c = FakeCompressor()
assert c._should_stop_cache_forward("model.layers.0") is False # not all seen yet
assert c._should_stop_cache_forward("model.layers.1") is True # all seen → lock and stop

Comment thread auto_round/compressors/base.py Outdated
Comment thread auto_round/compressors_new/calib.py Outdated
Comment thread auto_round/compressors_new/calib.py Outdated
Comment thread auto_round/utils/device.py
Comment thread auto_round/special_model_handler.py
Comment thread auto_round/compressors_new/calib.py Outdated
Comment thread auto_round/compressors_new/calib.py Outdated
# Skip non-hidden_states tensor args that are per-sample constants.
# Only hidden_states varies per sample and must be cached.
# However, shared_cache_keys tensors (e.g. position_ids) must still
# be stored once — they are needed by special model patches.
Copy link
Copy Markdown
Contributor

@wenhuach21 wenhuach21 May 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a simpler fix is removing this https://github.com/intel/auto-round/blob/main/auto_round/special_model_handler.py#L750, and reverting most changes in this file

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_PRE_DEFINED_FIXED_ATTR is necessary for transformers >= 5.6, otherwise it will raise "RuntimeError: The size of tensor a (512) must match the size of tensor b (256) at non-singleton dimension 3".

Copy link
Copy Markdown
Contributor

@wenhuach21 wenhuach21 May 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, when variable_block_input is enabled, it is expected to cache all block inputs for each block except hidden_ids, for which only the first block should be cached. Please make sure this behavior has not changed. I don't know why your code work as it seems you only cache hidden_ids.

For transformers>=5.6, I suspect Heng’s patch has some issues as transformers made some change. I suggest updating the patch and only supporting this model with transformers>=5.6.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hidden_states: Only block 0's is cached (skipped for blocks 1-N for RAM saving). All other inputs (position_embeddings, position_ids, attention_mask, etc.): Cached per-block in self.inputs[block_name].

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the RAM usage has been reduced to around 70GB, which I think is expected since the entire model is basically materialized. Since Gemma4 is an important model, would it be possible to refine the monkey-patch approach and disable the variable_block_inputs ? At least for transformers < 5.6, this should be doable and could save lots of ram futher.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The monkey-patch refined.

Comment thread auto_round/compressors_new/mllm_mixin.py Outdated
@lvliang-intel
Copy link
Copy Markdown
Contributor Author

/azp run Unit-Test-CUDA-AutoRound

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 1 pipeline(s).

@lvliang-intel
Copy link
Copy Markdown
Contributor Author

/azp run Unit-Test-CUDA-AutoRound

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 1 pipeline(s).

@chensuyue chensuyue added this to the 0.13.0 milestone May 14, 2026
lvliang-intel and others added 8 commits May 14, 2026 13:48
…g gemma4

Signed-off-by: lvliang-intel <liang1.lv@intel.com>
Signed-off-by: lvliang-intel <liang1.lv@intel.com>
Signed-off-by: lvliang-intel <liang1.lv@intel.com>
Signed-off-by: lvliang-intel <liang1.lv@intel.com>
Signed-off-by: lvliang-intel <liang1.lv@intel.com>
@lvliang-intel lvliang-intel force-pushed the lvl/fix_vlm_large_ram_issue branch from 370ac99 to 3280df5 Compare May 14, 2026 06:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants