Skip to content

Commit 627f84e

Browse files
authored
Merge pull request #432 from Modalities/improve_data_writeout_perf
Tokenization speedup and llama3-like weight init
2 parents 4190fef + bb1fd90 commit 627f84e

8 files changed

Lines changed: 417 additions & 51 deletions

File tree

src/modalities/config/instantiation_models.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import os
23
from pathlib import Path
34
from typing import Annotated, Any, Optional
@@ -27,6 +28,8 @@
2728
from modalities.util import warn_rank_0
2829
from modalities.utils.profilers.profilers import SteppableNoProfiler
2930

31+
logger = logging.getLogger(__name__)
32+
3033

3134
class CudaEnvSettings(BaseModel):
3235
local_rank: Annotated[int, Field(strict=True, ge=0)]
@@ -46,6 +49,7 @@ class ConsistencyEnforcement(BaseModel):
4649
enforce_last_step_logged: bool = True
4750
enforce_last_step_evaluated: bool = True
4851
enforce_last_step_checkpointed: bool = True
52+
enforce_enough_tokens_in_dataset: bool = True
4953

5054

5155
class Intervals(BaseModel):
@@ -192,15 +196,14 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel
192196

193197
@model_validator(mode="after")
194198
def _check_token_amount_in_dataset(self) -> "TrainingComponentsInstantiationModel":
195-
if (
196-
len(self.train_dataset) * self.settings.step_profile.sequence_length
197-
< self.settings.training_target.num_target_tokens
198-
):
199-
raise ValueError(
200-
"Not enough tokens in the dataset. "
201-
f"Actual: {len(self.train_dataset) * self.settings.step_profile.sequence_length}, "
202-
f"Expected: >={self.settings.training_target.num_target_tokens}"
203-
)
199+
dataset_tokens = len(self.train_dataset) * self.settings.step_profile.sequence_length
200+
expected_tokens = self.settings.training_target.num_target_tokens
201+
if dataset_tokens < expected_tokens:
202+
msg = f"Not enough tokens in dataset. Actual: {dataset_tokens}, Expected: >={expected_tokens}"
203+
if self.settings.consistency_enforcement.enforce_enough_tokens_in_dataset:
204+
raise ValueError(msg)
205+
else:
206+
logger.warning(msg)
204207
return self
205208

206209

src/modalities/dataloader/preprocessing/tokenization/tokenized_file_writer.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import math
22
import os
33
import pickle
4-
from itertools import repeat
54
from pathlib import Path
65
from typing import BinaryIO
76

