Skip to content

Commit 57946ac

Browse files
jstacclaude
andcommitted
Fix JIT pipeline figure: make compact, drop italic labels
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4e439c5 commit 57946ac

1 file changed

Lines changed: 20 additions & 22 deletions

File tree

lectures/jax_intro.md

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -720,44 +720,42 @@ The following diagram shows this pipeline for a simple function:
720720
```{code-cell} ipython3
721721
:tags: [hide-input]
722722
723-
fig, ax = plt.subplots(figsize=(9, 3.5))
724-
ax.set_xlim(-0.5, 9)
725-
ax.set_ylim(-0.5, 3)
726-
ax.set_aspect('equal')
723+
fig, ax = plt.subplots(figsize=(7, 2))
724+
ax.set_xlim(-0.2, 7.2)
725+
ax.set_ylim(0.2, 2.2)
727726
ax.axis('off')
728727
729728
# Boxes for pipeline stages
730729
stages = [
731-
(0.8, 1.5, "Python\nfunction"),
732-
(3.2, 1.5, "JAX\ntraces →\ncomp. graph"),
733-
(5.8, 1.5, "XLA\ncompiles →\noptimized\nkernel"),
734-
(8.2, 1.5, "fast\nexecution"),
730+
(0.7, 1.2, "Python\nfunction"),
731+
(2.6, 1.2, "computational\ngraph"),
732+
(4.5, 1.2, "optimized\nkernel"),
733+
(6.4, 1.2, "fast\nexecution"),
735734
]
736735
737736
colors = ["#e3f2fd", "#fff9c4", "#f3e5f5", "#d4edda"]
738737
739738
for (x, y, label), color in zip(stages, colors):
740739
box = mpatches.FancyBboxPatch(
741-
(x - 0.9, y - 0.8), 1.8, 1.6,
740+
(x - 0.7, y - 0.5), 1.4, 1.0,
742741
boxstyle="round,pad=0.15",
743742
facecolor=color, edgecolor="black", linewidth=1.5)
744743
ax.add_patch(box)
745-
ax.text(x, y, label, ha='center', va='center', fontsize=10)
744+
ax.text(x, y, label, ha='center', va='center', fontsize=9)
746745
747-
# Arrows
748-
for x_start, x_end in [(1.7, 2.3), (4.1, 4.9), (6.7, 7.3)]:
749-
ax.annotate("", xy=(x_end, 1.5), xytext=(x_start, 1.5),
750-
arrowprops=dict(arrowstyle="->", lw=2, color="gray"))
746+
# Arrows with labels
747+
arrows = [
748+
(1.4, 1.9, "trace"),
749+
(3.3, 3.8, "XLA"),
750+
(5.2, 5.7, "run"),
751+
]
751752
752-
# Labels
753-
ax.text(2.0, 0.35, "jax.jit(f)", ha='center', fontsize=9,
754-
fontstyle='italic', color='gray')
755-
ax.text(4.5, 0.35, "first call", ha='center', fontsize=9,
756-
fontstyle='italic', color='gray')
757-
ax.text(7.0, 0.35, "subsequent\ncalls", ha='center', fontsize=9,
758-
fontstyle='italic', color='gray')
753+
for x_start, x_end, label in arrows:
754+
ax.annotate("", xy=(x_end, 1.2), xytext=(x_start, 1.2),
755+
arrowprops=dict(arrowstyle="->", lw=1.5, color="gray"))
756+
ax.text((x_start + x_end) / 2, 1.55, label,
757+
ha='center', fontsize=8, color='gray')
759758
760-
ax.set_title("JIT Compilation Pipeline", fontsize=13, pad=10)
761759
plt.tight_layout()
762760
plt.show()
763761
```

0 commit comments

Comments
 (0)