Skip to content

Commit 2aadb35

Browse files
committed
Extend FSDP2 unit tests to include DCP checkpointing and parity tests.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent 018543c commit 2aadb35

6 files changed

Lines changed: 179 additions & 19 deletions

File tree

tests/pytorch/distributed/run_fsdp2_model.py

Lines changed: 157 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
#
55
# See LICENSE for license information.
66

7+
import argparse
78
import os
89
import sys
9-
import argparse
10+
import shutil
11+
from contextlib import nullcontext
12+
from copy import deepcopy
1013
from dataclasses import dataclass
14+
from pathlib import Path
1115

1216
import transformer_engine.pytorch as te
1317
import transformer_engine.common.recipe
14-
18+
from transformer_engine.pytorch import QuantizedTensor
1519
import torch
1620
import torch.distributed as dist
1721
from torch.distributed.checkpoint import save, load
@@ -27,11 +31,13 @@
2731
from torch.distributed import DeviceMesh
2832
from torch.distributed._composable.fsdp import fully_shard
2933
from torch.distributed.device_mesh import init_device_mesh
30-
from transformer_engine.pytorch import QuantizedTensor
31-
from contextlib import nullcontext
3234

3335
LOCAL_RANK = None
3436

37+
# Needed for `torch.distributed.checkpoint.{save,load}` because
38+
# multiple processes need to write to the same directory.
39+
SHARED_TMP_DIR = "/tmp/pytest-shared-tmp"
40+
3541

