You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/src/examples/tensor_parallelism.rst
+4-2Lines changed: 4 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,3 +1,5 @@
1
+
.. _tensor_parallelism:
2
+
1
3
Tensor Parallelism
2
4
==================
3
5
@@ -60,7 +62,7 @@ We can create partial inputs based on rank. For example, for an input with 1024
60
62
layer = nn.ShardedToAllLinear(1024, 1024, bias=False) # initialize the layer
61
63
y = layer(x[part]) # process sharded input
62
64
63
-
This code splits the 1024 input features into ``world.size()`` different groups which are assigned continuously based on ``world.rank()``. More information about distributed communication can be found in the :doc:`Distributed Communication <../usage/distributed>` page.
65
+
This code splits the 1024 input features into ``world.size()`` different groups which are assigned continuously based on ``world.rank()``. More information about distributed communication can be found in the :ref:`Distributed Communication <usage_distributed>` page.
64
66
65
67
:class:`QuantizedShardedToAllLinear <mlx.nn.QuantizedShardedToAllLinear>` is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`.
66
68
Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and
@@ -117,7 +119,7 @@ LLM Inference with Tensor Parallelism
117
119
118
120
We can apply these TP techniques to LLMs in order to enable inference for much larger models by sharding parameters from huge layers across multiple devices.
119
121
120
-
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama Inference <../examples/llama-inference>` example. In this example, we will use the same inference script as the Llama Inference example, which can be found in `mlx-examples`_.
122
+
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama Inference <llama-inference>` example. In this example, we will use the same inference script as the Llama Inference example, which can be found in `mlx-examples`_.
121
123
122
124
Our first edit is to initialize the distributed communication group and get the current process rank:
0 commit comments