diff --git a/src/maxdiffusion/configs/base_wan_lora_14b.yml b/src/maxdiffusion/configs/base_wan_lora_14b.yml new file mode 100644 index 00000000..2935453c --- /dev/null +++ b/src/maxdiffusion/configs/base_wan_lora_14b.yml @@ -0,0 +1,347 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 + +pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' +model_name: wan2.1 + +# Overrides the transformer from pretrained_model_name_or_path +wan_transformer_pretrained_model_name_or_path: '' + +unet_checkpoint: '' +revision: '' +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# Replicates vae across devices instead of using the model's sharding annotations for sharding. +replicate_vae: False + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" +# Use jax.lax.scan for transformer layers +scan_layers: False + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + +# Set true to load weights from pytorch +from_pt: True +split_head_dim: True +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring +flash_min_seq_length: 4096 + +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. +attention_sharding_uniform: True +dropout: 0.1 + +#flash_block_sizes: { +# "block_q" : 1024, +# "block_kv_compute" : 256, +# "block_kv" : 1024, +# "block_q_dkv" : 1024, +# "block_kv_dkv" : 1024, +# "block_kv_dkv_compute" : 256, +# "block_q_dq" : 1024, +# "block_kv_dq" : 1024 +#} +# Use on v6e +flash_block_sizes: { + "block_q" : 3024, + "block_kv_compute" : 1024, + "block_kv" : 2048, + "block_q_dkv" : 3024, + "block_kv_dkv" : 2048, + "block_kv_dkv_compute" : 2048, + "block_q_dq" : 3024, + "block_kv_dq" : 2048, + "use_fused_bwd_kernel": False, +} +# GroupNorm groups +norm_num_groups: 32 + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', + rescale_zero_terminal_snr: False, + timestep_spacing: 'trailing' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False + +# Parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', 'data'], + ['activation_self_attn_heads', ['fsdp', 'tensor']], + ['activation_cross_attn_q_length', ['fsdp', 'tensor']], + ['activation_length', 'fsdp'], + ['activation_heads', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['norm', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +allow_split_physical_axes: True + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tfrecord' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '' +load_tfrecord_cached: True +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# Defines the type of gradient checkpoint to enable. +# NONE - means no gradient checkpoint +# FULL - means full gradient checkpoint, whenever possible (minimum memory usage) +# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, +# except for ones that involve batch dimension - that means that all attention and projection +# layers will have gradient checkpoint, but not the backward with respect to the parameters. +# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing. +# CUSTOM - set names to offload and save. +remat_policy: "NONE" +# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj +# xq_out, xk_out, ffn_activation +names_which_can_be_saved: [] +names_which_can_be_offloaded: [] + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +checkpoint_dir: "" +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 1.e-5 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 1500 +num_train_epochs: 1 +seed: 0 +output_dir: 'sdxl-model-finetuned' +per_device_batch_size: 1.0 +# If global_batch_size % jax.device_count is not 0, use FSDP sharding. +global_batch_size: 0 + +# For creating tfrecords from dataset +tfrecords_dir: '' +no_records_per_shard: 0 +enable_eval_timesteps: False +timesteps_list: [125, 250, 375, 500, 625, 750, 875] +num_eval_samples: 420 + +warmup_steps_fraction: 0.1 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. +save_optimizer: False + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 0 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Enable JAX named scopes for detailed profiling and debugging +# When enabled, adds named scopes around key operations in transformer and attention layers +enable_jax_named_scopes: False + +# Generation parameters +prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" +do_classifier_free_guidance: True +height: 720 +width: 1280 +num_frames: 81 +guidance_scale: 5.0 +flow_shift: 3.0 + +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 30 +fps: 16 +save_final_checkpoint: False + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" + +# LoRA parameters +lora_rank: 64 +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"], + weight_name: ["wan2.1_t2v_14b_lora_rank64_lightx2v_4step.safetensors"], + adapter_name: ["wan21-distill-lora"], + scale: [1.0], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' +quantization: '' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. +use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix. +# Quantization calibration method used for weights, activations and bwd. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 +weight_quantization_calibration_method: "absmax" +act_quantization_calibration_method: "absmax" +bwd_quantization_calibration_method: "absmax" +qwix_module_path: ".*" + +# Eval model on per eval_every steps. -1 means don't eval. +eval_every: -1 +eval_data_dir: "" +enable_generate_video_for_eval: False # This will increase the used TPU memory. +eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list). + +enable_ssim: False \ No newline at end of file diff --git a/src/maxdiffusion/configs/base_wan_lora_27b.yml b/src/maxdiffusion/configs/base_wan_lora_27b.yml new file mode 100644 index 00000000..ed76a963 --- /dev/null +++ b/src/maxdiffusion/configs/base_wan_lora_27b.yml @@ -0,0 +1,358 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 + +pretrained_model_name_or_path: 'Wan-AI/Wan2.2-T2V-A14B-Diffusers' +model_name: wan2.2 + +# Overrides the transformer from pretrained_model_name_or_path +wan_transformer_pretrained_model_name_or_path: '' + +unet_checkpoint: '' +revision: '' +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# Replicates vae across devices instead of using the model's sharding annotations for sharding. +replicate_vae: False + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" +# Use jax.lax.scan for transformer layers +scan_layers: False + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + +# Set true to load weights from pytorch +from_pt: True +split_head_dim: True +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring +flash_min_seq_length: 4096 +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. +attention_sharding_uniform: True +dropout: 0.1 + +#flash_block_sizes: { +# "block_q" : 1024, +# "block_kv_compute" : 256, +# "block_kv" : 1024, +# "block_q_dkv" : 1024, +# "block_kv_dkv" : 1024, +# "block_kv_dkv_compute" : 256, +# "block_q_dq" : 1024, +# "block_kv_dq" : 1024 +#} +# Use on v6e +flash_block_sizes: { + "block_q" : 3024, + "block_kv_compute" : 1024, + "block_kv" : 2048, + "block_q_dkv" : 3024, + "block_kv_dkv" : 2048, + "block_kv_dkv_compute" : 2048, + "block_q_dq" : 3024, + "block_kv_dq" : 2048, + "use_fused_bwd_kernel": False, +} +# GroupNorm groups +norm_num_groups: 32 + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', + rescale_zero_terminal_snr: False, + timestep_spacing: 'trailing' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False + +# Parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', 'data'], + ['activation_self_attn_heads', ['fsdp', 'tensor']], + ['activation_cross_attn_q_length', ['fsdp', 'tensor']], + ['activation_length', 'fsdp'], + ['activation_heads', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['norm', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +allow_split_physical_axes: True + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tfrecord' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '' +load_tfrecord_cached: True +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# Defines the type of gradient checkpoint to enable. +# NONE - means no gradient checkpoint +# FULL - means full gradient checkpoint, whenever possible (minimum memory usage) +# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, +# except for ones that involve batch dimension - that means that all attention and projection +# layers will have gradient checkpoint, but not the backward with respect to the parameters. +# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing. +# CUSTOM - set names to offload and save. +remat_policy: "NONE" +# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj +# xq_out, xk_out, ffn_activation +names_which_can_be_saved: [] +names_which_can_be_offloaded: [] + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +checkpoint_dir: "" +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 1.e-5 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 1500 +num_train_epochs: 1 +seed: 0 +output_dir: 'sdxl-model-finetuned' +per_device_batch_size: 1.0 +# If global_batch_size % jax.device_count is not 0, use FSDP sharding. +global_batch_size: 0 + +# For creating tfrecords from dataset +tfrecords_dir: '' +no_records_per_shard: 0 +enable_eval_timesteps: False +timesteps_list: [125, 250, 375, 500, 625, 750, 875] +num_eval_samples: 420 + +warmup_steps_fraction: 0.1 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. +save_optimizer: False + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 0 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Enable JAX named scopes for detailed profiling and debugging +# When enabled, adds named scopes around key operations in transformer and attention layers +enable_jax_named_scopes: False + +# Generation parameters +prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" +do_classifier_free_guidance: True +height: 720 +width: 1280 +num_frames: 81 +flow_shift: 3.0 + +# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py +# guidance scale factor for low noise transformer +guidance_scale_low: 3.0 + +# guidance scale factor for high noise transformer +guidance_scale_high: 4.0 + +# The timestep threshold. If `t` is at or above this value, +# the `high_noise_model` is considered as the required model. +# timestep to switch between low noise and high noise transformer +boundary_timestep: 875 + +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 30 +fps: 16 +save_final_checkpoint: False + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" + +# LoRA parameters +lora_rank: 64 +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras"], + high_noise_weight_name: ["wan2.2_t2v_A14b_high_noise_lora_rank64_lightx2v_4step_1217.safetensors"], + low_noise_weight_name: ["wan2.2_t2v_A14b_low_noise_lora_rank64_lightx2v_4step_1217.safetensors"], + adapter_name: ["wan22-distill-lora"], + scale: [1.0], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' +quantization: '' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. +use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix. +# Quantization calibration method used for weights, activations and bwd. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 +weight_quantization_calibration_method: "absmax" +act_quantization_calibration_method: "absmax" +bwd_quantization_calibration_method: "absmax" +qwix_module_path: ".*" + +# Eval model on per eval_every steps. -1 means don't eval. +eval_every: -1 +eval_data_dir: "" +enable_generate_video_for_eval: False # This will increase the used TPU memory. +eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list). + +enable_ssim: False \ No newline at end of file diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index e3365e96..a8e94802 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -25,6 +25,7 @@ from google.cloud import storage import flax from maxdiffusion.common_types import WAN2_1, WAN2_2 +from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader def upload_video_to_gcs(output_dir: str, video_path: str): @@ -148,6 +149,42 @@ def run(config, pipeline=None, filename_prefix=""): else: raise ValueError(f"Unsupported model_name for checkpointer: {model_key}") pipeline, _, _ = checkpoint_loader.load_checkpoint() + + # If LoRA is specified, inject layers and load weights. + if hasattr(config, "lora_config") and config.lora_config and config.lora_config["lora_model_name_or_path"]: + if model_key == WAN2_1: + lora_loader = Wan2_1NnxLoraLoader() + lora_config = config.lora_config + + if len(lora_config["lora_model_name_or_path"]) > 1: + max_logging.log("Found multiple LoRAs in config, but only loading the first one.") + + pipeline = lora_loader.load_lora_weights( + pipeline, + lora_config["lora_model_name_or_path"][0], + transformer_weight_name=lora_config["weight_name"][0], + rank=config.lora_rank, + scale=lora_config["scale"][0], + scan_layers=config.scan_layers, + ) + + if model_key == WAN2_2: + lora_loader = Wan2_2NnxLoraLoader() + lora_config = config.lora_config + + if len(lora_config["lora_model_name_or_path"]) > 1: + max_logging.log("Found multiple LoRAs in config, but only loading the first one.") + + pipeline = lora_loader.load_lora_weights( + pipeline, + lora_config["lora_model_name_or_path"][0], + high_noise_weight_name=lora_config["high_noise_weight_name"][0], + low_noise_weight_name=lora_config["low_noise_weight_name"][0], + rank=config.lora_rank, + scale=lora_config["scale"][0], + scan_layers=config.scan_layers, + ) + s0 = time.perf_counter() # Using global_batch_size_to_train_on so not to create more config variables diff --git a/src/maxdiffusion/loaders/__init__.py b/src/maxdiffusion/loaders/__init__.py index 2c9e973d..e7abb88a 100644 --- a/src/maxdiffusion/loaders/__init__.py +++ b/src/maxdiffusion/loaders/__init__.py @@ -14,3 +14,4 @@ from .lora_pipeline import StableDiffusionLoraLoaderMixin from .flux_lora_pipeline import FluxLoraLoaderMixin +from .wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py index 5f9e72a6..f53428d3 100644 --- a/src/maxdiffusion/loaders/lora_conversion_utils.py +++ b/src/maxdiffusion/loaders/lora_conversion_utils.py @@ -608,3 +608,77 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") return new_state_dict + + +def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): + """ + Translates WAN NNX path to Diffusers/LoRA keys. + Verified against wan_utils.py mappings. + """ + + # --- 1. Embeddings (Exact Matches) --- + if nnx_path_str == 'condition_embedder.text_embedder.linear_1': + return 'diffusion_model.text_embedding.0' + if nnx_path_str == 'condition_embedder.text_embedder.linear_2': + return 'diffusion_model.text_embedding.2' + if nnx_path_str == 'condition_embedder.time_embedder.linear_1': + return 'diffusion_model.time_embedding.0' + if nnx_path_str == 'condition_embedder.time_embedder.linear_2': + return 'diffusion_model.time_embedding.2' + if nnx_path_str == 'patch_embedding': + return 'diffusion_model.patch_embedding' + if nnx_path_str == 'proj_out': + return 'diffusion_model.head.head' + if nnx_path_str == 'condition_embedder.time_proj': + return 'diffusion_model.time_projection.1' + + + # --- 2. Map NNX Suffixes to LoRA Suffixes --- + suffix_map = { + # Self Attention (attn1) + "attn1.query": "self_attn.q", + "attn1.key": "self_attn.k", + "attn1.value": "self_attn.v", + "attn1.proj_attn": "self_attn.o", + + # Self Attention Norms (QK Norm) - Added per your request + "attn1.norm_q": "self_attn.norm_q", + "attn1.norm_k": "self_attn.norm_k", + + # Cross Attention (attn2) + "attn2.query": "cross_attn.q", + "attn2.key": "cross_attn.k", + "attn2.value": "cross_attn.v", + "attn2.proj_attn": "cross_attn.o", + + # Cross Attention Norms (QK Norm) - Added per your request + "attn2.norm_q": "cross_attn.norm_q", + "attn2.norm_k": "cross_attn.norm_k", + + # Feed Forward (ffn) + "ffn.act_fn.proj": "ffn.0", # Up proj + "ffn.proj_out": "ffn.2", # Down proj + + # Global Norms & Modulation + "norm2.layer_norm": "norm3", + "scale_shift_table": "modulation", + "proj_out": "head.head" + } + + # --- 3. Translation Logic --- + if scan_layers: + # Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q" + if nnx_path_str.startswith("blocks."): + inner_suffix = nnx_path_str[len("blocks."):] + if inner_suffix in suffix_map: + return f"diffusion_model.blocks.{{}}.{suffix_map[inner_suffix]}" + else: + # Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q" + m = re.match(r"^blocks\.(\d+)\.(.+)$", nnx_path_str) + if m: + idx, inner_suffix = m.group(1), m.group(2) + if inner_suffix in suffix_map: + return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}" + + return None + diff --git a/src/maxdiffusion/loaders/wan_lora_nnx_loader.py b/src/maxdiffusion/loaders/wan_lora_nnx_loader.py new file mode 100644 index 00000000..2ae83888 --- /dev/null +++ b/src/maxdiffusion/loaders/wan_lora_nnx_loader.py @@ -0,0 +1,109 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NNX-based LoRA loader for WAN models.""" + +from flax import nnx +import jax +import re +from .lora_base import LoRABaseMixin +from .lora_pipeline import StableDiffusionLoraLoaderMixin +from ..models import lora_nnx +from .. import max_logging +from . import lora_conversion_utils + +class Wan2_1NnxLoraLoader(LoRABaseMixin): + """ + Handles loading LoRA weights into NNX-based WAN 2.1 model. + Assumes WAN pipeline contains 'transformer' + attributes that are NNX Modules. + """ + + def load_lora_weights( + self, + pipeline: nnx.Module, + lora_model_path: str, + transformer_weight_name: str, + rank: int, + scale: float = 1.0, + scan_layers: bool = False, + **kwargs, + ): + """ + Merges LoRA weights into the pipeline from a checkpoint. + """ + lora_loader = StableDiffusionLoraLoaderMixin() + + merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora + translate_fn = lambda nnx_path_str: lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) + + # Handle high noise model + if hasattr(pipeline, "transformer") and transformer_weight_name: + max_logging.log(f"Merging LoRA into transformer with rank={rank}") + h_state_dict, _ = lora_loader.lora_state_dict( + lora_model_path, weight_name=transformer_weight_name, **kwargs + ) + merge_fn(pipeline.transformer, h_state_dict, scale, translate_fn) + else: + max_logging.log("transformer not found or no weight name provided for LoRA.") + + return pipeline + +class Wan2_2NnxLoraLoader(LoRABaseMixin): + """ + Handles loading LoRA weights into NNX-based WAN 2.2 model. + Assumes WAN pipeline contains 'high_noise_transformer' and 'low_noise_transformer' + attributes that are NNX Modules. + """ + + def load_lora_weights( + self, + pipeline: nnx.Module, + lora_model_path: str, + high_noise_weight_name: str, + low_noise_weight_name: str, + rank: int, + scale: float = 1.0, + scan_layers: bool = False, + **kwargs, + ): + """ + Merges LoRA weights into the pipeline from a checkpoint. + """ + lora_loader = StableDiffusionLoraLoaderMixin() + + merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora + translate_fn = lambda nnx_path_str: lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) + + # Handle high noise model + if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name: + max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}") + h_state_dict, _ = lora_loader.lora_state_dict( + lora_model_path, weight_name=high_noise_weight_name, **kwargs + ) + merge_fn(pipeline.high_noise_transformer, h_state_dict, scale, translate_fn) + else: + max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.") + + # Handle low noise model + if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name: + max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}") + l_state_dict, _ = lora_loader.lora_state_dict( + lora_model_path, weight_name=low_noise_weight_name, **kwargs + ) + merge_fn(pipeline.low_noise_transformer, l_state_dict, scale, translate_fn) + else: + max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.") + + return pipeline diff --git a/src/maxdiffusion/models/lora_nnx.py b/src/maxdiffusion/models/lora_nnx.py new file mode 100644 index 00000000..bf4ee793 --- /dev/null +++ b/src/maxdiffusion/models/lora_nnx.py @@ -0,0 +1,495 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Union, Tuple, Optional +import re +import torch +import jax +from jax import dlpack +import jax.numpy as jnp +from flax import nnx +from .. import max_logging +import numpy as np + +# ----------------------------------------------------------------------------- +# JIT Helpers (The Fix for Sharding & Device-Side Computation) +# ----------------------------------------------------------------------------- + +@jax.jit +def _compute_and_add_single_jit(kernel, bias, down, up, scale, w_diff, b_diff): + """ + Applies LoRA + Weight Diff + Bias Diff on device. + """ + # 1. Apply LoRA (if valid) + if down is not None and up is not None: + # down: (Rank, In), up: (Out, Rank) -> Result: (In, Out) + # Note: We reshape to kernel shape to handle 1x1 convs + delta = (down.T @ up.T).reshape(kernel.shape) + kernel = kernel + (delta * scale).astype(kernel.dtype) + + # 2. Apply Full Weight Diff (if valid) + if w_diff is not None: + kernel = kernel + w_diff.astype(kernel.dtype) + + # 3. Apply Bias Diff (if valid and bias exists) + if bias is not None and b_diff is not None: + bias = bias + b_diff.astype(bias.dtype) + + return kernel, bias + +@jax.jit +def _compute_and_add_scanned_jit(kernel, downs, ups, alphas, global_scale, w_diffs=None, b_diffs=None, bias=None): + """ + Applies scanned LoRA + Diffs. + """ + # 1. Apply LoRA + if downs is not None and ups is not None: + rank = downs.shape[1] + scales = (global_scale * alphas / rank) + # Batch Matmul: (L, In, Out) + delta = jnp.matmul(jnp.swapaxes(downs, 1, 2), jnp.swapaxes(ups, 1, 2)) + delta = (delta * scales).astype(kernel.dtype) + kernel = kernel + delta.reshape(kernel.shape) + + # 2. Apply Scanned Weight Diffs (L, ...) + if w_diffs is not None: + kernel = kernel + w_diffs.astype(kernel.dtype) + + # 3. Apply Scanned Bias Diffs (L, ...) + # Note: Scanned bias is usually shape (L, Out) + if bias is not None and b_diffs is not None: + bias = bias + b_diffs.astype(bias.dtype) + + return kernel, bias + +# ----------------------------------------------------------------------------- + +def _to_jax_array(v): + if isinstance(v, torch.Tensor): + return dlpack.from_dlpack(v) + return jnp.array(v) + +def parse_lora_dict(state_dict): + """Helper to parse state_dict into structured params including diffs.""" + lora_params = {} + for k, v in state_dict.items(): + # Alpha + if k.endswith(".alpha"): + key_base = k[:-len(".alpha")] + if key_base not in lora_params: lora_params[key_base] = {} + lora_params[key_base]["alpha"] = _to_jax_array(v) + continue + + # Bias Diff (e.g., "layer.diff_b") + if k.endswith(".diff_b"): + key_base = k[:-len(".diff_b")] + if key_base not in lora_params: lora_params[key_base] = {} + lora_params[key_base]["diff_b"] = _to_jax_array(v) + continue + + # Weight Diff (e.g., "layer.diff") + if k.endswith(".diff"): + key_base = k[:-len(".diff")] + if key_base not in lora_params: lora_params[key_base] = {} + lora_params[key_base]["diff"] = _to_jax_array(v) + continue + + # Standard LoRA + m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k) + if not m: m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k) + if not m: m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k) + + if m: + key_base, weight_type = m.group(1), m.group(2).replace("lora_", "") + if key_base not in lora_params: lora_params[key_base] = {} + lora_params[key_base][weight_type] = _to_jax_array(v) + else: + # Fallback for exact matches of diffs if regex failed above + pass + + return lora_params + +def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=None): + """ + Merges weights for non-scanned layers (Embeddings, singular Dense, etc). + Now supports diff and diff_b. + """ + lora_params = parse_lora_dict(state_dict) + max_logging.log(f"Parsed {len(lora_params)} unique module keys.") + matched_keys = set() + + assigned_count = 0 + for path, module in nnx.iter_graph(model): + if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)): + continue + + nnx_path_str = ".".join(map(str, path)) + lora_key = translate_fn(nnx_path_str) if translate_fn else None + + if lora_key and lora_key in lora_params: + matched_keys.add(lora_key) + weights = lora_params[lora_key] + + is_conv_kxk_locon = False + if isinstance(module, nnx.Conv) and module.kernel_size != (1,1) and "down" in weights and "up" in weights: + is_conv_kxk_locon = True + + # Handle Embeddings + if isinstance(module, nnx.Embed): + if "diff" in weights and hasattr(module, 'embedding'): + module.embedding.value += np.array(weights["diff"]).reshape(module.embedding.shape).astype(module.embedding.dtype) + assigned_count += 1 + continue + # Handle Norms + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): + scale_diff = weights.get("diff", None) + bias_diff = weights.get("diff_b", None) + updated = False + if scale_diff is not None and hasattr(module, 'scale') and module.scale is not None: + module.scale.value += np.array(scale_diff).reshape(module.scale.shape).astype(module.scale.dtype) + updated = True + if bias_diff is not None and isinstance(module, nnx.LayerNorm) and hasattr(module, 'bias') and module.bias is not None: + module.bias.value += np.array(bias_diff).reshape(module.bias.shape).astype(module.bias.dtype) + updated = True + if updated: + assigned_count += 1 + continue + + # Prepare LoRA terms + down_w, up_w, current_scale = None, None, None + if "down" in weights and "up" in weights and not is_conv_kxk_locon: + down_w, up_w = weights["down"], weights["up"] + down_w, up_w = np.array(down_w), np.array(up_w) # CPU convert + + # Squeeze dimensions if needed (Conv 1x1 or Linear) + if isinstance(module, nnx.Conv) and module.kernel_size == (1, 1): + down_w, up_w = np.squeeze(down_w), np.squeeze(up_w) + + rank = down_w.shape[0] if down_w.ndim > 0 else 0 + alpha = float(weights.get("alpha", rank)) + current_scale = scale * alpha / rank + + # Prepare Diff terms + w_diff = weights.get("diff", None) + b_diff = weights.get("diff_b", None) + + if w_diff is not None: + w_diff = np.array(w_diff) + # Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed. + if isinstance(module, nnx.Conv): + if w_diff.ndim == 5: + w_diff = w_diff.transpose((2,3,4,1,0)) + elif w_diff.ndim == 4: + w_diff = w_diff.transpose((2,3,1,0)) + elif isinstance(module, nnx.Linear) and w_diff.ndim == 2: + w_diff = w_diff.transpose((1,0)) + if b_diff is not None: b_diff = np.array(b_diff) + + # If LoCON, compute delta and add to w_diff + if is_conv_kxk_locon: + dw, uw = np.array(weights['down']), np.array(weights['up']) + rank, in_c, *k_dims = dw.shape + out_c = uw.shape[0] + alpha = float(weights.get("alpha", rank)) + + delta_pt = (uw.reshape(out_c, rank) @ dw.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) + + # Transpose to flax + if delta_pt.ndim == 5: delta_fx = delta_pt.transpose((2,3,4,1,0)) + else: delta_fx = delta_pt.transpose((2,3,1,0)) + + lora_delta = delta_fx * (scale * alpha / rank) + if w_diff is None: + w_diff = lora_delta.astype(np.float32) + else: + w_diff += lora_delta.astype(w_diff.dtype) + + # Check for Bias existence + bias_val = module.bias.value if module.bias is not None else None + + # --- EXECUTE JIT UPDATE --- + if down_w is not None or w_diff is not None or b_diff is not None: + new_kernel, new_bias = _compute_and_add_single_jit( + module.kernel.value, + bias_val, + down_w, up_w, current_scale, + w_diff, b_diff + ) + + module.kernel.value = new_kernel + if new_bias is not None: + module.bias.value = new_bias + + assigned_count +=1 + else: + max_logging.log(f"Matched key {lora_key} but found no actionable weights.") + + max_logging.log(f"Merged weights into {assigned_count} layers.") + unmatched_keys = set(lora_params.keys()) - matched_keys + if unmatched_keys: + max_logging.log(f"{len(unmatched_keys)} key(s) in LoRA dictionary were not applied to any layer in the model: {unmatched_keys}") + + +def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, translate_fn=None): + """ + Device-Side Optimized Merge for Scanned Layers. + Now supports diff and diff_b. + """ + lora_params = parse_lora_dict(state_dict) + max_logging.log(f"Parsed {len(lora_params)} keys for scanned merge.") + matched_keys = set() + + assigned_count = 0 + for path, module in nnx.iter_graph(model): + if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)): + continue + + nnx_path_str = ".".join(map(str, path)) + lora_key_template = translate_fn(nnx_path_str) if translate_fn else None + + if not lora_key_template: + continue + + # Determine if layer is scanned based on parameter dimensions + is_scanned = False + if isinstance(module, nnx.Embed) and hasattr(module, 'embedding'): + is_scanned = module.embedding.ndim > 2 + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)) and hasattr(module, 'scale') and module.scale is not None: + is_scanned = module.scale.ndim > 1 + elif isinstance(module, nnx.Linear): + is_scanned = module.kernel.ndim == 3 + elif isinstance(module, nnx.Conv): + is_scanned = module.kernel.ndim == 5 + + # If layer is not scanned, merge it using single-layer logic + if not is_scanned: + lora_key = lora_key_template + if lora_key in lora_params: + matched_keys.add(lora_key) + weights = lora_params[lora_key] + is_conv_kxk_locon = isinstance(module, nnx.Conv) and module.kernel_size != (1,1) and "down" in weights and "up" in weights + + if isinstance(module, nnx.Embed): + if "diff" in weights and hasattr(module, 'embedding'): + module.embedding.value += np.array(weights["diff"]).reshape(module.embedding.shape).astype(module.embedding.dtype) + assigned_count += 1 + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): + scale_diff = weights.get("diff", None) + bias_diff = weights.get("diff_b", None) + updated = False + if scale_diff is not None and hasattr(module, 'scale') and module.scale is not None: + module.scale.value += np.array(scale_diff).reshape(module.scale.shape).astype(module.scale.dtype) + updated = True + if bias_diff is not None and isinstance(module, nnx.LayerNorm) and hasattr(module, 'bias') and module.bias is not None: + module.bias.value += np.array(bias_diff).reshape(module.bias.shape).astype(module.bias.dtype) + updated = True + if updated: + assigned_count += 1 + elif isinstance(module, (nnx.Linear, nnx.Conv)): + down_w, up_w, current_scale_ = None, None, None + if "down" in weights and "up" in weights and not is_conv_kxk_locon: + down_w, up_w = np.array(weights["down"]), np.array(weights["up"]) + if isinstance(module, nnx.Conv): down_w, up_w = np.squeeze(down_w), np.squeeze(up_w) + rank, alpha = down_w.shape[0], float(weights.get("alpha", down_w.shape[0])) + current_scale_ = scale * alpha / rank + + w_diff, b_diff = weights.get("diff", None), weights.get("diff_b", None) + if w_diff is not None: + w_diff = np.array(w_diff) + if isinstance(module, nnx.Conv): + if w_diff.ndim == 5: w_diff = w_diff.transpose((2,3,4,1,0)) + elif w_diff.ndim == 4: w_diff = w_diff.transpose((2,3,1,0)) + elif isinstance(module, nnx.Linear) and w_diff.ndim == 2: + w_diff = w_diff.transpose((1,0)) + if b_diff is not None: b_diff = np.array(b_diff) + if is_conv_kxk_locon: + dw, uw = np.array(weights['down']), np.array(weights['up']) + rank, in_c, *k_dims = dw.shape + out_c = uw.shape[0] + alpha = float(weights.get("alpha", rank)) + delta_pt = (uw.reshape(out_c, rank) @ dw.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) + if delta_pt.ndim == 5: delta_fx = delta_pt.transpose((2,3,4,1,0)) + else: delta_fx = delta_pt.transpose((2,3,1,0)) + lora_delta = delta_fx * (scale * alpha / rank) + if w_diff is None: w_diff = lora_delta.astype(np.float32) + else: w_diff += lora_delta.astype(w_diff.dtype) + + bias_val = module.bias.value if module.bias is not None else None + if down_w is not None or w_diff is not None or b_diff is not None: + k, b = _compute_and_add_single_jit(module.kernel.value, bias_val, down_w, up_w, current_scale_, w_diff, b_diff) + module.kernel.value = k + if b is not None: module.bias.value = b + assigned_count +=1 + continue + + # If we reach here, layer is SCANNED + if isinstance(module, nnx.Embed): + num_layers = module.embedding.shape[0] + embed_diffs_to_add = np.zeros_like(module.embedding.value) + updated = False + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + if "diff" in lora_params[lora_key]: + embed_diffs_to_add[i] = np.array(lora_params[lora_key]["diff"]).reshape(module.embedding.shape[1:]) + updated = True + if updated: + module.embedding.value += embed_diffs_to_add.astype(module.embedding.dtype) + assigned_count += 1 + continue + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): + num_layers = module.scale.shape[0] + scale_diffs_to_add = np.zeros_like(module.scale.value) + bias_diffs_to_add = np.zeros_like(module.bias.value) if isinstance(module, nnx.LayerNorm) and hasattr(module, 'bias') and module.bias is not None else None + updated_scale, updated_bias = False, False + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + weights = lora_params[lora_key] + if "diff" in weights: + scale_diffs_to_add[i] = np.array(weights["diff"]).reshape(module.scale.shape[1:]) + updated_scale = True + if "diff_b" in weights and bias_diffs_to_add is not None: + bias_diffs_to_add[i] = np.array(weights["diff_b"]).reshape(module.bias.shape[1:]) + updated_bias = True + if updated_scale: + module.scale.value += scale_diffs_to_add.astype(module.scale.dtype) + if updated_bias and bias_diffs_to_add is not None: + module.bias.value += bias_diffs_to_add.astype(module.bias.dtype) + if updated_scale or updated_bias: + assigned_count += 1 + continue + elif isinstance(module, (nnx.Linear, nnx.Conv)): + is_linear = isinstance(module, nnx.Linear) + is_conv = isinstance(module, nnx.Conv) + is_conv_kxk = isinstance(module, nnx.Conv) and module.kernel_size != (1,1) + if is_linear: + num_layers, in_feat, out_feat = module.kernel.shape + else: # Conv + num_layers = module.kernel.shape[0] + in_feat, out_feat = module.kernel.shape[3], module.kernel.shape[4] + else: + # Should not happen based on is_scanned logic + continue + + # 1. Scan for Rank (Fallback 64) + found_rank = 64 + for i in range(num_layers): + k = lora_key_template.format(i) + if k in lora_params and "down" in lora_params[k]: + found_rank = lora_params[k]["down"].shape[0] + break + + # 2. Pre-allocate Buffers (CPU) + # LoRA Buffers + stack_down = np.zeros((num_layers, found_rank, in_feat), dtype=np.float32) + stack_up = np.zeros((num_layers, out_feat, found_rank), dtype=np.float32) + stack_alpha = np.zeros((num_layers, 1, 1), dtype=np.float32) + + # Diff Buffers + # Initialize as None, allocate only if found to save memory + stack_w_diff = None + stack_b_diff = None + + has_lora = False + has_diff = False + + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + w = lora_params[lora_key] + + # --- Fill LoRA --- + if "down" in w: + d, u = np.array(w["down"]), np.array(w["up"]) + alpha = float(w.get("alpha", d.shape[0])) + rank = d.shape[0] + + if is_conv_kxk: + # For LoCON kxk, compute delta and merge into stack_w_diff + rank, in_c, *k_dims = d.shape + out_c = u.shape[0] + delta_pt = (u.reshape(out_c, rank) @ d.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) + if delta_pt.ndim == 5: delta_fx = delta_pt.transpose((2,3,4,1,0)) + else: delta_fx = delta_pt.transpose((2,3,1,0)) + + lora_delta = delta_fx * (scale * alpha / rank) + if stack_w_diff is None: stack_w_diff = np.zeros(module.kernel.shape, dtype=np.float32) + stack_w_diff[i] += lora_delta.reshape(stack_w_diff[i].shape).astype(stack_w_diff.dtype) + has_diff = True # Mark as having diff because we merged LoRA into w_diff + else: + # For Linear or 1x1 Conv, prepare for JIT + if d.ndim > 2: d = np.squeeze(d) + if u.ndim > 2: u = np.squeeze(u) + stack_down[i] = d + stack_up[i] = u + stack_alpha[i] = alpha + has_lora = True + + # --- Fill Weight Diff --- + if "diff" in w: + if stack_w_diff is None: + stack_w_diff = np.zeros(module.kernel.shape, dtype=np.float32) + wd = np.array(w["diff"]) + # Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed. + if is_conv: + if wd.ndim == 5: + wd = wd.transpose((2,3,4,1,0)) + elif wd.ndim == 4: + wd = wd.transpose((2,3,1,0)) + elif is_linear and wd.ndim == 2: + wd = wd.transpose((1,0)) + + stack_w_diff[i] += wd.reshape(stack_w_diff[i].shape) + has_diff = True + + # --- Fill Bias Diff --- + if "diff_b" in w: + if stack_b_diff is None: + # Bias shape: Linear (L, Out), Conv (L, Out) usually + stack_b_diff = np.zeros((num_layers, out_feat), dtype=np.float32) + bd = np.array(w["diff_b"]) + stack_b_diff[i] = bd.flatten() + has_diff = True + + if has_lora or has_diff: + bias_val = module.bias.value if module.bias is not None else None + + # Call JIT + new_k, new_b = _compute_and_add_scanned_jit( + module.kernel.value, + stack_down if has_lora else None, + stack_up if has_lora else None, + stack_alpha if has_lora else None, + scale, + stack_w_diff, + stack_b_diff, + bias_val + ) + + module.kernel.value = new_k + if new_b is not None: + module.bias.value = new_b + + assigned_count += 1 + + max_logging.log(f"Merged weights into {assigned_count} scanned layers.") + unmatched_keys = set(lora_params.keys()) - matched_keys + if unmatched_keys: + max_logging.log(f"{len(unmatched_keys)} key(s) in LoRA dictionary were not applied to any layer in the model: {unmatched_keys}")