Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
90619c5
Reapply "Attention bug fixes, tokamax splash defaulting logic (#282)"…
Dec 15, 2025
d848983
Reapply "Cross self attention switch (#251)" (#288)
Dec 15, 2025
c29fdc4
Disable unsafe rng
Dec 15, 2025
f68c7b0
Integrate tokamax ring attention as optional attention kernel for WAN…
Dec 17, 2025
8a18686
Merge branch 'main' into elisatsai_disable_unsafe_rng
eltsai Dec 29, 2025
a7fa4f0
Fixed formatting issue
Dec 30, 2025
41d9353
Updated scheduler test values
Dec 30, 2025
d128e32
Updated values based on v5p-8 tests
Dec 30, 2025
70ce989
Fixing ring attention
Jan 5, 2026
ed47e5f
moving kernel init outside the sharding map
Feb 10, 2026
65e7f93
Revert "moving kernel init outside the sharding map"
Feb 15, 2026
a0c377f
jitting and sharding vae, refactored for loops in jitted VAE, 132 sec…
Feb 23, 2026
e7cd3c4
Renaming VAE sharding axis to vae_spatial
Feb 26, 2026
c236d56
Renaming VAE sharding axis to vae_spatial
Feb 26, 2026
9bcd458
ring-attention
coolkp Mar 2, 2026
0e60bbb
Merge remote-tracking branch 'origin/kunjanp-ring-attention' into eli…
Mar 4, 2026
10f2f33
Merge remote-tracking branch 'origin/main' into elisatsai_ring_attention
Mar 4, 2026
ffd7933
fixing attention from merging main
Mar 5, 2026
62e3b06
Fix attention_flax API regression from manual edits regarding context…
Mar 5, 2026
0a7d593
Merge branch 'elisatsai_ring_attention' of https://github.com/AI-Hype…
Mar 5, 2026
115fffa
Added sharding on ROPE
Mar 10, 2026
e04e78d
cfg cache
Mar 9, 2026
5b91824
Merged CFG cache, 220 sec using tokamax_flash
Mar 11, 2026
2d4eae1
Changed profiling logic
Mar 12, 2026
438fefd
Format fix
Mar 16, 2026
dff5c30
Merge remote-tracking branch 'origin/main' into elisatsai_ring_attention
Mar 16, 2026
7293017
updated vae config logic to be the consistent, update xprof logic
Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ activations_dtype: 'bfloat16'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False
vae_spatial: -1 # default to total_device * 2 // (dp)

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
Expand All @@ -60,7 +61,7 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
attention: 'tokamax_flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring
flash_min_seq_length: 0

# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
Expand All @@ -81,9 +82,7 @@ flash_block_sizes: {
"block_q_dkv" : 512,
"block_kv_dkv" : 512,
"block_kv_dkv_compute" : 512,
"block_q_dq" : 512,
"block_kv_dq" : 512,
"use_fused_bwd_kernel": False,
"use_fused_bwd_kernel": True
}
# Use on v6e
# flash_block_sizes: {
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ activations_dtype: 'bfloat16'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False
vae_spatial: -1

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
Expand Down
60 changes: 53 additions & 7 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,20 @@ def get_git_commit_hash():
jax.config.update("jax_use_shardy_partitioner", True)


def call_pipeline(config, pipeline, prompt, negative_prompt):
def call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=None):
"""Call the pipeline with optional num_inference_steps override.

Args:
config: The configuration object.
pipeline: The pipeline to call.
prompt: The prompt(s) to use.
negative_prompt: The negative prompt(s) to use.
num_inference_steps: Optional override for number of inference steps.
If None, uses config.num_inference_steps.
"""
model_key = config.model_name
model_type = config.model_type
steps = num_inference_steps if num_inference_steps is not None else config.num_inference_steps
if model_type == "I2V":
image = load_image(config.image_url)
if model_key == WAN2_1:
Expand All @@ -98,7 +109,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
num_inference_steps=steps,
guidance_scale=config.guidance_scale,
)
elif model_key == WAN2_2:
Expand All @@ -109,7 +120,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
num_inference_steps=steps,
guidance_scale_low=config.guidance_scale_low,
guidance_scale_high=config.guidance_scale_high,
use_cfg_cache=config.use_cfg_cache,
Expand All @@ -124,7 +135,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
num_inference_steps=steps,
guidance_scale=config.guidance_scale,
use_cfg_cache=config.use_cfg_cache,
)
Expand All @@ -135,7 +146,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
num_inference_steps=steps,
guidance_scale_low=config.guidance_scale_low,
guidance_scale_high=config.guidance_scale_high,
use_cfg_cache=config.use_cfg_cache,
Expand Down Expand Up @@ -248,6 +259,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
max_logging.log(f"hardware: {jax.devices()[0].platform}")
max_logging.log(f"number of devices: {jax.device_count()}")
max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}")
max_logging.log(f"vae_spatial: {config.vae_spatial}")
max_logging.log("============================================================")

compile_time = time.perf_counter() - s0
Expand Down Expand Up @@ -276,15 +288,49 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
max_logging.log(f"generation time per video: {generation_time_per_video}")
else:
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
s0 = time.perf_counter()

if config.enable_profiler:
skip_steps = getattr(config, 'skip_first_n_steps_for_profiler', 0)
profiler_steps = getattr(config, 'profiler_steps', config.num_inference_steps)
profile_all = profiler_steps == -1
steps_for_profile = config.num_inference_steps if profile_all else profiler_steps

if profile_all:
max_logging.log(f"Profiler: profiling all {steps_for_profile} inference steps (profiler_steps=-1)")
else:
max_logging.log(f"Profiler: profiling {steps_for_profile} steps out of {config.num_inference_steps} total")
max_logging.log(f"Profiler: skip_first_n_steps={skip_steps}")

def block_if_jax(x):
"""Block until ready if x is a JAX array, otherwise no-op."""
if hasattr(x, 'block_until_ready'):
x.block_until_ready()
return x

for i in range(skip_steps):
max_logging.log(f"Profiler warmup iteration {i + 1}/{skip_steps}")
warmup_videos = call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=steps_for_profile)
# Block until warmup completes
jax.tree_util.tree_map(block_if_jax, warmup_videos)

# Warm up GCS connection by flushing writer before starting profiler
if writer and jax.process_index() == 0:
max_logging.log("Flushing writer to warm up GCS connection before profiler...")
writer.flush()

s0 = time.perf_counter()
max_utils.activate_profiler(config)
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
max_logging.log(f"Profiler: starting profiled run with {steps_for_profile} steps")
profiled_videos = call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=steps_for_profile)
# Wait for all computation to finish before stopping profiler
jax.tree_util.tree_map(block_if_jax, profiled_videos)
max_utils.deactivate_profiler(config)
max_utils.upload_profiler_traces(config)
generation_time_with_profiler = time.perf_counter() - s0
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
if writer and jax.process_index() == 0:
writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)
max_logging.log("Profiler: completed (video not saved)")

return saved_video_path

Expand Down
Empty file.
Loading
Loading