Skip to content

Commit 7ca2716

Browse files
Copilotlstein
andauthored
[Feature] CPU execution for text encoders with automatic device management (#47)
* Initial plan * Fix TypeScript linting errors for cpu_only field Co-authored-by: lstein <111189+lstein@users.noreply.github.com> * chore(frontend) eslint * chore(frontend): prettier * Add missing popover translation for cpuOnly feature Co-authored-by: lstein <111189+lstein@users.noreply.github.com> * Improve cpuOnly popover help text based on code review Co-authored-by: lstein <111189+lstein@users.noreply.github.com> * Simplify CPU-only UI and add encoder support with device mismatch fix Co-authored-by: lstein <111189+lstein@users.noreply.github.com> * Limit CPU-only execution to text encoders and ensure conditioning is moved to CPU for storage Co-authored-by: lstein <111189+lstein@users.noreply.github.com> * Fix CPU-only execution to properly check model-specific compute device Co-authored-by: lstein <111189+lstein@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: lstein <111189+lstein@users.noreply.github.com> Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
1 parent bb2797f commit 7ca2716

15 files changed

Lines changed: 7148 additions & 3827 deletions

File tree

invokeai/app/invocations/cogview4_text_encoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class CogView4TextEncoderInvocation(BaseInvocation):
3737
@torch.no_grad()
3838
def invoke(self, context: InvocationContext) -> CogView4ConditioningOutput:
3939
glm_embeds = self._glm_encode(context, max_seq_len=COGVIEW4_GLM_MAX_SEQ_LEN)
40+
# Move embeddings to CPU for storage to save VRAM
41+
glm_embeds = glm_embeds.detach().to("cpu")
4042
conditioning_data = ConditioningFieldData(conditionings=[CogView4ConditioningInfo(glm_embeds=glm_embeds)])
4143
conditioning_name = context.conditioning.save(conditioning_data)
4244
return CogView4ConditioningOutput.build(conditioning_name)
@@ -85,7 +87,7 @@ def _glm_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Ten
8587
)
8688
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
8789
prompt_embeds = glm_text_encoder(
88-
text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
90+
text_input_ids.to(glm_text_encoder.device), output_hidden_states=True
8991
).hidden_states[-2]
9092

9193
assert isinstance(prompt_embeds, torch.Tensor)

invokeai/app/invocations/compel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
103103
textual_inversion_manager=ti_manager,
104104
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
105105
truncate_long_prompts=False,
106-
device=TorchDevice.choose_torch_device(),
106+
device=text_encoder.device, # Use the device the model is actually on
107107
split_long_text_mode=SplitLongTextMode.SENTENCES,
108108
)
109109

@@ -212,7 +212,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
212212
truncate_long_prompts=False, # TODO:
213213
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
214214
requires_pooled=get_pooled,
215-
device=TorchDevice.choose_torch_device(),
215+
device=text_encoder.device, # Use the device the model is actually on
216216
split_long_text_mode=SplitLongTextMode.SENTENCES,
217217
)
218218

invokeai/app/invocations/flux_text_encoder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
5858
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
5959
t5_embeddings = self._t5_encode(context)
6060
clip_embeddings = self._clip_encode(context)
61+
62+
# Move embeddings to CPU for storage to save VRAM
63+
# They will be moved to the appropriate device when used by the denoiser
64+
t5_embeddings = t5_embeddings.detach().to("cpu")
65+
clip_embeddings = clip_embeddings.detach().to("cpu")
66+
6167
conditioning_data = ConditioningFieldData(
6268
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
6369
)

invokeai/app/invocations/sd3_text_encoder.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@ def invoke(self, context: InvocationContext) -> SD3ConditioningOutput:
6969
if self.t5_encoder is not None:
7070
t5_embeddings = self._t5_encode(context, SD3_T5_MAX_SEQ_LEN)
7171

