-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgradio_app.py
More file actions
84 lines (71 loc) · 2.64 KB
/
gradio_app.py
File metadata and controls
84 lines (71 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import gradio as gr
from unsloth import FastVisionModel
import torch
from transformers import TextStreamer
from PIL import Image
# Load model and tokenizer once
model_path = "./lora_model"
device_map = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[LOG] Loading model on device: {device_map}")
model, tokenizer = FastVisionModel.from_pretrained(
"unsloth/Llama-3.2-11B-Vision-Instruct",
adapter_name=model_path,
load_in_4bit=True,
use_gradient_checkpointing="unsloth",
device_map=device_map
)
FastVisionModel.for_inference(model)
print("[LOG] Model and tokenizer loaded.")
instruction = "You are an expert radiographer. Describe accurately what you see in this image."
def analyze_image(image):
if image is None:
return "No image selected!"
try:
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": instruction}
]}
]
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
inputs = tokenizer(
image,
input_text,
add_special_tokens=False,
return_tensors="pt",
).to(device_map)
output_ids = model.generate(
**inputs,
max_new_tokens=512,
use_cache=True,
temperature=1,
min_p=0.1
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
if instruction in output_text:
output_text = output_text.split(instruction, 1)[1]
prefixes_to_remove = [
"You are an expert radiographer. Describe accurately what you see in this image.",
"Describe accurately what you see in this image.",
"assistant",
"Assistant:",
"I am an expert radiographer."
]
for prefix in prefixes_to_remove:
if output_text.startswith(prefix):
output_text = output_text[len(prefix):]
output_text = output_text.lstrip(":., \n")
return output_text.strip()
except Exception as e:
return f"Error: {e}"
with gr.Blocks() as demo:
gr.Markdown("# Radiology Image Analyzer\nUpload an image to get an AI analysis.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
with gr.Column():
output = gr.Textbox(label="Model Output", lines=10)
analyze_btn = gr.Button("Analyze", variant="primary")
analyze_btn.click(analyze_image, inputs=image_input, outputs=output)
demo.launch(share=True)