Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 36 additions & 68 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_attention_mask
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
from tokamax._src.ops.experimental.tpu.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
from einops import rearrange
from .. import common_types, max_logging

Expand Down Expand Up @@ -305,92 +304,62 @@ def wrap_flash_attention(query, key, value):
mask=mask,
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
save_residuals=True if "ring" in attention_kernel else False,
)
elif attention_kernel == "tokamax_ring":
mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),)
splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
mask=mask,
is_mqa=False,
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
save_residuals=True,
ring_axis="fsdp",
save_residuals=True if attention_kernel == "ring" else False,
)
else:
splash_kernel = splash_attention_kernel.make_splash_mha(
mask=multi_head_mask,
head_shards=1, # the sizes of the axis is sharding over heads
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
block_sizes=block_sizes,
save_residuals=True if "ring" in attention_kernel else False,
save_residuals=True if attention_kernel == "ring" else False,
residual_checkpoint_name=residual_checkpoint_name
)
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))

if attention_kernel == "tokamax_ring":
# For tokamax_ring, use the kernel directly without vmap
# The ring attention kernel handles the ring topology internally
if not mask_padding_tokens:
segment_ids = None
attention_output = splash_kernel(
fwd_mask_info=None,
dkv_mask_info=None,
q=query,
k=key,
v=value,
segment_ids=segment_ids,
is_mqa=False,
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
mask_value=-jnp.inf,
mask_function=None,
fwd_mask_sparsity=1.0,
save_residuals=True,
)
if not mask_padding_tokens:
segment_ids = None
if attention_kernel in ["flash", "tokamax_flash"]:
attention_output = vmapped_splash(query, key, value, segment_ids)
else:
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))

if not mask_padding_tokens:
segment_ids = None
if attention_kernel in ["flash", "tokamax_flash"]:
attention_output = vmapped_splash(query, key, value, segment_ids)
else:
if num_fsdp_shards > 1:
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
m = lse.astype(jnp.float32)
l = jnp.exp(lse - m)
o = out.astype(jnp.float32) * l[..., None]
if num_fsdp_shards > 1:
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
m = lse.astype(jnp.float32)
l = jnp.exp(lse - m)
o = out.astype(jnp.float32) * l[..., None]

perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]
perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]

k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)
k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)

def ring_scan_body(carry, _):
m, l, o, k_current, v_current = carry
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)
def ring_scan_body(carry, _):
m, l, o, k_current, v_current = carry
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)

out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)

m_chunk = lse_chunk.astype(jnp.float32)
m_old = m
m = jnp.maximum(m_old, m_chunk)
m_chunk = lse_chunk.astype(jnp.float32)
m_old = m
m = jnp.maximum(m_old, m_chunk)

exp_m_diff = jnp.exp(m_old - m)
exp_m_chunk_diff = jnp.exp(m_chunk - m)
exp_m_diff = jnp.exp(m_old - m)
exp_m_chunk_diff = jnp.exp(m_chunk - m)

l = l * exp_m_diff + jnp.exp(lse_chunk - m)
o = o * exp_m_diff[..., None]
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
o = o * exp_m_diff[..., None]
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)

# Return the updated state for the next iteration
return (m, l, o, k_next, v_next), None
# Return the updated state for the next iteration
return (m, l, o, k_next, v_next), None

initial_carry = (m, l, o, k1, v1)
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)
initial_carry = (m, l, o, k1, v1)
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)

attention_output = o_final / l_final[..., None]
else:
raise ValueError("ring attention requires fsdp > 1")
attention_output = o_final / l_final[..., None]
else:
raise ValueError("ring attention requires fsdp > 1")

return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)

Expand Down Expand Up @@ -566,7 +535,7 @@ def _apply_attention(
mask_padding_tokens=mask_padding_tokens,
residual_checkpoint_name=residual_checkpoint_name,
)
elif "ring" in attention_kernel:
elif attention_kernel == "ring":
return _tpu_flash_attention(
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel,
mask_padding_tokens=mask_padding_tokens,
Expand All @@ -577,7 +546,6 @@ def _apply_attention(
raise ValueError(f"Unexpected attention kernel {attention_kernel=}.")



def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
"""Multi-head dot product attention with a limited number of queries."""
num_kv, num_heads, k_features = key.shape[-3:]
Expand Down
8 changes: 4 additions & 4 deletions src/maxdiffusion/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def user_init(raw_keys):

raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
# Verify qkv is sharded across sequence.
if "ring" in raw_keys["attention"] or raw_keys["attention_sharding_uniform"]:
max_logging.log(f"Adding sequence sharding to q and kv if not already present because '{raw_keys['attention']}' contains 'ring' or {raw_keys['attention_sharding_uniform']} is set.")
if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]:
max_logging.log(f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set.")
logical_axis_rules = list(raw_keys["logical_axis_rules"])
max_logging.log(f"Initial logical axis rules: {logical_axis_rules}")
new_rules = []
Expand All @@ -206,12 +206,12 @@ def user_init(raw_keys):
logical_axis_rules.append(q_seq_sharding)
if kv_seq_sharding not in logical_axis_rules:
logical_axis_rules.append(kv_seq_sharding)
if "ring" in raw_keys["attention"]:
if raw_keys["attention"] == "ring":
for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES:
if ring_attention_axis_rule not in logical_axis_rules:
max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}")
new_rules.append(ring_attention_axis_rule)
else: # attention contains 'flash' but sequence parallel sharding requested for both self and cross attention
else: # attention =flash but sequence parallel sharding requested for both self and cross attention
for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES:
if seq_parallel_axis_rule not in logical_axis_rules:
max_logging.log(f"Adding sequence parallel attention axis rule {seq_parallel_axis_rule}")
Expand Down
Loading