72+
# Move all embeddings to CPU for storage to save VRAM
73+
# They will be moved to the appropriate device when used by the denoiser
74+
clip_l_embeddings = clip_l_embeddings.detach().to("cpu")
75+
clip_l_pooled_embeddings = clip_l_pooled_embeddings.detach().to("cpu")
76+
clip_g_embeddings = clip_g_embeddings.detach().to("cpu")
77+
clip_g_pooled_embeddings = clip_g_pooled_embeddings.detach().to("cpu")
78+
if t5_embeddings is not None:
79+
t5_embeddings = t5_embeddings.detach().to("cpu")
80+
7281
conditioning_data = ConditioningFieldData(
7382
conditionings=[
7483
SD3ConditioningInfo(
@@ -117,7 +126,7 @@ def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tens
117126
f" {max_seq_len} tokens: {removed_text}"
118127
)
119128

120-
prompt_embeds = t5_text_encoder(text_input_ids.to(TorchDevice.choose_torch_device()))[0]
129+
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
121130

122131
assert isinstance(prompt_embeds, torch.Tensor)
123132
return prompt_embeds
@@ -180,7 +189,7 @@ def _clip_encode(
180189
f" {tokenizer_max_length} tokens: {removed_text}"
181190
)
182191
prompt_embeds = clip_text_encoder(
183-
input_ids=text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
192+
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
184193
)
185194
pooled_prompt_embeds = prompt_embeds[0]
186195
prompt_embeds = prompt_embeds.hidden_states[-2]

invokeai/app/invocations/z_image_text_encoder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class ZImageTextEncoderInvocation(BaseInvocation):
5757
@torch.no_grad()
5858
def invoke(self, context: InvocationContext) -> ZImageConditioningOutput:
5959
prompt_embeds = self._encode_prompt(context, max_seq_len=Z_IMAGE_MAX_SEQ_LEN)
60+
# Move embeddings to CPU for storage to save VRAM
61+
prompt_embeds = prompt_embeds.detach().to("cpu")
6062
conditioning_data = ConditioningFieldData(conditionings=[ZImageConditioningInfo(prompt_embeds=prompt_embeds)])
6163
conditioning_name = context.conditioning.save(conditioning_data)
6264
return ZImageConditioningOutput(
@@ -69,7 +71,6 @@ def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.
6971
Based on the ZImagePipeline._encode_prompt method from diffusers.
7072
"""
7173
prompt = self.prompt
72-
device = TorchDevice.choose_torch_device()
7374

7475
text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
7576
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
@@ -78,6 +79,9 @@ def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.
7879
(_, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
7980
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
8081

82+
# Use the device that the text_encoder is actually on
83+
device = text_encoder.device
84+
8185
# Apply LoRA models to the text encoder
8286
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
8387
exit_stack.enter_context(

invokeai/backend/model_manager/configs/qwen3_encoder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class Qwen3Encoder_Checkpoint_Config(Checkpoint_Config_Base, Config_Base):
5151
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
5252
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
5353
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
54+
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")
5455

5556
@classmethod
5657
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
@@ -87,6 +88,7 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
8788
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
8889
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
8990
format: Literal[ModelFormat.Qwen3Encoder] = Field(default=ModelFormat.Qwen3Encoder)
91+
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")
9092

9193
@classmethod
9294
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
@@ -130,6 +132,7 @@ class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
130132
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
131133
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
132134
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
135+
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")
133136

134137
@classmethod
135138
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:

invokeai/backend/model_manager/configs/t5_encoder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class T5Encoder_T5Encoder_Config(Config_Base):
2121
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
2222
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
2323
format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder)
24+
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")
2425

2526
@classmethod
2627
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
@@ -50,6 +51,7 @@ class T5Encoder_BnBLLMint8_Config(Config_Base):
5051
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
5152
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
5253
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b)
54+
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")
5355

5456
@classmethod
5557
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:

invokeai/backend/model_manager/load/load_default.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,26 @@ def _get_model_path(self, config: AnyModelConfig) -> Path:
6868
model_base = self._app_config.models_path
6969
return (model_base / config.path).resolve()
7070

71-
def _get_execution_device(self, config: AnyModelConfig) -> Optional[torch.device]:
71+
def _get_execution_device(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> Optional[torch.device]:
7272
"""Determine the execution device for a model based on its configuration.
73-
73+
74+
CPU-only execution is only applied to text encoder submodels to save VRAM while keeping
75+
the denoiser on GPU for performance. Conditioning tensors are moved to GPU after encoding.
76+
7477
Returns:
7578
torch.device("cpu") if the model should run on CPU only, None otherwise (use cache default).
7679
"""
77-
# Check if this is a main model with default settings that specify cpu_only
80+
# Check if this is a text encoder submodel of a main model with cpu_only setting
7881
if hasattr(config, "default_settings") and config.default_settings is not None:
7982
if hasattr(config.default_settings, "cpu_only") and config.default_settings.cpu_only is True:
80-
return torch.device("cpu")
83+
# Only apply CPU execution to text encoder submodels
84+
if submodel_type in [SubModelType.TextEncoder, SubModelType.TextEncoder2, SubModelType.TextEncoder3]:
85+
return torch.device("cpu")
86+
87+
# Check if this is a standalone text encoder config with cpu_only field (T5Encoder, Qwen3Encoder, etc.)
88+
if hasattr(config, "cpu_only") and config.cpu_only is True:
89+
return torch.device("cpu")
90+
8191
return None
8292

8393
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord:
@@ -91,8 +101,8 @@ def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubMod
91101
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
92102
loaded_model = self._load_model(config, submodel_type)
93103

94-
# Determine execution device from model config
95-
execution_device = self._get_execution_device(config)
104+
# Determine execution device from model config, considering submodel type
105+
execution_device = self._get_execution_device(config, submodel_type)
96106

97107
self._ram_cache.put(
98108
get_model_cache_key(config.key, submodel_type),

invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def is_in_vram(self) -> bool:
6060
"""Return true if the model is currently in VRAM."""
6161
return self._is_in_vram
6262

63+
@property
64+
def compute_device(self) -> torch.device:
65+
"""Return the compute device for this model."""
66+
return self._compute_device
67+
6368
def full_load_to_vram(self) -> int:
6469
"""Load all weights into VRAM (if supported by the model).
6570
Returns:

invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ def cur_vram_bytes(self) -> int:
136136
)
137137
return self._cur_vram_bytes
138138

139+
@property
140+
def compute_device(self) -> torch.device:
141+
"""Return the compute device for this model."""
142+
return self._compute_device
143+
139144
def full_load_to_vram(self) -> int:
140145
"""Load all weights into VRAM."""
141146
return self.partial_load_to_vram(self.total_bytes())

0 commit comments

Comments
 (0)