@@ -82,30 +81,56 @@ def _write_index_segment(file_descriptor: BinaryIO, index_list: list[tuple[int,
8281
def _write_data_segment(
8382
file_descriptor: BinaryIO, token_data: list[np.ndarray], token_size_in_bytes: int, write_batch_size: int
8483
) -> list[tuple[int, int]]:
85-
def encoded_token_to_bytes(encoded_token: int, token_size_in_bytes: int) -> bytes:
86-
# Converts an token_ids to its byte representation.
87-
try:
88-
token_bytes = encoded_token.to_bytes(token_size_in_bytes, byteorder="little", signed=False)
89-
except OverflowError as e:
90-
raise ValueError(f"Token {encoded_token} cannot be represented by {token_size_in_bytes} bytes.") from e
91-
return token_bytes
92-
93-
samples = []
94-
index_list = []
84+
# Fast path: vectorized cast + tobytes (no per-token Python work).
85+
# Preserves little-endian unsigned representation and overflow checks.
86+
87+
if token_size_in_bytes == 1:
88+
dtype = np.dtype("u1")
89+
elif token_size_in_bytes == 2:
90+
dtype = np.dtype("<u2") # force little-endian
91+
elif token_size_in_bytes == 4:
92+
dtype = np.dtype("<u4") # force little-endian
93+
else:
94+
raise ValueError("Currently only support token byte sizes of 1, 2, and 4.")
95+
96+
max_allowed = 2 ** (8 * token_size_in_bytes) - 1
97+
98+
samples: list[bytes] = []
99+
index_list: list[tuple[int, int]] = []
95100
curr_offset = 0
101+
pending = 0
102+
96103
for sample_tokens in token_data:
97-
# convert token_ids to byte representation
98-
sample_token_byte_string = b"".join(
99-
map(encoded_token_to_bytes, sample_tokens.tolist(), repeat(token_size_in_bytes))
100-
)
104+
arr = np.asarray(sample_tokens)
105+
106+
# ---- Overflow / range check (preserves original semantics) ----
107+
if arr.size:
108+
min_val = int(arr.min())
109+
max_val = int(arr.max())
110+
if min_val < 0 or max_val > max_allowed:
111+
raise ValueError(
112+
f"Token values out of range for {token_size_in_bytes} bytes: "
113+
f"min={min_val}, max={max_val}, allowed=[0, {max_allowed}]"
114+
)
115+
# ----------------------------------------------------------------
116+
117+
# Cast to correct unsigned little-endian dtype
118+
arr = np.asarray(arr, dtype=dtype, order="C")
119+
sample_token_byte_string = arr.tobytes(order="C")
120+
101121
samples.append(sample_token_byte_string)
102122
index_list.append((curr_offset, len(sample_token_byte_string)))
103123
curr_offset += len(sample_token_byte_string)
104-
if len(samples) % write_batch_size == 0:
124+
125+
pending += 1
126+
if pending >= write_batch_size:
105127
file_descriptor.write(b"".join(samples))
106-
samples = []
128+
samples.clear()
129+
pending = 0
130+
107131
if len(samples) > 0:
108132
file_descriptor.write(b"".join(samples))
133+
109134
return index_list
110135

111136
@staticmethod
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import math
2+
import re
3+
from typing import Annotated, Callable
4+
5+
import torch
6+
import torch.nn as nn
7+
from pydantic import BaseModel, Field
8+
9+
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
10+
from modalities.utils.logger_utils import get_logger
11+
12+
logger = get_logger(name="llama3 initialization")
13+
14+
15+
class Llama3InitializerConfig(BaseModel):
16+
num_layers: Annotated[int, Field(strict=True, gt=0)]
17+
n_embd: Annotated[int, Field(strict=True, gt=0)]
18+
depth_init: bool = True
19+
20+
21+
class Llama3Initializer(ModelInitializationIF):
22+
"""
23+
Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan.
24+
"""
25+
26+
def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
27+
"""
28+
Initializes the Llama3Initializer.
29+
Args:
30+
num_layers: The number of transformer layers in the model. Used to calculate std for certain parameters.
31+
n_embd: The embedding dimension of the model. Used to calculate std and truncation for certain parameters.
32+
depth_init: Whether to use depth-aware initialization for certain parameters, where the std
33+
is scaled based on the layer's depth in the model. If False, a constant std is
34+
used for all layers baed on num_layers.
35+
"""
36+
super().__init__()
37+
self.depth_init = depth_init
38+
39+
self.regex_to_init = {
40+
# embedding weights
41+
r"transformer\.wte\.weight": (nn.init.normal_, {"mean": 0.0, "std": 1}),
42+
# lm head weights
43+
r"transformer\.lm_head\.weight": (
44+
trunc_normal_,
45+
{
46+
"mean": 0.0,
47+
"std": 1 / math.sqrt(n_embd),
48+
"a": -3 / math.sqrt(n_embd),
49+
"b": 3 / math.sqrt(n_embd),
50+
},
51+
),
52+
# qkv projections
53+
r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": (
54+
trunc_normal_,
55+
{
56+
"mean": 0.0,
57+
"std": 0.02,
58+
"a": -2,
59+
"b": 2,
60+
},
61+
),
62+
# final attention projection in attention block
63+
r"transformer\.h\.\d+\.attn\.c_proj\.weight": (
64+
trunc_normal_,
65+
{
66+
"mean": 0.0,
67+
"std": (
68+
(lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1)))
69+
if depth_init
70+
else 0.02 / math.sqrt(2 * num_layers)
71+
),
72+
"a": -2,
73+
"b": 2,
74+
},
75+
),
76+
# SwiGLU
77+
r"transformer\.h\.\d+\.mlp\.(W)\.weight": (
78+
trunc_normal_,
79+
{
80+
"mean": 0.0,
81+
"std": 0.02,
82+
"a": -2,
83+
"b": 2,
84+
},
85+
),
86+
r"transformer\.h\.\d+\.mlp\.(V|W_2)\.weight": (
87+
trunc_normal_,
88+
{
89+
"mean": 0.0,
90+
"std": (
91+
(lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1)))
92+
if depth_init
93+
else 0.02 / math.sqrt(2 * num_layers)
94+
),
95+
"a": -2,
96+
"b": 2,
97+
},
98+
),
99+
}
100+
101+
def initialize_in_place(self, model: nn.Module):
102+
self._init_by_fqn_regex(model, self.regex_to_init, depth_init=self.depth_init)
103+
104+
@staticmethod
105+
def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]], depth_init: bool):
106+
hits = {k: 0 for k in regex_to_init.keys()}
107+
108+
for parameter_name, p in model.named_parameters():
109+
if parameter_name.endswith("bias"):
110+
raise ValueError(
111+
f"Bias initialization is not allowed for Llama3Initializer. Found bias parameter: {parameter_name}"
112+
)
113+
match_count = 0
114+
for weight_regex in regex_to_init.keys():
115+
if re.fullmatch(weight_regex, parameter_name):
116+
init_fn, arg_dict = regex_to_init[weight_regex]
117+
if arg_dict["std"] is not None and callable(arg_dict["std"]):
118+
# If std is a function, call it with the layer_id
119+
layer_id_match = re.search(r"transformer\.h\.(\d+)\.", parameter_name)
120+
if layer_id_match is not None:
121+
layer_id = int(layer_id_match.group(1))
122+
arg_dict = arg_dict.copy() # create a copy of the arg_dict to avoid mutating the original
123+
arg_dict["std"] = arg_dict["std"](layer_id)
124+
else:
125+
raise ValueError(
126+
f"Could not extract layer_id from parameter name {parameter_name} "
127+
"for dynamic std calculation"
128+
)
129+
init_fn(p, **arg_dict)
130+
match_count += 1
131+
hits[weight_regex] += 1
132+
133+
if match_count == 0:
134+
logger.warning(f"Parameter {parameter_name} did not match any regex for initialization")
135+
elif match_count > 1:
136+
raise ValueError(
137+
f"Parameter {parameter_name} matched multiple regexes for initialization, which is not allowed"
138+
)
139+
140+
for k, count in hits.items():
141+
if count == 0:
142+
raise ValueError(
143+
f"Regex {k} did not match any FQNs. The model specification probably does not match LLama3."
144+
)
145+
146+
147+
def trunc_normal_(
148+
tensor: torch.Tensor,
149+
mean: float = 0.0,
150+
std: float = 1.0,
151+
a: float = -2.0,
152+
b: float = 2.0,
153+
):
154+
"""
155+
Fills the input tensor with values sampled from a truncated normal distribution.
156+
Values are drawn from a normal distribution with the given mean and standard
157+
deviation. Any sampled values outside the range defined by a and b are resampled
158+
until they fall within the bounds.
159+
160+
To avoid numerical instability in torch.nn.init.trunc_normal_, the initialization
161+
is always performed using float32 precision. The result is then cast back to the
162+
original data type of the input tensor.
163+
164+
Args:
165+
tensor: an n dimensional torch Tensor
166+
mean: the mean of the normal distribution
167+
std: the standard deviation of the normal distribution
168+
a: the lower bound for truncation
169+
b: the upper bound for truncation
170+
171+
Returns:
172+
The input tensor filled with values from the truncated normal distribution.
173+
"""
174+
# This function is copied from from Meta's open-source project TorchTitan,
175+
# licensed under the BSD 3-Clause License.
176+
tmp = tensor.float()
177+
nn.init.trunc_normal_(tmp, mean=mean, std=std, a=a, b=b)
178+
tensor.copy_(tmp)

