Skip to content

Commit 33d4211

Browse files
stefpiAwni Hannun
authored andcommitted
docs: extract data parallel training into seperate example doc
1 parent 8d93b91 commit 33d4211

4 files changed

Lines changed: 99 additions & 84 deletions

File tree

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
.. _data_parallelism:
2+
3+
Data Parallelism
4+
================
5+
6+
MLX enables efficient data parallel distributed training through its distributed communication primitives.
7+
8+
.. _training_example:
9+
10+
Training Example
11+
----------------
12+
13+
In this section we will adapt an MLX training loop to support data parallel
14+
distributed training. Namely, we will average the gradients across a set of
15+
hosts before applying them to the model.
16+
17+
Our training loop looks like the following code snippet if we omit the model,
18+
dataset and optimizer initialization.
19+
20+
.. code:: python
21+
22+
model = ...
23+
optimizer = ...
24+
dataset = ...
25+
26+
def step(model, x, y):
27+
loss, grads = loss_grad_fn(model, x, y)
28+
optimizer.update(model, grads)
29+
return loss
30+
31+
for x, y in dataset:
32+
loss = step(model, x, y)
33+
mx.eval(loss, model.parameters())
34+
35+
All we have to do to average the gradients across machines is perform an
36+
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
37+
have to :func:`mlx.utils.tree_map` the gradients with following function.
38+
39+
.. code:: python
40+
41+
def all_avg(x):
42+
return mx.distributed.all_sum(x) / mx.distributed.init().size()
43+
44+
Putting everything together our training loop step looks as follows with
45+
everything else remaining the same.
46+
47+
.. code:: python
48+
49+
from mlx.utils import tree_map
50+
51+
def all_reduce_grads(grads):
52+
N = mx.distributed.init().size()
53+
if N == 1:
54+
return grads
55+
return tree_map(
56+
lambda x: mx.distributed.all_sum(x) / N,
57+
grads
58+
)
59+
60+
def step(model, x, y):
61+
loss, grads = loss_grad_fn(model, x, y)
62+
grads = all_reduce_grads(grads) # <--- This line was added
63+
optimizer.update(model, grads)
64+
return loss
65+
66+
Utilizing ``nn.average_gradients``
67+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
68+
69+
Although the code example above works correctly; it performs one communication
70+
per gradient. It is significantly more efficient to aggregate several gradients
71+
together and perform fewer communication steps.
72+
73+
This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks
74+
almost identical to the example above:
75+
76+
.. code:: python
77+
78+
model = ...
79+
optimizer = ...
80+
dataset = ...
81+
82+
def step(model, x, y):
83+
loss, grads = loss_grad_fn(model, x, y)
84+
grads = mx.nn.average_gradients(grads) # <---- This line was added
85+
optimizer.update(model, grads)
86+
return loss
87+
88+
for x, y in dataset:
89+
loss = step(model, x, y)
90+
mx.eval(loss, model.parameters())

docs/src/examples/tensor_parallelism.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _tensor_parallelism:
2+
13
Tensor Parallelism
24
==================
35

@@ -60,7 +62,7 @@ We can create partial inputs based on rank. For example, for an input with 1024
6062
layer = nn.ShardedToAllLinear(1024, 1024, bias=False) # initialize the layer
6163
y = layer(x[part]) # process sharded input
6264
63-
This code splits the 1024 input features into ``world.size()`` different groups which are assigned continuously based on ``world.rank()``. More information about distributed communication can be found in the :doc:`Distributed Communication <../usage/distributed>` page.
65+
This code splits the 1024 input features into ``world.size()`` different groups which are assigned continuously based on ``world.rank()``. More information about distributed communication can be found in the :ref:`Distributed Communication <usage_distributed>` page.
6466

