Skip to content

Commit b4c6d61

Browse files
charliewwdevclaude
andcommitted
fix: pipeline memory management in shot chaining to prevent OOM
Three issues fixed: 1. current_pipeline held dangling ref to pipeline_t2v, preventing GC when switching T2V→I2V — both models in RAM simultaneously → OOM 2. UnboundLocalError on pipeline_t2v after del in repeated iterations 3. VACE 14B selected on 24GB GPUs — now checks VRAM, uses 1.3B if <48GB Tested: 16 shots + 2 cards generated on RTX 4090 24GB in 14.8min Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent cf0b184 commit b4c6d61

1 file changed

Lines changed: 38 additions & 9 deletions

File tree

scripts/produce_trailer_v4.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,21 +1601,24 @@ def _generate_shots_chained(
16011601
try:
16021602
if pipeline_vace is None:
16031603
print(" Loading VACE pipeline for continuation...")
1604-
# Free current pipeline first
1605-
if current_pipeline is not pipeline_t2v:
1604+
# Free ALL existing pipelines to avoid OOM
1605+
if current_pipeline is not None:
16061606
del current_pipeline
1607+
current_pipeline = None
16071608
if pipeline_i2v is not None:
16081609
del pipeline_i2v
16091610
pipeline_i2v = None
16101611
gc.collect()
16111612
if device == "cuda":
16121613
torch.cuda.empty_cache()
1614+
time.sleep(1)
16131615

1614-
# Load VACE model params
1616+
# Load VACE model params — use 14B only if >=48GB VRAM
16151617
vace_params = storyboard.get("model_params_vace", {})
1616-
vace_variant = "1.3B" # VACE is available in 1.3B and 14B
1618+
vace_variant = "1.3B"
16171619
if device == "cuda":
1618-
vace_variant = "14B"
1620+
gpu_mem_gb = torch.cuda.get_device_properties(0).total_mem / (1024**3)
1621+
vace_variant = "14B" if gpu_mem_gb >= 48 else "1.3B"
16191622

16201623
pipeline_vace = VACEBackendClass.load(
16211624
model_variant=vace_variant,
@@ -1656,14 +1659,22 @@ def _generate_shots_chained(
16561659
# Ensure I2V pipeline is loaded
16571660
if pipeline_i2v is None:
16581661
print(" Loading I2V pipeline...")
1659-
# Free T2V / VACE
1662+
# Free ALL existing pipelines to avoid OOM
16601663
if pipeline_vace is not None:
16611664
del pipeline_vace
16621665
pipeline_vace = None
1663-
del pipeline_t2v
1666+
# current_pipeline may hold a ref to pipeline_t2v — clear both
1667+
if current_pipeline is not None:
1668+
del current_pipeline
1669+
current_pipeline = None
1670+
try:
1671+
del pipeline_t2v
1672+
except UnboundLocalError:
1673+
pass
16641674
gc.collect()
16651675
if device == "cuda":
16661676
torch.cuda.empty_cache()
1677+
time.sleep(1) # give CUDA allocator time to reclaim
16671678

16681679
pipeline_i2v = BackendClass.load(
16691680
model_variant=model_variant,
@@ -1711,16 +1722,20 @@ def _generate_shots_chained(
17111722
# Note: pipeline_t2v may have been freed during I2V/VACE loading.
17121723
# In chained mode this can happen. We need to reload.
17131724
if pipeline_i2v is not None or pipeline_vace is not None:
1714-
# Need to reload T2V
1725+
# Free ALL existing pipelines before reloading T2V
17151726
if pipeline_i2v is not None:
17161727
del pipeline_i2v
17171728
pipeline_i2v = None
17181729
if pipeline_vace is not None:
17191730
del pipeline_vace
17201731
pipeline_vace = None
1732+
if current_pipeline is not None:
1733+
del current_pipeline
1734+
current_pipeline = None
17211735
gc.collect()
17221736
if device == "cuda":
17231737
torch.cuda.empty_cache()
1738+
time.sleep(1)
17241739

17251740
print(" Reloading T2V pipeline...")
17261741
pipeline_t2v_reload = BackendClass.load(
@@ -1735,7 +1750,21 @@ def _generate_shots_chained(
17351750
)
17361751
current_pipeline = pipeline_t2v_reload
17371752
else:
1738-
current_pipeline = pipeline_t2v
1753+
try:
1754+
current_pipeline = pipeline_t2v
1755+
except UnboundLocalError:
1756+
# pipeline_t2v was freed earlier — reload
1757+
print(" Reloading T2V pipeline...")
1758+
current_pipeline = BackendClass.load(
1759+
model_variant=model_variant,
1760+
mode="t2v",
1761+
torch_dtype=torch_dtype,
1762+
device=device,
1763+
quantization=quantization_override,
1764+
offload_strategy=offload,
1765+
enable_vae_slicing=True,
1766+
enable_vae_tiling=True,
1767+
)
17391768

17401769
gen_kwargs = dict(
17411770
prompt=prompt,

0 commit comments

Comments
 (0)