Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713
Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713cspades wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
50da1dc to
925d022
Compare
Greptile SummaryThis PR adds DCP (Distributed Checkpoint) compatibility for FSDP2 × TP sharding across all Key changes and observations:
Confidence Score: 3/5
|
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
Outdated
Show resolved
Hide resolved
4ec2947 to
dbb9d14
Compare
fcdd5bd to
c912f5b
Compare
c912f5b to
2aadb35
Compare
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Cory Ye <cye@nvidia.com>
ae064d0 to
a7a17c2
Compare
Signed-off-by: Cory Ye <cye@nvidia.com>
a7a17c2 to
bc82f02
Compare
Summary
(H/F)SDP2 x TPstrided sharding, andDTensorFP8 parameters for Torch DCP checkpointing, across allTransformerEngineBaseModule(s).GroupedLinear, pending FSDP2 standalone pipe-cleaning. All other modules undertransformer_engine.pytorch.modulesare supported.FusibleOperationsupport is also a WIP, except forLayerNormorRMSNormwhich are TE modules.DTensor-based TP when unified by Torch DCP! In the Llama3 recipe, we useDTensor-based TP on thetorch.nn.Embedding, TransformerEngine-based TP on the LM head, and weight-tie the LM head to thetorch.nn.Embedding, which is why we do not need to callset_device_meshfor the LM head!Usage / Documentation
Details
DTensor Lifecycle in TransformerEngine
__init__metadevice with the appropriatetp_sizeand TP sharding strategy, e.g.parallel_modeandsequence_parallel.TransformerEngineModule.set_device_mesh(tp_mesh, weight_mesh)DTensorwith appropriate TPplacement(s) based on the TP sharding strategy specified in__init__, usingtransformer_engine.pytorch.distributed._convert_param_to_dtensor_param.tp_meshis a 1-DDeviceMeshcontaining the TPProcessGroupthat will be registered with the TransformerEngine module.weight_meshis the 1-DDeviceMeshcontaining theProcessGroupthat shards TransformerEngine module weights, the flattened combination of groups such as FSDP and TP. Specifically, it excludes non-weight groups such as DP-Replicate when using HSDP or HSDP-TP and is mainly required for per-Tensor scaling recipes likeFloat8CurrentScaling.fully_shard(which responds to the TP placements) and prior toreset_parameters(defer_init=False), which quantizes parameters.__init__(tp_mesh, weight_mesh)for supported TransformerEngine modules.fully_shardshards the TransformerEngine model with FSDP2.fully_shardencounters TP sharding ondim=0, it will use a_StridedShardfor DP. Put simply, this "pre-shards" the data prior to sharding on the current placement, followed by concatenating the pre-shards to get strided shards that will be re-sharded by the next placement. This effectively reverses the sharding order when processing the placements from left-to-right, and distributes shards as if we sharded on TP first, then FSDP, as required, even though DP appears before TP in theDeviceMeshandDTensor.placements. (SeeAppendixfor visualization of this sharding strategy.)reset_parametersis called if using meta device initialization.fully_shard. (Note that this essentially shares the same properties as the compute weight besides shape, and supporting tools such asFusedAdammust be used to properly handle high-precision main weights.)Tensoris actually a TP-shardedDTensor, which deviates from the original FSDP2 paradigm where the all-gatheredTensoris fully-unsharded and theDTensorwrapping is discarded. To support theseDTensorcompute weights in TransformerEngine modules, we utilizetransformer_engine.pytorch.distributed._extract_trainable_tensor_from_dtensorto localize theDTensorand also inheritrequires_gradattribute from theDTensorparameter as the localTensorhas this un-set duringDTensor.from_local(Tensor)for FP8 parameters specifically!Tensorgradient is converted and attached to theDTensor.gradattribute.FusibleOperation(RMSNorm and LayerNorm) require casting the gradient fromTensorto aDTensormatching the configuration of theDTensorweights. I have confirmed the gradient is installed correctly onRMSNormweights (same shape and sharding configuration as the sharded optimizer state), and it will not affect normal TransfomerEngine operations, but it is not totally clear why this is necessary with FSDP2 x TP.Bugs
"shard"was the presumed weight sharding sub-mesh in theDTensor.device_mesh. Now, users can precisely specify their own custom weight-shardingDeviceMeshfor per-tensoramax_reduction_groupvia theset_device_mesh(weight_mesh)API.TransformerEngineBaseModule:self.quantizers = {"scaling_fwd": [], "scaling_bwd": []}Testing
num_zerostest failure that is common to bothmainandcspades:cye/fsdp2-tp-dcpso we can assume it is not associated to my change: https://github.com/NVIDIA/Megatron-LM/actions/runs/22637904520/job/65636890955?pr=3661 (TransformerEnginemain)mainvs.cspades:cye/fsdp2-tp-dcpwith Megatron-LMmainon PyTorch25.11DelayedScalinghas DCP save/load disparity issues, i.e. on the scale of+/-1to theuint8parameter checkpoint!Appendix
_StridedShard- Using FSDP2 x TP Strided-ShardingWhen
redistribute'ing a global DTensor to(_StridedShard(dim=0, sf=2), Shard(dim=0)),DTensorwill perform the following steps:Shardplacements to the right of_StridedShard. (In the above example, since TP=2, the factor is 2.)[0 1 2 3 4 5 6 7] -> [0 1 2 3] and [4 5 6 7].fully_shard, this has already been done via initializing the TransformerEngine module with TP and calling_convert_param_to_dtensor_param!_StridedShard.[0] [1] [2] [3]and[4] [5] [6] [7][0 4] [1 5] [2 6] [3 7], which are assigned to the_StridedShardranks.[0 1] [2 3] [4 5] [6 7]!Shardplacement.[0] [4]/[1] [5]/[2] [6]/[3] [7], which are assigned to theShardranks.[0] [1]/[2] [3]/[4] [5]/[6] [7]!PyTorch also supports the inverse / un-sharding of this
redistribute, which is literally the inverse of these simple operations! (Though things get a bit more complicated with un-even shards from odd-numbered dimension sizes.)Type of change
Changes
Please list the changes introduced in this PR:
Checklist: