4444import jax
4545import transformers
4646
47+ from MaxText import max_logging
4748from MaxText import model_creation_utils
4849from MaxText import pyconfig
4950from MaxText .common_types import Config
7071flags .DEFINE_string ("model_name" , "qwen3-30b-a3b" , "Model name for MaxText." )
7172flags .DEFINE_string ("hf_model_name" , "Qwen/Qwen3-30B-A3B" , "Path to the Hugging Face model." )
7273flags .DEFINE_string ("hf_config_path" , None , "Path to the local Hugging Face model config." )
74+ flags .DEFINE_string ("hf_access_token" , None , "Hugging Face access token for private models." )
75+ flags .DEFINE_string ("tokenizer_path" , None , "Path to the tokenizer. If None, use hf_model_name." )
7376flags .DEFINE_string ("load_parameters_path" , None , "Path to load model parameters from." )
7477flags .DEFINE_bool ("enable_expert_parallel" , False , "Whether to enable expert parallelism." )
7578
8083
8184# Decoding
8285flags .DEFINE_bool ("use_tunix" , False , "Whether to use Tunix for vLLM decoding." )
86+ flags .DEFINE_bool ("use_chat_template" , False , "Whether to format the prompt using chat template." )
8387flags .DEFINE_string ("prompt" , "Suggest some famous landmarks in London." , "The prompt to decode." )
84- flags .DEFINE_integer ("decode_sampling_temperature" , 0 , "Temperature for sampling." )
85- flags .DEFINE_integer ("decode_sampling_nucleus_p" , 1 , "Nucleus sampling probability." )
88+ flags .DEFINE_float ("decode_sampling_temperature" , 0 , "Temperature for sampling." )
89+ flags .DEFINE_float ("decode_sampling_nucleus_p" , 1.0 , "Nucleus sampling probability." )
8690flags .DEFINE_integer ("decode_sampling_top_k" , 1 , "Top-k sampling probability." )
87-
88- # Mark required flags
89- flags .mark_flag_as_required ("hf_config_path" )
91+ flags .DEFINE_integer ("seed" , 42 , "Random seed for sampling." )
9092
9193
9294def decode_with_vllm (
9395 model_name : str ,
9496 hf_model_name : str ,
95- hf_config_path : str ,
96- load_parameters_path : str ,
9797 ici_data_parallelism : int ,
9898 ici_tensor_parallelism : int ,
9999 ici_expert_parallelism : int ,
100- max_prefill_length : int ,
101100 max_target_length : int ,
102101 gpu_memory_utilization : float ,
103102 enable_expert_parallel : bool ,
104103 prompt : str ,
105- decode_sampling_temperature : float ,
106- decode_sampling_nucleus_p : float ,
107- decode_sampling_top_k : float ,
104+ use_chat_template : bool = False ,
105+ decode_sampling_temperature : float = 0.0 ,
106+ decode_sampling_nucleus_p : float = 1.0 ,
107+ decode_sampling_top_k : int = 1 ,
108+ hf_config_path : str | None = None ,
109+ hf_access_token : str | None = None ,
110+ tokenizer_path : str | None = None ,
111+ load_parameters_path : str | None = None ,
112+ seed : int = 42 ,
108113) -> None :
109114 """Decode using vLLM with a MaxText model implementation.
110115
111116 Args:
112117 model_name: Name of the model for MaxText.
113118 hf_model_name: Path to the Hugging Face model.
114- hf_config_path: Path to the local Hugging Face model config.
115- load_parameters_path: Path to load model parameters from.
116119 ici_data_parallelism: Size of the data parallelism dimension.
117120 ici_tensor_parallelism: Size of the non-expert tensor parallelism dimension.
118121 ici_expert_parallelism: Size of the MoE expert parallelism dimension.
119- max_prefill_length: Maximum prefill length.
120122 max_target_length: Maximum total context length (MCL).
121123 gpu_memory_utilization: Fraction of GPU memory to be used for the model executor.
122124 enable_expert_parallel: Whether to enable expert parallelism.
123125 prompt: The prompt to decode.
126+ use_chat_template: Whether to format the prompt using chat template.
124127 decode_sampling_temperature: Temperature for sampling.
125128 decode_sampling_nucleus_p: Nucleus sampling probability.
126129 decode_sampling_top_k: Top-k sampling probability.
130+ hf_config_path: Path to the local Hugging Face model config.
131+ hf_access_token: Hugging Face access token for private models.
132+ tokenizer_path: Path to the tokenizer. If None, use hf_model_name.
133+ load_parameters_path: Path to load model parameters from.
127134 """
135+ if not model_name :
136+ raise ValueError ("model_name must be provided" )
137+
138+ if not hf_model_name :
139+ raise ValueError ("hf_model_name must be provided" )
140+
141+ if not hf_config_path :
142+ raise ValueError ("hf_config_path must be provided" )
128143
129144 # Prepare vLLM Arguments
130145 vllm_args = {}
@@ -142,7 +157,6 @@ def decode_with_vllm(
142157 # Prepare MaxText and sharding configs (Parallelism is dynamic)
143158 vllm_args ["additional_config" ]["maxtext_config" ] = {
144159 "model_name" : model_name ,
145- "max_target_length" : max_target_length ,
146160 "weight_dtype" : "bfloat16" ,
147161 "allow_split_physical_axes" : True ,
148162 }
@@ -154,24 +168,14 @@ def decode_with_vllm(
154168 vllm_args ["additional_config" ]["sharding" ] = {
155169 "sharding_strategy" : {
156170 "tensor_parallelism" : ici_tensor_parallelism ,
157- "expert_parallelism" : ici_expert_parallelism ,
158171 "data_parallelism" : ici_data_parallelism ,
159172 },
160173 }
161174
162175 if enable_expert_parallel :
163176 vllm_args ["additional_config" ]["sharding" ]["sharding_strategy" ].update ({"expert_parallelism" : ici_expert_parallelism })
164177
165- # Initialize and Run LLM
166- max_tokens = max_target_length - max_prefill_length
167- sampling_params = SamplingParams (
168- temperature = decode_sampling_temperature ,
169- max_tokens = max_tokens ,
170- top_k = decode_sampling_top_k ,
171- top_p = decode_sampling_nucleus_p ,
172- )
173-
174- print (
178+ max_logging .log (
175179 f"Initializing LLM with DP={ vllm_args ['data_parallel_size' ]} , TP={ vllm_args ['tensor_parallel_size' ]} "
176180 f"and EP={ ici_expert_parallelism if enable_expert_parallel else 0 } ..."
177181 )
@@ -183,14 +187,44 @@ def decode_with_vllm(
183187 with nn_partitioning .axis_rules (vllm_config .logical_axis_rules ):
184188 llm = LLM (** vllm_args )
185189
186- print ("Generating output..." )
187- outputs = llm .generate ([prompt ], sampling_params )
190+ max_logging .log ("Generating output..." )
191+ tokenizer = transformers .AutoTokenizer .from_pretrained (
192+ tokenizer_path if tokenizer_path is not None else hf_model_name ,
193+ token = hf_access_token ,
194+ )
195+
196+ prompts = [prompt ]
197+ if use_chat_template :
198+ # Format the prompt using chat template if specified
199+ messages = [
200+ {"role" : "user" , "content" : prompt },
201+ ]
202+ input_with_chat_template = tokenizer .apply_chat_template (
203+ messages ,
204+ tokenize = False , # Set to False to get the string
205+ add_generation_prompt = True ,
206+ add_special_tokens = False , # Prevent adding special tokens
207+ )
208+ prompts = [input_with_chat_template ]
209+
210+ max_prompt_length = max (len (tokenizer .encode (p )) for p in prompts )
211+ max_tokens_to_generate = max_target_length - max_prompt_length
212+
213+ sampling_params = SamplingParams (
214+ temperature = decode_sampling_temperature ,
215+ max_tokens = max_tokens_to_generate ,
216+ top_k = decode_sampling_top_k ,
217+ top_p = decode_sampling_nucleus_p ,
218+ seed = seed ,
219+ )
220+
221+ outputs = llm .generate (prompts , sampling_params )
188222
189- # Print Outputs
223+ # max_logging.log Outputs
190224 for output in outputs :
191225 prompt = output .prompt
192226 generated_text = output .outputs [0 ].text
193- print (f"Prompt: { prompt !r } , Generated text: { generated_text !r } " )
227+ max_logging . log (f"Prompt: { prompt } , Generated text: { generated_text } " )
194228
195229
196230def decode_with_tunix (
@@ -253,8 +287,8 @@ def decode_with_tunix(
253287
254288 # Generate text
255289 output = vllm_rollout .generate (prompts , rollout_config )
256- print (f"Prompt: { config .prompt } " )
257- print (f"Output: { output .text [0 ]} " )
290+ max_logging . log (f"Prompt: { config .prompt } " )
291+ max_logging . log (f"Output: { output .text [0 ]} " )
258292
259293
260294def main (argv : Sequence [str ]) -> None :
@@ -274,18 +308,21 @@ def main(argv: Sequence[str]) -> None:
274308 model_name = FLAGS .model_name ,
275309 hf_model_name = FLAGS .hf_model_name ,
276310 hf_config_path = FLAGS .hf_config_path ,
311+ hf_access_token = FLAGS .hf_access_token ,
312+ tokenizer_path = FLAGS .tokenizer_path ,
277313 load_parameters_path = FLAGS .load_parameters_path ,
278314 ici_data_parallelism = FLAGS .ici_data_parallelism ,
279315 ici_tensor_parallelism = FLAGS .ici_tensor_parallelism ,
280316 ici_expert_parallelism = FLAGS .ici_expert_parallelism ,
281317 max_target_length = FLAGS .max_target_length ,
282- max_prefill_length = FLAGS .max_prefill_length ,
283318 gpu_memory_utilization = FLAGS .gpu_memory_utilization ,
284319 enable_expert_parallel = FLAGS .enable_expert_parallel ,
285320 prompt = FLAGS .prompt ,
321+ use_chat_template = FLAGS .use_chat_template ,
286322 decode_sampling_temperature = FLAGS .decode_sampling_temperature ,
287323 decode_sampling_nucleus_p = FLAGS .decode_sampling_nucleus_p ,
288324 decode_sampling_top_k = FLAGS .decode_sampling_top_k ,
325+ seed = FLAGS .seed ,
289326 )
290327
291328
0 commit comments