@@ -720,44 +720,42 @@ The following diagram shows this pipeline for a simple function:
720720``` {code-cell} ipython3
721721:tags: [hide-input]
722722
723- fig, ax = plt.subplots(figsize=(9, 3.5))
724- ax.set_xlim(-0.5, 9)
725- ax.set_ylim(-0.5, 3)
726- ax.set_aspect('equal')
723+ fig, ax = plt.subplots(figsize=(7, 2))
724+ ax.set_xlim(-0.2, 7.2)
725+ ax.set_ylim(0.2, 2.2)
727726ax.axis('off')
728727
729728# Boxes for pipeline stages
730729stages = [
731- (0.8 , 1.5 , "Python\nfunction"),
732- (3.2 , 1.5 , "JAX\ntraces →\ncomp. graph "),
733- (5.8 , 1.5 , "XLA\ncompiles →\noptimized \nkernel"),
734- (8.2 , 1.5 , "fast\nexecution"),
730+ (0.7 , 1.2 , "Python\nfunction"),
731+ (2.6 , 1.2 , "computational\ngraph "),
732+ (4.5 , 1.2 , "optimized \nkernel"),
733+ (6.4 , 1.2 , "fast\nexecution"),
735734]
736735
737736colors = ["#e3f2fd", "#fff9c4", "#f3e5f5", "#d4edda"]
738737
739738for (x, y, label), color in zip(stages, colors):
740739 box = mpatches.FancyBboxPatch(
741- (x - 0.9 , y - 0.8 ), 1.8 , 1.6 ,
740+ (x - 0.7 , y - 0.5 ), 1.4 , 1.0 ,
742741 boxstyle="round,pad=0.15",
743742 facecolor=color, edgecolor="black", linewidth=1.5)
744743 ax.add_patch(box)
745- ax.text(x, y, label, ha='center', va='center', fontsize=10 )
744+ ax.text(x, y, label, ha='center', va='center', fontsize=9 )
746745
747- # Arrows
748- for x_start, x_end in [(1.7, 2.3), (4.1, 4.9), (6.7, 7.3)]:
749- ax.annotate("", xy=(x_end, 1.5), xytext=(x_start, 1.5),
750- arrowprops=dict(arrowstyle="->", lw=2, color="gray"))
746+ # Arrows with labels
747+ arrows = [
748+ (1.4, 1.9, "trace"),
749+ (3.3, 3.8, "XLA"),
750+ (5.2, 5.7, "run"),
751+ ]
751752
752- # Labels
753- ax.text(2.0, 0.35, "jax.jit(f)", ha='center', fontsize=9,
754- fontstyle='italic', color='gray')
755- ax.text(4.5, 0.35, "first call", ha='center', fontsize=9,
756- fontstyle='italic', color='gray')
757- ax.text(7.0, 0.35, "subsequent\ncalls", ha='center', fontsize=9,
758- fontstyle='italic', color='gray')
753+ for x_start, x_end, label in arrows:
754+ ax.annotate("", xy=(x_end, 1.2), xytext=(x_start, 1.2),
755+ arrowprops=dict(arrowstyle="->", lw=1.5, color="gray"))
756+ ax.text((x_start + x_end) / 2, 1.55, label,
757+ ha='center', fontsize=8, color='gray')
759758
760- ax.set_title("JIT Compilation Pipeline", fontsize=13, pad=10)
761759plt.tight_layout()
762760plt.show()
763761```
0 commit comments