Skip to content

Commit 36ae2d0

Browse files
committed
Adding chat template to vllm decode.
1 parent d4a259d commit 36ae2d0

1 file changed

Lines changed: 70 additions & 33 deletions

File tree

src/MaxText/vllm_decode.py

Lines changed: 70 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import jax
4545
import transformers
4646

47+
from MaxText import max_logging
4748
from MaxText import model_creation_utils
4849
from MaxText import pyconfig
4950
from MaxText.common_types import Config
@@ -70,6 +71,8 @@
7071
flags.DEFINE_string("model_name", "qwen3-30b-a3b", "Model name for MaxText.")
7172
flags.DEFINE_string("hf_model_name", "Qwen/Qwen3-30B-A3B", "Path to the Hugging Face model.")
7273
flags.DEFINE_string("hf_config_path", None, "Path to the local Hugging Face model config.")
74+
flags.DEFINE_string("hf_access_token", None, "Hugging Face access token for private models.")
75+
flags.DEFINE_string("tokenizer_path", None, "Path to the tokenizer. If None, use hf_model_name.")
7376
flags.DEFINE_string("load_parameters_path", None, "Path to load model parameters from.")
7477
flags.DEFINE_bool("enable_expert_parallel", False, "Whether to enable expert parallelism.")
7578

@@ -80,51 +83,63 @@
8083

8184
# Decoding
8285
flags.DEFINE_bool("use_tunix", False, "Whether to use Tunix for vLLM decoding.")
86+
flags.DEFINE_bool("use_chat_template", False, "Whether to format the prompt using chat template.")
8387
flags.DEFINE_string("prompt", "Suggest some famous landmarks in London.", "The prompt to decode.")
84-
flags.DEFINE_integer("decode_sampling_temperature", 0, "Temperature for sampling.")
85-
flags.DEFINE_integer("decode_sampling_nucleus_p", 1, "Nucleus sampling probability.")
88+
flags.DEFINE_float("decode_sampling_temperature", 0, "Temperature for sampling.")
89+
flags.DEFINE_float("decode_sampling_nucleus_p", 1.0, "Nucleus sampling probability.")
8690
flags.DEFINE_integer("decode_sampling_top_k", 1, "Top-k sampling probability.")
87-
88-
# Mark required flags
89-
flags.mark_flag_as_required("hf_config_path")
91+
flags.DEFINE_integer("seed", 42, "Random seed for sampling.")
9092

9193

9294
def decode_with_vllm(
9395
model_name: str,
9496
hf_model_name: str,
95-
hf_config_path: str,
96-
load_parameters_path: str,
9797
ici_data_parallelism: int,
9898
ici_tensor_parallelism: int,
9999
ici_expert_parallelism: int,
100-
max_prefill_length: int,
101100
max_target_length: int,
102101
gpu_memory_utilization: float,
103102
enable_expert_parallel: bool,
104103
prompt: str,
105-
decode_sampling_temperature: float,
106-
decode_sampling_nucleus_p: float,
107-
decode_sampling_top_k: float,
104+
use_chat_template: bool = False,
105+
decode_sampling_temperature: float = 0.0,
106+
decode_sampling_nucleus_p: float = 1.0,
107+
decode_sampling_top_k: int = 1,
108+
hf_config_path: str | None = None,
109+
hf_access_token: str | None = None,
110+
tokenizer_path: str | None = None,
111+
load_parameters_path: str | None = None,
112+
seed: int = 42,
108113
) -> None:
109114
"""Decode using vLLM with a MaxText model implementation.
110115
111116
Args:
112117
model_name: Name of the model for MaxText.
113118
hf_model_name: Path to the Hugging Face model.
114-
hf_config_path: Path to the local Hugging Face model config.
115-
load_parameters_path: Path to load model parameters from.
116119
ici_data_parallelism: Size of the data parallelism dimension.
117120
ici_tensor_parallelism: Size of the non-expert tensor parallelism dimension.
118121
ici_expert_parallelism: Size of the MoE expert parallelism dimension.
119-
max_prefill_length: Maximum prefill length.
120122
max_target_length: Maximum total context length (MCL).
121123
gpu_memory_utilization: Fraction of GPU memory to be used for the model executor.
122124
enable_expert_parallel: Whether to enable expert parallelism.
123125
prompt: The prompt to decode.
126+
use_chat_template: Whether to format the prompt using chat template.
124127
decode_sampling_temperature: Temperature for sampling.
125128
decode_sampling_nucleus_p: Nucleus sampling probability.
126129
decode_sampling_top_k: Top-k sampling probability.
130+
hf_config_path: Path to the local Hugging Face model config.
131+
hf_access_token: Hugging Face access token for private models.
132+
tokenizer_path: Path to the tokenizer. If None, use hf_model_name.
133+
load_parameters_path: Path to load model parameters from.
127134
"""
135+
if not model_name:
136+
raise ValueError("model_name must be provided")
137+
138+
if not hf_model_name:
139+
raise ValueError("hf_model_name must be provided")
140+
141+
if not hf_config_path:
142+
raise ValueError("hf_config_path must be provided")
128143

