Skip to content

Commit ac5b39e

Browse files
Merge pull request #1893 from roboflow/add_qwen3vl
Fix paligemma sdpa issue
2 parents 5ac2102 + 3404242 commit ac5b39e

2 files changed

Lines changed: 106 additions & 0 deletions

File tree

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,104 @@
1+
import os
2+
3+
from peft import LoraConfig
4+
from peft.peft_model import PeftModel
15
from transformers import PaliGemmaForConditionalGeneration
6+
from transformers.utils import is_flash_attn_2_available
27

8+
from inference.core.env import DEVICE, MODEL_CACHE_DIR
39
from inference.models.transformers import LoRATransformerModel, TransformerModel
410

511

12+
def _get_paligemma_attn_implementation():
13+
"""Use flash_attention_2 if available, otherwise eager.
14+
15+
SDPA has dtype mismatch issues with token_type_ids in transformers 4.57+.
16+
"""
17+
if is_flash_attn_2_available() and DEVICE and "cuda" in DEVICE:
18+
# Verify flash_attn can actually be imported (not just installed)
19+
try:
20+
import flash_attn # noqa: F401
21+
22+
return "flash_attention_2"
23+
except ImportError:
24+
pass
25+
return "eager"
26+
27+
628
class PaliGemma(TransformerModel):
729
"""By using you agree to the terms listed at https://ai.google.dev/gemma/terms"""
830

931
generation_includes_input = True
1032
transformers_class = PaliGemmaForConditionalGeneration
1133

34+
def initialize_model(self, **kwargs):
35+
if not self.load_base_from_roboflow:
36+
model_id = self.dataset_id
37+
else:
38+
model_id = self.cache_dir
39+
40+
self.model = (
41+
self.transformers_class.from_pretrained(
42+
model_id,
43+
cache_dir=self.cache_dir,
44+
device_map=DEVICE,
45+
token=self.huggingface_token,
46+
torch_dtype=self.default_dtype,
47+
attn_implementation=_get_paligemma_attn_implementation(),
48+
)
49+
.eval()
50+
.to(self.dtype)
51+
)
52+
53+
self.processor = self.processor_class.from_pretrained(
54+
model_id, cache_dir=self.cache_dir, token=self.huggingface_token
55+
)
56+
1257

1358
class LoRAPaliGemma(LoRATransformerModel):
1459
"""By using you agree to the terms listed at https://ai.google.dev/gemma/terms"""
1560

1661
generation_includes_input = True
1762
transformers_class = PaliGemmaForConditionalGeneration
1863
load_base_from_roboflow = True
64+
65+
def initialize_model(self, **kwargs):
66+
import torch
67+
68+
lora_config = LoraConfig.from_pretrained(self.cache_dir, device_map=DEVICE)
69+
model_id = lora_config.base_model_name_or_path
70+
revision = lora_config.revision
71+
if revision is not None:
72+
try:
73+
self.dtype = getattr(torch, revision)
74+
except AttributeError:
75+
pass
76+
if not self.load_base_from_roboflow:
77+
model_load_id = model_id
78+
cache_dir = os.path.join(MODEL_CACHE_DIR, "huggingface")
79+
revision = revision
80+
token = self.huggingface_token
81+
else:
82+
model_load_id = self.get_lora_base_from_roboflow(model_id, revision)
83+
cache_dir = model_load_id
84+
revision = None
85+
token = None
86+
self.base_model = self.transformers_class.from_pretrained(
87+
model_load_id,
88+
revision=revision,
89+
device_map=DEVICE,
90+
cache_dir=cache_dir,
91+
token=token,
92+
attn_implementation=_get_paligemma_attn_implementation(),
93+
).to(self.dtype)
94+
self.model = (
95+
PeftModel.from_pretrained(self.base_model, self.cache_dir)
96+
.eval()
97+
.to(self.dtype)
98+
)
99+
100+
self.model.merge_and_unload()
101+
102+
self.processor = self.processor_class.from_pretrained(
103+
model_load_id, revision=revision, cache_dir=cache_dir, token=token
104+
)

inference_experimental/inference_exp/models/paligemma/paligemma_hf.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,23 @@
1919
BitsAndBytesConfig,
2020
PaliGemmaForConditionalGeneration,
2121
)
22+
from transformers.utils import is_flash_attn_2_available
23+
24+
25+
def _get_paligemma_attn_implementation(device: torch.device) -> str:
26+
"""Use flash_attention_2 if available, otherwise eager.
27+
28+
SDPA has dtype mismatch issues with token_type_ids in transformers 4.57+.
29+
"""
30+
if is_flash_attn_2_available() and device.type == "cuda":
31+
# Verify flash_attn can actually be imported (not just installed)
32+
try:
33+
import flash_attn # noqa: F401
34+
35+
return "flash_attention_2"
36+
except ImportError:
37+
pass
38+
return "eager"
2239

2340

2441
class PaliGemmaHF:
@@ -59,6 +76,7 @@ def from_pretrained(
5976
bnb_4bit_quant_type="nf4",
6077
bnb_4bit_compute_dtype=torch.bfloat16,
6178
)
79+
attn_implementation = _get_paligemma_attn_implementation(device)
6280
adapter_config_path = os.path.join(model_name_or_path, "adapter_config.json")
6381
if os.path.exists(adapter_config_path):
6482
base_model_path = os.path.join(model_name_or_path, "base")
@@ -68,6 +86,7 @@ def from_pretrained(
6886
trust_remote_code=trust_remote_code,
6987
local_files_only=local_files_only,
7088
quantization_config=quantization_config,
89+
attn_implementation=attn_implementation,
7190
)
7291
model = PeftModel.from_pretrained(model, model_name_or_path)
7392
if quantization_config is None:
@@ -88,6 +107,7 @@ def from_pretrained(
88107
trust_remote_code=trust_remote_code,
89108
local_files_only=local_files_only,
90109
quantization_config=quantization_config,
110+
attn_implementation=attn_implementation,
91111
).eval()
92112
processor = AutoProcessor.from_pretrained(
93113
model_name_or_path,

0 commit comments

Comments
 (0)