Skip to content

Commit 72a8f72

Browse files
committed
add jaxpr dump support
1 parent 8204907 commit 72a8f72

8 files changed

Lines changed: 452 additions & 177 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ profile_periodically_period: -1 # If set to a positive integer, profile every pr
696696
managed_mldiagnostics: False # Whether to enable the managed diagnostics
697697
managed_mldiagnostics_run_group: "" # Optional. Used to group multiple runs.
698698

699-
# Dump HLO options
699+
# Dump HLO and jaxpr options
700700
dump_hlo: False
701701
dump_step: -1 # Dump modules at the given step if set to a positive integer.
702702
dump_hlo_local_dir: "/tmp/xla_dump/"
@@ -708,6 +708,10 @@ dump_hlo_xla_flags: "" # Defaults to "--xla_dump_to={dump_hlo_local_dir} --xla_d
708708
dump_hlo_upload_all: False # If true all hosts dump HLO, false only jax.process_index()==0
709709
# All hosts should have identical HLO for SPMD programs, however we have encountered some bugs
710710
# where this is not the case and it is helpful to compare HLO across hosts.
711+
dump_jaxpr: False
712+
dump_jaxpr_local_dir: "/tmp/jaxpr_dump/"
713+
dump_jaxpr_delete_local_after: True
714+
dump_jaxpr_gcs_dir: "" # Defaults to {base_output_directory}/{run_name}/jaxpr_dump
711715

712716
# When dropout is false the model is a deterministic function of the
713717
# data_shuffle_seed and init_weights_seed (i.e. reproducible losses)

src/MaxText/configs/types.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,13 @@ class HloDump(BaseModel):
12891289
dump_hlo_local_module_name: str = Field("jit_train_step", description="Filter modules to save locally by this name.")
12901290
dump_hlo_xla_flags: str = Field("", description="Pass custom XLA flags for HLO dumping.")
12911291
dump_hlo_upload_all: bool = Field(False, description="Upload HLO from all hosts.")
1292+
dump_jaxpr: bool = Field(False, description="Enable jaxpr dumping.")
1293+
dump_jaxpr_local_dir: PathStr = Field(
1294+
os.path.join(gettempdir(), "jaxpr_dump", ""),
1295+
description="Local directory to dump jaxpr.",
1296+
)
1297+
dump_jaxpr_delete_local_after: bool = Field(True, description="Delete local jaxpr dump after uploading to GCS.")
1298+
dump_jaxpr_gcs_dir: PathStr = Field("", description="GCS directory to upload jaxpr dumps.")
12921299

12931300

12941301
class StackTrace(BaseModel):
@@ -1837,6 +1844,10 @@ def validate_and_set_hlo_dump_defaults():
18371844
self.dump_hlo_gcs_dir = os.path.join(self.base_output_directory, self.run_name, "xla_dump")
18381845
else:
18391846
self.dump_hlo_gcs_dir = gcs_utils.add_trailing_slash(self.dump_hlo_gcs_dir)
1847+
if not self.dump_jaxpr_gcs_dir:
1848+
self.dump_jaxpr_gcs_dir = os.path.join(self.base_output_directory, self.run_name, "jaxpr_dump")
1849+
else:
1850+
self.dump_jaxpr_gcs_dir = gcs_utils.add_trailing_slash(self.dump_jaxpr_gcs_dir)
18401851
if not os.environ.get("XLA_FLAGS"):
18411852
os.environ["XLA_FLAGS"] = self.dump_hlo_xla_flags
18421853

src/MaxText/maxtext_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import functools
1919
import pickle
20+
import os
2021

2122
from flax import linen as nn
2223
from flax.linen import partitioning as nn_partitioning
@@ -41,6 +42,7 @@
4142
from MaxText import multimodal_utils
4243
from MaxText import sharding
4344
from MaxText.configs import types
45+
from MaxText.utils import gcs_utils
4446
from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
4547
from MaxText.inference.page_manager import PageState
4648

