From 71ecc5f1e260332b9ca635163f244ee27dfdf813 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 18 Mar 2026 14:47:34 +0000 Subject: [PATCH 1/6] infer baseoutput dir and HF_TOKEN --- src/maxtext/configs/pyconfig.py | 36 ++++++++++++++++--- src/maxtext/inference/vllm_decode.py | 2 +- .../vllm/maxtext_vllm_adapter/adapter.py | 2 +- .../post_train/distillation/train_distill.py | 2 +- .../trainers/post_train/rl/train_rl.py | 27 +++++++------- .../trainers/post_train/sft/train_sft.py | 2 +- src/maxtext/utils/model_creation_utils.py | 9 +++-- 7 files changed, 56 insertions(+), 24 deletions(-) diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index 78d783270f..b315c4e928 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -77,7 +77,7 @@ def _module_from_path(path: str) -> str | None: return None -def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]: +def _resolve_or_infer_config(argv: list[str], **kwargs) -> tuple[str, list[str]]: """Resolves or infers config file path from module.""" if len(argv) >= 2 and argv[1].endswith(".yml"): return resolve_config_path(argv[1]), argv[2:] @@ -88,7 +88,27 @@ def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]: ) config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module]) logger.warning("No config file provided, using default config mapping: %s", config_path) - return config_path, argv[1:] + remaining_argv = argv[1:] + + return config_path, remaining_argv + +def _resolve_or_infer_addl_config(**kwargs): + """Resolves or infers more configs from module.""" + inferred_kwargs = {} + + # if base_output_directory key is not seen + if "base_output_directory" not in kwargs: + max_logging.warning("base_output_directory is not provided; Using local directory called maxtext_output") + base_output_directory = os.path.abspath("maxtext_output") + inferred_kwargs["base_output_directory"] = base_output_directory + + # if hf_access_token key is not seen + if "hf_access_token" not in kwargs: + hf_access_token = os.environ.get("HF_TOKEN") + if hf_access_token: + inferred_kwargs["hf_access_token"] = hf_access_token + + return inferred_kwargs def yaml_key_to_env_key(s: str) -> str: @@ -291,7 +311,7 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig: Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters` """ # 1. Load base and inherited configs from file(s) - config_path, cli_args = _resolve_or_infer_config(argv) + config_path, cli_args = _resolve_or_infer_config(argv, kwargs) base_yml_config = _load_config(config_path) # 2. Get overrides from CLI and kwargs @@ -299,8 +319,16 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig: kwargs_cfg = omegaconf.OmegaConf.create(kwargs) overrides_cfg = omegaconf.OmegaConf.merge(cli_cfg, kwargs_cfg) - # 3. Handle model-specific config + temp_cfg1 = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg) + # 3.1. infer more configs if possible + temp_cfg1 = _resolve_or_infer_addl_config(**temp_cfg1) + # update overrides_cfg with temp_cfg1 + overrides_cfg = omegaconf.OmegaConf.merge(overrides_cfg, temp_cfg1) temp_cfg = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg) + + + # 3.2. Handle model-specific config + model_name = temp_cfg.get("model_name", "default") # The architecture for -Instruct v/s base models are the same, so for identifying the # architecture we replace "-Instruct" from the model_name and get the base model name diff --git a/src/maxtext/inference/vllm_decode.py b/src/maxtext/inference/vllm_decode.py index f7df999547..d2e44a0145 100644 --- a/src/maxtext/inference/vllm_decode.py +++ b/src/maxtext/inference/vllm_decode.py @@ -242,7 +242,7 @@ def main(argv: Sequence[str]) -> None: config = pyconfig.initialize(argv) if FLAGS.use_tunix: - maxtext_model, mesh = model_creation_utils.create_nnx_model(config) + maxtext_model, mesh = model_creation_utils.from_pretrained(config) decode_with_tunix(config, model=maxtext_model, mesh=mesh) else: decode_with_vllm(config) diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index a0f3afba76..77dd4791d1 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -233,7 +233,7 @@ def load_weights(self, rng_key: jax.Array) -> None: return with self.mesh, nn.logical_axis_rules(""): - model, _ = model_creation_utils.create_nnx_model( + model, _ = model_creation_utils.from_pretrained( self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key ) self.model = nnx.data(model) diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 85eb045bfe..0bad8f3151 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -388,7 +388,7 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh) The loaded MaxText model. """ max_logging.log(f"Initializing model: {config.model_name}...") - model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh) + model, _ = model_creation_utils.from_pretrained(config, mesh=mesh) return model diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 7f5a33eed6..378ff929a2 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -77,8 +77,8 @@ from maxtext.trainers.post_train.rl.evaluate_rl import evaluate from maxtext.trainers.post_train.rl import utils_rl from maxtext.input_pipeline.instruction_data_processing import load_template_from_file -from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils - +from maxtext.utils import max_logging, max_utils, maxtext_utils +import maxtext as mt def get_maxtext_model(config, devices=None): """ @@ -96,7 +96,7 @@ def get_maxtext_model(config, devices=None): # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e., # load_parameters_path=/path/to/your/output/directory/0/items """ - model, mesh = model_creation_utils.create_nnx_model(config, devices=devices) + model, mesh = mt.from_pretrained(config, devices=devices) with mesh: use_no_op_mappings = "maxtext_config" in config.vllm_additional_config tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings) @@ -149,9 +149,9 @@ def get_dataset( return loaded_dataset -def setup_configs_and_devices(argv: list[str]): +def setup_configs_and_devices(argv: list[str], **kwargs): """Setup device allocation and configs for training and inference.""" - config = pyconfig.initialize_pydantic(argv) + config = pyconfig.initialize_pydantic(argv, kwargs) devices = jax.devices() if config.num_trainer_slices == -1 and config.num_samplers_slices == -1: max_logging.log("Running RL on a single slice") @@ -580,7 +580,7 @@ def create_rl_components( return rl_cluster, rl_trainer, optimizer -def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): +def rl_train(argv: Sequence[str], **kwargs): """ Run RL training with the provided configuration. @@ -590,13 +590,19 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): trainer_devices: JAX devices for the trainer. sampler_devices: JAX devices for the sampler. """ + trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv, kwargs) + + reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = create_models_and_meshes( + trainer_config, sampler_config, trainer_devices, sampler_devices + ) + if not trainer_config.debug.rl: # Apply filter to suppress noisy logs noise_filter = max_logging.NoisyLogFilter() logging.getLogger().addFilter(noise_filter) absl_logging.get_absl_logger().addFilter(noise_filter) + os.environ["VLLM_LOGGING_LEVEL"] = "ERROR" - max_logging.log("Starting RL Training") if not epath.Path(trainer_config.tensorboard_dir).exists(): epath.Path(trainer_config.tensorboard_dir).mkdir(parents=True, exist_ok=True) @@ -620,10 +626,6 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): break pprint(ele) - reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = create_models_and_meshes( - trainer_config, sampler_config, trainer_devices, sampler_devices - ) - if trainer_config.debug.rl: max_logging.log("Reference Model initialized successfully") nnx.display(reference_model) @@ -697,8 +699,7 @@ def main(argv: Sequence[str]) -> None: os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" max_utils.print_system_information() - trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv) - rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices) + rl_train(argv) if __name__ == "__main__": diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index 90595a05fd..e2b9407a63 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -146,7 +146,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None): tunix_config = get_tunix_config(mt_config) with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT): - model, mesh = model_creation_utils.create_nnx_model(mt_config) + model, mesh = model_creation_utils.from_pretrained(mt_config) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config) # pass in model for muon optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model) diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index b3057d0518..a23c4dbf27 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -112,9 +112,9 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng return model -def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): +def from_pretrained(config, original_mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): """Creates a NNX model with sharded parameters, possibly loading from a checkpoint.""" - + mesh = original_mesh def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None): if rng_key is None: rng_key = jax.random.PRNGKey(config.init_weights_seed) @@ -229,4 +229,7 @@ def create_sharded_state(): except Exception as e: raise ValueError(f"Checkpoint loading failed: {e}") from e - return model, mesh + if original_mesh: + return model + else: + return model, mesh From dd535f1254b75c2101fb9191417204d9fcbcc000 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 18 Mar 2026 15:19:59 +0000 Subject: [PATCH 2/6] fix kwargs --- src/maxtext/trainers/post_train/rl/train_rl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 378ff929a2..ac2c14888a 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -151,7 +151,7 @@ def get_dataset( def setup_configs_and_devices(argv: list[str], **kwargs): """Setup device allocation and configs for training and inference.""" - config = pyconfig.initialize_pydantic(argv, kwargs) + config = pyconfig.initialize_pydantic(argv, **kwargs) devices = jax.devices() if config.num_trainer_slices == -1 and config.num_samplers_slices == -1: max_logging.log("Running RL on a single slice") @@ -590,7 +590,7 @@ def rl_train(argv: Sequence[str], **kwargs): trainer_devices: JAX devices for the trainer. sampler_devices: JAX devices for the sampler. """ - trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv, kwargs) + trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv, **kwargs) reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = create_models_and_meshes( trainer_config, sampler_config, trainer_devices, sampler_devices @@ -689,7 +689,7 @@ def rl_train(argv: Sequence[str], **kwargs): max_logging.warning(f"Post RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") -def main(argv: Sequence[str]) -> None: +def main(argv: Sequence[str], **kwargs) -> None: """Main function to run RL training. Args: @@ -699,7 +699,7 @@ def main(argv: Sequence[str]) -> None: os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" max_utils.print_system_information() - rl_train(argv) + rl_train(argv, **kwargs) if __name__ == "__main__": From 2c9349fb5ff18c72fe6db0a9cfadaf8eeeb6e16a Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 18 Mar 2026 17:42:49 +0000 Subject: [PATCH 3/6] get past kwargs error and _resolve_or_infer_addl_config works now --- src/maxtext/configs/pyconfig.py | 9 +++++---- src/maxtext/trainers/post_train/rl/train_rl.py | 11 ++++++----- src/maxtext/utils/model_creation_utils.py | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index b315c4e928..bf3f9d626d 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -95,18 +95,19 @@ def _resolve_or_infer_config(argv: list[str], **kwargs) -> tuple[str, list[str]] def _resolve_or_infer_addl_config(**kwargs): """Resolves or infers more configs from module.""" inferred_kwargs = {} - # if base_output_directory key is not seen - if "base_output_directory" not in kwargs: + if not kwargs.get("base_output_directory"): max_logging.warning("base_output_directory is not provided; Using local directory called maxtext_output") base_output_directory = os.path.abspath("maxtext_output") inferred_kwargs["base_output_directory"] = base_output_directory # if hf_access_token key is not seen - if "hf_access_token" not in kwargs: + if not kwargs.get("hf_access_token"): hf_access_token = os.environ.get("HF_TOKEN") if hf_access_token: inferred_kwargs["hf_access_token"] = hf_access_token + breakpoint() + return inferred_kwargs @@ -311,7 +312,7 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig: Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters` """ # 1. Load base and inherited configs from file(s) - config_path, cli_args = _resolve_or_infer_config(argv, kwargs) + config_path, cli_args = _resolve_or_infer_config(argv, **kwargs) base_yml_config = _load_config(config_path) # 2. Get overrides from CLI and kwargs diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index ac2c14888a..19702cef74 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -149,7 +149,7 @@ def get_dataset( return loaded_dataset -def setup_configs_and_devices(argv: list[str], **kwargs): +def setup_configs_and_devices(argv: list[str], kwargs): """Setup device allocation and configs for training and inference.""" config = pyconfig.initialize_pydantic(argv, **kwargs) devices = jax.devices() @@ -580,7 +580,7 @@ def create_rl_components( return rl_cluster, rl_trainer, optimizer -def rl_train(argv: Sequence[str], **kwargs): +def rl_train(argv: Sequence[str], kwargs: dict): """ Run RL training with the provided configuration. @@ -590,7 +590,7 @@ def rl_train(argv: Sequence[str], **kwargs): trainer_devices: JAX devices for the trainer. sampler_devices: JAX devices for the sampler. """ - trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv, **kwargs) + trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv, kwargs) reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = create_models_and_meshes( trainer_config, sampler_config, trainer_devices, sampler_devices @@ -689,17 +689,18 @@ def rl_train(argv: Sequence[str], **kwargs): max_logging.warning(f"Post RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") -def main(argv: Sequence[str], **kwargs) -> None: +def main(argv: Sequence[str], kwargs: dict = None) -> None: """Main function to run RL training. Args: argv: Command-line arguments. """ + kwargs = kwargs or {} pathwaysutils.initialize() os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" max_utils.print_system_information() - rl_train(argv, **kwargs) + rl_train(argv, kwargs) if __name__ == "__main__": diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index a23c4dbf27..3f88dbd39e 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -112,7 +112,7 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng return model -def from_pretrained(config, original_mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): +def from_pretrained(config, original_mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None, convert_checkpoint_if_possible=False): """Creates a NNX model with sharded parameters, possibly loading from a checkpoint.""" mesh = original_mesh def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None): From 1a5144038d6b8a4cabe9ce24c8357841df12135c Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 18 Mar 2026 21:13:45 +0000 Subject: [PATCH 4/6] single host run works --- src/maxtext/configs/post_train/rl.yml | 1 + src/maxtext/configs/pyconfig.py | 1 - src/maxtext/configs/types.py | 5 + .../trainers/post_train/rl/train_rl.py | 130 +------------ src/maxtext/utils/model_creation_utils.py | 179 +++++++++++++++++- 5 files changed, 186 insertions(+), 130 deletions(-) diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index da455a13e2..095935e901 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -81,6 +81,7 @@ checkpoint_storage_use_ocdbt: False # For Pathways checkpoint_storage_use_zarr3: False # For Pathways use_pathways: True log_period: 20 +convert_checkpoint_if_possible: True # ====== Debugging ====== debug: diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index bf3f9d626d..4b4639de78 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -106,7 +106,6 @@ def _resolve_or_infer_addl_config(**kwargs): hf_access_token = os.environ.get("HF_TOKEN") if hf_access_token: inferred_kwargs["hf_access_token"] = hf_access_token - breakpoint() return inferred_kwargs diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index e66ecac8fa..5a125cc0ed 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1793,6 +1793,11 @@ class DerivedValues(BaseModel): None, description="The full path to the checkpoint directory, derived from `run_name`.", ) + convert_checkpoint_if_possible:bool = Field( + False, + description="Whether to convert checkpoint on the fly if not provided via\ + load_parameters_path or base_output_directory" + ) metrics_dir: None | str = Field( None, description="The full path to the metrics directory, derived from `run_name`.", diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 19702cef74..18cdc1e546 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -46,7 +46,6 @@ from __future__ import annotations from typing import Sequence -import collections import grain import jax import json @@ -77,32 +76,9 @@ from maxtext.trainers.post_train.rl.evaluate_rl import evaluate from maxtext.trainers.post_train.rl import utils_rl from maxtext.input_pipeline.instruction_data_processing import load_template_from_file -from maxtext.utils import max_logging, max_utils, maxtext_utils +from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils import maxtext as mt -def get_maxtext_model(config, devices=None): - """ - Load MaxText model with Tunix adapter. - # Note: pass the path to your scanned checkpoint for 'load_parameters_path'. - # To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if - # using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags: - # export USE_PATHWAYS=1 - # python src/MaxText/checkpoint_conversion/to_maxtext.py \ - # --model_name="gemma2-2b" \ - # --base_output_directory="/path/to/your/output/directory" \ - # --scan_layers=True \ - # --checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \ - # --checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) - # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e., - # load_parameters_path=/path/to/your/output/directory/0/items - """ - model, mesh = mt.from_pretrained(config, devices=devices) - with mesh: - use_no_op_mappings = "maxtext_config" in config.vllm_additional_config - tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings) - tunix_model.config = None - return tunix_model, mesh - def get_dataset( model_tokenizer, tmvp_config, data_dir, split="train", data_files=None, dataset_name=None @@ -149,85 +125,6 @@ def get_dataset( return loaded_dataset -def setup_configs_and_devices(argv: list[str], kwargs): - """Setup device allocation and configs for training and inference.""" - config = pyconfig.initialize_pydantic(argv, **kwargs) - devices = jax.devices() - if config.num_trainer_slices == -1 and config.num_samplers_slices == -1: - max_logging.log("Running RL on a single slice") - num_vms = len(devices) // config.chips_per_vm - trainer_devices = devices - sampler_devices = devices - if num_vms >= 2 and config.use_pathways: - # Multiple hosts with Pathways - potentially split devices for trainer and sampler - # based on trainer_devices_fraction and sampler_devices_fraction - max_logging.log(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.") - num_devices = len(devices) - num_trainer_devices = int(num_devices * config.trainer_devices_fraction) - num_sampler_devices = int(num_devices * config.sampler_devices_fraction) - trainer_devices = devices[:num_trainer_devices] - sampler_devices = devices[num_devices - num_sampler_devices :] - if config.trainer_devices_fraction != 1.0: - max_logging.log(f"Using first {len(trainer_devices)} devices as Trainer devices") - if config.sampler_devices_fraction != 1.0: - max_logging.log(f"Using last {len(sampler_devices)} devices as Sampler devices") - trainer_config = config - sampler_config = config - elif config.num_trainer_slices > 0 and config.num_samplers_slices > 0: - max_logging.log("Running RL with Multislice") - devices_by_slice = collections.defaultdict(list) - for d in devices: - devices_by_slice[d.slice_index].append(d) - slice_indices = sorted(devices_by_slice.keys()) - - if len(slice_indices) < config.num_trainer_slices + config.num_samplers_slices: - raise ValueError("Not enough slices for trainer and samplers") - - trainer_devices = [] - for i in range(config.num_trainer_slices): - trainer_devices.extend(devices_by_slice[slice_indices[i]]) - - sampler_devices = [] - for i in range(config.num_trainer_slices, config.num_trainer_slices + config.num_samplers_slices): - sampler_devices.extend(devices_by_slice[slice_indices[i]]) - - trainer_devices_per_slice = len(trainer_devices) // config.num_trainer_slices - trainer_fsdp = trainer_devices_per_slice - tp = config.ici_tensor_parallelism - if tp > 1: - if trainer_devices_per_slice % tp != 0: - raise ValueError( - f"trainer_devices_per_slice ({trainer_devices_per_slice}) must be divisible by tensor parallelism ({tp})" - ) - if config.ici_fsdp_parallelism != -1 and config.ici_fsdp_parallelism * tp != trainer_devices_per_slice: - raise ValueError( - f"ici_fsdp_parallelism ({config.ici_fsdp_parallelism}) * ici_tensor_parallelism ({tp}) must equal " - f"devices_per_slice ({trainer_devices_per_slice})" - ) - trainer_fsdp = trainer_devices_per_slice // tp - - trainer_update = { - "num_slices": config.num_trainer_slices, - "ici_fsdp_parallelism": trainer_fsdp, - "ici_tensor_parallelism": tp, - "dcn_data_parallelism": config.num_trainer_slices, - } - - sampler_update = { - "num_slices": config.num_samplers_slices, - "ici_fsdp_parallelism": len(sampler_devices) // config.num_samplers_slices, - "ici_tensor_parallelism": -1, - "dcn_data_parallelism": config.num_samplers_slices, - } - - trainer_config = pyconfig.initialize_pydantic(argv, **trainer_update) - sampler_config = pyconfig.initialize_pydantic(argv, **sampler_update) - - else: - raise ValueError("num_trainer_slices and num_samplers_slices should be both -1 or positive") - - return trainer_config, sampler_config, trainer_devices, sampler_devices - def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices): """Get rollout kwargs for vLLM rollout when using data parallelism.""" @@ -400,27 +297,6 @@ def _filter_long_prompts(x): return train_dataset, test_dataset -def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sampler_devices): - """Create reference and actor models and their respective meshes.""" - max_logging.log("Creating reference model and also meshes for reference and rollout") - reference_model, reference_mesh = get_maxtext_model(trainer_config, trainer_devices) - devices_array = maxtext_utils.create_device_mesh(sampler_config, sampler_devices) - rollout_mesh = Mesh(devices_array, sampler_config.mesh_axes) - - if trainer_config.load_checkpoint_only_once: - max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.") - with reference_mesh: - actor_base_model = nnx.clone(reference_model.base) - use_no_op_mappings = "maxtext_config" in trainer_config.vllm_additional_config - actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings) - actor_model.config = None - actor_mesh = reference_mesh - else: - max_logging.log("Creating policy model with same config as reference model on trainer mesh") - actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices) - - return reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh - def create_rl_components( trainer_config, @@ -590,9 +466,9 @@ def rl_train(argv: Sequence[str], kwargs: dict): trainer_devices: JAX devices for the trainer. sampler_devices: JAX devices for the sampler. """ - trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv, kwargs) + trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(argv, kwargs) - reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = create_models_and_meshes( + reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = model_creation_utils.create_models_and_meshes( trainer_config, sampler_config, trainer_devices, sampler_devices ) diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index 3f88dbd39e..3231aea427 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -15,8 +15,10 @@ # pylint: disable=bare-except, consider-using-generator """ Utils that are only interesting for creating a model in MaxText. """ +import collections from collections.abc import Sequence from functools import partial +import os from typing import overload from etils import epath @@ -28,8 +30,9 @@ from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode from maxtext.layers import quantizations from maxtext.models import models -from maxtext.utils import max_utils +from maxtext.utils import max_utils, max_logging from maxtext.utils import maxtext_utils +from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter from orbax import checkpoint as ocp @@ -112,9 +115,181 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng return model -def from_pretrained(config, original_mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None, convert_checkpoint_if_possible=False): +def setup_configs_and_devices(argv: list[str], kwargs): + """Setup device allocation and configs for training and inference.""" + config = pyconfig.initialize_pydantic(argv, **kwargs) + devices = jax.devices() + if config.num_trainer_slices == -1 and config.num_samplers_slices == -1: + max_logging.log("Running on a single slice") + num_vms = len(devices) // config.chips_per_vm + trainer_devices = devices + sampler_devices = devices + if num_vms >= 2 and config.use_pathways: + # Multiple hosts with Pathways - potentially split devices for trainer and sampler + # based on trainer_devices_fraction and sampler_devices_fraction + max_logging.log(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.") + num_devices = len(devices) + num_trainer_devices = int(num_devices * config.trainer_devices_fraction) + num_sampler_devices = int(num_devices * config.sampler_devices_fraction) + trainer_devices = devices[:num_trainer_devices] + sampler_devices = devices[num_devices - num_sampler_devices :] + if config.trainer_devices_fraction != 1.0: + max_logging.log(f"Using first {len(trainer_devices)} devices as Trainer devices") + if config.sampler_devices_fraction != 1.0: + max_logging.log(f"Using last {len(sampler_devices)} devices as Sampler devices") + trainer_config = config + sampler_config = config + elif config.num_trainer_slices > 0 and config.num_samplers_slices > 0: + max_logging.log("Running with Multislice") + devices_by_slice = collections.defaultdict(list) + for d in devices: + devices_by_slice[d.slice_index].append(d) + slice_indices = sorted(devices_by_slice.keys()) + + if len(slice_indices) < config.num_trainer_slices + config.num_samplers_slices: + raise ValueError("Not enough slices for trainer and samplers") + + trainer_devices = [] + for i in range(config.num_trainer_slices): + trainer_devices.extend(devices_by_slice[slice_indices[i]]) + + sampler_devices = [] + for i in range(config.num_trainer_slices, config.num_trainer_slices + config.num_samplers_slices): + sampler_devices.extend(devices_by_slice[slice_indices[i]]) + + trainer_devices_per_slice = len(trainer_devices) // config.num_trainer_slices + trainer_fsdp = trainer_devices_per_slice + tp = config.ici_tensor_parallelism + if tp > 1: + if trainer_devices_per_slice % tp != 0: + raise ValueError( + f"trainer_devices_per_slice ({trainer_devices_per_slice}) must be divisible by tensor parallelism ({tp})" + ) + if config.ici_fsdp_parallelism != -1 and config.ici_fsdp_parallelism * tp != trainer_devices_per_slice: + raise ValueError( + f"ici_fsdp_parallelism ({config.ici_fsdp_parallelism}) * ici_tensor_parallelism ({tp}) must equal " + f"devices_per_slice ({trainer_devices_per_slice})" + ) + trainer_fsdp = trainer_devices_per_slice // tp + + trainer_update = { + "num_slices": config.num_trainer_slices, + "ici_fsdp_parallelism": trainer_fsdp, + "ici_tensor_parallelism": tp, + "dcn_data_parallelism": config.num_trainer_slices, + } + + sampler_update = { + "num_slices": config.num_samplers_slices, + "ici_fsdp_parallelism": len(sampler_devices) // config.num_samplers_slices, + "ici_tensor_parallelism": -1, + "dcn_data_parallelism": config.num_samplers_slices, + } + + trainer_config = pyconfig.initialize_pydantic(argv, **trainer_update) + sampler_config = pyconfig.initialize_pydantic(argv, **sampler_update) + + else: + raise ValueError("num_trainer_slices and num_samplers_slices should be both -1 or positive") + + return trainer_config, sampler_config, trainer_devices, sampler_devices + + + +def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sampler_devices): + """Create reference and actor models and their respective meshes.""" + max_logging.log("Creating reference model and also meshes for reference and rollout") + reference_model, reference_mesh = get_maxtext_model(trainer_config, trainer_devices) + devices_array = maxtext_utils.create_device_mesh(sampler_config, sampler_devices) + rollout_mesh = Mesh(devices_array, sampler_config.mesh_axes) + + if trainer_config.load_checkpoint_only_once: + max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.") + with reference_mesh: + actor_base_model = nnx.clone(reference_model.base) + use_no_op_mappings = "maxtext_config" in trainer_config.vllm_additional_config + actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings) + actor_model.config = None + actor_mesh = reference_mesh + else: + max_logging.log("Creating policy model with same config as reference model on trainer mesh") + actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices) + + return reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh + +def get_maxtext_model(config, devices=None): + """ + Load MaxText model with Tunix adapter. + # Note: pass the path to your scanned checkpoint for 'load_parameters_path'. + # To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if + # using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags: + # export USE_PATHWAYS=1 + # python src/MaxText/checkpoint_conversion/to_maxtext.py \ + # --model_name="gemma2-2b" \ + # --base_output_directory="/path/to/your/output/directory" \ + # --scan_layers=True \ + # --checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \ + # --checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) + # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e., + # load_parameters_path=/path/to/your/output/directory/0/items + """ + model, mesh = from_pretrained(config, devices=devices) + with mesh: + use_no_op_mappings = "maxtext_config" in config.vllm_additional_config + tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings) + tunix_model.config = None + return tunix_model, mesh + + +def from_pretrained(config, original_mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): """Creates a NNX model with sharded parameters, possibly loading from a checkpoint.""" mesh = original_mesh + if config.convert_checkpoint_if_possible: + if not not (epath.Path(config.base_output_directory) / "0" / "items").exists(): + # Try to convert checkpoint on the fly + if not config.hf_access_token: + raise ValueError("hf_access_token must be provided when not providing a pre-existing checkpoint") + + max_logging.warning("Checkpoint path is not provided, converting checkpoint to orbax format for MaxText") + + simulated_cpu_devices_count = 16 + + import subprocess + import sys + + # Run the conversion in a completely isolated subprocess so its CPU + # JAX/XLA requirements do not interfere with the parent's Pathways TPU mesh. + conversion_env = os.environ.copy() + conversion_env["JAX_PLATFORMS"] = "cpu" + # conversion_env["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={simulated_cpu_devices_count}" + + to_maxtext_cmd = [ + sys.executable, + "-m", "maxtext.checkpoint_conversion.to_maxtext", + ] + [ + f"model_name={config.model_name}", + f"base_output_directory={config.base_output_directory}", + f"scan_layers={config.scan_layers}", + f"hf_access_token={config.hf_access_token}", + "use_multimodal=false", + "skip_jax_distributed_system=True", + f"--lazy_load_tensors=True", + f"--simulated_cpu_devices_count={simulated_cpu_devices_count}", + ] + + try: + subprocess.run(to_maxtext_cmd, env=conversion_env, check=True) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Checkpoint conversion failed with exit code {e.returncode}") from e + load_parameters_path = epath.Path(config.base_output_directory) / "0" / "items" + # Create a copied Pydantic model with the updated values + pydantic_config = config._pydantic_config if hasattr(config, "_pydantic_config") else config + new_config = pydantic_config.model_copy(update={ + "load_parameters_path": load_parameters_path, + }) + config = pyconfig.HyperParameters(new_config) + + def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None): if rng_key is None: rng_key = jax.random.PRNGKey(config.init_weights_seed) From a0a80925650387367e2bcfefb0e2076df6d42347 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 18 Mar 2026 22:36:59 +0000 Subject: [PATCH 5/6] simplify invocations --- src/maxtext/configs/pyconfig.py | 29 ++++++++++++++++------- src/maxtext/utils/model_creation_utils.py | 23 +++++++++++------- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index 4b4639de78..d62261523d 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -77,18 +77,22 @@ def _module_from_path(path: str) -> str | None: return None -def _resolve_or_infer_config(argv: list[str], **kwargs) -> tuple[str, list[str]]: +def _resolve_or_infer_config(argv: list[str] | None = None, **kwargs) -> tuple[str, list[str]]: """Resolves or infers config file path from module.""" + if argv is None: + argv = [""] if len(argv) >= 2 and argv[1].endswith(".yml"): return resolve_config_path(argv[1]), argv[2:] - module = _module_from_path(argv[0]) + module = _module_from_path(argv[0]) if len(argv) > 0 else None if module not in _CONFIG_FILE_MAPPING: - raise ValueError( - f"No config file provided and no default config found for module '{module}'" + config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml") + logger.warning( + "No config file provided and no default config found for module '%s', using base.yml", module ) - config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module]) - logger.warning("No config file provided, using default config mapping: %s", config_path) - remaining_argv = argv[1:] + else: + config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module]) + logger.warning("No config file provided, using default config mapping: %s", config_path) + remaining_argv = argv[1:] if len(argv) > 1 else [] return config_path, remaining_argv @@ -299,14 +303,14 @@ def get_keys(self) -> dict[str, Any]: return self._flat_config -def initialize(argv: list[str], **kwargs) -> HyperParameters: +def initialize(argv: list[str] | None = None, **kwargs) -> HyperParameters: """Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides.""" pydantic_config = initialize_pydantic(argv, **kwargs) config = HyperParameters(pydantic_config) return config -def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig: +def initialize_pydantic(argv: list[str] | None = None, **kwargs) -> MaxTextConfig: """Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides. Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters` """ @@ -446,3 +450,10 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig: # Shim for backward compatibility with pyconfig_deprecated_test.py validate_and_update_keys = pyconfig_deprecated.validate_and_update_keys __all__ = ["initialize", "initialize_pydantic"] + +class _CallablePyconfigModule(sys.modules[__name__].__class__): + """Allows calling the module directly as mt.pyconfig().""" + def __call__(self, argv: list[str] | None = None, **kwargs) -> HyperParameters: + return initialize(argv, **kwargs) + +sys.modules[__name__].__class__ = _CallablePyconfigModule diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index 3231aea427..2942cc50bc 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -115,9 +115,14 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng return model -def setup_configs_and_devices(argv: list[str], kwargs): +def setup_configs_and_devices(argv: list[str] | None = None, kwargs: dict | None = None, **extra_kwargs): """Setup device allocation and configs for training and inference.""" - config = pyconfig.initialize_pydantic(argv, **kwargs) + if argv is None: + argv = [""] + + combined_kwargs = dict(kwargs) if kwargs else {} + combined_kwargs.update(extra_kwargs) + config = pyconfig.initialize_pydantic(argv, **combined_kwargs) devices = jax.devices() if config.num_trainer_slices == -1 and config.num_samplers_slices == -1: max_logging.log("Running on a single slice") @@ -172,22 +177,24 @@ def setup_configs_and_devices(argv: list[str], kwargs): ) trainer_fsdp = trainer_devices_per_slice // tp - trainer_update = { + trainer_kwargs = dict(combined_kwargs) + trainer_kwargs.update({ "num_slices": config.num_trainer_slices, "ici_fsdp_parallelism": trainer_fsdp, "ici_tensor_parallelism": tp, "dcn_data_parallelism": config.num_trainer_slices, - } + }) - sampler_update = { + sampler_kwargs = dict(combined_kwargs) + sampler_kwargs.update({ "num_slices": config.num_samplers_slices, "ici_fsdp_parallelism": len(sampler_devices) // config.num_samplers_slices, "ici_tensor_parallelism": -1, "dcn_data_parallelism": config.num_samplers_slices, - } + }) - trainer_config = pyconfig.initialize_pydantic(argv, **trainer_update) - sampler_config = pyconfig.initialize_pydantic(argv, **sampler_update) + trainer_config = pyconfig.initialize_pydantic(argv, **trainer_kwargs) + sampler_config = pyconfig.initialize_pydantic(argv, **sampler_kwargs) else: raise ValueError("num_trainer_slices and num_samplers_slices should be both -1 or positive") From 1e13fdc23a7b825e2d13d9b70f0bca7373a05374 Mon Sep 17 00:00:00 2001 From: A9isha Date: Thu, 19 Mar 2026 01:32:46 +0000 Subject: [PATCH 6/6] add init.py --- src/maxtext/__init__.py | 45 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 src/maxtext/__init__.py diff --git a/src/maxtext/__init__.py b/src/maxtext/__init__.py new file mode 100644 index 0000000000..64c787352e --- /dev/null +++ b/src/maxtext/__init__.py @@ -0,0 +1,45 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +MaxText is a high performance, highly scalable, open-source LLM written in pure Python/Jax and targeting Google Cloud +TPUs and GPUs for training and inference. MaxText achieves high MFUs and scales from single host to very large clusters +while staying simple and "optimization-free" thanks to the power of Jax and the XLA compiler. +""" + +__author__ = "Google LLC" +__version__ = "0.2.0" +__description__ = ( + "MaxText is a high performance, highly scalable, open-source LLM written in pure Python/Jax and " + "targeting Google Cloud TPUs and GPUs for training and **inference." +) + +from collections.abc import Sequence + +import os +# In order to have any effect on the C++ logging this has to be set before we import anything from jax. +# When jax is imported, its `__init__.py` calls `cloud_tpu_init()`, which also initializes the C++ logger. +os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "0") +del os + +from jax.sharding import Mesh + +from maxtext.configs import pyconfig +from maxtext.models import models +from maxtext.trainers.post_train.dpo import dpo_utils +from maxtext.utils import maxtext_utils +from maxtext.utils.model_creation_utils import * + +Transformer = models.Transformer +transformer_as_linen = models.transformer_as_linen