Skip to content

Commit 7d00eed

Browse files
committed
Tag and save Pallas LSE statistics to eliminate attention recomputation entirely
1 parent b195959 commit 7d00eed

2 files changed

Lines changed: 2 additions & 1 deletion

File tree

src/maxdiffusion/kernels/splash_attention/splash_attention_kernel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,6 +1264,7 @@ def _splash_attention_fwd(
12641264
max_logit_value=max_logit_value,
12651265
)
12661266
logsumexp = stats["logsumexp"] # save in the config base for the bwd pass
1267+
logsumexp = jax.checkpoint_name(logsumexp, "pallas_logsumexp")
12671268
if config.use_base2_exp: # for user, output values in natural base
12681269
stats["logsumexp"] = stats["logsumexp"] / LOG2E
12691270
stats["max_logits"] = stats["max_logits"] / LOG2E

src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ def setup(self):
495495
# 4. Force strict checkpointing on the Single Wrapper
496496
#RemattedSingleWrapper = nn.remat(ScannedSingleBlockWrapper, prevent_cse=True, policy=cp.checkpoint_dots_with_no_batch_dims)
497497
#RemattedSingleWrapper = nn.remat(ScannedSingleBlockWrapper, prevent_cse=True, policy=cp.offload_dot_with_no_batch_dims(offload_src="device", offload_dst="pinned_host"))
498-
RemattedSingleWrapper = nn.remat(ScannedSingleBlockWrapper, prevent_cse=True, policy=cp.save_only_these_names("attn_output", "lin1_norm_hidden_states"))
498+
RemattedSingleWrapper = nn.remat(ScannedSingleBlockWrapper, prevent_cse=True, policy=cp.save_only_these_names("attn_output", "lin1_norm_hidden_states", "pallas_logsumexp"))
499499

500500
self.scanned_single_blocks = nn.scan(
501501
RemattedSingleWrapper,

0 commit comments

Comments
 (0)