|
| 1 | +import os |
| 2 | + |
| 3 | +from peft import LoraConfig |
| 4 | +from peft.peft_model import PeftModel |
1 | 5 | from transformers import PaliGemmaForConditionalGeneration |
| 6 | +from transformers.utils import is_flash_attn_2_available |
2 | 7 |
|
| 8 | +from inference.core.env import DEVICE, MODEL_CACHE_DIR |
3 | 9 | from inference.models.transformers import LoRATransformerModel, TransformerModel |
4 | 10 |
|
5 | 11 |
|
| 12 | +def _get_paligemma_attn_implementation(): |
| 13 | + """Use flash_attention_2 if available, otherwise eager. |
| 14 | +
|
| 15 | + SDPA has dtype mismatch issues with token_type_ids in transformers 4.57+. |
| 16 | + """ |
| 17 | + if is_flash_attn_2_available() and DEVICE and "cuda" in DEVICE: |
| 18 | + # Verify flash_attn can actually be imported (not just installed) |
| 19 | + try: |
| 20 | + import flash_attn # noqa: F401 |
| 21 | + |
| 22 | + return "flash_attention_2" |
| 23 | + except ImportError: |
| 24 | + pass |
| 25 | + return "eager" |
| 26 | + |
| 27 | + |
6 | 28 | class PaliGemma(TransformerModel): |
7 | 29 | """By using you agree to the terms listed at https://ai.google.dev/gemma/terms""" |
8 | 30 |
|
9 | 31 | generation_includes_input = True |
10 | 32 | transformers_class = PaliGemmaForConditionalGeneration |
11 | 33 |
|
| 34 | + def initialize_model(self, **kwargs): |
| 35 | + if not self.load_base_from_roboflow: |
| 36 | + model_id = self.dataset_id |
| 37 | + else: |
| 38 | + model_id = self.cache_dir |
| 39 | + |
| 40 | + self.model = ( |
| 41 | + self.transformers_class.from_pretrained( |
| 42 | + model_id, |
| 43 | + cache_dir=self.cache_dir, |
| 44 | + device_map=DEVICE, |
| 45 | + token=self.huggingface_token, |
| 46 | + torch_dtype=self.default_dtype, |
| 47 | + attn_implementation=_get_paligemma_attn_implementation(), |
| 48 | + ) |
| 49 | + .eval() |
| 50 | + .to(self.dtype) |
| 51 | + ) |
| 52 | + |
| 53 | + self.processor = self.processor_class.from_pretrained( |
| 54 | + model_id, cache_dir=self.cache_dir, token=self.huggingface_token |
| 55 | + ) |
| 56 | + |
12 | 57 |
|
13 | 58 | class LoRAPaliGemma(LoRATransformerModel): |
14 | 59 | """By using you agree to the terms listed at https://ai.google.dev/gemma/terms""" |
15 | 60 |
|
16 | 61 | generation_includes_input = True |
17 | 62 | transformers_class = PaliGemmaForConditionalGeneration |
18 | 63 | load_base_from_roboflow = True |
| 64 | + |
| 65 | + def initialize_model(self, **kwargs): |
| 66 | + import torch |
| 67 | + |
| 68 | + lora_config = LoraConfig.from_pretrained(self.cache_dir, device_map=DEVICE) |
| 69 | + model_id = lora_config.base_model_name_or_path |
| 70 | + revision = lora_config.revision |
| 71 | + if revision is not None: |
| 72 | + try: |
| 73 | + self.dtype = getattr(torch, revision) |
| 74 | + except AttributeError: |
| 75 | + pass |
| 76 | + if not self.load_base_from_roboflow: |
| 77 | + model_load_id = model_id |
| 78 | + cache_dir = os.path.join(MODEL_CACHE_DIR, "huggingface") |
| 79 | + revision = revision |
| 80 | + token = self.huggingface_token |
| 81 | + else: |
| 82 | + model_load_id = self.get_lora_base_from_roboflow(model_id, revision) |
| 83 | + cache_dir = model_load_id |
| 84 | + revision = None |
| 85 | + token = None |
| 86 | + self.base_model = self.transformers_class.from_pretrained( |
| 87 | + model_load_id, |
| 88 | + revision=revision, |
| 89 | + device_map=DEVICE, |
| 90 | + cache_dir=cache_dir, |
| 91 | + token=token, |
| 92 | + attn_implementation=_get_paligemma_attn_implementation(), |
| 93 | + ).to(self.dtype) |
| 94 | + self.model = ( |
| 95 | + PeftModel.from_pretrained(self.base_model, self.cache_dir) |
| 96 | + .eval() |
| 97 | + .to(self.dtype) |
| 98 | + ) |
| 99 | + |
| 100 | + self.model.merge_and_unload() |
| 101 | + |
| 102 | + self.processor = self.processor_class.from_pretrained( |
| 103 | + model_load_id, revision=revision, cache_dir=cache_dir, token=token |
| 104 | + ) |
0 commit comments