Skip to content

Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713

Open
cspades wants to merge 5 commits intoNVIDIA:mainfrom
cspades:cye/fsdp2-tp-dcp
Open

Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713
cspades wants to merge 5 commits intoNVIDIA:mainfrom
cspades:cye/fsdp2-tp-dcp

Conversation

@cspades
Copy link
Member

@cspades cspades commented Feb 26, 2026

Summary

  • Support (H/F)SDP2 x TP strided sharding, and DTensor FP8 parameters for Torch DCP checkpointing, across all TransformerEngineBaseModule(s).
    • Except GroupedLinear, pending FSDP2 standalone pipe-cleaning. All other modules under transformer_engine.pytorch.modules are supported.
    • FusibleOperation support is also a WIP, except for LayerNorm or RMSNorm which are TE modules.
  • Associated with BioNeMo-Recipes Llama3 TP: Enable TransformerEngine-backed Tensor Parallelism with Llama3. bionemo-framework#1483
    • Notably, TransformerEngine TP can be easily mixed with DTensor-based TP when unified by Torch DCP! In the Llama3 recipe, we use DTensor-based TP on the torch.nn.Embedding, TransformerEngine-based TP on the LM head, and weight-tie the LM head to the torch.nn.Embedding, which is why we do not need to call set_device_mesh for the LM head!

Usage / Documentation

    def set_device_mesh(
        self,
        tp_mesh: Optional[DeviceMesh] = None,
        weight_mesh: Optional[DeviceMesh] = None,
    ) -> None:
        """
        Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor
        depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2.

        TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless
        integration with Torch DCP checkpointing. This method should only be invoked when
        using DTensor parameters, e.g. when using FSDP2 or DCP.

        When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically
        convert them into FSDP-TP strided or non-strided shards depending on the current
        sharding dimension and factor of the DTensor. When the sharding dimension of FSDP
        matches that of TP, FSDP uses a _StridedShard placement type instead of Shard.
        This experimental FSDP-TP logic presides in this FSDP2 initialization function:
        ``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param``

        Parameters
        ----------
        tp_mesh : Optional[DeviceMesh]
            A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"].
            Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP.
        weight_mesh : Optional[DeviceMesh]
            A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required
            when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor
            parameters and if the DTensor DeviceMesh includes dimensions that do not
            shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard).
            For example:
                - device_mesh["dp"] for FSDP.
                - device_mesh["dp_cp"] if using CP ranks in FSDP.
                - device_mesh["dp_shard"] if using HSDP ("dp_replicate", "dp_shard").
                - device_mesh["tp"] if using TP.
                - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP.
        """

Details

DTensor Lifecycle in TransformerEngine

  • Initialization
    • __init__
      • TransformerEngine model parameters are initialized either on device or meta device with the appropriate tp_size and TP sharding strategy, e.g. parallel_mode and sequence_parallel.
    • TransformerEngineModule.set_device_mesh(tp_mesh, weight_mesh)
      • Converts parameters to DTensor with appropriate TP placement(s) based on the TP sharding strategy specified in __init__, using transformer_engine.pytorch.distributed._convert_param_to_dtensor_param.
        • tp_mesh is a 1-D DeviceMesh containing the TP ProcessGroup that will be registered with the TransformerEngine module.
        • weight_mesh is the 1-D DeviceMesh containing the ProcessGroup that shards TransformerEngine module weights, the flattened combination of groups such as FSDP and TP. Specifically, it excludes non-weight groups such as DP-Replicate when using HSDP or HSDP-TP and is mainly required for per-Tensor scaling recipes like Float8CurrentScaling.
      • Needs to be invoked prior to fully_shard (which responds to the TP placements) and prior to reset_parameters(defer_init=False), which quantizes parameters.
      • Can also be directly invoked during __init__(tp_mesh, weight_mesh) for supported TransformerEngine modules.
    • fully_shard shards the TransformerEngine model with FSDP2.
      • When fully_shard encounters TP sharding on dim=0, it will use a _StridedShard for DP. Put simply, this "pre-shards" the data prior to sharding on the current placement, followed by concatenating the pre-shards to get strided shards that will be re-sharded by the next placement. This effectively reverses the sharding order when processing the placements from left-to-right, and distributes shards as if we sharded on TP first, then FSDP, as required, even though DP appears before TP in the DeviceMesh and DTensor.placements. (See Appendix for visualization of this sharding strategy.)
    • reset_parameters is called if using meta device initialization.
  • Training
    • Pre-forward, FSDP2 all-gathers the sharded DTensor "main" weight that it registered during fully_shard. (Note that this essentially shares the same properties as the compute weight besides shape, and supporting tools such as FusedAdam must be used to properly handle high-precision main weights.)
      • When using FSDP2 x TP, the all-gathered Tensor is actually a TP-sharded DTensor, which deviates from the original FSDP2 paradigm where the all-gathered Tensor is fully-unsharded and the DTensor wrapping is discarded. To support these DTensor compute weights in TransformerEngine modules, we utilize transformer_engine.pytorch.distributed._extract_trainable_tensor_from_dtensor to localize the DTensor and also inherit requires_grad attribute from the DTensor parameter as the local Tensor has this un-set during DTensor.from_local(Tensor) for FP8 parameters specifically!
    • Post-backward, the Tensor gradient is converted and attached to the DTensor.grad attribute.
      • NOTE(@cspades, @vthumbe1503): For some reason, FusibleOperation (RMSNorm and LayerNorm) require casting the gradient from Tensor to a DTensor matching the configuration of the DTensor weights. I have confirmed the gradient is installed correctly on RMSNorm weights (same shape and sharding configuration as the sharded optimizer state), and it will not affect normal TransfomerEngine operations, but it is not totally clear why this is necessary with FSDP2 x TP.

