@@ -882,34 +882,46 @@ This avoids the need to manually write vectorized code or use explicit loops.
882882
883883### A simple example
884884
885- Suppose we have a function that processes a single vector :
885+ Suppose we have a function that computes summary statistics for a single array :
886886
887887``` {code-cell} ipython3
888- def scalar_func (x):
889- return jnp.sum(x ** 2 )
888+ def summary (x):
889+ return jnp.mean(x), jnp.median(x )
890890```
891891
892892We can apply it to a single vector:
893893
894894``` {code-cell} ipython3
895- x = jnp.array([1.0, 2.0, 3 .0])
896- scalar_func (x)
895+ x = jnp.array([1.0, 2.0, 5 .0])
896+ summary (x)
897897```
898898
899- To apply it across a * batch * of vectors (i.e., rows of a matrix), we use ` vmap ` :
899+ Now suppose we have a matrix and want to compute these statistics for each row.
900900
901- ``` {code-cell} ipython3
902- batch_func = jax.vmap(scalar_func)
901+ Without ` vmap ` , we'd need an explicit loop:
903902
904- X = jnp.array([[1.0, 2.0, 3.0],
903+ ``` {code-cell} ipython3
904+ X = jnp.array([[1.0, 2.0, 5.0],
905905 [4.0, 5.0, 6.0],
906- [7.0, 8.0, 9.0]])
906+ [1.0, 8.0, 9.0]])
907+
908+ for row in X:
909+ print(summary(row))
910+ ```
907911
908- batch_func(X)
912+ However, Python loops are slow and cannot be efficiently compiled or
913+ parallelized by JAX.
914+
915+ Using ` vmap ` keeps the computation on the accelerator and composes with other
916+ JAX transformations like ` jit ` and ` grad ` :
917+
918+ ``` {code-cell} ipython3
919+ batch_summary = jax.vmap(summary)
920+ batch_summary(X)
909921```
910922
911- Without ` vmap ` , we would need to write an explicit loop or reshape the
912- computation manually .
923+ The function ` summary ` was written for a single array, and ` vmap ` automatically
924+ lifted it to operate row-wise over a matrix --- no loops, no reshaping .
913925
914926### Combining transformations
915927
@@ -918,8 +930,8 @@ One of JAX's strengths is that transformations compose naturally.
918930For example, we can JIT-compile a vectorized function:
919931
920932``` {code-cell} ipython3
921- fast_batch_func = jax.jit(jax.vmap(scalar_func ))
922- fast_batch_func (X)
933+ fast_batch_summary = jax.jit(jax.vmap(summary ))
934+ fast_batch_summary (X)
923935```
924936
925937This composition of ` jit ` , ` vmap ` , and (as we'll see next) ` grad ` is central to
0 commit comments