Skip to content

Commit 4e439c5

Browse files
jstacclaude
andcommitted
Improve vmap example and consolidate imports
- Replace artificial sum-of-squares vmap example with mean/median statistics - Add explanation of why Python loops are inefficient with JAX - Move all imports to initial cell per lecture conventions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 90963e1 commit 4e439c5

1 file changed

Lines changed: 27 additions & 15 deletions

File tree

lectures/jax_intro.md

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -882,34 +882,46 @@ This avoids the need to manually write vectorized code or use explicit loops.
882882

883883
### A simple example
884884

885-
Suppose we have a function that processes a single vector:
885+
Suppose we have a function that computes summary statistics for a single array:
886886

887887
```{code-cell} ipython3
888-
def scalar_func(x):
889-
return jnp.sum(x ** 2)
888+
def summary(x):
889+
return jnp.mean(x), jnp.median(x)
890890
```
891891

892892
We can apply it to a single vector:
893893

894894
```{code-cell} ipython3
895-
x = jnp.array([1.0, 2.0, 3.0])
896-
scalar_func(x)
895+
x = jnp.array([1.0, 2.0, 5.0])
896+
summary(x)
897897
```
898898

899-
To apply it across a *batch* of vectors (i.e., rows of a matrix), we use `vmap`:
899+
Now suppose we have a matrix and want to compute these statistics for each row.
900900

901-
```{code-cell} ipython3
902-
batch_func = jax.vmap(scalar_func)
901+
Without `vmap`, we'd need an explicit loop:
903902

904-
X = jnp.array([[1.0, 2.0, 3.0],
903+
```{code-cell} ipython3
904+
X = jnp.array([[1.0, 2.0, 5.0],
905905
[4.0, 5.0, 6.0],
906-
[7.0, 8.0, 9.0]])
906+
[1.0, 8.0, 9.0]])
907+
908+
for row in X:
909+
print(summary(row))
910+
```
907911

908-
batch_func(X)
912+
However, Python loops are slow and cannot be efficiently compiled or
913+
parallelized by JAX.
914+
915+
Using `vmap` keeps the computation on the accelerator and composes with other
916+
JAX transformations like `jit` and `grad`:
917+
918+
```{code-cell} ipython3
919+
batch_summary = jax.vmap(summary)
920+
batch_summary(X)
909921
```
910922

911-
Without `vmap`, we would need to write an explicit loop or reshape the
912-
computation manually.
923+
The function `summary` was written for a single array, and `vmap` automatically
924+
lifted it to operate row-wise over a matrix --- no loops, no reshaping.
913925

914926
### Combining transformations
915927

@@ -918,8 +930,8 @@ One of JAX's strengths is that transformations compose naturally.
918930
For example, we can JIT-compile a vectorized function:
919931

920932
```{code-cell} ipython3
921-
fast_batch_func = jax.jit(jax.vmap(scalar_func))
922-
fast_batch_func(X)
933+
fast_batch_summary = jax.jit(jax.vmap(summary))
934+
fast_batch_summary(X)
923935
```
924936

925937
This composition of `jit`, `vmap`, and (as we'll see next) `grad` is central to

0 commit comments

Comments
 (0)