Skip to content

Commit 6ffb3cf

Browse files
committed
Add prototype VLLM inference engine
1 parent 1a27804 commit 6ffb3cf

3 files changed

Lines changed: 76 additions & 4 deletions

File tree

align_system/algorithms/abstracts.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ def choose_action(self,
1616

1717
class StructuredInferenceEngine(ABC):
1818
@abstractmethod
19-
def dialog_to_prompt(dialog: list[dict]) -> str:
19+
def dialog_to_prompt(self, dialog: list[dict]) -> str:
2020
pass
2121

2222
@abstractmethod
23-
def run_inference(prompts: Union[str, list[str]],
23+
def run_inference(self,
24+
prompts: Union[str, list[str]],
2425
schema: str) -> Union[dict, list[dict]]:
2526
pass
2627

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Union
2+
import json
3+
4+
import jinja2
5+
from vllm import LLM, SamplingParams
6+
from vllm.sampling_params import StructuredOutputsParams
7+
8+
from align_system.algorithms.abstracts import StructuredInferenceEngine
9+
10+
# Sometimes the internal default for VLLM is 50,
11+
# leading to very short (and often invalid JSON) outputs. Setting a
12+
# somewhat generous default.
13+
DEFAULT_MAX_TOKENS = 8192
14+
15+
class VLLMInferenceEngine(StructuredInferenceEngine):
16+
def __init__(self,
17+
model_name,
18+
sampling_params=None):
19+
self.llm = LLM(model=model_name)
20+
21+
self.sampling_params = sampling_params
22+
if self.sampling_params is None:
23+
self.sampling_params = {}
24+
25+
if 'max_tokens' not in self.sampling_params:
26+
self.sampling_params['max_tokens'] = DEFAULT_MAX_TOKENS
27+
28+
def dialog_to_prompt(self, dialog: list[dict]) -> str:
29+
tokenizer = self.llm.get_tokenizer()
30+
31+
try:
32+
encoded_dialog = tokenizer.apply_chat_template(dialog)
33+
except jinja2.exceptions.TemplateError:
34+
# Assume that the tokenizer chat template doesn't accept
35+
# system messages; combine system message first user
36+
# message
37+
# Ensure each dialog element is a dict
38+
system_msg, user_msg, *rest = [dict(d) for d in dialog]
39+
40+
assert user_msg['role'] == 'user'
41+
42+
updated_content = system_msg['content'] + '\n' + user_msg['content']
43+
44+
dialog = [{'role': 'user', 'content': updated_content}, *rest]
45+
46+
encoded_dialog = tokenizer.apply_chat_template(dialog)
47+
48+
return tokenizer.decode(encoded_dialog)
49+
50+
def run_inference(self,
51+
prompts: Union[str, list[str]],
52+
schema: str) -> Union[dict, list[dict]]:
53+
json_schema = json.loads(schema)
54+
schema_params = StructuredOutputsParams(json=json_schema)
55+
56+
structured_sampling_params = SamplingParams(
57+
**self.sampling_params,
58+
structured_outputs=schema_params)
59+
60+
outputs = self.llm.generate(
61+
prompts,
62+
sampling_params=structured_sampling_params)
63+
64+
parsed_outputs = [json.loads(o.outputs[0].text) for o in outputs]
65+
66+
if isinstance(prompts, str):
67+
# Return single instance if single prompt provided as a string
68+
return parsed_outputs[0]
69+
else:
70+
return parsed_outputs

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ url = "https://download.pytorch.org/whl/cu130"
1212
priority = "supplemental"
1313

1414
[tool.poetry.dependencies]
15-
python = ">=3.9,<3.13"
15+
python = ">=3.10,<3.13"
1616
torch = { version = "^2.0.1", source = "pytorch" }
1717
transformers = "^4.57.1"
1818
llama-index = "^0.8.21"
@@ -29,11 +29,12 @@ rouge-score = "^0.1.2"
2929
swagger-client = {git = "https://github.com/NextCenturyCorporation/itm-evaluation-client.git", rev = "0.5.0"}
3030
hydra-core = "^1.3.2"
3131
outlines = "^1.2.7"
32-
setuptools = "^70.1.1"
32+
setuptools = "^77.0.3"
3333
sentencepiece = "^0.2.0"
3434
protobuf = "^5.28.3"
3535
datasets = "^3.3.2"
3636
ubelt = "1.3.6"
37+
vllm = "^0.11.1"
3738

3839
[tool.poetry.scripts]
3940
run_align_system = 'align_system.cli.run_align_system:main'

0 commit comments

Comments
 (0)