Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions auto_round/special_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ def prepare_special_model_block_inputs(block, rotary_input, input_others, positi
return input_others, positional_inputs


def _get_gemma4_shared_kv_states_global(block):
"""Return the shared KV states dict for Gemma4 block-wise quantization."""
ref = getattr(block, "_shared_kv_states_global_ref", None)
if ref is not None:
return ref
return {}


def _get_gemma4_rotary_emb(block, default_rotary_emb=None):
rotary_emb_ref = getattr(block, "_rotary_emb_ref", None)
if rotary_emb_ref:
Expand Down Expand Up @@ -152,7 +160,12 @@ def _prepare_gemma4_replay_inputs(
head_dim = getattr(attn, "head_dim", None)

if attn is not None and hasattr(attn, "store_full_length_kv") and shared_kv_states is None:
shared_kv_states = default_shared_kv_states if default_shared_kv_states is not None else {}
if default_shared_kv_states is not None:
shared_kv_states = default_shared_kv_states
else:
shared_kv_states = _get_gemma4_shared_kv_states_global(block)
if getattr(block, "layer_idx", None) == 0:
shared_kv_states.clear()

need_position_embeddings = position_embeddings is None
if isinstance(position_embeddings, dict):
Expand Down Expand Up @@ -1154,10 +1167,15 @@ def _attach_gemma4_rotary_emb(model):
if text_model is None:
return

# Create a single shared dict to propagate KV state between anchor/sharer layers.
# Gemma4TextModel.forward in newer transformers uses the same pattern.
shared_kv_states_global = {}

for layer in text_model.layers:
# Store in a plain list to prevent nn.Module from registering rotary_emb
# as a child submodule (which would cause meta-tensor errors during .to(device)).
# Store in a plain list to prevent nn.Module from registering these
# as child submodules (which would cause meta-tensor errors during .to(device)).
object.__setattr__(layer, "_rotary_emb_ref", [text_model.rotary_emb])
object.__setattr__(layer, "_shared_kv_states_global_ref", shared_kv_states_global)
object.__setattr__(layer, "_autoround_special_replay", "gemma4")
object.__setattr__(layer, "_gemma4_config_ref", text_model.config)

Expand Down