diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index bd4f102e21..52be815772 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -944,6 +944,15 @@ def setup_initial_state( return state, state_mesh_annotations, state_mesh_shardings, data_iterator +def get_logical_annotations(model, tx, config, rng, mesh, is_training=True): + init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng) + + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + abstract_state = jax.eval_shape(init_state_partial) + logical_annotations = nn.get_partition_spec(abstract_state) + return logical_annotations + + def get_abstract_state(model, tx, config, rng, mesh, is_training=True): """Get a shaped abstraction of the state (including optimizer)""" init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng) @@ -1227,15 +1236,32 @@ def schedule(step): return optax.join_schedules(pieces, boundaries) -def print_shardings_params(params, params_sharding, mesh): - """Print state shardings.""" +def print_shardings_params(params, params_sharding, mesh, logical_annotations=None): + """ + Print state shardings comparing Logical Definition vs Physical Result. + """ + if not hasattr(params, "params"): + params = {"params": params} + if not hasattr(params_sharding, "params"): + params_sharding = {"params": params_sharding} + if logical_annotations and not hasattr(logical_annotations, "params"): + logical_annotations = {"params": logical_annotations} + leaves_params, _ = jax.tree_util.tree_flatten_with_path(params) leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding) - for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding): + leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations) + + for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical): path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) shape = jax.typeof(leaf_val) pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) - max_logging.log(f"{path_str:.<80} {shape} {tuple(pspec)}") + pspec_str = str(tuple(pspec)) + logical_str = str(leaf_logical_val) + + message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" + max_logging.info(message) + + print(flush=True) def maybe_dump_jaxpr(config, p_train_step, train_step_inputs): diff --git a/src/MaxText/model_creation_utils.py b/src/MaxText/model_creation_utils.py index 4a7b42056b..cd86cfc22c 100644 --- a/src/MaxText/model_creation_utils.py +++ b/src/MaxText/model_creation_utils.py @@ -157,7 +157,12 @@ def create_sharded_state(): # print weights sharding info under debug sharding mode if config.debug_sharding: max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params(sharded_state, out_shardings, model.mesh) + maxtext_utils.print_shardings_params( + params=sharded_state, + params_sharding=out_shardings, + mesh=model.mesh, + logical_annotations=specs, + ) if config.load_parameters_path: try: ckptr = ocp.Checkpointer( diff --git a/src/MaxText/sharding.py b/src/MaxText/sharding.py index 616530e51f..a8f83195eb 100644 --- a/src/MaxText/sharding.py +++ b/src/MaxText/sharding.py @@ -31,6 +31,7 @@ _LOGGED_ACTIVATION_SHARDINGS = set() +_LOGGED_LOGICAL_AXES = set() def get_input_data_sharding(config, mesh): @@ -51,7 +52,7 @@ def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=Fal pspec = remove_size_one_mesh_axis(getattr(named_sharding, "spec"), getattr(named_sharding, "mesh")) log_key = (str(jax.typeof(inputs)), tuple(pspec), extra_stack_level) if log_key not in _LOGGED_ACTIVATION_SHARDINGS: - max_logging.info(f"{log_key[0]:.<80} {log_key[1]}.", stacklevel=3 + extra_stack_level) + max_logging.info(f"Physical: {log_key[0]:.<80} {log_key[1]}.", stacklevel=3 + extra_stack_level) _LOGGED_ACTIVATION_SHARDINGS.add(log_key) if shard_mode == ShardMode.EXPLICIT: return reshard(inputs, named_sharding) @@ -67,9 +68,22 @@ def maybe_shard_with_logical( """ if inputs is None: return None + named_sharding = create_sharding(mesh, logical_axes, rules=rules) + + if debug_sharding and isinstance(inputs, Tracer): + log_key = (str(jax.typeof(inputs)), logical_axes, extra_stack_level) + + if log_key not in _LOGGED_LOGICAL_AXES: + max_logging.info(f"Logical: {log_key[0]:.<60} {log_key[1]}", stacklevel=3 + extra_stack_level) + _LOGGED_LOGICAL_AXES.add(log_key) + return maybe_shard_with_name( - inputs, named_sharding, shard_mode, debug_sharding=debug_sharding, extra_stack_level=extra_stack_level + 1 + inputs, + named_sharding, + shard_mode, + debug_sharding=debug_sharding, + extra_stack_level=extra_stack_level + 1, ) diff --git a/src/MaxText/train_compile.py b/src/MaxText/train_compile.py index cccb03e284..d643e1c672 100644 --- a/src/MaxText/train_compile.py +++ b/src/MaxText/train_compile.py @@ -104,12 +104,15 @@ def get_shaped_inputs(topology_mesh, config): model, tx, config, example_rng, topology_mesh ) + # unsharded logical annotations + logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, example_rng, topology_mesh) + # Shaped batch shaped_batch = maxtext_utils.get_shaped_batch(config) shaped_train_args = (abstract_state, shaped_batch, shaped_rng) shaped_train_kwargs = {} - return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model + return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model def jit_and_compile( @@ -160,7 +163,13 @@ def is_oom(argv: Sequence[str]) -> bool: max_utils.print_system_information() # Get shaped inputs - shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(topology_mesh, config) + ( + shaped_train_args, + shaped_train_kwargs, + state_mesh_shardings, + _, + model, + ) = get_shaped_inputs(topology_mesh, config) # Get data sharding data_sharding = sharding.get_input_data_sharding(config, topology_mesh) @@ -216,7 +225,13 @@ def main(argv: Sequence[str]) -> None: max_utils.print_system_information() # Get shaped inputs - shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(topology_mesh, config) + ( + shaped_train_args, + shaped_train_kwargs, + state_mesh_shardings, + logical_annotations, + model, + ) = get_shaped_inputs(topology_mesh, config) # Get data sharding data_sharding = sharding.get_input_data_sharding(config, topology_mesh) @@ -231,7 +246,12 @@ def main(argv: Sequence[str]) -> None: # print weights sharding info under debug sharding mode if config.debug_sharding: max_utils.print_non_trivial_mesh_axis(topology_mesh) - maxtext_utils.print_shardings_params(shaped_train_args[0].params, state_mesh_shardings.params, topology_mesh) + maxtext_utils.print_shardings_params( + shaped_train_args[0].params, + state_mesh_shardings.params, + topology_mesh, + logical_annotations.params, + ) # Compile print("Jitting and compiling train step...", flush=True) diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index a672399eb8..9dba42cb32 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -217,8 +217,11 @@ def setup_train_loop(config, recorder, devices=None): # print weights sharding info under debug sharding mode if config.debug_sharding: + logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, init_rng, mesh, is_training=True) max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params(state.params, state_mesh_shardings.params, model.mesh) + maxtext_utils.print_shardings_params( + state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params + ) if config.use_dpo: abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) diff --git a/tests/unit/sharding_compare_test.py b/tests/unit/sharding_compare_test.py index ad103d20df..c9b901a509 100644 --- a/tests/unit/sharding_compare_test.py +++ b/tests/unit/sharding_compare_test.py @@ -97,7 +97,7 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) validate_config(config) topology_mesh = get_topology_mesh(config) - _, _, state_mesh_shardings, _ = get_shaped_inputs(topology_mesh, config) + _, _, state_mesh_shardings, _, _ = get_shaped_inputs(topology_mesh, config) actual_json = named_shardings_to_json(state_mesh_shardings) expected_json = load_named_sharding_json(json_path) diff --git a/tests/utils/sharding_dump.py b/tests/utils/sharding_dump.py index 454fb31576..c096c98136 100644 --- a/tests/utils/sharding_dump.py +++ b/tests/utils/sharding_dump.py @@ -276,7 +276,7 @@ def main(argv: Sequence[str]) -> None: try: topology_mesh = get_topology_mesh(config) - _, _, state_mesh_shardings, _ = get_shaped_inputs(topology_mesh, config) + _, _, state_mesh_shardings, _, _ = get_shaped_inputs(topology_mesh, config) except: # pylint: disable=bare-except state_mesh_shardings = {}