Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions src/maxtext/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
MaxText is a high performance, highly scalable, open-source LLM written in pure Python/Jax and targeting Google Cloud
TPUs and GPUs for training and inference. MaxText achieves high MFUs and scales from single host to very large clusters
while staying simple and "optimization-free" thanks to the power of Jax and the XLA compiler.
"""

__author__ = "Google LLC"
__version__ = "0.2.0"
__description__ = (
"MaxText is a high performance, highly scalable, open-source LLM written in pure Python/Jax and "
"targeting Google Cloud TPUs and GPUs for training and **inference."
)

from collections.abc import Sequence

import os
# In order to have any effect on the C++ logging this has to be set before we import anything from jax.
# When jax is imported, its `__init__.py` calls `cloud_tpu_init()`, which also initializes the C++ logger.
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "0")
del os

from jax.sharding import Mesh

from maxtext.configs import pyconfig
from maxtext.models import models
from maxtext.trainers.post_train.dpo import dpo_utils
from maxtext.utils import maxtext_utils
from maxtext.utils.model_creation_utils import *

Transformer = models.Transformer
transformer_as_linen = models.transformer_as_linen
1 change: 1 addition & 0 deletions src/maxtext/configs/post_train/rl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ checkpoint_storage_use_ocdbt: False # For Pathways
checkpoint_storage_use_zarr3: False # For Pathways
use_pathways: True
log_period: 20
convert_checkpoint_if_possible: True

# ====== Debugging ======
debug:
Expand Down
61 changes: 50 additions & 11 deletions src/maxtext/configs/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,42 @@ def _module_from_path(path: str) -> str | None:
return None


def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]:
def _resolve_or_infer_config(argv: list[str] | None = None, **kwargs) -> tuple[str, list[str]]:
"""Resolves or infers config file path from module."""
if argv is None:
argv = [""]
if len(argv) >= 2 and argv[1].endswith(".yml"):
return resolve_config_path(argv[1]), argv[2:]
module = _module_from_path(argv[0])
module = _module_from_path(argv[0]) if len(argv) > 0 else None
if module not in _CONFIG_FILE_MAPPING:
raise ValueError(
f"No config file provided and no default config found for module '{module}'"
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml")
logger.warning(
"No config file provided and no default config found for module '%s', using base.yml", module
)
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
logger.warning("No config file provided, using default config mapping: %s", config_path)
return config_path, argv[1:]
else:
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
logger.warning("No config file provided, using default config mapping: %s", config_path)
remaining_argv = argv[1:] if len(argv) > 1 else []

return config_path, remaining_argv

def _resolve_or_infer_addl_config(**kwargs):
"""Resolves or infers more configs from module."""
inferred_kwargs = {}
# if base_output_directory key is not seen
if not kwargs.get("base_output_directory"):
max_logging.warning("base_output_directory is not provided; Using local directory called maxtext_output")
base_output_directory = os.path.abspath("maxtext_output")
inferred_kwargs["base_output_directory"] = base_output_directory

# if hf_access_token key is not seen
if not kwargs.get("hf_access_token"):
hf_access_token = os.environ.get("HF_TOKEN")
if hf_access_token:
inferred_kwargs["hf_access_token"] = hf_access_token


return inferred_kwargs


def yaml_key_to_env_key(s: str) -> str:
Expand Down Expand Up @@ -279,28 +303,36 @@ def get_keys(self) -> dict[str, Any]:
return self._flat_config


def initialize(argv: list[str], **kwargs) -> HyperParameters:
def initialize(argv: list[str] | None = None, **kwargs) -> HyperParameters:
"""Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides."""
pydantic_config = initialize_pydantic(argv, **kwargs)
config = HyperParameters(pydantic_config)
return config


def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
def initialize_pydantic(argv: list[str] | None = None, **kwargs) -> MaxTextConfig:
"""Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides.
Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters`
"""
# 1. Load base and inherited configs from file(s)
config_path, cli_args = _resolve_or_infer_config(argv)
config_path, cli_args = _resolve_or_infer_config(argv, **kwargs)
base_yml_config = _load_config(config_path)

# 2. Get overrides from CLI and kwargs
cli_cfg = omegaconf.OmegaConf.from_cli(cli_args)
kwargs_cfg = omegaconf.OmegaConf.create(kwargs)
overrides_cfg = omegaconf.OmegaConf.merge(cli_cfg, kwargs_cfg)

# 3. Handle model-specific config
temp_cfg1 = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg)
# 3.1. infer more configs if possible
temp_cfg1 = _resolve_or_infer_addl_config(**temp_cfg1)
# update overrides_cfg with temp_cfg1
overrides_cfg = omegaconf.OmegaConf.merge(overrides_cfg, temp_cfg1)
temp_cfg = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg)


# 3.2. Handle model-specific config

model_name = temp_cfg.get("model_name", "default")
# The architecture for -Instruct v/s base models are the same, so for identifying the
# architecture we replace "-Instruct" from the model_name and get the base model name
Expand Down Expand Up @@ -418,3 +450,10 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
# Shim for backward compatibility with pyconfig_deprecated_test.py
validate_and_update_keys = pyconfig_deprecated.validate_and_update_keys
__all__ = ["initialize", "initialize_pydantic"]

class _CallablePyconfigModule(sys.modules[__name__].__class__):
"""Allows calling the module directly as mt.pyconfig()."""
def __call__(self, argv: list[str] | None = None, **kwargs) -> HyperParameters:
return initialize(argv, **kwargs)

sys.modules[__name__].__class__ = _CallablePyconfigModule
5 changes: 5 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1793,6 +1793,11 @@ class DerivedValues(BaseModel):
None,
description="The full path to the checkpoint directory, derived from `run_name`.",
)
convert_checkpoint_if_possible:bool = Field(
False,
description="Whether to convert checkpoint on the fly if not provided via\
load_parameters_path or base_output_directory"
)
metrics_dir: None | str = Field(
None,
description="The full path to the metrics directory, derived from `run_name`.",
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/inference/vllm_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def main(argv: Sequence[str]) -> None:
config = pyconfig.initialize(argv)

if FLAGS.use_tunix:
maxtext_model, mesh = model_creation_utils.create_nnx_model(config)
maxtext_model, mesh = model_creation_utils.from_pretrained(config)
decode_with_tunix(config, model=maxtext_model, mesh=mesh)
else:
decode_with_vllm(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
return

with self.mesh, nn.logical_axis_rules(""):
model, _ = model_creation_utils.create_nnx_model(
model, _ = model_creation_utils.from_pretrained(
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)
self.model = nnx.data(model)
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh)
The loaded MaxText model.
"""
max_logging.log(f"Initializing model: {config.model_name}...")
model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh)
model, _ = model_creation_utils.from_pretrained(config, mesh=mesh)
return model


Expand Down
Loading
Loading