src/modalities/registry/components.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
)
9393
from modalities.models.gpt2.collator import GPT2LLMCollateFn
9494
from modalities.models.gpt2.gpt2_model import GPT2LLMConfig
95+
from modalities.models.gpt2.llama3_like_initialization import Llama3Initializer, Llama3InitializerConfig
9596
from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig
9697
from modalities.models.model_factory import GPT2ModelFactory, ModelFactory
9798
from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory
@@ -240,6 +241,12 @@ class ComponentEntity:
240241
ComposedInitializationRoutines.get_composed_model_initializer,
241242
ComposedModelInitializationConfig,
242243
),
244+
ComponentEntity(
245+
"model_initialization",
246+
"gpt2_llama3_like",
247+
Llama3Initializer,
248+
Llama3InitializerConfig,
249+
),
243250
# losses
244251
ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig),
245252
# optimizers

tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ app_state_raw:
177177
component_key: app_state
178178
variant_key: raw
179179
config:
180-
model:
180+
model:
181181
instance_key: initialized_model
182182
pass_type: BY_REFERENCE
183183
optimizer:
@@ -288,7 +288,7 @@ optimizer:
288288
eps: 1e-8
289289
weight_decay: 1e-1
290290
weight_decay_groups_excluded: [embedding, layernorm]
291-
wrapped_model:
291+
wrapped_model:
292292
instance_key: initialized_model
293293
pass_type: BY_REFERENCE
294294

0 commit comments

Comments
 (0)