forked from pytorch/torchtitan
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfloat8.py
More file actions
113 lines (88 loc) · 3.93 KB
/
float8.py
File metadata and controls
113 lines (88 loc) · 3.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# [Note] Getting the 'torchao' package:
# This script requires the 'torchao' package to function correctly.
# Please ensure you have this package installed from the appropriate repository.
# You can obtain it from https://github.com/pytorch/ao by following the
# installation instructions.
# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
from typing import List, Union
import torch
import torch.nn as nn
from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.model_converter import ModelConverter, register_model_converter
from torchtitan.parallelisms import ParallelDims
def _is_sm89_or_later():
# Float8 is only supported on SM89 or later (H100+ GPUs)
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
class Float8Converter(ModelConverter):
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.enabled = False
float8_config = job_config.float8
if not _is_sm89_or_later():
logger.warning(
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
)
return
try:
from torchao.float8 import Float8LinearConfig
except ImportError as e:
raise ImportError(
"torchao is not installed. Please install it to use float8 linear layers."
) from e
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
parallel_dims.dp_shard_enabled
and float8_config.enable_fsdp_float8_all_gather
)
self.config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd,
)
self.enabled = True
# for precompute_float8_dynamic_scale_for_fsdp
self.precompute_scale = (
enable_fsdp_float8_all_gather
and float8_config.precompute_float8_dynamic_scale_for_fsdp
)
logger.info("Float8 training active")
def convert(self, model: nn.Module):
return self.convert_to_float8_training(model)
def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
return self.precompute_float8_dynamic_scale_for_fsdp(model)
def convert_to_float8_training(self, model: nn.Module):
"""
This function converts the linear layers of `model` to `Float8Linear`.
Note that today, only dynamic tensor scaling (the default) is supported.
This will mutate the model inplace.
"""
if not self.enabled:
return
from torchao.float8 import convert_to_float8_training
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
convert_to_float8_training(
model,
config=self.config,
module_filter_fn=lambda mod, fqn: fqn != "output",
)
logger.info(
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
f"{self.config.enable_fsdp_float8_all_gather}"
)
def precompute_float8_dynamic_scale_for_fsdp(
self, model: Union[nn.Module, List[nn.Module]]
):
if not self.enabled:
return
if not self.precompute_scale:
return
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
models = [model] if isinstance(model, nn.Module) else model
for m in models:
precompute_float8_dynamic_scale_for_fsdp(m)
register_model_converter(Float8Converter, "float8")