From 182d59135fbe10440db1c74767759fa8444486f3 Mon Sep 17 00:00:00 2001 From: Elisa Tsai Date: Mon, 12 Jan 2026 22:49:06 +0000 Subject: [PATCH] Revert "Integrate tokamax ring attention as optional attention kernel for WAN 2.1" This reverts commit f68c7b0c6e4b8030685d43b22b2a22b8f0b9da40. --- src/maxdiffusion/models/attention_flax.py | 104 ++++++++-------------- src/maxdiffusion/pyconfig.py | 8 +- 2 files changed, 40 insertions(+), 72 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 218b3b79..f8ba9310 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -27,7 +27,6 @@ from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_attention_mask from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel -from tokamax._src.ops.experimental.tpu.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel from einops import rearrange from .. import common_types, max_logging @@ -305,16 +304,7 @@ def wrap_flash_attention(query, key, value): mask=mask, q_seq_shards=1, # the sizes of the axis is sharding over seq_len config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name), - save_residuals=True if "ring" in attention_kernel else False, - ) - elif attention_kernel == "tokamax_ring": - mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),) - splash_kernel = tokamax_ring_attention_kernel.make_ring_attention( - mask=mask, - is_mqa=False, - config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name), - save_residuals=True, - ring_axis="fsdp", + save_residuals=True if attention_kernel == "ring" else False, ) else: splash_kernel = splash_attention_kernel.make_splash_mha( @@ -322,75 +312,54 @@ def wrap_flash_attention(query, key, value): head_shards=1, # the sizes of the axis is sharding over heads q_seq_shards=1, # the sizes of the axis is sharding over seq_len block_sizes=block_sizes, - save_residuals=True if "ring" in attention_kernel else False, + save_residuals=True if attention_kernel == "ring" else False, residual_checkpoint_name=residual_checkpoint_name ) + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) - if attention_kernel == "tokamax_ring": - # For tokamax_ring, use the kernel directly without vmap - # The ring attention kernel handles the ring topology internally - if not mask_padding_tokens: - segment_ids = None - attention_output = splash_kernel( - fwd_mask_info=None, - dkv_mask_info=None, - q=query, - k=key, - v=value, - segment_ids=segment_ids, - is_mqa=False, - config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name), - mask_value=-jnp.inf, - mask_function=None, - fwd_mask_sparsity=1.0, - save_residuals=True, - ) + if not mask_padding_tokens: + segment_ids = None + if attention_kernel in ["flash", "tokamax_flash"]: + attention_output = vmapped_splash(query, key, value, segment_ids) else: - vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) - - if not mask_padding_tokens: - segment_ids = None - if attention_kernel in ["flash", "tokamax_flash"]: - attention_output = vmapped_splash(query, key, value, segment_ids) - else: - if num_fsdp_shards > 1: - out, (lse,) = vmapped_splash(query, key, value, segment_ids) - m = lse.astype(jnp.float32) - l = jnp.exp(lse - m) - o = out.astype(jnp.float32) * l[..., None] + if num_fsdp_shards > 1: + out, (lse,) = vmapped_splash(query, key, value, segment_ids) + m = lse.astype(jnp.float32) + l = jnp.exp(lse - m) + o = out.astype(jnp.float32) * l[..., None] - perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)] + perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)] - k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm) - v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm) + k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm) + v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm) - def ring_scan_body(carry, _): - m, l, o, k_current, v_current = carry - k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm) - v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm) + def ring_scan_body(carry, _): + m, l, o, k_current, v_current = carry + k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm) + v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm) - out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids) + out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids) - m_chunk = lse_chunk.astype(jnp.float32) - m_old = m - m = jnp.maximum(m_old, m_chunk) + m_chunk = lse_chunk.astype(jnp.float32) + m_old = m + m = jnp.maximum(m_old, m_chunk) - exp_m_diff = jnp.exp(m_old - m) - exp_m_chunk_diff = jnp.exp(m_chunk - m) + exp_m_diff = jnp.exp(m_old - m) + exp_m_chunk_diff = jnp.exp(m_chunk - m) - l = l * exp_m_diff + jnp.exp(lse_chunk - m) - o = o * exp_m_diff[..., None] - o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32) + l = l * exp_m_diff + jnp.exp(lse_chunk - m) + o = o * exp_m_diff[..., None] + o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32) - # Return the updated state for the next iteration - return (m, l, o, k_next, v_next), None + # Return the updated state for the next iteration + return (m, l, o, k_next, v_next), None - initial_carry = (m, l, o, k1, v1) - (m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1) + initial_carry = (m, l, o, k1, v1) + (m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1) - attention_output = o_final / l_final[..., None] - else: - raise ValueError("ring attention requires fsdp > 1") + attention_output = o_final / l_final[..., None] + else: + raise ValueError("ring attention requires fsdp > 1") return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) @@ -566,7 +535,7 @@ def _apply_attention( mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name=residual_checkpoint_name, ) - elif "ring" in attention_kernel: + elif attention_kernel == "ring": return _tpu_flash_attention( query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel, mask_padding_tokens=mask_padding_tokens, @@ -577,7 +546,6 @@ def _apply_attention( raise ValueError(f"Unexpected attention kernel {attention_kernel=}.") - def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): """Multi-head dot product attention with a limited number of queries.""" num_kv, num_heads, k_features = key.shape[-3:] diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 060cc1bf..27c9f645 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -195,8 +195,8 @@ def user_init(raw_keys): raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) # Verify qkv is sharded across sequence. - if "ring" in raw_keys["attention"] or raw_keys["attention_sharding_uniform"]: - max_logging.log(f"Adding sequence sharding to q and kv if not already present because '{raw_keys['attention']}' contains 'ring' or {raw_keys['attention_sharding_uniform']} is set.") + if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]: + max_logging.log(f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set.") logical_axis_rules = list(raw_keys["logical_axis_rules"]) max_logging.log(f"Initial logical axis rules: {logical_axis_rules}") new_rules = [] @@ -206,12 +206,12 @@ def user_init(raw_keys): logical_axis_rules.append(q_seq_sharding) if kv_seq_sharding not in logical_axis_rules: logical_axis_rules.append(kv_seq_sharding) - if "ring" in raw_keys["attention"]: + if raw_keys["attention"] == "ring": for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES: if ring_attention_axis_rule not in logical_axis_rules: max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}") new_rules.append(ring_attention_axis_rule) - else: # attention contains 'flash' but sequence parallel sharding requested for both self and cross attention + else: # attention =flash but sequence parallel sharding requested for both self and cross attention for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES: if seq_parallel_axis_rule not in logical_axis_rules: max_logging.log(f"Adding sequence parallel attention axis rule {seq_parallel_axis_rule}")