-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsageAttention_extensions.py
More file actions
executable file
·86 lines (76 loc) · 3.28 KB
/
sageAttention_extensions.py
File metadata and controls
executable file
·86 lines (76 loc) · 3.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
##########################################################################################################################
# From: https://arxiv.org/abs/2405.17661
# Title: RefDrop: Controllable Consistency in Image or Video Generation via Reference Feature Guidance
##########################################################################################################################
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.invocations.fields import (
InputField,
LatentsField,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
IPAdapterConditioningInfo,
IPAdapterData,
Range,
SDXLConditioningInfo,
TextConditioningData,
TextConditioningRegions,
)
from invokeai.app.invocations.fields import (
ConditioningField,
Input,
)
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation
import torch
from .extension_classes import GuidanceField, base_guidance_extension, GuidanceDataOutput
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.util.logging import info, warning, error
import random
import einops
from diffusers import UNet2DConditionModel
from typing import Type, Any, Dict, Iterator, List, Optional, Tuple, Union
from .sage_attention import SageAttention
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
from contextlib import contextmanager
@base_guidance_extension("SageAttention")
class SageAttention_Guidance(ExtensionBase):
def __init__(
self,
context: InvocationContext,
):
super().__init__()
@contextmanager
def patch_unet(unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
"""A context manager that patches `unet` with the provided attention processor."""
unet_orig_processors = unet.attn_processors
unet_replacement_processors = {}
for key in unet.attn_processors.keys():
unet_replacement_processors[key] = SageAttention()
try:
unet.set_attn_processor(unet_replacement_processors)
yield None
finally:
unet.set_attn_processor(unet_orig_processors)
@invocation(
"SageAttention_extInvocation",
title="SageAttention [Extension]",
tags=["SageAttention", "attention", "gottagofast"],
category="extensions",
version="1.0.0",
)
class SageAttention_ExtensionInvocation(BaseInvocation):
"""Incorporates features from the reference image in the output."""
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GuidanceDataOutput:
kwargs = {}
return GuidanceDataOutput(
guidance_data_output=GuidanceField(
guidance_name="SageAttention",
extension_kwargs=kwargs
)
)