From e9bd2cb87822c4aa488b5a451092d2f9a2c23bfb Mon Sep 17 00:00:00 2001 From: Sharon Yu Date: Wed, 21 Jan 2026 16:17:20 +0000 Subject: [PATCH 1/9] print out logic axes --- src/MaxText/max_utils.py | 11 +++++++ src/MaxText/maxtext_utils.py | 48 +++++++++++++++++++++++++++-- src/MaxText/model_creation_utils.py | 8 ++++- src/MaxText/train_compile.py | 30 ++++++++++++++---- src/MaxText/train_utils.py | 6 ++-- tests/unit/sharding_compare_test.py | 2 +- tests/utils/sharding_dump.py | 2 +- 7 files changed, 94 insertions(+), 13 deletions(-) diff --git a/src/MaxText/max_utils.py b/src/MaxText/max_utils.py index 510878f9be..38c0e43495 100644 --- a/src/MaxText/max_utils.py +++ b/src/MaxText/max_utils.py @@ -1032,3 +1032,14 @@ def transformer_engine_context(): yield except (ImportError, AttributeError): yield + + +def print_mesh_axes_info(mesh: jax.sharding.Mesh): + """Prints all mesh axes and their sizes in a single comma-separated line.""" + if not mesh.shape: + max_logging.info("Mesh Axes: (Empty Mesh)") + return + + axis_info = [f"{axis_name}: {axis_size}" for axis_name, axis_size in mesh.shape.items()] + info_str = "Mesh Axes: (" + ", ".join(axis_info) + ")" + max_logging.info(info_str) diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index bd4f102e21..c793d1490c 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -19,6 +19,8 @@ import pickle import os +from collections import defaultdict + from flax import linen as nn from flax.linen import partitioning as nn_partitioning from flax.training import train_state @@ -27,6 +29,7 @@ from jax.experimental import mesh_utils from jax.experimental.serialize_executable import deserialize_and_load +from jax.sharding import PartitionSpec as P import jax import jax.numpy as jnp @@ -1227,11 +1230,52 @@ def schedule(step): return optax.join_schedules(pieces, boundaries) -def print_shardings_params(params, params_sharding, mesh): +def print_shardings_params(params, params_sharding, mesh, state_logical_annotations=None, logical_axis_rules=None): """Print state shardings.""" leaves_params, _ = jax.tree_util.tree_flatten_with_path(params) leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding) - for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding): + + leaves_rule_values = [] + if state_logical_annotations and hasattr(state_logical_annotations, "params"): + leaves_rule_values, _ = jax.tree_util.tree_flatten_with_path(state_logical_annotations.params) + else: + leaves_rule_values = [(None, None)] * len(leaves_params) + + if not len(leaves_params) == len(leaves_sharding) == len(leaves_rule_values): + max_logging.warning( + "Warning: Parameter tree structure mismatch between params, shardings," " and logical annotations." + ) + return + + # Build a reverse map + rule_value_to_semantic = defaultdict(list) + if logical_axis_rules: + rules_iter = logical_axis_rules.items() if isinstance(logical_axis_rules, dict) else logical_axis_rules + for name, potentials in rules_iter: + if isinstance(potentials, str): + key = (potentials,) + elif potentials is None: + key = (None,) + elif isinstance(potentials, list): + key = tuple(potentials) + elif isinstance(potentials, tuple): + key = potentials + else: + key = (potentials,) + + key = tuple(p for p in key) + rule_value_to_semantic[key].append(name) + + # Header for the entire block ( + max_logging.info("Parameter Path") + max_logging.info("Shape") + max_logging.info("Logical Axes") + max_logging.info("Physical PartitionSpec") + max_logging.info("-" * 120) + + for (path, leaf_val), (_, leaf_sharding), (_, leaf_rule_value) in zip( + leaves_params, leaves_sharding, leaves_rule_values + ): path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) shape = jax.typeof(leaf_val) pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) diff --git a/src/MaxText/model_creation_utils.py b/src/MaxText/model_creation_utils.py index 4a7b42056b..cae4af607b 100644 --- a/src/MaxText/model_creation_utils.py +++ b/src/MaxText/model_creation_utils.py @@ -157,7 +157,13 @@ def create_sharded_state(): # print weights sharding info under debug sharding mode if config.debug_sharding: max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params(sharded_state, out_shardings, model.mesh) + maxtext_utils.print_state_mesh_shardings_params( + state=sharded_state, + state_sharding=out_shardings, + state_logical_annotations=specs, + mesh=model.mesh, + logical_axis_rules=config.logical_axis_rules, + ) if config.load_parameters_path: try: ckptr = ocp.Checkpointer( diff --git a/src/MaxText/train_compile.py b/src/MaxText/train_compile.py index cccb03e284..05c69d8731 100644 --- a/src/MaxText/train_compile.py +++ b/src/MaxText/train_compile.py @@ -100,7 +100,7 @@ def get_shaped_inputs(topology_mesh, config): shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype) # Shaped state - abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state( + abstract_state, state_logical_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state( model, tx, config, example_rng, topology_mesh ) @@ -109,7 +109,7 @@ def get_shaped_inputs(topology_mesh, config): shaped_train_args = (abstract_state, shaped_batch, shaped_rng) shaped_train_kwargs = {} - return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model + return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, state_logical_annotations, model def jit_and_compile( @@ -160,7 +160,13 @@ def is_oom(argv: Sequence[str]) -> bool: max_utils.print_system_information() # Get shaped inputs - shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(topology_mesh, config) + ( + shaped_train_args, + shaped_train_kwargs, + state_mesh_shardings, + _, + model, + ) = get_shaped_inputs(topology_mesh, config) # Get data sharding data_sharding = sharding.get_input_data_sharding(config, topology_mesh) @@ -216,7 +222,13 @@ def main(argv: Sequence[str]) -> None: max_utils.print_system_information() # Get shaped inputs - shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(topology_mesh, config) + ( + shaped_train_args, + shaped_train_kwargs, + state_mesh_shardings, + state_logical_annotations, + model, + ) = get_shaped_inputs(topology_mesh, config) # Get data sharding data_sharding = sharding.get_input_data_sharding(config, topology_mesh) @@ -230,8 +242,14 @@ def main(argv: Sequence[str]) -> None: # print weights sharding info under debug sharding mode if config.debug_sharding: - max_utils.print_non_trivial_mesh_axis(topology_mesh) - maxtext_utils.print_shardings_params(shaped_train_args[0].params, state_mesh_shardings.params, topology_mesh) + max_utils.print_mesh_axes_info(topology_mesh) + maxtext_utils.print_shardings_params( + shaped_train_args[0].params, + state_mesh_shardings.params, + topology_mesh, + state_logical_annotations, + config.logical_axis_rules, + ) # Compile print("Jitting and compiling train step...", flush=True) diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index a672399eb8..e0698eff8b 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -206,7 +206,7 @@ def setup_train_loop(config, recorder, devices=None): eval_data_iterator, ) - state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( + state, state_mesh_annotations, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager ) @@ -218,7 +218,9 @@ def setup_train_loop(config, recorder, devices=None): # print weights sharding info under debug sharding mode if config.debug_sharding: max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params(state.params, state_mesh_shardings.params, model.mesh) + maxtext_utils.print_shardings_params( + state.params, state_mesh_shardings.params, model.mesh, state_mesh_annotations, config.logical_axis_rules + ) if config.use_dpo: abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) diff --git a/tests/unit/sharding_compare_test.py b/tests/unit/sharding_compare_test.py index ad103d20df..c9b901a509 100644 --- a/tests/unit/sharding_compare_test.py +++ b/tests/unit/sharding_compare_test.py @@ -97,7 +97,7 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) validate_config(config) topology_mesh = get_topology_mesh(config) - _, _, state_mesh_shardings, _ = get_shaped_inputs(topology_mesh, config) + _, _, state_mesh_shardings, _, _ = get_shaped_inputs(topology_mesh, config) actual_json = named_shardings_to_json(state_mesh_shardings) expected_json = load_named_sharding_json(json_path) diff --git a/tests/utils/sharding_dump.py b/tests/utils/sharding_dump.py index 454fb31576..c096c98136 100644 --- a/tests/utils/sharding_dump.py +++ b/tests/utils/sharding_dump.py @@ -276,7 +276,7 @@ def main(argv: Sequence[str]) -> None: try: topology_mesh = get_topology_mesh(config) - _, _, state_mesh_shardings, _ = get_shaped_inputs(topology_mesh, config) + _, _, state_mesh_shardings, _, _ = get_shaped_inputs(topology_mesh, config) except: # pylint: disable=bare-except state_mesh_shardings = {} From d26681104accfa16e5b43fcbfab7042e61a43832 Mon Sep 17 00:00:00 2001 From: Sharon Yu Date: Thu, 22 Jan 2026 02:50:09 +0000 Subject: [PATCH 2/9] print activation --- src/MaxText/sharding.py | 94 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 4 deletions(-) diff --git a/src/MaxText/sharding.py b/src/MaxText/sharding.py index 616530e51f..870cd944b1 100644 --- a/src/MaxText/sharding.py +++ b/src/MaxText/sharding.py @@ -29,8 +29,10 @@ from MaxText import max_logging from MaxText.common_types import ShardMode +import inspect _LOGGED_ACTIVATION_SHARDINGS = set() +_GLOBAL_LOGICAL_RULES = None def get_input_data_sharding(config, mesh): @@ -64,13 +66,91 @@ def maybe_shard_with_logical( ): """ A wrapper of maybe_shard_with_name when logical axes are inputs + Features: + - Auto-fetches global rules if not provided. + - Prints "Logical Rules -> Physical Spec" mapping logs. """ if inputs is None: return None - named_sharding = create_sharding(mesh, logical_axes, rules=rules) - return maybe_shard_with_name( - inputs, named_sharding, shard_mode, debug_sharding=debug_sharding, extra_stack_level=extra_stack_level + 1 - ) + + active_rules = rules if rules is not None else _GLOBAL_LOGICAL_RULES + named_sharding = create_sharding(mesh, logical_axes, rules=active_rules) + if debug_sharding and hasattr(inputs, "shape"): + max_logging.info("=" * 120) + max_logging.info(" Tracing logical axes for activations during JIT compilation.") + max_logging.info("=" * 120) + caller_info = "Unknown" + # ----(Caller Detection) --- + try: + stack = inspect.stack() + frame_idx = 2 + extra_stack_level + if len(stack) > frame_idx: + frame_basic = stack[frame_idx] + filename = frame_basic.filename.split("/")[-1] + lineno = frame_basic.lineno + caller_info = f"{filename}:{lineno}" + + # search up to 10 frames for better caller context + for i in range(frame_idx, min(frame_idx + 10, len(stack))): + frame = stack[i].frame + if "self" in frame.f_locals: + obj = frame.f_locals["self"] + if hasattr(obj, "__class__"): + cls_name = obj.__class__.__name__ + if cls_name not in ["Module", "object", "PjitFunction"]: + caller_info = cls_name + break + + func_name = stack[i].function + if func_name not in ["", "__call__", "setup", "apply", "wrapper", "maybe_shard_with_logical"]: + if ":" in caller_info: + caller_info = func_name + except Exception as e: + caller_info = f"Err:{str(e)}" + + # ---(Unresolved Rules) --- + unresolved_str = "" + if active_rules is None: + unresolved_str = "Implicit/Global Rules (Not Found)" + else: + try: + rules_dict = active_rules if isinstance(active_rules, dict) else dict(active_rules) + + mapping_parts = [] + if logical_axes: + for axis in logical_axes: + target = rules_dict.get(axis, "NoRule") + mapping_parts.append(f"{axis}={target}") + unresolved_str = ", ".join(mapping_parts) + else: + unresolved_str = "No Logical Axes" + except: + unresolved_str = "Rules Format Error" + + # --- (Physical Spec) --- + phys_spec = getattr(named_sharding, "spec", None) + resolved_str = "" + try: + if phys_spec is not None and logical_axes is not None and len(logical_axes) == len(phys_spec): + pairs = [] + for logic_name, phys_axis in zip(logical_axes, phys_spec): + axis_str = str(phys_axis) if phys_axis else "None" + pairs.append(f"{logic_name}={axis_str}") + resolved_str = ", ".join(pairs) + else: + resolved_str = f"Axes={logical_axes} -> Spec={phys_spec}" + except: + resolved_str = f"Axes={logical_axes} -> Spec={phys_spec}" + + shape_str = str(jax.typeof(inputs)) + # format: Caller [RULES] -> [FINAL] Shape + # max_logging.info(f"{caller_info:<20} [RULES: {unresolved_str}] -> [FINAL: {resolved_str}] shape={shape_str}") + max_logging.info(f"{caller_info}") + max_logging.info(f"[RULES: {unresolved_str}]") + max_logging.info(f"[FINAL: {resolved_str}] shape={shape_str}") + max_logging.info("=" * 120) + + return maybe_shard_with_name(inputs, named_sharding, shard_mode, extra_stack_level=extra_stack_level + 1) def remove_size_one_mesh_axis(spec, mesh): @@ -586,3 +666,9 @@ def all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules, sha # Apply the constraint to the model's current variables. This tells JAX to # gather the weights into this layout. return maybe_shard_with_name(variables, physical_constraint_no_fsdp, shard_mode=shard_mode) + + +def set_global_logical_rules(rules): + """Allows external modules (like train_compile) to inject logical rules.""" + global _GLOBAL_LOGICAL_RULES + _GLOBAL_LOGICAL_RULES = rules From e77f07d3121530fbb729e9e61b79126e156ccba5 Mon Sep 17 00:00:00 2001 From: Sharon Yu Date: Thu, 22 Jan 2026 03:20:24 +0000 Subject: [PATCH 3/9] fix format --- src/MaxText/model_creation_utils.py | 8 +- src/MaxText/sharding.py | 148 ++++++++++++++++------------ src/MaxText/train_compile.py | 1 + 3 files changed, 90 insertions(+), 67 deletions(-) diff --git a/src/MaxText/model_creation_utils.py b/src/MaxText/model_creation_utils.py index cae4af607b..942a7e99b1 100644 --- a/src/MaxText/model_creation_utils.py +++ b/src/MaxText/model_creation_utils.py @@ -157,11 +157,11 @@ def create_sharded_state(): # print weights sharding info under debug sharding mode if config.debug_sharding: max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_state_mesh_shardings_params( - state=sharded_state, - state_sharding=out_shardings, - state_logical_annotations=specs, + maxtext_utils.print_shardings_params( + params=sharded_state, + params_sharding=out_shardings, mesh=model.mesh, + state_logical_annotations=specs, logical_axis_rules=config.logical_axis_rules, ) if config.load_parameters_path: diff --git a/src/MaxText/sharding.py b/src/MaxText/sharding.py index 870cd944b1..25ef6fb711 100644 --- a/src/MaxText/sharding.py +++ b/src/MaxText/sharding.py @@ -61,6 +61,85 @@ def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=Fal return jax.lax.with_sharding_constraint(inputs, named_sharding) +def _get_caller_info(stack, frame_idx): + """Helper function to extract caller information to reduce nesting.""" + try: + if len(stack) <= frame_idx: + return "Unknown" + + frame_basic = stack[frame_idx] + filename = frame_basic.filename.split("/")[-1] + lineno = frame_basic.lineno + caller_name = f"{filename}:{lineno}" + + # Search up to 10 frames for better caller context + for i in range(frame_idx, min(frame_idx + 10, len(stack))): + frame = stack[i].frame + + # Check for Class Name + if "self" in frame.f_locals: + obj = frame.f_locals["self"] + if hasattr(obj, "__class__"): + cls_name = obj.__class__.__name__ + if cls_name not in ["Module", "object", "PjitFunction"]: + return cls_name + + # Check for Function Name + func_name = stack[i].function + ignored_funcs = ["", "__call__", "setup", "apply", "wrapper", "maybe_shard_with_logical"] + if func_name not in ignored_funcs: + # If we only have file:line, prefer the function name + if ":" in caller_name: + caller_name = func_name + + return caller_name + + except (AttributeError, KeyError, IndexError): + return "Unknown(Error)" + + +def _format_rule_info(active_rules, logical_axes): + """Helper to format rule string to avoid broad exception catch.""" + if active_rules is None: + return "Implicit/Global Rules (Not Found)" + + try: + rules_dict = active_rules if isinstance(active_rules, dict) else dict(active_rules) + except (ValueError, TypeError): + return "Rules Format Error" + + if not logical_axes: + return "No Logical Axes" + + mapping_parts = [] + for axis in logical_axes: + target = rules_dict.get(axis, "NoRule") + mapping_parts.append(f"{axis}={target}") + + return ", ".join(mapping_parts) + + +def _format_phys_spec_info(named_sharding, logical_axes): + """Helper to format physical spec string.""" + phys_spec = getattr(named_sharding, "spec", None) + + if phys_spec is None: + return f"Axes={logical_axes} -> Spec=None" + + if logical_axes is None: + return f"Axes=None -> Spec={phys_spec}" + + if len(logical_axes) != len(phys_spec): + return f"Axes={logical_axes} -> Spec={phys_spec}" + + pairs = [] + for logic_name, phys_axis in zip(logical_axes, phys_spec): + axis_str = str(phys_axis) if phys_axis else "None" + pairs.append(f"{logic_name}={axis_str}") + + return ", ".join(pairs) + + def maybe_shard_with_logical( inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0 ): @@ -75,76 +154,19 @@ def maybe_shard_with_logical( active_rules = rules if rules is not None else _GLOBAL_LOGICAL_RULES named_sharding = create_sharding(mesh, logical_axes, rules=active_rules) + if debug_sharding and hasattr(inputs, "shape"): max_logging.info("=" * 120) max_logging.info(" Tracing logical axes for activations during JIT compilation.") max_logging.info("=" * 120) - caller_info = "Unknown" - # ----(Caller Detection) --- - try: - stack = inspect.stack() - frame_idx = 2 + extra_stack_level - if len(stack) > frame_idx: - frame_basic = stack[frame_idx] - filename = frame_basic.filename.split("/")[-1] - lineno = frame_basic.lineno - caller_info = f"{filename}:{lineno}" - - # search up to 10 frames for better caller context - for i in range(frame_idx, min(frame_idx + 10, len(stack))): - frame = stack[i].frame - if "self" in frame.f_locals: - obj = frame.f_locals["self"] - if hasattr(obj, "__class__"): - cls_name = obj.__class__.__name__ - if cls_name not in ["Module", "object", "PjitFunction"]: - caller_info = cls_name - break - - func_name = stack[i].function - if func_name not in ["", "__call__", "setup", "apply", "wrapper", "maybe_shard_with_logical"]: - if ":" in caller_info: - caller_info = func_name - except Exception as e: - caller_info = f"Err:{str(e)}" - - # ---(Unresolved Rules) --- - unresolved_str = "" - if active_rules is None: - unresolved_str = "Implicit/Global Rules (Not Found)" - else: - try: - rules_dict = active_rules if isinstance(active_rules, dict) else dict(active_rules) - - mapping_parts = [] - if logical_axes: - for axis in logical_axes: - target = rules_dict.get(axis, "NoRule") - mapping_parts.append(f"{axis}={target}") - unresolved_str = ", ".join(mapping_parts) - else: - unresolved_str = "No Logical Axes" - except: - unresolved_str = "Rules Format Error" - - # --- (Physical Spec) --- - phys_spec = getattr(named_sharding, "spec", None) - resolved_str = "" - try: - if phys_spec is not None and logical_axes is not None and len(logical_axes) == len(phys_spec): - pairs = [] - for logic_name, phys_axis in zip(logical_axes, phys_spec): - axis_str = str(phys_axis) if phys_axis else "None" - pairs.append(f"{logic_name}={axis_str}") - resolved_str = ", ".join(pairs) - else: - resolved_str = f"Axes={logical_axes} -> Spec={phys_spec}" - except: - resolved_str = f"Axes={logical_axes} -> Spec={phys_spec}" + stack = inspect.stack() + caller_info = _get_caller_info(stack, 2 + extra_stack_level) + unresolved_str = _format_rule_info(active_rules, logical_axes) + resolved_str = _format_phys_spec_info(named_sharding, logical_axes) shape_str = str(jax.typeof(inputs)) + # format: Caller [RULES] -> [FINAL] Shape - # max_logging.info(f"{caller_info:<20} [RULES: {unresolved_str}] -> [FINAL: {resolved_str}] shape={shape_str}") max_logging.info(f"{caller_info}") max_logging.info(f"[RULES: {unresolved_str}]") max_logging.info(f"[FINAL: {resolved_str}] shape={shape_str}") diff --git a/src/MaxText/train_compile.py b/src/MaxText/train_compile.py index 05c69d8731..19ec404448 100644 --- a/src/MaxText/train_compile.py +++ b/src/MaxText/train_compile.py @@ -214,6 +214,7 @@ def main(argv: Sequence[str]) -> None: config = pyconfig.initialize(argv) validate_config(config) + sharding.set_global_logical_rules(config.logical_axis_rules) # Create target mesh topology_mesh = get_topology_mesh(config) From 9fa76b706c9b0b1adf443053f623bc576a48e549 Mon Sep 17 00:00:00 2001 From: Sharon Yu Date: Fri, 23 Jan 2026 00:03:28 +0000 Subject: [PATCH 4/9] resolve comments --- src/MaxText/maxtext_utils.py | 78 +++++++-------- src/MaxText/model_creation_utils.py | 3 +- src/MaxText/sharding.py | 141 +++++----------------------- src/MaxText/train_compile.py | 13 +-- src/MaxText/train_utils.py | 7 +- 5 files changed, 75 insertions(+), 167 deletions(-) diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index c793d1490c..b1565cadd0 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -19,8 +19,6 @@ import pickle import os -from collections import defaultdict - from flax import linen as nn from flax.linen import partitioning as nn_partitioning from flax.training import train_state @@ -29,7 +27,6 @@ from jax.experimental import mesh_utils from jax.experimental.serialize_executable import deserialize_and_load -from jax.sharding import PartitionSpec as P import jax import jax.numpy as jnp @@ -947,6 +944,16 @@ def setup_initial_state( return state, state_mesh_annotations, state_mesh_shardings, data_iterator +def get_logical_annotations(model, tx, config, rng, mesh, is_training=True): + init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng) + + with nn_partitioning.axis_rules(config.logical_axis_rules): + abstract_state = jax.eval_shape(init_state_partial) + + logical_annotations = nn.get_partition_spec(abstract_state) + return logical_annotations + + def get_abstract_state(model, tx, config, rng, mesh, is_training=True): """Get a shaped abstraction of the state (including optimizer)""" init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng) @@ -1230,48 +1237,45 @@ def schedule(step): return optax.join_schedules(pieces, boundaries) -def print_shardings_params(params, params_sharding, mesh, state_logical_annotations=None, logical_axis_rules=None): - """Print state shardings.""" +def print_shardings_params(params, params_sharding, mesh, logical_annotations=None): + """ + Print state shardings comparing Logical Definition vs Physical Result. + Simplified version: Directly prints logical annotations without reverse mapping. + """ + if not hasattr(params, "params"): + params = {"params": params} + if not hasattr(params_sharding, "params"): + params_sharding = {"params": params_sharding} + leaves_params, _ = jax.tree_util.tree_flatten_with_path(params) leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding) - leaves_rule_values = [] - if state_logical_annotations and hasattr(state_logical_annotations, "params"): - leaves_rule_values, _ = jax.tree_util.tree_flatten_with_path(state_logical_annotations.params) - else: - leaves_rule_values = [(None, None)] * len(leaves_params) + leaves_logical = [] + has_logical = False + if logical_annotations and hasattr(logical_annotations, "params"): + try: + leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations.params) + if len(leaves_params) == len(leaves_logical): + has_logical = True + else: + max_logging.warning("Warning: Logical annotations tree structure mismatch. Skipping logical info.") + except Exception as e: # pylint: disable=broad-exception-caught + max_logging.warning(f"Warning: Failed to process logical annotations: {e}. Skipping logical info.") - if not len(leaves_params) == len(leaves_sharding) == len(leaves_rule_values): - max_logging.warning( - "Warning: Parameter tree structure mismatch between params, shardings," " and logical annotations." - ) + if not has_logical: + leaves_logical = [(None, None)] * len(leaves_params) + + if len(leaves_params) != len(leaves_sharding): + max_logging.warning("Warning: Params and Sharding tree mismatch.") return - # Build a reverse map - rule_value_to_semantic = defaultdict(list) - if logical_axis_rules: - rules_iter = logical_axis_rules.items() if isinstance(logical_axis_rules, dict) else logical_axis_rules - for name, potentials in rules_iter: - if isinstance(potentials, str): - key = (potentials,) - elif potentials is None: - key = (None,) - elif isinstance(potentials, list): - key = tuple(potentials) - elif isinstance(potentials, tuple): - key = potentials - else: - key = (potentials,) + for i, (path, leaf_val) in enumerate(leaves_params): + _, leaf_sharding = leaves_sharding[i] + leaf_logical_val = leaves_logical[i][1] if has_logical else None - key = tuple(p for p in key) - rule_value_to_semantic[key].append(name) + path_str = "/".join(str(p.key if hasattr(p, "key") else getattr(p, "name", "?")) for p in path) - # Header for the entire block ( - max_logging.info("Parameter Path") - max_logging.info("Shape") - max_logging.info("Logical Axes") - max_logging.info("Physical PartitionSpec") - max_logging.info("-" * 120) + shape = str(jax.typeof(leaf_val)) for (path, leaf_val), (_, leaf_sharding), (_, leaf_rule_value) in zip( leaves_params, leaves_sharding, leaves_rule_values diff --git a/src/MaxText/model_creation_utils.py b/src/MaxText/model_creation_utils.py index 942a7e99b1..3ff16b7eff 100644 --- a/src/MaxText/model_creation_utils.py +++ b/src/MaxText/model_creation_utils.py @@ -160,9 +160,8 @@ def create_sharded_state(): maxtext_utils.print_shardings_params( params=sharded_state, params_sharding=out_shardings, + logical_annotations=specs, mesh=model.mesh, - state_logical_annotations=specs, - logical_axis_rules=config.logical_axis_rules, ) if config.load_parameters_path: try: diff --git a/src/MaxText/sharding.py b/src/MaxText/sharding.py index 25ef6fb711..19f28229d2 100644 --- a/src/MaxText/sharding.py +++ b/src/MaxText/sharding.py @@ -29,10 +29,8 @@ from MaxText import max_logging from MaxText.common_types import ShardMode -import inspect _LOGGED_ACTIVATION_SHARDINGS = set() -_GLOBAL_LOGICAL_RULES = None def get_input_data_sharding(config, mesh): @@ -40,20 +38,28 @@ def get_input_data_sharding(config, mesh): return create_sharding(mesh, config.input_data_sharding_logical_axes, rules=config.logical_axis_rules) -def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0): +def maybe_shard_with_name( + inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0, logical_axes=None +): """ In auto shardmode, this function hints inputs follow given named_sharding. In explicit shardmode, this function enforces inputs following named_sharding. """ if inputs is None: return None - if ( - debug_sharding and isinstance(inputs, Tracer) and isinstance(named_sharding, NamedSharding) - ): # only print pspec for JitTracer + if debug_sharding and isinstance(inputs, Tracer) and isinstance(named_sharding, NamedSharding): pspec = remove_size_one_mesh_axis(getattr(named_sharding, "spec"), getattr(named_sharding, "mesh")) - log_key = (str(jax.typeof(inputs)), tuple(pspec), extra_stack_level) + if logical_axes is not None: + logical_str = str(logical_axes) + else: + logical_str = "None" + shape_str = str(jax.typeof(inputs)) + log_key = (shape_str, tuple(pspec), extra_stack_level, logical_str) + if log_key not in _LOGGED_ACTIVATION_SHARDINGS: - max_logging.info(f"{log_key[0]:.<80} {log_key[1]}.", stacklevel=3 + extra_stack_level) + max_logging.info( + f"Activation: {logical_str:<40} -> {str(tuple(pspec)):<30} {shape_str}", stacklevel=3 + extra_stack_level + ) _LOGGED_ACTIVATION_SHARDINGS.add(log_key) if shard_mode == ShardMode.EXPLICIT: return reshard(inputs, named_sharding) @@ -61,118 +67,23 @@ def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=Fal return jax.lax.with_sharding_constraint(inputs, named_sharding) -def _get_caller_info(stack, frame_idx): - """Helper function to extract caller information to reduce nesting.""" - try: - if len(stack) <= frame_idx: - return "Unknown" - - frame_basic = stack[frame_idx] - filename = frame_basic.filename.split("/")[-1] - lineno = frame_basic.lineno - caller_name = f"{filename}:{lineno}" - - # Search up to 10 frames for better caller context - for i in range(frame_idx, min(frame_idx + 10, len(stack))): - frame = stack[i].frame - - # Check for Class Name - if "self" in frame.f_locals: - obj = frame.f_locals["self"] - if hasattr(obj, "__class__"): - cls_name = obj.__class__.__name__ - if cls_name not in ["Module", "object", "PjitFunction"]: - return cls_name - - # Check for Function Name - func_name = stack[i].function - ignored_funcs = ["", "__call__", "setup", "apply", "wrapper", "maybe_shard_with_logical"] - if func_name not in ignored_funcs: - # If we only have file:line, prefer the function name - if ":" in caller_name: - caller_name = func_name - - return caller_name - - except (AttributeError, KeyError, IndexError): - return "Unknown(Error)" - - -def _format_rule_info(active_rules, logical_axes): - """Helper to format rule string to avoid broad exception catch.""" - if active_rules is None: - return "Implicit/Global Rules (Not Found)" - - try: - rules_dict = active_rules if isinstance(active_rules, dict) else dict(active_rules) - except (ValueError, TypeError): - return "Rules Format Error" - - if not logical_axes: - return "No Logical Axes" - - mapping_parts = [] - for axis in logical_axes: - target = rules_dict.get(axis, "NoRule") - mapping_parts.append(f"{axis}={target}") - - return ", ".join(mapping_parts) - - -def _format_phys_spec_info(named_sharding, logical_axes): - """Helper to format physical spec string.""" - phys_spec = getattr(named_sharding, "spec", None) - - if phys_spec is None: - return f"Axes={logical_axes} -> Spec=None" - - if logical_axes is None: - return f"Axes=None -> Spec={phys_spec}" - - if len(logical_axes) != len(phys_spec): - return f"Axes={logical_axes} -> Spec={phys_spec}" - - pairs = [] - for logic_name, phys_axis in zip(logical_axes, phys_spec): - axis_str = str(phys_axis) if phys_axis else "None" - pairs.append(f"{logic_name}={axis_str}") - - return ", ".join(pairs) - - def maybe_shard_with_logical( inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0 ): """ A wrapper of maybe_shard_with_name when logical axes are inputs - Features: - - Auto-fetches global rules if not provided. - - Prints "Logical Rules -> Physical Spec" mapping logs. """ if inputs is None: return None - - active_rules = rules if rules is not None else _GLOBAL_LOGICAL_RULES - named_sharding = create_sharding(mesh, logical_axes, rules=active_rules) - - if debug_sharding and hasattr(inputs, "shape"): - max_logging.info("=" * 120) - max_logging.info(" Tracing logical axes for activations during JIT compilation.") - max_logging.info("=" * 120) - - stack = inspect.stack() - caller_info = _get_caller_info(stack, 2 + extra_stack_level) - unresolved_str = _format_rule_info(active_rules, logical_axes) - resolved_str = _format_phys_spec_info(named_sharding, logical_axes) - shape_str = str(jax.typeof(inputs)) - - # format: Caller [RULES] -> [FINAL] Shape - max_logging.info(f"{caller_info}") - max_logging.info(f"[RULES: {unresolved_str}]") - max_logging.info(f"[FINAL: {resolved_str}] shape={shape_str}") - max_logging.info("=" * 120) - - return maybe_shard_with_name(inputs, named_sharding, shard_mode, extra_stack_level=extra_stack_level + 1) + named_sharding = create_sharding(mesh, logical_axes, rules=rules) + return maybe_shard_with_name( + inputs, + named_sharding, + shard_mode, + debug_sharding=debug_sharding, + extra_stack_level=extra_stack_level + 1, + logical_axes=logical_axes, + ) def remove_size_one_mesh_axis(spec, mesh): @@ -688,9 +599,3 @@ def all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules, sha # Apply the constraint to the model's current variables. This tells JAX to # gather the weights into this layout. return maybe_shard_with_name(variables, physical_constraint_no_fsdp, shard_mode=shard_mode) - - -def set_global_logical_rules(rules): - """Allows external modules (like train_compile) to inject logical rules.""" - global _GLOBAL_LOGICAL_RULES - _GLOBAL_LOGICAL_RULES = rules diff --git a/src/MaxText/train_compile.py b/src/MaxText/train_compile.py index 19ec404448..2a365a1074 100644 --- a/src/MaxText/train_compile.py +++ b/src/MaxText/train_compile.py @@ -100,16 +100,19 @@ def get_shaped_inputs(topology_mesh, config): shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype) # Shaped state - abstract_state, state_logical_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state( + abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state( model, tx, config, example_rng, topology_mesh ) + # unsharded logical annotations + logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, example_rng, topology_mesh) + # Shaped batch shaped_batch = maxtext_utils.get_shaped_batch(config) shaped_train_args = (abstract_state, shaped_batch, shaped_rng) shaped_train_kwargs = {} - return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, state_logical_annotations, model + return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model def jit_and_compile( @@ -214,7 +217,6 @@ def main(argv: Sequence[str]) -> None: config = pyconfig.initialize(argv) validate_config(config) - sharding.set_global_logical_rules(config.logical_axis_rules) # Create target mesh topology_mesh = get_topology_mesh(config) @@ -227,7 +229,7 @@ def main(argv: Sequence[str]) -> None: shaped_train_args, shaped_train_kwargs, state_mesh_shardings, - state_logical_annotations, + logical_annotations, model, ) = get_shaped_inputs(topology_mesh, config) @@ -248,8 +250,7 @@ def main(argv: Sequence[str]) -> None: shaped_train_args[0].params, state_mesh_shardings.params, topology_mesh, - state_logical_annotations, - config.logical_axis_rules, + logical_annotations, ) # Compile diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index e0698eff8b..016a8b242e 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -206,7 +206,7 @@ def setup_train_loop(config, recorder, devices=None): eval_data_iterator, ) - state, state_mesh_annotations, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( + state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager ) @@ -217,10 +217,9 @@ def setup_train_loop(config, recorder, devices=None): # print weights sharding info under debug sharding mode if config.debug_sharding: + logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, init_rng, mesh, is_training=True) max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params( - state.params, state_mesh_shardings.params, model.mesh, state_mesh_annotations, config.logical_axis_rules - ) + maxtext_utils.print_shardings_params(state.params, state_mesh_shardings.params, model.mesh, logical_annotations) if config.use_dpo: abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) From 883f518d49fffc219d43992e85dc0ea3394f841f Mon Sep 17 00:00:00 2001 From: Sharon Yu Date: Mon, 26 Jan 2026 22:24:43 +0000 Subject: [PATCH 5/9] fix comments --- src/MaxText/maxtext_utils.py | 19 +++++++-------- src/MaxText/model_creation_utils.py | 2 +- src/MaxText/sharding.py | 37 ++++++++++++++++------------- 3 files changed, 30 insertions(+), 28 deletions(-) diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index b1565cadd0..5e55fdf7f7 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -1246,21 +1246,18 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No params = {"params": params} if not hasattr(params_sharding, "params"): params_sharding = {"params": params_sharding} + if logical_annotations and not hasattr(logical_annotations, "params"): + logical_annotations = {"params": logical_annotations} leaves_params, _ = jax.tree_util.tree_flatten_with_path(params) leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding) + leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations.params) - leaves_logical = [] - has_logical = False - if logical_annotations and hasattr(logical_annotations, "params"): - try: - leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations.params) - if len(leaves_params) == len(leaves_logical): - has_logical = True - else: - max_logging.warning("Warning: Logical annotations tree structure mismatch. Skipping logical info.") - except Exception as e: # pylint: disable=broad-exception-caught - max_logging.warning(f"Warning: Failed to process logical annotations: {e}. Skipping logical info.") + for i, ((path, leaf_val), (_, leaf_sharding)) in enumerate(zip(leaves_params, leaves_sharding)): + path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) + shape = jax.typeof(leaf_val) + pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) + pspec_str = str(tuple(pspec)) if not has_logical: leaves_logical = [(None, None)] * len(leaves_params) diff --git a/src/MaxText/model_creation_utils.py b/src/MaxText/model_creation_utils.py index 3ff16b7eff..cd86cfc22c 100644 --- a/src/MaxText/model_creation_utils.py +++ b/src/MaxText/model_creation_utils.py @@ -160,8 +160,8 @@ def create_sharded_state(): maxtext_utils.print_shardings_params( params=sharded_state, params_sharding=out_shardings, - logical_annotations=specs, mesh=model.mesh, + logical_annotations=specs, ) if config.load_parameters_path: try: diff --git a/src/MaxText/sharding.py b/src/MaxText/sharding.py index 19f28229d2..6d74587171 100644 --- a/src/MaxText/sharding.py +++ b/src/MaxText/sharding.py @@ -31,6 +31,7 @@ _LOGGED_ACTIVATION_SHARDINGS = set() +_LOGGED_LOGICAL_AXES = set() def get_input_data_sharding(config, mesh): @@ -38,28 +39,20 @@ def get_input_data_sharding(config, mesh): return create_sharding(mesh, config.input_data_sharding_logical_axes, rules=config.logical_axis_rules) -def maybe_shard_with_name( - inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0, logical_axes=None -): +def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0): """ In auto shardmode, this function hints inputs follow given named_sharding. In explicit shardmode, this function enforces inputs following named_sharding. """ if inputs is None: return None - if debug_sharding and isinstance(inputs, Tracer) and isinstance(named_sharding, NamedSharding): + if ( + debug_sharding and isinstance(inputs, Tracer) and isinstance(named_sharding, NamedSharding) + ): # only print pspec for JitTracer pspec = remove_size_one_mesh_axis(getattr(named_sharding, "spec"), getattr(named_sharding, "mesh")) - if logical_axes is not None: - logical_str = str(logical_axes) - else: - logical_str = "None" - shape_str = str(jax.typeof(inputs)) - log_key = (shape_str, tuple(pspec), extra_stack_level, logical_str) - + log_key = (str(jax.typeof(inputs)), tuple(pspec), extra_stack_level) if log_key not in _LOGGED_ACTIVATION_SHARDINGS: - max_logging.info( - f"Activation: {logical_str:<40} -> {str(tuple(pspec)):<30} {shape_str}", stacklevel=3 + extra_stack_level - ) + max_logging.info(f"{log_key[0]:.<80} {log_key[1]}.", stacklevel=3 + extra_stack_level) _LOGGED_ACTIVATION_SHARDINGS.add(log_key) if shard_mode == ShardMode.EXPLICIT: return reshard(inputs, named_sharding) @@ -75,14 +68,26 @@ def maybe_shard_with_logical( """ if inputs is None: return None + named_sharding = create_sharding(mesh, logical_axes, rules=rules) + + if debug_sharding and isinstance(inputs, Tracer): + log_key = (str(jax.typeof(inputs)), logical_axes, extra_stack_level) + + if log_key not in _LOGGED_LOGICAL_AXES: + pspec = remove_size_one_mesh_axis(getattr(named_sharding, "spec"), getattr(named_sharding, "mesh")) + pspec_str = str(tuple(pspec)) if pspec else "None" + + max_logging.info(f"Logical: {log_key[0]:.<60} {log_key[1]}", stacklevel=3 + extra_stack_level) + max_logging.info(f"{log_key[0]:.<80} {pspec_str}.", stacklevel=3 + extra_stack_level) + _LOGGED_LOGICAL_AXES.add(log_key) + return maybe_shard_with_name( inputs, named_sharding, shard_mode, - debug_sharding=debug_sharding, + debug_sharding=False, extra_stack_level=extra_stack_level + 1, - logical_axes=logical_axes, ) From 3bd533e12f10e0a77e9648d601e67462e0d422fe Mon Sep 17 00:00:00 2001 From: Sharon Yu Date: Tue, 27 Jan 2026 01:29:34 +0000 Subject: [PATCH 6/9] resolve conflict --- src/MaxText/maxtext_utils.py | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index 5e55fdf7f7..70e2235cd4 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -1253,34 +1253,17 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding) leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations.params) - for i, ((path, leaf_val), (_, leaf_sharding)) in enumerate(zip(leaves_params, leaves_sharding)): + for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical): path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) shape = jax.typeof(leaf_val) pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) pspec_str = str(tuple(pspec)) + logical_str = str(leaf_logical_val) - if not has_logical: - leaves_logical = [(None, None)] * len(leaves_params) + message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" + max_logging.info(message) - if len(leaves_params) != len(leaves_sharding): - max_logging.warning("Warning: Params and Sharding tree mismatch.") - return - - for i, (path, leaf_val) in enumerate(leaves_params): - _, leaf_sharding = leaves_sharding[i] - leaf_logical_val = leaves_logical[i][1] if has_logical else None - - path_str = "/".join(str(p.key if hasattr(p, "key") else getattr(p, "name", "?")) for p in path) - - shape = str(jax.typeof(leaf_val)) - - for (path, leaf_val), (_, leaf_sharding), (_, leaf_rule_value) in zip( - leaves_params, leaves_sharding, leaves_rule_values - ): - path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) - shape = jax.typeof(leaf_val) - pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) - max_logging.log(f"{path_str:.<80} {shape} {tuple(pspec)}") + print(flush=True) def maybe_dump_jaxpr(config, p_train_step, train_step_inputs): From 22dfcbee8e9dc7e1015c627b9101277bccf2f922 Mon Sep 17 00:00:00 2001 From: Sharon Yu Date: Tue, 27 Jan 2026 23:57:18 +0000 Subject: [PATCH 7/9] fix comments --- src/MaxText/max_utils.py | 11 ----------- src/MaxText/maxtext_utils.py | 6 ++---- src/MaxText/sharding.py | 8 ++------ src/MaxText/train_compile.py | 2 +- 4 files changed, 5 insertions(+), 22 deletions(-) diff --git a/src/MaxText/max_utils.py b/src/MaxText/max_utils.py index 38c0e43495..510878f9be 100644 --- a/src/MaxText/max_utils.py +++ b/src/MaxText/max_utils.py @@ -1032,14 +1032,3 @@ def transformer_engine_context(): yield except (ImportError, AttributeError): yield - - -def print_mesh_axes_info(mesh: jax.sharding.Mesh): - """Prints all mesh axes and their sizes in a single comma-separated line.""" - if not mesh.shape: - max_logging.info("Mesh Axes: (Empty Mesh)") - return - - axis_info = [f"{axis_name}: {axis_size}" for axis_name, axis_size in mesh.shape.items()] - info_str = "Mesh Axes: (" + ", ".join(axis_info) + ")" - max_logging.info(info_str) diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index 70e2235cd4..21d5046a6a 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -947,10 +947,9 @@ def setup_initial_state( def get_logical_annotations(model, tx, config, rng, mesh, is_training=True): init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng) - with nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): abstract_state = jax.eval_shape(init_state_partial) - - logical_annotations = nn.get_partition_spec(abstract_state) + logical_annotations = nn.get_partition_spec(abstract_state) return logical_annotations @@ -1240,7 +1239,6 @@ def schedule(step): def print_shardings_params(params, params_sharding, mesh, logical_annotations=None): """ Print state shardings comparing Logical Definition vs Physical Result. - Simplified version: Directly prints logical annotations without reverse mapping. """ if not hasattr(params, "params"): params = {"params": params} diff --git a/src/MaxText/sharding.py b/src/MaxText/sharding.py index 6d74587171..a8f83195eb 100644 --- a/src/MaxText/sharding.py +++ b/src/MaxText/sharding.py @@ -52,7 +52,7 @@ def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=Fal pspec = remove_size_one_mesh_axis(getattr(named_sharding, "spec"), getattr(named_sharding, "mesh")) log_key = (str(jax.typeof(inputs)), tuple(pspec), extra_stack_level) if log_key not in _LOGGED_ACTIVATION_SHARDINGS: - max_logging.info(f"{log_key[0]:.<80} {log_key[1]}.", stacklevel=3 + extra_stack_level) + max_logging.info(f"Physical: {log_key[0]:.<80} {log_key[1]}.", stacklevel=3 + extra_stack_level) _LOGGED_ACTIVATION_SHARDINGS.add(log_key) if shard_mode == ShardMode.EXPLICIT: return reshard(inputs, named_sharding) @@ -75,18 +75,14 @@ def maybe_shard_with_logical( log_key = (str(jax.typeof(inputs)), logical_axes, extra_stack_level) if log_key not in _LOGGED_LOGICAL_AXES: - pspec = remove_size_one_mesh_axis(getattr(named_sharding, "spec"), getattr(named_sharding, "mesh")) - pspec_str = str(tuple(pspec)) if pspec else "None" - max_logging.info(f"Logical: {log_key[0]:.<60} {log_key[1]}", stacklevel=3 + extra_stack_level) - max_logging.info(f"{log_key[0]:.<80} {pspec_str}.", stacklevel=3 + extra_stack_level) _LOGGED_LOGICAL_AXES.add(log_key) return maybe_shard_with_name( inputs, named_sharding, shard_mode, - debug_sharding=False, + debug_sharding=debug_sharding, extra_stack_level=extra_stack_level + 1, ) diff --git a/src/MaxText/train_compile.py b/src/MaxText/train_compile.py index 2a365a1074..8428655ba1 100644 --- a/src/MaxText/train_compile.py +++ b/src/MaxText/train_compile.py @@ -245,7 +245,7 @@ def main(argv: Sequence[str]) -> None: # print weights sharding info under debug sharding mode if config.debug_sharding: - max_utils.print_mesh_axes_info(topology_mesh) + max_utils.print_non_trivial_mesh_axis(topology_mesh) maxtext_utils.print_shardings_params( shaped_train_args[0].params, state_mesh_shardings.params, From d84d4cac74351e87d35ae1bf1e2759c524029f30 Mon Sep 17 00:00:00 2001 From: Sharon Yu Date: Wed, 28 Jan 2026 17:22:59 +0000 Subject: [PATCH 8/9] fix CL unit test issue --- src/MaxText/maxtext_utils.py | 2 +- src/MaxText/train_compile.py | 2 +- src/MaxText/train_utils.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index 21d5046a6a..52be815772 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -1249,7 +1249,7 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No leaves_params, _ = jax.tree_util.tree_flatten_with_path(params) leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding) - leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations.params) + leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations) for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical): path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) diff --git a/src/MaxText/train_compile.py b/src/MaxText/train_compile.py index 8428655ba1..d643e1c672 100644 --- a/src/MaxText/train_compile.py +++ b/src/MaxText/train_compile.py @@ -250,7 +250,7 @@ def main(argv: Sequence[str]) -> None: shaped_train_args[0].params, state_mesh_shardings.params, topology_mesh, - logical_annotations, + logical_annotations.params, ) # Compile diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index 016a8b242e..5f3256aac7 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -219,7 +219,7 @@ def setup_train_loop(config, recorder, devices=None): if config.debug_sharding: logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, init_rng, mesh, is_training=True) max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params(state.params, state_mesh_shardings.params, model.mesh, logical_annotations) + maxtext_utils.print_shardings_params(state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params) if config.use_dpo: abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) From 87a6b93b5318cf90d6c337cbd06a7407d7647d93 Mon Sep 17 00:00:00 2001 From: Sharon Yu Date: Wed, 28 Jan 2026 17:26:02 +0000 Subject: [PATCH 9/9] fix format --- src/MaxText/train_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index 5f3256aac7..9dba42cb32 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -219,7 +219,9 @@ def setup_train_loop(config, recorder, devices=None): if config.debug_sharding: logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, init_rng, mesh, is_training=True) max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params(state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params) + maxtext_utils.print_shardings_params( + state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params + ) if config.use_dpo: abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)