Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 119 additions & 79 deletions src/maxtext/vllm_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import transformers

from maxtext.utils import model_creation_utils
from maxtext.utils import max_logging
from MaxText import pyconfig
from MaxText.common_types import Config
from MaxText.globals import MAXTEXT_CONFIGS_DIR
Expand All @@ -69,119 +70,112 @@
flags.DEFINE_bool("debug_sharding", False, "Debug Shardings")

# Model
flags.DEFINE_string("model_name", "qwen3-30b-a3b", "Model name for MaxText.")
flags.DEFINE_string("hf_model_name", "Qwen/Qwen3-30B-A3B", "Path to the Hugging Face model.")
flags.DEFINE_string("model_name", None, "Model name for MaxText.")
flags.DEFINE_string("hf_model_name", None, "Path to the Hugging Face model.")
flags.DEFINE_string("hf_config_path", None, "Path to the local Hugging Face model config.")
flags.DEFINE_string("hf_access_token", None, "Hugging Face access token for private models.")
flags.DEFINE_string("tokenizer_path", None, "Path to the tokenizer. If None, use hf_model_name.")
flags.DEFINE_string("load_parameters_path", None, "Path to load model parameters from.")
flags.DEFINE_bool("enable_expert_parallel", False, "Whether to enable expert parallelism.")

# Length/Throughput
flags.DEFINE_integer("max_target_length", 1024, "Maximum total context length (MCL).")
flags.DEFINE_integer("max_prefill_length", 512, "Maximum prefill length.")
flags.DEFINE_float("gpu_memory_utilization", 0.72, "Fraction of GPU memory to be used for the model executor.")

# Decoding
flags.DEFINE_bool("use_tunix", False, "Whether to use Tunix for vLLM decoding.")
flags.DEFINE_bool("use_chat_template", False, "Whether to format the prompt using chat template.")
flags.DEFINE_string("prompt", "Suggest some famous landmarks in London.", "The prompt to decode.")
flags.DEFINE_integer("decode_sampling_temperature", 0, "Temperature for sampling.")
flags.DEFINE_integer("decode_sampling_nucleus_p", 1, "Nucleus sampling probability.")
flags.DEFINE_float("decode_sampling_temperature", 0.0, "Temperature for sampling.")
flags.DEFINE_float("decode_sampling_nucleus_p", 1.0, "Nucleus sampling probability.")
flags.DEFINE_integer("decode_sampling_top_k", 1, "Top-k sampling probability.")
Comment thread
NicoGrande marked this conversation as resolved.
flags.DEFINE_integer("seed", 42, "Random seed for sampling.")

# Mark required flags
flags.mark_flag_as_required("hf_config_path")
Comment thread
NicoGrande marked this conversation as resolved.
# Set mandatory flags
flags.mark_flag_as_required("model_name")
flags.mark_flag_as_required("hf_model_name")


def decode_with_vllm(
model_name: str,
hf_model_name: str,
hf_config_path: str,
load_parameters_path: str,
ici_data_parallelism: int,
ici_tensor_parallelism: int,
ici_expert_parallelism: int,
enable_dp_attention: bool,
max_prefill_length: int,
max_target_length: int,
gpu_memory_utilization: float,
enable_expert_parallel: bool,
hf_config_path: str | None,
prompt: str,
decode_sampling_temperature: float,
decode_sampling_nucleus_p: float,
decode_sampling_top_k: float,
debug_sharding: bool,
vllm_config_path: str | None = None,
ici_data_parallelism: int = 1,
ici_tensor_parallelism: int = 1,
ici_expert_parallelism: int = 1,
enable_dp_attention: bool = False,
max_target_length: int = 1024,
gpu_memory_utilization: float = 0.72,
use_chat_template: bool = False,
decode_sampling_temperature: float = 0.0,
decode_sampling_nucleus_p: float = 1.0,
decode_sampling_top_k: int = 1,
hf_access_token: str | None = None,
tokenizer_path: str | None = None,
load_parameters_path: str | None = None,
debug_sharding: bool = False,
seed: int = 42,
Comment thread
NicoGrande marked this conversation as resolved.
) -> None:
"""Decode using vLLM with a MaxText model implementation.

Args:
model_name: Name of the model for MaxText.
hf_model_name: Path to the Hugging Face model.
hf_config_path: Path to the local Hugging Face model config.
load_parameters_path: Path to load model parameters from.
prompt: The prompt to decode.
ici_data_parallelism: Size of the data parallelism dimension.
ici_tensor_parallelism: Size of the non-expert tensor parallelism dimension.
ici_expert_parallelism: Size of the MoE expert parallelism dimension
enable_dp_attention: Enable DP attention
max_prefill_length: Maximum prefill length.
ici_expert_parallelism: Size of the MoE expert parallelism dimension.
enable_dp_attention: Enable attention DP parallelism.
max_target_length: Maximum total context length (MCL).
gpu_memory_utilization: Fraction of GPU memory to be used for the model executor.
enable_expert_parallel: Whether to enable expert parallelism.
prompt: The prompt to decode.
use_chat_template: Whether to format the prompt using chat template.
decode_sampling_temperature: Temperature for sampling.
decode_sampling_nucleus_p: Nucleus sampling probability.
decode_sampling_top_k: Top-k sampling probability.
vllm_config_path: Path to vLLM config file. Defaults to MAXTEXT_PKG_DIR/configs/vllm.yml.
hf_access_token: Hugging Face access token for private models.
tokenizer_path: Path to the tokenizer. If None, use hf_model_name.
load_parameters_path: Path to load model parameters from.
debug_sharding: Whether to debug shardings.
seed: Random seed for sampling.
"""

# Prepare vLLM Arguments
vllm_args = {}
vllm_args["additional_config"] = {}

# Core vLLM Arguments
vllm_args["model"] = hf_model_name
vllm_args["max_model_len"] = max_target_length
vllm_args["tensor_parallel_size"] = ici_tensor_parallelism
vllm_args["data_parallel_size"] = ici_data_parallelism
vllm_args["enable_expert_parallel"] = enable_expert_parallel
vllm_args["hf_config_path"] = hf_config_path
vllm_args["gpu_memory_utilization"] = gpu_memory_utilization

# Prepare MaxText and sharding configs (Parallelism is dynamic)
vllm_args["additional_config"]["maxtext_config"] = {
"model_name": model_name,
"max_target_length": max_target_length,
"weight_dtype": "bfloat16",
"allow_split_physical_axes": True,
"debug_sharding": debug_sharding,
vllm_args = {
"model": hf_model_name,
"max_model_len": max_target_length,
"tensor_parallel_size": ici_tensor_parallelism,
"data_parallel_size": ici_data_parallelism,
"hf_config_path": hf_config_path,
"gpu_memory_utilization": gpu_memory_utilization,
"additional_config": {
"maxtext_config": {
"model_name": model_name,
"weight_dtype": "bfloat16",
"allow_split_physical_axes": True,
"debug_sharding": debug_sharding,
},
"sharding": {
"sharding_strategy": {
"enable_dp_attention": enable_dp_attention,
},
},
},
}
if load_parameters_path is not None:

if load_parameters_path:
vllm_args["additional_config"]["maxtext_config"]["load_parameters_path"] = load_parameters_path
else:
vllm_args["load_format"] = "dummy"

sharding_strategy = {
"enable_dp_attention": enable_dp_attention,
}
enable_expert_parallel = ici_expert_parallelism > 1
if enable_expert_parallel:
sharding_strategy["expert_parallelism"] = ici_expert_parallelism
vllm_args["additional_config"]["sharding"] = {
"sharding_strategy": sharding_strategy,
}

if enable_expert_parallel:
vllm_args["additional_config"]["sharding"]["sharding_strategy"].update({"expert_parallelism": ici_expert_parallelism})

# Initialize and Run LLM
max_tokens = max_target_length - max_prefill_length
sampling_params = SamplingParams(
temperature=decode_sampling_temperature,
max_tokens=max_tokens,
top_k=decode_sampling_top_k,
top_p=decode_sampling_nucleus_p,
)
vllm_args["additional_config"]["sharding"]["sharding_strategy"]["expert_parallelism"] = ici_expert_parallelism
vllm_args["enable_expert_parallel"] = enable_expert_parallel

print(
f"Initializing LLM with DP={vllm_args['data_parallel_size']}, TP={vllm_args['tensor_parallel_size']} "
max_logging.log(
f"Initializing LLM with DP={ici_data_parallelism}, TP={ici_tensor_parallelism} "
f"and EP={ici_expert_parallelism if enable_expert_parallel else 0}..."
)

Expand All @@ -192,22 +186,62 @@ def decode_with_vllm(
with nn_partitioning.axis_rules(vllm_config.logical_axis_rules):
llm = LLM(**vllm_args)

print("Generating output...")
outputs = llm.generate([prompt], sampling_params)
max_logging.log("Generating output...")
tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer_path if tokenizer_path is not None else hf_model_name,
token=hf_access_token,
)

# Print Outputs
prompts = [prompt]
if use_chat_template:
# Format the prompt using chat template if specified
messages = [
{"role": "user", "content": prompt},
]
input_with_chat_template = tokenizer.apply_chat_template(
messages,
tokenize=False, # Set to False to get the string
add_generation_prompt=True,
add_special_tokens=False, # Prevent adding special tokens
)
prompts = [input_with_chat_template]

max_prompt_length = max(len(tokenizer.encode(p)) for p in prompts)
max_tokens_to_generate = max_target_length - max_prompt_length
if max_tokens_to_generate <= 0:
raise ValueError(
f"max_target_length ({max_target_length}) must be greater than max_prompt_length ({max_prompt_length})"
)

sampling_params = SamplingParams(
temperature=decode_sampling_temperature,
max_tokens=max_tokens_to_generate,
top_k=decode_sampling_top_k,
top_p=decode_sampling_nucleus_p,
seed=seed,
)

outputs = llm.generate(prompts, sampling_params)

# max_logging.log Outputs
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
max_logging.log(f"Prompt: {prompt}, Generated text: {generated_text}")


def decode_with_tunix(
config: Config,
model: Any,
mesh: jax.sharding.Mesh,
) -> None:
"""Decode using vLLM with a MaxText model."""
"""Decode using vLLM with a MaxText model via Tunix adapter.

Args:
config: MaxText config.
model: The MaxText model instance.
mesh: The JAX mesh for parallelism.
"""
# Wrap the model for Tunix
tunix_model = TunixMaxTextAdapter(base_model=model)

Expand Down Expand Up @@ -235,6 +269,10 @@ def decode_with_tunix(

max_prompt_length = max(len(tokenizer.encode(p)) for p in prompts)
max_tokens_to_generate = config.max_target_length - max_prompt_length
if max_tokens_to_generate <= 0:
raise ValueError(
f"max_target_length ({config.max_target_length}) must be greater than max_prompt_length ({max_prompt_length})"
)

# Create vLLM rollout for inference
rollout_config = base_rollout.RolloutConfig(
Expand Down Expand Up @@ -262,8 +300,8 @@ def decode_with_tunix(

# Generate text
output = vllm_rollout.generate(prompts, rollout_config)
print(f"Prompt: {config.prompt}")
print(f"Output: {output.text[0]}")
max_logging.log(f"Prompt: {config.prompt}")
max_logging.log(f"Output: {output.text[0]}")


def main(argv: Sequence[str]) -> None:
Expand All @@ -283,20 +321,22 @@ def main(argv: Sequence[str]) -> None:
model_name=FLAGS.model_name,
hf_model_name=FLAGS.hf_model_name,
hf_config_path=FLAGS.hf_config_path,
hf_access_token=FLAGS.hf_access_token,
tokenizer_path=FLAGS.tokenizer_path,
load_parameters_path=FLAGS.load_parameters_path,
ici_data_parallelism=FLAGS.ici_data_parallelism,
ici_tensor_parallelism=FLAGS.ici_tensor_parallelism,
ici_expert_parallelism=FLAGS.ici_expert_parallelism,
enable_dp_attention=FLAGS.enable_dp_attention,
max_target_length=FLAGS.max_target_length,
max_prefill_length=FLAGS.max_prefill_length,
gpu_memory_utilization=FLAGS.gpu_memory_utilization,
enable_expert_parallel=FLAGS.enable_expert_parallel,
prompt=FLAGS.prompt,
use_chat_template=FLAGS.use_chat_template,
decode_sampling_temperature=FLAGS.decode_sampling_temperature,
decode_sampling_nucleus_p=FLAGS.decode_sampling_nucleus_p,
decode_sampling_top_k=FLAGS.decode_sampling_top_k,
debug_sharding=FLAGS.debug_sharding,
seed=FLAGS.seed,
)


Expand Down
Loading