Bugs

  • Fix bug where "shard" was the presumed weight sharding sub-mesh in the DTensor.device_mesh. Now, users can precisely specify their own custom weight-sharding DeviceMesh for per-tensor amax_reduction_group via the set_device_mesh(weight_mesh) API.
  • TransformerEngineBaseModule: self.quantizers = {"scaling_fwd": [], "scaling_bwd": []}

Testing

# TransformerEngine Main
[Rank 0] (after 1 iterations) memory (MB) | allocated: 23511.65 | max allocated: 25189.68 | reserved: 25678.00 | max reserved: 25678.00
 [2026-03-02 09:55:17.189564] iteration       99/15258789 | consumed samples:        12672 | elapsed time per iteration (ms): 12715.7 | throughput per GPU (TFLOP/s/GPU): 530.6 | learning rate: 4.866046E-07 | global batch size:   128 | lm loss: 1.124915E+00 | loss scale: 1.0 | grad norm: 5.474 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2026-03-02 09:55:29.768521] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 12578.7 | throughput per GPU (TFLOP/s/GPU): 536.4 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.143806E+00 | loss scale: 1.0 | grad norm: 5.366 | number of skipped iterations:   0 | number of nan iterations:   0 |

# Post-DCP Modifications (This PR)
[Rank 0] (after 2 iterations) memory (MB) | allocated: 23511.65 | max allocated: 29783.24 | reserved: 25678.00 | max reserved: 31510.00
 [2026-03-02 09:29:36.550070] iteration       99/15258789 | consumed samples:        12672 | elapsed time per iteration (ms): 12556.5 | throughput per GPU (TFLOP/s/GPU): 537.3 | learning rate: 4.866046E-07 | global batch size:   128 | lm loss: 1.124463E+00 | loss scale: 1.0 | grad norm: 5.471 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2026-03-02 09:29:49.216068] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 12665.7 | throughput per GPU (TFLOP/s/GPU): 532.7 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.142863E+00 | loss scale: 1.0 | grad norm: 5.355 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • NOTE(@cspades): DelayedScaling has DCP save/load disparity issues, i.e. on the scale of +/-1 to the uint8 parameter checkpoint!

Appendix

_StridedShard - Using FSDP2 x TP Strided-Sharding

# (DP=4, TP=2)
(_StridedShard(dim=0, sf=2), Shard(dim=0))

┌───┬───┐
│ 0 │ 4 │ ← DP=0
├───┼───┤
│ 1 │ 5 │ ← DP=1
├───┼───┤          FSDP all-gather happens across the DP ranks,
│ 2 │ 6 │ ← DP=2   so we need to form the 0-3 and 4-7 TP shards!
├───┼───┤
│ 3 │ 7 │ ← DP=3
└───┴───┘
  ↑   ↑
TP=0 TP=1

When redistribute'ing a global DTensor to (_StridedShard(dim=0, sf=2), Shard(dim=0)), DTensor will perform the following steps:

  • Pre-shard the Tensor data with respect to the stride / shard factor, which is defined as the product of the parallelism sizes of all Shard placements to the right of _StridedShard. (In the above example, since TP=2, the factor is 2.)
    • [0 1 2 3 4 5 6 7] -> [0 1 2 3] and [4 5 6 7].
    • In the context of this PR and fully_shard, this has already been done via initializing the TransformerEngine module with TP and calling _convert_param_to_dtensor_param!
  • Shard the pre-shards for _StridedShard.
    • [0] [1] [2] [3] and [4] [5] [6] [7]
  • Concatenate the strided shards.
    • [0 4] [1 5] [2 6] [3 7], which are assigned to the _StridedShard ranks.
    • Note that this is very different if we did left-to-right-sharding, which would have given us [0 1] [2 3] [4 5] [6 7]!
  • Subsequently / finally, each strided shard is sharded on the Shard placement.
    • [0] [4] / [1] [5] / [2] [6] / [3] [7], which are assigned to the Shard ranks.
    • Note that this is very different if we did left-to-right sharding, which would have given us [0] [1] / [2] [3] / [4] [5] / [6] [7]!

PyTorch also supports the inverse / un-sharding of this redistribute, which is literally the inverse of these simple operations! (Though things get a bit more complicated with un-even shards from odd-numbered dimension sizes.)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Greptile Summary

