Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions build_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Custom build hooks for PyPI."""

import os
import sys
from hatchling.builders.hooks.plugin.interface import BuildHookInterface

TPU_REQUIREMENTS_PATH = "src/dependencies/requirements/generated_requirements/tpu-requirements.txt"
Expand All @@ -33,9 +34,27 @@ def get_tpu_dependencies():


class CustomBuildHook(BuildHookInterface):
"""A custom hook to inject TPU dependencies into the core wheel dependencies."""
"""A custom hook to handle platform-specific package configuration for MaxText."""

def initialize(self, version, build_data): # pylint: disable=unused-argument
tpu_deps = get_tpu_dependencies()
build_data["dependencies"] = tpu_deps
print(f"Successfully injected {len(tpu_deps)} TPU dependencies into the wheel's core requirements.")
"""Adjusts the build_data dictionary to customize the wheel's package structure."""
# The following TPU dependency injection is disabled because TPU-specific requirements
# are now managed via optional dependencies (extras) in pyproject.toml
# (e.g., pip install maxtext[tpu]).
# tpu_deps = get_tpu_dependencies()
# build_data["dependencies"] = tpu_deps
# print(f"Successfully injected {len(tpu_deps)} TPU dependencies into the wheel's core requirements.")

# macOS specific logic to avoid case-sensitivity issues with MaxText and maxtext directories
build_data["force_include"] = build_data.get("force_include", {})
if sys.platform == "darwin":
print("macOS detected. Skipping legacy MaxText shims to avoid case-sensitivity conflicts.")
# Always include the __init__.py in the lowercase 'maxtext' package on macOS.
# This ensures that 'import maxtext' (and thus 'import MaxText' on macOS)
# has the proper version and metadata.
build_data["force_include"]["src/MaxText/__init__.py"] = "maxtext/__init__.py"
else:
# On other platforms, include 'src/MaxText' as its own top-level package for legacy support.
# We do NOT add __init__.py to 'maxtext' here to maintain exact parity with previous builds.
print("Included src/MaxText as a top-level package for non-macOS platforms.")
build_data["force_include"]["src/MaxText"] = "MaxText"
35 changes: 28 additions & 7 deletions docs/install_maxtext.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ MaxText offers following installation modes:
3. maxtext[tpu-post-train]. Used for post-training on TPUs. Currently, this option should also be used for running vllm_decode on TPUs.
4. maxtext[runner]. Used for building MaxText's Docker images and scheduling workloads through XPK.

## From PyPI (Recommended)
## From PyPI (Recommended on Linux)

This is the easiest way to get started with the latest stable version.

Expand All @@ -45,7 +45,7 @@ install_maxtext_tpu_github_deps

# Option 2: Installing maxtext[cuda12]
uv pip install maxtext[cuda12] --resolution=lowest
install_maxtext_cuda12_github_dep
install_maxtext_cuda12_github_deps

# Option 3: Installing maxtext[tpu-post-train]
uv pip install maxtext[tpu-post-train] --resolution=lowest
Expand All @@ -55,12 +55,33 @@ install_maxtext_tpu_post_train_extra_deps
uv pip install maxtext[runner] --resolution=lowest
```

> **Note:** The `install_maxtext_tpu_github_deps`, `install_maxtext_cuda12_github_dep`, and
> **Note:** The `install_maxtext_tpu_github_deps`, `install_maxtext_cuda12_github_deps`, and
> `install_maxtext_tpu_post_train_extra_deps` commands are temporarily required to install dependencies directly from GitHub
> that are not yet available on PyPI. As shown above, choose the one that corresponds to your use case.

## Modern UV Project (Recommended for New Projects)

If you are starting a new project and want to use `uv`'s project management features (with a `pyproject.toml` and `uv.lock` in your own project), you can use `uv add`. MaxText's helper scripts will detect your `uv.lock` and correctly add their extra dependencies to your `pyproject.toml`.

```bash
# 1. Initialize your project
mkdir my-maxtext-project && cd my-maxtext-project
uv init