129144
# Prepare vLLM Arguments
130145
vllm_args = {}
@@ -142,7 +157,6 @@ def decode_with_vllm(
142157
# Prepare MaxText and sharding configs (Parallelism is dynamic)
143158
vllm_args["additional_config"]["maxtext_config"] = {
144159
"model_name": model_name,
145-
"max_target_length": max_target_length,
146160
"weight_dtype": "bfloat16",
147161
"allow_split_physical_axes": True,
148162
}
@@ -154,24 +168,14 @@ def decode_with_vllm(
154168
vllm_args["additional_config"]["sharding"] = {
155169
"sharding_strategy": {
156170
"tensor_parallelism": ici_tensor_parallelism,
157-
"expert_parallelism": ici_expert_parallelism,
158171
"data_parallelism": ici_data_parallelism,
159172
},
160173
}
161174

162175
if enable_expert_parallel:
163176
vllm_args["additional_config"]["sharding"]["sharding_strategy"].update({"expert_parallelism": ici_expert_parallelism})
164177

165-
# Initialize and Run LLM
166-
max_tokens = max_target_length - max_prefill_length
167-
sampling_params = SamplingParams(
168-
temperature=decode_sampling_temperature,
169-
max_tokens=max_tokens,
170-
top_k=decode_sampling_top_k,
171-
top_p=decode_sampling_nucleus_p,
172-
)
173-
174-
print(
178+
max_logging.log(
175179
f"Initializing LLM with DP={vllm_args['data_parallel_size']}, TP={vllm_args['tensor_parallel_size']} "
176180
f"and EP={ici_expert_parallelism if enable_expert_parallel else 0}..."
177181
)
@@ -183,14 +187,44 @@ def decode_with_vllm(
183187
with nn_partitioning.axis_rules(vllm_config.logical_axis_rules):
184188
llm = LLM(**vllm_args)
185189

186-
print("Generating output...")
187-
outputs = llm.generate([prompt], sampling_params)
190+
max_logging.log("Generating output...")
191+
tokenizer = transformers.AutoTokenizer.from_pretrained(
192+
tokenizer_path if tokenizer_path is not None else hf_model_name,
193+
token=hf_access_token,
194+
)
195+
196+
prompts = [prompt]
197+
if use_chat_template:
198+
# Format the prompt using chat template if specified
199+
messages = [
200+
{"role": "user", "content": prompt},
201+
]
202+
input_with_chat_template = tokenizer.apply_chat_template(
203+
messages,
204+
tokenize=False, # Set to False to get the string
205+
add_generation_prompt=True,
206+
add_special_tokens=False, # Prevent adding special tokens
207+
)
208+
prompts = [input_with_chat_template]
209+
210+
max_prompt_length = max(len(tokenizer.encode(p)) for p in prompts)
211+
max_tokens_to_generate = max_target_length - max_prompt_length
212+
213+
sampling_params = SamplingParams(
214+
temperature=decode_sampling_temperature,
215+
max_tokens=max_tokens_to_generate,
216+
top_k=decode_sampling_top_k,
217+
top_p=decode_sampling_nucleus_p,
218+
seed=seed,
219+
)
220+
221+
outputs = llm.generate(prompts, sampling_params)
188222

189-
# Print Outputs
223+
# max_logging.log Outputs
190224
for output in outputs:
191225
prompt = output.prompt
192226
generated_text = output.outputs[0].text
193-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
227+
max_logging.log(f"Prompt: {prompt}, Generated text: {generated_text}")
194228

195229

196230
def decode_with_tunix(
@@ -253,8 +287,8 @@ def decode_with_tunix(
253287

254288
# Generate text
255289
output = vllm_rollout.generate(prompts, rollout_config)
256-
print(f"Prompt: {config.prompt}")
257-
print(f"Output: {output.text[0]}")
290+
max_logging.log(f"Prompt: {config.prompt}")
291+
max_logging.log(f"Output: {output.text[0]}")
258292

259293

260294
def main(argv: Sequence[str]) -> None:
@@ -274,18 +308,21 @@ def main(argv: Sequence[str]) -> None:
274308
model_name=FLAGS.model_name,
275309
hf_model_name=FLAGS.hf_model_name,
276310
hf_config_path=FLAGS.hf_config_path,
311+
hf_access_token=FLAGS.hf_access_token,
312+
tokenizer_path=FLAGS.tokenizer_path,
277313
load_parameters_path=FLAGS.load_parameters_path,
278314
ici_data_parallelism=FLAGS.ici_data_parallelism,
279315
ici_tensor_parallelism=FLAGS.ici_tensor_parallelism,
280316
ici_expert_parallelism=FLAGS.ici_expert_parallelism,
281317
max_target_length=FLAGS.max_target_length,
282-
max_prefill_length=FLAGS.max_prefill_length,
283318
gpu_memory_utilization=FLAGS.gpu_memory_utilization,
284319
enable_expert_parallel=FLAGS.enable_expert_parallel,
285320
prompt=FLAGS.prompt,
321+
use_chat_template=FLAGS.use_chat_template,
286322
decode_sampling_temperature=FLAGS.decode_sampling_temperature,
287323
decode_sampling_nucleus_p=FLAGS.decode_sampling_nucleus_p,
288324
decode_sampling_top_k=FLAGS.decode_sampling_top_k,
325+
seed=FLAGS.seed,
289326
)
290327

291328

0 commit comments

Comments
 (0)