support mtp for gemma4#1316
Conversation
…into support_gemma4
There was a problem hiding this comment.
Code Review
This pull request introduces comprehensive support for the Gemma-4 model family, including multimodal vision capabilities and Multi-Token Prediction (MTP) assistant models. Key technical additions include heterogeneous attention mechanisms for sliding window and full attention layers, tanh-approximate GELU activations in MoE kernels, and a specialized eagle_frozen_kv MTP mode. The implementation also features a new reasoning parser for Gemma-4's Harmony-like format and updates to various Triton kernels. Feedback on the code changes suggests adopting more idiomatic PyTorch advanced indexing for row selection in the MTP post-layer inference and improving robustness by replacing bare except blocks with except Exception in configuration utilities.
| token_num, num_selected, H | ||
| ) | ||
| # Sparse logits: dot product per token vs its selected rows. |
There was a problem hiding this comment.
Using advanced indexing is more idiomatic and readable than index_select followed by a view when selecting rows from a weight matrix. PyTorch's advanced indexing handles this pattern efficiently.
| token_num, num_selected, H | |
| ) | |
| # Sparse logits: dot product per token vs its selected rows. | |
| selected_embeddings = lm_head_w[selected_vocab] |
| return [eos_token_id] | ||
| elif isinstance(eos_token_id, list): | ||
| return list(eos_token_id) | ||
| except: |
| if model_type in ["gemma4"]: | ||
| logger.info("Gemma4 uses tanh-approximate-gelu for FFN") | ||
| return True | ||
| except: |
No description provided.