# 2. Add MaxText as a dependency
uv add maxtext[tpu] --resolution=lowest

# 3. Install MaxText's extra GitHub dependencies
# These will be automatically added to your pyproject.toml
install_maxtext_tpu_github_deps
```

> **Note:** The maxtext package contains a comprehensive list of all direct and transitive dependencies, with lower bounds, generated by [seed-env](https://github.com/google-ml-infra/actions/tree/main/python_seed_env). We highly recommend the `--resolution=lowest` flag. It instructs `uv` to install the specific, tested versions of dependencies defined by MaxText, rather than the latest available ones. This ensures a consistent and reproducible environment, which is critical for stable performance and for running benchmarks.

## macOS Installation

Due to macOS's case-insensitive filesystem, special care is needed to avoid conflicts between the `maxtext` and legacy `MaxText` package names. We recommend installing it from source using the `.[runner]` configuration.

## From Source

If you plan to contribute to MaxText or need the latest unreleased features, install from source.
Expand All @@ -84,7 +105,7 @@ install_maxtext_tpu_github_deps

# Option 2: Installing .[cuda12]
uv pip install -e .[cuda12] --resolution=lowest
install_maxtext_cuda12_github_dep
install_maxtext_cuda12_github_deps

# Option 3: Installing .[tpu-post-train]
uv pip install -e .[tpu-post-train] --resolution=lowest
Expand All @@ -110,7 +131,7 @@ To update dependencies, you will follow these general steps:

1. **Modify Base Requirements**: Update the desired dependencies in `base_requirements/requirements.txt` or the hardware-specific files (`base_requirements/tpu-base-requirements.txt`, `base_requirements/gpu-base-requirements.txt`).
2. **Generate New Files**: Run the `seed-env` CLI tool to generate new, fully-pinned requirements files based on your changes.
3. **Update Project Files**: Copy the newly generated files into the `generated_requirements/` directory.
3. **Update Project Files**: Copy the newly generated files into the `src/dependencies/requirements/generated_requirements/` directory.
4. **Handle GitHub Dependencies**: Move any dependencies that are installed directly from GitHub from the generated files to `src/dependencies/github_deps/pre_train_deps.txt`.
5. **Verify**: Test the new dependencies to ensure the project installs and runs correctly.

Expand Down Expand Up @@ -166,8 +187,8 @@ After generating the new requirements, you need to update the files in the MaxTe

1. **Copy the generated files:**

- Move `generated_tpu_artifacts/tpu-requirements.txt` to `generated_requirements/tpu-requirements.txt`.
- Move `generated_gpu_artifacts/cuda12-requirements.txt` to `generated_requirements/cuda12-requirements.txt`.
- Move `generated_tpu_artifacts/tpu-requirements.txt` to `src/dependencies/requirements/generated_requirements/tpu-requirements.txt`.
- Move `generated_gpu_artifacts/cuda12-requirements.txt` to `src/dependencies/requirements/generated_requirements/cuda12-requirements.txt`.

2. **Update `pre_train_deps.txt` (if necessary):**
Currently, MaxText uses a few dependencies, such as `mlperf-logging` and `google-jetstream`, that are installed directly from GitHub source. These are defined in `base_requirements/requirements.txt`, and the `seed-env` tool will carry them over to the generated requirements files.
Expand Down
7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ Repository = "https://github.com/AI-Hypercomputer/maxtext.git"
allow-direct-references = true

[tool.hatch.build.targets.wheel]
packages = ["src/MaxText", "src/maxtext", "src/dependencies"]
packages = ["src/maxtext", "src/dependencies"]

# TODO: Add this hook back when it handles device-type parsing
# [tool.hatch.build.targets.wheel.hooks.custom]
# path = "build_hooks.py"
[tool.hatch.build.targets.wheel.hooks.custom]
path = "build_hooks.py"

[project.scripts]
install_maxtext_tpu_github_deps = "dependencies.github_deps.install_pre_train_deps:main"
Expand Down
64 changes: 10 additions & 54 deletions src/dependencies/github_deps/install_post_train_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,19 @@
"""

import os
import subprocess
import sys

try:
from . import uv_utils
except ImportError:
import uv_utils


def main():
"""
Installs extra dependencies specified in post_train_deps.txt using uv.

This script looks for 'post_train_deps.txt' relative to its own location.
It executes 'uv pip install -r <path_to_extra_deps.txt> --resolution=lowest'.
It executes 'uv add' (if uv.lock is present) or 'uv pip install'.
"""
os.environ["VLLM_TARGET_DEVICE"] = "tpu"

Expand All @@ -40,57 +43,10 @@ def main():
if not os.path.exists(extra_deps_path):
raise FileNotFoundError(f"Dependencies file not found at {extra_deps_path}")

# Check if 'uv' is available in the environment
try:
subprocess.run([sys.executable, "-m", "pip", "install", "uv"], check=True, capture_output=True)
subprocess.run([sys.executable, "-m", "uv", "--version"], check=True, capture_output=True)
except subprocess.CalledProcessError as e:
print(f"Error checking uv version: {e}")
print(f"Stderr: {e.stderr.decode()}")
sys.exit(1)

command = [
sys.executable, # Use the current Python executable's pip to ensure the correct environment
"-m",
"uv",
"pip",
"install",
"-r",
str(extra_deps_path),
"--no-deps",
]

local_vllm_install_command = [
sys.executable, # Use the current Python executable's pip to ensure the correct environment
"-m",
"uv",
"pip",
"install",
f"{repo_root}/maxtext/integration/vllm", # MaxText on vllm installations
"--no-deps",
]

try:
# Run the command to install Github dependencies
print(f"Installing extra dependencies: {' '.join(command)}")
_ = subprocess.run(command, check=True, capture_output=True, text=True)
print("Extra dependencies installed successfully!")

# Run the command to install the MaxText vLLM directory
print(f"Installing MaxText vLLM dependency: {' '.join(local_vllm_install_command)}")
_ = subprocess.run(local_vllm_install_command, check=True, capture_output=True, text=True)
print("MaxText vLLM dependency installed successfully!")
except subprocess.CalledProcessError as e:
print("Failed to install extra dependencies.")
print(f"Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}.")
print("--- Stderr ---")
print(e.stderr)
print("--- Stdout ---")
print(e.stdout)
sys.exit(e.returncode)
except (OSError, FileNotFoundError) as e:
print(f"An OS-level error occurred while trying to run uv: {e}")
sys.exit(1)
# Install both requirements file and the local vLLM integration
uv_utils.run_install(
requirements_files=[extra_deps_path], paths=[f"{repo_root}/maxtext/integration/vllm"], is_editable=True
)


if __name__ == "__main__":
Expand Down
46 changes: 7 additions & 39 deletions src/dependencies/github_deps/install_pre_train_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,58 +21,26 @@
"""

import os
import subprocess
import sys

try:
from . import uv_utils
except ImportError:
import uv_utils


def main():
"""
Installs extra dependencies specified in pre_train_deps.txt using uv.

This script looks for 'pre_train_deps.txt' relative to its own location.
It executes 'uv pip install -r <path_to_extra_deps.txt> --resolution=lowest'.
It executes 'uv add' (if uv.lock is present) or 'uv pip install'.
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
extra_deps_path = os.path.join(current_dir, "pre_train_deps.txt")
if not os.path.exists(extra_deps_path):
raise FileNotFoundError(f"Dependencies file not found at {extra_deps_path}")

# Check if 'uv' is available in the environment
try:
subprocess.run([sys.executable, "-m", "pip", "install", "uv"], check=True, capture_output=True)
subprocess.run([sys.executable, "-m", "uv", "--version"], check=True, capture_output=True)
except subprocess.CalledProcessError as e:
print(f"Error checking uv version: {e}")
print(f"Stderr: {e.stderr.decode()}")
sys.exit(1)

command = [
sys.executable, # Use the current Python executable's pip to ensure the correct environment
"-m",
"uv",
"pip",
"install",
"-r",
str(extra_deps_path),
"--no-deps",
]

try:
# Run the command
print(f"Installing extra dependencies: {' '.join(command)}")
_ = subprocess.run(command, check=True, capture_output=True, text=True)
print("Extra dependencies installed successfully!")
except subprocess.CalledProcessError as e:
print("Failed to install extra dependencies.")
print(f"Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}.")
print("--- Stderr ---")
print(e.stderr)
print("--- Stdout ---")
print(e.stdout)
sys.exit(e.returncode)
except (OSError, FileNotFoundError) as e:
print(f"An OS-level error occurred while trying to run uv: {e}")
sys.exit(1)
uv_utils.run_install(requirements_files=[extra_deps_path])


if __name__ == "__main__":
Expand Down
111 changes: 111 additions & 0 deletions src/dependencies/github_deps/uv_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helper utilities for working with uv in installation scripts."""

import os
import shutil
import subprocess
import sys


def get_uv_command():
"""
Returns the command to run uv, either as a binary in PATH or as a module.
Attempts to install uv via pip if not found.
"""
# 1. Try finding 'uv' in PATH
uv_binary = shutil.which("uv")
if uv_binary:
return [uv_binary]

# 2. Try running it as a module
try:
subprocess.run([sys.executable, "-m", "uv", "--version"], check=True, capture_output=True)
return [sys.executable, "-m", "uv"]
except (subprocess.CalledProcessError, FileNotFoundError):
pass

# 3. Fall back to installing via pip
try:
print("uv not found in PATH or as a module. Attempting to install it via pip...")
subprocess.run([sys.executable, "-m", "pip", "install", "uv"], check=True, capture_output=True)
# Check PATH again after installation
uv_binary = shutil.which("uv")
if uv_binary:
return [uv_binary]
return [sys.executable, "-m", "uv"]
except subprocess.CalledProcessError as e:
print(f"Error installing uv via pip: {e}")
print(f"Stderr: {e.stderr.decode()}")
sys.exit(1)


def run_install(requirements_files=None, paths=None, editable_paths=None):
"""
Executes the appropriate uv install command (uv add or uv pip install).

Args:
requirements_files: List of paths to requirements.txt files.
paths: List of paths to local packages or directories (non-editable).
editable_paths: List of paths to local packages or directories (editable).
"""
uv_command = get_uv_command()
is_uv_project = os.path.exists("uv.lock")

# We run installations in two steps if we have both standard and editable items,
# because 'uv add --editable' cannot be mixed with non-local requirements.

# Step 1: Standard installations
if requirements_files or paths:
if is_uv_project:
cmd = uv_command + ["add", "--frozen"]
else:
cmd = uv_command + ["pip", "install", "--no-deps"]

if requirements_files:
for req in requirements_files:
cmd.extend(["-r", str(req)])
if paths:
cmd.extend(paths)

_execute_command(cmd)

# Step 2: Editable installations
if editable_paths:
if is_uv_project:
cmd = uv_command + ["add", "--frozen", "--editable"]
else:
cmd = uv_command + ["pip", "install", "--no-deps", "-e"]

cmd.extend(editable_paths)
_execute_command(cmd)


def _execute_command(cmd):
"""Helper to execute a command with logging and error handling."""
try:
print(f"Executing: {' '.join(cmd)}")
subprocess.run(cmd, check=True, capture_output=True, text=True)
print("Success!")
except subprocess.CalledProcessError as e:
print(f"Command failed with exit status {e.returncode}.")
print("--- Stderr ---")
print(e.stderr)
print("--- Stdout ---")
print(e.stdout)
sys.exit(e.returncode)
except (OSError, FileNotFoundError) as e:
print(f"An OS-level error occurred: {e}")
sys.exit(1)
Loading
Loading