This PR adds DCP (Distributed Checkpoint) compatibility for FSDP2 × TP sharding across all TransformerEngineBaseModule subclasses by introducing a set_device_mesh(tp_mesh, weight_mesh) API that converts module parameters to TP-sharded DTensors before fully_shard() is called. It also fixes a bug where the amax_reduction_group for Float8CurrentScaling was hard-coded to the "shard" mesh dimension name, and includes a full DCP save/load round-trip test.

Key changes and observations:

  • New set_device_mesh API added to Linear, LayerNormLinear, LayerNormMLP, LayerNorm, RMSNorm, DotProductAttention, MultiheadAttention, and TransformerLayer; propagated recursively through module hierarchies.
  • New utilities _convert_param_to_dtensor_param and _extract_trainable_tensor_from_dtensor in distributed.py handle the DTensor ↔ plain Tensor lifecycle, including requires_grad re-propagation which DTensor suppresses for FP8 parameters.
  • DTensor unwrapping added in ops/basic/layer_norm.py and ops/basic/rmsnorm.py forward/backward paths, with gradients correctly re-wrapped as DTensors in the backward pass.
  • quantizers refactored from {"scaling_fwd": {}, "scaling_bwd": {}} (dict of dicts) to {"scaling_fwd": [], "scaling_bwd": []} (dict of lists) in base.py, consistent with integer-indexed access throughout the codebase.
  • LayerNormMLP backward fix: The condition isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) was always False (comparing a quantizer object to a tensor type); corrected to isinstance(ctx.fc1_weight, QuantizedTensorStorage). This now enables update_usage(columnwise_usage=True) in the backward pass for FP8 LayerNormMLP, a behavioral change that could affect FP8 numerics.
  • Test infrastructure extended to exercise FSDP, HSDP, and (H/F)SDP-TP topologies with DCP round-trip parity checks; however, the skip guard at line 96 in test_torch_fsdp2.py uses < instead of !=, causing hard assertion failures rather than clean skips on machines with even GPU counts that don't align with the HSDP-TP topology (e.g. 6 GPUs).

Confidence Score: 3/5

  • Core DTensor integration is well-structured and safe, but contains a test skip-logic bug that can cause hard failures on certain GPU counts, and introduces a behavioral change in FP8 backward that could affect numerics.
  • The PR introduces solid DTensor integration with proper lifecycle management across TransformerEngine modules. However, two issues reduce confidence: (1) the test skip condition uses < instead of !=, which will cause assertion failures rather than clean skips when running on machines with even GPU counts that don't align with HSDP-TP topology (e.g., 6 GPUs), making test results unpredictable across different hardware configurations; (2) the LayerNormMLP backward fix correctly changes from a dead code path (always-False condition) to an active path that now invokes update_usage(columnwise_usage=True) for FP8 backward passes, which is logically correct but represents an undocumented behavioral change that could affect FP8 numerics for existing models. Both issues are fixable but warrant careful attention.
  • tests/pytorch/distributed/test_torch_fsdp2.py (skip condition logic), transformer_engine/pytorch/module/layernorm_mlp.py (backward behavioral change impact on FP8 numerics)

Comments Outside Diff (2)

  1. tests/pytorch/distributed/test_torch_fsdp2.py, line 96-99 (link)

    Incorrect skip condition for (H/F)SDP-TP topology

    The condition NUM_PROCS < parallel_size does not correctly guard against GPU counts that are even (satisfying the outer skipif) but not aligned with the HSDP-TP topology. For NUM_PROCS = 6, sharding_dims = [6//4, 2, 2] = [1, 2, 2] and parallel_size = 4. Since 6 > 4, the test is not skipped. The test then calls torchrun --nproc_per_node=6 which spawns 6 processes, and get_device_mesh([1, 2, 2]) in run_fsdp2_model.py will immediately hit:

    assert sharding_dims[0] * sharding_dims[1] * sharding_dims[2] == world_size
    # AssertionError: 4 != 6

    The condition should require an exact match between NUM_PROCS and the required process count:

  2. transformer_engine/pytorch/module/layernorm_mlp.py, line 1375-1378 (link)

    Bug fix changes FP8 backward behavior

    This PR corrects a pre-existing bug where isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) was being checked (comparing a quantizer object against a tensor type, always False). The fix changes it to check isinstance(ctx.fc1_weight, QuantizedTensorStorage), which is logically correct.

    However, this means ctx.fc1_weight.update_usage(columnwise_usage=True) will now actually be invoked in the backward pass for quantized FP8 weights, whereas it was silently skipped before. This is a behavioral change that could affect FP8 numerics for existing LayerNormMLP models.

    While the fix is correct, the change from dead code to active code may surprise users who were inadvertently relying on the old behavior. Ensure existing FP8 LayerNormMLP backward tests cover this code path.

Last reviewed commit: bc82f02

@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from 4ec2947 to dbb9d14 Compare March 4, 2026 18:10
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from fcdd5bd to c912f5b Compare March 5, 2026 16:06
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from c912f5b to 2aadb35 Compare March 5, 2026 18:30
cspades and others added 4 commits March 5, 2026 15:50
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch 2 times, most recently from ae064d0 to a7a17c2 Compare March 6, 2026 01:37
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from a7a17c2 to bc82f02 Compare March 6, 2026 17:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant