Skip to content

Commit 04ccabf

Browse files
authored
Feature/llava next and onevision variants (#1202)
* Creating Architecture Adapters for Llava Next and Onevision variants * Format checks * Updated testing for older models with new fixes * improved comments, slight loosening of tolerances
1 parent 42a2d52 commit 04ccabf

11 files changed

Lines changed: 486 additions & 78 deletions

File tree

tests/integration/model_bridge/compatibility/test_legacy_hooks.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,11 @@ def test_cache_hook_equality_with_hooked_transformer(
172172
assert torch.allclose(
173173
hooked_transformer_activation[unmasked_positions],
174174
bridge_activation[unmasked_positions],
175-
atol=1e-6,
176-
rtol=1e-6,
175+
atol=1e-4,
176+
rtol=1e-4,
177177
), (
178-
"Unmasked attention scores should match within float32 " "numerical precision"
178+
"Unmasked attention scores should match within float32 "
179+
"cross-implementation tolerance"
179180
)
180181

181182
masked_bridge_values = bridge_activation[masked_positions]

transformer_lens/benchmarks/main_benchmark.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def _hf_token() -> Optional[str]:
112112
"Gemma3ForCausalLM",
113113
"Gemma3ForConditionalGeneration",
114114
"LlavaForConditionalGeneration",
115+
"LlavaNextForConditionalGeneration",
116+
"LlavaOnevisionForConditionalGeneration",
115117
]
116118

117119

@@ -188,6 +190,8 @@ def _is_multimodal_model(model_name: str, trust_remote_code: bool = False) -> bo
188190
"""Check if a model is a multimodal (vision-language) model."""
189191
MULTIMODAL_ARCHITECTURES = [
190192
"LlavaForConditionalGeneration",
193+
"LlavaNextForConditionalGeneration",
194+
"LlavaOnevisionForConditionalGeneration",
191195
"Gemma3ForConditionalGeneration",
192196
]
193197
try:
@@ -1117,6 +1121,11 @@ def cleanup_model(model, model_name_str: str):
11171121
bridge_unprocessed = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, trust_remote_code=trust_remote_code) # type: ignore[attr-defined]
11181122
if verbose:
11191123
print("✓ TransformerBridge loaded (unprocessed)\n")
1124+
# Apply the adapter's prepare_model() to the HF reference model so
1125+
# both bridge and reference have the same fixups (e.g., weight tying).
1126+
# This keeps model-specific logic in the adapter, not the benchmark.
1127+
if hf_model is not None and hasattr(bridge_unprocessed, "adapter"):
1128+
bridge_unprocessed.adapter.prepare_model(hf_model)
11201129
except Exception as e:
11211130
import traceback
11221131

transformer_lens/benchmarks/multimodal.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ def _create_test_image():
3131
def _prepare_test_inputs(bridge: TransformerBridge):
3232
"""Prepare multimodal test inputs using the bridge's processor.
3333
34-
Returns (input_ids, pixel_values, prompt) or (None, None, None) on failure.
34+
Returns (input_ids, extra_kwargs, prompt) where extra_kwargs is a dict
35+
containing pixel_values and any other processor outputs (e.g. image_sizes
36+
for LlavaNext). Returns (None, None, None) on failure.
3537
"""
3638
if bridge.processor is None:
3739
return None, None, None
@@ -51,8 +53,19 @@ def _prepare_test_inputs(bridge: TransformerBridge):
5153
try:
5254
inputs = bridge.processor(text=prompt, images=image, return_tensors="pt")
5355
input_ids = inputs["input_ids"].to(bridge.cfg.device)
54-
pixel_values = inputs["pixel_values"].to(bridge.cfg.device)
55-
return input_ids, pixel_values, prompt
56+
57+
# Collect all extra kwargs the model's forward() may need
58+
# (pixel_values, image_sizes, pixel_attention_mask, etc.)
59+
extra_kwargs = {}
60+
for key, val in inputs.items():
61+
if key == "input_ids":
62+
continue
63+
if hasattr(val, "to"):
64+
extra_kwargs[key] = val.to(bridge.cfg.device)
65+
else:
66+
extra_kwargs[key] = val
67+
68+
return input_ids, extra_kwargs, prompt
5669
except Exception:
5770
return None, None, None
5871

@@ -88,7 +101,7 @@ def benchmark_multimodal_forward(
88101
message="Skipped for tiny/test model",
89102
)
90103