@@ -1213,3 +1215,33 @@ def print_shardings_params(params, params_sharding, mesh):
12131215
shape = jax.typeof(leaf_val)
12141216
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
12151217
max_logging.log(f"{path_str:.<80} {shape} {tuple(pspec)}")
1218+
1219+
1220+
def maybe_dump_jaxpr(config, p_train_step, train_step_inputs):
1221+
"""Dump jaxpr to local then upload to GCS."""
1222+
if not config.dump_jaxpr:
1223+
return
1224+
max_logging.log("Tracing train_step to jaxpr...")
1225+
1226+
# We use the p_train_step (the JIT-decorated function)
1227+
p_train_jaxpr = jax.make_jaxpr(p_train_step)(*train_step_inputs)
1228+
1229+
local_filename = "train_step.jaxpr"
1230+
local_path = os.path.join(config.dump_jaxpr_local_dir, local_filename)
1231+
1232+
os.makedirs(config.dump_jaxpr_local_dir, exist_ok=True)
1233+
1234+
# pylint: disable=unspecified-encoding
1235+
with open(local_path, "w") as f:
1236+
f.write(str(p_train_jaxpr))
1237+
1238+
max_logging.log(f"Jaxpr dumped locally to {local_path}")
1239+
1240+
if config.dump_jaxpr_gcs_dir:
1241+
gcs_utils.upload_dump(
1242+
config.dump_jaxpr_local_dir,
1243+
config.dump_jaxpr_gcs_dir,
1244+
module_name=local_filename,
1245+
delete_local_after=config.dump_jaxpr_delete_local_after, # Keeping local for debugging
1246+
all_host_upload=False, # Only upload from lead host (Host 0)
1247+
)

src/MaxText/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ def train_loop(config, recorder, state=None):
427427
shaped_batch = maxtext_utils.get_shaped_batch(config)
428428
if config.shard_optimizer_over_data:
429429
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
430+
maxtext_utils.maybe_dump_jaxpr(config, p_train_step, (state, shaped_batch, init_rng))
430431
if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded
431432
compiled = p_train_step.lower(state, shaped_batch, init_rng).compile()
432433
compiled_stats = compiled.memory_analysis()

src/MaxText/train_compile.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def jit_and_compile(
121121
out_shardings,
122122
static_argnums,
123123
donate_argnums,
124+
config,
124125
logical_axis_rules,
125126
):
126127
"""Jit, lower, and compile func."""
@@ -132,6 +133,7 @@ def jit_and_compile(
132133
static_argnums=static_argnums,
133134
donate_argnums=donate_argnums,
134135
)
136+
maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args)
135137
lowered = jitted.lower(*func_input_args, **func_input_kwargs)
136138
compiled = lowered.compile()
137139
return compiled
@@ -180,6 +182,7 @@ def is_oom(argv: Sequence[str]) -> bool:
180182
out_shard,
181183
static_argnums,
182184
donate_argnums,
185+
config,
183186
nn_partitioning.axis_rules(config.logical_axis_rules),
184187
)
185188
return False
@@ -241,6 +244,7 @@ def main(argv: Sequence[str]) -> None:
241244
out_shard,
242245
static_argnums,
243246
donate_argnums,
247+
config,
244248
nn_partitioning.axis_rules(config.logical_axis_rules),
245249
)
246250
print("Jitting and compilation complete!", flush=True)

src/MaxText/utils/gcs_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def upload_dump(local_dir, target_dir, module_name=None, delete_local_after=True
7979
hostname = socket.gethostname() # Alternatively can use jax.process_id()
8080
prefix_name = os.path.join(prefix_name, hostname)
8181
target_dir = os.path.join(target_dir, hostname)
82-
max_logging.log(f"Uploading HLO Dump to {target_dir}...")
82+
max_logging.log(f"Uploading Dump to {target_dir}...")
8383
for root, _, files in os.walk(local_dir):
8484
for file in files:
8585
if module_name and module_name not in file:
@@ -91,7 +91,7 @@ def upload_dump(local_dir, target_dir, module_name=None, delete_local_after=True
9191
blob_name = os.path.join(prefix_name, relative_path)
9292
blob = bucket.blob(blob_name)
9393
blob.upload_from_filename(local_path)
94-
max_logging.log(f"HLO Dump Uploaded to {target_dir}!")
94+
max_logging.log(f"Dump Uploaded to {target_dir}!")
9595
if delete_local_after:
9696
shutil.rmtree(local_dir)
9797

tests/integration/aot_hlo_identical_test.py

Lines changed: 0 additions & 174 deletions
This file was deleted.

0 commit comments

Comments
 (0)