Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions embeddings_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,10 @@ def forward(
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
)
indices_grid = indices_grid[None, None, :]
freqs_cis = self.precompute_freqs_cis(indices_grid)
# "exp" RoPE uses POS_EMBEDDING_EXP_VALUES (sized for inner_dim=3840).
# LTX-2.3 connector has inner_dim=4096 → use "exp_2" (standard formula, scales with inner_dim).
_rope_spacing = "exp" if self.inner_dim == 3840 else "exp_2"
freqs_cis = self.precompute_freqs_cis(indices_grid, _rope_spacing)

# 2. Blocks
for block_idx, block in enumerate(self.transformer_1d_blocks):
Expand Down Expand Up @@ -376,7 +379,7 @@ def load_embeddings_connector(
split_rope=rope_type == LTXRopeType.SPLIT,
double_precision_rope=frequencies_precision == LTXFrequenciesPrecision.FLOAT64,
)
connector.load_state_dict(sd_connector)
connector.load_state_dict(sd_connector, strict=False)
return connector


Expand Down
44 changes: 33 additions & 11 deletions gemma_encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from glob import glob
from pathlib import Path
from typing import List, Optional, Tuple
Expand All @@ -10,6 +11,7 @@
import torch
from PIL import Image
from transformers import (
AutoModelForCausalLM,
AutoImageProcessor,
AutoTokenizer,
Gemma3Config,
Expand Down Expand Up @@ -51,7 +53,7 @@ def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
class LTXVGemmaTokenizer:
def __init__(self, tokenizer_path: str, max_length: int = 1024):
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, local_files_only=True, model_max_length=max_length
tokenizer_path, local_files_only=True, ignore_mismatched_sizes=True, model_max_length=max_length
)
# Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much.
self.tokenizer.padding_side = "left"
Expand Down Expand Up @@ -156,7 +158,7 @@ def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)

def memory_required(self, input_shape):
# Return a conservative estimate in bytesed(input_shape)
# Return a conservative estimate in bytes
return self._model_memory_required


Expand All @@ -168,14 +170,17 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
return _LTXVGemmaTokenizer


def ltxv_gemma_clip(encoder_path, ltxv_path, processor=None, dtype=None):
def ltxv_gemma_clip(encoder_path, ltxv_path, processor=None, dtype=None, gguf_file=None):
class _LTXVGemmaTextEncoderModel(LTXVGemmaTextEncoderModel):
def __init__(self, device="cpu", dtype=dtype, model_options={}):
dtype = torch.bfloat16 # TODO: make this configurable

