This library contains plugins to accelerate finetuning with the following optimizations:
- Expert-Parallel MoE with Triton Kernels from ScatterMoE, and some extracted from megablocks.
- Megablocks kernels for
gatherandscatter
- Megablocks kernels for
| Plugin | Description | Depends | Loading | Augmentation | Callbacks |
|---|---|---|---|---|---|
| scattermoe | MoE Expert Parallel with Triton Kernels from scattermoe (& megablocks) | ScatterMoE / extracted kernels from megablocks | ✅ | ✅ |
Our ScatterMoe implementation is a module-swap; to add new models we need to update the specifications in scattermoe_constants.py.
- See the code documentation within to understand how to add new models.
ScatterMoE checkpoints are saved using torch.distributed.checkpoint (DCP) and which is by default StateDictType.SHARDED_STATE_DICT:
DTensorslimited support for full state dicts.- sharded state dicts are the extremely efficient, and require little comms overhead when saving.
We provide a script to recover back the original checkpoint:
- currently the script is only tested in the case where DCP has saved the model in a single node.
If the checkpoint is stored in hf/checkpoint-10, call the following to have the converted checkpoint written into output_dir:
python -m fms_acceleration_moe.utils.checkpoint_utils \
hf/checkpoint-10 output_dir \
mistralai/Mixtral-8x7B-Instruct-v0.1
Notes on code extraction:
- we have only extracted two
autogradfunctions GatherOp and ScatterOp, - and the associated triton kernels from backend/kernels.py; mostly the
_padded_copy.
Run the below in the top-level directory of this repo:
- the
scattermoedep is not included by default, so the-xswitch installs it. - consider disabling the
torchmemory logging to see improved speeds.
tox -e run-benches \
-x testenv:run-benches.setenv+="MEMORY_LOGGING=nvidia" \
-- \
"1 2 4" 128 benchmark_outputs scenarios-moe.yaml accelerated-moe-full
or run the larger Mixtral-8x7B bench:
tox ... \
8 128 benchmark_outputs scenarios-moe.yaml accelerated-moe-full-mixtral
NOTE: if FileNotFoundError is observed on the triton cache, similar to issues like these:
then somehow tox is causing problems with triton and multiprocessing (there is some race condition).
But the workaound is to first activate the tox env and
running in bash:
# if FileNotFoundError in the triton cache is observed
# - then activate the env and run the script manually
source .tox/run-benches/bin/activate
bash scripts/run_benchmarks.sh \
....
Triton Kernels are copied into scattermoe_utils and were copied from kernel hyperdrive which is a fork of cute kernels
These are currently some known issues not yet resolved:
- should eventually remove the dependency on an external
kernel-hyperdriverepository. - now support only loading sharded
safetensornon-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed. - when used together with FSDP, the FSDP's
clip_grad_normwill not properly compute forScatterMoE, see issue here.