diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 834bce942e43..066f93a2a126 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1865,9 +1865,12 @@ def forward( out = out.to(torch.float32) lse = lse.to(torch.float32) - # Refer to: - # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 - if is_torch_version("<", "2.9.0"): + # lse must be 4-D to broadcast with out (B, S, H, D). + # Some backends (e.g. cuDNN on torch>=2.9) already return a + # trailing-1 dim; others (e.g. flash-hub / native-flash) always + # return 3-D lse, so we add the dim here when needed. + # See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 + if lse.ndim == 3: lse = lse.unsqueeze(-1) if prev_out is not None: out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) @@ -2154,10 +2157,11 @@ def _templated_unified_attention( scatter_idx, ) if return_lse: - # lse is of shape (B, S, H_LOCAL, 1) - # Refer to: - # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 - if is_torch_version("<", "2.9.0"): + # lse from TemplatedRingAttention is 3-D (B, S, H_LOCAL) after its + # final squeeze(-1). SeqAllToAllDim requires a 4-D input, so we add + # the trailing dim here and remove it after the collective. + # See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 + if lse.ndim == 3: lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1) lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx) lse = lse.squeeze(-1) diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py index ea076b3ec774..d012114da85e 100644 --- a/tests/models/testing_utils/__init__.py +++ b/tests/models/testing_utils/__init__.py @@ -13,7 +13,7 @@ from .ip_adapter import IPAdapterTesterMixin from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin -from .parallelism import ContextParallelTesterMixin +from .parallelism import ContextParallelAttentionBackendsTesterMixin, ContextParallelTesterMixin from .quantization import ( BitsAndBytesCompileTesterMixin, BitsAndBytesConfigMixin, @@ -45,6 +45,7 @@ "BitsAndBytesTesterMixin", "CacheTesterMixin", "ContextParallelTesterMixin", + "ContextParallelAttentionBackendsTesterMixin", "CPUOffloadTesterMixin", "FasterCacheConfigMixin", "FasterCacheTesterMixin", diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index e05b36799e66..b2e7b92d8231 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -23,10 +23,8 @@ from diffusers.models._modeling_parallel import ContextParallelConfig -from ...testing_utils import ( - is_context_parallel, - require_torch_multi_accelerator, -) +from ...testing_utils import is_context_parallel, is_kernels_available, require_torch_multi_accelerator +from .utils import _maybe_cast_to_bf16 def _find_free_port(): @@ -38,7 +36,9 @@ def _find_free_port(): return port -def _context_parallel_worker(rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict): +def _context_parallel_worker( + rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict, attention_backend=None +): """Worker function for context parallel testing.""" try: # Set up distributed environment @@ -59,6 +59,9 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di model.to(device) model.eval() + # Cast as needed. + model, inputs_dict = _maybe_cast_to_bf16(attention_backend, model, inputs_dict) + # Move inputs to device inputs_on_device = {} for key, value in inputs_dict.items(): @@ -67,6 +70,13 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di else: inputs_on_device[key] = value + # Enable attention backend + if attention_backend: + try: + model.set_attention_backend(attention_backend) + except Exception as e: + pytest.skip(f"Skipping test because of exception: {e}.") + # Enable context parallelism cp_config = ContextParallelConfig(**cp_dict) model.enable_parallelism(config=cp_config) @@ -126,3 +136,76 @@ def test_context_parallel_inference(self, cp_type): assert return_dict.get("status") == "success", ( f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}" ) + + +@is_context_parallel +@require_torch_multi_accelerator +class ContextParallelAttentionBackendsTesterMixin: + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"]) + @pytest.mark.parametrize( + "attention_backend", + [ + "native", + pytest.param( + "flash_hub", + marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), + ), + pytest.param( + "_flash_3_hub", + marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), + ), + ], + ) + @pytest.mark.parametrize("ulysses_anything", [True, False]) + @torch.no_grad() + def test_context_parallel_attn_backend_inference(self, cp_type, attention_backend, ulysses_anything): + if not torch.distributed.is_available(): + pytest.skip("torch.distributed is not available.") + + if getattr(self.model_class, "_cp_plan", None) is None: + pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + + if cp_type == "ring_degree": + if attention_backend == "native": + pytest.skip("Skipping test because ulysses isn't supported with native attention backend.") + + if ulysses_anything and "ulysses" not in cp_type: + pytest.skip("Skipping test as ulysses anything needs the ulysses degree set.") + + world_size = 2 + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + # Move all tensors to CPU for multiprocessing + inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} + cp_dict = {cp_type: world_size} + if ulysses_anything: + cp_dict.update({"ulysses_anything": ulysses_anything}) + + # Find a free port for distributed communication + master_port = _find_free_port() + + # Use multiprocessing manager for cross-process communication + manager = mp.Manager() + return_dict = manager.dict() + + # Spawn worker processes + mp.spawn( + _context_parallel_worker, + args=( + world_size, + master_port, + self.model_class, + init_dict, + cp_dict, + inputs_dict, + return_dict, + attention_backend, + ), + nprocs=world_size, + join=True, + ) + + assert return_dict.get("status") == "success", ( + f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}" + ) diff --git a/tests/models/testing_utils/utils.py b/tests/models/testing_utils/utils.py new file mode 100644 index 000000000000..7bec37db2496 --- /dev/null +++ b/tests/models/testing_utils/utils.py @@ -0,0 +1,22 @@ +import torch + +from diffusers.models.attention_dispatch import AttentionBackendName + + +_BF16_REQUIRED_BACKENDS = { + AttentionBackendName._NATIVE_CUDNN, + AttentionBackendName.FLASH_HUB, + AttentionBackendName._FLASH_3_HUB, +} + + +def _maybe_cast_to_bf16(backend, model, inputs_dict): + """Cast model and floating-point inputs to bfloat16 when the backend requires it.""" + if not backend or backend not in _BF16_REQUIRED_BACKENDS: + return model, inputs_dict + model = model.to(dtype=torch.bfloat16) + inputs_dict = { + k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v + for k, v in inputs_dict.items() + } + return model, inputs_dict diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 2d39dadfcad1..c8b68f36307a 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -29,6 +29,7 @@ BaseModelTesterConfig, BitsAndBytesCompileTesterMixin, BitsAndBytesTesterMixin, + ContextParallelAttentionBackendsTesterMixin, ContextParallelTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, @@ -228,6 +229,12 @@ class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextPar """Context Parallel inference tests for Flux Transformer""" +class TestFluxTransformerContextParallelAttnBackends( + FluxTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin +): + """Context Parallel inference x attention backends tests for Flux Transformer""" + + class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin): """IP Adapter tests for Flux Transformer.""" diff --git a/utils/generate_model_tests.py b/utils/generate_model_tests.py index 11acd2175e21..d27ced15afba 100644 --- a/utils/generate_model_tests.py +++ b/utils/generate_model_tests.py @@ -72,6 +72,7 @@ # Other testers ("SingleFileTesterMixin", "single_file"), ("IPAdapterTesterMixin", "ip_adapter"), + ("ContextParallelAttentionBackendsTesterMixin", "cp_attn"), ] @@ -229,7 +230,14 @@ def determine_testers(model_info: dict, include_optional: list[str], imports: se for tester, flag in OPTIONAL_TESTERS: if flag in include_optional: - if tester not in testers: + if tester == "ContextParallelAttentionBackendsTesterMixin": + if ( + "cp_attn" in include_optional + and "_cp_plan" in model_info["attributes"] + and model_info["attributes"]["_cp_plan"] is not None + ): + testers.append(tester) + elif tester not in testers: testers.append(tester) return testers @@ -530,6 +538,7 @@ def main(): "faster_cache", "single_file", "ip_adapter", + "cp_attn", "all", ], help="Optional testers to include",