gemma_model = Gemma3ForConditionalGeneration.from_pretrained(
_kw = {"local_files_only": True}
if gguf_file: _kw["gguf_file"] = gguf_file
gemma_model = AutoModelForCausalLM.from_pretrained(
encoder_path,
local_files_only=True,
dtype=dtype,
**_kw,
torch_dtype=dtype,
)

Expand Down Expand Up @@ -224,7 +229,8 @@ def INPUT_TYPES(s):
{"tooltip": "The name of the text encoder model to load."},
),
"ltxv_path": (
folder_paths.get_filename_list("checkpoints"),
[""] + folder_paths.get_filename_list("checkpoints"),
{"default": ""},
{"tooltip": "The name of the ltxv model to load."},
),
"max_length": (
Expand All @@ -245,15 +251,25 @@ def load_model(self, gemma_path: str, ltxv_path: str, max_length: int):
path = Path(folder_paths.get_full_path("text_encoders", gemma_path))
model_root = path.parents[1]
tokenizer_path = Path(find_matching_dir(model_root, "tokenizer.model"))
gemma_model_path = Path(find_matching_dir(model_root, "model*.safetensors"))
gguf_filename = None
try:
gemma_model_path = Path(find_matching_dir(model_root, "model*.safetensors"))
except Exception:
gguf_dir = Path(find_matching_dir(model_root, "*.gguf"))
gguf_files = sorted(gguf_dir.glob("*.gguf"))
if not gguf_files:
raise ValueError(f"No GGUF found in {gguf_dir}")
gemma_model_path = gguf_dir
gguf_filename = gguf_files[0].name
logger.info(f"Using GGUF: {gguf_dir / gguf_filename}")
processor_path = Path(find_matching_dir(model_root, "preprocessor_config.json"))
tokenizer_class = ltxv_gemma_tokenizer(tokenizer_path, max_length=max_length)

processor = None
try:
image_processor = AutoImageProcessor.from_pretrained(
str(processor_path),
local_files_only=True,
local_files_only=True, ignore_mismatched_sizes=True,
)
processor = Gemma3Processor(
image_processor=image_processor,
Expand All @@ -264,11 +280,17 @@ def load_model(self, gemma_path: str, ltxv_path: str, max_length: int):
logger.warning(f"Could not load processor from {model_root}: {e}")

clip_dtype = torch.bfloat16
ltxv_full_path = folder_paths.get_full_path("checkpoints", ltxv_path)
if ltxv_path:
ltxv_full_path = folder_paths.get_full_path("checkpoints", ltxv_path)
else:
_unet_dirs = folder_paths.get_folder_paths("unet")
_ggufs = [g for d in _unet_dirs for g in glob(os.path.join(d, "*.gguf"))]
ltxv_full_path = _ggufs[0] if _ggufs else str(gemma_model_path / "proj_linear.safetensors")
logger.info(f"GGUF connector path: {ltxv_full_path}")
clip_target = comfy.supported_models_base.ClipTarget(
tokenizer=tokenizer_class,
clip=ltxv_gemma_clip(
gemma_model_path, ltxv_full_path, processor=processor, dtype=clip_dtype
gemma_model_path, ltxv_full_path, processor=processor, dtype=clip_dtype, gguf_file=gguf_filename
),
)

Expand Down Expand Up @@ -662,7 +684,7 @@ def transformers_gemma3_from_encoder(encoder):
tokenizer_class = ltxv_gemma_tokenizer(jsons_path, max_length=1024)
image_processor = AutoImageProcessor.from_pretrained(
str(jsons_path),
local_files_only=True,
local_files_only=True, ignore_mismatched_sizes=True,
)
processor = Gemma3Processor(
image_processor=image_processor,
Expand Down
78 changes: 75 additions & 3 deletions text_embeddings_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
3. Embeddings Processor (Video / AV) -- wraps Embeddings1DConnector(s)
"""

import glob
import importlib.util
import json
import logging
import math
import os
from pathlib import Path

import folder_paths
import torch
from comfy.utils import load_torch_file
from einops import rearrange
Expand All @@ -21,6 +26,8 @@
load_video_embeddings_connector,
)

logger = logging.getLogger(__name__)

_PREFIX_BASE = "model.diffusion_model."
_PREFIX_TEXT_PROJ = "text_embedding_projection."

Expand Down Expand Up @@ -313,6 +320,38 @@ def _load_single_aggregate_embed_from_file(path, dtype):
# ---------------------------------------------------------------------------



def _load_gguf_connector_sd(gguf_path):
"""Load connector + projection tensors from GGUF for LTXVGemmaCLIPModelLoader."""
import gguf as _gguf
import numpy as np
reader = _gguf.GGUFReader(str(gguf_path))
sd = {}
prefixes = ('video_embeddings_connector', 'audio_embeddings_connector', 'text_embedding_projection', 'audio_adaln_single')
for t in reader.tensors:
if not any(t.name.startswith(p) for p in prefixes):
continue
try:
ttype = t.tensor_type.name
shape = list(reversed(t.shape))
raw = bytes(t.data)
if ttype == 'F32':
sd[f"model.diffusion_model.{t.name}"] = torch.from_numpy(np.frombuffer(raw, dtype=np.float32).copy()).reshape(shape)
elif ttype == 'F16':
sd[f"model.diffusion_model.{t.name}"] = torch.from_numpy(np.frombuffer(raw, dtype=np.float16).copy()).reshape(shape)
elif ttype == 'BF16':
sd[f"model.diffusion_model.{t.name}"] = torch.frombuffer(bytearray(raw), dtype=torch.bfloat16).reshape(shape).contiguous()
else:
dq = os.path.join(os.path.dirname(__file__), '..', 'ComfyUI-GGUF', 'dequant.py')
if os.path.exists(dq):
spec = importlib.util.spec_from_file_location("dequant", dq)
m = importlib.util.module_from_spec(spec); spec.loader.exec_module(m)
sd[f"model.diffusion_model.{t.name}"] = torch.tensor(m.dequantize(t.data, t.tensor_type), dtype=torch.float32).reshape(shape)
except Exception as e:
logger.warning("Skipping GGUF tensor %s: %s", t.name, e)
logger.info("GGUF connector: loaded %d tensors", len(sd))
return sd

def load_text_embeddings_pipeline(
ltxv_path, dtype=torch.bfloat16, fallback_proj_path=None
):
Expand All @@ -330,9 +369,42 @@ def load_text_embeddings_pipeline(
Returns:
(feature_extractor, embeddings_processor)
"""
sd, metadata = load_torch_file(str(ltxv_path), return_metadata=True)
config = json.loads(metadata.get("config", "{}"))
transformer_config = config.get("transformer", {})
if ltxv_path and str(ltxv_path).endswith('.gguf'):
sd = _load_gguf_connector_sd(ltxv_path)
transformer_config = {
"caption_projection_first_linear": False,
"caption_proj_input_norm": False,
"caption_projection_second_linear": False,
"caption_proj_before_connector": True,
"text_encoder_norm_type": "per_token_rms",
"prompt_embedding_dim": 3840,
"connector_num_layers": 8,
"connector_num_attention_heads": 32,
"connector_attention_head_dim": 128,
"connector_apply_gated_attention": True,
"connector_positional_embedding_max_pos": [4096],
"audio_connector_attention_head_dim": 64,
}
# Merge text_embedding_projection keys from proj_linear.safetensors
_proj_candidates = [
f for d in folder_paths.get_folder_paths("text_encoders")
for f in glob.glob(os.path.join(d, "*/proj_linear.safetensors"))
]
_proj_path = fallback_proj_path or (_proj_candidates[0] if _proj_candidates else None)
if _proj_path is not None:
try:
from comfy.utils import load_torch_file as _ltf2
proj_sd = _ltf2(str(_proj_path))
sd.update(proj_sd)
logger.info("Merged %d proj_linear keys from %s", len(proj_sd), _proj_path)
except Exception as e:
logger.warning("Could not merge proj_linear keys: %s", e)
else:
logger.warning("proj_linear.safetensors not found; text projection may be missing")
else:
sd, metadata = load_torch_file(str(ltxv_path), return_metadata=True)
config = json.loads(metadata.get("config", "{}"))
transformer_config = config.get("transformer", {})

is_av = f"{_PREFIX_BASE}audio_adaln_single.linear.weight" in sd
has_dual_aggregate = f"{_PREFIX_TEXT_PROJ}video_aggregate_embed.weight" in sd
Expand Down