91-
input_ids, pixel_values, prompt = _prepare_test_inputs(bridge)
104+
input_ids, extra_kwargs, prompt = _prepare_test_inputs(bridge)
92105
if input_ids is None:
93106
return BenchmarkResult(
94107
name="multimodal_forward",
@@ -98,7 +111,7 @@ def benchmark_multimodal_forward(
98111

99112
try:
100113
with torch.no_grad():
101-
logits = bridge.forward(input_ids, pixel_values=pixel_values, return_type="logits")
114+
logits = bridge.forward(input_ids, return_type="logits", **extra_kwargs)
102115

103116
if logits is None:
104117
return BenchmarkResult(
@@ -120,14 +133,17 @@ def benchmark_multimodal_forward(
120133
passed=False,
121134
)
122135

136+
pixel_values = extra_kwargs.get("pixel_values")
123137
return BenchmarkResult(
124138
name="multimodal_forward",
125139
severity=BenchmarkSeverity.INFO,
126140
message=f"Multimodal forward pass successful, logits shape: {list(logits.shape)}",
127141
details={
128142
"logits_shape": list(logits.shape),
129143
"input_ids_shape": list(input_ids.shape),
130-
"pixel_values_shape": list(pixel_values.shape),
144+
"pixel_values_shape": list(pixel_values.shape)
145+
if pixel_values is not None
146+
else None,
131147
},
132148
)
133149

@@ -173,7 +189,7 @@ def benchmark_multimodal_generation(
173189
message="Skipped for tiny/test model",
174190
)
175191

176-
input_ids, pixel_values, prompt = _prepare_test_inputs(bridge)
192+
input_ids, extra_kwargs, prompt = _prepare_test_inputs(bridge)
177193
if input_ids is None:
178194
return BenchmarkResult(
179195
name="multimodal_generation",
@@ -185,8 +201,8 @@ def benchmark_multimodal_generation(
185201
output = bridge.generate(
186202
input_ids,
187203
max_new_tokens=max_new_tokens,
188-
pixel_values=pixel_values,
189204
return_type="tokens",
205+
**extra_kwargs,
190206
)
191207

192208
if not isinstance(output, torch.Tensor):
@@ -264,7 +280,7 @@ def benchmark_multimodal_cache(
264280
message="Skipped for tiny/test model",
265281
)
266282

267-
input_ids, pixel_values, prompt = _prepare_test_inputs(bridge)
283+
input_ids, extra_kwargs, prompt = _prepare_test_inputs(bridge)
268284
if input_ids is None:
269285
return BenchmarkResult(
270286
name="multimodal_cache",
@@ -274,7 +290,7 @@ def benchmark_multimodal_cache(
274290

275291
try:
276292
with torch.no_grad():
277-
logits, cache = bridge.run_with_cache(input_ids, pixel_values=pixel_values)
293+
logits, cache = bridge.run_with_cache(input_ids, **extra_kwargs)
278294

279295
if cache is None or len(cache) == 0:
280296
return BenchmarkResult(

transformer_lens/factories/architecture_adapter_factory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
GPTOSSArchitectureAdapter,
1919
LlamaArchitectureAdapter,
2020
LlavaArchitectureAdapter,
21+
LlavaNextArchitectureAdapter,
22+
LlavaOnevisionArchitectureAdapter,
2123
MingptArchitectureAdapter,
2224
MistralArchitectureAdapter,
2325
MixtralArchitectureAdapter,
@@ -55,6 +57,8 @@
5557
"GPTJForCausalLM": GptjArchitectureAdapter,
5658
"LlamaForCausalLM": LlamaArchitectureAdapter,
5759
"LlavaForConditionalGeneration": LlavaArchitectureAdapter,
60+
"LlavaNextForConditionalGeneration": LlavaNextArchitectureAdapter,
61+
"LlavaOnevisionForConditionalGeneration": LlavaOnevisionArchitectureAdapter,
5862
"MixtralForCausalLM": MixtralArchitectureAdapter,
5963
"MistralForCausalLM": MistralArchitectureAdapter,
6064
"NeoForCausalLM": NeoArchitectureAdapter,

transformer_lens/model_bridge/bridge.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,6 +1805,7 @@ def generate(
18051805
verbose: bool = True,
18061806
output_logits: bool = False,
18071807
pixel_values: Optional[torch.Tensor] = None,
1808+
**multimodal_kwargs,
18081809
) -> str | list[str] | torch.Tensor | Any: # Any for transformers.utils.ModelOutput
18091810
# Using Any due to beartype's forward reference resolution limitations.
18101811
# See: https://github.com/beartype/beartype/issues/546
@@ -1920,10 +1921,15 @@ def generate(
19201921
)
19211922
else:
19221923
forward_kwargs: Dict[str, Any] = {}
1923-
# Pass pixel_values only on the first step — the vision encoder
1924-
# processes the image once, embedding it into the token sequence.
1925-
if gen_step_idx == 0 and pixel_values is not None:
1926-
forward_kwargs["pixel_values"] = pixel_values
1924+
# Pass multimodal inputs only on the first step — the vision
1925+
# encoder processes the image once, embedding it into the
1926+
# token sequence. This includes pixel_values plus any extra
1927+
# processor outputs (e.g. image_sizes for LlavaNext).
1928+
if gen_step_idx == 0:
1929+
if pixel_values is not None:
1930+
forward_kwargs["pixel_values"] = pixel_values
1931+
if multimodal_kwargs:
1932+
forward_kwargs.update(multimodal_kwargs)
19271933
logits = self(current_tokens, return_type="logits", **forward_kwargs)
19281934
final_logits = logits[:, -1, :]
19291935

transformer_lens/model_bridge/sources/transformers.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ def get_hf_model_class_for_architecture(architecture: str):
252252
}
253253
multimodal_architectures = {
254254
"LlavaForConditionalGeneration",
255+
"LlavaNextForConditionalGeneration",
256+
"LlavaOnevisionForConditionalGeneration",
255257
"Gemma3ForConditionalGeneration",
256258
}
257259
if architecture in seq2seq_architectures:
@@ -453,7 +455,50 @@ def boot(
453455
trust_remote_code=trust_remote_code,
454456
)
455457
except Exception:
456-
pass # Processor not available; user can set bridge.processor manually
458+
# Some multimodal processors (e.g., LlavaOnevision) require
459+
# torchvision for video processing. Conditionally install it
460+
# and retry the processor loading.
461+
_torchvision_available = False
462+
try:
463+
import torchvision # noqa: F401
464+
465+
_torchvision_available = True
466+
except Exception:
467+
# torchvision may be missing (ImportError) or broken/version-
468+
# mismatched (RuntimeError). Try to install/reinstall it.
469+
import shutil
470+
import subprocess
471+
import sys
472+
473+
try:
474+
if shutil.which("uv"):
475+
subprocess.check_call(
476+
["uv", "pip", "install", "torchvision", "-q"],
477+
)
478+
else:
479+
subprocess.check_call(
480+
[sys.executable, "-m", "pip", "install", "torchvision", "-q"],
481+
)
482+
import importlib
483+
484+
importlib.invalidate_caches()
485+
_torchvision_available = True
486+
except Exception:
487+
pass # torchvision install failed; processor will be unavailable
488+
489+
if _torchvision_available:
490+
try:
491+
from transformers import AutoProcessor
492+
493+
huggingface_token = os.environ.get("HF_TOKEN", "")
494+
token_arg = huggingface_token if len(huggingface_token) > 0 else None
495+
bridge.processor = AutoProcessor.from_pretrained(
496+
model_name,
497+
token=token_arg,
498+
trust_remote_code=trust_remote_code,
499+
)
500+
except Exception:
501+
pass # Processor not available; user can set bridge.processor manually
457502

458503
return bridge
459504

transformer_lens/model_bridge/supported_architectures/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@
3939
from transformer_lens.model_bridge.supported_architectures.llava import (
4040
LlavaArchitectureAdapter,
4141
)
42+
from transformer_lens.model_bridge.supported_architectures.llava_next import (
43+
LlavaNextArchitectureAdapter,
44+
)
45+
from transformer_lens.model_bridge.supported_architectures.llava_onevision import (
46+
LlavaOnevisionArchitectureAdapter,
47+
)
4248
from transformer_lens.model_bridge.supported_architectures.mingpt import (
4349
MingptArchitectureAdapter,
4450
)
@@ -116,6 +122,8 @@
116122
"GptjArchitectureAdapter",
117123
"LlamaArchitectureAdapter",
118124
"LlavaArchitectureAdapter",
125+
"LlavaNextArchitectureAdapter",
126+
"LlavaOnevisionArchitectureAdapter",
119127
"MingptArchitectureAdapter",
120128
"MistralArchitectureAdapter",
121129
"MixtralArchitectureAdapter",
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""LLava-NeXT architecture adapter.
2+
3+
Same module hierarchy as base LLava; high-res tiling differences are
4+
handled internally by HuggingFace's forward().
5+
"""
6+
7+
from transformer_lens.model_bridge.supported_architectures.llava import (
8+
LlavaArchitectureAdapter,
9+
)
10+
11+
12+
class LlavaNextArchitectureAdapter(LlavaArchitectureAdapter):
13+
"""Architecture adapter for LLaVA-NeXT (1.6) models."""
14+
15+
pass
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""LLava-OneVision architecture adapter.
2+
3+
Same module hierarchy as base LLava; SigLIP encoder and Qwen2 backbone
4+
are handled dynamically by the base adapter and HuggingFace's forward().
5+
"""
6+
7+
from typing import Any
8+
9+
from transformer_lens.model_bridge.supported_architectures.llava import (
10+
LlavaArchitectureAdapter,
11+
)
12+
13+
14+
class LlavaOnevisionArchitectureAdapter(LlavaArchitectureAdapter):
15+
"""Architecture adapter for LLaVA-OneVision models."""
16+
17+
def prepare_model(self, hf_model: Any) -> None:
18+
"""Fix weight tying when text_config and top-level config disagree.
19+
20+
Some checkpoints have tie_word_embeddings=True in text_config but False
21+
at the top level, leaving lm_head randomly initialized.
22+
"""
23+
if not hasattr(hf_model, "lm_head") or not hasattr(hf_model, "model"):
24+
return
25+
language_model = getattr(hf_model.model, "language_model", None)
26+
if language_model is None:
27+
return
28+
embed = getattr(language_model, "embed_tokens", None)
29+
if embed is None:
30+
return
31+
32+
# Check if text config expects tied weights but top-level config doesn't
33+
text_config = getattr(hf_model.config, "text_config", None)
34+
if text_config is not None and getattr(text_config, "tie_word_embeddings", False):
35+
if not getattr(hf_model.config, "tie_word_embeddings", True):
36+
hf_model.lm_head.weight = embed.weight

0 commit comments

Comments
 (0)