diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index 1d94f6354..453149f4e 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -114,6 +114,14 @@ def prepare_special_model_block_inputs(block, rotary_input, input_others, positi return input_others, positional_inputs +def _get_gemma4_shared_kv_states_global(block): + """Return the shared KV states dict for Gemma4 block-wise quantization.""" + ref = getattr(block, "_shared_kv_states_global_ref", None) + if ref is not None: + return ref + return {} + + def _get_gemma4_rotary_emb(block, default_rotary_emb=None): rotary_emb_ref = getattr(block, "_rotary_emb_ref", None) if rotary_emb_ref: @@ -152,7 +160,12 @@ def _prepare_gemma4_replay_inputs( head_dim = getattr(attn, "head_dim", None) if attn is not None and hasattr(attn, "store_full_length_kv") and shared_kv_states is None: - shared_kv_states = default_shared_kv_states if default_shared_kv_states is not None else {} + if default_shared_kv_states is not None: + shared_kv_states = default_shared_kv_states + else: + shared_kv_states = _get_gemma4_shared_kv_states_global(block) + if getattr(block, "layer_idx", None) == 0: + shared_kv_states.clear() need_position_embeddings = position_embeddings is None if isinstance(position_embeddings, dict): @@ -1154,10 +1167,15 @@ def _attach_gemma4_rotary_emb(model): if text_model is None: return + # Create a single shared dict to propagate KV state between anchor/sharer layers. + # Gemma4TextModel.forward in newer transformers uses the same pattern. + shared_kv_states_global = {} + for layer in text_model.layers: - # Store in a plain list to prevent nn.Module from registering rotary_emb - # as a child submodule (which would cause meta-tensor errors during .to(device)). + # Store in a plain list to prevent nn.Module from registering these + # as child submodules (which would cause meta-tensor errors during .to(device)). object.__setattr__(layer, "_rotary_emb_ref", [text_model.rotary_emb]) + object.__setattr__(layer, "_shared_kv_states_global_ref", shared_kv_states_global) object.__setattr__(layer, "_autoround_special_replay", "gemma4") object.__setattr__(layer, "_gemma4_config_ref", text_model.config)