-
Notifications
You must be signed in to change notification settings - Fork 7k
Expand file tree
/
Copy pathattention.py
More file actions
468 lines (370 loc) · 18 KB
/
attention.py
File metadata and controls
468 lines (370 loc) · 18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import logging
import pytest
import torch
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry, attention_backend
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils import is_kernels_available, is_torch_version
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_attention, torch_device
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Module-level backend parameter sets for AttentionBackendTesterMixin
# ---------------------------------------------------------------------------
_CUDA_AVAILABLE = torch.cuda.is_available()
_KERNELS_AVAILABLE = is_kernels_available()
_PARAM_NATIVE = pytest.param(AttentionBackendName.NATIVE, id="native")
_PARAM_NATIVE_CUDNN = pytest.param(
AttentionBackendName._NATIVE_CUDNN,
id="native_cudnn",
marks=pytest.mark.skipif(
not _CUDA_AVAILABLE,
reason="CUDA is required for _native_cudnn backend.",
),
)
_PARAM_FLASH_HUB = pytest.param(
AttentionBackendName.FLASH_HUB,
id="flash_hub",
marks=[
pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for flash_hub backend."),
pytest.mark.skipif(
not _KERNELS_AVAILABLE,
reason="`kernels` package is required for flash_hub backend. Install with `pip install kernels`.",
),
],
)
_PARAM_FLASH_3_HUB = pytest.param(
AttentionBackendName._FLASH_3_HUB,
id="flash_3_hub",
marks=[
pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for _flash_3_hub backend."),
pytest.mark.skipif(
not _KERNELS_AVAILABLE,
reason="`kernels` package is required for _flash_3_hub backend. Install with `pip install kernels`.",
),
],
)
# All backends under test.
_ALL_BACKEND_PARAMS = [_PARAM_NATIVE, _PARAM_NATIVE_CUDNN, _PARAM_FLASH_HUB, _PARAM_FLASH_3_HUB]
# Backends that only accept bf16/fp16 inputs; models and inputs must be cast before running them.
_BF16_REQUIRED_BACKENDS = {
AttentionBackendName._NATIVE_CUDNN,
AttentionBackendName.FLASH_HUB,
AttentionBackendName._FLASH_3_HUB,
}
# Backends that perform non-deterministic operations and therefore cannot run when
# torch.use_deterministic_algorithms(True) is active (e.g. after enable_full_determinism()).
_NON_DETERMINISTIC_BACKENDS = {AttentionBackendName._NATIVE_CUDNN}
def _maybe_cast_to_bf16(backend, model, inputs_dict):
"""Cast model and floating-point inputs to bfloat16 when the backend requires it."""
if backend not in _BF16_REQUIRED_BACKENDS:
return model, inputs_dict
model = model.to(dtype=torch.bfloat16)
inputs_dict = {
k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs_dict.items()
}
return model, inputs_dict
def _skip_if_backend_requires_nondeterminism(backend):
"""Skip at runtime when torch.use_deterministic_algorithms(True) blocks the backend.
This check is intentionally deferred to test execution time because
enable_full_determinism() is typically called at module level in test files *after*
the module-level pytest.param() objects in this file have already been evaluated,
making it impossible to catch via a collection-time skipif condition.
"""
if backend in _NON_DETERMINISTIC_BACKENDS and torch.are_deterministic_algorithms_enabled():
pytest.skip(
f"Backend '{backend.value}' performs non-deterministic operations and cannot run "
f"while `torch.use_deterministic_algorithms(True)` is active."
)
@is_attention
class AttentionTesterMixin:
"""
Mixin class for testing attention processor and module functionality on models.
Tests functionality from AttentionModuleMixin including:
- Attention processor management (set/get)
- QKV projection fusion/unfusion
Expected from config mixin:
- model_class: The model class to test
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@torch.no_grad()
def test_fuse_unfuse_qkv_projections(self, atol=1e-3, rtol=0):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
if not hasattr(model, "fuse_qkv_projections"):
pytest.skip("Model does not support QKV projection fusion.")
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
model.fuse_qkv_projections()
has_fused_projections = False
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
has_fused_projections = True
assert module.fused_projections, "fused_projections flag should be True"
break
if has_fused_projections:
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
output_before_fusion,
output_after_fusion,
atol=atol,
rtol=rtol,
msg="Output should not change after fusing projections",
)
model.unfuse_qkv_projections()
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
assert not module.fused_projections, "fused_projections flag should be False"
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
output_before_fusion,
output_after_unfusion,
atol=atol,
rtol=rtol,
msg="Output should match original after unfusing projections",
)
def test_get_set_processor(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
# Check if model has attention processors
if not hasattr(model, "attn_processors"):
pytest.skip("Model does not have attention processors.")
# Test getting processors
processors = model.attn_processors
assert isinstance(processors, dict), "attn_processors should return a dict"
assert len(processors) > 0, "Model should have at least one attention processor"
# Test that all processors can be retrieved via get_processor
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
processor = module.get_processor()
assert processor is not None, "get_processor should return a processor"
# Test setting a new processor
new_processor = AttnProcessor()
module.set_processor(new_processor)
retrieved_processor = module.get_processor()
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
def test_attention_processor_dict(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict of new processors
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
# Set processors using dict
model.set_attn_processor(new_processors)
# Verify all processors were set
updated_processors = model.attn_processors
for key in current_processors.keys():
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
def test_attention_processor_count_mismatch_raises_error(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict with wrong number of processors
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
# Verify error is raised
with pytest.raises(ValueError) as exc_info:
model.set_attn_processor(wrong_processors)
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
@is_attention
class AttentionBackendTesterMixin:
"""
Mixin class for testing attention backends on models. Following things are tested:
1. Backends can be set with the `attention_backend` context manager and with
`set_attention_backend()` method.
2. SDPA outputs don't deviate too much from backend outputs.
3. Backend works with (regional) compilation.
4. Backends can be restored.
Tests the backends using the model provided by the host test class. The backends to test
are defined in `_ALL_BACKEND_PARAMS`.
Expected from the host test class:
- model_class: The model class to instantiate.
Expected methods from the host test class:
- get_init_dict(): Returns dict of kwargs to construct the model.
- get_dummy_inputs(): Returns dict of inputs for the model's forward pass.
Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests.
"""
# -----------------------------------------------------------------------
# Tolerance attributes — override in host class to loosen/tighten checks.
# -----------------------------------------------------------------------
# test_output_close_to_native: alternate backends (flash, cuDNN) may
# accumulate small numerical errors vs the reference PyTorch SDPA kernel.
backend_vs_native_atol: float = 1e-2
backend_vs_native_rtol: float = 1e-2
# test_compile: regional compilation introduces the same kind of numerical
# error as the non-compiled backend path, so the same loose tolerance applies.
compile_vs_native_atol: float = 1e-2
compile_vs_native_rtol: float = 1e-2
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@torch.no_grad()
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_set_attention_backend_matches_context_manager(self, backend):
"""set_attention_backend() and the attention_backend() context manager must yield identical outputs."""
_skip_if_backend_requires_nondeterminism(backend)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
with attention_backend(backend):
ctx_output = model(**inputs_dict, return_dict=False)[0]
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
try:
model.set_attention_backend(backend.value)
except Exception as e:
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
pytest.skip(str(e))
try:
set_output = model(**inputs_dict, return_dict=False)[0]
finally:
model.reset_attention_backend()
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
assert_tensors_close(
set_output,
ctx_output,
atol=0,
rtol=0,
msg=(
f"Output from model.set_attention_backend('{backend.value}') should be identical "
f"to the output from `with attention_backend('{backend.value}'):`."
),
)
@torch.no_grad()
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_output_close_to_native(self, backend):
"""All backends should produce model output numerically close to the native SDPA reference."""
_skip_if_backend_requires_nondeterminism(backend)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
with attention_backend(AttentionBackendName.NATIVE):
native_output = model(**inputs_dict, return_dict=False)[0]
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
try:
model.set_attention_backend(backend.value)
except Exception as e:
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
pytest.skip(str(e))
try:
backend_output = model(**inputs_dict, return_dict=False)[0]
finally:
model.reset_attention_backend()
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
assert_tensors_close(
backend_output,
native_output,
atol=self.backend_vs_native_atol,
rtol=self.backend_vs_native_rtol,
msg=f"Output from {backend} should be numerically close to native SDPA.",
)
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_context_manager_switches_and_restores_backend(self, backend):
"""attention_backend() should activate the requested backend and restore the previous one on exit."""
initial_backend, _ = _AttentionBackendRegistry.get_active_backend()
with attention_backend(backend):
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
assert active_backend == backend, (
f"Backend should be {backend} inside the context manager, got {active_backend}."
)
restored_backend, _ = _AttentionBackendRegistry.get_active_backend()
assert restored_backend == initial_backend, (
f"Backend should be restored to {initial_backend} after exiting the context manager, "
f"got {restored_backend}."
)
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_compile(self, backend):
"""
`torch.compile` tests checking for recompilation, graph breaks, forward can run, etc.
For speed, we use regional compilation here (`model.compile_repeated_blocks()`
as opposed to `model.compile`).
"""
_skip_if_backend_requires_nondeterminism(backend)
if getattr(self.model_class, "_repeated_blocks", None) is None:
pytest.skip("Skipping tests as regional compilation is not supported.")
if backend == AttentionBackendName.NATIVE and not is_torch_version(">=", "2.9.0"):
pytest.xfail(
"test_compile with the native backend requires torch >= 2.9.0 for stable "
"fullgraph compilation with error_on_recompile=True."
)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
with torch.no_grad(), attention_backend(AttentionBackendName.NATIVE):
native_output = model(**inputs_dict, return_dict=False)[0]
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
try:
model.set_attention_backend(backend.value)
except Exception as e:
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
pytest.skip(str(e))
try:
model.compile_repeated_blocks(fullgraph=True)
torch.compiler.reset()
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
):
with torch.no_grad():
compile_output = model(**inputs_dict, return_dict=False)[0]
model(**inputs_dict, return_dict=False)
finally:
model.reset_attention_backend()
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
assert_tensors_close(
compile_output,
native_output,
atol=self.compile_vs_native_atol,
rtol=self.compile_vs_native_rtol,
msg=f"Compiled output with backend '{backend.value}' should be numerically close to eager native SDPA.",
)