Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 63 additions & 9 deletions llmc/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from .base_model import BaseModel


def _has_legacy_moe(block):
return hasattr(block, 'block_sparse_moe')


@MODEL_REGISTRY
class Mixtral(BaseModel):
def __init__(self, config, device_map=None, use_cache=False):
Expand Down Expand Up @@ -36,11 +40,27 @@ def get_layernorms_in_block(self, block):
}

def get_extra_modules(self, block):
if _has_legacy_moe(block):
return {
'block_sparse_moe': block.block_sparse_moe
}
return {
'block_sparse_moe': block.block_sparse_moe
'mlp': block.mlp
}

def get_moe_gate(self, block):
if _has_legacy_moe(block):
return block.block_sparse_moe.gate
return block.mlp.gate

def get_subsets_in_block(self, block):
if _has_legacy_moe(block):
return self._get_subsets_legacy(block)
return self._get_subsets_fused(block)

def _get_subsets_legacy(self, block):
"""transformers <5.0: block.block_sparse_moe with ModuleList experts."""
moe = block.block_sparse_moe
return [
{
'layers': {
Expand All @@ -62,25 +82,59 @@ def get_subsets_in_block(self, block):
},
{
'layers': {
**{f'block_sparse_moe.experts.{i}.w1': block.block_sparse_moe.experts[i].w1 for i in range(len(block.block_sparse_moe.experts))}, # noqa
**{f'block_sparse_moe.experts.{i}.w3': block.block_sparse_moe.experts[i].w3 for i in range(len(block.block_sparse_moe.experts))}, # noqa
'block_sparse_moe.gate': block.block_sparse_moe.gate,
**{f'block_sparse_moe.experts.{i}.w1': moe.experts[i].w1 for i in range(len(moe.experts))}, # noqa
**{f'block_sparse_moe.experts.{i}.w3': moe.experts[i].w3 for i in range(len(moe.experts))}, # noqa
'block_sparse_moe.gate': moe.gate,
},
'prev_op': [block.post_attention_layernorm],
'input': ['block_sparse_moe'],
'inspect': block.block_sparse_moe,
'inspect': moe,
'has_kwargs': False,
'is_mlp': True,
},
*[
{
'layers': {f'block_sparse_moe.experts.{i}.w2': block.block_sparse_moe.experts[i].w2}, # noqa
'prev_op': [block.block_sparse_moe.experts[i].w3],
'layers': {f'block_sparse_moe.experts.{i}.w2': moe.experts[i].w2},
'prev_op': [moe.experts[i].w3],
'input': [f'block_sparse_moe.experts.{i}.w2'],
'inspect': block.block_sparse_moe.experts[i].w2,
'inspect': moe.experts[i].w2,
'has_kwargs': False,
'is_mlp': True,
}
for i in range(len(block.block_sparse_moe.experts))
for i in range(len(moe.experts))
],
]

def _get_subsets_fused(self, block):
"""transformers >=5.0: block.mlp with fused MixtralExperts."""
moe = block.mlp
return [
{
'layers': {
'self_attn.q_proj': block.self_attn.q_proj,
'self_attn.k_proj': block.self_attn.k_proj,
'self_attn.v_proj': block.self_attn.v_proj,
},
'prev_op': [block.input_layernorm],
'input': ['self_attn.q_proj'],
'inspect': block.self_attn,
'has_kwargs': True,
},
{
'layers': {'self_attn.o_proj': block.self_attn.o_proj},
'prev_op': [block.self_attn.v_proj],
'input': ['self_attn.o_proj'],
'inspect': block.self_attn.o_proj,
'has_kwargs': False,
},
{
'layers': {
'mlp.gate': moe.gate,
},
'prev_op': [block.post_attention_layernorm],
'input': ['mlp'],
'inspect': moe,
'has_kwargs': False,
'is_mlp': True,
},
]