Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions src/MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,15 @@ def setup_initial_state(
return state, state_mesh_annotations, state_mesh_shardings, data_iterator


def get_logical_annotations(model, tx, config, rng, mesh, is_training=True):
init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng)

with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
abstract_state = jax.eval_shape(init_state_partial)
logical_annotations = nn.get_partition_spec(abstract_state)
return logical_annotations


def get_abstract_state(model, tx, config, rng, mesh, is_training=True):
"""Get a shaped abstraction of the state (including optimizer)"""
init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng)
Expand Down Expand Up @@ -1227,15 +1236,32 @@ def schedule(step):
return optax.join_schedules(pieces, boundaries)


def print_shardings_params(params, params_sharding, mesh):
"""Print state shardings."""
def print_shardings_params(params, params_sharding, mesh, logical_annotations=None):
"""
Print state shardings comparing Logical Definition vs Physical Result.
"""
if not hasattr(params, "params"):
params = {"params": params}
if not hasattr(params_sharding, "params"):
params_sharding = {"params": params_sharding}
if logical_annotations and not hasattr(logical_annotations, "params"):
logical_annotations = {"params": logical_annotations}

leaves_params, _ = jax.tree_util.tree_flatten_with_path(params)
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding)
for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding):
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations)

for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical):
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
shape = jax.typeof(leaf_val)
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
max_logging.log(f"{path_str:.<80} {shape} {tuple(pspec)}")
pspec_str = str(tuple(pspec))
logical_str = str(leaf_logical_val)

message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
max_logging.info(message)

print(flush=True)


def maybe_dump_jaxpr(config, p_train_step, train_step_inputs):
Expand Down
7 changes: 6 additions & 1 deletion src/MaxText/model_creation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,12 @@ def create_sharded_state():
# print weights sharding info under debug sharding mode
if config.debug_sharding:
max_utils.print_non_trivial_mesh_axis(model.mesh)
maxtext_utils.print_shardings_params(sharded_state, out_shardings, model.mesh)
maxtext_utils.print_shardings_params(
params=sharded_state,
params_sharding=out_shardings,
mesh=model.mesh,
logical_annotations=specs,
)
if config.load_parameters_path:
try:
ckptr = ocp.Checkpointer(
Expand Down
18 changes: 16 additions & 2 deletions src/MaxText/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@


_LOGGED_ACTIVATION_SHARDINGS = set()
_LOGGED_LOGICAL_AXES = set()


def get_input_data_sharding(config, mesh):
Expand All @@ -51,7 +52,7 @@ def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=Fal
pspec = remove_size_one_mesh_axis(getattr(named_sharding, "spec"), getattr(named_sharding, "mesh"))
log_key = (str(jax.typeof(inputs)), tuple(pspec), extra_stack_level)
if log_key not in _LOGGED_ACTIVATION_SHARDINGS:
max_logging.info(f"{log_key[0]:.<80} {log_key[1]}.", stacklevel=3 + extra_stack_level)
max_logging.info(f"Physical: {log_key[0]:.<80} {log_key[1]}.", stacklevel=3 + extra_stack_level)
_LOGGED_ACTIVATION_SHARDINGS.add(log_key)
if shard_mode == ShardMode.EXPLICIT:
return reshard(inputs, named_sharding)
Expand All @@ -67,9 +68,22 @@ def maybe_shard_with_logical(
"""
if inputs is None:
return None

named_sharding = create_sharding(mesh, logical_axes, rules=rules)

if debug_sharding and isinstance(inputs, Tracer):
log_key = (str(jax.typeof(inputs)), logical_axes, extra_stack_level)

if log_key not in _LOGGED_LOGICAL_AXES:
max_logging.info(f"Logical: {log_key[0]:.<60} {log_key[1]}", stacklevel=3 + extra_stack_level)
_LOGGED_LOGICAL_AXES.add(log_key)

return maybe_shard_with_name(
inputs, named_sharding, shard_mode, debug_sharding=debug_sharding, extra_stack_level=extra_stack_level + 1
inputs,
named_sharding,
shard_mode,
debug_sharding=debug_sharding,
extra_stack_level=extra_stack_level + 1,
)


Expand Down
28 changes: 24 additions & 4 deletions src/MaxText/train_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,15 @@ def get_shaped_inputs(topology_mesh, config):
model, tx, config, example_rng, topology_mesh
)

# unsharded logical annotations
logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, example_rng, topology_mesh)

# Shaped batch
shaped_batch = maxtext_utils.get_shaped_batch(config)

shaped_train_args = (abstract_state, shaped_batch, shaped_rng)
shaped_train_kwargs = {}
return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model
return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model


def jit_and_compile(
Expand Down Expand Up @@ -160,7 +163,13 @@ def is_oom(argv: Sequence[str]) -> bool:
max_utils.print_system_information()

# Get shaped inputs
shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(topology_mesh, config)
(
shaped_train_args,
shaped_train_kwargs,
state_mesh_shardings,
_,
model,
) = get_shaped_inputs(topology_mesh, config)

# Get data sharding
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
Expand Down Expand Up @@ -216,7 +225,13 @@ def main(argv: Sequence[str]) -> None:
max_utils.print_system_information()

# Get shaped inputs
shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(topology_mesh, config)
(
shaped_train_args,
shaped_train_kwargs,
state_mesh_shardings,
logical_annotations,
model,
) = get_shaped_inputs(topology_mesh, config)

# Get data sharding
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
Expand All @@ -231,7 +246,12 @@ def main(argv: Sequence[str]) -> None:
# print weights sharding info under debug sharding mode
if config.debug_sharding:
max_utils.print_non_trivial_mesh_axis(topology_mesh)
maxtext_utils.print_shardings_params(shaped_train_args[0].params, state_mesh_shardings.params, topology_mesh)
maxtext_utils.print_shardings_params(
shaped_train_args[0].params,
state_mesh_shardings.params,
topology_mesh,
logical_annotations.params,
)

# Compile
print("Jitting and compiling train step...", flush=True)
Expand Down
5 changes: 4 additions & 1 deletion src/MaxText/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,11 @@ def setup_train_loop(config, recorder, devices=None):

# print weights sharding info under debug sharding mode
if config.debug_sharding:
logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, init_rng, mesh, is_training=True)
max_utils.print_non_trivial_mesh_axis(model.mesh)
maxtext_utils.print_shardings_params(state.params, state_mesh_shardings.params, model.mesh)
maxtext_utils.print_shardings_params(
state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params
)

if config.use_dpo:
abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sharding_compare_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str)
validate_config(config)

topology_mesh = get_topology_mesh(config)
_, _, state_mesh_shardings, _ = get_shaped_inputs(topology_mesh, config)
_, _, state_mesh_shardings, _, _ = get_shaped_inputs(topology_mesh, config)
actual_json = named_shardings_to_json(state_mesh_shardings)
expected_json = load_named_sharding_json(json_path)

Expand Down
2 changes: 1 addition & 1 deletion tests/utils/sharding_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def main(argv: Sequence[str]) -> None:

try:
topology_mesh = get_topology_mesh(config)
_, _, state_mesh_shardings, _ = get_shaped_inputs(topology_mesh, config)
_, _, state_mesh_shardings, _, _ = get_shaped_inputs(topology_mesh, config)
except: # pylint: disable=bare-except
state_mesh_shardings = {}

Expand Down
Loading