|
4 | 4 | # |
5 | 5 | # See LICENSE for license information. |
6 | 6 |
|
| 7 | +import argparse |
7 | 8 | import os |
8 | 9 | import sys |
9 | | -import argparse |
| 10 | +import shutil |
| 11 | +from contextlib import nullcontext |
| 12 | +from copy import deepcopy |
10 | 13 | from dataclasses import dataclass |
| 14 | +from pathlib import Path |
11 | 15 |
|
12 | 16 | import transformer_engine.pytorch as te |
13 | 17 | import transformer_engine.common.recipe |
14 | | - |
| 18 | +from transformer_engine.pytorch import QuantizedTensor |
15 | 19 | import torch |
16 | 20 | import torch.distributed as dist |
17 | 21 | from torch.distributed.checkpoint import save, load |
|
27 | 31 | from torch.distributed import DeviceMesh |
28 | 32 | from torch.distributed._composable.fsdp import fully_shard |
29 | 33 | from torch.distributed.device_mesh import init_device_mesh |
30 | | -from transformer_engine.pytorch import QuantizedTensor |
31 | | -from contextlib import nullcontext |
32 | 34 |
|
33 | 35 | LOCAL_RANK = None |
34 | 36 |
|
| 37 | +# Needed for `torch.distributed.checkpoint.{save,load}` because |
| 38 | +# multiple processes need to write to the same directory. |
| 39 | +SHARED_TMP_DIR = "/tmp/pytest-shared-tmp" |
| 40 | + |
35 | 41 |
|
36 | 42 | @dataclass |
37 | 43 | class AppState(Stateful): |
@@ -63,7 +69,7 @@ def state_dict(self): |
63 | 69 | # yet get_state_dict / _init_optim_state produce empty Tensors. |
64 | 70 | # TransformerEngine uses empty Tensors for dummy Parameters. |
65 | 71 | optimizer_state_dict["state"][fqn] = {} |
66 | | - if fqn.endswith("._extra_state"): |
| 72 | + if fqn.endswith("_extra_state"): |
67 | 73 | # Evict `_extra_state` quantization data from model checkpoint. |
68 | 74 | model_state_dict.pop(fqn) |
69 | 75 | return { |
@@ -352,7 +358,9 @@ def test_fp8_fsdp2_allgather(model): |
352 | 358 | # FP32 manual weight allgather |
353 | 359 | fp32_allgathered_params = {} |
354 | 360 | for name, param in model.named_parameters(): |
355 | | - assert isinstance(param, DTensor) |
| 361 | + assert isinstance( |
| 362 | + param, DTensor |
| 363 | + ), f"[test_fp8_fsdp2_allgather] {param} should be a DTensor." |
356 | 364 | local_tensor = param._local_tensor |
357 | 365 | device_mesh = param.device_mesh |
358 | 366 | dist_group = ( |
@@ -471,7 +479,7 @@ def _train(args): |
471 | 479 | optimizer = optim.Adam(model.parameters(), lr=1e-3) |
472 | 480 |
|
473 | 481 | """ |
474 | | - Pre-Save Training |
| 482 | + FSDP2 Training |
475 | 483 | """ |
476 | 484 | for iteration in range(args.iter): |
477 | 485 | # Zero the parameter gradients |
@@ -499,6 +507,148 @@ def _train(args): |
499 | 507 | if args.fp8_init: |
500 | 508 | test_fp8_fsdp2_allgather(model) |
501 | 509 |
|
| 510 | + """ |
| 511 | + DCP Checkpoint Testing |
| 512 | + """ |
| 513 | + # Compute the pre-save model loss to the last random input |
| 514 | + # with respect to the last random target. |
| 515 | + model.eval() |
| 516 | + with te.autocast(enabled=True, recipe=fp8_recipe): |
| 517 | + output = model(input_data) |
| 518 | + pre_save_loss = F.mse_loss(output, target) |
| 519 | + |
| 520 | + # Save deep copy of the model and optimizer state before checkpointing. |
| 521 | + # NOTE(@cspades): deepcopy has issues with DTensors. Just clone(). |
| 522 | + s1 = {} |
| 523 | + for key, val in model.state_dict().items(): |
| 524 | + s1[key] = val.clone() |
| 525 | + optim_state_dict = optimizer.state_dict() |
| 526 | + o1 = {"state": {}} |
| 527 | + for idx, state in optim_state_dict["state"].items(): |
| 528 | + o1_state = o1["state"].setdefault(idx, {}) |
| 529 | + for key, val in state.items(): |
| 530 | + o1_state[key] = val.clone() |
| 531 | + o1["param_groups"] = deepcopy(optim_state_dict["param_groups"]) |
| 532 | + |
| 533 | + # Write model to checkpoint. |
| 534 | + CKPT_DIR = ( |
| 535 | + Path(SHARED_TMP_DIR) |
| 536 | + / "run_fsdp2_model" |
| 537 | + / f"dcp-{'_'.join(str(x) for x in args.sharding_dims)}-{args.layer_type}-{args.recipe}-fp8_init_{args.fp8_init}" |
| 538 | + ) |
| 539 | + CKPT_DIR.mkdir(parents=True, exist_ok=True, mode=0o777) |
| 540 | + state_dict = {"app": AppState(model=model, optimizer=optimizer)} |
| 541 | + torch.distributed.checkpoint.save(state_dict, checkpoint_id=str(CKPT_DIR)) |
| 542 | + |
| 543 | + # Perform an extra training step to change the weights such that |
| 544 | + # state parity tests will fail unless the checkpoint is loaded |
| 545 | + # without any errors or incongruities vs. the saved model state. |
| 546 | + model.train() |
| 547 | + for iteration in range(args.iter): |
| 548 | + optimizer.zero_grad() |
| 549 | + with ( |
| 550 | + torch.autocast(device_type="cuda", dtype=torch.bfloat16) |
| 551 | + if args.recipe == "NVFP4BlockScaling" |
| 552 | + else nullcontext() |
| 553 | + ): |
| 554 | + with te.autocast(enabled=True, recipe=fp8_recipe): |
| 555 | + output = model(torch.randn(inp_shape).to(device)) |
| 556 | + loss = F.mse_loss(output, torch.randn(out_shape).to(device)) |
| 557 | + loss.backward() |
| 558 | + optimizer.step() |
| 559 | + |
| 560 | + # Load the checkpoint. |
| 561 | + state_dict = {"app": AppState(model=model, optimizer=optimizer)} |
| 562 | + torch.distributed.checkpoint.load(state_dict=state_dict, checkpoint_id=str(CKPT_DIR)) |
| 563 | + |
| 564 | + # FIXME(@cspades): DelayedScaling checkpointing has tiny uint8 parity issues |
| 565 | + # that affects the dequantized model state. Only test loss parity. |
| 566 | + if args.recipe != "DelayedScaling" and args.fp8_init: |
| 567 | + # Validate checkpoint parity with pre-save state dictionaries. |
| 568 | + # Compare pre-save and post-load model state dictionaries. |
| 569 | + s2 = model.state_dict() |
| 570 | + nonempty_model_state = False |
| 571 | + for key in s1.keys() | s2.keys(): |
| 572 | + if key.endswith("_extra_state"): |
| 573 | + # Don't parity test _extra_state. Shape can change after reset_parameters(). |
| 574 | + continue |
| 575 | + v1 = s1.get(key, None) |
| 576 | + if isinstance(v1, DTensor): |
| 577 | + v1 = v1.to_local() |
| 578 | + v2 = s2.get(key, None) |
| 579 | + if isinstance(v2, DTensor): |
| 580 | + v2 = v2.to_local() |
| 581 | + assert ( |
| 582 | + v1 is not None and v2 is not None |
| 583 | + ), f"[{key} Not Found] Original Param: {v1} | Checkpoint Param: {v2}" |
| 584 | + assert ( |
| 585 | + v1.shape == v2.shape |
| 586 | + ), f"[Checkpoint Param {key} Shape Mismatch] {v1.shape} != {v2.shape}" |
| 587 | + assert torch.allclose(v1, v2), f"[Checkpoint Param {key} Value Mismatch] {v1} != {v2}" |
| 588 | + nonempty_model_state = True |
| 589 | + assert nonempty_model_state, "Model state should not be empty for evenly-sharded DTensors!" |
| 590 | + |
| 591 | + # Compare pre-save and post-load optimizer state dictionaries. |
| 592 | + o2 = optimizer.state_dict() |
| 593 | + nonempty_optim_state = False |
| 594 | + for param_id in o1["state"].keys() | o2["state"].keys(): |
| 595 | + param_state_1 = o1["state"].get(param_id, None) |
| 596 | + param_state_2 = o2["state"].get(param_id, None) |
| 597 | + assert param_state_1 is not None and param_state_2 is not None, ( |
| 598 | + f"[{param_id} Not Found] Original Optim State: {param_state_1} | Checkpoint Optim" |
| 599 | + f" State: {param_state_2}" |
| 600 | + ) |
| 601 | + for key in param_state_1.keys() | param_state_2.keys(): |
| 602 | + v1 = param_state_1.get(key, None) |
| 603 | + if isinstance(v1, DTensor): |
| 604 | + v1 = v1.to_local() |
| 605 | + v2 = param_state_2.get(key, None) |
| 606 | + if isinstance(v2, DTensor): |
| 607 | + v2 = v2.to_local() |
| 608 | + assert v1 is not None and v2 is not None, ( |
| 609 | + f"[{param_id} {key} Not Found] Original Optim State: {v1} | Checkpoint Optim" |
| 610 | + f" State: {v2}" |
| 611 | + ) |
| 612 | + assert ( |
| 613 | + v1.shape == v2.shape |
| 614 | + ), f"[Optim State {param_id} {key} Shape Mismatch] {v1.shape} != {v2.shape}" |
| 615 | + assert torch.allclose( |
| 616 | + v1, v2 |
| 617 | + ), f"[Optim State {param_id} {key} Value Mismatch] {v1} != {v2}" |
| 618 | + nonempty_optim_state = True # Optimizer state depends on wgrad, verify this! |
| 619 | + assert ( |
| 620 | + nonempty_optim_state |
| 621 | + ), "Optimizer state should not be empty for evenly-sharded DTensors!" |
| 622 | + assert len(o1["param_groups"]) == len(o2["param_groups"]), ( |
| 623 | + f"[Optim State Param Groups Length Mismatch] {o1['param_groups']} !=" |
| 624 | + f" {o2['param_groups']}" |
| 625 | + ) |
| 626 | + for i in range(len(o2["param_groups"])): |
| 627 | + for key in o1["param_groups"][i].keys(): |
| 628 | + v1 = o1["param_groups"][i][key] |
| 629 | + v2 = o2["param_groups"][i][key] |
| 630 | + assert v1 == v2, f"[Optim State Param Group {i} {key} Value Mismatch] {v1} != {v2}" |
| 631 | + |
| 632 | + # Validate post-load model loss. |
| 633 | + model.eval() |
| 634 | + with ( |
| 635 | + torch.autocast(device_type="cuda", dtype=torch.bfloat16) |
| 636 | + if args.recipe == "NVFP4BlockScaling" |
| 637 | + else nullcontext() |
| 638 | + ): |
| 639 | + with te.autocast(enabled=True, recipe=fp8_recipe): |
| 640 | + output = model(input_data) |
| 641 | + post_load_loss = F.mse_loss(output, target) |
| 642 | + # Allow for 1% disparity due to _extra_state disparity. |
| 643 | + assert torch.allclose( |
| 644 | + pre_save_loss, post_load_loss, rtol=1e-2 |
| 645 | + ), f"Pre-Save Loss: {pre_save_loss} != Post-Load Loss: {post_load_loss}" |
| 646 | + |
| 647 | + # Clean up temporary checkpoint directory. |
| 648 | + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: |
| 649 | + shutil.rmtree(CKPT_DIR) |
| 650 | + torch.distributed.barrier() |
| 651 | + |
502 | 652 | dist.destroy_process_group() |
503 | 653 | return 0 |
504 | 654 |
|
|
0 commit comments