|
17 | 17 |
|
18 | 18 | import functools |
19 | 19 | import pickle |
| 20 | +import os |
20 | 21 |
|
21 | 22 | from flax import linen as nn |
22 | 23 | from flax.linen import partitioning as nn_partitioning |
|
41 | 42 | from MaxText import multimodal_utils |
42 | 43 | from MaxText import sharding |
43 | 44 | from MaxText.configs import types |
| 45 | +from MaxText.utils import gcs_utils |
44 | 46 | from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE |
45 | 47 | from MaxText.inference.page_manager import PageState |
46 | 48 |
|
@@ -1213,3 +1215,33 @@ def print_shardings_params(params, params_sharding, mesh): |
1213 | 1215 | shape = jax.typeof(leaf_val) |
1214 | 1216 | pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) |
1215 | 1217 | max_logging.log(f"{path_str:.<80} {shape} {tuple(pspec)}") |
| 1218 | + |
| 1219 | + |
| 1220 | +def maybe_dump_jaxpr(config, p_train_step, train_step_inputs): |
| 1221 | + """Dump jaxpr to local then upload to GCS.""" |
| 1222 | + if not config.dump_jaxpr: |
| 1223 | + return |
| 1224 | + max_logging.log("Tracing train_step to jaxpr...") |
| 1225 | + |
| 1226 | + # We use the p_train_step (the JIT-decorated function) |
| 1227 | + p_train_jaxpr = jax.make_jaxpr(p_train_step)(*train_step_inputs) |
| 1228 | + |
| 1229 | + local_filename = "train_step.jaxpr" |
| 1230 | + local_path = os.path.join(config.dump_jaxpr_local_dir, local_filename) |
| 1231 | + |
| 1232 | + os.makedirs(config.dump_jaxpr_local_dir, exist_ok=True) |
| 1233 | + |
| 1234 | + # pylint: disable=unspecified-encoding |
| 1235 | + with open(local_path, "w") as f: |
| 1236 | + f.write(str(p_train_jaxpr)) |
| 1237 | + |
| 1238 | + max_logging.log(f"Jaxpr dumped locally to {local_path}") |
| 1239 | + |
| 1240 | + if config.dump_jaxpr_gcs_dir: |
| 1241 | + gcs_utils.upload_dump( |
| 1242 | + config.dump_jaxpr_local_dir, |
| 1243 | + config.dump_jaxpr_gcs_dir, |
| 1244 | + module_name=local_filename, |
| 1245 | + delete_local_after=config.dump_jaxpr_delete_local_after, # Keeping local for debugging |
| 1246 | + all_host_upload=False, # Only upload from lead host (Host 0) |
| 1247 | + ) |
0 commit comments