From 0726da7eb0713bfd348c0578bb09d8eeb70c4697 Mon Sep 17 00:00:00 2001 From: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> Date: Wed, 21 Jan 2026 00:24:10 -0500 Subject: [PATCH] =?UTF-8?q?[src/MaxText/inference/offline=5Fengine.py]=20`?= =?UTF-8?q?OfflineEngine.=5F=5Finit=5F=5F`=C2=A0takes=20many=20configurati?= =?UTF-8?q?on=20flags=20(=20`enable=5Fbatch=5Fprefill`,=C2=A0`min=5Fdecode?= =?UTF-8?q?=5Fsteps`,=C2=A0`prefill=5Flengths`,=C2=A0`eos=5Fids`)=20to=20c?= =?UTF-8?q?onfigure=20internal=20workers=20and=20prefill=20helpers.=20The?= =?UTF-8?q?=20setup=20logic=20branches=20based=20on=20these=20flags=20(e.g?= =?UTF-8?q?.,=20choosing=C2=A0`BatchedPrefillProcessor`=C2=A0vs=C2=A0`Pref?= =?UTF-8?q?illProcessor`)=20;=20[tests/{grpo=5Ftrainer=5Fcorrectness=5Ftes?= =?UTF-8?q?t.py,inference/benchmark=5Foffline=5Fengine.py,offline=5Fengine?= =?UTF-8?q?=5Ftest.py}]=20Update=20tests=20for=20new=20fluent=20interface?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/MaxText/experimental/rl/grpo_trainer.py | 8 +- src/MaxText/inference/offline_engine.py | 387 +++++++++++++------- tests/grpo_trainer_correctness_test.py | 5 +- tests/inference/benchmark_offline_engine.py | 16 +- tests/offline_engine_test.py | 16 +- 5 files changed, 291 insertions(+), 141 deletions(-) diff --git a/src/MaxText/experimental/rl/grpo_trainer.py b/src/MaxText/experimental/rl/grpo_trainer.py index 926f1e1b67..47dceb615c 100644 --- a/src/MaxText/experimental/rl/grpo_trainer.py +++ b/src/MaxText/experimental/rl/grpo_trainer.py @@ -41,6 +41,7 @@ import functools import threading +from copy import deepcopy from typing import Sequence, Callable, Iterator from absl import app @@ -691,10 +692,9 @@ def train_loop(config, config_inference, recorder, state=None): data_sharding = sharding.get_input_data_sharding(config, mesh) - inference_engine = offline_engine.OfflineEngine( - config=config_inference, - mesh=inference_mesh, - ) + config = deepcopy(config_inference) + config.mesh = inference_mesh + inference_engine = offline_engine.OfflineEngine(config=config) data_buffer = [] data_buffer_lock = threading.Lock() diff --git a/src/MaxText/inference/offline_engine.py b/src/MaxText/inference/offline_engine.py index 255b2a0791..f23fde7edb 100644 --- a/src/MaxText/inference/offline_engine.py +++ b/src/MaxText/inference/offline_engine.py @@ -16,11 +16,8 @@ Offline Inference Engine Example usage: - offline_engine = OfflineEngine( - config=maxtext_config, - params=None, - enable_batch_prefill=True, - ) + config_obj = OfflineEngineBuilder(maxtext_config).build_config() + offline_engine = OfflineEngine(config_obj) input_data = [ jax.numpy.arange(80), @@ -73,7 +70,7 @@ class InputData: true_length: Actual length of the input before padding """ - id: str + id: str | int tokens: jax.Array | np.ndarray true_length: int @@ -97,7 +94,12 @@ class CompletionOutput: @dataclasses.dataclass class TokenOutput: - """Container for individual token generation result.""" + """Container for individual token generation result. + + Attributes: + token: The generated token ID. + log_prob: The log probability of the token. + """ token: np.ndarray log_prob: np.ndarray @@ -105,7 +107,18 @@ class TokenOutput: @dataclasses.dataclass class DetokenizationTask: - """Container for detokenization work to be done on background thread.""" + """Container for detokenization work to be done on background thread. + + Attributes: + task_type: Type of task ("prefill" or "decode"). + result_tokens: List of result tokens (for prefill). + log_prob: Log probabilities (for prefill). + prompt_logp: Prompt log probabilities (for prefill). + prompt_ids: List of prompt IDs (for prefill). + slots: List of slots (for prefill). + tokens_buffer: Buffer of tokens (for decode). + logprob_buffer: Buffer of log probabilities (for decode). + """ task_type: str # "prefill" or "decode" # For prefill tasks @@ -123,6 +136,7 @@ class SafeThread(threading.Thread): """Thread class with exception handling to prevent silent failures.""" def run(self): + """Executes the thread's activity with exception capturing.""" try: super().run() except Exception as _: # pylint: disable=broad-exception-caught @@ -140,7 +154,13 @@ class PrefillType(Enum): @dataclasses.dataclass class PrefillResult: - """Result from prefill processing operation.""" + """Result from prefill processing operation. + + Attributes: + result_tokens: The result tokens object from the engine. + slot: The slot index associated with this result. + prompt_logp: Optional log probabilities for the prompt. + """ result_tokens: "jetstream.engine_api.ResultTokens" slot: int @@ -165,11 +185,12 @@ def __init__( """Initialize the PrefillHelper. Args: - type: The type of prefill processor to use ("default" or "batch") + prefill_type: The type of prefill processor to use ("default" or "batch") engine: The MaxEngine instance to use for prefill operations prefill_lengths: list of prompt lengths to support batch_prefill_max_batch_size: Maximum number of prompts in one packed - sequence for batch prefill + sequence for batch prefill. + rng: Optional random number generator. """ self._type = prefill_type self.engine = engine @@ -286,38 +307,7 @@ def finalize( class InferenceWorker: """ - InferenceWorker runs continuous batching over - a queue of inputs. - - Continuous batching workflow: - 1. Process inputs one at a time from queue - 2. Prefill input and insert into KV cache - 3. Continue prefilling until enough samples for batch decode - 4. Decode until at least one sequence completes - 5. Refill newly available decode slots with prefill - 6. Repeat until all sequences complete - - Prefill Packing: - When enable_batch_prefill is True, the prefill processor - will pack multiple inputs into a single sequence before - doing the prefill. - - There are multiple buckets for packed sequences, where each bucket - contains inputs with the same padded length. Only inputs with the same - padded length can be packed together. - - It is important to sort the inputs by padded length so that the - buckets fill up quickly. - - When a decode slot frees up, the prefill processor will add the - sequence to a bucket. If the bucket becomes full, the packed sequence - will be prefilled. - - E.g. - Bucket for length 64: [...seq1, ...seq2, ...seq3, ...seq4] - Bucket for length 128: [...seq1, ...seq2] - Bucket for length 256: [...seq1] - + InferenceWorker runs continuous batching over a queue of inputs. """ def __init__( @@ -349,10 +339,10 @@ def __init__( prefill_lengths: list of supported prefill lengths max_decode_length: Maximum tokens to generate per sequence batch_prefill_max_batch_size: Maximum batch size for batch prefill - run_as_a_thread: Whether to run in a separate thread + is_pw_reshard: Whether to use Pathways for resharding rng: Random number generator key mesh: JAX mesh for distributed computation - is_pw_reshard: Whether to use Pathways for resharding + debug: Whether to run in debug mode """ # Configurations self.config = config @@ -444,15 +434,15 @@ def update_params( self, params: Params, ): - """Update the model parameters""" + """Update the model parameters. + + Args: + params: New model parameters. + """ self.params = params def reset_state(self): - """Reset all worker state for a new inference run. - - This allows reusing the same InferenceWorker instance across multiple - batch_inference calls without recreating the expensive engine components. - """ + """Reset all worker state for a new inference run.""" max_logging.log("Resetting InferenceWorker state") # Reset inference state @@ -475,6 +465,9 @@ def run_inference(self, data: list[InputData], rng=None): Args: data: list of InputData objects containing input sequences rng: Random number generator key. If None, the previous key will be used. + + Returns: + List of CompletionOutput objects. """ # Reset state for new inference run @@ -547,7 +540,14 @@ def _run_continuous_batching( max_logging.log(f"Inference worker: detokenization thread joined in {time.time() - start_time} seconds") def _build_final_outputs(self, input_data: list[InputData]) -> list[CompletionOutput]: - """Build the final list of CompletionOutput.""" + """Build the final list of CompletionOutput. + + Args: + input_data: list of input data items. + + Returns: + list of CompletionOutput objects. + """ with jax.profiler.TraceAnnotation("offline_engine.batch_inference.return_final_output"): completion_outputs = [] @@ -582,7 +582,12 @@ def _build_final_outputs(self, input_data: list[InputData]) -> list[CompletionOu ) return completion_outputs - def prefill_done(self, prefill_result: list[PrefillResult], prompt_ids: list[int], decode_state: DecodeState): + def prefill_done( + self, + prefill_result: list[PrefillResult], + prompt_ids: list[int], + decode_state: DecodeState, + ): """Callback function called when prefill completes. This function queues the prefill data for background processing. @@ -751,72 +756,210 @@ def emit_token( return should_terminate +@dataclasses.dataclass +class OfflineEngineConfig: + """Configuration for OfflineEngine.""" + + config: Any + params: Any = None + enable_batch_prefill: bool = False + min_decode_steps: int = 10 + tokenizer: Any = None + eos_ids: list[int] | None = None + prefill_lengths: list[int] | str = "auto" + batch_prefill_max_batch_size: int = 16 + mesh: Mesh = None + rng: Any = None + debug: bool = False + max_decode_length: int | None = None + + def validate(self): + """Validates the configuration.""" + if self.enable_batch_prefill and self.config.scan_layers: + raise ValueError("scan_layers must be False if enable_batch_prefill is True") + if not self.config.return_log_prob: + raise ValueError("return_log_prob must be True when using OfflineEngine") + if self.config.scan_layers: + max_logging.log( + "WARNING: scan_layers=True will result in slow step time. " "It is recommended for debugging purposes only." + ) + + +class OfflineEngineBuilder: + """Builder for OfflineEngine configuration.""" + + def __init__(self, config: Any): + # Initialize with default config options + self._config = OfflineEngineConfig(config=config) + + def enable_batch_prefill(self, max_batch_size: int = 16): + """Enables batch prefill with specified max batch size. + + Args: + max_batch_size: Maximum batch size for batch prefill. + + Returns: + self + """ + self._config.enable_batch_prefill = True + self._config.batch_prefill_max_batch_size = max_batch_size + return self + + def set_decoding_params(self, min_steps: int = 10, max_len: int | None = None): + """Sets decoding parameters. + + Args: + min_steps: Minimum number of decode steps to run at once. + max_len: Maximum decode length override. + + Returns: + self + """ + self._config.min_decode_steps = min_steps + self._config.max_decode_length = max_len + return self + + def set_tokenizer(self, path: str): + """Sets tokenizer path. + + Args: + path: Path to tokenizer model. + + Returns: + self + """ + self._config.config.tokenizer_path = path + # Tokenizer instance will be built by Engine using config.tokenizer_path + return self + + def set_params(self, params: Any): + """Sets the model parameters. + + Args: + params: Model parameters. + + Returns: + self + """ + self._config.params = params + return self + + def set_mesh(self, mesh: Mesh): + """Sets the mesh. + + Args: + mesh: JAX mesh. + + Returns: + self + """ + self._config.mesh = mesh + return self + + def set_batch_prefill_max_batch_size(self, size: int): + """Sets batch prefill max batch size. + Args: + size: max size. + Returns: + self. + """ + self._config.batch_prefill_max_batch_size = size + return self + + def set_eos_ids(self, eos_ids: list[int]): + """Sets EOS IDs. + Args: + eos_ids: list of eos ids. + Returns: + self. + """ + self._config.eos_ids = eos_ids + return self + + def set_prefill_lengths(self, lengths: list[int] | str): + """Sets prefill lengths. + + Args: + lengths: List of lengths or "auto". + + Returns: + self + """ + self._config.prefill_lengths = lengths + return self + + def set_rng(self, rng: Any): + """Sets random number generator. + + Args: + rng: PRNG Key. + + Returns: + self + """ + self._config.rng = rng + return self + + def set_debug(self, debug: bool): + """Sets debug flag. + + Args: + debug: Debug boolean. + + Returns: + self + """ + self._config.debug = debug + return self + + def build(self): + """Builds the OfflineEngine. + + Returns: + Initialized OfflineEngine. + """ + return OfflineEngine(self._config) + + class OfflineEngine: """Class for handling offline inference on batches of inputs.""" - def __init__( - self, - config: Any, - params: None | Params = None, - enable_batch_prefill: bool = False, - min_decode_steps: int = 10, - tokenizer: Any = None, - eos_ids: list[int] | None = None, - prefill_lengths: list[int] | str = "auto", - batch_prefill_max_batch_size: int = 16, - mesh: Mesh = None, - rng: jax.random.PRNGKey = None, - debug: bool = False, - ): + def __init__(self, config: OfflineEngineConfig): """Initialize the OfflineEngine. Args: - config: The MaxText config object which will be used to - create MaxEngine instance(s). - params: Model parameters (loaded from engine if None) - enable_batch_prefill: Whether to use prefill packing. - config.scan_layers must be False if this is True - min_decode_steps: Number of decode steps to perform at a time, - before checking for completion. - eos_ids: list of EOS token IDs for checking sequence completion. - If None, the tokenizer's EOS token will be used. - tokenizer: Tokenizer instance for encoding/decoding text. If None, - will be created using the config if eos_ids is not provided. - prefill_lengths: list of expected prefill lengths, or "auto" to - automatically determine appropriate lengths from the engine - config. Input sequences will be padded to the nearest length - in this list. - batch_prefill_max_batch_size: Maximum number of inputs to pack - into a single prefill. This is only used when enable_batch_prefill - is True. - mesh: JAX Mesh object. Use this - argument if you want to use only some of the devices for OfflineEngine and - reserve the rest for other tasks. If None, OfflineEngine will create the mesh - automatically. - rng: Random number generator key. If None, a new key will be created. + config: The OfflineEngineConfig object containing all settings. """ max_logging.log("Initializing OfflineEngine") - # Configurations - self.config = config - self.params = params - self.min_decode_steps = min_decode_steps - self.enable_batch_prefill = enable_batch_prefill - self.mesh = mesh - self.tokenizer = tokenizer - self.eos_ids = eos_ids - self.prefill_lengths = prefill_lengths - self.batch_prefill_max_batch_size = batch_prefill_max_batch_size + # Centralized validation + config.validate() + + self.config = config.config # The inner MaxText config + self.params = config.params + self.min_decode_steps = config.min_decode_steps + self.enable_batch_prefill = config.enable_batch_prefill + self.mesh = config.mesh + self.tokenizer = config.tokenizer + self.eos_ids = config.eos_ids + self.prefill_lengths = config.prefill_lengths + self.batch_prefill_max_batch_size = config.batch_prefill_max_batch_size self.max_prefill_length = self.config.max_prefill_predict_length - self.max_decode_length = self.config.max_target_length - self.max_prefill_length - self.rng = jax.random.PRNGKey(0) if rng is None else rng - self.debug = debug - self._validate_config() + self.rng = jax.random.PRNGKey(0) if config.rng is None else config.rng + self.debug = config.debug + + # Calculate max decode length + if config.max_decode_length is not None: + self.max_decode_length = config.max_decode_length + else: + self.max_decode_length = self.config.max_target_length - self.max_prefill_length + if self.max_decode_length <= 0: + raise ValueError("Make sure max_target_length - max_prefill_predict_length is greater than 0") # Create prefill buckets: [0, 64], (64, 128], (128, 256], ..., [max_length//2, max_length] - if prefill_lengths == "auto": + if self.prefill_lengths == "auto": self.prefill_lengths = [2**i for i in range(6, max(6, (self.max_prefill_length - 1).bit_length()) + 1)] else: - self.prefill_lengths = sorted(prefill_lengths) + self.prefill_lengths = sorted(self.prefill_lengths) # Create meshes if not self.mesh: @@ -844,7 +987,11 @@ def update_params( self, params: Params, ): - """Update model weights.""" + """Update model weights. + + Args: + params: New model parameters. + """ self.worker.update_params(params) def batch_inference( @@ -858,7 +1005,7 @@ def batch_inference( Args: data: list of InputData objects, or JAX or numpy arrays. If input is JAX or numpy array, it must not contain padding tokens. - desc: Description string for logging + desc: Description string for logging. rng: Random number generator key. If None, the previous key will be used. Returns: @@ -940,7 +1087,15 @@ def pad_data(self, data: list[InputData]) -> list[InputData]: @staticmethod def create_mesh(devices, config): - """Create data parallelism meshes for each Inference worker.""" + """Create data parallelism meshes for each Inference worker. + + Args: + devices: A list of JAX devices. + config: The MaxText configuration object. + + Returns: + A JAX Mesh object. + """ ici_parallelism = max_utils.fill_unspecified_mesh_axes(config.ici_parallelism.copy(), len(devices), "ICI") devices_array = mesh_utils.create_device_mesh( ici_parallelism, @@ -950,17 +1105,3 @@ def create_mesh(devices, config): ) mesh = Mesh(devices_array.reshape(ici_parallelism), config.mesh_axes) return mesh - - def _validate_config(self): - """Validate configuration parameters and check for incompatible settings.""" - if not self.config.return_log_prob: - raise ValueError("return_log_prob must be True when using OfflineEngine") - if self.enable_batch_prefill and self.config.scan_layers: - raise ValueError("scan_layers must be False if enable_batch_prefill is True") - - if self.max_decode_length <= 0: - raise ValueError("Make sure max_target_length - max_prefill_predict_length is greater than 0") - if self.config.scan_layers: - max_logging.log( - "WARNING: scan_layers=True will result in slow step time. " "It is recommended for debugging purposes only." - ) diff --git a/tests/grpo_trainer_correctness_test.py b/tests/grpo_trainer_correctness_test.py index 18502ca9f5..ffdb94bff8 100644 --- a/tests/grpo_trainer_correctness_test.py +++ b/tests/grpo_trainer_correctness_test.py @@ -142,9 +142,8 @@ def setUp(self): devices_array = maxtext_utils.create_device_mesh(self.config_inference) self.mesh = Mesh(devices_array, self.config_inference.mesh_axes) self.tokenizer_model.add_special_tokens({"pad_token": ""}) - self.inference_engine = offline_engine.OfflineEngine( - config=self.config_inference, - mesh=self.inference_model.mesh, + self.inference_engine = ( + offline_engine.OfflineEngineBuilder(self.config_inference).set_mesh(self.inference_model.mesh).build() ) @pytest.mark.skip(reason="Logit output test fragile, failing on jax upgrade to 0.6.2 - see b/425997645") diff --git a/tests/inference/benchmark_offline_engine.py b/tests/inference/benchmark_offline_engine.py index c503de1b98..34dd685727 100644 --- a/tests/inference/benchmark_offline_engine.py +++ b/tests/inference/benchmark_offline_engine.py @@ -31,7 +31,7 @@ from MaxText.globals import MAXTEXT_PKG_DIR from MaxText import max_logging from MaxText import pyconfig -from MaxText.inference.offline_engine import OfflineEngine, InputData, CompletionOutput +from MaxText.inference.offline_engine import OfflineEngineBuilder, InputData, CompletionOutput def get_metrics(results: list[CompletionOutput], start_time, end_time): @@ -97,13 +97,13 @@ def run( profile_path="", ): """Run offline engine""" - inference_engine = OfflineEngine( - config, - params=None, - enable_batch_prefill=False, - rng=jax.random.PRNGKey(0), - eos_ids=[1002], - debug=False, + inference_engine = ( + OfflineEngineBuilder(config) + .set_params(None) + .set_rng(jax.random.PRNGKey(0)) + .set_eos_ids([1002]) + .set_debug(False) + .build() ) max_logging.log("Starting Warmup") _ = [inference_engine.batch_inference(input_data) for _ in range(4)] diff --git a/tests/offline_engine_test.py b/tests/offline_engine_test.py index 0e599ed93a..62d3a3e147 100644 --- a/tests/offline_engine_test.py +++ b/tests/offline_engine_test.py @@ -21,7 +21,7 @@ import jax import jax.numpy as jnp import numpy as np -from MaxText.inference.offline_engine import OfflineEngine, InputData, CompletionOutput +from MaxText.inference.offline_engine import OfflineEngineBuilder, InputData, CompletionOutput from MaxText import pyconfig from MaxText.globals import MAXTEXT_PKG_DIR @@ -68,7 +68,17 @@ def test_mcjax_tp(self): config = self.cfg rng = jax.random.PRNGKey(0) - inference_engine = OfflineEngine(config=config, params=None, enable_batch_prefill=False, rng=rng, eos_ids=[]) + # Use Builder instead of direct init + inference_engine = ( + OfflineEngineBuilder(config) + .enable_batch_prefill(max_batch_size=16) # Just demonstrating usage; test asked for False which is default + .set_rng(rng) + .set_eos_ids([]) + .build() + ) + # The test actually wanted enable_batch_prefill=False. To match previous behavior precisely: + inference_engine = OfflineEngineBuilder(config).set_rng(rng).set_eos_ids([]).build() + input_lengths = list(range(10, 600, 100)) input_data = [ InputData(id=f"input_{i}", tokens=np.arange(length), true_length=length) for i, length in enumerate(input_lengths) @@ -87,7 +97,7 @@ def test_mcjax_tp(self): def test_multi_sampling(self): config = self.cfg rng = jax.random.PRNGKey(0) - inference_engine = OfflineEngine(config=config, params=None, enable_batch_prefill=False, rng=rng, eos_ids=[]) + inference_engine = OfflineEngineBuilder(config).set_rng(rng).set_eos_ids([]).build() rng1, rng_2 = jax.random.split(rng, 2) input_data = [InputData(id=f"input_{i}", tokens=jnp.arange(128), true_length=128) for i in range(4)]