-
Notifications
You must be signed in to change notification settings - Fork 510
Expand file tree
/
Copy pathdiloco_test.py
More file actions
305 lines (274 loc) · 11.3 KB
/
diloco_test.py
File metadata and controls
305 lines (274 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the DiLoCo implementation in diloco.py"""
import os
import unittest
from tempfile import gettempdir
import chex
from flax.experimental import nnx
from flax.training import train_state
import jax
import jax.numpy as jnp
import jax.sharding
import numpy as np
import optax
import pytest
from maxtext.configs.pyconfig import initialize_pydantic
from maxtext.trainers.pre_train.train_compile import main as train_compile_main
from maxtext.trainers.diloco import diloco
from tests.utils.test_helpers import get_test_config_path
class SimpleNNXModel(nnx.Module):
"""A simple state for testing a minimal model."""
def __init__(self, *, rngs: nnx.Rngs):
self.dense = nnx.Linear(
2,
1,
kernel_init=nnx.initializers.constant(jnp.asarray([[2.0], [1.0]])),
bias_init=nnx.initializers.ones_init(),
rngs=rngs,
)
def __call__(self, x):
return self.dense(x)
class DiLoCoTest(unittest.TestCase):
@pytest.mark.tpu_only
def test_diloco_training_simulation_with_mesh(self):
"""Runs a simulation of DiLoCo training on a mesh and asserts correctness."""
num_replicas = 2
num_steps = 4
devices = jax.devices()
if len(devices) < num_replicas:
self.skipTest(f"Test requires {num_replicas} devices, but only {len(devices)} are available.")
mesh_devices = np.array(devices[:num_replicas]).reshape(1, num_replicas)
mesh = jax.sharding.Mesh(mesh_devices, axis_names=("data", "diloco"))
test_config = initialize_pydantic(
[
"",
get_test_config_path(),
f"dcn_diloco_parallelism={num_replicas}",
"ici_diloco_parallelism=1",
"diloco_outer_momentum=0.9",
"diloco_outer_lr=1.0",
f"diloco_sync_period={num_steps-1}",
]
)
with mesh:
tx = optax.sgd(learning_rate=0.1)
rngs = nnx.Rngs(params=jax.random.key(seed=42))
model = SimpleNNXModel(rngs=rngs)
graphdef, params = nnx.split(model)
def nnx_apply_fn(params, inputs):
model_replica = nnx.merge(graphdef, params)
return model_replica(inputs)
# 2. Vmap this new wrapper function
vmapped_apply = jax.vmap(nnx_apply_fn, in_axes=(None, 0))
def _test_train_step(state: train_state.TrainState, batch, prng_key: diloco.PRNGKey):
"""A simple MSE loss train step to enable numerics testing."""
del prng_key
def loss_fn(params, batch):
inputs, labels = batch
logits = vmapped_apply(params, inputs)
residual = logits - labels
sq_residual = jnp.square(residual)
msq_residual = jnp.mean(sq_residual)
return msq_residual
loss, grad = jax.value_and_grad(loss_fn)(state.params, batch)
return state.apply_gradients(grads=grad), loss
initial_test_state = train_state.TrainState.create(
apply_fn=vmapped_apply,
params=params,
tx=tx,
)
diloco_test_state, _ = diloco.build_diloco_state(test_config, lambda: initial_test_state)
chex.assert_equal(diloco_test_state.step, 0)
chex.assert_trees_all_equal(diloco_test_state.params, initial_test_state.params)
diloco_train_step = diloco.build_diloco_train_step(test_config, _test_train_step)
inputs = jnp.array(
[
[[0.0, 1.0], [1.0, 0.0]], # First replica inputs.
[[1.0, 0.0], [0.0, 1.0]], # Second replica inputs.
]
)
labels = jnp.array(
[
[[1.0], [2.0]], # First replica labels.
[[2.0], [3.0]], # Second replica labels.
]
)
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, "diloco"))
inputs = jax.device_put(inputs, sharding)
labels = jax.device_put(labels, sharding)
# Run the first step (no synchronization).
# Replica 0:
# Data: [[0, 1], [1, 0]]
# Labels: [[1], [2]]
# Weights: w = [[2], [1]]
# Bias: b = [1]
# Loss = mean((y - pred)^2) =
# = mean( ([[1], [2]] - (x . w + b)) ^ 2 ) )
# = mean( ([[1], [2]] - ([[0, 1], [1, 0]] . [[2], [1]] + [1])) ^ 2 )
# = mean( ([[1], [2]] - [[2], [3]]) ^ 2 )
# = mean( ([-1, 1]) ^ 2 ) = mean( [1, 1] )
# = 1.0
#
# Replica 1:
# Data: [[1, 0], [0, 1]]
# Labels: [[2], [3]]
# Weights: w = [[2], [1]]
# Bias: b = [1]
# Loss = mean((y - pred)^2) =
# = mean( ([[2], [3]] - (x . w + b)) ^ 2 ) )
# = mean( ([[2], [3]] - ([[1, 0], [0, 1]] . [[2], [1]] + [1])) ^ 2 )
# = mean( ([[2], [3]] - [[3], [2]]) ^ 2 )
# = mean( ([-1, 1]) ^ 2 ) = mean( [1, 1] )
# = 1.0
diloco_test_state, loss = diloco_train_step(diloco_test_state, (inputs, labels), jax.random.key(seed=42))
chex.assert_equal(diloco_test_state.step, 1.0)
chex.assert_equal(loss, 1.0)
# Assert no updates to the global model yet (no synchronization)
chex.assert_trees_all_equal(diloco_test_state.params, initial_test_state.params)
# Run the second step (no synchronization).
# Replica 0:
# Data: [[0, 1], [1, 0]]
# Labels: [[1], [2]]
# Weights: w = [[1.9], [0.9]]
# Bias: b = [0.8]
# Loss = mean((y - pred)^2) =
# = mean( ([[1], [2]] - (x . w + b)) ^ 2 ) )
# = mean( ([[1], [2]] - ([[0, 1], [1, 0]] . [[1.9], [0.9]] + [0.8])) ^ 2 )
# = mean( ([[1], [2]] - [[1.7], [2.7]]) ^ 2 )
# = mean( ([-0.7, 0.7]) ^ 2 ) = mean( [0.49, 0.49] )
# = 0.49
#
# Replica 1:
# Data: [[1, 0], [0, 1]]
# Labels: [[2], [3]]
# Weights: w = [[1.9], [1.1]]
# Bias: b = [1]
# Loss = mean((y - pred)^2) =
# = mean( ([[2], [3]] - (x . w + b)) ^ 2 ) )
# = mean( ([[2], [3]] - ([[1, 0], [0, 1]] . [[1.9], [1.1]] + [1])) ^ 2 )
# = mean( ([[2], [3]] - [[2.9], [2.1]]) ^ 2 )
# = mean( ([-0.9, 0.9]) ^ 2 ) = mean( [0.81, 0.81] )
# = 0.81
diloco_test_state, loss = diloco_train_step(diloco_test_state, (inputs, labels), jax.random.key(seed=42))
chex.assert_equal(diloco_test_state.step, 2.0)
chex.assert_trees_all_close(loss, 0.65)
# Assert no updates to the global model yet (no synchronization)
chex.assert_trees_all_equal(diloco_test_state.params, initial_test_state.params)
# Run the third step, which synchronizes afterwards.
# Replica 0:
# Data: [[0, 1], [1, 0]]
# Labels: [[1], [2]]
# Weights: w = [[1.83], [0.83]]
# Bias: b = [0.66]
# Loss = mean((y - pred)^2) =
# = mean( ([[1], [2]] - (x . w + b)) ^ 2 ) )
# = mean( ([[1], [2]] - ([[0, 1], [1, 0]] . [[1.83], [0.83]] + [0.66])) ^ 2 )
# = mean( ([[1], [2]] - [[1.49], [2.49]]) ^ 2 )
# = mean( ([-0.49, 0.49]) ^ 2 ) = mean( [0.2401, 0.2401] )
# = 0.2401
#
# Replica 1:
# Data: [[1, 0], [0, 1]]
# Labels: [[2], [3]]
# Weights: w = [[1.81], [1.19]]
# Bias: b = [1.]
# Loss = mean((y - pred)^2) =
# = mean( ([[2], [3]] - (x . w + b)) ^ 2 ) )
# = mean( ([[2], [3]] - ([[1, 0], [0, 1]] . [[1.81], [1.19]] + [1])) ^ 2 )
# = mean( ([[2], [3]] - [[2.81], [2.19]]) ^ 2 )
# = mean( ([-0.81, 0.81]) ^ 2 ) = mean( [0.6561, 0.6561] )
# = 0.6561
#
# After these are averaged, the model differences are computed to create a
# pseudo-gradient update to the outer_params and applied via a momentum
# based outer optimizer.
diloco_test_state, loss = diloco_train_step(diloco_test_state, (inputs, labels), jax.random.key(seed=42))
chex.assert_equal(diloco_test_state.step, 3.0)
chex.assert_trees_all_close(loss, 0.4481)
# Assert that inner and outer parameters are all equal now that
# synchronization has happened.
chex.assert_trees_all_equal(
diloco_test_state.params,
jax.tree.map(lambda arr: arr[0, ...], diloco_test_state.inner_state.params),
)
chex.assert_trees_all_equal(
diloco_test_state.params,
jax.tree.map(lambda arr: arr[1, ...], diloco_test_state.inner_state.params),
)
# Run the fourth step (no synchronization).
# Replica 0:
# Data: [[0, 1], [1, 0]]
# Labels: [[1], [2]]
# Weights: w = [[1.5345], [1.0494]]
# Bias: b = [0.5839]
# Loss = mean((y - pred)^2) =
# = mean( ([[1], [2]] - (x . w + b)) ^ 2 ) )
# = mean( ([[1], [2]] - ([[0, 1], [1, 0]] . [[1.5345], [1.0494]]] + [0.5839])) ^ 2 )
# = mean( ([[1], [2]] - [[1.6333], [2.1184]]) ^ 2 )
# = mean( ([-0.6333, 0.1184]) ^ 2 ) = mean( [0.4010, 0.0140] )
# ~ 0.2075
#
# Replica 1:
# Data: [[1, 0], [0, 1]]
# Labels: [[2], [3]]
# Weights: w = [[1.5345], [1.0494]]
# Bias: b = [0.5839]
# Loss = mean((y - pred)^2) =
# = mean( ([[2], [3]] - (x . w + b)) ^ 2 ) )
# = mean( ([[2], [3]] - ([[1, 0], [0, 1]] . [[1.5345], [1.0494]] + [0.5839])) ^ 2 )
# = mean( ([[2], [3]] - [[2.1184], [1.6333]]) ^ 2 )
# = mean( ([-0.1184, 1.3667]) ^ 2 ) = mean( [0.0140, 1.8678] )
# ~ 0.94
step_three_outer_params = diloco_test_state.params
diloco_test_state, loss = diloco_train_step(diloco_test_state, (inputs, labels), jax.random.key(seed=42))
chex.assert_equal(diloco_test_state.step, 4.0)
chex.assert_trees_all_close(loss, 0.574244)
# Assert no updates to the global model since previous step (no
# synchronization).
chex.assert_trees_all_equal(diloco_test_state.params, step_three_outer_params)
@pytest.mark.cpu_only
def test_diloco_qwen3_moe_two_slices(self):
temp_dir = gettempdir()
compiled_trainstep_file = os.path.join(temp_dir, "test_compiled_diloco_qwen3_moe.pickle")
train_compile_main(
(
None,
get_test_config_path(),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=tpu7x-16",
"compile_topology_num_slices=2",
"ici_fsdp_parallelism=-1",
"dcn_diloco_parallelism=2",
"enable_diloco=true",
"model_name=qwen3-30b-a3b",
)
)
@pytest.mark.tpu_only
def test_diloco_two_slices(self):
temp_dir = gettempdir()
compiled_trainstep_file = os.path.join(temp_dir, "test_compiled_diloco.pickle")
train_compile_main(
(
None,
get_test_config_path(),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=tpu7x-8",
"compile_topology_num_slices=2",
"ici_fsdp_parallelism=-1",
"dcn_diloco_parallelism=2",
"enable_diloco=true",
"model_name=gemma2-2b",
)
)