Skip to content

Latest commit

 

History

History
91 lines (66 loc) · 2.49 KB

File metadata and controls

91 lines (66 loc) · 2.49 KB

Data Parallelism

MLX enables efficient data parallel distributed training through its distributed communication primitives.

Training Example

In this section we will adapt an MLX training loop to support data parallel distributed training. Namely, we will average the gradients across a set of hosts before applying them to the model.

Our training loop looks like the following code snippet if we omit the model, dataset, and optimizer initialization.

model = ...
optimizer = ...
dataset = ...

def step(model, x, y):
    loss, grads = loss_grad_fn(model, x, y)
    optimizer.update(model, grads)
    return loss

for x, y in dataset:
    loss = step(model, x, y)
    mx.eval(loss, model.parameters())

All we have to do to average the gradients across machines is perform an :func:`all_sum` and divide by the size of the :class:`Group`. Namely we have to :func:`mlx.utils.tree_map` the gradients with following function.

def all_avg(x):
    return mx.distributed.all_sum(x) / mx.distributed.init().size()

Putting everything together our training loop step looks as follows with everything else remaining the same.

from mlx.utils import tree_map

def all_reduce_grads(grads):
    N = mx.distributed.init().size()
    if N == 1:
        return grads
    return tree_map(
        lambda x: mx.distributed.all_sum(x) / N,
        grads
    )

def step(model, x, y):
    loss, grads = loss_grad_fn(model, x, y)
    grads = all_reduce_grads(grads)  # <--- This line was added
    optimizer.update(model, grads)
    return loss

Using nn.average_gradients

Although the code example above works correctly; it performs one communication per gradient. It is significantly more efficient to aggregate several gradients together and perform fewer communication steps.

This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks almost identical to the example above:

model = ...
optimizer = ...
dataset = ...

def step(model, x, y):
    loss, grads = loss_grad_fn(model, x, y)
    grads = mx.nn.average_gradients(grads)  # <---- This line was added
    optimizer.update(model, grads)
    return loss

for x, y in dataset:
    loss = step(model, x, y)
    mx.eval(loss, model.parameters())