From 0e3e727887d7ed9f088119f7f01549031804960e Mon Sep 17 00:00:00 2001 From: lvliang-intel Date: Thu, 21 May 2026 14:19:28 +0800 Subject: [PATCH 1/3] Fix Gemma4 KeyError sliding_attention issue Signed-off-by: lvliang-intel --- auto_round/special_model_handler.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index 1d94f6354..4f6e9b522 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -114,6 +114,14 @@ def prepare_special_model_block_inputs(block, rotary_input, input_others, positi return input_others, positional_inputs +def _get_gemma4_shared_kv_states_global(block): + """Return the shared KV states dict for Gemma4 block-wise quantization.""" + ref = getattr(block, "_shared_kv_states_global_ref", None) + if ref is not None: + return ref + return {} + + def _get_gemma4_rotary_emb(block, default_rotary_emb=None): rotary_emb_ref = getattr(block, "_rotary_emb_ref", None) if rotary_emb_ref: @@ -152,7 +160,7 @@ def _prepare_gemma4_replay_inputs( head_dim = getattr(attn, "head_dim", None) if attn is not None and hasattr(attn, "store_full_length_kv") and shared_kv_states is None: - shared_kv_states = default_shared_kv_states if default_shared_kv_states is not None else {} + shared_kv_states = default_shared_kv_states if default_shared_kv_states is not None else _get_gemma4_shared_kv_states_global(block) need_position_embeddings = position_embeddings is None if isinstance(position_embeddings, dict): @@ -1154,10 +1162,15 @@ def _attach_gemma4_rotary_emb(model): if text_model is None: return + # Create a single shared dict to propagate KV state between anchor/sharer layers. + # Gemma4TextModel.forward in newer transformers uses the same pattern. + shared_kv_states_global = {} + for layer in text_model.layers: - # Store in a plain list to prevent nn.Module from registering rotary_emb - # as a child submodule (which would cause meta-tensor errors during .to(device)). + # Store in a plain list to prevent nn.Module from registering these + # as child submodules (which would cause meta-tensor errors during .to(device)). object.__setattr__(layer, "_rotary_emb_ref", [text_model.rotary_emb]) + object.__setattr__(layer, "_shared_kv_states_global_ref", shared_kv_states_global) object.__setattr__(layer, "_autoround_special_replay", "gemma4") object.__setattr__(layer, "_gemma4_config_ref", text_model.config) From e9cf5f42e40b57d3e4b5e857940a42794781ddcc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 May 2026 06:48:28 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/special_model_handler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index 4f6e9b522..a92dcb9e8 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -160,7 +160,11 @@ def _prepare_gemma4_replay_inputs( head_dim = getattr(attn, "head_dim", None) if attn is not None and hasattr(attn, "store_full_length_kv") and shared_kv_states is None: - shared_kv_states = default_shared_kv_states if default_shared_kv_states is not None else _get_gemma4_shared_kv_states_global(block) + shared_kv_states = ( + default_shared_kv_states + if default_shared_kv_states is not None + else _get_gemma4_shared_kv_states_global(block) + ) need_position_embeddings = position_embeddings is None if isinstance(position_embeddings, dict): From 8be3062eb2a6dffb0d67e94f805118a0cc9638aa Mon Sep 17 00:00:00 2001 From: Liang Lv Date: Thu, 21 May 2026 15:01:28 +0800 Subject: [PATCH 3/3] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- auto_round/special_model_handler.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index a92dcb9e8..453149f4e 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -160,11 +160,12 @@ def _prepare_gemma4_replay_inputs( head_dim = getattr(attn, "head_dim", None) if attn is not None and hasattr(attn, "store_full_length_kv") and shared_kv_states is None: - shared_kv_states = ( - default_shared_kv_states - if default_shared_kv_states is not None - else _get_gemma4_shared_kv_states_global(block) - ) + if default_shared_kv_states is not None: + shared_kv_states = default_shared_kv_states + else: + shared_kv_states = _get_gemma4_shared_kv_states_global(block) + if getattr(block, "layer_idx", None) == 0: + shared_kv_states.clear() need_position_embeddings = position_embeddings is None if isinstance(position_embeddings, dict):