Skip to content

Commit a7345e2

Browse files
committed
generalize across sharding parallelisms
1 parent 5e21db3 commit a7345e2

4 files changed

Lines changed: 23 additions & 15 deletions

File tree

src/maxdiffusion/max_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,4 +654,16 @@ def maybe_initialize_jax_distributed_system(raw_keys):
654654
initialize_jax_for_gpu()
655655
max_logging.log("Jax distributed system initialized on GPU!")
656656
else:
657-
jax.distributed.initialize()
657+
jax.distributed.initialize()
658+
659+
def get_axis_names(axis_key: str, config=None) -> str:
660+
"""Returns the mesh axis names given the logical axis key from config.logical_axis_rules."""
661+
axis_name = ''
662+
if config:
663+
axis_rules = config.logical_axis_rules
664+
else:
665+
axis_rules = nn.get_logical_axis_rules()
666+
for rules in axis_rules:
667+
if rules[0] == axis_key:
668+
axis_name = rules[1]
669+
return axis_name

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
2828
from einops import rearrange
2929
from .. import common_types, max_logging
30+
from .. import max_utils
3031

3132
from . import quantizations
3233

@@ -205,10 +206,7 @@ def _tpu_flash_attention(
205206
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
206207
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
207208
)
208-
fsdp_key = "fsdp"
209-
if "fsdp_tpu" in mesh.shape.keys():
210-
fsdp_key = "fsdp_tpu"
211-
209+
fsdp_key = max_utils.get_axis_names("activation_length")
212210
num_fsdp_shards = mesh.shape[fsdp_key]
213211
query = _reshape_data_for_flash(query, heads)
214212
key = _reshape_data_for_flash(key, heads)

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import jax.numpy as jnp
2121
from flax import nnx
2222
from ...configuration_utils import ConfigMixin
23+
from ... import max_utils
2324
from ..modeling_flax_utils import FlaxModelMixin, get_activation
2425
from ... import common_types
2526
from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput)
@@ -72,7 +73,10 @@ def __init__(
7273
self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0]
7374

7475
# Set sharding dynamically based on out_channels.
75-
num_fsdp_axis_devices = mesh.device_ids.shape[2]
76+
fspd_key = max_utils.get_axis_names("activation_length")
77+
if not fspd_key:
78+
fspd_key = "fsdp"
79+
num_fsdp_axis_devices = mesh.shape[fspd_key]
7680
kernel_sharding = (None, None, None, None, None)
7781
if out_channels % num_fsdp_axis_devices == 0:
7882
kernel_sharding = (None, None, None, None, "conv_out")

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ def prepare_sample_eval(features):
211211
return data_iterator
212212

213213
def start_training(self):
214-
215-
pipeline, opt_state, step = self.checkpointer.load_checkpoint()
214+
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
215+
pipeline, opt_state, step = self.checkpointer.load_checkpoint()
216216
restore_args = {}
217217
if opt_state and step:
218218
restore_args = {"opt_state": opt_state, "step": step}
@@ -362,13 +362,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
362362
example_batch = load_next_batch(train_data_iterator, None, self.config)
363363

364364
# Designate the context parallel axis for sharding
365-
cp_resource = ''
366-
for rules in self.config.logical_axis_rules:
367-
if rules[0] == "activation_length":
368-
if isinstance(rules[1], list):
369-
cp_resource = rules[1][0]
370-
else:
371-
cp_resource = rules[1]
365+
cp_resource = max_utils.get_axis_names("activation_length", config=self.config)
372366
mesh_resource = MeshResource(cp_resource=cp_resource)
373367

374368
with ThreadPoolExecutor(max_workers=1) as executor:

0 commit comments

Comments
 (0)