-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathvllm_inference_local.py
More file actions
68 lines (57 loc) · 2.08 KB
/
vllm_inference_local.py
File metadata and controls
68 lines (57 loc) · 2.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from transformers import AutoProcessor
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info
import argparse
def main():
parser = argparse.ArgumentParser(description="Run Vision-R1 model inference.")
parser.add_argument("--model_path", type=str, default="", help="Path to the model.")
parser.add_argument("--image_path", type=str, default="", help="Path to the input image.")
parser.add_argument("--prompt", type=str, default="", help="The input prompt.")
parser.add_argument("--max_tokens", type=int, default=128, help="Max tokens of model generation")
parser.add_argument("--temperature", type=float, default=0.6, help="Temperature of generate")
parser.add_argument("--top_p", type=float, default=0.95, help="top_p of generate")
args = parser.parse_args()
llm = LLM(
model=args.model_path,
limit_mm_per_prompt={"image": 5},
)
sampling_params = SamplingParams(
temperature=args.temperature,
top_p=args.top_p,
max_tokens=args.max_tokens,
)
image_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [
{
"type": "image",
"image": args.image_path,
},
{
"type": "text",
"text": args.prompt},
],
},
]
# Here we use video messages as a demonstration
messages = image_messages
processor = AutoProcessor.from_pretrained(args.model_path)
prompt = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, _ = process_vision_info(messages)
mm_data = {}
mm_data["image"] = image_inputs
llm_inputs = {
"prompt": prompt,
"multi_modal_data": mm_data,
}
outputs = llm.generate([llm_inputs], sampling_params=sampling_params)
generated_text = outputs[0].outputs[0].text
print(generated_text)
if __name__ == "__main__":
main()