6567
:class:`QuantizedShardedToAllLinear <mlx.nn.QuantizedShardedToAllLinear>` is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`.
6668
Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and
@@ -117,7 +119,7 @@ LLM Inference with Tensor Parallelism
117119

118120
We can apply these TP techniques to LLMs in order to enable inference for much larger models by sharding parameters from huge layers across multiple devices.
119121

120-
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama Inference <../examples/llama-inference>` example. In this example, we will use the same inference script as the Llama Inference example, which can be found in `mlx-examples`_.
122+
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama Inference <llama-inference>` example. In this example, we will use the same inference script as the Llama Inference example, which can be found in `mlx-examples`_.
121123

122124
Our first edit is to initialize the distributed communication group and get the current process rank:
123125

docs/src/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ are the CPU and GPU.
5454
examples/linear_regression
5555
examples/mlp
5656
examples/llama-inference
57+
examples/data_parallelism
5758
examples/tensor_parallelism
5859

5960
.. toctree::

docs/src/usage/distributed.rst

Lines changed: 4 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -117,89 +117,11 @@ The following examples aim to clarify the backend initialization logic in MLX:
117117
world_ring = mx.distributed.init(backend="ring")
118118
world_any = mx.distributed.init() # same as MPI because it was initialized first!
119119
120-
.. _training_example:
120+
Distributed Program Examples
121+
----------------------------
121122

122-
Training Example
123-
----------------
124-
125-
In this section we will adapt an MLX training loop to support data parallel
126-
distributed training. Namely, we will average the gradients across a set of
127-
hosts before applying them to the model.
128-
129-
Our training loop looks like the following code snippet if we omit the model,
130-
dataset and optimizer initialization.
131-
132-
.. code:: python
133-
134-
model = ...
135-
optimizer = ...
136-
dataset = ...
137-
138-
def step(model, x, y):
139-
loss, grads = loss_grad_fn(model, x, y)
140-
optimizer.update(model, grads)
141-
return loss
142-
143-
for x, y in dataset:
144-
loss = step(model, x, y)
145-
mx.eval(loss, model.parameters())
146-
147-
All we have to do to average the gradients across machines is perform an
148-
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
149-
have to :func:`mlx.utils.tree_map` the gradients with following function.
150-
151-
.. code:: python
152-
153-
def all_avg(x):
154-
return mx.distributed.all_sum(x) / mx.distributed.init().size()
155-
156-
Putting everything together our training loop step looks as follows with
157-
everything else remaining the same.
158-
159-
.. code:: python
160-
161-
from mlx.utils import tree_map
162-
163-
def all_reduce_grads(grads):
164-
N = mx.distributed.init().size()
165-
if N == 1:
166-
return grads
167-
return tree_map(
168-
lambda x: mx.distributed.all_sum(x) / N,
169-
grads
170-
)
171-
172-
def step(model, x, y):
173-
loss, grads = loss_grad_fn(model, x, y)
174-
grads = all_reduce_grads(grads) # <--- This line was added
175-
optimizer.update(model, grads)
176-
return loss
177-
178-
Utilizing ``nn.average_gradients``
179-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
180-
181-
Although the code example above works correctly; it performs one communication
182-
per gradient. It is significantly more efficient to aggregate several gradients
183-
together and perform fewer communication steps.
184-
185-
This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks
186-
almost identical to the example above:
187-
188-
.. code:: python
189-
190-
model = ...
191-
optimizer = ...
192-
dataset = ...
193-
194-
def step(model, x, y):
195-
loss, grads = loss_grad_fn(model, x, y)
196-
grads = mx.nn.average_gradients(grads) # <---- This line was added
197-
optimizer.update(model, grads)
198-
return loss
199-
200-
for x, y in dataset:
201-
loss = step(model, x, y)
202-
mx.eval(loss, model.parameters())
123+
- :ref:`Data Parallelism <data_parallelism>`
124+
- :ref:`Tensor Parallelism <tensor_parallelism>`
203125

204126
.. _ring_section:
205127

0 commit comments

Comments
 (0)