3642
@dataclass
3743
class AppState(Stateful):
@@ -63,7 +69,7 @@ def state_dict(self):
6369
# yet get_state_dict / _init_optim_state produce empty Tensors.
6470
# TransformerEngine uses empty Tensors for dummy Parameters.
6571
optimizer_state_dict["state"][fqn] = {}
66-
if fqn.endswith("._extra_state"):
72+
if fqn.endswith("_extra_state"):
6773
# Evict `_extra_state` quantization data from model checkpoint.
6874
model_state_dict.pop(fqn)
6975
return {
@@ -352,7 +358,9 @@ def test_fp8_fsdp2_allgather(model):
352358
# FP32 manual weight allgather
353359
fp32_allgathered_params = {}
354360
for name, param in model.named_parameters():
355-
assert isinstance(param, DTensor)
361+
assert isinstance(
362+
param, DTensor
363+
), f"[test_fp8_fsdp2_allgather] {param} should be a DTensor."
356364
local_tensor = param._local_tensor
357365
device_mesh = param.device_mesh
358366
dist_group = (
@@ -471,7 +479,7 @@ def _train(args):
471479
optimizer = optim.Adam(model.parameters(), lr=1e-3)
472480

473481
"""
474-
Pre-Save Training
482+
FSDP2 Training
475483
"""
476484
for iteration in range(args.iter):
477485
# Zero the parameter gradients
@@ -499,6 +507,148 @@ def _train(args):
499507
if args.fp8_init:
500508
test_fp8_fsdp2_allgather(model)
501509

510+
"""
511+
DCP Checkpoint Testing
512+
"""
513+
# Compute the pre-save model loss to the last random input
514+
# with respect to the last random target.
515+
model.eval()
516+
with te.autocast(enabled=True, recipe=fp8_recipe):
517+
output = model(input_data)
518+
pre_save_loss = F.mse_loss(output, target)
519+
520+
# Save deep copy of the model and optimizer state before checkpointing.
521+
# NOTE(@cspades): deepcopy has issues with DTensors. Just clone().
522+
s1 = {}
523+
for key, val in model.state_dict().items():
524+
s1[key] = val.clone()
525+
optim_state_dict = optimizer.state_dict()
526+
o1 = {"state": {}}
527+
for idx, state in optim_state_dict["state"].items():
528+
o1_state = o1["state"].setdefault(idx, {})
529+
for key, val in state.items():
530+
o1_state[key] = val.clone()
531+
o1["param_groups"] = deepcopy(optim_state_dict["param_groups"])
532+
533+
# Write model to checkpoint.
534+
CKPT_DIR = (
535+
Path(SHARED_TMP_DIR)
536+
/ "run_fsdp2_model"
537+
/ f"dcp-{'_'.join(str(x) for x in args.sharding_dims)}-{args.layer_type}-{args.recipe}-fp8_init_{args.fp8_init}"
538+
)
539+
CKPT_DIR.mkdir(parents=True, exist_ok=True, mode=0o777)
540+
state_dict = {"app": AppState(model=model, optimizer=optimizer)}
541+
torch.distributed.checkpoint.save(state_dict, checkpoint_id=str(CKPT_DIR))
542+
543+
# Perform an extra training step to change the weights such that
544+
# state parity tests will fail unless the checkpoint is loaded
545+
# without any errors or incongruities vs. the saved model state.
546+
model.train()
547+
for iteration in range(args.iter):
548+
optimizer.zero_grad()
549+
with (
550+
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
551+
if args.recipe == "NVFP4BlockScaling"
552+
else nullcontext()
553+
):
554+
with te.autocast(enabled=True, recipe=fp8_recipe):
555+
output = model(torch.randn(inp_shape).to(device))
556+
loss = F.mse_loss(output, torch.randn(out_shape).to(device))
557+
loss.backward()
558+
optimizer.step()
559+
560+
# Load the checkpoint.
561+
state_dict = {"app": AppState(model=model, optimizer=optimizer)}
562+
torch.distributed.checkpoint.load(state_dict=state_dict, checkpoint_id=str(CKPT_DIR))
563+
564+
# FIXME(@cspades): DelayedScaling checkpointing has tiny uint8 parity issues
565+
# that affects the dequantized model state. Only test loss parity.
566+
if args.recipe != "DelayedScaling" and args.fp8_init:
567+
# Validate checkpoint parity with pre-save state dictionaries.
568+
# Compare pre-save and post-load model state dictionaries.
569+
s2 = model.state_dict()
570+
nonempty_model_state = False
571+
for key in s1.keys() | s2.keys():
572+
if key.endswith("_extra_state"):
573+
# Don't parity test _extra_state. Shape can change after reset_parameters().
574+
continue
575+
v1 = s1.get(key, None)
576+
if isinstance(v1, DTensor):
577+
v1 = v1.to_local()
578+
v2 = s2.get(key, None)
579+
if isinstance(v2, DTensor):
580+
v2 = v2.to_local()
581+
assert (
582+
v1 is not None and v2 is not None
583+
), f"[{key} Not Found] Original Param: {v1} | Checkpoint Param: {v2}"
584+
assert (
585+
v1.shape == v2.shape
586+
), f"[Checkpoint Param {key} Shape Mismatch] {v1.shape} != {v2.shape}"
587+
assert torch.allclose(v1, v2), f"[Checkpoint Param {key} Value Mismatch] {v1} != {v2}"
588+
nonempty_model_state = True
589+
assert nonempty_model_state, "Model state should not be empty for evenly-sharded DTensors!"
590+
591+
# Compare pre-save and post-load optimizer state dictionaries.
592+
o2 = optimizer.state_dict()
593+
nonempty_optim_state = False
594+
for param_id in o1["state"].keys() | o2["state"].keys():
595+
param_state_1 = o1["state"].get(param_id, None)
596+
param_state_2 = o2["state"].get(param_id, None)
597+
assert param_state_1 is not None and param_state_2 is not None, (
598+
f"[{param_id} Not Found] Original Optim State: {param_state_1} | Checkpoint Optim"
599+
f" State: {param_state_2}"
600+
)
601+
for key in param_state_1.keys() | param_state_2.keys():
602+
v1 = param_state_1.get(key, None)
603+
if isinstance(v1, DTensor):
604+
v1 = v1.to_local()
605+
v2 = param_state_2.get(key, None)
606+
if isinstance(v2, DTensor):
607+
v2 = v2.to_local()
608+
assert v1 is not None and v2 is not None, (
609+
f"[{param_id} {key} Not Found] Original Optim State: {v1} | Checkpoint Optim"
610+
f" State: {v2}"
611+
)
612+
assert (
613+
v1.shape == v2.shape
614+
), f"[Optim State {param_id} {key} Shape Mismatch] {v1.shape} != {v2.shape}"
615+
assert torch.allclose(
616+
v1, v2
617+
), f"[Optim State {param_id} {key} Value Mismatch] {v1} != {v2}"
618+
nonempty_optim_state = True # Optimizer state depends on wgrad, verify this!
619+
assert (
620+
nonempty_optim_state
621+
), "Optimizer state should not be empty for evenly-sharded DTensors!"
622+
assert len(o1["param_groups"]) == len(o2["param_groups"]), (
623+
f"[Optim State Param Groups Length Mismatch] {o1['param_groups']} !="
624+
f" {o2['param_groups']}"
625+
)
626+
for i in range(len(o2["param_groups"])):
627+
for key in o1["param_groups"][i].keys():
628+
v1 = o1["param_groups"][i][key]
629+
v2 = o2["param_groups"][i][key]
630+
assert v1 == v2, f"[Optim State Param Group {i} {key} Value Mismatch] {v1} != {v2}"
631+
632+
# Validate post-load model loss.
633+
model.eval()
634+
with (
635+
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
636+
if args.recipe == "NVFP4BlockScaling"
637+
else nullcontext()
638+
):
639+
with te.autocast(enabled=True, recipe=fp8_recipe):
640+
output = model(input_data)
641+
post_load_loss = F.mse_loss(output, target)
642+
# Allow for 1% disparity due to _extra_state disparity.
643+
assert torch.allclose(
644+
pre_save_loss, post_load_loss, rtol=1e-2
645+
), f"Pre-Save Loss: {pre_save_loss} != Post-Load Loss: {post_load_loss}"
646+
647+
# Clean up temporary checkpoint directory.
648+
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
649+
shutil.rmtree(CKPT_DIR)
650+
torch.distributed.barrier()
651+
502652
dist.destroy_process_group()
503653
return 0
504654

tests/pytorch/distributed/test_torch_fsdp2.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type):
7474
subprocess.run(test_cmd, env=os.environ, check=True)
7575

7676

77-
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
77+
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs.")
7878
@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
7979
@pytest.mark.parametrize(
8080
"sharding_dims",
@@ -83,16 +83,20 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type):
8383
[NUM_PROCS],
8484
# HSDP
8585
[2, NUM_PROCS // 2],
86-
# FSDP-TP
87-
[1, 2, NUM_PROCS // 2],
88-
# HSDP-TP
86+
# (H/F)SDP-TP
8987
[NUM_PROCS // 4, 2, 2],
9088
),
9189
)
9290
@pytest.mark.parametrize("fp8_init", (False, True))
9391
@pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer"))
9492
def test_distributed(fp8_init, sharding_dims, fp_recipe, layer_type):
9593

94+
parallel_size = math.prod(x for x in sharding_dims if x != 0)
95+
if NUM_PROCS < parallel_size:
96+
pytest.skip(
97+
f"Insufficient devices ({NUM_PROCS}) to test sharding configuration: {sharding_dims}"
98+
)
99+
96100
if fp_recipe in ("Float8BlockScaling", "NVFP4BlockScaling") and fp8_init:
97101
pytest.xfail(f"{fp_recipe} + fp8_init: test_fp8_fsdp2_allgather is currently failing.")
98102

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,8 @@ def set_device_mesh(
575575
weight_mesh : Optional[DeviceMesh]
576576
Not used for DotProductAttention as there are no quantized weights.
577577
"""
578-
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
578+
if weight_mesh is not None:
579+
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
579580
if tp_mesh is not None:
580581
# Validate TP DeviceMesh / Group. Must be consistent with tp_size.
581582
assert tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), (

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
from torch.distributed import DeviceMesh
1313
from torch.distributed.tensor import DTensor
14+
from torch.distributed.tensor.placement_types import Replicate, Shard, _StridedShard
1415

1516
import transformer_engine_torch as tex
1617

@@ -800,13 +801,17 @@ def make_grouped_weights(self, defer_init=False) -> None:
800801
weight_quantizers[0] is None or not weight_quantizers[0].internal
801802
), "Found internal quantizer with `single_grouped_parameter=True`."
802803
grouped_param = torch.nn.Parameter(grouped_weights)
803-
if isinstance(getattr(self, f"weight0", None), DTensor):
804+
if isinstance(getattr(self, "weight0", None), DTensor):
804805
# Convert to DTensor with properties equivalent to the original DTensor.
805-
dtensor_member_param = getattr(self, f"weight0")
806+
dtensor_member_param = getattr(self, "weight0")
807+
grouped_3d_placements = tuple(
808+
type(p)(p.dim + 1) if isinstance(p, (Shard, _StridedShard)) else p
809+
for p in dtensor_member_param.placements
810+
)
806811
grouped_param = _convert_param_to_dtensor_param(
807812
grouped_param,
808813
device_mesh=dtensor_member_param.device_mesh,
809-
placements=dtensor_member_param.placements,
814+
placements=grouped_3d_placements,
810815
# DTensor / DCP will view this as a TP-sharded 3-D Tensor.
811816
shape=(self.num_gemms, self.out_features, self.in_features),
812817
# Default Stride: (out*in, in, 1)
@@ -878,8 +883,6 @@ def set_device_mesh(
878883
self.set_tensor_parallel_group(tp_mesh.get_group())
879884

880885
# Construct TP-sharded DTensors.
881-
from torch.distributed.tensor.placement_types import Replicate, Shard
882-
883886
for weight in self.weight_names:
884887
param = getattr(self, weight)
885888
placements = (Replicate(),)

transformer_engine/pytorch/module/layernorm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ def set_device_mesh(
168168
Quantized DTensor parameters are currently not supported for FusibleOperation(s),
169169
and this mesh is not used.
170170
"""
171-
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
171+
if weight_mesh is not None:
172+
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
172173
if tp_mesh is not None:
173174
# Construct TP-Replicate DTensors. Used to shim non-TP parameters for compatibility
174175
# with DTensor parameters in TP layers to support DTensor operations.

transformer_engine/pytorch/module/rmsnorm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ def set_device_mesh(
171171
Quantized DTensor parameters are currently not supported for FusibleOperation(s),
172172
and this mesh is not used.
173173
"""
174-
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
174+
if weight_mesh is not None:
175+
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
175176
if tp_mesh is not None:
176177
# Construct TP-Replicate DTensors. Used to shim non-TP parameters for compatibility
177178
# with DTensor parameters in TP layers to support DTensor operations.

0 commit comments

Comments
 (0)