forked from invoke-ai/InvokeAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_cache.py
More file actions
870 lines (736 loc) · 41 KB
/
model_cache.py
File metadata and controls
870 lines (736 loc) · 41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
import gc
import logging
import threading
import time
from dataclasses import dataclass
from functools import wraps
from logging import Logger
from typing import Any, Callable, Dict, List, Optional, Protocol
import psutil
import torch
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
apply_custom_layers_to_model,
)
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.prefix_logger_adapter import PrefixedLoggerAdapter
# Size of a GB in bytes.
GB = 2**30
# Size of a MB in bytes.
MB = 2**20
# TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels.
def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
"""Get the cache key for a model based on the optional submodel type."""
if submodel_type:
return f"{model_key}:{submodel_type.value}"
else:
return model_key
def synchronized(method: Callable[..., Any]) -> Callable[..., Any]:
"""A decorator that applies the class's self._lock to the method."""
@wraps(method)
def wrapper(self, *args, **kwargs):
with self._lock: # Automatically acquire and release the lock
return method(self, *args, **kwargs)
return wrapper
def record_activity(method: Callable[..., Any]) -> Callable[..., Any]:
"""A decorator that records activity after a method completes successfully.
Note: This decorator should be applied to methods that already hold self._lock.
"""
@wraps(method)
def wrapper(self, *args, **kwargs):
result = method(self, *args, **kwargs)
self._record_activity()
return result
return wrapper
@dataclass
class CacheEntrySnapshot:
cache_key: str
total_bytes: int
current_vram_bytes: int
class CacheMissCallback(Protocol):
def __call__(
self,
model_key: str,
cache_snapshot: dict[str, CacheEntrySnapshot],
) -> None: ...
class CacheHitCallback(Protocol):
def __call__(
self,
model_key: str,
cache_snapshot: dict[str, CacheEntrySnapshot],
) -> None: ...
class CacheModelsClearedCallback(Protocol):
def __call__(
self,
models_cleared: int,
bytes_requested: int,
bytes_freed: int,
cache_snapshot: dict[str, CacheEntrySnapshot],
) -> None: ...
class ModelCache:
"""A cache for managing models in memory.
The cache is based on two levels of model storage:
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
The model cache is based on the following assumptions:
- storage_device_mem_size > execution_device_mem_size
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
the execution_device.
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
configuration.
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
the context, and unload outside the context.
Example usage:
```
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
do_something_on_gpu(SD1)
```
"""
def __init__(
self,
execution_device_working_mem_gb: float,
enable_partial_loading: bool,
keep_ram_copy_of_weights: bool,
max_ram_cache_size_gb: float | None = None,
max_vram_cache_size_gb: float | None = None,
execution_device: torch.device | str = "cuda",
storage_device: torch.device | str = "cpu",
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
keep_alive_minutes: float = 0,
):
"""Initialize the model RAM cache.
:param execution_device_working_mem_gb: The amount of working memory to keep on the GPU (in GB) i.e. non-model
VRAM.
:param enable_partial_loading: Whether to enable partial loading of models.
:param max_ram_cache_size_gb: The maximum amount of CPU RAM to use for model caching in GB. This parameter is
kept to maintain compatibility with previous versions of the model cache, but should be deprecated in the
future. If set, this parameter overrides the default cache size logic.
:param max_vram_cache_size_gb: The amount of VRAM to use for model caching in GB. This parameter is kept to
maintain compatibility with previous versions of the model cache, but should be deprecated in the future.
If set, this parameter overrides the default cache size logic.
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
:param logger: InvokeAILogger to use (otherwise creates one)
:param keep_alive_minutes: How long to keep models in cache after last use (in minutes). 0 means keep indefinitely.
"""
self._enable_partial_loading = enable_partial_loading
self._keep_ram_copy_of_weights = keep_ram_copy_of_weights
self._execution_device_working_mem_gb = execution_device_working_mem_gb
self._execution_device: torch.device = torch.device(execution_device)
self._storage_device: torch.device = torch.device(storage_device)
self._max_ram_cache_size_gb = max_ram_cache_size_gb
self._max_vram_cache_size_gb = max_vram_cache_size_gb
self._logger = PrefixedLoggerAdapter(
logger or InvokeAILogger.get_logger(self.__class__.__name__), "MODEL CACHE"
)
self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None
self._cached_models: Dict[str, CacheRecord] = {}
self._cache_stack: List[str] = []
self._ram_cache_size_bytes = self._calc_ram_available_to_model_cache()
# A lock applied to all public method calls to make the ModelCache thread-safe.
# At the time of writing, the ModelCache should only be accessed from two threads:
# - The graph execution thread
# - Requests to empty the cache from a separate thread
self._lock = threading.RLock()
self._on_cache_hit_callbacks: set[CacheHitCallback] = set()
self._on_cache_miss_callbacks: set[CacheMissCallback] = set()
self._on_cache_models_cleared_callbacks: set[CacheModelsClearedCallback] = set()
# Keep-alive timeout support
self._keep_alive_minutes = keep_alive_minutes
self._last_activity_time: Optional[float] = None
self._timeout_timer: Optional[threading.Timer] = None
self._shutdown_event = threading.Event()
def on_cache_hit(self, cb: CacheHitCallback) -> Callable[[], None]:
self._on_cache_hit_callbacks.add(cb)
def unsubscribe() -> None:
self._on_cache_hit_callbacks.discard(cb)
return unsubscribe
def on_cache_miss(self, cb: CacheMissCallback) -> Callable[[], None]:
self._on_cache_miss_callbacks.add(cb)
def unsubscribe() -> None:
self._on_cache_miss_callbacks.discard(cb)
return unsubscribe
def on_cache_models_cleared(self, cb: CacheModelsClearedCallback) -> Callable[[], None]:
self._on_cache_models_cleared_callbacks.add(cb)
def unsubscribe() -> None:
self._on_cache_models_cleared_callbacks.discard(cb)
return unsubscribe
@property
@synchronized
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
return self._stats
@stats.setter
@synchronized
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collecting cache statistics."""
self._stats = stats
# Populate the cache size in the stats object when it's set
if self._stats is not None:
self._stats.cache_size = self._ram_cache_size_bytes
def _record_activity(self) -> None:
"""Record model activity and reset the timeout timer if configured.
Note: This method should only be called when self._lock is already held.
"""
if self._keep_alive_minutes <= 0:
return
self._last_activity_time = time.time()
# Cancel any existing timer
if self._timeout_timer is not None:
self._timeout_timer.cancel()
# Start a new timer
timeout_seconds = self._keep_alive_minutes * 60
self._timeout_timer = threading.Timer(timeout_seconds, self._on_timeout)
# Set as daemon so it doesn't prevent application shutdown
self._timeout_timer.daemon = True
self._timeout_timer.start()
self._logger.debug(f"Model cache activity recorded. Timeout set to {self._keep_alive_minutes} minutes.")
@synchronized
@record_activity
def _on_timeout(self) -> None:
"""Called when the keep-alive timeout expires. Clears the model cache."""
if self._shutdown_event.is_set():
return
# Double-check if there has been activity since the timer was set
# This handles the race condition where activity occurred just before the timer fired
if self._last_activity_time is not None and self._keep_alive_minutes > 0:
elapsed_minutes = (time.time() - self._last_activity_time) / 60
if elapsed_minutes < self._keep_alive_minutes:
# Activity occurred, don't clear cache
self._logger.debug(
f"Model cache timeout fired but activity detected {elapsed_minutes:.2f} minutes ago. "
f"Skipping cache clear."
)
return
# Check if there are any unlocked models that can be cleared
unlocked_models = [key for key, entry in self._cached_models.items() if not entry.is_locked]
if len(unlocked_models) > 0:
self._logger.info(
f"Model cache keep-alive timeout of {self._keep_alive_minutes} minutes expired. "
f"Clearing {len(unlocked_models)} unlocked model(s) from cache."
)
# Clear the cache by requesting a very large amount of space.
# This is the same logic used by the "Clear Model Cache" button.
# Using 1000 GB ensures all unlocked models are removed.
self._make_room_internal(1000 * GB)
elif len(self._cached_models) > 0:
# All models are locked, don't log at info level
self._logger.debug(
f"Model cache timeout fired but all {len(self._cached_models)} model(s) are locked. "
f"Skipping cache clear."
)
else:
self._logger.debug("Model cache timeout fired but cache is already empty.")
@synchronized
def shutdown(self) -> None:
"""Shutdown the model cache, cancelling any pending timers."""
self._shutdown_event.set()
if self._timeout_timer is not None:
self._timeout_timer.cancel()
self._timeout_timer = None
@synchronized
@record_activity
def put(self, key: str, model: AnyModel, execution_device: Optional[torch.device] = None) -> None:
"""Add a model to the cache.
Args:
key: Cache key for the model
model: The model to cache
execution_device: Optional device to use for this specific model. If None, uses the cache's default
execution_device. Use torch.device("cpu") to force a model to run on CPU.
"""
if key in self._cached_models:
self._logger.debug(
f"Attempted to add model {key} ({model.__class__.__name__}), but it already exists in the cache. No action necessary."
)
return
size = calc_model_size_by_data(self._logger, model)
self._make_room_internal(size)
# Inject custom modules into the model.
if isinstance(model, torch.nn.Module):
apply_custom_layers_to_model(model)
# Use the provided execution device, or fall back to the cache's default
effective_execution_device = execution_device if execution_device is not None else self._execution_device
# Partial loading only makes sense on CUDA.
# - When running on CPU, there is no 'loading' to do.
# - When running on MPS, memory is shared with the CPU, so the default OS memory management already handles this
# well.
running_with_cuda = effective_execution_device.type == "cuda"
# Wrap model.
if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading:
wrapped_model = CachedModelWithPartialLoad(
model, effective_execution_device, keep_ram_copy=self._keep_ram_copy_of_weights
)
else:
wrapped_model = CachedModelOnlyFullLoad(
model, effective_execution_device, size, keep_ram_copy=self._keep_ram_copy_of_weights
)
cache_record = CacheRecord(key=key, cached_model=wrapped_model)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
self._logger.debug(
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size / MB:.2f}MB)"
)
@synchronized
def _get_cache_snapshot(self) -> dict[str, CacheEntrySnapshot]:
overview: dict[str, CacheEntrySnapshot] = {}
for cache_key, cache_entry in self._cached_models.items():
total_bytes = cache_entry.cached_model.total_bytes()
current_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
overview[cache_key] = CacheEntrySnapshot(
cache_key=cache_key,
total_bytes=total_bytes,
current_vram_bytes=current_vram_bytes,
)
return overview
@synchronized
@record_activity
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
"""Retrieve a model from the cache.
:param key: Model key
:param stats_name: A human-readable id for the model for the purposes of stats reporting.
Raises IndexError if the model is not in the cache.
"""
if key in self._cached_models:
if self.stats:
self.stats.hits += 1
else:
for cb in self._on_cache_miss_callbacks:
cb(model_key=key, cache_snapshot=self._get_cache_snapshot())
if self.stats:
self.stats.misses += 1
self._logger.debug(f"Cache miss: {key}")
raise IndexError(f"The model with key {key} is not in the cache.")
cache_entry = self._cached_models[key]
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.high_watermark = max(self.stats.high_watermark, self._get_ram_in_use())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.cached_model.total_bytes()
)
# This moves the entry to the top (right end) of the stack.
self._cache_stack = [k for k in self._cache_stack if k != key]
self._cache_stack.append(key)
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
for cb in self._on_cache_hit_callbacks:
cb(model_key=key, cache_snapshot=self._get_cache_snapshot())
return cache_entry
@synchronized
@record_activity
def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> None:
"""Lock a model for use and move it into VRAM."""
if cache_entry.key not in self._cached_models:
self._logger.info(
f"Locking model cache entry {cache_entry.key} "
f"(Type: {cache_entry.cached_model.model.__class__.__name__}), but it has already been dropped from "
"the RAM cache. This is a sign that the model loading order is non-optimal in the invocation code "
"(See https://github.com/invoke-ai/InvokeAI/issues/7513)."
)
# cache_entry = self._cached_models[key]
cache_entry.lock()
self._logger.debug(
f"Locking model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
)
# Check if the model's specific compute_device is CPU, not just the cache's default execution_device
model_compute_device = cache_entry.cached_model.compute_device
if model_compute_device.type == "cpu":
# Models configured for CPU execution don't need to be loaded into VRAM
self._logger.debug(
f"Model {cache_entry.key} is configured for CPU execution, skipping VRAM load"
)
return
try:
self._load_locked_model(cache_entry, working_mem_bytes)
self._logger.debug(
f"Finished locking model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
)
except torch.cuda.OutOfMemoryError:
self._logger.warning("Insufficient GPU memory to load model. Aborting")
cache_entry.unlock()
raise
except Exception:
cache_entry.unlock()
raise
self._log_cache_state()
@synchronized
@record_activity
def unlock(self, cache_entry: CacheRecord) -> None:
"""Unlock a model."""
if cache_entry.key not in self._cached_models:
self._logger.info(
f"Unlocking model cache entry {cache_entry.key} "
f"(Type: {cache_entry.cached_model.model.__class__.__name__}), but it has already been dropped from "
"the RAM cache. This is a sign that the model loading order is non-optimal in the invocation code "
"(See https://github.com/invoke-ai/InvokeAI/issues/7513)."
)
# cache_entry = self._cached_models[key]
cache_entry.unlock()
self._logger.debug(
f"Unlocked model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
)
def _load_locked_model(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int] = None) -> None:
"""Helper function for self.lock(). Loads a locked model into VRAM."""
start_time = time.time()
# Calculate model_vram_needed, the amount of additional VRAM that will be used if we fully load the model into
# VRAM.
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
model_total_bytes = cache_entry.cached_model.total_bytes()
model_vram_needed = model_total_bytes - model_cur_vram_bytes
vram_available = self._get_vram_available(working_mem_bytes)
self._logger.debug(
f"Before unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
)
# Make room for the model in VRAM.
# 1. If the model can fit entirely in VRAM, then make enough room for it to be loaded fully.
# 2. If the model can't fit fully into VRAM, then unload all other models and load as much of the model as
# possible.
vram_bytes_freed = self._offload_unlocked_models(model_vram_needed, working_mem_bytes)
self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed / MB):.2f}MB")
# Check the updated vram_available after offloading.
vram_available = self._get_vram_available(working_mem_bytes)
self._logger.debug(
f"After unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
)
if vram_available < 0:
# There is insufficient VRAM available. As a last resort, try to unload the model being locked from VRAM,
# as it may still be loaded from a previous use.
vram_bytes_freed_from_own_model = self._move_model_to_ram(cache_entry, -vram_available)
vram_available = self._get_vram_available(working_mem_bytes)
self._logger.debug(
f"Unloaded {vram_bytes_freed_from_own_model / MB:.2f}MB from the model being locked ({cache_entry.key})."
)
# Move as much of the model as possible into VRAM.
# For testing, only allow 10% of the model to be loaded into VRAM.
# vram_available = int(model_vram_needed * 0.1)
# We add 1 MB to the available VRAM to account for small errors in memory tracking (e.g. off-by-one). A fully
# loaded model is much faster than a 95% loaded model.
model_bytes_loaded = self._move_model_to_vram(cache_entry, vram_available + MB)
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
vram_available = self._get_vram_available(working_mem_bytes)
loaded_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0
# Use the model's actual compute_device for logging, not the cache's default
model_device = cache_entry.cached_model.compute_device
self._logger.info(
f"Loaded model '{cache_entry.key}' ({cache_entry.cached_model.model.__class__.__name__}) onto "
f"{model_device.type} device in {(time.time() - start_time):.2f}s. "
f"Total model size: {model_total_bytes / MB:.2f}MB, "
f"VRAM: {model_cur_vram_bytes / MB:.2f}MB ({loaded_percent:.1%})"
)
self._logger.debug(
f"Loaded model onto execution device: model_bytes_loaded={(model_bytes_loaded / MB):.2f}MB, "
)
self._logger.debug(
f"After loading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
)
def _move_model_to_vram(self, cache_entry: CacheRecord, vram_available: int) -> int:
try:
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
return cache_entry.cached_model.partial_load_to_vram(vram_available)
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
# Partial load is not supported, so we have not choice but to try and fit it all into VRAM.
return cache_entry.cached_model.full_load_to_vram()
else:
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
self._logger.warning("Insufficient GPU memory to load model. Aborting")
# If an exception occurs, the model could be left in a bad state, so we delete it from the cache entirely.
self._delete_cache_entry(cache_entry)
raise
def _move_model_to_ram(self, cache_entry: CacheRecord, vram_bytes_to_free: int) -> int:
try:
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
return cache_entry.cached_model.partial_unload_from_vram(
vram_bytes_to_free, keep_required_weights_in_vram=cache_entry.is_locked
)
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
return cache_entry.cached_model.full_unload_from_vram()
else:
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
except Exception:
# If an exception occurs, the model could be left in a bad state, so we delete it from the cache entirely.
self._delete_cache_entry(cache_entry)
raise
def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int:
"""Calculate the amount of additional VRAM available for the cache to use (takes into account the working
memory).
"""
# If self._max_vram_cache_size_gb is set, then it overrides the default logic.
if self._max_vram_cache_size_gb is not None:
vram_total_available_to_cache = int(self._max_vram_cache_size_gb * GB)
return vram_total_available_to_cache - self._get_vram_in_use()
working_mem_bytes_default = int(self._execution_device_working_mem_gb * GB)
working_mem_bytes = max(working_mem_bytes or working_mem_bytes_default, working_mem_bytes_default)
if self._execution_device.type == "cuda":
# TODO(ryand): It is debatable whether we should use memory_reserved() or memory_allocated() here.
# memory_reserved() includes memory reserved by the torch CUDA memory allocator that may or may not be
# re-used for future allocations. For now, we use memory_allocated() to be conservative.
# vram_reserved = torch.cuda.memory_reserved(self._execution_device)
vram_allocated = torch.cuda.memory_allocated(self._execution_device)
vram_free, _vram_total = torch.cuda.mem_get_info(self._execution_device)
vram_available_to_process = vram_free + vram_allocated
elif self._execution_device.type == "mps":
vram_reserved = torch.mps.driver_allocated_memory()
# TODO(ryand): Is it accurate that MPS shares memory with the CPU?
vram_free = psutil.virtual_memory().available
vram_available_to_process = vram_free + vram_reserved
else:
raise ValueError(f"Unsupported execution device: {self._execution_device.type}")
vram_total_available_to_cache = vram_available_to_process - working_mem_bytes
vram_cur_available_to_cache = vram_total_available_to_cache - self._get_vram_in_use()
return vram_cur_available_to_cache
def _get_vram_in_use(self) -> int:
"""Get the amount of VRAM currently in use by the cache."""
if self._execution_device.type == "cuda":
return torch.cuda.memory_allocated()
elif self._execution_device.type == "mps":
return torch.mps.current_allocated_memory()
else:
raise ValueError(f"Unsupported execution device type: {self._execution_device.type}")
# Alternative definition of VRAM in use:
# return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values())
def _calc_ram_available_to_model_cache(self) -> int:
"""Calculate the amount of RAM available for the cache to use."""
# If self._max_ram_cache_size_gb is set, then it overrides the default logic.
if self._max_ram_cache_size_gb is not None:
self._logger.info(f"Using user-defined RAM cache size: {self._max_ram_cache_size_gb} GB.")
return int(self._max_ram_cache_size_gb * GB)
# Heuristics for dynamically calculating the RAM cache size, **in order of increasing priority**:
# 1. As an initial default, use 50% of the total RAM for InvokeAI.
# - Assume a 2GB baseline for InvokeAI's non-model RAM usage, and use the rest of the RAM for the model cache.
# 2. On a system with a lot of RAM, users probably don't want InvokeAI to eat up too much RAM.
# There are diminishing returns to storing more and more models. So, we apply an upper bound. (Keep in mind
# that most OSes have some amount of disk caching, which we still benefit from if there is excess memory,
# even if we drop models from the cache.)
# - On systems without a CUDA device, the upper bound is 32GB.
# - On systems with a CUDA device, the upper bound is 1x the amount of VRAM (less the working memory).
# 3. Absolute minimum of 4GB.
# NOTE(ryand): We explored dynamically adjusting the RAM cache size based on memory pressure (using psutil), but
# decided against it for now, for the following reasons:
# - It was surprisingly difficult to get memory metrics with consistent definitions across OSes. (If you go
# down this path again, don't underestimate the amount of complexity here and be sure to test rigorously on all
# OSes.)
# - Making the RAM cache size dynamic opens the door for performance regressions that are hard to diagnose and
# hard for users to understand. It is better for users to see that their RAM is maxed out, and then override
# the default value if desired.
# Lookup the total VRAM size for the CUDA execution device.
total_cuda_vram_bytes: int | None = None
if self._execution_device.type == "cuda":
_, total_cuda_vram_bytes = torch.cuda.mem_get_info(self._execution_device)
# Apply heuristic 1.
# ------------------
heuristics_applied = [1]
total_system_ram_bytes = psutil.virtual_memory().total
# Assumed baseline RAM used by InvokeAI for non-model stuff.
baseline_ram_used_by_invokeai = 2 * GB
ram_available_to_model_cache = int(total_system_ram_bytes * 0.5 - baseline_ram_used_by_invokeai)
# Apply heuristic 2.
# ------------------
max_ram_cache_size_bytes = 32 * GB
if total_cuda_vram_bytes is not None:
if self._max_vram_cache_size_gb is not None:
max_ram_cache_size_bytes = int(self._max_vram_cache_size_gb * GB)
else:
max_ram_cache_size_bytes = total_cuda_vram_bytes - int(self._execution_device_working_mem_gb * GB)
if ram_available_to_model_cache > max_ram_cache_size_bytes:
heuristics_applied.append(2)
ram_available_to_model_cache = max_ram_cache_size_bytes
# Apply heuristic 3.
# ------------------
if ram_available_to_model_cache < 4 * GB:
heuristics_applied.append(3)
ram_available_to_model_cache = 4 * GB
self._logger.info(
f"Calculated model RAM cache size: {ram_available_to_model_cache / MB:.2f} MB. Heuristics applied: {heuristics_applied}."
)
return ram_available_to_model_cache
def _get_ram_in_use(self) -> int:
"""Get the amount of RAM currently in use."""
return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values())
def _get_ram_available(self) -> int:
"""Get the amount of RAM available for the cache to use."""
return self._ram_cache_size_bytes - self._get_ram_in_use()
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
return None
def _get_vram_state_str(self, model_cur_vram_bytes: int, model_total_bytes: int, vram_available: int) -> str:
"""Helper function for preparing a VRAM state log string."""
model_cur_vram_bytes_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0
return (
f"model_total={model_total_bytes / MB:.0f} MB, "
+ f"model_vram={model_cur_vram_bytes / MB:.0f} MB ({model_cur_vram_bytes_percent:.1%} %), "
# + f"vram_total={int(self._max_vram_cache_size * GB)/MB:.0f} MB, "
+ f"vram_available={(vram_available / MB):.0f} MB, "
)
def _offload_unlocked_models(self, vram_bytes_required: int, working_mem_bytes: Optional[int] = None) -> int:
"""Offload models from the execution_device until vram_bytes_required bytes are available, or all models are
offloaded. Of course, locked models are not offloaded.
Returns:
int: The number of bytes freed based on believed model sizes. The actual change in VRAM may be different.
"""
self._logger.debug(
f"Offloading unlocked models with goal of making room for {vram_bytes_required / MB:.2f}MB of VRAM."
)
vram_bytes_freed = 0
# TODO(ryand): Give more thought to the offloading policy used here.
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
for cache_entry in cache_entries_increasing_size:
# We do not fully trust the count of bytes freed, so we check again on each iteration.
vram_available = self._get_vram_available(working_mem_bytes)
vram_bytes_to_free = vram_bytes_required - vram_available
if vram_bytes_to_free <= 0:
break
if cache_entry.is_locked:
# TODO(ryand): In the future, we may want to partially unload locked models, but this requires careful
# handling of model patches (e.g. LoRA).
continue
cache_entry_bytes_freed = self._move_model_to_ram(cache_entry, vram_bytes_to_free)
if cache_entry_bytes_freed > 0:
self._logger.debug(
f"Unloaded {cache_entry.key} from VRAM to free {(cache_entry_bytes_freed / MB):.0f} MB."
)
vram_bytes_freed += cache_entry_bytes_freed
TorchDevice.empty_cache()
return vram_bytes_freed
def _log_cache_state(self, title: str = "Model cache state:", include_entry_details: bool = True):
if self._logger.getEffectiveLevel() > logging.DEBUG:
# Short circuit if the logger is not set to debug. Some of the data lookups could take a non-negligible
# amount of time.
return
log = f"{title}\n"
log_format = " {:<30} Limit: {:>7.1f} MB, Used: {:>7.1f} MB ({:>5.1%}), Available: {:>7.1f} MB ({:>5.1%})\n"
ram_in_use_bytes = self._get_ram_in_use()
ram_available_bytes = self._get_ram_available()
ram_size_bytes = ram_in_use_bytes + ram_available_bytes
ram_in_use_bytes_percent = ram_in_use_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
ram_available_bytes_percent = ram_available_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
log += log_format.format(
f"Storage Device ({self._storage_device.type})",
ram_size_bytes / MB,
ram_in_use_bytes / MB,
ram_in_use_bytes_percent,
ram_available_bytes / MB,
ram_available_bytes_percent,
)
if self._execution_device.type != "cpu":
vram_in_use_bytes = self._get_vram_in_use()
vram_available_bytes = self._get_vram_available(None)
vram_size_bytes = vram_in_use_bytes + vram_available_bytes
vram_in_use_bytes_percent = vram_in_use_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
vram_available_bytes_percent = vram_available_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
log += log_format.format(
f"Compute Device ({self._execution_device.type})",
vram_size_bytes / MB,
vram_in_use_bytes / MB,
vram_in_use_bytes_percent,
vram_available_bytes / MB,
vram_available_bytes_percent,
)
if torch.cuda.is_available():
log += " {:<30} {:.1f} MB\n".format("CUDA Memory Allocated:", torch.cuda.memory_allocated() / MB)
log += " {:<30} {}\n".format("Total models:", len(self._cached_models))
if include_entry_details and len(self._cached_models) > 0:
log += " Models:\n"
log_format = (
" {:<80} total={:>7.1f} MB, vram={:>7.1f} MB ({:>5.1%}), ram={:>7.1f} MB ({:>5.1%}), locked={}\n"
)
for cache_record in self._cached_models.values():
total_bytes = cache_record.cached_model.total_bytes()
cur_vram_bytes = cache_record.cached_model.cur_vram_bytes()
cur_vram_bytes_percent = cur_vram_bytes / total_bytes if total_bytes > 0 else 0
cur_ram_bytes = total_bytes - cur_vram_bytes
cur_ram_bytes_percent = cur_ram_bytes / total_bytes if total_bytes > 0 else 0
log += log_format.format(
f"{cache_record.key} ({cache_record.cached_model.model.__class__.__name__}):",
total_bytes / MB,
cur_vram_bytes / MB,
cur_vram_bytes_percent,
cur_ram_bytes / MB,
cur_ram_bytes_percent,
cache_record.is_locked,
)
self._logger.debug(log)
@synchronized
def make_room(self, bytes_needed: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size.
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
external references to the model, there's nothing that the cache can do about it, and those models will not be
garbage-collected.
"""
self._make_room_internal(bytes_needed)
def _make_room_internal(self, bytes_needed: int) -> None:
"""Internal implementation of make_room(). Assumes the lock is already held."""
self._logger.debug(f"Making room for {bytes_needed / MB:.2f}MB of RAM.")
self._log_cache_state(title="Before dropping models:")
ram_bytes_available = self._get_ram_available()
ram_bytes_to_free = max(0, bytes_needed - ram_bytes_available)
ram_bytes_freed = 0
pos = 0
models_cleared = 0
while ram_bytes_freed < ram_bytes_to_free and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
if not cache_entry.is_locked:
ram_bytes_freed += cache_entry.cached_model.total_bytes()
self._logger.debug(
f"Dropping {model_key} from RAM cache to free {(cache_entry.cached_model.total_bytes() / MB):.2f}MB."
)
self._delete_cache_entry(cache_entry)
del cache_entry
models_cleared += 1
else:
pos += 1
if models_cleared > 0:
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
# is high even if no garbage gets collected.)
#
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
# - If models had to be cleared, it's a signal that we are close to our memory limit.
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
# collected.
#
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
# immediately when their reference count hits 0.
if self.stats:
self.stats.cleared = models_cleared
for cb in self._on_cache_models_cleared_callbacks:
cb(
models_cleared=models_cleared,
bytes_requested=bytes_needed,
bytes_freed=ram_bytes_freed,
cache_snapshot=self._get_cache_snapshot(),
)
gc.collect()
TorchDevice.empty_cache()
self._logger.debug(f"Dropped {models_cleared} models to free {ram_bytes_freed / MB:.2f}MB of RAM.")
self._log_cache_state(title="After dropping models:")
def _delete_cache_entry(self, cache_entry: CacheRecord) -> None:
"""Delete cache_entry from the cache if it exists. No exception is thrown if it doesn't exist."""
self._cache_stack = [key for key in self._cache_stack if key != cache_entry.key]
self._cached_models.pop(cache_entry.key, None)