-
Notifications
You must be signed in to change notification settings - Fork 635
Add examples for MoE models - Mixtral in TE #2642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,192 @@ | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cells": [ | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "markdown", | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "# Mixtral MoE with Transformer Engine\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "## Step 1: Wrap MoE Layers with TE Modules\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "This notebook demonstrates wrapping Mixtral's MoE FFN layers with Transformer Engine's `GroupedLinear` for efficient expert processing.\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "Reference: `src/transformers/models/mixtral/modular_mixtral.py`" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "code", | ||||||||||||||||||||||
| "execution_count": null, | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "outputs": [], | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "import torch\n", | ||||||||||||||||||||||
| "import torch.nn as nn\n", | ||||||||||||||||||||||
| "import torch.nn.functional as F\n", | ||||||||||||||||||||||
| "from typing import Optional, Tuple\n", | ||||||||||||||||||||||
| "import transformer_engine.pytorch as te\n", | ||||||||||||||||||||||
| "from transformer_engine.pytorch import GroupedLinear\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "class TEMixtralSparseMoeBlock(nn.Module):\n", | ||||||||||||||||||||||
| " \"\"\"\n", | ||||||||||||||||||||||
| " Transformer Engine optimized MoE block using GroupedLinear for parallel expert processing.\n", | ||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| " Key improvements:\n", | ||||||||||||||||||||||
| " 1. Use te.GroupedLinear to process all experts in a single batched GEMM\n", | ||||||||||||||||||||||
| " 2. Use te.moe_permute/unpermute for efficient token routing\n", | ||||||||||||||||||||||
| " \"\"\"\n", | ||||||||||||||||||||||
| " def __init__(self, config):\n", | ||||||||||||||||||||||
| " super().__init__()\n", | ||||||||||||||||||||||
| " self.hidden_dim = config.hidden_size\n", | ||||||||||||||||||||||
| " self.ffn_dim = config.intermediate_size\n", | ||||||||||||||||||||||
| " self.num_experts = config.num_local_experts\n", | ||||||||||||||||||||||
| " self.top_k = config.num_experts_per_tok\n", | ||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| " # Keep HuggingFace router (not in critical path for performance)\n", | ||||||||||||||||||||||
| " self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)\n", | ||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| " # Replace individual expert layers with GroupedLinear\n", | ||||||||||||||||||||||
| " # GroupedLinear processes all experts in parallel with a single GEMM\n", | ||||||||||||||||||||||
| " # For SwiGLU: w1 (gate) and w3 (up) are combined, then w2 (down)\n", | ||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| " # w1 and w3 combined (gate_proj + up_proj)\n", | ||||||||||||||||||||||
| " self.experts_gate_up = GroupedLinear(\n", | ||||||||||||||||||||||
| " num_gemms=self.num_experts,\n", | ||||||||||||||||||||||
| " in_features=self.hidden_dim,\n", | ||||||||||||||||||||||
| " out_features=2 * self.ffn_dim, # 2x for gate and up proj combined\n", | ||||||||||||||||||||||
| " bias=False,\n", | ||||||||||||||||||||||
| " device='cuda'\n", | ||||||||||||||||||||||
| " )\n", | ||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| " # w2 (down_proj)\n", | ||||||||||||||||||||||
| " self.experts_down = GroupedLinear(\n", | ||||||||||||||||||||||
| " num_gemms=self.num_experts,\n", | ||||||||||||||||||||||
| " in_features=self.ffn_dim,\n", | ||||||||||||||||||||||
| " out_features=self.hidden_dim,\n", | ||||||||||||||||||||||
| " bias=False,\n", | ||||||||||||||||||||||
| " device='cuda'\n", | ||||||||||||||||||||||
| " )\n", | ||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| " def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n", | ||||||||||||||||||||||
| " \"\"\"\n", | ||||||||||||||||||||||
| " Args:\n", | ||||||||||||||||||||||
| " hidden_states: [batch_size, sequence_length, hidden_dim]\n", | ||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| " Returns:\n", | ||||||||||||||||||||||
| " final_hidden_states: [batch_size, sequence_length, hidden_dim]\n", | ||||||||||||||||||||||
| " router_logits: [batch_size * sequence_length, num_experts]\n", | ||||||||||||||||||||||
| " \"\"\"\n", | ||||||||||||||||||||||
| " batch_size, sequence_length, hidden_dim = hidden_states.shape\n", | ||||||||||||||||||||||
| " hidden_states_flat = hidden_states.view(-1, hidden_dim) # [num_tokens, hidden_dim]\n", | ||||||||||||||||||||||
| " num_tokens = hidden_states_flat.shape[0]\n", | ||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| " # Router: Get expert assignments for each token\n", | ||||||||||||||||||||||
| " router_logits = self.gate(hidden_states_flat)\n", | ||||||||||||||||||||||
| " routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n", | ||||||||||||||||||||||
| " routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n", | ||||||||||||||||||||||
| " routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n", | ||||||||||||||||||||||
| " routing_weights = routing_weights.to(hidden_states.dtype)\n", | ||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| " # Permute tokens by expert assignment\n", | ||||||||||||||||||||||
| " # moe_permute groups tokens going to the same expert together\n", | ||||||||||||||||||||||
| " permuted_tokens, row_id_map = te.moe_permute(\n", | ||||||||||||||||||||||
| " hidden_states_flat,\n", | ||||||||||||||||||||||
| " selected_experts.to(torch.int32),\n", | ||||||||||||||||||||||
| " num_out_tokens=None, # Auto-calculate\n", | ||||||||||||||||||||||
| " max_token_num=num_tokens\n", | ||||||||||||||||||||||
| " )\n", | ||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| " # Calculate m_splits: number of tokens assigned to each expert\n", | ||||||||||||||||||||||
| " m_splits = []\n", | ||||||||||||||||||||||
| " for expert_idx in range(self.num_experts):\n", | ||||||||||||||||||||||
| " expert_mask = (selected_experts == expert_idx).any(dim=-1)\n", | ||||||||||||||||||||||
| " m_splits.append(expert_mask.sum().item() * self.top_k)\n", | ||||||||||||||||||||||
|
Comment on lines
+98
to
+102
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Logic error in The issue: For example, if token 0 selects experts [1, 3] and token 1 selects experts [1, 2], then for expert 1:
Suggested change
|
||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| " # Process all experts in parallel using GroupedLinear\n", | ||||||||||||||||||||||
| " # Gate and Up projection (combined)\n", | ||||||||||||||||||||||
| " intermediate = self.experts_gate_up(permuted_tokens, m_splits=m_splits)\n", | ||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| " # Apply SwiGLU activation: silu(gate) * up\n", | ||||||||||||||||||||||
| " gate, up = intermediate.chunk(2, dim=-1)\n", | ||||||||||||||||||||||
| " intermediate_act = F.silu(gate) * up\n", | ||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| " # Down projection\n", | ||||||||||||||||||||||
| " expert_outputs = self.experts_down(intermediate_act, m_splits=m_splits)\n", | ||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| " # Unpermute tokens back to original order and apply routing weights\n", | ||||||||||||||||||||||
| " final_hidden_states = te.moe_unpermute(\n", | ||||||||||||||||||||||
| " expert_outputs,\n", | ||||||||||||||||||||||
| " row_id_map,\n", | ||||||||||||||||||||||
| " probs=routing_weights\n", | ||||||||||||||||||||||
| " )\n", | ||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| " final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)\n", | ||||||||||||||||||||||
| " return final_hidden_states, router_logits" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "markdown", | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "### Test the Implementation" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "code", | ||||||||||||||||||||||
| "execution_count": null, | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "outputs": [], | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "# Create a mock config for testing\n", | ||||||||||||||||||||||
| "class MixtralConfig:\n", | ||||||||||||||||||||||
| " hidden_size = 4096\n", | ||||||||||||||||||||||
| " intermediate_size = 14336\n", | ||||||||||||||||||||||
| " num_local_experts = 8\n", | ||||||||||||||||||||||
| " num_experts_per_tok = 2\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "config = MixtralConfig()\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "# Initialize TE-optimized MoE block\n", | ||||||||||||||||||||||
| "te_moe_block = TEMixtralSparseMoeBlock(config).cuda()\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "# Test with sample input\n", | ||||||||||||||||||||||
| "batch_size, seq_len = 2, 16\n", | ||||||||||||||||||||||
| "hidden_states = torch.randn(batch_size, seq_len, config.hidden_size, device='cuda', dtype=torch.bfloat16)\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "# Forward pass\n", | ||||||||||||||||||||||
| "with torch.no_grad():\n", | ||||||||||||||||||||||
| " output, router_logits = te_moe_block(hidden_states)\n", | ||||||||||||||||||||||
| " \n", | ||||||||||||||||||||||
| "print(f\"Input shape: {hidden_states.shape}\")\n", | ||||||||||||||||||||||
| "print(f\"Output shape: {output.shape}\")\n", | ||||||||||||||||||||||
| "print(f\"Router logits shape: {router_logits.shape}\")\n", | ||||||||||||||||||||||
| "print(f\"Output dtype: {output.dtype}\")\n", | ||||||||||||||||||||||
| "print(\"✓ TE-optimized MoE block working correctly!\")" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "markdown", | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "source": [ | ||||||||||||||||||||||
| "### Next: Weight Mapping and Integration\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "To integrate with HuggingFace Mixtral models, you need to:\n", | ||||||||||||||||||||||
| "\n", | ||||||||||||||||||||||
| "1. Map weights from HF `MixtralSparseMoeBlock` to `TEMixtralSparseMoeBlock`\n", | ||||||||||||||||||||||
| "2. Use monkey-patching to replace HF layers during model loading\n", | ||||||||||||||||||||||
| "3. Implement weight loading from HF checkpoints" | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| "cell_type": "markdown", | ||||||||||||||||||||||
| "metadata": {}, | ||||||||||||||||||||||
| "source": [] | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| ], | ||||||||||||||||||||||
| "metadata": { | ||||||||||||||||||||||
| "language_info": { | ||||||||||||||||||||||
| "name": "python" | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| "nbformat": 4, | ||||||||||||||||||||||
| "nbformat_minor": 2 | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting num_out_tokens to None is fine for auto-calculation, but when using top_k > 1, the expected output token count should be num_tokens times top_k since each token is routed to multiple experts.