diff --git a/.github/workflows/package-install.yml b/.github/workflows/package-install.yml index 1bd34a35c..3665c1a84 100644 --- a/.github/workflows/package-install.yml +++ b/.github/workflows/package-install.yml @@ -27,7 +27,7 @@ jobs: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" - name: Build wheel - run: uv build --wheel --out-dir dist + run: python scripts/build_package.py --wheel - name: Smoke test uv add + sync for backend extra run: | diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 68047e1d7..2ceed52e4 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,9 +10,11 @@ permissions: id-token: write jobs: - release: + build-package: runs-on: ubuntu-latest if: github.event.pull_request.merged == true && startsWith(github.event.pull_request.head.ref, 'release/') + outputs: + version: ${{ steps.get_version.outputs.VERSION }} steps: - uses: actions/checkout@v4 with: @@ -21,44 +23,113 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.11' + python-version: "3.11" - name: Install uv run: | curl -LsSf https://astral.sh/uv/install.sh | sh - echo "$HOME/.cargo/bin" >> $GITHUB_PATH - - - name: Install dependencies - run: | - uv venv - uv pip install -e . - uv pip install hatch + echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" - name: Build package - run: uv run hatch build + run: python scripts/build_package.py - name: Get version from pyproject.toml id: get_version run: | VERSION=$(python -c "import tomllib; print(tomllib.load(open('pyproject.toml', 'rb'))['project']['version'])") - echo "VERSION=$VERSION" >> $GITHUB_OUTPUT + echo "VERSION=$VERSION" >> "$GITHUB_OUTPUT" + + - name: Upload package artifact + uses: actions/upload-artifact@v4 + with: + name: python-distributions + path: dist/* + + runtime-smoke: + runs-on: art-large-runner + needs: build-package + steps: + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" + + - name: Download package artifact + uses: actions/download-artifact@v4 + with: + name: python-distributions + path: dist + + - name: Smoke test managed vLLM runtime install + run: | + export ART_VLLM_RUNTIME_CACHE_DIR="${RUNNER_TEMP}/art-vllm-runtime-cache" + export UV_LINK_MODE=copy + wheel_path="$(python - <<'PY' + from pathlib import Path + + print(next(Path("dist").glob("openpipe_art-*.whl")).resolve()) + PY + )" + + project_dir="$(mktemp -d)" + cd "$project_dir" + uv init --name art-runtime-smoke --python 3.11 --bare + uv add "openpipe-art[backend] @ file://${wheel_path}" + uv sync + uv run python - <<'PY' + from pathlib import Path + import subprocess + + from art.vllm_runtime import ensure_vllm_runtime + + runtime_bin = ensure_vllm_runtime() + runtime_python = Path(runtime_bin).parent / "python" + subprocess.run([str(runtime_bin), "--help"], check=True) + subprocess.run( + [ + str(runtime_python), + "-c", + "import art_vllm_runtime, torch, vllm; print('runtime imports ok')", + ], + check=True, + ) + print(runtime_bin) + PY + + publish: + runs-on: ubuntu-latest + needs: [build-package, runtime-smoke] + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Download package artifact + uses: actions/download-artifact@v4 + with: + name: python-distributions + path: dist - name: Create git tag run: | git config --local user.email "action@github.com" git config --local user.name "GitHub Action" - git tag v${{ steps.get_version.outputs.VERSION }} - git push origin v${{ steps.get_version.outputs.VERSION }} + git tag v${{ needs.build-package.outputs.version }} + git push origin v${{ needs.build-package.outputs.version }} - name: Publish draft release env: GH_TOKEN: ${{ github.token }} run: | - # Check if draft release exists and publish it - if gh release view v${{ steps.get_version.outputs.VERSION }} --json isDraft | jq -r '.isDraft' | grep -q true; then - gh release edit v${{ steps.get_version.outputs.VERSION }} --draft=false + if gh release view v${{ needs.build-package.outputs.version }} --json isDraft | jq -r '.isDraft' | grep -q true; then + gh release edit v${{ needs.build-package.outputs.version }} --draft=false else - echo "::error::No draft release found for v${{ steps.get_version.outputs.VERSION }}" + echo "::error::No draft release found for v${{ needs.build-package.outputs.version }}" exit 1 fi @@ -66,7 +137,7 @@ jobs: env: GH_TOKEN: ${{ github.token }} run: | - gh release upload v${{ steps.get_version.outputs.VERSION }} dist/* + gh release upload v${{ needs.build-package.outputs.version }} dist/* - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore index bc0764abb..d1f4ebd59 100644 --- a/.gitignore +++ b/.gitignore @@ -20,4 +20,5 @@ trajectories/ .ruff_cache/ !/src/art/wandb/ !/src/art/wandb/** -/src/art/wandb/__pycache__/ \ No newline at end of file +/src/art/wandb/__pycache__/ +scratch/ diff --git a/dev/bench_cute_grouped_lora.py b/dev/bench_cute_grouped_lora.py index 770332768..4cb838dcd 100644 --- a/dev/bench_cute_grouped_lora.py +++ b/dev/bench_cute_grouped_lora.py @@ -11,7 +11,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator import torch -from art.megatron.cute_grouped_lora_quack import quack_grouped_lora +from art.megatron.kernels.cute_grouped_lora_quack import quack_grouped_lora GroupedLoraFn = Callable[ [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], diff --git a/pyproject.toml b/pyproject.toml index a255a01aa..11e5893c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,12 +39,13 @@ backend = [ "nbmake>=1.5.5", "gql<4", "nvidia-cudnn-frontend<1.21 ; sys_platform == 'linux'", + "nvidia-nccl-cu12==2.28.9 ; sys_platform == 'linux'", "nvidia-resiliency-ext<0.5 ; sys_platform == 'linux'", - "vllm @ https://github.com/vivekkalyan/vllm/releases/download/v0.17.0-art1/vllm-0.17.0%2Bart1-cp38-abi3-manylinux_2_31_x86_64.whl ; sys_platform == 'linux'", ] megatron = [ "numpy<2", "torch==2.10.0", + "quack-kernels==0.2.5", "apex @ git+https://github.com/NVIDIA/apex.git@25.09", "transformer-engine==2.11.0", "transformer-engine-cu12==2.11.0", @@ -56,6 +57,7 @@ megatron = [ "causal-conv1d @ https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.1.post4/causal_conv1d-1.6.1%2Bcu12torch2.10cxx11abiTRUE-cp311-cp311-linux_x86_64.whl ; sys_platform == 'linux' and platform_machine == 'x86_64' and python_full_version < '3.12'", "mamba-ssm @ https://github.com/state-spaces/mamba/releases/download/v2.3.1/mamba_ssm-2.3.1%2Bcu12torch2.10cxx11abiTRUE-cp311-cp311-linux_x86_64.whl ; sys_platform == 'linux' and platform_machine == 'x86_64' and python_full_version < '3.12'", "nvidia-ml-py==13.580.82", + "nvidia-nccl-cu12==2.28.9 ; sys_platform == 'linux'", "nvidia-resiliency-ext<0.5 ; sys_platform == 'linux'", "ml-dtypes>=0.5.0 ; python_full_version < '3.13'", ] @@ -74,6 +76,7 @@ tinker = [ "pydantic>=2.12.5", "tinker-cookbook>=0.3.0,<0.4", "tinker>=0.18.2,<0.19", + "nvidia-nccl-cu12==2.28.9 ; sys_platform == 'linux'", "torch==2.10.0", "transformers==5.2.0", "uvicorn>=0.35.0", @@ -83,9 +86,6 @@ tinker = [ [project.scripts] art = "art.cli:app" -[project.entry-points."vllm.general_plugins"] -art = "art.vllm.patches:patch_transformers_v5_compat" - [build-system] requires = ["hatchling"] build-backend = "hatchling.build" @@ -103,17 +103,30 @@ packages = ["src/art", "src/mp_actors"] sources = ["src"] [tool.hatch.build.targets.sdist] +sources = [] +only-include = [ + ".agents/skills", + "LICENSE", + "README.md", + "THIRD-PARTY-NOTICES", + "pyproject.toml", + "src", +] exclude = [ "/dev", "/wandb", "/.art", + "/.local", "/.ruff_cache", "/.venv", "/dist", + "/scratch", + "/unsloth_compiled_cache", "/.git", "/.github", "/examples/*/data", "/examples/*/wandb", + "/tests/unsloth_compiled_cache", "**/__pycache__", "**/*.pyc", ] @@ -138,6 +151,7 @@ required-version = ">=0.11.7" override-dependencies = [ "flashinfer-python==0.6.1", "numpy<2", + "nvidia-nccl-cu12==2.28.9 ; sys_platform == 'linux'", "nvidia-resiliency-ext<0.5", "quack-kernels==0.2.5", "transformer-engine==2.11.0", @@ -217,7 +231,6 @@ allowed-unresolved-imports = [ "unsloth.**", "unsloth_zoo.**", "uvicorn.**", - "vllm.**", "wandb.**", # langgraph deps "langchain_core.**", diff --git a/scripts/build_package.py b/scripts/build_package.py new file mode 100644 index 000000000..d4f0a4f12 --- /dev/null +++ b/scripts/build_package.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import hashlib +import json +from pathlib import Path +import shutil +import subprocess +import sys +import tarfile +import tempfile +import tomllib +import zipfile + +ROOT = Path(__file__).resolve().parents[1] +BUNDLE_DIR = ROOT / "src" / "art" / "_vllm_runtime" +BUNDLE_MARKER = BUNDLE_DIR / ".art_generated" +PROTOCOL_VERSION = 1 + + +def run(command: list[str], *, cwd: Path = ROOT) -> None: + print("+", " ".join(command), flush=True) + subprocess.run(command, cwd=cwd, check=True) + + +def read_pyproject(path: Path) -> dict: + return tomllib.loads(path.read_text()) + + +def sha256_file(path: Path) -> str: + digest = hashlib.sha256() + with path.open("rb") as file: + for chunk in iter(lambda: file.read(1024 * 1024), b""): + digest.update(chunk) + return digest.hexdigest() + + +def clean_bundle_dir() -> None: + if not BUNDLE_DIR.exists(): + return + if not BUNDLE_MARKER.exists(): + raise RuntimeError( + f"Refusing to remove non-generated runtime bundle directory: {BUNDLE_DIR}" + ) + shutil.rmtree(BUNDLE_DIR) + + +def build_runtime_wheel(runtime_dist: Path) -> Path: + run(["uv", "lock", "--project", "vllm_runtime", "--check"]) + run( + [ + "uv", + "build", + "--wheel", + "vllm_runtime", + "--out-dir", + str(runtime_dist), + "--no-progress", + ] + ) + wheels = sorted(runtime_dist.glob("art_vllm_runtime-*.whl")) + if len(wheels) != 1: + raise RuntimeError(f"Expected one art-vllm-runtime wheel, found {wheels}") + return wheels[0] + + +def write_bundle(runtime_wheel: Path) -> None: + root_project = read_pyproject(ROOT / "pyproject.toml")["project"] + runtime_project = read_pyproject(ROOT / "vllm_runtime" / "pyproject.toml")[ + "project" + ] + pyproject = ROOT / "vllm_runtime" / "pyproject.toml" + lockfile = ROOT / "vllm_runtime" / "uv.lock" + + BUNDLE_DIR.mkdir(parents=True) + shutil.copy2(pyproject, BUNDLE_DIR / "pyproject.toml") + shutil.copy2(lockfile, BUNDLE_DIR / "uv.lock") + shutil.copy2(runtime_wheel, BUNDLE_DIR / runtime_wheel.name) + + manifest = { + "art_package": root_project["name"], + "art_version": root_project["version"], + "runtime_package": runtime_project["name"], + "runtime_version": runtime_project["version"], + "protocol_version": PROTOCOL_VERSION, + "python": runtime_project["requires-python"], + "runtime_wheel": runtime_wheel.name, + "runtime_wheel_sha256": sha256_file(runtime_wheel), + "pyproject": "pyproject.toml", + "pyproject_sha256": sha256_file(pyproject), + "lockfile": "uv.lock", + "lockfile_sha256": sha256_file(lockfile), + } + (BUNDLE_DIR / "manifest.json").write_text( + json.dumps(manifest, indent=2, sort_keys=True) + "\n" + ) + BUNDLE_MARKER.write_text("generated by scripts/build_package.py\n") + + +def build_root_package(*, wheel_only: bool, out_dir: Path) -> None: + if out_dir.exists(): + shutil.rmtree(out_dir) + command = ["uv", "build", "--out-dir", str(out_dir), "--no-progress"] + if wheel_only: + command.append("--wheel") + run(command) + + +def wheel_metadata(wheel: Path) -> str: + with zipfile.ZipFile(wheel) as archive: + metadata_names = [ + name for name in archive.namelist() if name.endswith(".dist-info/METADATA") + ] + if len(metadata_names) != 1: + raise RuntimeError(f"Expected one METADATA file in {wheel}") + return archive.read(metadata_names[0]).decode() + + +def verify_wheel(wheel: Path) -> None: + expected = { + "art/_vllm_runtime/manifest.json", + "art/_vllm_runtime/pyproject.toml", + "art/_vllm_runtime/uv.lock", + } + with zipfile.ZipFile(wheel) as archive: + names = set(archive.namelist()) + missing = expected - names + runtime_wheels = [ + name + for name in names + if name.startswith("art/_vllm_runtime/art_vllm_runtime-") + and name.endswith(".whl") + ] + if missing: + raise RuntimeError(f"Wheel missing runtime bundle files: {sorted(missing)}") + if len(runtime_wheels) != 1: + raise RuntimeError( + f"Expected one bundled runtime wheel, found {runtime_wheels}" + ) + + bad_dependencies: list[str] = [] + for line in wheel_metadata(wheel).splitlines(): + if not line.startswith("Requires-Dist:"): + continue + requirement = line.removeprefix("Requires-Dist:").strip().lower() + if requirement.startswith("vllm") or requirement.startswith("art-vllm-runtime"): + bad_dependencies.append(line) + if bad_dependencies: + raise RuntimeError( + "Root wheel must not depend on vLLM runtime packages: " + + "; ".join(bad_dependencies) + ) + + +def verify_sdist(sdist: Path) -> None: + expected = { + "src/art/_vllm_runtime/manifest.json", + "src/art/_vllm_runtime/pyproject.toml", + "src/art/_vllm_runtime/uv.lock", + } + with tarfile.open(sdist) as archive: + names = set(archive.getnames()) + prefix = next(iter(names)).split("/", 1)[0] + missing = {f"{prefix}/{name}" for name in expected} - names + runtime_wheels = [ + name + for name in names + if name.startswith(f"{prefix}/src/art/_vllm_runtime/art_vllm_runtime-") + and name.endswith(".whl") + ] + if missing: + raise RuntimeError(f"sdist missing runtime bundle files: {sorted(missing)}") + if len(runtime_wheels) != 1: + raise RuntimeError( + f"Expected one bundled runtime wheel, found {runtime_wheels}" + ) + + +def verify_dist(out_dir: Path, *, wheel_only: bool) -> None: + root_wheels = sorted(out_dir.glob("openpipe_art-*.whl")) + if len(root_wheels) != 1: + raise RuntimeError(f"Expected one openpipe-art wheel, found {root_wheels}") + verify_wheel(root_wheels[0]) + + if wheel_only: + return + sdists = sorted(out_dir.glob("openpipe_art-*.tar.gz")) + if len(sdists) != 1: + raise RuntimeError(f"Expected one openpipe-art sdist, found {sdists}") + verify_sdist(sdists[0]) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Build ART package artifacts") + parser.add_argument("--wheel", action="store_true", help="Build only the wheel") + parser.add_argument("--out-dir", default="dist", type=Path) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + out_dir = args.out_dir + if not out_dir.is_absolute(): + out_dir = ROOT / out_dir + + clean_bundle_dir() + try: + with tempfile.TemporaryDirectory() as temp_dir: + runtime_wheel = build_runtime_wheel(Path(temp_dir)) + write_bundle(runtime_wheel) + build_root_package(wheel_only=args.wheel, out_dir=out_dir) + verify_dist(out_dir, wheel_only=args.wheel) + finally: + clean_bundle_dir() + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/publish.sh b/scripts/publish.sh index 5b614660a..e5cca6f57 100755 --- a/scripts/publish.sh +++ b/scripts/publish.sh @@ -15,7 +15,7 @@ fi rm -rf dist # Build the package -uv run hatch build +python scripts/build_package.py # If the token is set, proceed with publishing diff --git a/src/art/__init__.py b/src/art/__init__.py index 8e494e6c4..6cdc18667 100644 --- a/src/art/__init__.py +++ b/src/art/__init__.py @@ -29,28 +29,15 @@ suppress_litellm_serialization_warnings() -# Create a dummy GuidedDecodingParams class and inject it into vllm.sampling_params for trl compatibility -try: - import vllm.sampling_params - - class GuidedDecodingParams: - """Shim for vLLM 0.13+ where GuidedDecodingParams was removed.""" - - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - vllm.sampling_params.GuidedDecodingParams = GuidedDecodingParams # type: ignore -except ImportError: - pass # vllm not installed - # torch.cuda.MemPool doesn't currently support expandable_segments which is used in sleep mode conf = os.getenv("PYTORCH_CUDA_ALLOC_CONF", "").split(",") if "expandable_segments:True" in conf: conf.remove("expandable_segments:True") os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ",".join(conf) -# Import unsloth before transformers, peft, and trl to maximize Unsloth optimizations +# Import unsloth before transformers, peft, and trl only in backend processes that +# explicitly request it. Unsloth is an optional backend dependency, not a base ART +# import dependency. if os.environ.get("IMPORT_UNSLOTH", "0") == "1": import unsloth # noqa: F401 diff --git a/src/art/cli.py b/src/art/cli.py index 1d3da12de..9fed1e74f 100644 --- a/src/art/cli.py +++ b/src/art/cli.py @@ -230,6 +230,8 @@ def migrate( def run(host: str = "0.0.0.0", port: int = 7999) -> None: """Run the ART CLI.""" + from contextlib import asynccontextmanager + from fastapi import Body, FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse import pydantic @@ -264,7 +266,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: TrajectoryGroup.__init__ = __init__ # ty:ignore[invalid-assignment] backend = LocalBackend() - app = FastAPI() + + @asynccontextmanager + async def lifespan(_: FastAPI): + try: + yield + finally: + await backend.close() + + app = FastAPI(lifespan=lifespan) # Add exception handler for ARTError @app.exception_handler(ARTError) diff --git a/src/art/dev/engine.py b/src/art/dev/engine.py index d79384f72..517bc83ab 100644 --- a/src/art/dev/engine.py +++ b/src/art/dev/engine.py @@ -72,6 +72,7 @@ class EngineArgs(TypedDict, total=False): max_prompt_adapters: int max_prompt_adapter_token: int fully_sharded_loras: bool + lora_target_modules: list[str] lora_extra_vocab_size: int long_lora_scaling_factors: Tuple[float] | None lora_dtype: str | None @@ -123,6 +124,7 @@ class EngineArgs(TypedDict, total=False): generation_config: str | None override_generation_config: dict[str, Any] | None enable_sleep_mode: bool + enable_expert_parallel: bool model_impl: str calculate_kv_scales: bool | None diff --git a/src/art/dev/get_model_config.py b/src/art/dev/get_model_config.py index 550f97e4f..850008ae0 100644 --- a/src/art/dev/get_model_config.py +++ b/src/art/dev/get_model_config.py @@ -1,31 +1,11 @@ +from ..megatron.model_support import default_target_modules_for_model from .engine import EngineArgs from .model import InitArgs, InternalModelConfig, PeftArgs, TrainerArgs -from .validate import QWEN3_5_MOE_MODELS, is_dedicated_mode +from .validate import is_dedicated_mode def default_target_modules(base_model: str) -> list[str]: - if base_model in QWEN3_5_MOE_MODELS: - return [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "in_proj_qkv", - "in_proj_z", - "out_proj", - "gate_proj", - "up_proj", - "down_proj", - ] - return [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ] + return default_target_modules_for_model(base_model, allow_unvalidated_arch=True) def get_model_config( @@ -51,10 +31,7 @@ def get_model_config( max_seq_length=32768, model_name=base_model, ) - # fast_inference triggers in-process vLLM via Unsloth; dedicated mode runs vLLM as a subprocess - if not dedicated: - init_args["fast_inference"] = False - + target_modules = default_target_modules(base_model) engine_args = EngineArgs( allowed_local_media_path="/tmp", enable_sleep_mode=enable_sleep_mode, @@ -69,10 +46,14 @@ def get_model_config( lora_alpha=16, r=8, random_state=3407, - target_modules=default_target_modules(base_model), + target_modules=target_modules, use_gradient_checkpointing="unsloth", ) peft_args.update(config.get("peft_args", {})) + if rollout_weights_mode == "lora" and "lora_target_modules" not in config.get( + "engine_args", {} + ): + engine_args["lora_target_modules"] = peft_args["target_modules"] trainer_args = TrainerArgs( adam_beta1=0.9, adam_beta2=0.99, @@ -100,6 +81,8 @@ def get_model_config( tinker_args=config.get("tinker_args"), trainer_args=trainer_args, ) + if "allow_unvalidated_arch" in config: + result["allow_unvalidated_arch"] = config["allow_unvalidated_arch"] if "trainer_gpu_ids" in config: result["trainer_gpu_ids"] = config["trainer_gpu_ids"] if "inference_gpu_ids" in config: diff --git a/src/art/dev/model.py b/src/art/dev/model.py index e55b35d18..1c0f18f1f 100644 --- a/src/art/dev/model.py +++ b/src/art/dev/model.py @@ -127,6 +127,8 @@ class InternalModelConfig(TypedDict, total=False): - "lora": load LoRA adapters into vLLM directly - "merged": keep training LoRA adapters, but push merged weights into vLLM for inference + allow_unvalidated_arch: Permit model-support validation workflows to run + architectures that are not yet in the supported-model registry. """ init_args: "InitArgs" @@ -138,6 +140,7 @@ class InternalModelConfig(TypedDict, total=False): trainer_gpu_ids: list[int] inference_gpu_ids: list[int] rollout_weights_mode: "RolloutWeightsMode" + allow_unvalidated_arch: bool class TinkerArgs(TypedDict, total=False): diff --git a/src/art/dev/validate.py b/src/art/dev/validate.py index 7ab8c6a1f..56e91c1df 100644 --- a/src/art/dev/validate.py +++ b/src/art/dev/validate.py @@ -2,11 +2,6 @@ from .model import InternalModelConfig, RolloutWeightsMode -QWEN3_5_MOE_MODELS = { - "Qwen/Qwen3.5-35B-A3B", - "Qwen/Qwen3.5-397B-A17B", -} - def is_dedicated_mode(config: InternalModelConfig) -> bool: """Return True if the config specifies dedicated mode (separate training and inference GPUs).""" @@ -20,11 +15,6 @@ def _rollout_weights_mode(config: InternalModelConfig) -> RolloutWeightsMode: raise ValueError("rollout_weights_mode must be either 'lora' or 'merged'") -def _is_qwen3_5_moe_model(config: InternalModelConfig) -> bool: - model_name = config.get("engine_args", {}).get("model") - return model_name in QWEN3_5_MOE_MODELS - - def validate_dedicated_config(config: InternalModelConfig) -> None: """Validate dedicated mode GPU configuration. @@ -46,6 +36,12 @@ def validate_dedicated_config(config: InternalModelConfig) -> None: "(set both trainer_gpu_ids and inference_gpu_ids)" ) + if "fast_inference" in config.get("init_args", {}): + raise ValueError( + "fast_inference is no longer supported; ART always uses an external " + "vLLM runtime" + ) + if not has_trainer: return @@ -77,21 +73,8 @@ def validate_dedicated_config(config: InternalModelConfig) -> None: "trainer_gpu_ids must be contiguous starting from 0 (e.g., [0], [0,1])" ) - # Reject settings that are incompatible with dedicated mode - if config.get("init_args", {}).get("fast_inference"): - raise ValueError( - "fast_inference is incompatible with dedicated mode " - "(dedicated mode runs vLLM as a subprocess, not in-process)" - ) - if config.get("engine_args", {}).get("enable_sleep_mode"): raise ValueError( "enable_sleep_mode is incompatible with dedicated mode " - "(dedicated mode runs vLLM on a separate GPU, sleep/wake is not needed)" - ) - - if _is_qwen3_5_moe_model(config) and rollout_weights_mode == "lora": - raise ValueError( - "Qwen3.5-MoE models require rollout_weights_mode='merged' with the " - "current vLLM version because direct LoRA inference is currently broken" + "(shared-GPU mode uses runtime sleep/wake; dedicated mode does not)" ) diff --git a/src/art/local/backend.py b/src/art/local/backend.py index daa490204..180bc08a2 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -1,11 +1,10 @@ -import asyncio +import gc import json import logging import math import os import shutil import socket -import subprocess import time from types import TracebackType from typing import AsyncIterator, Iterable, Literal, cast @@ -17,7 +16,6 @@ "H200": 3.0, } -import aiohttp import numpy as np import polars as pl import torch @@ -105,7 +103,6 @@ def __init__( # Other initialization self._services: dict[str, ModelService] = {} - self._monitor_tasks: dict[str, asyncio.Task[None]] = {} self._tokenizers: dict[str, PreTrainedTokenizerBase] = {} self._image_processors: dict[str, BaseImageProcessor | None] = {} self._requires_explicit_packed_sequence_length = False @@ -189,16 +186,7 @@ async def close(self) -> None: """ If running vLLM in a separate process, this will kill that process and close the communication threads. """ - tasks = list(self._monitor_tasks.values()) - for task in tasks: - task.cancel() - for task in tasks: - try: - await task - except (asyncio.CancelledError, Exception): - pass - self._monitor_tasks.clear() - for service in list(self._services.values()): + for service in self._services.values(): aclose = getattr(service, "aclose", None) if aclose is None: close = getattr(service, "close", None) @@ -207,16 +195,23 @@ async def close(self) -> None: else: await aclose() close_proxy(service) + self._services.clear() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def _close(self) -> None: - for task in list(self._monitor_tasks.values()): - task.cancel() - self._monitor_tasks.clear() - for service in list(self._services.values()): + for service in self._services.values(): close = getattr(service, "close", None) if close is not None: close() close_proxy(service) + self._services.clear() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() async def register( self, @@ -324,8 +319,6 @@ async def _get_service(self, model: TrainableModel) -> ModelService: output_dir=get_model_dir(model=model, art_path=self._path), ) if not dedicated and not self._in_process: - # Kill all "model-service" processes to free up GPU memory - subprocess.run(["pkill", "-9", "model-service"]) self._services[model.name] = move_to_child_process( self._services[model.name], process_name="tinker-service" if is_tinker else "model-service", @@ -498,86 +491,8 @@ async def _prepare_backend_for_training( base_url = f"http://{host}:{port}/v1" api_key = server_args.get("api_key") or "default" - def done_callback(_: asyncio.Task[None]) -> None: - self._monitor_tasks.pop(model.name, None) - close_proxy(self._services.pop(model.name)) - - task = asyncio.create_task( - self._monitor_openai_server(model, base_url, api_key) - ) - task.add_done_callback(done_callback) - self._monitor_tasks[model.name] = task - return base_url, api_key - async def _monitor_openai_server( - self, model: AnyTrainableModel, base_url: str, api_key: str - ) -> None: - model_name = model.name - consecutive_failures = 0 - max_consecutive_failures = 3 - async with aiohttp.ClientSession() as session: - while True: - # Wait 30 seconds before checking again - await asyncio.sleep(30) - try: - # If the server is sleeping, skip the check - if await self._services[model_name].vllm_engine_is_sleeping(): - consecutive_failures = 0 - continue - # Check the metrics with a timeout - async with session.get( - f"{base_url.split('/v1')[0]}/metrics", - timeout=aiohttp.ClientTimeout(total=10), - ) as response: - metrics = await response.text() - # Parse Prometheus metrics for running requests - running_requests = 0 - pending_requests = 0 - for line in metrics.split("\n"): - if line.startswith("vllm:num_requests_running"): - running_requests = int(float(line.split()[1])) - elif line.startswith("vllm:num_requests_waiting"): - pending_requests = int(float(line.split()[1])) - # If there are no running or pending requests, send a cheap liveness - # probe rather than a real generation request. Large models can take - # longer than a short completion-based probe while still being healthy. - if running_requests == 0 and pending_requests == 0: - try: - async with session.get( - f"{base_url.split('/v1')[0]}/health", - timeout=float( - os.environ.get("ART_SERVER_MONITOR_TIMEOUT", 5.0) - ), - ) as health_response: - if health_response.status >= 400: - raise RuntimeError( - "OpenAI server health check failed with " - f"status {health_response.status}" - ) - except Exception as e: - # If the server is sleeping, a failed health check is okay - if await self._services[ - model_name - ].vllm_engine_is_sleeping(): - consecutive_failures = 0 - continue - raise e - # Reset failure counter on success - consecutive_failures = 0 - except Exception: - # If the server is sleeping during an exception, it's okay - try: - if await self._services[model_name].vllm_engine_is_sleeping(): - consecutive_failures = 0 - continue - except Exception: - pass # If we can't check sleeping status, count it as a failure - consecutive_failures += 1 - if consecutive_failures >= max_consecutive_failures: - raise - # Otherwise, continue and try again - # Note: _log() method has been moved to the Model class (frontend) def _trajectory_log(self, trajectory: Trajectory) -> str: @@ -648,8 +563,9 @@ async def train( # type: ignore[override] "cispo" and "ppo". loss_fn_config: Additional loss-function config. Not supported by LocalBackend. - normalize_advantages: Whether to normalize advantages. LocalBackend - currently requires True. + normalize_advantages: Backward-compatible alias for reward std scaling. + When False, LocalBackend centers rewards but does not divide by + group reward std dev. adam_params: Custom optimizer params. Not supported by LocalBackend. kl_penalty_coef: Coefficient for KL-penalized advantage adjustment. @@ -712,7 +628,7 @@ async def train( # type: ignore[override] if loss_fn_config is not None: raise ValueError("LocalBackend requires loss_fn_config=None.") if not normalize_advantages: - raise ValueError("LocalBackend requires normalize_advantages=True.") + scale_rewards = False if adam_params is not None: raise ValueError("LocalBackend requires adam_params=None.") if ( @@ -1035,6 +951,7 @@ async def _train_sft( **result, "data/step_num_trajectories": float(total_trajectories), "data/step_trainer_tokens": float(total_trainable_tokens), + "data/step_num_dropped_trajectories": float(total_dropped_trajectories), TRAIN_GRADIENT_STEPS_KEY: float(len(batches)), } diff --git a/src/art/megatron/__init__.py b/src/art/megatron/__init__.py index 07107df61..720e3a88f 100644 --- a/src/art/megatron/__init__.py +++ b/src/art/megatron/__init__.py @@ -1,3 +1,11 @@ -from .backend import MegatronBackend +from typing import Any __all__ = ["MegatronBackend"] + + +def __getattr__(name: str) -> Any: + if name == "MegatronBackend": + from .runtime.backend import MegatronBackend + + return MegatronBackend + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/art/megatron/bridge_adapter_compat.py b/src/art/megatron/bridge_adapter_compat.py deleted file mode 100644 index 71e3ce6c7..000000000 --- a/src/art/megatron/bridge_adapter_compat.py +++ /dev/null @@ -1,318 +0,0 @@ -import math - -from megatron.bridge.models.conversion.model_bridge import MegatronWeightTuple -from megatron.bridge.models.conversion.peft_bridge import AdapterWeight -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.transformer_layer import TransformerLayer -import torch - -from art.megatron.lora import ( - GatedDeltaNetInProjLoRA, - LoRA, - MLPExpertsLinearFC1LoRA, - MLPExpertsLinearFC2LoRA, - SelfAttentionLinearProjLoRA, - SelfAttentionLinearQKVLoRA, - SharedExpertsLinearFC1LoRA, - SharedExpertsLinearFC2LoRA, -) - - -def _is_language_transformer_layer_name(module_name: str) -> bool: - while module_name.startswith("module."): - module_name = module_name.removeprefix("module.") - return module_name.startswith(("decoder.layers.", "language_model.decoder.layers.")) - - -def _language_layer_prefix(module_name: str, layer_number: int) -> str: - while module_name.startswith("module."): - module_name = module_name.removeprefix("module.") - if module_name.startswith("decoder.layers."): - return f"decoder.layers.{layer_number - 1}" - return f"language_model.decoder.layers.{layer_number - 1}" - - -def _adapter_alpha_dim(lora: LoRA) -> tuple[int, int]: - dim = int(lora.A_T.shape[-1]) - alpha = float(lora.scale) * dim - rounded_alpha = round(alpha) - assert math.isclose(alpha, rounded_alpha) - return rounded_alpha, dim - - -def _adapter_tensors( - lora: LoRA, expert_idx: int | None = None -) -> tuple[torch.Tensor, torch.Tensor]: - a_t = lora.A_T if expert_idx is None else lora.A_T[expert_idx] - b_t = lora.B_T if expert_idx is None else lora.B_T[expert_idx] - return a_t.transpose(-1, -2).contiguous(), b_t.transpose(-1, -2).contiguous() - - -def _adapter_param_prefix(base_prefix: str, adapter_key: str | None) -> str: - if adapter_key is None: - return f"{base_prefix}.adapter" - return f"{base_prefix}.adapter.{adapter_key}" - - -def _adapter_weight( - *, - base_prefix: str, - adapter_key: str | None, - alpha: int, - dim: int, - linear_in: torch.Tensor, - linear_out: torch.Tensor, -) -> AdapterWeight: - param_prefix = _adapter_param_prefix(base_prefix, adapter_key) - return AdapterWeight( - global_base_prefix=base_prefix, - adapter_key=adapter_key, - alpha=alpha, - dim=dim, - linear_in_weight=MegatronWeightTuple( - param_name=f"{param_prefix}.linear_in.weight", - weight=linear_in, - vp_stage=0, - ), - linear_out_weight=MegatronWeightTuple( - param_name=f"{param_prefix}.linear_out.weight", - weight=linear_out, - vp_stage=0, - ), - ) - - -def _simple_adapter_weight( - base_prefix: str, - lora: LoRA, - *, - adapter_key: str | None = None, - expert_idx: int | None = None, -) -> AdapterWeight: - alpha, dim = _adapter_alpha_dim(lora) - linear_in, linear_out = _adapter_tensors(lora, expert_idx) - return _adapter_weight( - base_prefix=base_prefix, - adapter_key=adapter_key, - alpha=alpha, - dim=dim, - linear_in=linear_in, - linear_out=linear_out, - ) - - -def _fused_gdn_adapter_weight( - base_prefix: str, - handler: GatedDeltaNetInProjLoRA, -) -> AdapterWeight: - qkv_linear_in, qkv_linear_out = _adapter_tensors(handler.qkv_lora) - z_linear_in, z_linear_out = _adapter_tensors(handler.z_lora) - assert math.isclose(float(handler.qkv_lora.scale), float(handler.z_lora.scale)) - total_dim = int(qkv_linear_in.shape[0] + z_linear_in.shape[0]) - alpha = round(float(handler.qkv_lora.scale) * total_dim) - - qkv_rank = int(qkv_linear_in.shape[0]) - z_rank = int(z_linear_in.shape[0]) - qkv_out = int(qkv_linear_out.shape[0]) - z_out = int(z_linear_out.shape[0]) - beta_alpha_out = int(handler.num_value_heads_per_partition) - - qkv_padding = qkv_linear_out.new_zeros((qkv_out, z_rank)) - z_padding = z_linear_out.new_zeros((z_out, qkv_rank)) - zeros = qkv_linear_out.new_zeros((beta_alpha_out, total_dim)) - - return _adapter_weight( - base_prefix=base_prefix, - adapter_key=None, - alpha=alpha, - dim=total_dim, - linear_in=torch.cat([qkv_linear_in, z_linear_in], dim=0), - linear_out=torch.cat( - [ - torch.cat([qkv_linear_out, qkv_padding], dim=1), - torch.cat([z_padding, z_linear_out], dim=1), - zeros, - zeros.clone(), - ], - dim=0, - ), - ) - - -def _fused_pair_adapter_weight( - base_prefix: str, - first_lora: LoRA, - second_lora: LoRA, - *, - first_expert_idx: int | None = None, - second_expert_idx: int | None = None, -) -> AdapterWeight: - first_linear_in, first_linear_out = _adapter_tensors(first_lora, first_expert_idx) - second_linear_in, second_linear_out = _adapter_tensors( - second_lora, second_expert_idx - ) - assert math.isclose(float(first_lora.scale), float(second_lora.scale)) - total_dim = int(first_linear_in.shape[0] + second_linear_in.shape[0]) - alpha = round(float(first_lora.scale) * total_dim) - - first_rank = int(first_linear_in.shape[0]) - second_rank = int(second_linear_in.shape[0]) - first_out = int(first_linear_out.shape[0]) - second_out = int(second_linear_out.shape[0]) - - first_padding = first_linear_out.new_zeros((first_out, second_rank)) - second_padding = second_linear_out.new_zeros((second_out, first_rank)) - - return _adapter_weight( - base_prefix=base_prefix, - adapter_key=None, - alpha=alpha, - dim=total_dim, - linear_in=torch.cat([first_linear_in, second_linear_in], dim=0), - linear_out=torch.cat( - [ - torch.cat([first_linear_out, first_padding], dim=1), - torch.cat([second_padding, second_linear_out], dim=1), - ], - dim=0, - ), - ) - - -def build_adapter_weights_by_base( - model_chunks: list[MegatronModule], -) -> dict[str, list[AdapterWeight]]: - adapter_weights_by_base: dict[str, list[AdapterWeight]] = {} - for chunk in model_chunks: - for module_name, module in chunk.named_modules(): - if not isinstance(module, TransformerLayer): - continue - if not _is_language_transformer_layer_name(module_name): - continue - - layer_prefix = _language_layer_prefix(module_name, module.layer_number) - self_attention = module.self_attention - - linear_proj = getattr(self_attention, "linear_proj", None) - if isinstance(linear_proj, SelfAttentionLinearProjLoRA): - base_prefix = f"{layer_prefix}.self_attention.linear_proj" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight(base_prefix, linear_proj.lora) - ] - - linear_qkv = getattr(self_attention, "linear_qkv", None) - if isinstance(linear_qkv, SelfAttentionLinearQKVLoRA): - base_prefix = f"{layer_prefix}.self_attention.linear_qkv" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight( - base_prefix, linear_qkv.q_proj_lora, adapter_key="adapter_q" - ), - _simple_adapter_weight( - base_prefix, linear_qkv.k_proj_lora, adapter_key="adapter_k" - ), - _simple_adapter_weight( - base_prefix, linear_qkv.v_proj_lora, adapter_key="adapter_v" - ), - ] - - out_proj = getattr(self_attention, "out_proj", None) - if isinstance(out_proj, SelfAttentionLinearProjLoRA): - base_prefix = f"{layer_prefix}.self_attention.out_proj" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight(base_prefix, out_proj.lora) - ] - - in_proj = getattr(self_attention, "in_proj", None) - if isinstance(in_proj, GatedDeltaNetInProjLoRA): - base_prefix = f"{layer_prefix}.self_attention.in_proj" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _fused_gdn_adapter_weight(base_prefix, in_proj) - ] - - experts = getattr(module.mlp, "experts", None) - if experts is not None: - if isinstance(experts.linear_fc1, MLPExpertsLinearFC1LoRA): - base_prefix = f"{layer_prefix}.mlp.experts.linear_fc1" - for local_expert_idx in range( - experts.linear_fc1.gate_lora.num_local_experts - ): - global_expert_idx = ( - local_expert_idx - + experts.linear_fc1.gate_lora._expert_offset - ) - adapter_weights_by_base[ - f"{base_prefix}.weight{global_expert_idx}" - ] = [ - _fused_pair_adapter_weight( - base_prefix, - experts.linear_fc1.gate_lora, - experts.linear_fc1.up_lora, - first_expert_idx=local_expert_idx, - second_expert_idx=local_expert_idx, - ) - ] - if isinstance(experts.linear_fc2, MLPExpertsLinearFC2LoRA): - base_prefix = f"{layer_prefix}.mlp.experts.linear_fc2" - for local_expert_idx in range( - experts.linear_fc2.lora.num_local_experts - ): - global_expert_idx = ( - local_expert_idx + experts.linear_fc2.lora._expert_offset - ) - adapter_weights_by_base[ - f"{base_prefix}.weight{global_expert_idx}" - ] = [ - _simple_adapter_weight( - base_prefix, - experts.linear_fc2.lora, - expert_idx=local_expert_idx, - ) - ] - else: - linear_fc1 = getattr(module.mlp, "linear_fc1", None) - if isinstance(linear_fc1, SharedExpertsLinearFC1LoRA): - base_prefix = f"{layer_prefix}.mlp.linear_fc1" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight( - base_prefix, - linear_fc1.gate_lora, - adapter_key="adapter_gate", - ), - _simple_adapter_weight( - base_prefix, linear_fc1.up_lora, adapter_key="adapter_up" - ), - ] - linear_fc2 = getattr(module.mlp, "linear_fc2", None) - if isinstance(linear_fc2, SharedExpertsLinearFC2LoRA): - base_prefix = f"{layer_prefix}.mlp.linear_fc2" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight( - base_prefix, linear_fc2.row_parallel_lora.lora - ) - ] - - shared_experts = getattr(module.mlp, "shared_experts", None) - if shared_experts is not None: - if isinstance(shared_experts.linear_fc1, SharedExpertsLinearFC1LoRA): - base_prefix = f"{layer_prefix}.mlp.shared_experts.linear_fc1" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight( - base_prefix, - shared_experts.linear_fc1.gate_lora, - adapter_key="adapter_gate", - ), - _simple_adapter_weight( - base_prefix, - shared_experts.linear_fc1.up_lora, - adapter_key="adapter_up", - ), - ] - if isinstance(shared_experts.linear_fc2, SharedExpertsLinearFC2LoRA): - base_prefix = f"{layer_prefix}.mlp.shared_experts.linear_fc2" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight( - base_prefix, - shared_experts.linear_fc2.row_parallel_lora.lora, - ) - ] - return adapter_weights_by_base diff --git a/src/art/megatron/compile_workarounds.py b/src/art/megatron/compile_workarounds.py index 5016c99bb..70e11bcf9 100644 --- a/src/art/megatron/compile_workarounds.py +++ b/src/art/megatron/compile_workarounds.py @@ -1,8 +1,22 @@ from __future__ import annotations +import os +from typing import Any + import torch -_INSTALLED = False +from art.megatron.model_support.spec import CompileWorkaroundConfig + +_INSTALLED_CONFIG: tuple[frozenset[str], str] | None = None + + +def _require_attr(obj: Any, name: str) -> Any: + value = getattr(obj, name, None) + if value is None: + raise RuntimeError( + f"Required compile workaround target is missing: {obj}.{name}" + ) + return value def _disable(fn): @@ -13,26 +27,141 @@ def _disable(fn): return wrapped -def install_torch_compile_workarounds() -> None: - global _INSTALLED - if _INSTALLED: +def _disable_attr(obj: Any, name: str) -> None: + setattr(obj, name, _disable(_require_attr(obj, name))) + + +def _selected_workaround_flags( + config: CompileWorkaroundConfig | None, +) -> set[str]: + raw = os.environ.get("ART_MEGATRON_COMPILE_WORKAROUNDS", "").strip() + if not raw: + return set(() if config is None else config.flags) + if raw.lower() in {"none", "off"}: + return set() + return {part.strip() for part in raw.split(",") if part.strip()} + + +def install_torch_compile_workarounds( + config: CompileWorkaroundConfig | None = None, +) -> None: + global _INSTALLED_CONFIG + flags = _selected_workaround_flags(config) + shared_expert_state = "none" if config is None else config.shared_expert_state + installed_config = (frozenset(flags), shared_expert_state) + if _INSTALLED_CONFIG is not None: + if _INSTALLED_CONFIG != installed_config: + raise RuntimeError( + "torch.compile workarounds already installed with a different config" + ) return - from megatron.core.transformer.moe import moe_utils, token_dispatcher - from megatron.core.transformer.moe.moe_layer import MoELayer - - moe_utils.maybe_move_tensor_to_cpu = _disable(moe_utils.maybe_move_tensor_to_cpu) - token_dispatcher.MoEAlltoAllTokenDispatcher._maybe_dtoh_and_synchronize = _disable( - token_dispatcher.MoEAlltoAllTokenDispatcher._maybe_dtoh_and_synchronize - ) - MoELayer.preprocess = _disable(MoELayer.preprocess) - deepep_manager = getattr(token_dispatcher, "_DeepepManager", None) - if deepep_manager is not None: - deepep_manager.dispatch = _disable(deepep_manager.dispatch) - deepep_manager.combine = _disable(deepep_manager.combine) - deepep_manager.get_permuted_hidden_states_by_experts = _disable( - deepep_manager.get_permuted_hidden_states_by_experts - ) - deepep_manager.get_restored_hidden_states_by_experts = _disable( - deepep_manager.get_restored_hidden_states_by_experts - ) - _INSTALLED = True + from megatron.core.extensions import transformer_engine as te_ext + from megatron.core.transformer.moe import experts as moe_experts + from megatron.core.transformer.moe import moe_layer, moe_utils, token_dispatcher + + if "fake_sync_dealloc" in flags: + try: + + @torch.library.register_fake("streams::sync_dealloc") + def _sync_dealloc_fake( + wait_event_index: int, + src_stream_index: int, + to_dealloc: torch.Tensor, + ) -> None: + del wait_event_index, src_stream_index, to_dealloc + return None + except RuntimeError as exc: + if "already has a fake impl registered" not in str(exc): + raise + + deepep_flags = {"deepep_permute_restore", "deepep_dispatch_combine"} & flags + if deepep_flags: + deepep_manager = _require_attr(token_dispatcher, "_DeepepManager") + if "deepep_permute_restore" in flags: + _disable_attr(deepep_manager, "get_permuted_hidden_states_by_experts") + _disable_attr(deepep_manager, "get_restored_hidden_states_by_experts") + if "deepep_dispatch_combine" in flags: + _disable_attr(deepep_manager, "dispatch") + _disable_attr(deepep_manager, "combine") + if "alltoall_dtoh" in flags: + token_dispatcher.MoEAlltoAllTokenDispatcher._maybe_dtoh_and_synchronize = ( + _disable( + token_dispatcher.MoEAlltoAllTokenDispatcher._maybe_dtoh_and_synchronize + ) + ) + if "alltoall_dispatch_preprocess" in flags: + token_dispatcher.MoEAlltoAllTokenDispatcher.dispatch_preprocess = _disable( + token_dispatcher.MoEAlltoAllTokenDispatcher.dispatch_preprocess + ) + if "alltoall_combine_postprocess" in flags: + token_dispatcher.MoEAlltoAllTokenDispatcher.combine_postprocess = _disable( + token_dispatcher.MoEAlltoAllTokenDispatcher.combine_postprocess + ) + if "te_moe_permute_with_probs" in flags: + from transformer_engine.pytorch import permutation as te_permutation + + te_permutation.moe_permute_with_probs = _disable( + te_permutation.moe_permute_with_probs + ) + if te_ext.fused_permute_with_probs is not None: + te_ext.fused_permute_with_probs = _disable(te_ext.fused_permute_with_probs) + if moe_utils.fused_permute_with_probs is not None: + moe_utils.fused_permute_with_probs = _disable( + moe_utils.fused_permute_with_probs + ) + if "te_triton_permute_with_mask_map" in flags: + from transformer_engine.pytorch.triton import ( + permutation as te_triton_permutation, + ) + + te_triton_permutation.permute_with_mask_map = _disable( + te_triton_permutation.permute_with_mask_map + ) + if "te_moe_unpermute" in flags: + from transformer_engine.pytorch import permutation as te_permutation + + te_permutation.moe_unpermute = _disable(te_permutation.moe_unpermute) + if te_ext.fused_unpermute is not None: + te_ext.fused_unpermute = _disable(te_ext.fused_unpermute) + if moe_utils.fused_unpermute is not None: + moe_utils.fused_unpermute = _disable(moe_utils.fused_unpermute) + if "moe_utils_permute" in flags: + moe_utils.permute = _disable(moe_utils.permute) + if "moe_utils_unpermute" in flags: + moe_utils.unpermute = _disable(moe_utils.unpermute) + if "te_moe_unpermute_backward" in flags: + from transformer_engine.pytorch import permutation as te_permutation + + setattr( + te_permutation._moe_unpermute_mask_map, + "backward", + staticmethod(_disable(te_permutation._moe_unpermute_mask_map.backward)), + ) + if "te_triton_unpermute_bwd_with_merging_probs" in flags: + from transformer_engine.pytorch.triton import ( + permutation as te_triton_permutation, + ) + + te_triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs = _disable( + te_triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs + ) + if "flex_token_dispatch_combine" in flags: + token_dispatcher.MoEFlexTokenDispatcher.token_dispatch = _disable( + token_dispatcher.MoEFlexTokenDispatcher.token_dispatch + ) + token_dispatcher.MoEFlexTokenDispatcher.token_combine = _disable( + token_dispatcher.MoEFlexTokenDispatcher.token_combine + ) + if "moe_preprocess" in flags: + moe_layer.MoELayer.preprocess = _disable(moe_layer.MoELayer.preprocess) + if "moe_forward" in flags: + moe_layer.MoELayer.forward = _disable(moe_layer.MoELayer.forward) + if "moe_routed_experts_compute" in flags: + moe_layer.MoELayer.routed_experts_compute = _disable( + moe_layer.MoELayer.routed_experts_compute + ) + if "grouped_mlp_forward" in flags: + _disable_attr(_require_attr(moe_experts, "GroupedMLP"), "forward") + if "te_grouped_mlp_forward" in flags: + moe_experts.TEGroupedMLP.forward = _disable(moe_experts.TEGroupedMLP.forward) + _INSTALLED_CONFIG = installed_config diff --git a/src/art/megatron/context_parallel/__init__.py b/src/art/megatron/context_parallel/__init__.py new file mode 100644 index 000000000..4818a0639 --- /dev/null +++ b/src/art/megatron/context_parallel/__init__.py @@ -0,0 +1 @@ +"""Minimal context-parallel shared types used by GDN planning.""" diff --git a/src/art/megatron/context_parallel/layout_index.py b/src/art/megatron/context_parallel/layout_index.py new file mode 100644 index 000000000..99fb2c35b --- /dev/null +++ b/src/art/megatron/context_parallel/layout_index.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict + + +class TokenLayoutIndex(BaseModel): + model_config = ConfigDict(frozen=True) + + ownership_ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...] + token_counts_by_rank: tuple[int, ...] diff --git a/src/art/megatron/flex_attention.py b/src/art/megatron/flex_attention.py index e5f583d37..690300762 100644 --- a/src/art/megatron/flex_attention.py +++ b/src/art/megatron/flex_attention.py @@ -1,8 +1,7 @@ """Flex attention plumbing for ART's Megatron backend.""" -from collections.abc import Callable import math -from typing import Any, ClassVar, TypeAlias, cast +from typing import Any, ClassVar, cast from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection @@ -25,32 +24,19 @@ class SharedPrefixAttentionState(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) block_mask: BlockMask - - -CompileOptions: TypeAlias = dict[str, str | int | bool | Callable[..., Any]] + group_ids: Tensor + parent_ids: Tensor class FlexAttentionWrapper(torch.nn.Module): """Compiled `flex_attention` wrapper with Torchtitan-style inductor options.""" - # Torchtitan inductor options for compiling flex attention. - _compile_options: ClassVar[CompileOptions] = { - "max_autotune": True, - "coordinate_descent_tuning": True, - "triton.cudagraphs": False, - } - # Skip Inductor's flex_decoding specialization: it has triggered both - # shared-memory OOMs (triton_flex_decoding) and symbolic-shape assertion - # failures (create_flex_decoding_kernel). The regular flex_attention - # kernel autotunes against the actual hardware smem budget, so this - # stays GPU-agnostic. + # Force the regular flex kernel. The flex-decoding specialization has hit + # shared-memory OOMs and symbolic-shape assertions on long packed training sequences. _kernel_options: ClassVar[FlexKernelOptions] = { "FORCE_USE_FLEX_ATTENTION": True, } - _compiled_flex_attention: ClassVar = torch.compile( - flex_attention, - options=_compile_options, - ) + _compiled_flex_attention: ClassVar = torch.compile(flex_attention) def forward( self, @@ -117,7 +103,11 @@ def _shared_prefix_mask( group_ids.shape[1], device=group_ids.device, ) - return SharedPrefixAttentionState(block_mask=block_mask) + return SharedPrefixAttentionState( + block_mask=block_mask, + group_ids=group_ids, + parent_ids=parent_ids, + ) class FlexDotProductAttention(torch.nn.Module): diff --git a/src/art/megatron/gdn/__init__.py b/src/art/megatron/gdn/__init__.py new file mode 100644 index 000000000..0c62a558d --- /dev/null +++ b/src/art/megatron/gdn/__init__.py @@ -0,0 +1,15 @@ +"""ART helpers for Megatron GatedDeltaNet integration.""" + +from .gdn_shared_prefix import ( + GdnPackedExecutionSpec, + GdnPackedFamilySpec, + GdnSegmentSpec, + parse_gdn_shared_prefix_segments, +) + +__all__ = [ + "GdnPackedExecutionSpec", + "GdnPackedFamilySpec", + "GdnSegmentSpec", + "parse_gdn_shared_prefix_segments", +] diff --git a/src/art/megatron/gdn/conv_gelu.py b/src/art/megatron/gdn/conv_gelu.py new file mode 100644 index 000000000..2da562d3b --- /dev/null +++ b/src/art/megatron/gdn/conv_gelu.py @@ -0,0 +1,1283 @@ +from __future__ import annotations + +from enum import IntEnum +from typing import Any, cast + +import torch +from torch import Tensor +import triton +import triton.language as tl + + +class PackedConvActivation(IntEnum): + NONE = 0 + SILU = 1 + SWISH = 1 + GELU = 2 + + +@triton.jit +def _gelu(x): + return 0.5 * x * (1.0 + tl.erf(x * 0.70710678118654752440)) + + +@triton.jit +def _gelu_grad(x): + cdf = 0.5 * (1.0 + tl.erf(x * 0.70710678118654752440)) + pdf = 0.39894228040143267794 * tl.exp(-0.5 * x * x) + return cdf + x * pdf + + +@triton.jit +def _apply_activation(x, ACTIVATION: tl.constexpr): + if ACTIVATION == 0: + return x + if ACTIVATION == 1: + sigmoid = tl.sigmoid(x) + return x * sigmoid + return _gelu(x) + + +@triton.jit +def _activation_grad(x, ACTIVATION: tl.constexpr): + if ACTIVATION == 0: + return x * 0.0 + 1.0 + if ACTIVATION == 1: + sigmoid = tl.sigmoid(x) + return sigmoid + x * sigmoid * (1.0 - sigmoid) + return _gelu_grad(x) + + +@triton.jit(do_not_specialize=["SEGMENTS"]) +def _segment_for_token( + cu_seqlens, + token, + SEGMENTS, + SEARCH_STEPS: tl.constexpr, +): + lo = tl.zeros(token.shape, dtype=tl.int64) + hi = lo + SEGMENTS.to(tl.int64) - 1 + for _ in tl.static_range(0, SEARCH_STEPS): + mid = (lo + hi + 1) // 2 + mid_start = tl.load(cu_seqlens + mid) + take_upper = mid_start <= token + lo = tl.where(take_upper, mid, lo) + hi = tl.where(take_upper, hi, mid - 1) + return lo + + +@triton.jit(do_not_specialize=["TOTAL_TOKENS", "SEGMENTS"]) +def _packed_conv_token_metadata_kernel( + cu_seqlens, + token_segment, + token_local_t, + TOTAL_TOKENS, + SEGMENTS, + SEARCH_STEPS: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_n = tl.program_id(0) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + token = offs_n.to(tl.int64) + mask = offs_n < TOTAL_TOKENS + segment = _segment_for_token(cu_seqlens, token, SEGMENTS, SEARCH_STEPS) + start = tl.load(cu_seqlens + segment).to(tl.int64) + tl.store(token_segment + token, segment, mask=mask) + tl.store(token_local_t + token, token - start, mask=mask) + + +@triton.jit +def _conv_gelu_fwd_kernel( + qkv, + conv_initial, + weight, + bias, + lengths, + out, + final, + C: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + HAS_BIAS: tl.constexpr, + OUTPUT_FINAL: tl.constexpr, + BLOCK_C: tl.constexpr, + BLOCK_T: tl.constexpr, +): + pid_t = tl.program_id(0) + pid_c = tl.program_id(1) + b = tl.program_id(2) + tail: tl.constexpr = K - 1 + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + offs_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) + c = offs_c[:, None] + t = offs_t[None, :] + b64 = b.to(tl.int64) + c64 = c.to(tl.int64) + t64 = t.to(tl.int64) + offs_c64 = offs_c.to(tl.int64) + mask = (offs_c[:, None] < C) & (offs_t[None, :] < T) + acc = tl.zeros((BLOCK_C, BLOCK_T), dtype=tl.float32) + if HAS_BIAS: + acc += tl.load(bias + offs_c, mask=offs_c < C, other=0.0)[:, None].to( + tl.float32 + ) + for j in tl.static_range(0, K): + ext = t + j + ext64 = ext.to(tl.int64) + from_initial = ext < tail + init_idx = (b64 * C + c64) * tail + ext64 + qkv_idx = (b64 * C + c64) * T + (ext64 - tail) + x_init = tl.load(conv_initial + init_idx, mask=mask & from_initial, other=0.0) + x_qkv = tl.load(qkv + qkv_idx, mask=mask & ~from_initial, other=0.0) + w = tl.load(weight + offs_c * K + j, mask=offs_c < C, other=0.0).to(tl.float32) + acc += (x_init + x_qkv).to(tl.float32) * w[:, None] + tl.store(out + (b64 * C + c64) * T + t64, _gelu(acc), mask=mask) + + if OUTPUT_FINAL: + length = tl.load(lengths + b) + for r in tl.static_range(0, tail): + ext = length + r + ext64 = ext.to(tl.int64) + from_initial = ext < tail + init_idx = (b64 * C + offs_c64) * tail + ext64 + qkv_idx = (b64 * C + offs_c64) * T + (ext64 - tail) + x_init = tl.load( + conv_initial + init_idx, + mask=(pid_t == 0) & (offs_c < C) & from_initial, + other=0.0, + ) + x_qkv = tl.load( + qkv + qkv_idx, + mask=(pid_t == 0) & (offs_c < C) & ~from_initial, + other=0.0, + ) + tl.store( + final + (b64 * C + offs_c64) * tail + r, + x_init + x_qkv, + mask=(pid_t == 0) & (offs_c < C), + ) + + +@triton.jit +def _conv_gelu_grad_preact_kernel( + qkv, + conv_initial, + weight, + bias, + grad_out, + grad_preact, + C: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_C: tl.constexpr, + BLOCK_T: tl.constexpr, +): + pid_t = tl.program_id(0) + pid_c = tl.program_id(1) + b = tl.program_id(2) + tail: tl.constexpr = K - 1 + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + offs_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) + c = offs_c[:, None] + t = offs_t[None, :] + b64 = b.to(tl.int64) + c64 = c.to(tl.int64) + t64 = t.to(tl.int64) + mask = (offs_c[:, None] < C) & (offs_t[None, :] < T) + acc = tl.zeros((BLOCK_C, BLOCK_T), dtype=tl.float32) + if HAS_BIAS: + acc += tl.load(bias + offs_c, mask=offs_c < C, other=0.0)[:, None].to( + tl.float32 + ) + for j in tl.static_range(0, K): + ext = t + j + ext64 = ext.to(tl.int64) + from_initial = ext < tail + init_idx = (b64 * C + c64) * tail + ext64 + qkv_idx = (b64 * C + c64) * T + (ext64 - tail) + x_init = tl.load(conv_initial + init_idx, mask=mask & from_initial, other=0.0) + x_qkv = tl.load(qkv + qkv_idx, mask=mask & ~from_initial, other=0.0) + w = tl.load(weight + offs_c * K + j, mask=offs_c < C, other=0.0).to(tl.float32) + acc += (x_init + x_qkv).to(tl.float32) * w[:, None] + out_idx = (b64 * C + c64) * T + t64 + go = tl.load(grad_out + out_idx, mask=mask, other=0.0).to(tl.float32) + tl.store(grad_preact + out_idx, go * _gelu_grad(acc), mask=mask) + + +@triton.jit +def _conv_gelu_bwd_input_kernel( + grad_preact, + weight, + lengths, + grad_final, + grad_qkv, + grad_initial, + C: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + HAS_FINAL_GRAD: tl.constexpr, + BLOCK_C: tl.constexpr, + BLOCK_E: tl.constexpr, +): + pid_e = tl.program_id(0) + pid_c = tl.program_id(1) + b = tl.program_id(2) + tail: tl.constexpr = K - 1 + ext_len: tl.constexpr = T + K - 1 + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + offs_e = pid_e * BLOCK_E + tl.arange(0, BLOCK_E) + c = offs_c[:, None] + e = offs_e[None, :] + b64 = b.to(tl.int64) + c64 = c.to(tl.int64) + e64 = e.to(tl.int64) + mask = (offs_c[:, None] < C) & (offs_e[None, :] < ext_len) + acc = tl.zeros((BLOCK_C, BLOCK_E), dtype=tl.float32) + for j in tl.static_range(0, K): + t = e - j + t64 = t.to(tl.int64) + valid = mask & (t >= 0) & (t < T) + gz = tl.load(grad_preact + (b64 * C + c64) * T + t64, mask=valid, other=0.0) + w = tl.load(weight + offs_c * K + j, mask=offs_c < C, other=0.0).to(tl.float32) + acc += gz.to(tl.float32) * w[:, None] + if HAS_FINAL_GRAD: + length = tl.load(lengths + b) + r = e - length + r64 = r.to(tl.int64) + valid_final = mask & (r >= 0) & (r < tail) + gf = tl.load( + grad_final + (b64 * C + c64) * tail + r64, + mask=valid_final, + other=0.0, + ) + acc += gf.to(tl.float32) + + init_mask = mask & (e < tail) + qkv_mask = mask & (e >= tail) + tl.store(grad_initial + (b64 * C + c64) * tail + e64, acc, mask=init_mask) + tl.store(grad_qkv + (b64 * C + c64) * T + (e64 - tail), acc, mask=qkv_mask) + + +@triton.jit +def _conv_gelu_bwd_weight_kernel( + qkv, + conv_initial, + grad_preact, + grad_weight, + grad_bias, + C: tl.constexpr, + B: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_BT: tl.constexpr, +): + c = tl.program_id(0) + tail: tl.constexpr = K - 1 + bt_total: tl.constexpr = B * T + offsets = tl.arange(0, BLOCK_BT) + bias_acc = tl.zeros((BLOCK_BT,), dtype=tl.float32) + for j in tl.static_range(0, K): + weight_acc = tl.zeros((BLOCK_BT,), dtype=tl.float32) + for start in range(0, bt_total, BLOCK_BT): + bt = start + offsets + mask = bt < bt_total + b = bt // T + t = bt - b * T + b64 = b.to(tl.int64) + t64 = t.to(tl.int64) + c64 = c.to(tl.int64) + gz = tl.load(grad_preact + (b64 * C + c64) * T + t64, mask=mask, other=0.0) + ext = t + j + ext64 = ext.to(tl.int64) + from_initial = ext < tail + init_idx = (b64 * C + c64) * tail + ext64 + qkv_idx = (b64 * C + c64) * T + (ext64 - tail) + x_init = tl.load( + conv_initial + init_idx, mask=mask & from_initial, other=0.0 + ) + x_qkv = tl.load(qkv + qkv_idx, mask=mask & ~from_initial, other=0.0) + weight_acc += gz.to(tl.float32) * (x_init + x_qkv).to(tl.float32) + if HAS_BIAS and j == 0: + bias_acc += gz.to(tl.float32) + tl.store(grad_weight + c * K + j, tl.sum(weight_acc, axis=0)) + if HAS_BIAS: + tl.store(grad_bias + c, tl.sum(bias_acc, axis=0)) + + +@triton.jit(do_not_specialize=["TOTAL_TOKENS"]) +def _packed_conv_fwd_kernel( + conv_in, + token_segment, + token_local_t, + conv_initial, + weight, + bias, + out, + C: tl.constexpr, + TOTAL_TOKENS, + K: tl.constexpr, + HAS_BIAS: tl.constexpr, + ACTIVATION: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + tail: tl.constexpr = K - 1 + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + token = offs_n.to(tl.int64) + segment = tl.load(token_segment + token, mask=offs_n < TOTAL_TOKENS, other=0).to( + tl.int64 + ) + local_t = tl.load(token_local_t + token, mask=offs_n < TOTAL_TOKENS, other=0).to( + tl.int64 + ) + n = offs_n[:, None].to(tl.int64) + c = offs_c[None, :].to(tl.int64) + segment_bc = segment[:, None].to(tl.int64) + local_t_bc = local_t[:, None] + mask = (offs_n[:, None] < TOTAL_TOKENS) & (offs_c[None, :] < C) + acc = tl.zeros((BLOCK_N, BLOCK_C), dtype=tl.float32) + if HAS_BIAS: + acc += tl.load(bias + offs_c, mask=offs_c < C, other=0.0)[None, :].to( + tl.float32 + ) + for j in tl.static_range(0, K): + ext = local_t_bc + j + from_initial = ext < tail + init_idx = (segment_bc * C + c) * tail + ext + in_idx = (n + j - tail) * C + c + x_init = tl.load(conv_initial + init_idx, mask=mask & from_initial, other=0.0) + x_in = tl.load(conv_in + in_idx, mask=mask & ~from_initial, other=0.0) + w = tl.load(weight + offs_c * K + j, mask=offs_c < C, other=0.0).to(tl.float32) + acc += (x_init + x_in).to(tl.float32) * w[None, :] + tl.store(out + n * C + c, _apply_activation(acc, ACTIVATION), mask=mask) + + +@triton.jit +def _packed_conv_final_kernel( + conv_in, + cu_seqlens, + conv_initial, + final, + C: tl.constexpr, + K: tl.constexpr, + BLOCK_C: tl.constexpr, + BLOCK_R: tl.constexpr, +): + pid_r = tl.program_id(0) + pid_c = tl.program_id(1) + segment = tl.program_id(2) + tail: tl.constexpr = K - 1 + offs_r = pid_r * BLOCK_R + tl.arange(0, BLOCK_R) + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + start = tl.load(cu_seqlens + segment).to(tl.int64) + end = tl.load(cu_seqlens + segment + 1).to(tl.int64) + length = end - start + r = offs_r[:, None].to(tl.int64) + c = offs_c[None, :].to(tl.int64) + ext = length + r + from_initial = ext < tail + mask = (offs_r[:, None] < tail) & (offs_c[None, :] < C) + init_idx = (segment.to(tl.int64) * C + c) * tail + ext + in_idx = (start + ext - tail) * C + c + x_init = tl.load(conv_initial + init_idx, mask=mask & from_initial, other=0.0) + x_in = tl.load(conv_in + in_idx, mask=mask & ~from_initial, other=0.0) + tl.store( + final + (segment.to(tl.int64) * C + c) * tail + r, + x_init + x_in, + mask=mask, + ) + + +@triton.jit(do_not_specialize=["TOTAL_TOKENS"]) +def _packed_conv_grad_preact_weight_partial_kernel( + conv_in, + token_segment, + token_local_t, + conv_initial, + weight, + bias, + grad_out, + grad_preact, + grad_weight_partial, + grad_bias_partial, + C: tl.constexpr, + TOTAL_TOKENS, + CHANNEL_TILES: tl.constexpr, + K: tl.constexpr, + HAS_BIAS: tl.constexpr, + ACTIVATION: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + tail: tl.constexpr = K - 1 + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + token = offs_n.to(tl.int64) + segment = tl.load(token_segment + token, mask=offs_n < TOTAL_TOKENS, other=0).to( + tl.int64 + ) + local_t = tl.load(token_local_t + token, mask=offs_n < TOTAL_TOKENS, other=0).to( + tl.int64 + ) + n = offs_n[:, None].to(tl.int64) + c = offs_c[None, :].to(tl.int64) + segment_bc = segment[:, None].to(tl.int64) + local_t_bc = local_t[:, None] + mask = (offs_n[:, None] < TOTAL_TOKENS) & (offs_c[None, :] < C) + acc = tl.zeros((BLOCK_N, BLOCK_C), dtype=tl.float32) + if HAS_BIAS: + acc += tl.load(bias + offs_c, mask=offs_c < C, other=0.0)[None, :].to( + tl.float32 + ) + for j in tl.static_range(0, K): + ext = local_t_bc + j + from_initial = ext < tail + init_idx = (segment_bc * C + c) * tail + ext + in_idx = (n + j - tail) * C + c + x_init = tl.load(conv_initial + init_idx, mask=mask & from_initial, other=0.0) + x_in = tl.load(conv_in + in_idx, mask=mask & ~from_initial, other=0.0) + w = tl.load(weight + offs_c * K + j, mask=offs_c < C, other=0.0).to(tl.float32) + acc += (x_init + x_in).to(tl.float32) * w[None, :] + go = tl.load(grad_out + n * C + c, mask=mask, other=0.0).to(tl.float32) + gz = go * _activation_grad(acc, ACTIVATION) + tl.store( + grad_preact + n * C + c, + gz, + mask=mask, + ) + partial_base = (pid_n * CHANNEL_TILES + pid_c) * K * BLOCK_C + partial_c = tl.arange(0, BLOCK_C) + for j in tl.static_range(0, K): + ext = local_t_bc + j + from_initial = ext < tail + init_idx = (segment_bc * C + c) * tail + ext + in_idx = (n + j - tail) * C + c + x_init = tl.load(conv_initial + init_idx, mask=mask & from_initial, other=0.0) + x_in = tl.load(conv_in + in_idx, mask=mask & ~from_initial, other=0.0) + weight_partial = tl.sum(gz * (x_init + x_in).to(tl.float32), axis=0) + tl.store( + grad_weight_partial + partial_base + j * BLOCK_C + partial_c, + weight_partial, + mask=offs_c < C, + ) + if HAS_BIAS: + bias_partial = tl.sum(gz, axis=0) + tl.store( + grad_bias_partial + (pid_n * CHANNEL_TILES + pid_c) * BLOCK_C + partial_c, + bias_partial, + mask=offs_c < C, + ) + + +@triton.jit(do_not_specialize=["TOTAL_TOKENS"]) +def _packed_conv_bwd_input_kernel( + cu_seqlens, + token_segment, + weight, + grad_preact, + grad_final, + grad_conv_in, + C: tl.constexpr, + TOTAL_TOKENS, + K: tl.constexpr, + HAS_FINAL_GRAD: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + tail: tl.constexpr = K - 1 + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + token = offs_n.to(tl.int64) + segment = tl.load(token_segment + token, mask=offs_n < TOTAL_TOKENS, other=0).to( + tl.int64 + ) + end = tl.load(cu_seqlens + segment + 1).to(tl.int64) + out_token_base = token[:, None] + tail + c = offs_c[None, :].to(tl.int64) + mask = (offs_n[:, None] < TOTAL_TOKENS) & (offs_c[None, :] < C) + acc = tl.zeros((BLOCK_N, BLOCK_C), dtype=tl.float32) + for j in tl.static_range(0, K): + out_token = out_token_base - j + valid = mask & (out_token < end[:, None]) + gz = tl.load( + grad_preact + out_token * C + c, + mask=valid, + other=0.0, + ) + w = tl.load(weight + offs_c * K + j, mask=offs_c < C, other=0.0).to(tl.float32) + acc += gz.to(tl.float32) * w[None, :] + if HAS_FINAL_GRAD: + r = out_token_base - end[:, None] + valid_final = mask & (r >= 0) & (r < tail) + gf = tl.load( + grad_final + (segment[:, None].to(tl.int64) * C + c) * tail + r, + mask=valid_final, + other=0.0, + ) + acc += gf.to(tl.float32) + tl.store(grad_conv_in + token[:, None] * C + c, acc, mask=mask) + + +@triton.jit +def _packed_conv_bwd_initial_kernel( + cu_seqlens, + weight, + grad_preact, + grad_final, + grad_initial, + C: tl.constexpr, + K: tl.constexpr, + HAS_FINAL_GRAD: tl.constexpr, + BLOCK_C: tl.constexpr, + BLOCK_R: tl.constexpr, +): + pid_r = tl.program_id(0) + pid_c = tl.program_id(1) + segment = tl.program_id(2) + tail: tl.constexpr = K - 1 + offs_r = pid_r * BLOCK_R + tl.arange(0, BLOCK_R) + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + start = tl.load(cu_seqlens + segment).to(tl.int64) + end = tl.load(cu_seqlens + segment + 1).to(tl.int64) + length = end - start + e = offs_r[:, None].to(tl.int64) + c = offs_c[None, :].to(tl.int64) + mask = (offs_r[:, None] < tail) & (offs_c[None, :] < C) + acc = tl.zeros((BLOCK_R, BLOCK_C), dtype=tl.float32) + for j in tl.static_range(0, K): + out_t = e - j + valid = mask & (out_t >= 0) & (out_t < length) + gz = tl.load(grad_preact + (start + out_t) * C + c, mask=valid, other=0.0) + w = tl.load(weight + offs_c * K + j, mask=offs_c < C, other=0.0).to(tl.float32) + acc += gz.to(tl.float32) * w[None, :] + if HAS_FINAL_GRAD: + r = e - length + valid_final = mask & (r >= 0) & (r < tail) + gf = tl.load( + grad_final + (segment.to(tl.int64) * C + c) * tail + r, + mask=valid_final, + other=0.0, + ) + acc += gf.to(tl.float32) + tl.store(grad_initial + (segment.to(tl.int64) * C + c) * tail + e, acc, mask=mask) + + +@triton.jit(do_not_specialize=["TOKEN_TILES"]) +def _packed_conv_bwd_weight_reduce_kernel( + grad_weight_partial, + grad_weight, + C: tl.constexpr, + TOKEN_TILES, + CHANNEL_TILES: tl.constexpr, + K: tl.constexpr, + BLOCK_C: tl.constexpr, + BLOCK_TILES: tl.constexpr, +): + pid_c = tl.program_id(0) + j = tl.program_id(1) + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + c_mask = offs_c < C + partial_c = tl.arange(0, BLOCK_C) + tile_offsets = tl.arange(0, BLOCK_TILES) + weight_acc = tl.zeros((BLOCK_TILES, BLOCK_C), dtype=tl.float32) + start_tile = 0 + while start_tile < TOKEN_TILES: + tile = start_tile + tile_offsets + partial_idx = ( + (tile[:, None] * CHANNEL_TILES + pid_c) * K + j + ) * BLOCK_C + partial_c[None, :] + weight_acc += tl.load( + grad_weight_partial + partial_idx, + mask=(tile[:, None] < TOKEN_TILES) & c_mask[None, :], + other=0.0, + ) + start_tile += BLOCK_TILES + tl.store(grad_weight + offs_c * K + j, tl.sum(weight_acc, axis=0), mask=c_mask) + + +@triton.jit(do_not_specialize=["TOKEN_TILES"]) +def _packed_conv_bwd_bias_reduce_kernel( + grad_bias_partial, + grad_bias, + C: tl.constexpr, + TOKEN_TILES, + CHANNEL_TILES: tl.constexpr, + BLOCK_C: tl.constexpr, + BLOCK_TILES: tl.constexpr, +): + pid_c = tl.program_id(0) + offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + c_mask = offs_c < C + partial_c = tl.arange(0, BLOCK_C) + tile_offsets = tl.arange(0, BLOCK_TILES) + bias_acc = tl.zeros((BLOCK_TILES, BLOCK_C), dtype=tl.float32) + start_tile = 0 + while start_tile < TOKEN_TILES: + tile = start_tile + tile_offsets + partial_idx = (tile[:, None] * CHANNEL_TILES + pid_c) * BLOCK_C + partial_c[ + None, : + ] + bias_acc += tl.load( + grad_bias_partial + partial_idx, + mask=(tile[:, None] < TOKEN_TILES) & c_mask[None, :], + other=0.0, + ) + start_tile += BLOCK_TILES + tl.store(grad_bias + offs_c, tl.sum(bias_acc, axis=0), mask=c_mask) + + +class _VarlenCausalConvGelu(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + qkv: Tensor, + conv_initial: Tensor, + weight: Tensor, + bias: Tensor | None, + lengths: Tensor, + output_final_state: bool, + ) -> tuple[Tensor, Tensor | None]: + _validate_inputs(qkv, conv_initial, weight, bias, lengths) + qkv = qkv.contiguous() + conv_initial = conv_initial.contiguous() + weight = weight.contiguous() + bias_tensor = ( + bias.contiguous() + if bias is not None + else torch.empty((0,), device=qkv.device, dtype=qkv.dtype) + ) + lengths = lengths.contiguous() + batch, channels, max_len = qkv.shape + kernel_width = int(weight.shape[1]) + out = torch.empty_like(qkv) + final = ( + torch.empty( + (batch, channels, kernel_width - 1), + device=qkv.device, + dtype=qkv.dtype, + ) + if output_final_state + else None + ) + block_c, block_t, num_warps = _tile_config(channels, max_len) + grid = (triton.cdiv(max_len, block_t), triton.cdiv(channels, block_c), batch) + cast(Any, _conv_gelu_fwd_kernel)[grid]( + qkv, + conv_initial, + weight, + bias_tensor, + lengths, + out, + out if final is None else final, + channels, + max_len, + kernel_width, + HAS_BIAS=bias is not None, + OUTPUT_FINAL=output_final_state, + BLOCK_C=block_c, + BLOCK_T=block_t, + num_warps=num_warps, + ) + ctx.save_for_backward(qkv, conv_initial, weight, bias_tensor, lengths) + ctx.has_bias = bias is not None + ctx.output_final_state = bool(output_final_state) + ctx.tile = (block_c, block_t, num_warps) + return out, final + + @staticmethod + def backward( + ctx: Any, *grad_outputs: Any + ) -> tuple[Tensor, Tensor, Tensor, Tensor | None, None, None]: + grad_out, grad_final = grad_outputs + qkv, conv_initial, weight, bias, lengths = ctx.saved_tensors + grad_out = grad_out.contiguous() + grad_final_tensor = ( + grad_final.contiguous() + if grad_final is not None + else torch.empty((0,), device=qkv.device, dtype=qkv.dtype) + ) + batch, channels, max_len = qkv.shape + kernel_width = int(weight.shape[1]) + grad_qkv = torch.empty_like(qkv) + grad_initial = torch.empty_like(conv_initial) + grad_weight = torch.empty_like(weight) + grad_bias = torch.empty_like(bias) if bool(ctx.has_bias) else None + grad_preact = torch.empty(qkv.shape, device=qkv.device, dtype=torch.float32) + block_c, block_t, num_warps = ctx.tile + grid_t = ( + triton.cdiv(max_len, block_t), + triton.cdiv(channels, block_c), + batch, + ) + cast(Any, _conv_gelu_grad_preact_kernel)[grid_t]( + qkv, + conv_initial, + weight, + bias, + grad_out, + grad_preact, + channels, + max_len, + kernel_width, + HAS_BIAS=bool(ctx.has_bias), + BLOCK_C=block_c, + BLOCK_T=block_t, + num_warps=num_warps, + ) + ext_len = max_len + kernel_width - 1 + grid_e = ( + triton.cdiv(ext_len, block_t), + triton.cdiv(channels, block_c), + batch, + ) + cast(Any, _conv_gelu_bwd_input_kernel)[grid_e]( + grad_preact, + weight, + lengths, + grad_final_tensor, + grad_qkv, + grad_initial, + channels, + max_len, + kernel_width, + HAS_FINAL_GRAD=grad_final is not None, + BLOCK_C=block_c, + BLOCK_E=block_t, + num_warps=num_warps, + ) + reduce_block = 1024 + cast(Any, _conv_gelu_bwd_weight_kernel)[(channels,)]( + qkv, + conv_initial, + grad_preact, + grad_weight, + grad_bias if grad_bias is not None else grad_weight, + channels, + batch, + max_len, + kernel_width, + HAS_BIAS=bool(ctx.has_bias), + BLOCK_BT=reduce_block, + num_warps=8, + ) + return grad_qkv, grad_initial, grad_weight, grad_bias, None, None + + +class _PackedVarlenCausalConv(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + conv_in: Tensor, + cu_seqlens: Tensor, + conv_initial: Tensor, + weight: Tensor, + bias: Tensor | None, + output_final_state: bool, + activation: str | PackedConvActivation, + ) -> tuple[Tensor, Tensor | None]: + activation_code = _activation_code(activation) + _validate_packed_inputs(conv_in, cu_seqlens, conv_initial, weight, bias) + conv_in = conv_in.contiguous() + cu_seqlens = cu_seqlens.contiguous() + conv_initial = conv_initial.contiguous() + weight = weight.contiguous() + bias_tensor = ( + bias.contiguous() + if bias is not None + else torch.empty((0,), device=conv_in.device, dtype=conv_in.dtype) + ) + _assert_valid_cu_seqlens(cu_seqlens, int(conv_in.shape[0])) + total_tokens, channels = conv_in.shape + segments = int(cu_seqlens.numel()) - 1 + kernel_width = int(weight.shape[1]) + out = torch.empty_like(conv_in) + final = ( + torch.empty( + (segments, channels, kernel_width - 1), + device=conv_in.device, + dtype=conv_in.dtype, + ) + if output_final_state + else None + ) + block_n, block_c, num_warps = _packed_tile_config(channels) + search_steps = _search_steps(segments) + metadata_dtype = ( + torch.long + if max(total_tokens, segments) > torch.iinfo(torch.int32).max + else torch.int32 + ) + token_segment = torch.empty( + (total_tokens,), device=conv_in.device, dtype=metadata_dtype + ) + token_local_t = torch.empty_like(token_segment) + if total_tokens > 0: + metadata_block_n = 256 + cast(Any, _packed_conv_token_metadata_kernel)[ + (triton.cdiv(total_tokens, metadata_block_n),) + ]( + cu_seqlens, + token_segment, + token_local_t, + total_tokens, + segments, + search_steps, + BLOCK_N=metadata_block_n, + num_warps=4, + ) + cast(Any, _packed_conv_fwd_kernel)[ + (triton.cdiv(total_tokens, block_n), triton.cdiv(channels, block_c)) + ]( + conv_in, + token_segment, + token_local_t, + conv_initial, + weight, + bias_tensor, + out, + channels, + total_tokens, + kernel_width, + HAS_BIAS=bias is not None, + ACTIVATION=activation_code, + BLOCK_N=block_n, + BLOCK_C=block_c, + num_warps=num_warps, + ) + if final is not None and kernel_width > 1 and segments > 0: + block_r = _tail_block(kernel_width - 1) + cast(Any, _packed_conv_final_kernel)[ + ( + triton.cdiv(kernel_width - 1, block_r), + triton.cdiv(channels, block_c), + segments, + ) + ]( + conv_in, + cu_seqlens, + conv_initial, + final, + channels, + kernel_width, + BLOCK_C=block_c, + BLOCK_R=block_r, + num_warps=num_warps, + ) + ctx.save_for_backward( + conv_in, + cu_seqlens, + token_segment, + token_local_t, + conv_initial, + weight, + bias_tensor, + ) + ctx.has_bias = bias is not None + ctx.has_final = bool(output_final_state) + ctx.activation = activation_code + ctx.tile = (block_n, block_c, num_warps) + return out, final + + @staticmethod + def backward( + ctx: Any, *grad_outputs: Any + ) -> tuple[Tensor, None, Tensor, Tensor, Tensor | None, None, None]: + grad_out, grad_final = grad_outputs + ( + conv_in, + cu_seqlens, + token_segment, + token_local_t, + conv_initial, + weight, + bias, + ) = ctx.saved_tensors + grad_out = grad_out.contiguous() + grad_final_tensor = ( + grad_final.contiguous() + if grad_final is not None + else torch.empty((0,), device=conv_in.device, dtype=conv_in.dtype) + ) + total_tokens, channels = conv_in.shape + segments = int(cu_seqlens.numel()) - 1 + kernel_width = int(weight.shape[1]) + grad_conv_in = torch.empty_like(conv_in) + grad_initial = torch.empty_like(conv_initial) + grad_weight = torch.empty_like(weight) + grad_bias = torch.empty_like(bias) if bool(ctx.has_bias) else None + block_n, block_c, num_warps = ctx.tile + grad_preact = torch.empty( + conv_in.shape, device=conv_in.device, dtype=torch.float32 + ) + if total_tokens > 0: + token_tiles = triton.cdiv(total_tokens, block_n) + channel_tiles = triton.cdiv(channels, block_c) + grad_weight_partial = torch.empty( + (token_tiles, channel_tiles, kernel_width, block_c), + device=conv_in.device, + dtype=torch.float32, + ) + grad_bias_partial = ( + torch.empty( + (token_tiles, channel_tiles, block_c), + device=conv_in.device, + dtype=torch.float32, + ) + if bool(ctx.has_bias) + else torch.empty((0,), device=conv_in.device, dtype=torch.float32) + ) + grid_n = ( + token_tiles, + channel_tiles, + ) + cast(Any, _packed_conv_grad_preact_weight_partial_kernel)[grid_n]( + conv_in, + token_segment, + token_local_t, + conv_initial, + weight, + bias, + grad_out, + grad_preact, + grad_weight_partial, + grad_bias_partial, + channels, + total_tokens, + channel_tiles, + kernel_width, + HAS_BIAS=bool(ctx.has_bias), + ACTIVATION=int(ctx.activation), + BLOCK_N=block_n, + BLOCK_C=block_c, + num_warps=num_warps, + ) + cast(Any, _packed_conv_bwd_input_kernel)[grid_n]( + cu_seqlens, + token_segment, + weight, + grad_preact, + grad_final_tensor, + grad_conv_in, + channels, + total_tokens, + kernel_width, + HAS_FINAL_GRAD=grad_final is not None, + BLOCK_N=block_n, + BLOCK_C=block_c, + num_warps=num_warps, + ) + cast(Any, _packed_conv_bwd_weight_reduce_kernel)[ + (channel_tiles, kernel_width) + ]( + grad_weight_partial, + grad_weight, + channels, + token_tiles, + channel_tiles, + kernel_width, + BLOCK_C=block_c, + BLOCK_TILES=64, + num_warps=4, + ) + if grad_bias is not None: + cast(Any, _packed_conv_bwd_bias_reduce_kernel)[(channel_tiles,)]( + grad_bias_partial, + grad_bias, + channels, + token_tiles, + channel_tiles, + BLOCK_C=block_c, + BLOCK_TILES=64, + num_warps=4, + ) + else: + grad_conv_in = torch.zeros_like(conv_in) + grad_weight = torch.zeros_like(weight) + if grad_bias is not None: + grad_bias = torch.zeros_like(bias) + if kernel_width > 1 and segments > 0: + block_r = _tail_block(kernel_width - 1) + cast(Any, _packed_conv_bwd_initial_kernel)[ + ( + triton.cdiv(kernel_width - 1, block_r), + triton.cdiv(channels, block_c), + segments, + ) + ]( + cu_seqlens, + weight, + grad_preact, + grad_final_tensor, + grad_initial, + channels, + kernel_width, + HAS_FINAL_GRAD=grad_final is not None, + BLOCK_C=block_c, + BLOCK_R=block_r, + num_warps=num_warps, + ) + else: + grad_initial = torch.zeros_like(conv_initial) + return grad_conv_in, None, grad_initial, grad_weight, grad_bias, None, None + + +def packed_varlen_causal_conv( + conv_in: Tensor, + cu_seqlens: Tensor, + conv_initial: Tensor, + weight: Tensor, + bias: Tensor | None, + *, + activation: str | PackedConvActivation = PackedConvActivation.GELU, + output_final_state: bool = True, +) -> tuple[Tensor, Tensor | None]: + """Run packed-varlen causal depthwise conv over real tokens only. + + ``conv_in`` is compact ``[total_real_tokens, channels]`` data and + ``cu_seqlens`` is the exclusive prefix sum for segment lengths. The returned + output has the same compact token layout. ``conv_initial`` and the optional + final state keep the recurrent tail layout ``[segments, channels, K - 1]``. + """ + + return _PackedVarlenCausalConv.apply( + conv_in, + cu_seqlens, + conv_initial, + weight, + bias, + output_final_state, + activation, + ) + + +def packed_varlen_causal_conv_gelu( + conv_in: Tensor, + cu_seqlens: Tensor, + conv_initial: Tensor, + weight: Tensor, + bias: Tensor | None, + *, + output_final_state: bool = True, +) -> tuple[Tensor, Tensor | None]: + return packed_varlen_causal_conv( + conv_in, + cu_seqlens, + conv_initial, + weight, + bias, + activation=PackedConvActivation.GELU, + output_final_state=output_final_state, + ) + + +def varlen_causal_conv_gelu( + qkv: Tensor, + conv_initial: Tensor, + weight: Tensor, + bias: Tensor | None, + lengths: Tensor, + *, + output_final_state: bool = True, +) -> tuple[Tensor, Tensor | None]: + """Run ART GDN's prepared-varlen causal depthwise conv followed by GELU. + + Inputs use the existing prepared GDN layout: ``qkv`` is ``[segments, channels, + max_len]`` with padded positions already zeroed, ``conv_initial`` is + ``[segments, channels, kernel_width - 1]``, and ``lengths`` contains each + segment's real token count. The dense output intentionally matches the + current production conv path over the padded tensor; callers can keep using + the existing real-token mask after this fused operation. + """ + + return _VarlenCausalConvGelu.apply( + qkv, conv_initial, weight, bias, lengths, output_final_state + ) + + +def gdn_varlen_causal_conv_gelu( + gdn: Any, + qkv: Tensor, + conv_initial: Tensor, + lengths: Tensor, + *, + output_final_state: bool = True, +) -> tuple[Tensor, Tensor | None]: + if str(getattr(gdn, "activation", "")) != "gelu": + raise ValueError( + "fused varlen causal conv is only defined for GDN GELU activation, " + f"got {getattr(gdn, 'activation', None)!r}" + ) + return varlen_causal_conv_gelu( + qkv, + conv_initial, + gdn.conv1d.weight.squeeze(1), + gdn.conv1d.bias, + lengths, + output_final_state=output_final_state, + ) + + +def _tile_config(channels: int, max_len: int) -> tuple[int, int, int]: + del channels + if max_len >= 512: + return 2, 128, 4 + return 4, 64, 4 + + +def _packed_tile_config(channels: int) -> tuple[int, int, int]: + del channels + return 128, 16, 4 + + +def _tail_block(tail: int) -> int: + return max(1, min(16, 1 << (tail - 1).bit_length())) + + +def _search_steps(segments: int) -> int: + return max(1, (segments - 1).bit_length()) + + +def _activation_code(activation: str | PackedConvActivation) -> int: + if isinstance(activation, PackedConvActivation): + return int(activation) + activation_key = str(activation).lower() + if activation_key == "none": + return int(PackedConvActivation.NONE) + if activation_key in ("silu", "swish"): + return int(PackedConvActivation.SILU) + if activation_key == "gelu": + return int(PackedConvActivation.GELU) + raise ValueError( + "packed varlen causal conv activation must be one of " + "'none', 'silu', 'swish', or 'gelu'; got " + f"{activation!r}" + ) + + +def _assert_valid_cu_seqlens(cu_seqlens: Tensor, total_tokens: int) -> None: + torch._assert_async(cu_seqlens[0] == 0) + torch._assert_async(cu_seqlens[-1] == total_tokens) + if cu_seqlens.numel() > 1: + torch._assert_async(torch.all(cu_seqlens[1:] >= cu_seqlens[:-1])) + + +def _validate_inputs( + qkv: Tensor, + conv_initial: Tensor, + weight: Tensor, + bias: Tensor | None, + lengths: Tensor, +) -> None: + if not qkv.is_cuda: + raise ValueError("qkv must be a CUDA tensor") + if qkv.ndim != 3: + raise ValueError(f"qkv must be [segments, channels, max_len], got {qkv.shape}") + if conv_initial.ndim != 3: + raise ValueError( + "conv_initial must be [segments, channels, kernel_width - 1], " + f"got {conv_initial.shape}" + ) + if weight.ndim != 2: + raise ValueError(f"weight must be [channels, kernel_width], got {weight.shape}") + batch, channels, _ = qkv.shape + kernel_width = int(weight.shape[1]) + if kernel_width < 1: + raise ValueError("kernel_width must be at least 1") + if tuple(conv_initial.shape) != (batch, channels, kernel_width - 1): + raise ValueError( + "conv_initial shape must match qkv and weight tail, got " + f"qkv={tuple(qkv.shape)} conv_initial={tuple(conv_initial.shape)} " + f"weight={tuple(weight.shape)}" + ) + if int(weight.shape[0]) != channels: + raise ValueError( + f"weight channels {int(weight.shape[0])} must match qkv channels {channels}" + ) + if bias is not None and tuple(bias.shape) != (channels,): + raise ValueError(f"bias must be [channels], got {tuple(bias.shape)}") + if tuple(lengths.shape) != (batch,): + raise ValueError(f"lengths must be [segments], got {tuple(lengths.shape)}") + if lengths.device != qkv.device: + raise ValueError("lengths must be on the same CUDA device as qkv") + if lengths.dtype not in (torch.int32, torch.int64): + raise ValueError(f"lengths must be int32 or int64, got {lengths.dtype}") + for name, tensor in ( + ("conv_initial", conv_initial), + ("weight", weight), + ("bias", bias), + ): + if tensor is not None and tensor.device != qkv.device: + raise ValueError(f"{name} must be on the same CUDA device as qkv") + if tensor is not None and tensor.dtype != qkv.dtype: + raise ValueError(f"{name} dtype {tensor.dtype} must match qkv {qkv.dtype}") + + +def _validate_packed_inputs( + conv_in: Tensor, + cu_seqlens: Tensor, + conv_initial: Tensor, + weight: Tensor, + bias: Tensor | None, +) -> None: + if not conv_in.is_cuda: + raise ValueError("conv_in must be a CUDA tensor") + if conv_in.ndim != 2: + raise ValueError( + f"conv_in must be [total_real_tokens, channels], got {conv_in.shape}" + ) + if cu_seqlens.ndim != 1: + raise ValueError(f"cu_seqlens must be [segments + 1], got {cu_seqlens.shape}") + if cu_seqlens.numel() < 1: + raise ValueError("cu_seqlens must contain at least the leading zero") + if cu_seqlens.device != conv_in.device: + raise ValueError("cu_seqlens must be on the same CUDA device as conv_in") + if cu_seqlens.dtype not in (torch.int32, torch.int64): + raise ValueError(f"cu_seqlens must be int32 or int64, got {cu_seqlens.dtype}") + if conv_initial.ndim != 3: + raise ValueError( + "conv_initial must be [segments, channels, kernel_width - 1], " + f"got {conv_initial.shape}" + ) + if weight.ndim != 2: + raise ValueError(f"weight must be [channels, kernel_width], got {weight.shape}") + total_tokens, channels = conv_in.shape + segments = int(cu_seqlens.numel()) - 1 + if total_tokens > 0 and segments == 0: + raise ValueError("cu_seqlens must describe at least one segment for conv_in") + kernel_width = int(weight.shape[1]) + if kernel_width < 1: + raise ValueError("kernel_width must be at least 1") + if tuple(conv_initial.shape) != (segments, channels, kernel_width - 1): + raise ValueError( + "conv_initial shape must match conv_in, cu_seqlens, and weight tail, got " + f"conv_in={tuple(conv_in.shape)} " + f"cu_seqlens={tuple(cu_seqlens.shape)} " + f"conv_initial={tuple(conv_initial.shape)} weight={tuple(weight.shape)}" + ) + if int(weight.shape[0]) != channels: + raise ValueError( + f"weight channels {int(weight.shape[0])} must match conv_in channels " + f"{channels}" + ) + if bias is not None and tuple(bias.shape) != (channels,): + raise ValueError(f"bias must be [channels], got {tuple(bias.shape)}") + for name, tensor in ( + ("conv_initial", conv_initial), + ("weight", weight), + ("bias", bias), + ): + if tensor is not None and tensor.device != conv_in.device: + raise ValueError(f"{name} must be on the same CUDA device as conv_in") + if tensor is not None and tensor.dtype != conv_in.dtype: + raise ValueError( + f"{name} dtype {tensor.dtype} must match conv_in {conv_in.dtype}" + ) diff --git a/src/art/megatron/gdn/gdn_shared_prefix.py b/src/art/megatron/gdn/gdn_shared_prefix.py new file mode 100644 index 000000000..86f39fdd2 --- /dev/null +++ b/src/art/megatron/gdn/gdn_shared_prefix.py @@ -0,0 +1,3523 @@ +from __future__ import annotations + +from bisect import bisect_left +from typing import Any, Literal, TypeVar + +from pydantic import BaseModel, ConfigDict, Field +import torch + +from art.megatron.context_parallel.layout_index import TokenLayoutIndex + +GdnSegmentKind = Literal["prefix", "completion"] +# FLA's public chunk_gated_delta_rule hard-codes 64-token WY chunks. +FLA_CHUNK_SIZE = 64 +_PydanticModelT = TypeVar("_PydanticModelT", bound=BaseModel) + + +class GdnSegmentSpec(BaseModel): + """Contiguous logical GDN segment in one packed row.""" + + model_config = ConfigDict(frozen=True) + + row_index: int = Field(ge=0) + family_index: int = Field(ge=0) + group_id: int + parent_id: int + start: int = Field(ge=0) + end: int = Field(ge=1) + kind: GdnSegmentKind + child_index: int | None = Field(default=None, ge=0) + + @property + def length(self) -> int: + return self.end - self.start + + def linear_indices(self, sequence_length: int) -> tuple[int, ...]: + base = self.row_index * sequence_length + return tuple(range(base + self.start, base + self.end)) + + +class GdnPackedFamilySpec(BaseModel): + """One shared-prefix family plus child completion segments.""" + + model_config = ConfigDict(frozen=True) + + row_index: int = Field(ge=0) + family_index: int = Field(ge=0) + prefix: GdnSegmentSpec + completions: tuple[GdnSegmentSpec, ...] + + @property + def completion_count(self) -> int: + return len(self.completions) + + @property + def token_count(self) -> int: + return self.prefix.length + sum(segment.length for segment in self.completions) + + +class GdnPackedExecutionSpec(BaseModel): + """Parsed shared-prefix GDN execution metadata for a packed batch.""" + + model_config = ConfigDict(frozen=True) + + batch_size: int = Field(ge=1) + sequence_length: int = Field(ge=1) + valid_lengths: tuple[int, ...] + families: tuple[GdnPackedFamilySpec, ...] + + @property + def family_count(self) -> int: + return len(self.families) + + @property + def completion_count(self) -> int: + return sum(family.completion_count for family in self.families) + + @property + def real_token_count(self) -> int: + return sum(self.valid_lengths) + + @property + def max_segment_length(self) -> int: + lengths = [ + segment.length + for family in self.families + for segment in (family.prefix, *family.completions) + ] + return max(lengths, default=0) + + def segments(self) -> tuple[GdnSegmentSpec, ...]: + return tuple( + segment + for family in self.families + for segment in (family.prefix, *family.completions) + ) + + +_GDN_SEGMENT_SPEC_FIELDS = frozenset( + { + "row_index", + "family_index", + "group_id", + "parent_id", + "start", + "end", + "kind", + "child_index", + } +) +_GDN_PACKED_FAMILY_SPEC_FIELDS = frozenset( + { + "row_index", + "family_index", + "prefix", + "completions", + } +) + + +def _trusted_pydantic_construct( + model_type: type[_PydanticModelT], + fields_set: frozenset[str], + **values: Any, +) -> _PydanticModelT: + model = model_type.__new__(model_type) + object.__setattr__(model, "__dict__", values) + object.__setattr__(model, "__pydantic_fields_set__", fields_set) + object.__setattr__(model, "__pydantic_extra__", None) + object.__setattr__(model, "__pydantic_private__", None) + return model + + +class GdnSegmentBucketPlan(BaseModel): + """Device-local index tensors for a variable-length GDN segment batch.""" + + model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) + + length: int = Field(ge=1) + lengths: torch.Tensor + real_mask: torch.Tensor + cu_seqlens: torch.Tensor + row_indices: torch.Tensor + position_indices: torch.Tensor + family_indices: torch.Tensor + real_token_count_static: int = Field(ge=0) + output_mask: torch.Tensor | None = None + + @property + def segment_count(self) -> int: + return int(self.family_indices.numel()) + + @property + def real_token_count(self) -> int: + return self.real_token_count_static + + +class GdnParentStateTransferPlan(BaseModel): + """Prefix-state rows transferred from one CP rank to another.""" + + model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) + + source_rank: int = Field(ge=0) + dest_rank: int = Field(ge=0) + family_indices: tuple[int, ...] + family_indices_tensor: torch.Tensor | None = None + + +class GdnCpPeerTransfer(BaseModel): + """Token rows sent from one source rank to one destination rank.""" + + model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) + + source_rank: int = Field(ge=0) + dest_rank: int = Field(ge=0) + token_count: int = Field(ge=0) + source_positions_tensor: torch.Tensor | None = None + dest_positions_tensor: torch.Tensor | None = None + + +class GdnCpExchangePlan(BaseModel): + """Minimal exchange metadata for local GDN plans.""" + + model_config = ConfigDict(frozen=True) + + cp_size: int = Field(ge=1) + source_token_counts_by_rank: tuple[int, ...] + dest_token_counts_by_rank: tuple[int, ...] + transfers: tuple[GdnCpPeerTransfer, ...] + cross_rank_token_count_override: int | None = Field(default=None, ge=0) + + @property + def cross_rank_token_count(self) -> int: + if self.cross_rank_token_count_override is not None: + return int(self.cross_rank_token_count_override) + return sum( + int(transfer.token_count) + for transfer in self.transfers + if transfer.source_rank != transfer.dest_rank + ) + + +class GdnPlannerConfig(BaseModel): + """Tunable cost coefficients for one packed-row GDN execution plan.""" + + model_config = ConfigDict(frozen=True) + + max_padding_ratio: float = Field(default=2.0, gt=1.0) + max_segments_per_batch: int = Field(default=4096, ge=1) + cp_chain_min_tokens_per_rank: int = Field(default=32, ge=1) + cp_chain_min_total_tokens: int = Field(default=32768, ge=1) + cp_chain_min_prefix_only_tokens: int = Field(default=32768, ge=1) + local_fork_launch_penalty_tokens: int = Field(default=256, ge=0) + cp_collective_latency_tokens: int = Field(default=512, ge=0) + parent_state_exchange_penalty_tokens: int = Field(default=2048, ge=0) + layout_cross_rank_token_cost: float = Field(default=2.0, ge=0.0) + rank_idle_token_cost: float = Field(default=1.0, ge=0.0) + empty_rank_penalty_tokens: int = Field(default=65536, ge=0) + max_zero_exchange_load_imbalance: float = Field(default=1.5, ge=1.0) + local_completion_rebalance_min_imbalance: float = Field(default=1.08, ge=1.0) + cp_schedule_improve_iters: int = Field(default=0, ge=0) + + +class GdnRankExecutionPlan(BaseModel): + """Rank-local planned execution metadata for shared-prefix GDN.""" + + model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) + + cp_rank: int = Field(ge=0) + cp_size: int = Field(ge=1) + batch_size: int = Field(ge=1) + sequence_length: int = Field(ge=0) + packed_batch_size: int | None = Field(default=None, ge=1) + packed_sequence_length: int | None = Field(default=None, ge=1) + real_token_mask: torch.Tensor + family_count: int = Field(ge=0) + completion_count: int = Field(ge=0) + prefix_buckets: tuple[GdnSegmentBucketPlan, ...] + completion_buckets: tuple[GdnSegmentBucketPlan, ...] + local_prefix_buckets: tuple[GdnSegmentBucketPlan, ...] = () + local_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () + ready_local_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () + remote_local_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () + chain_prefix_buckets: tuple[GdnSegmentBucketPlan, ...] = () + chain_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () + prefix_table_is_dense_ordered: bool + attention_to_gdn: Any | None = None + gdn_to_attention: Any | None = None + attention_token_ranges: tuple[tuple[int, int, int], ...] = () + gdn_token_ranges: tuple[tuple[int, int, int], ...] = () + attention_token_count: int = Field(default=0, ge=0) + gdn_token_count: int = Field(default=0, ge=0) + parent_state_exchange_family_indices: tuple[int, ...] = () + parent_state_transfers: tuple[GdnParentStateTransferPlan, ...] = () + prefix_boundary_buckets: tuple[GdnSegmentBucketPlan, ...] = () + prefix_tail_buckets: tuple[GdnSegmentBucketPlan, ...] = () + completion_warmup_buckets: tuple[GdnSegmentBucketPlan, ...] = () + + @property + def attention_token_indices(self) -> tuple[int, ...]: + return _tokens_from_rank_ranges(self.attention_token_ranges) + + @property + def gdn_token_indices(self) -> tuple[int, ...]: + return _tokens_from_rank_ranges(self.gdn_token_ranges) + + +class GdnCpSegmentSchedule(BaseModel): + """CPU-side ownership and bucket schedule for one CP GDN plan.""" + + model_config = ConfigDict(frozen=True) + + gdn_token_counts_by_rank: tuple[int, ...] + gdn_token_ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...] = () + cross_rank_token_count: int = Field(ge=0) + chain_prefix_buckets: tuple[tuple[GdnSegmentSpec, ...], ...] + chain_completion_buckets: tuple[tuple[GdnSegmentSpec, ...], ...] + local_prefix_segments_by_rank: tuple[tuple[GdnSegmentSpec, ...], ...] + local_completion_segments_by_rank: tuple[tuple[GdnSegmentSpec, ...], ...] + parent_state_exchange_family_indices: tuple[int, ...] = () + parent_state_transfers: tuple[GdnParentStateTransferPlan, ...] = () + + +class _ExplicitBucketColumn(BaseModel): + model_config = ConfigDict(frozen=True) + + row_index: int + family_index: int + positions: tuple[int, ...] + output_mask: tuple[bool, ...] + + @property + def length(self) -> int: + return len(self.positions) + + +class _AttentionLayoutIndex(BaseModel): + """Counting index for CP attention token ownership.""" + + model_config = ConfigDict(frozen=True) + + token_ranges_by_rank: tuple[tuple[tuple[int, int], ...], ...] + token_range_ends_by_rank: tuple[tuple[int, ...], ...] + range_count: int = Field(ge=0) + + +def _layout_cp_size(layout: TokenLayoutIndex) -> int: + return len(layout.token_counts_by_rank) + + +def _layout_token_count(layout: TokenLayoutIndex) -> int: + return sum(int(count) for count in layout.token_counts_by_rank) + + +def _tokens_from_rank_ranges( + ranges: tuple[tuple[int, int, int], ...], +) -> tuple[int, ...]: + return tuple(token for start, end, _ in ranges for token in range(start, end)) + + +def _token_layout_from_rank_ranges( + ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...], +) -> TokenLayoutIndex: + return TokenLayoutIndex( + ownership_ranges_by_rank=ranges_by_rank, + token_counts_by_rank=tuple( + _ranges_token_count(ranges) for ranges in ranges_by_rank + ), + ) + + +def _ranges_token_count(ranges: tuple[tuple[int, int, int], ...]) -> int: + return sum(int(end) - int(start) for start, end, _ in ranges) + + +def build_gdn_rank_execution_plan( + spec: GdnPackedExecutionSpec, + *, + device: torch.device | str, + cp_rank: int = 0, + cp_size: int = 1, + attention_token_layout_index: TokenLayoutIndex | None = None, + cp_segment_schedule: GdnCpSegmentSchedule | None = None, + planner_config: GdnPlannerConfig | None = None, +) -> GdnRankExecutionPlan: + """Build rank-local tensor metadata from a parsed shared-prefix DAG. + + Planning is CPU-bound and must run once per packed training sequence. CP>1 + emits mixed work: native FLA CP chain buckets for long segments and local + fork buckets for short work where CP collectives would be inefficient. + """ + + planner_config = planner_config or GdnPlannerConfig() + target_device = torch.device(device) + if target_device.type != "cpu": + cpu_plan = build_gdn_rank_execution_plan( + spec, + device="cpu", + cp_rank=cp_rank, + cp_size=cp_size, + attention_token_layout_index=attention_token_layout_index, + cp_segment_schedule=cp_segment_schedule, + planner_config=planner_config, + ) + return move_gdn_rank_execution_plan_to_device(cpu_plan, target_device) + if cp_size != 1 or cp_rank != 0: + return _build_cp_rank_execution_plan( + spec, + device=device, + cp_rank=cp_rank, + cp_size=cp_size, + attention_token_layout_index=attention_token_layout_index, + cp_segment_schedule=cp_segment_schedule, + planner_config=planner_config, + ) + ( + prefix_boundary_buckets, + prefix_tail_buckets, + completion_warmup_buckets, + ) = _build_chunk_aligned_cp1_bucket_plans( + spec, + device=device, + planner_config=planner_config, + ) + valid_lengths = torch.tensor( + spec.valid_lengths, + device=device, + dtype=torch.long, + ) + positions = torch.arange(spec.sequence_length, device=device, dtype=torch.long) + local_range_list: list[tuple[int, int, int]] = [] + local_position = 0 + for row_index, length in enumerate(spec.valid_lengths): + if length: + start = row_index * spec.sequence_length + local_range_list.append((start, start + length, local_position)) + local_position += length + local_ranges = tuple(local_range_list) + return GdnRankExecutionPlan.model_construct( + cp_rank=cp_rank, + cp_size=cp_size, + batch_size=spec.batch_size, + sequence_length=spec.sequence_length, + packed_batch_size=spec.batch_size, + packed_sequence_length=spec.sequence_length, + real_token_mask=positions.unsqueeze(0) < valid_lengths.unsqueeze(1), + family_count=spec.family_count, + completion_count=spec.completion_count, + prefix_buckets=(), + completion_buckets=(), + local_prefix_buckets=(), + local_completion_buckets=(), + ready_local_completion_buckets=(), + remote_local_completion_buckets=(), + chain_prefix_buckets=(), + chain_completion_buckets=(), + prefix_table_is_dense_ordered=False, + attention_token_ranges=local_ranges, + gdn_token_ranges=local_ranges, + attention_token_count=spec.real_token_count, + gdn_token_count=spec.real_token_count, + prefix_boundary_buckets=prefix_boundary_buckets, + prefix_tail_buckets=prefix_tail_buckets, + completion_warmup_buckets=completion_warmup_buckets, + ) + + +def move_gdn_rank_execution_plan_to_device( + plan: GdnRankExecutionPlan, + device: torch.device | str, +) -> GdnRankExecutionPlan: + """Move planner tensors to the execution device after CPU planning.""" + + from art.megatron.gdn.layout import move_cp_exchange_plan_to_device + + return GdnRankExecutionPlan.model_construct( + cp_rank=plan.cp_rank, + cp_size=plan.cp_size, + batch_size=plan.batch_size, + sequence_length=plan.sequence_length, + packed_batch_size=plan.packed_batch_size, + packed_sequence_length=plan.packed_sequence_length, + real_token_mask=_move_planner_tensor(plan.real_token_mask, device), + family_count=plan.family_count, + completion_count=plan.completion_count, + prefix_buckets=_move_bucket_plans(plan.prefix_buckets, device), + completion_buckets=_move_bucket_plans(plan.completion_buckets, device), + local_prefix_buckets=_move_bucket_plans(plan.local_prefix_buckets, device), + local_completion_buckets=_move_bucket_plans( + plan.local_completion_buckets, device + ), + ready_local_completion_buckets=_move_bucket_plans( + plan.ready_local_completion_buckets, device + ), + remote_local_completion_buckets=_move_bucket_plans( + plan.remote_local_completion_buckets, device + ), + chain_prefix_buckets=_move_bucket_plans(plan.chain_prefix_buckets, device), + chain_completion_buckets=_move_bucket_plans( + plan.chain_completion_buckets, device + ), + prefix_table_is_dense_ordered=plan.prefix_table_is_dense_ordered, + attention_to_gdn=move_cp_exchange_plan_to_device(plan.attention_to_gdn, device), + gdn_to_attention=move_cp_exchange_plan_to_device(plan.gdn_to_attention, device), + attention_token_ranges=plan.attention_token_ranges, + gdn_token_ranges=plan.gdn_token_ranges, + attention_token_count=plan.attention_token_count, + gdn_token_count=plan.gdn_token_count, + parent_state_exchange_family_indices=plan.parent_state_exchange_family_indices, + parent_state_transfers=_move_parent_state_transfers( + plan.parent_state_transfers, device + ), + prefix_boundary_buckets=_move_bucket_plans( + plan.prefix_boundary_buckets, device + ), + prefix_tail_buckets=_move_bucket_plans(plan.prefix_tail_buckets, device), + completion_warmup_buckets=_move_bucket_plans( + plan.completion_warmup_buckets, device + ), + ) + + +def _move_bucket_plans( + buckets: tuple[GdnSegmentBucketPlan, ...], + device: torch.device | str, +) -> tuple[GdnSegmentBucketPlan, ...]: + return tuple( + GdnSegmentBucketPlan.model_construct( + length=bucket.length, + lengths=_move_planner_tensor(bucket.lengths, device), + real_mask=_move_planner_tensor(bucket.real_mask, device), + cu_seqlens=_move_planner_tensor(bucket.cu_seqlens, device), + row_indices=_move_planner_tensor(bucket.row_indices, device), + position_indices=_move_planner_tensor(bucket.position_indices, device), + family_indices=_move_planner_tensor(bucket.family_indices, device), + real_token_count_static=bucket.real_token_count, + output_mask=( + _move_planner_tensor(bucket.output_mask, device) + if bucket.output_mask is not None + else None + ), + ) + for bucket in buckets + ) + + +def _move_parent_state_transfers( + transfers: tuple[GdnParentStateTransferPlan, ...], + device: torch.device | str, +) -> tuple[GdnParentStateTransferPlan, ...]: + return tuple( + GdnParentStateTransferPlan.model_construct( + source_rank=transfer.source_rank, + dest_rank=transfer.dest_rank, + family_indices=transfer.family_indices, + family_indices_tensor=( + _move_planner_tensor(transfer.family_indices_tensor, device) + if transfer.family_indices_tensor is not None + else None + ), + ) + for transfer in transfers + ) + + +def build_gdn_chain_only_rank_execution_plan( + spec: GdnPackedExecutionSpec, + *, + device: torch.device | str, + cp_rank: int, + cp_size: int, + planner_config: GdnPlannerConfig | None = None, +) -> GdnRankExecutionPlan | None: + """Build the rank-local plan for rows that are entirely native CP chains. + + This avoids a large Python-object schedule broadcast for long pure-chain rows + such as `64k + 8x64k`. Mixed local/chain rows still use the general planner. + """ + + planner_config = planner_config or GdnPlannerConfig() + if cp_size <= 1: + return None + if cp_rank < 0 or cp_rank >= cp_size: + raise ValueError(f"cp_rank must be in [0, {cp_size}), got {cp_rank}") + if not spec.families: + return None + for family in spec.families: + if not _can_chain_prefix_segment( + family.prefix, cp_size=cp_size, planner_config=planner_config + ): + return None + if any( + not _can_chain_segment( + completion, cp_size=cp_size, planner_config=planner_config + ) + for completion in family.completions + ): + return None + + local_tokens: list[int] = [] + prefix_segments: list[GdnSegmentSpec] = [] + completion_segments: list[GdnSegmentSpec] = [] + for family in spec.families: + prefix_segments.append(family.prefix) + local_tokens.extend( + _chain_rank_token_indices( + family.prefix, + spec, + cp_rank=cp_rank, + cp_size=cp_size, + ) + ) + for completion in family.completions: + completion_segments.append(completion) + local_tokens.extend( + _chain_rank_token_indices( + completion, + spec, + cp_rank=cp_rank, + cp_size=cp_size, + ) + ) + local_token_tuple = tuple(local_tokens) + local_token_ranges = _local_token_ranges(local_token_tuple) + token_counts_by_rank = tuple( + len(local_token_tuple) if rank == cp_rank else 0 for rank in range(cp_size) + ) + identity_exchange = GdnCpExchangePlan.model_construct( + cp_size=cp_size, + source_token_counts_by_rank=token_counts_by_rank, + dest_token_counts_by_rank=token_counts_by_rank, + transfers=tuple( + GdnCpPeerTransfer.model_construct( + source_rank=rank, + dest_rank=rank, + token_count=count, + source_positions_tensor=None, + dest_positions_tensor=None, + ) + for rank, count in enumerate(token_counts_by_rank) + if count + ), + ) + chain_prefix_buckets = _batch_segments_by_padded_work( + tuple(prefix_segments), + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + chain_completion_buckets = _batch_segments_by_padded_work( + tuple(completion_segments), + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + prefix_family_order = tuple( + segment.family_index for bucket in chain_prefix_buckets for segment in bucket + ) + return GdnRankExecutionPlan.model_construct( + cp_rank=cp_rank, + cp_size=cp_size, + batch_size=1, + sequence_length=len(local_token_tuple), + packed_batch_size=spec.batch_size, + packed_sequence_length=spec.sequence_length, + real_token_mask=torch.ones( + 1, len(local_token_tuple), device=device, dtype=torch.bool + ), + family_count=spec.family_count, + completion_count=spec.completion_count, + prefix_buckets=(), + completion_buckets=(), + local_prefix_buckets=(), + local_completion_buckets=(), + ready_local_completion_buckets=(), + remote_local_completion_buckets=(), + chain_prefix_buckets=_build_position_bucket_plans( + chain_prefix_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ), + chain_completion_buckets=_build_position_bucket_plans( + chain_completion_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ), + prefix_table_is_dense_ordered=( + prefix_family_order == tuple(range(spec.family_count)) + ), + attention_to_gdn=identity_exchange, + gdn_to_attention=identity_exchange, + attention_token_ranges=local_token_ranges, + gdn_token_ranges=local_token_ranges, + attention_token_count=len(local_token_tuple), + gdn_token_count=len(local_token_tuple), + parent_state_exchange_family_indices=(), + parent_state_transfers=(), + ) + + +def _build_chain_attention_layout_rank_execution_plan( + spec: GdnPackedExecutionSpec, + *, + device: torch.device | str, + cp_rank: int, + cp_size: int, + attention_token_layout_index: TokenLayoutIndex | None, + planner_config: GdnPlannerConfig, +) -> GdnRankExecutionPlan | None: + if cp_size <= 1 or not spec.families: + return None + for family in spec.families: + if not _can_chain_prefix_segment( + family.prefix, cp_size=cp_size, planner_config=planner_config + ): + return None + if any( + not _can_chain_segment( + completion, cp_size=cp_size, planner_config=planner_config + ) + for completion in family.completions + ): + return None + + from art.megatron.gdn.layout import ( + _reverse_exchange_plan, + build_local_rank_cp_exchange_plan_from_dest_ranges, + ) + + source_layout = _attention_source_layout( + spec, + cp_size=cp_size, + attention_token_layout_index=attention_token_layout_index, + planner_config=planner_config, + ) + attention_layout_index = _build_attention_layout_index_from_token_layout( + source_layout, + max_ranges=max(1, 2 * spec.real_token_count // len(tuple(spec.segments()))), + ) + rank_loads = [0] * cp_size + gdn_ranges_by_rank: list[list[tuple[int, int, int]]] = [[] for _ in range(cp_size)] + prefix_segments: list[GdnSegmentSpec] = [] + completion_segments: list[GdnSegmentSpec] = [] + cross_rank_token_count = 0 + for family in spec.families: + for segment in (family.prefix, *family.completions): + if segment.kind == "prefix": + prefix_segments.append(segment) + else: + completion_segments.append(segment) + token_start = _segment_token_start(segment, spec.sequence_length) + shards = _attention_contiguous_chain_shards( + token_start, + segment.length, + cp_size=cp_size, + attention_layout_index=attention_layout_index, + ) + if shards is None: + shards = tuple( + _chain_rank_token_indices( + segment, + spec, + cp_rank=rank, + cp_size=cp_size, + ) + for rank in range(cp_size) + ) + for rank, shard in enumerate(shards): + position_start = rank_loads[rank] + gdn_ranges_by_rank[rank].append( + (shard.start, shard.stop, position_start) + ) + rank_loads[rank] += len(shard) + cross_rank_token_count += len(shard) - _attention_overlap_count( + attention_layout_index, + rank, + shard.start, + shard.stop, + ) + local_token_ranges = tuple(gdn_ranges_by_rank[cp_rank]) + local_token_count = rank_loads[cp_rank] + attention_to_gdn = build_local_rank_cp_exchange_plan_from_dest_ranges( + source_layout=source_layout, + device=device, + dest_ranges_by_rank=tuple(tuple(ranges) for ranges in gdn_ranges_by_rank), + local_rank=cp_rank, + cross_rank_token_count=cross_rank_token_count, + ) + gdn_to_attention = _reverse_exchange_plan(attention_to_gdn) + chain_prefix_buckets = _batch_segments_by_padded_work( + tuple(prefix_segments), + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + chain_completion_buckets = _batch_segments_by_padded_work( + tuple(completion_segments), + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + prefix_family_order = tuple( + segment.family_index for bucket in chain_prefix_buckets for segment in bucket + ) + return GdnRankExecutionPlan.model_construct( + cp_rank=cp_rank, + cp_size=cp_size, + batch_size=1, + sequence_length=local_token_count, + packed_batch_size=spec.batch_size, + packed_sequence_length=spec.sequence_length, + real_token_mask=torch.ones( + 1, local_token_count, device=device, dtype=torch.bool + ), + family_count=spec.family_count, + completion_count=spec.completion_count, + prefix_buckets=(), + completion_buckets=(), + local_prefix_buckets=(), + local_completion_buckets=(), + ready_local_completion_buckets=(), + remote_local_completion_buckets=(), + chain_prefix_buckets=_build_position_bucket_plans( + chain_prefix_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ), + chain_completion_buckets=_build_position_bucket_plans( + chain_completion_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ), + prefix_table_is_dense_ordered=( + prefix_family_order == tuple(range(spec.family_count)) + ), + attention_to_gdn=attention_to_gdn, + gdn_to_attention=gdn_to_attention, + attention_token_ranges=source_layout.ownership_ranges_by_rank[cp_rank], + gdn_token_ranges=local_token_ranges, + attention_token_count=source_layout.token_counts_by_rank[cp_rank], + gdn_token_count=local_token_count, + parent_state_exchange_family_indices=(), + parent_state_transfers=(), + ) + + +def _build_local_attention_layout_rank_execution_plan( + spec: GdnPackedExecutionSpec, + *, + device: torch.device | str, + cp_rank: int, + cp_size: int, + attention_token_layout_index: TokenLayoutIndex | None, + planner_config: GdnPlannerConfig, +) -> GdnRankExecutionPlan | None: + if cp_size <= 1 or not spec.families: + return None + if any( + _can_chain_family(family, cp_size=cp_size, planner_config=planner_config) + for family in spec.families + ): + return None + + from art.megatron.gdn.layout import ( + _reverse_exchange_plan, + build_local_rank_cp_exchange_plan_from_dest_ranges, + ) + + source_layout = _attention_source_layout( + spec, + cp_size=cp_size, + attention_token_layout_index=attention_token_layout_index, + planner_config=planner_config, + ) + attention_layout_index = _build_attention_layout_index_from_token_layout( + source_layout, + max_ranges=max(1, 2 * spec.real_token_count // len(tuple(spec.segments()))), + ) + segment_attention_counts = _segment_attention_rank_counts( + spec, + cp_size=cp_size, + attention_layout_index=attention_layout_index, + ) + best = _assign_local_attention_segments( + spec, + cp_size=cp_size, + segment_attention_counts=segment_attention_counts, + co_locate_local_families=False, + planner_config=planner_config, + ) + if _can_zero_exchange_colocate_families( + spec, + cp_size=cp_size, + segment_attention_counts=segment_attention_counts, + ): + co_located = _assign_local_attention_segments( + spec, + cp_size=cp_size, + segment_attention_counts=segment_attention_counts, + co_locate_local_families=True, + planner_config=planner_config, + ) + if co_located[3] == 0 and co_located[4] < best[4]: + best = co_located + ( + prefix_owner_by_family, + completion_owners_by_family, + _, + cross_rank_token_count, + _, + ) = best + + local_prefix_segments: list[GdnSegmentSpec] = [] + local_completion_segments: list[GdnSegmentSpec] = [] + gdn_ranges_by_rank: list[list[tuple[int, int, int]]] = [[] for _ in range(cp_size)] + rank_loads = [0] * cp_size + parent_state_exchange_families: set[int] = set() + parent_state_transfer_families: dict[tuple[int, int], set[int]] = {} + + def append_segment(rank: int, segment: GdnSegmentSpec) -> None: + token_start = _segment_token_start(segment, spec.sequence_length) + position_start = rank_loads[rank] + gdn_ranges_by_rank[rank].append( + (token_start, token_start + segment.length, position_start) + ) + rank_loads[rank] += segment.length + + for family in spec.families: + prefix_owner = prefix_owner_by_family[family.family_index] + if prefix_owner == cp_rank: + local_prefix_segments.append(family.prefix) + append_segment(prefix_owner, family.prefix) + completion_owners = completion_owners_by_family[family.family_index] + for completion, completion_owner in zip( + family.completions, completion_owners, strict=True + ): + if completion_owner == cp_rank: + local_completion_segments.append(completion) + append_segment(completion_owner, completion) + if completion_owner != prefix_owner: + parent_state_exchange_families.add(family.family_index) + parent_state_transfer_families.setdefault( + (prefix_owner, completion_owner), set() + ).add(family.family_index) + + local_token_ranges = tuple(gdn_ranges_by_rank[cp_rank]) + local_token_count = rank_loads[cp_rank] + attention_to_gdn = build_local_rank_cp_exchange_plan_from_dest_ranges( + source_layout=source_layout, + device=device, + dest_ranges_by_rank=tuple(tuple(ranges) for ranges in gdn_ranges_by_rank), + local_rank=cp_rank, + cross_rank_token_count=cross_rank_token_count, + ) + gdn_to_attention = _reverse_exchange_plan(attention_to_gdn) + local_prefix_family_indices = { + segment.family_index for segment in local_prefix_segments + } + local_prefix_buckets = _batch_segments_by_padded_work( + (), + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + chunk_local_completion_segments = tuple( + segment + for segment in local_completion_segments + if segment.family_index in local_prefix_family_indices + ) + plain_local_completion_segments = tuple( + segment + for segment in local_completion_segments + if segment.family_index not in local_prefix_family_indices + ) + ready_completion_segments, remote_completion_segments = ( + _split_ready_and_remote_completion_segments( + plain_local_completion_segments, + local_prefix_segments=(), + chain_prefix_buckets=(), + ) + ) + ready_completion_buckets = _batch_segments_by_padded_work( + ready_completion_segments, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + remote_completion_buckets = _batch_segments_by_padded_work( + remote_completion_segments, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + prefix_family_order = tuple( + segment.family_index for bucket in local_prefix_buckets for segment in bucket + ) + ready_completion_bucket_plans = _build_position_bucket_plans( + ready_completion_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ) + remote_completion_bucket_plans = _build_position_bucket_plans( + remote_completion_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ) + ( + prefix_boundary_buckets, + prefix_tail_buckets, + completion_warmup_buckets, + ) = _build_chunk_aligned_position_bucket_plans( + tuple(local_prefix_segments), + chunk_local_completion_segments, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + planner_config=planner_config, + ) + return GdnRankExecutionPlan.model_construct( + cp_rank=cp_rank, + cp_size=cp_size, + batch_size=1, + sequence_length=local_token_count, + packed_batch_size=spec.batch_size, + packed_sequence_length=spec.sequence_length, + real_token_mask=torch.ones( + 1, local_token_count, device=device, dtype=torch.bool + ), + family_count=spec.family_count, + completion_count=spec.completion_count, + prefix_buckets=(), + completion_buckets=(), + local_prefix_buckets=_build_position_bucket_plans( + local_prefix_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ), + local_completion_buckets=( + ready_completion_bucket_plans + remote_completion_bucket_plans + ), + ready_local_completion_buckets=ready_completion_bucket_plans, + remote_local_completion_buckets=remote_completion_bucket_plans, + chain_prefix_buckets=(), + chain_completion_buckets=(), + prefix_table_is_dense_ordered=( + not local_prefix_segments + and prefix_family_order == tuple(range(spec.family_count)) + ), + attention_to_gdn=attention_to_gdn, + gdn_to_attention=gdn_to_attention, + attention_token_ranges=source_layout.ownership_ranges_by_rank[cp_rank], + gdn_token_ranges=local_token_ranges, + attention_token_count=source_layout.token_counts_by_rank[cp_rank], + gdn_token_count=local_token_count, + parent_state_exchange_family_indices=tuple( + sorted(parent_state_exchange_families) + ), + parent_state_transfers=_transfer_plans_to_device( + _build_parent_state_transfer_plans(parent_state_transfer_families), + device=device, + ), + prefix_boundary_buckets=prefix_boundary_buckets, + prefix_tail_buckets=prefix_tail_buckets, + completion_warmup_buckets=completion_warmup_buckets, + ) + + +def _assign_local_attention_segments( + spec: GdnPackedExecutionSpec, + *, + cp_size: int, + segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], + co_locate_local_families: bool, + planner_config: GdnPlannerConfig, +) -> tuple[ + tuple[int, ...], + tuple[tuple[int, ...], ...], + tuple[int, ...], + int, + float, +]: + rank_loads = [0] * cp_size + has_prefix = [False] * cp_size + has_completion = [False] * cp_size + prefix_owner_by_family: list[int] = [] + completion_owners_by_family: list[tuple[int, ...]] = [] + parent_state_exchange_families: set[int] = set() + cross_rank_token_count = 0 + + def append_owner(rank: int, segment: GdnSegmentSpec) -> None: + nonlocal cross_rank_token_count + rank_loads[rank] += segment.length + cross_rank_token_count += ( + segment.length - segment_attention_counts[_segment_key(segment)][rank] + ) + + for family in spec.families: + if co_locate_local_families: + owner = _best_segment_owner( + (family.prefix, *family.completions), + rank_loads, + segment_attention_counts=segment_attention_counts, + planner_config=planner_config, + ) + prefix_owner_by_family.append(owner) + completion_owners = tuple(owner for _ in family.completions) + completion_owners_by_family.append(completion_owners) + has_prefix[owner] = True + for segment in (family.prefix, *family.completions): + append_owner(owner, segment) + if family.completions: + has_completion[owner] = True + continue + + prefix_owner = _best_segment_owner( + (family.prefix,), + rank_loads, + segment_attention_counts=segment_attention_counts, + planner_config=planner_config, + ) + prefix_owner_by_family.append(prefix_owner) + has_prefix[prefix_owner] = True + append_owner(prefix_owner, family.prefix) + completion_owners = [] + for completion in family.completions: + owner = _best_segment_owner( + (completion,), + rank_loads, + segment_attention_counts=segment_attention_counts, + planner_config=planner_config, + ) + completion_owners.append(owner) + has_completion[owner] = True + append_owner(owner, completion) + if owner != prefix_owner: + parent_state_exchange_families.add(family.family_index) + completion_owners_by_family.append(tuple(completion_owners)) + + max_load = max(rank_loads, default=0) + idle_tokens = sum(max_load - load for load in rank_loads) + empty_rank_count = sum(1 for load in rank_loads if load == 0) + local_launches = sum(has_prefix) + sum(has_completion) + score = ( + max_load + + planner_config.rank_idle_token_cost * idle_tokens + + planner_config.empty_rank_penalty_tokens * empty_rank_count + + planner_config.local_fork_launch_penalty_tokens * local_launches + + planner_config.layout_cross_rank_token_cost * cross_rank_token_count + + planner_config.parent_state_exchange_penalty_tokens + * len(parent_state_exchange_families) + ) + return ( + tuple(prefix_owner_by_family), + tuple(completion_owners_by_family), + tuple(sorted(parent_state_exchange_families)), + cross_rank_token_count, + score, + ) + + +def _can_zero_exchange_colocate_families( + spec: GdnPackedExecutionSpec, + *, + cp_size: int, + segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], +) -> bool: + for family in spec.families: + family_rank_counts = [0] * cp_size + for segment in (family.prefix, *family.completions): + segment_counts = segment_attention_counts[_segment_key(segment)] + for rank in range(cp_size): + family_rank_counts[rank] += segment_counts[rank] + if max(family_rank_counts, default=0) != family.token_count: + return False + return True + + +def parse_gdn_shared_prefix_segments( + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + *, + min_completions_per_family: int = 0, +) -> GdnPackedExecutionSpec: + """Parse ART packed shared-prefix metadata into a GDN segment DAG. + + The parser is intentionally strict: GDN state routing depends on prompt-family + boundaries, so malformed metadata should fail before execution can silently + leak recurrent or conv state across siblings or independent families. + """ + + groups = _rank2_long_cpu("group_ids", group_ids) + parents = _rank2_long_cpu("parent_ids", parent_ids) + if tuple(groups.shape) != tuple(parents.shape): + raise ValueError( + "group_ids and parent_ids must have the same shape, got " + f"{tuple(groups.shape)} and {tuple(parents.shape)}" + ) + + batch_size, sequence_length = (int(groups.shape[0]), int(groups.shape[1])) + valid_lengths: list[int] = [] + families: list[GdnPackedFamilySpec] = [] + for row_index in range(batch_size): + row_group_ids = groups[row_index] + row_parent_ids = parents[row_index] + valid_length = _validate_padding_tensor( + row_index, row_group_ids, row_parent_ids + ) + valid_lengths.append(valid_length) + if valid_length == 0: + continue + families.extend( + _parse_row_tensor( + row_index=row_index, + group_ids=row_group_ids, + parent_ids=row_parent_ids, + valid_length=valid_length, + first_family_index=len(families), + min_completions_per_family=min_completions_per_family, + ) + ) + + return GdnPackedExecutionSpec( + batch_size=batch_size, + sequence_length=sequence_length, + valid_lengths=tuple(valid_lengths), + families=tuple(families), + ) + + +def _build_segment_bucket_plans( + segment_buckets: tuple[tuple[GdnSegmentSpec, ...], ...], + *, + device: torch.device | str, +) -> tuple[GdnSegmentBucketPlan, ...]: + return tuple( + _build_segment_bucket_plan(bucket[0].length, bucket, device=device) + for bucket in segment_buckets + ) + + +def _build_chunk_aligned_cp1_bucket_plans( + spec: GdnPackedExecutionSpec, + *, + device: torch.device | str, + planner_config: GdnPlannerConfig, +) -> tuple[ + tuple[GdnSegmentBucketPlan, ...], + tuple[GdnSegmentBucketPlan, ...], + tuple[GdnSegmentBucketPlan, ...], +]: + boundary_segments: list[GdnSegmentSpec] = [] + tail_segments: list[GdnSegmentSpec] = [] + completion_columns: list[_ExplicitBucketColumn] = [] + for family in spec.families: + prefix = family.prefix + boundary_end = _prefix_chunk_boundary_end(prefix) + if boundary_end > prefix.start: + boundary_segments.append( + _segment_with_bounds(prefix, prefix.start, boundary_end) + ) + if boundary_end < prefix.end and not family.completions: + tail_segments.append(_segment_with_bounds(prefix, boundary_end, prefix.end)) + warmup_positions = tuple(range(boundary_end, prefix.end)) + for completion in family.completions: + warmup_mask = (completion.child_index == 0,) * len(warmup_positions) + completion_positions = tuple(range(completion.start, completion.end)) + completion_columns.append( + _ExplicitBucketColumn( + row_index=completion.row_index, + family_index=completion.family_index, + positions=warmup_positions + completion_positions, + output_mask=warmup_mask + (True,) * len(completion_positions), + ) + ) + boundary_buckets = _batch_segments_by_padded_work( + tuple(boundary_segments), + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + tail_buckets = _batch_segments_by_padded_work( + tuple(tail_segments), + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + completion_buckets = _batch_explicit_bucket_columns( + tuple(completion_columns), + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + return ( + _build_segment_bucket_plans(boundary_buckets, device=device), + _build_segment_bucket_plans(tail_buckets, device=device), + _build_explicit_bucket_plans(completion_buckets, device=device), + ) + + +def _build_chunk_aligned_position_bucket_plans( + prefix_segments: tuple[GdnSegmentSpec, ...], + completion_segments: tuple[GdnSegmentSpec, ...], + local_token_ranges: tuple[tuple[int, int, int], ...], + *, + sequence_length: int, + device: torch.device | str, + planner_config: GdnPlannerConfig, +) -> tuple[ + tuple[GdnSegmentBucketPlan, ...], + tuple[GdnSegmentBucketPlan, ...], + tuple[GdnSegmentBucketPlan, ...], +]: + local_range_ends = tuple(token_end for _, token_end, _ in local_token_ranges) + completions_by_family: dict[int, list[GdnSegmentSpec]] = {} + for completion in completion_segments: + completions_by_family.setdefault(completion.family_index, []).append(completion) + boundary_segments: list[GdnSegmentSpec] = [] + tail_segments: list[GdnSegmentSpec] = [] + completion_columns: list[_ExplicitBucketColumn] = [] + for prefix in prefix_segments: + boundary_end = _prefix_chunk_boundary_end(prefix) + if boundary_end > prefix.start: + boundary_segments.append( + _segment_with_bounds(prefix, prefix.start, boundary_end) + ) + family_completions = tuple( + sorted( + completions_by_family.get(prefix.family_index, ()), + key=lambda segment: segment.child_index or 0, + ) + ) + if boundary_end < prefix.end and not family_completions: + tail_segments.append(_segment_with_bounds(prefix, boundary_end, prefix.end)) + warmup_positions = _local_positions_for_span( + prefix.row_index, + boundary_end, + prefix.end, + sequence_length=sequence_length, + local_token_ranges=local_token_ranges, + local_range_ends=local_range_ends, + ) + for completion in family_completions: + completion_positions = _local_positions_for_span( + completion.row_index, + completion.start, + completion.end, + sequence_length=sequence_length, + local_token_ranges=local_token_ranges, + local_range_ends=local_range_ends, + ) + completion_columns.append( + _ExplicitBucketColumn( + row_index=0, + family_index=completion.family_index, + positions=warmup_positions + completion_positions, + output_mask=(completion.child_index == 0,) * len(warmup_positions) + + (True,) * len(completion_positions), + ) + ) + boundary_buckets = _batch_segments_by_padded_work( + tuple(boundary_segments), + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + tail_buckets = _batch_segments_by_padded_work( + tuple(tail_segments), + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + completion_buckets = _batch_explicit_bucket_columns( + tuple(completion_columns), + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + return ( + _build_position_bucket_plans( + boundary_buckets, + local_token_ranges, + sequence_length=sequence_length, + device=device, + ), + _build_position_bucket_plans( + tail_buckets, + local_token_ranges, + sequence_length=sequence_length, + device=device, + ), + _build_explicit_bucket_plans(completion_buckets, device=device), + ) + + +def _local_positions_for_span( + row_index: int, + start: int, + end: int, + *, + sequence_length: int, + local_token_ranges: tuple[tuple[int, int, int], ...], + local_range_ends: tuple[int, ...], +) -> tuple[int, ...]: + if start == end: + return () + segment = _trusted_pydantic_construct( + GdnSegmentSpec, + _GDN_SEGMENT_SPEC_FIELDS, + row_index=row_index, + family_index=0, + group_id=0, + parent_id=0, + start=start, + end=end, + kind="prefix", + child_index=None, + ) + return tuple( + int(position) + for position in _local_positions_for_segment( + segment, + sequence_length=sequence_length, + local_token_ranges=local_token_ranges, + local_range_ends=local_range_ends, + ).tolist() + ) + + +def _prefix_chunk_boundary_end(prefix: GdnSegmentSpec) -> int: + aligned_length = (prefix.length // FLA_CHUNK_SIZE) * FLA_CHUNK_SIZE + return prefix.start + aligned_length + + +def _segment_with_bounds( + segment: GdnSegmentSpec, start: int, end: int +) -> GdnSegmentSpec: + return _trusted_pydantic_construct( + GdnSegmentSpec, + _GDN_SEGMENT_SPEC_FIELDS, + row_index=segment.row_index, + family_index=segment.family_index, + group_id=segment.group_id, + parent_id=segment.parent_id, + start=start, + end=end, + kind=segment.kind, + child_index=segment.child_index, + ) + + +def _batch_explicit_bucket_columns( + columns: tuple[_ExplicitBucketColumn, ...], + *, + max_padding_ratio: float = 1.25, + max_segments_per_batch: int = 128, +) -> tuple[tuple[_ExplicitBucketColumn, ...], ...]: + if not columns: + return () + ordered = sorted( + columns, + key=lambda column: (column.length, column.family_index, column.row_index), + ) + batches: list[list[_ExplicitBucketColumn]] = [] + current: list[_ExplicitBucketColumn] = [] + current_tokens = 0 + current_max = 0 + for column in ordered: + next_count = len(current) + 1 + next_tokens = current_tokens + column.length + next_max = max(current_max, column.length) + padded = next_max * next_count + can_extend = not current or ( + next_count <= max_segments_per_batch + and padded <= max_padding_ratio * next_tokens + ) + if not can_extend: + batches.append(current) + current = [] + current_tokens = 0 + current_max = 0 + current.append(column) + current_tokens += column.length + current_max = max(current_max, column.length) + if current: + batches.append(current) + return tuple(tuple(batch) for batch in batches) + + +def _build_explicit_bucket_plans( + bucket_columns: tuple[tuple[_ExplicitBucketColumn, ...], ...], + *, + device: torch.device | str, +) -> tuple[GdnSegmentBucketPlan, ...]: + return tuple( + _build_explicit_bucket_plan(columns, device=device) + for columns in bucket_columns + ) + + +def _build_explicit_bucket_plan( + columns: tuple[_ExplicitBucketColumn, ...], + *, + device: torch.device | str, +) -> GdnSegmentBucketPlan: + max_length = max(column.length for column in columns) + lengths_cpu = torch.tensor([column.length for column in columns], dtype=torch.long) + offsets_cpu = torch.arange(max_length, dtype=torch.long).unsqueeze(1) + real_mask_cpu = offsets_cpu < lengths_cpu.unsqueeze(0) + row_indices_cpu = torch.zeros(max_length, len(columns), dtype=torch.long) + position_indices_cpu = torch.zeros(max_length, len(columns), dtype=torch.long) + output_mask_cpu = torch.zeros(max_length, len(columns), dtype=torch.bool) + for column_index, column in enumerate(columns): + length = column.length + row_indices_cpu[:length, column_index] = column.row_index + position_indices_cpu[:length, column_index] = torch.tensor( + column.positions, dtype=torch.long + ) + output_mask_cpu[:length, column_index] = torch.tensor( + column.output_mask, dtype=torch.bool + ) + family_indices_cpu = torch.tensor( + [column.family_index for column in columns], dtype=torch.long + ) + return GdnSegmentBucketPlan.model_construct( + length=max_length, + lengths=_move_planner_tensor(lengths_cpu, device), + real_mask=_move_planner_tensor(real_mask_cpu, device), + cu_seqlens=_move_planner_tensor( + torch.cat([lengths_cpu.new_zeros(1), torch.cumsum(lengths_cpu, dim=0)]), + device, + ), + row_indices=_move_planner_tensor(row_indices_cpu, device), + position_indices=_move_planner_tensor(position_indices_cpu, device), + family_indices=_move_planner_tensor(family_indices_cpu, device), + real_token_count_static=int(lengths_cpu.sum().item()), + output_mask=_move_planner_tensor(output_mask_cpu, device), + ) + + +def _attention_source_layout( + spec: GdnPackedExecutionSpec, + *, + cp_size: int, + attention_token_layout_index: TokenLayoutIndex | None, + planner_config: GdnPlannerConfig, +) -> TokenLayoutIndex: + if attention_token_layout_index is not None: + if _layout_cp_size(attention_token_layout_index) != cp_size: + raise ValueError( + "attention token layout index cp_size must match GDN cp_size, got " + f"{_layout_cp_size(attention_token_layout_index)} and {cp_size}" + ) + if _layout_token_count(attention_token_layout_index) != spec.real_token_count: + raise ValueError( + "attention token layout index token count must match GDN real token " + f"count, got {_layout_token_count(attention_token_layout_index)} and " + f"{spec.real_token_count}" + ) + return attention_token_layout_index + return _token_layout_from_rank_ranges( + _default_attention_layout_ranges( + spec, + cp_size=cp_size, + planner_config=planner_config, + ) + ) + + +def _build_cp_rank_execution_plan( + spec: GdnPackedExecutionSpec, + *, + device: torch.device | str, + cp_rank: int, + cp_size: int, + attention_token_layout_index: TokenLayoutIndex | None, + cp_segment_schedule: GdnCpSegmentSchedule | None, + planner_config: GdnPlannerConfig, +) -> GdnRankExecutionPlan: + if cp_size < 1: + raise ValueError(f"cp_size must be >= 1, got {cp_size}") + if cp_rank < 0 or cp_rank >= cp_size: + raise ValueError(f"cp_rank must be in [0, {cp_size}), got {cp_rank}") + if ( + attention_token_layout_index is not None + and _layout_cp_size(attention_token_layout_index) != cp_size + ): + raise ValueError( + "attention token layout index cp_size must match GDN cp_size, got " + f"{_layout_cp_size(attention_token_layout_index)} and {cp_size}" + ) + + has_explicit_attention_layout = attention_token_layout_index is not None + if cp_segment_schedule is None and not has_explicit_attention_layout: + chain_only_plan = build_gdn_chain_only_rank_execution_plan( + spec, + device=device, + cp_rank=cp_rank, + cp_size=cp_size, + planner_config=planner_config, + ) + if chain_only_plan is not None: + return chain_only_plan + local_family_plan = _build_local_family_rank_execution_plan( + spec, + device=device, + cp_rank=cp_rank, + cp_size=cp_size, + planner_config=planner_config, + ) + if local_family_plan is not None: + return local_family_plan + if cp_segment_schedule is None and has_explicit_attention_layout: + chain_layout_plan = _build_chain_attention_layout_rank_execution_plan( + spec, + device=device, + cp_rank=cp_rank, + cp_size=cp_size, + attention_token_layout_index=attention_token_layout_index, + planner_config=planner_config, + ) + if chain_layout_plan is not None: + return chain_layout_plan + local_layout_plan = _build_local_attention_layout_rank_execution_plan( + spec, + device=device, + cp_rank=cp_rank, + cp_size=cp_size, + attention_token_layout_index=attention_token_layout_index, + planner_config=planner_config, + ) + if local_layout_plan is not None: + return local_layout_plan + + from art.megatron.gdn.layout import ( + _reverse_exchange_plan, + build_local_rank_cp_exchange_plan_from_dest_ranges, + ) + + source_layout = _attention_source_layout( + spec, + cp_size=cp_size, + attention_token_layout_index=attention_token_layout_index, + planner_config=planner_config, + ) + if cp_segment_schedule is None: + schedule = _build_cp_segment_schedule( + spec, + cp_size=cp_size, + attention_layout_index=_build_attention_layout_index_from_token_layout( + source_layout, + max_ranges=max( + 1, + (2 * spec.real_token_count) // max(1, len(spec.segments())), + ), + ), + planner_config=planner_config, + ) + else: + schedule = cp_segment_schedule + if len(schedule.gdn_token_counts_by_rank) != cp_size: + raise ValueError(f"CP GDN schedule must contain {cp_size} ranks") + attention_to_gdn = build_local_rank_cp_exchange_plan_from_dest_ranges( + source_layout=source_layout, + device=device, + local_rank=cp_rank, + dest_ranges_by_rank=schedule.gdn_token_ranges_by_rank, + cross_rank_token_count=schedule.cross_rank_token_count, + ) + gdn_to_attention = _reverse_exchange_plan(attention_to_gdn) + local_token_ranges = schedule.gdn_token_ranges_by_rank[cp_rank] + local_gdn_token_count = schedule.gdn_token_counts_by_rank[cp_rank] + + chain_prefix_buckets = tuple( + bucket for bucket in schedule.chain_prefix_buckets if bucket + ) + chain_completion_buckets = tuple( + bucket for bucket in schedule.chain_completion_buckets if bucket + ) + local_prefix_segments = tuple(schedule.local_prefix_segments_by_rank[cp_rank]) + local_prefix_family_indices = { + segment.family_index for segment in local_prefix_segments + } + local_prefix_buckets = _batch_segments_by_padded_work( + () if local_prefix_segments else (), + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + local_completion_segments = tuple( + schedule.local_completion_segments_by_rank[cp_rank] + ) + chunk_local_completion_segments = tuple( + segment + for segment in local_completion_segments + if segment.family_index in local_prefix_family_indices + ) + plain_local_completion_segments = tuple( + segment + for segment in local_completion_segments + if segment.family_index not in local_prefix_family_indices + ) + ready_completion_segments, remote_completion_segments = ( + _split_ready_and_remote_completion_segments( + plain_local_completion_segments, + local_prefix_segments=(), + chain_prefix_buckets=chain_prefix_buckets, + ) + ) + ready_local_completion_buckets = _batch_segments_by_padded_work( + ready_completion_segments, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + remote_local_completion_buckets = _batch_segments_by_padded_work( + remote_completion_segments, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + local_completion_buckets = ( + ready_local_completion_buckets + remote_local_completion_buckets + ) + prefix_family_order = tuple( + segment.family_index + for bucket in ( + *chain_prefix_buckets, + *local_prefix_buckets, + ) + for segment in bucket + ) + ( + prefix_boundary_buckets, + prefix_tail_buckets, + completion_warmup_buckets, + ) = _build_chunk_aligned_position_bucket_plans( + local_prefix_segments, + chunk_local_completion_segments, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + planner_config=planner_config, + ) + return GdnRankExecutionPlan.model_construct( + cp_rank=cp_rank, + cp_size=cp_size, + batch_size=1, + sequence_length=local_gdn_token_count, + packed_batch_size=spec.batch_size, + packed_sequence_length=spec.sequence_length, + real_token_mask=torch.ones( + 1, local_gdn_token_count, device=device, dtype=torch.bool + ), + family_count=spec.family_count, + completion_count=spec.completion_count, + prefix_buckets=(), + completion_buckets=(), + local_prefix_buckets=_build_position_bucket_plans( + local_prefix_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ), + local_completion_buckets=_build_position_bucket_plans( + local_completion_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ), + ready_local_completion_buckets=_build_position_bucket_plans( + ready_local_completion_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ), + remote_local_completion_buckets=_build_position_bucket_plans( + remote_local_completion_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ), + chain_prefix_buckets=_build_position_bucket_plans( + chain_prefix_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ), + chain_completion_buckets=_build_position_bucket_plans( + chain_completion_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ), + prefix_table_is_dense_ordered=( + not local_prefix_segments + and prefix_family_order == tuple(range(spec.family_count)) + ), + attention_to_gdn=attention_to_gdn, + gdn_to_attention=gdn_to_attention, + attention_token_ranges=source_layout.ownership_ranges_by_rank[cp_rank], + gdn_token_ranges=local_token_ranges, + attention_token_count=source_layout.token_counts_by_rank[cp_rank], + gdn_token_count=local_gdn_token_count, + parent_state_exchange_family_indices=( + schedule.parent_state_exchange_family_indices + ), + parent_state_transfers=_transfer_plans_to_device( + schedule.parent_state_transfers, device=device + ), + prefix_boundary_buckets=prefix_boundary_buckets, + prefix_tail_buckets=prefix_tail_buckets, + completion_warmup_buckets=completion_warmup_buckets, + ) + + +def build_gdn_cp_segment_schedule( + spec: GdnPackedExecutionSpec, + *, + cp_size: int, + attention_token_layout_index: TokenLayoutIndex | None = None, + planner_config: GdnPlannerConfig | None = None, +) -> GdnCpSegmentSchedule: + planner_config = planner_config or GdnPlannerConfig() + source_layout = _attention_source_layout( + spec, + cp_size=cp_size, + attention_token_layout_index=attention_token_layout_index, + planner_config=planner_config, + ) + return _build_cp_segment_schedule( + spec, + cp_size=cp_size, + attention_layout_index=_build_attention_layout_index_from_token_layout( + source_layout, + max_ranges=max( + 1, (2 * spec.real_token_count) // max(1, len(spec.segments())) + ), + ), + planner_config=planner_config, + ) + + +def _build_cp_segment_schedule( + spec: GdnPackedExecutionSpec, + *, + cp_size: int, + attention_layout_index: _AttentionLayoutIndex, + planner_config: GdnPlannerConfig, +) -> GdnCpSegmentSchedule: + segment_attention_counts = _segment_attention_rank_counts( + spec, + cp_size=cp_size, + attention_layout_index=attention_layout_index, + ) + legal_chain_families = tuple( + family.family_index + for family in spec.families + if _can_chain_family(family, cp_size=cp_size, planner_config=planner_config) + ) + chain_family_indices = frozenset(legal_chain_families) + best = _materialize_cp_segment_schedule( + spec, + cp_size=cp_size, + attention_layout_index=attention_layout_index, + segment_attention_counts=segment_attention_counts, + chain_family_indices=chain_family_indices, + co_locate_local_families=False, + planner_config=planner_config, + ) + best_score = _score_cp_segment_schedule( + best, + planner_config=planner_config, + ) + has_local_families = len(chain_family_indices) != spec.family_count + if has_local_families: + local_family_trial = _materialize_cp_segment_schedule( + spec, + cp_size=cp_size, + attention_layout_index=attention_layout_index, + segment_attention_counts=segment_attention_counts, + chain_family_indices=chain_family_indices, + co_locate_local_families=True, + planner_config=planner_config, + ) + local_family_score = _score_cp_segment_schedule( + local_family_trial, + planner_config=planner_config, + ) + if ( + local_family_trial.cross_rank_token_count == 0 + and local_family_score < best_score + ): + best = local_family_trial + best_score = local_family_score + if _is_balanced_zero_exchange_schedule( + best, + planner_config=planner_config, + ): + return best + candidate_sets = _candidate_chain_family_sets( + spec, + legal_chain_families=legal_chain_families, + cp_size=cp_size, + ) + for trial_chain in candidate_sets: + if trial_chain == chain_family_indices: + continue + trial = _materialize_cp_segment_schedule( + spec, + cp_size=cp_size, + attention_layout_index=attention_layout_index, + segment_attention_counts=segment_attention_counts, + chain_family_indices=trial_chain, + co_locate_local_families=False, + planner_config=planner_config, + ) + trial_score = _score_cp_segment_schedule( + trial, + planner_config=planner_config, + ) + if trial.cross_rank_token_count == 0 and trial_score < best_score: + best = trial + best_score = trial_score + chain_family_indices = trial_chain + trial = _materialize_cp_segment_schedule( + spec, + cp_size=cp_size, + attention_layout_index=attention_layout_index, + segment_attention_counts=segment_attention_counts, + chain_family_indices=trial_chain, + co_locate_local_families=True, + planner_config=planner_config, + ) + trial_score = _score_cp_segment_schedule( + trial, + planner_config=planner_config, + ) + if trial_score < best_score: + best = trial + best_score = trial_score + chain_family_indices = trial_chain + for _ in range(planner_config.cp_schedule_improve_iters): + improved = False + for family_index in legal_chain_families: + for trial_chain in ( + chain_family_indices - {family_index}, + chain_family_indices | {family_index}, + ): + if trial_chain == chain_family_indices: + continue + trial = _materialize_cp_segment_schedule( + spec, + cp_size=cp_size, + attention_layout_index=attention_layout_index, + segment_attention_counts=segment_attention_counts, + chain_family_indices=trial_chain, + co_locate_local_families=False, + planner_config=planner_config, + ) + trial_score = _score_cp_segment_schedule( + trial, + planner_config=planner_config, + ) + if trial_score < best_score: + best = trial + best_score = trial_score + chain_family_indices = trial_chain + improved = True + break + if improved: + break + if not improved: + break + return best + + +def _is_balanced_zero_exchange_schedule( + schedule: GdnCpSegmentSchedule, + *, + planner_config: GdnPlannerConfig, +) -> bool: + rank_loads = list(schedule.gdn_token_counts_by_rank) + if not rank_loads or any(load == 0 for load in rank_loads): + return False + if schedule.cross_rank_token_count: + return False + if schedule.parent_state_exchange_family_indices: + return False + if max(rank_loads) > planner_config.max_zero_exchange_load_imbalance * ( + sum(rank_loads) / len(rank_loads) + ): + return False + return True + + +def _materialize_cp_segment_schedule( + spec: GdnPackedExecutionSpec, + *, + cp_size: int, + attention_layout_index: _AttentionLayoutIndex, + segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], + chain_family_indices: frozenset[int], + co_locate_local_families: bool, + planner_config: GdnPlannerConfig, +) -> GdnCpSegmentSchedule: + gdn_ranges_by_rank: list[list[tuple[int, int, int]]] = [[] for _ in range(cp_size)] + rank_loads = [0] * cp_size + local_prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [ + [] for _ in range(cp_size) + ] + local_completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ + [] for _ in range(cp_size) + ] + chain_prefix_segments: list[GdnSegmentSpec] = [] + chain_completion_segments: list[GdnSegmentSpec] = [] + parent_state_exchange_families: set[int] = set() + parent_state_transfer_families: dict[tuple[int, int], set[int]] = {} + cross_rank_token_count = 0 + + for family in spec.families: + if family.family_index in chain_family_indices: + chain_prefix_segments.append(family.prefix) + cross_rank_token_count += _append_chain_segment( + gdn_ranges_by_rank, + rank_loads, + family.prefix, + spec, + attention_layout_index=attention_layout_index, + ) + for completion in family.completions: + if _can_chain_segment( + completion, cp_size=cp_size, planner_config=planner_config + ): + chain_completion_segments.append(completion) + cross_rank_token_count += _append_chain_segment( + gdn_ranges_by_rank, + rank_loads, + completion, + spec, + attention_layout_index=attention_layout_index, + ) + continue + owner = _best_segment_owner( + (completion,), + rank_loads, + segment_attention_counts=segment_attention_counts, + planner_config=planner_config, + ) + local_completion_segments_by_rank[owner].append(completion) + cross_rank_token_count += _append_local_segment( + gdn_ranges_by_rank, + rank_loads, + owner, + completion, + spec, + segment_attention_counts=segment_attention_counts, + ) + else: + if co_locate_local_families: + owner = _best_segment_owner( + (family.prefix, *family.completions), + rank_loads, + segment_attention_counts=segment_attention_counts, + planner_config=planner_config, + ) + local_prefix_segments_by_rank[owner].append(family.prefix) + cross_rank_token_count += _append_local_segment( + gdn_ranges_by_rank, + rank_loads, + owner, + family.prefix, + spec, + segment_attention_counts=segment_attention_counts, + ) + for completion in family.completions: + local_completion_segments_by_rank[owner].append(completion) + cross_rank_token_count += _append_local_segment( + gdn_ranges_by_rank, + rank_loads, + owner, + completion, + spec, + segment_attention_counts=segment_attention_counts, + ) + continue + prefix_owner = _best_segment_owner( + (family.prefix,), + rank_loads, + segment_attention_counts=segment_attention_counts, + planner_config=planner_config, + ) + local_prefix_segments_by_rank[prefix_owner].append(family.prefix) + cross_rank_token_count += _append_local_segment( + gdn_ranges_by_rank, + rank_loads, + prefix_owner, + family.prefix, + spec, + segment_attention_counts=segment_attention_counts, + ) + for completion in family.completions: + owner = _best_segment_owner( + (completion,), + rank_loads, + segment_attention_counts=segment_attention_counts, + planner_config=planner_config, + ) + if owner != prefix_owner: + parent_state_exchange_families.add(family.family_index) + parent_state_transfer_families.setdefault( + (prefix_owner, owner), set() + ).add(family.family_index) + local_completion_segments_by_rank[owner].append(completion) + cross_rank_token_count += _append_local_segment( + gdn_ranges_by_rank, + rank_loads, + owner, + completion, + spec, + segment_attention_counts=segment_attention_counts, + ) + + return GdnCpSegmentSchedule.model_construct( + gdn_token_counts_by_rank=tuple(rank_loads), + gdn_token_ranges_by_rank=tuple(tuple(ranges) for ranges in gdn_ranges_by_rank), + cross_rank_token_count=cross_rank_token_count, + chain_prefix_buckets=_batch_segments_by_padded_work( + tuple(chain_prefix_segments), + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ), + chain_completion_buckets=_batch_segments_by_padded_work( + tuple(chain_completion_segments), + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ), + local_prefix_segments_by_rank=tuple( + tuple(segments) for segments in local_prefix_segments_by_rank + ), + local_completion_segments_by_rank=tuple( + tuple(segments) for segments in local_completion_segments_by_rank + ), + parent_state_exchange_family_indices=tuple( + sorted(parent_state_exchange_families) + ), + parent_state_transfers=_build_parent_state_transfer_plans( + parent_state_transfer_families + ), + ) + + +def _build_local_family_rank_execution_plan( + spec: GdnPackedExecutionSpec, + *, + device: torch.device | str, + cp_rank: int, + cp_size: int, + planner_config: GdnPlannerConfig, +) -> GdnRankExecutionPlan | None: + if cp_size <= 1 or not spec.families: + return None + target_rank_load = spec.real_token_count / cp_size + loads = [0] * cp_size + prefix_owner_by_family: list[int] = [] + completion_owner_by_family: list[int] = [] + for family in spec.families: + if _can_chain_family(family, cp_size=cp_size, planner_config=planner_config): + return None + if ( + family.prefix.length + > planner_config.max_zero_exchange_load_imbalance * target_rank_load + ): + return None + owner = _least_loaded_rank(loads) + prefix_owner_by_family.append(owner) + completion_owner_by_family.append(owner) + loads[owner] += family.token_count + + if max(loads, default=0) > ( + planner_config.local_completion_rebalance_min_imbalance * target_rank_load + ): + completion_owner_by_family = list( + _rebalance_local_completion_bundles( + spec, + prefix_owner_by_family=tuple(prefix_owner_by_family), + completion_owner_by_family=tuple(completion_owner_by_family), + initial_loads=tuple(loads), + planner_config=planner_config, + ) + ) + local_tokens, prefix_segments, completion_segments = ( + _materialize_local_family_rank_assignment( + spec, + cp_rank=cp_rank, + prefix_owner_by_family=tuple(prefix_owner_by_family), + completion_owner_by_family=tuple(completion_owner_by_family), + ) + ) + parent_state_transfer_families: dict[tuple[int, int], set[int]] = {} + for family in spec.families: + prefix_owner = prefix_owner_by_family[family.family_index] + completion_owner = completion_owner_by_family[family.family_index] + if completion_owner != prefix_owner and family.completions: + parent_state_transfer_families.setdefault( + (prefix_owner, completion_owner), set() + ).add(family.family_index) + + token_indices_by_rank = tuple( + local_tokens if rank == cp_rank else () for rank in range(cp_size) + ) + identity_exchange = GdnCpExchangePlan.model_construct( + cp_size=cp_size, + source_token_counts_by_rank=tuple( + len(tokens) for tokens in token_indices_by_rank + ), + dest_token_counts_by_rank=tuple( + len(tokens) for tokens in token_indices_by_rank + ), + transfers=tuple( + GdnCpPeerTransfer.model_construct( + source_rank=rank, + dest_rank=rank, + token_count=len(tokens), + source_positions_tensor=None, + dest_positions_tensor=None, + ) + for rank, tokens in enumerate(token_indices_by_rank) + if tokens + ), + ) + local_token_ranges = _local_token_ranges(local_tokens) + prefix_buckets = _batch_segments_by_padded_work( + prefix_segments, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + ready_completion_segments, remote_completion_segments = ( + _split_ready_and_remote_completion_segments( + completion_segments, + local_prefix_segments=prefix_segments, + chain_prefix_buckets=(), + ) + ) + ready_completion_buckets = _batch_segments_by_padded_work( + ready_completion_segments, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + remote_completion_buckets = _batch_segments_by_padded_work( + remote_completion_segments, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + completion_buckets = ready_completion_buckets + remote_completion_buckets + prefix_family_order = tuple( + segment.family_index for bucket in prefix_buckets for segment in bucket + ) + local_prefix_bucket_plans = _build_position_bucket_plans( + prefix_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ) + ready_completion_bucket_plans = _build_position_bucket_plans( + ready_completion_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ) + remote_completion_bucket_plans = _build_position_bucket_plans( + remote_completion_buckets, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + ) + local_completion_bucket_plans = ( + ready_completion_bucket_plans + remote_completion_bucket_plans + ) + ( + prefix_boundary_buckets, + prefix_tail_buckets, + completion_warmup_buckets, + ) = _build_chunk_aligned_position_bucket_plans( + prefix_segments, + completion_segments, + local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + planner_config=planner_config, + ) + return GdnRankExecutionPlan.model_construct( + cp_rank=cp_rank, + cp_size=cp_size, + batch_size=1, + sequence_length=len(local_tokens), + packed_batch_size=spec.batch_size, + packed_sequence_length=spec.sequence_length, + real_token_mask=torch.ones( + 1, len(local_tokens), device=device, dtype=torch.bool + ), + family_count=spec.family_count, + completion_count=spec.completion_count, + prefix_buckets=(), + completion_buckets=(), + local_prefix_buckets=local_prefix_bucket_plans, + local_completion_buckets=local_completion_bucket_plans, + ready_local_completion_buckets=ready_completion_bucket_plans, + remote_local_completion_buckets=remote_completion_bucket_plans, + chain_prefix_buckets=(), + chain_completion_buckets=(), + prefix_table_is_dense_ordered=( + prefix_family_order == tuple(range(spec.family_count)) + ), + attention_to_gdn=identity_exchange, + gdn_to_attention=identity_exchange, + attention_token_ranges=local_token_ranges, + gdn_token_ranges=local_token_ranges, + attention_token_count=len(local_tokens), + gdn_token_count=len(local_tokens), + parent_state_exchange_family_indices=tuple( + sorted( + family.family_index + for family in spec.families + if completion_owner_by_family[family.family_index] + != prefix_owner_by_family[family.family_index] + and family.completions + ) + ), + parent_state_transfers=_transfer_plans_to_device( + _build_parent_state_transfer_plans(parent_state_transfer_families), + device=device, + ), + prefix_boundary_buckets=prefix_boundary_buckets, + prefix_tail_buckets=prefix_tail_buckets, + completion_warmup_buckets=completion_warmup_buckets, + ) + + +def _rebalance_local_completion_bundles( + spec: GdnPackedExecutionSpec, + *, + prefix_owner_by_family: tuple[int, ...], + completion_owner_by_family: tuple[int, ...], + initial_loads: tuple[int, ...], + planner_config: GdnPlannerConfig, +) -> tuple[int, ...]: + owners = list(completion_owner_by_family) + loads = list(initial_loads) + + def score(candidate_loads: list[int], candidate_owners: list[int]) -> float: + max_load = max(candidate_loads, default=0) + idle_tokens = sum(max_load - load for load in candidate_loads) + transfer_count = sum( + 1 + for index, owner in enumerate(candidate_owners) + if owner != prefix_owner_by_family[index] + and spec.families[index].completions + ) + return ( + max_load + + planner_config.rank_idle_token_cost * idle_tokens + + planner_config.parent_state_exchange_penalty_tokens * transfer_count + ) + + best_score = score(loads, owners) + while True: + best_move: tuple[int, int, list[int], list[int], float] | None = None + for family in spec.families: + completion_tokens = sum(segment.length for segment in family.completions) + if completion_tokens <= 0: + continue + source = owners[family.family_index] + for dest in range(len(loads)): + if dest == source: + continue + candidate_loads = list(loads) + candidate_owners = list(owners) + candidate_loads[source] -= completion_tokens + candidate_loads[dest] += completion_tokens + candidate_owners[family.family_index] = dest + candidate_score = score(candidate_loads, candidate_owners) + if candidate_score >= best_score: + continue + if best_move is None or candidate_score < best_move[4]: + best_move = ( + family.family_index, + dest, + candidate_loads, + candidate_owners, + candidate_score, + ) + if best_move is None: + return tuple(owners) + _, _, loads, owners, best_score = best_move + + +def _materialize_local_family_rank_assignment( + spec: GdnPackedExecutionSpec, + *, + cp_rank: int, + prefix_owner_by_family: tuple[int, ...], + completion_owner_by_family: tuple[int, ...], +) -> tuple[tuple[int, ...], tuple[GdnSegmentSpec, ...], tuple[GdnSegmentSpec, ...]]: + token_indices: list[int] = [] + prefix_segments: list[GdnSegmentSpec] = [] + completion_segments: list[GdnSegmentSpec] = [] + for family in spec.families: + prefix_owner = prefix_owner_by_family[family.family_index] + completion_owner = completion_owner_by_family[family.family_index] + if prefix_owner == cp_rank: + prefix_segments.append(family.prefix) + token_indices.extend(family.prefix.linear_indices(spec.sequence_length)) + for completion in family.completions: + if completion_owner == cp_rank: + completion_segments.append(completion) + token_indices.extend(completion.linear_indices(spec.sequence_length)) + return tuple(token_indices), tuple(prefix_segments), tuple(completion_segments) + + +def _empty_local_family_rank_execution_plan( + spec: GdnPackedExecutionSpec, + *, + device: torch.device | str, + cp_rank: int, + cp_size: int, +) -> GdnRankExecutionPlan: + identity_exchange = GdnCpExchangePlan.model_construct( + cp_size=cp_size, + source_token_counts_by_rank=tuple(0 for _ in range(cp_size)), + dest_token_counts_by_rank=tuple(0 for _ in range(cp_size)), + transfers=(), + ) + return GdnRankExecutionPlan.model_construct( + cp_rank=cp_rank, + cp_size=cp_size, + batch_size=1, + sequence_length=0, + packed_batch_size=spec.batch_size, + packed_sequence_length=spec.sequence_length, + real_token_mask=torch.ones(1, 0, device=device, dtype=torch.bool), + family_count=spec.family_count, + completion_count=spec.completion_count, + prefix_buckets=(), + completion_buckets=(), + local_prefix_buckets=(), + local_completion_buckets=(), + ready_local_completion_buckets=(), + remote_local_completion_buckets=(), + chain_prefix_buckets=(), + chain_completion_buckets=(), + prefix_table_is_dense_ordered=False, + attention_to_gdn=identity_exchange, + gdn_to_attention=identity_exchange, + attention_token_ranges=(), + gdn_token_ranges=(), + attention_token_count=0, + gdn_token_count=0, + parent_state_exchange_family_indices=(), + parent_state_transfers=(), + ) + + +def _can_chain_segment( + segment: GdnSegmentSpec, + *, + cp_size: int, + planner_config: GdnPlannerConfig, +) -> bool: + if segment.length < cp_size: + return False + per_rank = segment.length / cp_size + if per_rank < planner_config.cp_chain_min_tokens_per_rank: + return False + return segment.length >= planner_config.cp_chain_min_total_tokens + + +def _build_parent_state_transfer_plans( + families_by_peer: dict[tuple[int, int], set[int]], +) -> tuple[GdnParentStateTransferPlan, ...]: + return tuple( + GdnParentStateTransferPlan( + source_rank=source_rank, + dest_rank=dest_rank, + family_indices=tuple(sorted(family_indices)), + ) + for (source_rank, dest_rank), family_indices in sorted(families_by_peer.items()) + if source_rank != dest_rank and family_indices + ) + + +def _split_ready_and_remote_completion_segments( + completion_segments: tuple[GdnSegmentSpec, ...], + *, + local_prefix_segments: tuple[GdnSegmentSpec, ...], + chain_prefix_buckets: tuple[tuple[GdnSegmentSpec, ...], ...], +) -> tuple[tuple[GdnSegmentSpec, ...], tuple[GdnSegmentSpec, ...]]: + ready_family_indices = { + segment.family_index for segment in local_prefix_segments + } | {segment.family_index for bucket in chain_prefix_buckets for segment in bucket} + ready = [] + remote = [] + for segment in completion_segments: + if segment.family_index in ready_family_indices: + ready.append(segment) + else: + remote.append(segment) + return tuple(ready), tuple(remote) + + +def _transfer_plans_to_device( + transfers: tuple[GdnParentStateTransferPlan, ...], + *, + device: torch.device | str, +) -> tuple[GdnParentStateTransferPlan, ...]: + return tuple( + transfer.model_copy( + update={ + "family_indices_tensor": _move_planner_tensor( + torch.tensor(transfer.family_indices, dtype=torch.long), + device, + ) + } + ) + for transfer in transfers + ) + + +def _can_chain_family( + family: GdnPackedFamilySpec, + *, + cp_size: int, + planner_config: GdnPlannerConfig, +) -> bool: + if not _can_chain_prefix_segment( + family.prefix, cp_size=cp_size, planner_config=planner_config + ): + return False + if any( + _can_chain_segment(completion, cp_size=cp_size, planner_config=planner_config) + for completion in family.completions + ): + return True + return family.prefix.length >= planner_config.cp_chain_min_prefix_only_tokens + + +def _can_chain_prefix_segment( + segment: GdnSegmentSpec, + *, + cp_size: int, + planner_config: GdnPlannerConfig, +) -> bool: + if segment.length < cp_size: + return False + per_rank = segment.length / cp_size + if per_rank < planner_config.cp_chain_min_tokens_per_rank: + return False + return segment.length >= planner_config.cp_chain_min_prefix_only_tokens + + +def _candidate_chain_family_sets( + spec: GdnPackedExecutionSpec, + *, + legal_chain_families: tuple[int, ...], + cp_size: int, +) -> tuple[frozenset[int], ...]: + if not legal_chain_families: + return (frozenset(),) + candidates: set[frozenset[int]] = {frozenset(), frozenset(legal_chain_families)} + if len(legal_chain_families) <= 4: + for mask in range(1, 1 << len(legal_chain_families)): + candidates.add( + frozenset( + family_index + for bit, family_index in enumerate(legal_chain_families) + if mask & (1 << bit) + ) + ) + else: + by_chain_value = sorted( + legal_chain_families, + key=lambda family_index: ( + _family_chain_candidate_tokens(spec.families[family_index]), + spec.families[family_index].prefix.length, + ), + reverse=True, + ) + for count in range(1, min(len(by_chain_value), cp_size * 2) + 1): + candidates.add(frozenset(by_chain_value[:count])) + for family_index in by_chain_value[: max(cp_size * 2, 1)]: + candidates.add(frozenset((family_index,))) + return tuple(sorted(candidates, key=lambda item: (len(item), tuple(sorted(item))))) + + +def _family_chain_candidate_tokens(family: GdnPackedFamilySpec) -> int: + return family.prefix.length + sum( + completion.length for completion in family.completions + ) + + +def _score_cp_segment_schedule( + schedule: GdnCpSegmentSchedule, + *, + planner_config: GdnPlannerConfig, +) -> float: + rank_loads = list(schedule.gdn_token_counts_by_rank) + max_load = max(rank_loads, default=0) + idle_tokens = sum(max_load - load for load in rank_loads) + empty_rank_count = sum(1 for load in rank_loads if load == 0) + local_launches = sum( + 1 for segments in schedule.local_prefix_segments_by_rank if segments + ) + sum(1 for segments in schedule.local_completion_segments_by_rank if segments) + return ( + max_load + + planner_config.rank_idle_token_cost * idle_tokens + + planner_config.empty_rank_penalty_tokens * empty_rank_count + + planner_config.local_fork_launch_penalty_tokens * local_launches + + planner_config.layout_cross_rank_token_cost * schedule.cross_rank_token_count + + planner_config.parent_state_exchange_penalty_tokens + * len(schedule.parent_state_exchange_family_indices) + + planner_config.cp_collective_latency_tokens + * (len(schedule.chain_prefix_buckets) + len(schedule.chain_completion_buckets)) + ) + + +def _best_segment_owner( + segments: tuple[GdnSegmentSpec, ...], + rank_loads: list[int], + *, + segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], + planner_config: GdnPlannerConfig, +) -> int: + del planner_config + if len(segments) == 1: + on_rank_tokens = segment_attention_counts[_segment_key(segments[0])] + else: + rank_count = len(rank_loads) + counts_by_rank = [0] * rank_count + for segment in segments: + segment_counts = segment_attention_counts[_segment_key(segment)] + for rank in range(rank_count): + counts_by_rank[rank] += segment_counts[rank] + on_rank_tokens = tuple(counts_by_rank) + best_locality = max(on_rank_tokens, default=0) + if best_locality <= 0: + return _least_loaded_rank(rank_loads) + best_rank = 0 + best_load = None + for rank, tokens in enumerate(on_rank_tokens): + if tokens != best_locality: + continue + load = rank_loads[rank] + if best_load is None or load < best_load: + best_rank = rank + best_load = load + return best_rank + + +def _build_attention_layout_index_from_token_layout( + layout: TokenLayoutIndex, + *, + max_ranges: int, +) -> _AttentionLayoutIndex: + del max_ranges + ranges_by_rank = tuple( + tuple(sorted((int(start), int(end)) for start, end, _ in rank_ranges)) + for rank_ranges in layout.ownership_ranges_by_rank + ) + range_count = sum(len(ranges) for ranges in ranges_by_rank) + return _AttentionLayoutIndex.model_construct( + token_ranges_by_rank=ranges_by_rank, + token_range_ends_by_rank=tuple( + tuple(end for _, end in ranges) for ranges in ranges_by_rank + ), + range_count=range_count, + ) + + +def _segment_attention_rank_counts( + spec: GdnPackedExecutionSpec, + *, + cp_size: int, + attention_layout_index: _AttentionLayoutIndex, +) -> dict[tuple[int, int, int], tuple[int, ...]]: + del cp_size + segments = tuple(spec.segments()) + if not segments: + return {} + starts = torch.tensor( + [_segment_token_start(segment, spec.sequence_length) for segment in segments], + dtype=torch.long, + ) + lengths = torch.tensor([segment.length for segment in segments], dtype=torch.long) + ends = starts + lengths + counts_by_rank = [] + for ranges in attention_layout_index.token_ranges_by_rank: + counts_by_rank.append(_rank_range_overlap_counts(starts, ends, ranges)) + counts_tensor = torch.stack(counts_by_rank, dim=1) + totals = counts_tensor.sum(dim=1) + if not torch.equal(totals, lengths): + bad_index = int(torch.nonzero(totals != lengths, as_tuple=False)[0].item()) + raise ValueError( + "attention layout is missing a real token required by GDN; " + f"segment={_segment_key(segments[bad_index])}" + ) + counts = counts_tensor.tolist() + return { + _segment_key(segment): tuple(int(value) for value in counts[index]) + for index, segment in enumerate(segments) + } + + +def _rank_range_overlap_counts( + starts: torch.Tensor, + ends: torch.Tensor, + ranges: tuple[tuple[int, int], ...], +) -> torch.Tensor: + if not ranges: + return torch.zeros_like(starts) + range_starts = torch.tensor([start for start, _ in ranges], dtype=torch.long) + range_ends = torch.tensor([end for _, end in ranges], dtype=torch.long) + range_lengths = range_ends - range_starts + prefix = torch.cat((range_lengths.new_zeros(1), torch.cumsum(range_lengths, dim=0))) + + def owned_before(points: torch.Tensor) -> torch.Tensor: + indices = torch.searchsorted(range_ends, points, right=False) + counts = prefix.index_select(0, indices) + active = indices < int(range_starts.numel()) + if bool(active.any().item()): + active_indices = indices[active] + active_starts = range_starts.index_select(0, active_indices) + active_ends = range_ends.index_select(0, active_indices) + counts[active] += torch.minimum( + torch.clamp(points[active] - active_starts, min=0), + active_ends - active_starts, + ) + return counts + + return owned_before(ends) - owned_before(starts) + + +def _segment_key(segment: GdnSegmentSpec) -> tuple[int, int, int]: + return (segment.row_index, segment.start, segment.end) + + +def _default_attention_layout_ranges( + spec: GdnPackedExecutionSpec, + *, + cp_size: int, + planner_config: GdnPlannerConfig, +) -> tuple[tuple[tuple[int, int, int], ...], ...]: + ranks: list[list[tuple[int, int, int]]] = [[] for _ in range(cp_size)] + loads = [0] * cp_size + + def append_segment(rank: int, token_start: int, token_count: int) -> None: + ranks[rank].append((token_start, token_start + token_count, loads[rank])) + loads[rank] += token_count + + for family in spec.families: + chain_family = _can_chain_family( + family, cp_size=cp_size, planner_config=planner_config + ) + if not chain_family: + if _should_co_locate_non_chain_family( + family, + total_real_tokens=spec.real_token_count, + cp_size=cp_size, + planner_config=planner_config, + ): + owner = _least_loaded_rank(loads) + for segment in (family.prefix, *family.completions): + token_start = _segment_token_start(segment, spec.sequence_length) + append_segment(owner, token_start, segment.length) + continue + for segment in (family.prefix, *family.completions): + token_start = _segment_token_start(segment, spec.sequence_length) + owner = _least_loaded_rank(loads) + append_segment(owner, token_start, segment.length) + continue + for segment in (family.prefix, *family.completions): + token_start = _segment_token_start(segment, spec.sequence_length) + if ( + segment.kind == "prefix" + and _can_chain_prefix_segment( + segment, cp_size=cp_size, planner_config=planner_config + ) + ) or _can_chain_segment( + segment, cp_size=cp_size, planner_config=planner_config + ): + _append_split_default_attention_segment( + ranks, loads, token_start, segment.length + ) + continue + owner = _least_loaded_rank(loads) + append_segment(owner, token_start, segment.length) + return tuple(tuple(ranges) for ranges in ranks) + + +def _should_co_locate_non_chain_family( + family: GdnPackedFamilySpec, + *, + total_real_tokens: int, + cp_size: int, + planner_config: GdnPlannerConfig, +) -> bool: + target_rank_load = total_real_tokens / cp_size + return family.token_count <= ( + planner_config.max_zero_exchange_load_imbalance * target_rank_load + ) + + +def _append_split_default_attention_segment( + ranks: list[list[tuple[int, int, int]]], + loads: list[int], + token_start: int, + token_count: int, +) -> None: + cp_size = len(ranks) + for rank in range(cp_size): + start = (token_count * rank) // cp_size + end = (token_count * (rank + 1)) // cp_size + ranks[rank].append((token_start + start, token_start + end, loads[rank])) + loads[rank] += end - start + + +def _append_chain_segment( + gdn_ranges_by_rank: list[list[tuple[int, int, int]]], + rank_loads: list[int], + segment: GdnSegmentSpec, + spec: GdnPackedExecutionSpec, + *, + attention_layout_index: _AttentionLayoutIndex | None = None, +) -> int: + token_start = _segment_token_start(segment, spec.sequence_length) + cp_size = len(gdn_ranges_by_rank) + attention_shards = _attention_contiguous_chain_shards( + token_start, + segment.length, + cp_size=cp_size, + attention_layout_index=attention_layout_index, + ) + if attention_shards is not None: + for rank, shard in enumerate(attention_shards): + position_start = rank_loads[rank] + gdn_ranges_by_rank[rank].append((shard.start, shard.stop, position_start)) + rank_loads[rank] += len(shard) + return 0 + cross_rank_tokens = 0 + shard_lengths = tuple( + (segment.length * (rank + 1)) // cp_size - (segment.length * rank) // cp_size + for rank in range(cp_size) + ) + start = 0 + for rank, shard_length in enumerate(shard_lengths): + end = start + shard_length + if start >= end: + raise ValueError( + "CP chain planning requires non-empty shards; " + f"segment={segment.kind}:{segment.family_index} " + f"length={segment.length} cp_size={cp_size}" + ) + shard_start = token_start + start + position_start = rank_loads[rank] + gdn_ranges_by_rank[rank].append( + (shard_start, shard_start + shard_length, position_start) + ) + rank_loads[rank] += shard_length + if attention_layout_index is not None: + cross_rank_tokens += shard_length - _attention_overlap_count( + attention_layout_index, + rank, + shard_start, + shard_start + shard_length, + ) + start = end + return cross_rank_tokens + + +def _chain_rank_token_indices( + segment: GdnSegmentSpec, + spec: GdnPackedExecutionSpec, + *, + cp_rank: int, + cp_size: int, +) -> range: + token_start = _segment_token_start(segment, spec.sequence_length) + start = (segment.length * cp_rank) // cp_size + end = (segment.length * (cp_rank + 1)) // cp_size + if start >= end: + raise ValueError( + "CP chain planning requires non-empty shards; " + f"segment={segment.kind}:{segment.family_index} " + f"length={segment.length} cp_size={cp_size}" + ) + return range(token_start + start, token_start + end) + + +def _attention_contiguous_chain_shards( + token_start: int, + token_count: int, + *, + cp_size: int, + attention_layout_index: _AttentionLayoutIndex | None, +) -> tuple[range, ...] | None: + if attention_layout_index is None: + return None + segment_end = token_start + token_count + shards: list[range] = [] + cursor = token_start + for rank in range(cp_size): + overlap = _attention_single_contiguous_overlap( + attention_layout_index, + rank, + token_start, + segment_end, + ) + if overlap is None: + return None + start, end = overlap + if start != cursor or end <= start: + return None + shards.append(range(start, end)) + cursor = end + if cursor != segment_end: + return None + return tuple(shards) + + +def _attention_single_contiguous_overlap( + index: _AttentionLayoutIndex, + rank: int, + start: int, + end: int, +) -> tuple[int, int] | None: + overlaps = _range_overlaps(start, end, index.token_ranges_by_rank[rank]) + if len(overlaps) != 1: + return None + return overlaps[0] + + +def _append_local_segment( + gdn_ranges_by_rank: list[list[tuple[int, int, int]]], + rank_loads: list[int], + rank: int, + segment: GdnSegmentSpec, + spec: GdnPackedExecutionSpec, + *, + segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], +) -> int: + token_start = _segment_token_start(segment, spec.sequence_length) + position_start = rank_loads[rank] + gdn_ranges_by_rank[rank].append( + (token_start, token_start + segment.length, position_start) + ) + rank_loads[rank] += segment.length + return segment.length - segment_attention_counts[_segment_key(segment)][rank] + + +def _least_loaded_rank(rank_loads: list[int]) -> int: + return min(range(len(rank_loads)), key=lambda rank: (rank_loads[rank], rank)) + + +def _owner_rank( + local_prefix_segments_by_rank: list[list[GdnSegmentSpec]], + prefix: GdnSegmentSpec, +) -> int: + for rank, segments in enumerate(local_prefix_segments_by_rank): + if prefix in segments: + return rank + raise RuntimeError("local prefix owner was not recorded") + + +def _build_position_bucket_plans( + segment_buckets: tuple[tuple[GdnSegmentSpec, ...], ...], + local_token_ranges: tuple[tuple[int, int, int], ...], + *, + sequence_length: int, + device: torch.device | str, +) -> tuple[GdnSegmentBucketPlan, ...]: + return tuple( + _build_position_bucket_plan( + bucket, + local_token_ranges, + sequence_length=sequence_length, + device=device, + ) + for bucket in segment_buckets + ) + + +def _build_position_bucket_plan( + segments: tuple[GdnSegmentSpec, ...], + local_token_ranges: tuple[tuple[int, int, int], ...], + *, + sequence_length: int, + device: torch.device | str, +) -> GdnSegmentBucketPlan: + exact_plan = _build_exact_range_position_bucket_plan( + segments, + local_token_ranges, + sequence_length=sequence_length, + device=device, + ) + if exact_plan is not None: + return exact_plan + local_positions_by_segment = [] + lengths = [] + local_range_ends = tuple(token_end for _, token_end, _ in local_token_ranges) + for segment in segments: + positions = _local_positions_for_segment( + segment, + sequence_length=sequence_length, + local_token_ranges=local_token_ranges, + local_range_ends=local_range_ends, + ) + length = int(positions.numel()) + if not length: + raise ValueError( + "planned GDN bucket contains a segment with no local tokens; " + f"family={segment.family_index} kind={segment.kind}" + ) + local_positions_by_segment.append(positions) + lengths.append(length) + max_length = max(lengths) + lengths_cpu = torch.tensor(lengths, dtype=torch.long) + offsets_cpu = torch.arange(max_length, dtype=torch.long).unsqueeze(1) + real_mask_cpu = offsets_cpu < lengths_cpu.unsqueeze(0) + position_indices_cpu = torch.zeros(max_length, len(segments), dtype=torch.long) + for column, positions in enumerate(local_positions_by_segment): + position_indices_cpu[: int(positions.numel()), column] = positions + cu_seqlens_cpu = torch.cat( + [lengths_cpu.new_zeros(1), torch.cumsum(lengths_cpu, dim=0)] + ) + row_indices_cpu = torch.zeros(max_length, len(segments), dtype=torch.long) + family_indices_cpu = torch.tensor( + [segment.family_index for segment in segments], + dtype=torch.long, + ) + return GdnSegmentBucketPlan.model_construct( + length=max_length, + lengths=_move_planner_tensor(lengths_cpu, device), + real_mask=_move_planner_tensor(real_mask_cpu, device), + cu_seqlens=_move_planner_tensor(cu_seqlens_cpu, device), + row_indices=_move_planner_tensor(row_indices_cpu, device), + position_indices=_move_planner_tensor(position_indices_cpu, device), + family_indices=_move_planner_tensor(family_indices_cpu, device), + real_token_count_static=sum(lengths), + ) + + +def _build_exact_range_position_bucket_plan( + segments: tuple[GdnSegmentSpec, ...], + local_token_ranges: tuple[tuple[int, int, int], ...], + *, + sequence_length: int, + device: torch.device | str, +) -> GdnSegmentBucketPlan | None: + range_positions = { + (start, end): position for start, end, position in local_token_ranges + } + starts = [] + lengths = [] + for segment in segments: + token_start = _segment_token_start(segment, sequence_length) + token_end = token_start + segment.length + position_start = range_positions.get((token_start, token_end)) + if position_start is None: + return None + starts.append(position_start) + lengths.append(segment.length) + max_length = max(lengths) + starts_cpu = torch.tensor(starts, dtype=torch.long) + lengths_cpu = torch.tensor(lengths, dtype=torch.long) + offsets_cpu = torch.arange(max_length, dtype=torch.long).unsqueeze(1) + real_mask_cpu = offsets_cpu < lengths_cpu.unsqueeze(0) + position_indices_cpu = torch.where( + real_mask_cpu, + starts_cpu.unsqueeze(0) + offsets_cpu, + torch.zeros_like(offsets_cpu), + ) + cu_seqlens_cpu = torch.cat( + [lengths_cpu.new_zeros(1), torch.cumsum(lengths_cpu, dim=0)] + ) + row_indices_cpu = torch.zeros(max_length, len(segments), dtype=torch.long) + family_indices_cpu = torch.tensor( + [segment.family_index for segment in segments], + dtype=torch.long, + ) + return GdnSegmentBucketPlan.model_construct( + length=max_length, + lengths=_move_planner_tensor(lengths_cpu, device), + real_mask=_move_planner_tensor(real_mask_cpu, device), + cu_seqlens=_move_planner_tensor(cu_seqlens_cpu, device), + row_indices=_move_planner_tensor(row_indices_cpu, device), + position_indices=_move_planner_tensor(position_indices_cpu, device), + family_indices=_move_planner_tensor(family_indices_cpu, device), + real_token_count_static=sum(lengths), + ) + + +def _move_planner_tensor( + tensor: torch.Tensor, device: torch.device | str +) -> torch.Tensor: + target = torch.device(device) + if target.type == "cpu": + return tensor + return tensor.to(device=target) + + +def _batch_segments_by_padded_work( + segments: tuple[GdnSegmentSpec, ...], + *, + max_padding_ratio: float = 1.25, + max_segments_per_batch: int = 128, +) -> tuple[tuple[GdnSegmentSpec, ...], ...]: + if not segments: + return () + ordered = sorted( + segments, key=lambda segment: (segment.length, segment.family_index) + ) + batches: list[list[GdnSegmentSpec]] = [] + current: list[GdnSegmentSpec] = [] + current_tokens = 0 + current_max = 0 + for segment in ordered: + next_count = len(current) + 1 + next_tokens = current_tokens + segment.length + next_max = max(current_max, segment.length) + padded = next_max * next_count + can_extend = not current or ( + next_count <= max_segments_per_batch + and padded <= max_padding_ratio * next_tokens + ) + if not can_extend: + batches.append(current) + current = [] + current_tokens = 0 + current_max = 0 + current.append(segment) + current_tokens += segment.length + current_max = max(current_max, segment.length) + if current: + batches.append(current) + return tuple(tuple(batch) for batch in batches) + + +def _build_segment_bucket_plan( + length: int, segments: tuple[GdnSegmentSpec, ...], *, device: torch.device | str +) -> GdnSegmentBucketPlan: + max_length = max(segment.length for segment in segments) + lengths = torch.tensor( + [segment.length for segment in segments], device=device, dtype=torch.long + ) + starts = torch.tensor( + [segment.start for segment in segments], device=device, dtype=torch.long + ) + rows = torch.tensor( + [segment.row_index for segment in segments], device=device, dtype=torch.long + ) + offsets = torch.arange(max_length, device=device, dtype=torch.long).unsqueeze(1) + real_mask = offsets < lengths.unsqueeze(0) + positions = starts.unsqueeze(0) + offsets + return GdnSegmentBucketPlan.model_construct( + length=max_length, + lengths=lengths, + real_mask=real_mask, + cu_seqlens=torch.cat([lengths.new_zeros(1), torch.cumsum(lengths, dim=0)]), + row_indices=rows.unsqueeze(0).expand(max_length, -1).contiguous(), + position_indices=positions, + family_indices=torch.tensor( + [segment.family_index for segment in segments], + device=device, + dtype=torch.long, + ), + real_token_count_static=sum(segment.length for segment in segments), + ) + + +def _segment_token_start(segment: GdnSegmentSpec, sequence_length: int) -> int: + return segment.row_index * sequence_length + segment.start + + +def _attention_overlap_count( + index: _AttentionLayoutIndex, + rank: int, + start: int, + end: int, +) -> int: + return _range_overlap_count( + start, + end, + index.token_ranges_by_rank[rank], + index.token_range_ends_by_rank[rank], + ) + + +def _range_overlap_count( + start: int, + end: int, + ranges: tuple[tuple[int, int], ...], + range_ends: tuple[int, ...], +) -> int: + count = 0 + range_index = bisect_left(range_ends, start + 1) + for range_start, range_end in ranges[range_index:]: + if range_start >= end: + break + count += min(end, range_end) - max(start, range_start) + return count + + +def _range_overlaps( + start: int, + end: int, + ranges: tuple[tuple[int, int], ...], +) -> list[tuple[int, int]]: + overlaps = [ + (max(start, range_start), min(end, range_end)) + for range_start, range_end in ranges + if max(start, range_start) < min(end, range_end) + ] + overlaps.sort() + return overlaps + + +def _local_token_ranges( + local_gdn_tokens: tuple[int, ...], +) -> tuple[tuple[int, int, int], ...]: + if not local_gdn_tokens: + return () + ranges = [] + token_start = local_gdn_tokens[0] + token_end = token_start + 1 + position_start = 0 + for position, token in enumerate(local_gdn_tokens[1:], start=1): + if token == token_end: + token_end += 1 + continue + ranges.append((token_start, token_end, position_start)) + token_start = token + token_end = token + 1 + position_start = position + ranges.append((token_start, token_end, position_start)) + return tuple(ranges) + + +def _local_positions_for_segment( + segment: GdnSegmentSpec, + *, + sequence_length: int, + local_token_ranges: tuple[tuple[int, int, int], ...], + local_range_ends: tuple[int, ...], +) -> torch.Tensor: + segment_start = _segment_token_start(segment, sequence_length) + segment_end = segment_start + segment.length + pieces = [] + range_index = bisect_left(local_range_ends, segment_start + 1) + for token_start, token_end, position_start in local_token_ranges[range_index:]: + if token_start >= segment_end: + break + overlap_start = max(segment_start, token_start) + overlap_end = min(segment_end, token_end) + if overlap_start >= overlap_end: + continue + pieces.append( + torch.arange( + position_start + overlap_start - token_start, + position_start + overlap_end - token_start, + dtype=torch.long, + ) + ) + if not pieces: + return torch.empty((0,), dtype=torch.long) + if len(pieces) == 1: + return pieces[0] + return torch.cat(pieces) + + +def _rank2_long_cpu(name: str, tensor: torch.Tensor) -> torch.Tensor: + if not torch.is_tensor(tensor): + raise TypeError(f"{name} must be a torch.Tensor") + if tensor.ndim != 2: + raise ValueError(f"{name} must be rank 2 [batch, sequence], got {tensor.ndim}") + if tensor.dtype not in ( + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.long, + ): + raise TypeError(f"{name} must contain integer ids, got dtype={tensor.dtype}") + return tensor.detach().to(device="cpu", dtype=torch.long) + + +def _validate_padding_tensor( + row_index: int, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, +) -> int: + padding_positions = torch.nonzero(group_ids == -1, as_tuple=False) + valid_length = ( + int(padding_positions[0].item()) + if int(padding_positions.numel()) > 0 + else int(group_ids.numel()) + ) + if valid_length == 0: + if bool(torch.any(parent_ids != -1).item()): + raise ValueError(f"row {row_index}: padding parent_ids must be -1") + return 0 + if bool(torch.any(group_ids[valid_length:] != -1).item()): + raise ValueError( + f"row {row_index}: valid tokens must be contiguous before padding" + ) + if bool(torch.any(parent_ids[:valid_length] == -1).item()): + raise ValueError( + f"row {row_index}: valid tokens must have non-padding parent_ids" + ) + if bool(torch.any(parent_ids[valid_length:] != -1).item()): + raise ValueError(f"row {row_index}: padding parent_ids must be -1") + return valid_length + + +def _validate_padding( + row_index: int, + group_ids: list[int], + parent_ids: list[int], +) -> int: + valid_length = 0 + for group_id in group_ids: + if group_id == -1: + break + valid_length += 1 + if valid_length == 0: + if any(parent_id != -1 for parent_id in parent_ids): + raise ValueError(f"row {row_index}: padding parent_ids must be -1") + return 0 + if any(group_id != -1 for group_id in group_ids[valid_length:]): + raise ValueError( + f"row {row_index}: valid tokens must be contiguous before padding" + ) + if any(parent_id == -1 for parent_id in parent_ids[:valid_length]): + raise ValueError( + f"row {row_index}: valid tokens must have non-padding parent_ids" + ) + if any(parent_id != -1 for parent_id in parent_ids[valid_length:]): + raise ValueError(f"row {row_index}: padding parent_ids must be -1") + return valid_length + + +def _parse_row_tensor( + *, + row_index: int, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + valid_length: int, + first_family_index: int, + min_completions_per_family: int, +) -> list[GdnPackedFamilySpec]: + valid_groups = group_ids[:valid_length] + valid_parents = parent_ids[:valid_length] + if valid_length > 1: + same_group = valid_groups[1:] == valid_groups[:-1] + parent_changed = same_group & (valid_parents[1:] != valid_parents[:-1]) + if bool(torch.any(parent_changed).item()): + position = int(torch.nonzero(parent_changed, as_tuple=False)[0].item()) + 1 + group_id = int(valid_groups[position].item()) + previous_parent = int(valid_parents[position - 1].item()) + current_parent = int(valid_parents[position].item()) + raise ValueError( + f"row {row_index}: group {group_id} changes parent from " + f"{previous_parent} to {current_parent}" + ) + boundaries = torch.nonzero(~same_group, as_tuple=False).flatten() + 1 + starts_tensor = torch.cat( + (valid_groups.new_zeros(1), boundaries.to(valid_groups.dtype)) + ) + ends_tensor = torch.cat( + ( + boundaries.to(valid_groups.dtype), + valid_groups.new_tensor([valid_length]), + ) + ) + else: + starts_tensor = valid_groups.new_zeros(1) + ends_tensor = valid_groups.new_tensor([valid_length]) + + starts = tuple(int(value) for value in starts_tensor.tolist()) + ends = tuple(int(value) for value in ends_tensor.tolist()) + segment_group_ids = tuple(int(valid_groups[start].item()) for start in starts) + segment_parent_ids = tuple(int(valid_parents[start].item()) for start in starts) + families: list[GdnPackedFamilySpec] = [] + seen_groups: set[int] = set() + segment_cursor = 0 + while segment_cursor < len(starts): + group_id = segment_group_ids[segment_cursor] + parent_id = segment_parent_ids[segment_cursor] + start = starts[segment_cursor] + end = ends[segment_cursor] + if group_id in seen_groups: + raise ValueError(f"row {row_index}: group_id {group_id} is non-contiguous") + if group_id != parent_id: + raise ValueError( + f"row {row_index}: completion group {group_id} appears before " + f"its prefix parent {parent_id}" + ) + seen_groups.add(group_id) + family_index = first_family_index + len(families) + prefix = _trusted_pydantic_construct( + GdnSegmentSpec, + _GDN_SEGMENT_SPEC_FIELDS, + row_index=row_index, + family_index=family_index, + group_id=group_id, + parent_id=parent_id, + start=start, + end=end, + kind="prefix", + child_index=None, + ) + segment_cursor += 1 + completions: list[GdnSegmentSpec] = [] + while segment_cursor < len(starts): + child_group_id = segment_group_ids[segment_cursor] + child_parent_id = segment_parent_ids[segment_cursor] + child_start = starts[segment_cursor] + child_end = ends[segment_cursor] + if child_group_id == child_parent_id: + break + if child_parent_id != group_id: + raise ValueError( + f"row {row_index}: completion group {child_group_id} has " + f"parent {child_parent_id}, expected active prefix {group_id}" + ) + if child_group_id in seen_groups: + raise ValueError( + f"row {row_index}: group_id {child_group_id} is non-contiguous" + ) + seen_groups.add(child_group_id) + completions.append( + _trusted_pydantic_construct( + GdnSegmentSpec, + _GDN_SEGMENT_SPEC_FIELDS, + row_index=row_index, + family_index=family_index, + group_id=child_group_id, + parent_id=child_parent_id, + start=child_start, + end=child_end, + kind="completion", + child_index=len(completions), + ) + ) + segment_cursor += 1 + if len(completions) < min_completions_per_family: + raise ValueError( + f"row {row_index}: prefix group {group_id} has {len(completions)} " + f"completion(s), expected at least {min_completions_per_family}" + ) + families.append( + _trusted_pydantic_construct( + GdnPackedFamilySpec, + _GDN_PACKED_FAMILY_SPEC_FIELDS, + row_index=row_index, + family_index=family_index, + prefix=prefix, + completions=tuple(completions), + ) + ) + return families + + +def _parse_row( + *, + row_index: int, + group_ids: list[int], + parent_ids: list[int], + valid_length: int, + first_family_index: int, + min_completions_per_family: int, +) -> list[GdnPackedFamilySpec]: + families: list[GdnPackedFamilySpec] = [] + seen_groups: set[int] = set() + cursor = 0 + while cursor < valid_length: + group_id, parent_id, start, end = _read_segment( + row_index, group_ids, parent_ids, valid_length, cursor + ) + if group_id in seen_groups: + raise ValueError(f"row {row_index}: group_id {group_id} is non-contiguous") + if group_id != parent_id: + raise ValueError( + f"row {row_index}: completion group {group_id} appears before " + f"its prefix parent {parent_id}" + ) + seen_groups.add(group_id) + family_index = first_family_index + len(families) + prefix = GdnSegmentSpec( + row_index=row_index, + family_index=family_index, + group_id=group_id, + parent_id=parent_id, + start=start, + end=end, + kind="prefix", + ) + cursor = end + completions: list[GdnSegmentSpec] = [] + while cursor < valid_length: + child_group_id, child_parent_id, child_start, child_end = _read_segment( + row_index, group_ids, parent_ids, valid_length, cursor + ) + if child_group_id == child_parent_id: + break + if child_parent_id != group_id: + raise ValueError( + f"row {row_index}: completion group {child_group_id} has " + f"parent {child_parent_id}, expected active prefix {group_id}" + ) + if child_group_id in seen_groups: + raise ValueError( + f"row {row_index}: group_id {child_group_id} is non-contiguous" + ) + seen_groups.add(child_group_id) + completions.append( + GdnSegmentSpec( + row_index=row_index, + family_index=family_index, + group_id=child_group_id, + parent_id=child_parent_id, + start=child_start, + end=child_end, + kind="completion", + child_index=len(completions), + ) + ) + cursor = child_end + if len(completions) < min_completions_per_family: + raise ValueError( + f"row {row_index}: prefix group {group_id} has {len(completions)} " + f"completion(s), expected at least {min_completions_per_family}" + ) + families.append( + GdnPackedFamilySpec( + row_index=row_index, + family_index=family_index, + prefix=prefix, + completions=tuple(completions), + ) + ) + return families + + +def _read_segment( + row_index: int, + group_ids: list[int], + parent_ids: list[int], + valid_length: int, + cursor: int, +) -> tuple[int, int, int, int]: + group_id = int(group_ids[cursor]) + parent_id = int(parent_ids[cursor]) + if group_id < 0 or parent_id < 0: + raise ValueError(f"row {row_index}: segment ids must be non-negative") + start = cursor + cursor += 1 + while cursor < valid_length and int(group_ids[cursor]) == group_id: + current_parent = int(parent_ids[cursor]) + if current_parent != parent_id: + raise ValueError( + f"row {row_index}: group {group_id} changes parent from " + f"{parent_id} to {current_parent}" + ) + cursor += 1 + return group_id, parent_id, start, cursor diff --git a/src/art/megatron/gdn/layout.py b/src/art/megatron/gdn/layout.py new file mode 100644 index 000000000..0af2961c5 --- /dev/null +++ b/src/art/megatron/gdn/layout.py @@ -0,0 +1,882 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, model_validator +import torch +from torch import Tensor +from torch.distributed import ( + all_to_all_single, + get_world_size, +) +from torch.distributed import ( + is_available as dist_is_available, +) +from torch.distributed import ( + is_initialized as dist_is_initialized, +) + +from art.megatron.context_parallel.layout_index import TokenLayoutIndex + + +class GdnCpPeerTransfer(BaseModel): + """Token rows sent from one source rank to one destination rank.""" + + model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) + + source_rank: int = Field(ge=0) + dest_rank: int = Field(ge=0) + token_count: int = Field(ge=0) + source_positions_tensor: Tensor | None = None + dest_positions_tensor: Tensor | None = None + + @model_validator(mode="after") + def _same_lengths(self) -> "GdnCpPeerTransfer": + lengths = {int(self.token_count)} + if self.source_positions_tensor is not None: + lengths.add(int(self.source_positions_tensor.numel())) + if self.dest_positions_tensor is not None: + lengths.add(int(self.dest_positions_tensor.numel())) + if len(lengths) != 1: + raise ValueError("token, source, and destination position counts differ") + return self + + +class GdnCpExchangePlan(BaseModel): + """Permutation/all-to-all metadata between two distributed token layouts.""" + + model_config = ConfigDict(frozen=True) + + cp_size: int = Field(ge=1) + source_token_counts_by_rank: tuple[int, ...] + dest_token_counts_by_rank: tuple[int, ...] + transfers: tuple[GdnCpPeerTransfer, ...] + cross_rank_token_count_override: int | None = Field(default=None, ge=0) + + @model_validator(mode="after") + def _rank_counts(self) -> "GdnCpExchangePlan": + if len(self.source_token_counts_by_rank) != self.cp_size: + raise ValueError("source token count length must equal cp_size") + if len(self.dest_token_counts_by_rank) != self.cp_size: + raise ValueError("destination token count length must equal cp_size") + return self + + @property + def cross_rank_token_count(self) -> int: + if self.cross_rank_token_count_override is not None: + return int(self.cross_rank_token_count_override) + return sum( + _transfer_token_count(transfer) + for transfer in self.transfers + if transfer.source_rank != transfer.dest_rank + ) + + +def _layout_cp_size(layout: TokenLayoutIndex) -> int: + return len(layout.token_counts_by_rank) + + +def _token_layout_from_rank_ranges( + ranges_by_rank: Sequence[Sequence[tuple[int, int, int]]], +) -> TokenLayoutIndex: + ranges = _normalize_rank_ranges( + "ranges_by_rank", + ranges_by_rank, + cp_size=len(ranges_by_rank), + ) + return TokenLayoutIndex( + ownership_ranges_by_rank=ranges, + token_counts_by_rank=tuple( + _rank_range_count(rank_ranges) for rank_ranges in ranges + ), + ) + + +def _normalize_rank_ranges( + name: str, + values: Sequence[Sequence[tuple[int, int, int]]], + *, + cp_size: int, +) -> tuple[tuple[tuple[int, int, int], ...], ...]: + if len(values) != cp_size: + raise ValueError(f"{name} must have {cp_size} ranks, got {len(values)}") + normalized = [] + for rank, rank_ranges in enumerate(values): + cursor = 0 + normalized_rank = [] + for start, end, position in rank_ranges: + start = int(start) + end = int(end) + position = int(position) + if start < 0 or end < start: + raise ValueError(f"{name}[{rank}] has invalid range {(start, end)}") + if position != cursor: + raise ValueError( + f"{name}[{rank}] positions must be contiguous; " + f"expected {cursor}, got {position}" + ) + normalized_rank.append((start, end, position)) + cursor += end - start + normalized.append(tuple(normalized_rank)) + return tuple(normalized) + + +def _rank_range_count(ranges: Sequence[tuple[int, int, int]]) -> int: + return sum(int(end) - int(start) for start, end, _ in ranges) + + +def _intersection_position_tensors( + source_ranges: Sequence[tuple[int, int, int]], + dest_ranges: Sequence[tuple[int, int, int]], +) -> tuple[Tensor, Tensor]: + source_sorted = sorted(source_ranges, key=lambda item: (item[0], item[1])) + dest_sorted = sorted(dest_ranges, key=lambda item: (item[0], item[1])) + source_starts: list[int] = [] + dest_starts: list[int] = [] + lengths: list[int] = [] + source_index = 0 + dest_index = 0 + while source_index < len(source_sorted) and dest_index < len(dest_sorted): + source_start, source_end, source_pos = source_sorted[source_index] + dest_start, dest_end, dest_pos = dest_sorted[dest_index] + overlap_start = max(source_start, dest_start) + overlap_end = min(source_end, dest_end) + if overlap_start < overlap_end: + source_starts.append(source_pos + overlap_start - source_start) + dest_starts.append(dest_pos + overlap_start - dest_start) + lengths.append(overlap_end - overlap_start) + if source_end <= dest_end: + source_index += 1 + else: + dest_index += 1 + if not lengths: + empty = torch.empty((0,), dtype=torch.long) + return empty, empty + lengths_tensor = torch.tensor(lengths, dtype=torch.long) + total = int(lengths_tensor.sum().item()) + range_offsets = torch.cumsum(lengths_tensor, dim=0) - lengths_tensor + item_offsets = torch.arange(total, dtype=torch.long) - torch.repeat_interleave( + range_offsets, + lengths_tensor, + ) + return ( + torch.repeat_interleave( + torch.tensor(source_starts, dtype=torch.long), + lengths_tensor, + ) + + item_offsets, + torch.repeat_interleave( + torch.tensor(dest_starts, dtype=torch.long), + lengths_tensor, + ) + + item_offsets, + ) + + +def _merged_token_ranges( + ranges_by_rank: Sequence[Sequence[tuple[int, int, int]]], +) -> tuple[tuple[int, int], ...]: + ranges = sorted( + (int(start), int(end)) + for rank_ranges in ranges_by_rank + for start, end, _ in rank_ranges + if int(start) < int(end) + ) + if not ranges: + return () + merged = [ranges[0]] + for start, end in ranges[1:]: + prev_start, prev_end = merged[-1] + if start <= prev_end: + merged[-1] = (prev_start, max(prev_end, end)) + else: + merged.append((start, end)) + return tuple(merged) + + +def _range_list_count(ranges: Sequence[tuple[int, int]]) -> int: + return sum(int(end) - int(start) for start, end in ranges) + + +def build_cp_exchange_plan_from_layout_index( + *, + source_layout: TokenLayoutIndex, + dest_layout: TokenLayoutIndex, + device: torch.device | str | None, + validate: bool = True, + local_rank: int | None = None, +) -> GdnCpExchangePlan: + cp_size = _layout_cp_size(source_layout) + if _layout_cp_size(dest_layout) != cp_size: + raise ValueError( + "source and destination cp_size differ: " + f"{cp_size} and {_layout_cp_size(dest_layout)}" + ) + if local_rank is not None and (local_rank < 0 or local_rank >= cp_size): + raise ValueError(f"local_rank must be in [0, {cp_size}), got {local_rank}") + if validate: + _validate_layout_token_sets_match(source_layout, dest_layout) + source_counts = source_layout.token_counts_by_rank + dest_counts = dest_layout.token_counts_by_rank + transfers: list[GdnCpPeerTransfer] = [] + cross_rank_token_count = 0 + for source_rank, source_ranges in enumerate(source_layout.ownership_ranges_by_rank): + for dest_rank, dest_ranges in enumerate(dest_layout.ownership_ranges_by_rank): + source_positions, dest_positions = _intersection_position_tensors( + source_ranges, + dest_ranges, + ) + token_count = int(source_positions.numel()) + if token_count == 0: + continue + if source_rank != dest_rank: + cross_rank_token_count += token_count + if ( + local_rank is not None + and source_rank != local_rank + and dest_rank != local_rank + ): + continue + transfers.append( + _make_peer_transfer( + source_rank=source_rank, + dest_rank=dest_rank, + source_positions=source_positions, + dest_positions=dest_positions, + source_count=source_counts[source_rank], + dest_count=dest_counts[dest_rank], + device=device, + ) + ) + return GdnCpExchangePlan.model_construct( + cp_size=cp_size, + source_token_counts_by_rank=source_counts, + dest_token_counts_by_rank=dest_counts, + transfers=tuple( + sorted(transfers, key=lambda item: (item.source_rank, item.dest_rank)) + ), + cross_rank_token_count_override=cross_rank_token_count, + ) + + +def build_local_rank_cp_exchange_plan_from_dest_ranges( + *, + source_layout: TokenLayoutIndex, + dest_ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...], + device: torch.device | str | None, + local_rank: int, + cross_rank_token_count: int, +) -> GdnCpExchangePlan: + cp_size = _layout_cp_size(source_layout) + if len(dest_ranges_by_rank) != cp_size: + raise ValueError("destination range rank count must equal cp_size") + if local_rank < 0 or local_rank >= cp_size: + raise ValueError(f"local_rank must be in [0, {cp_size}), got {local_rank}") + dest_ranges_by_rank = _normalize_rank_ranges( + "dest_ranges_by_rank", + dest_ranges_by_rank, + cp_size=cp_size, + ) + dest_counts = tuple( + sum(int(end) - int(start) for start, end, _ in ranges) + for ranges in dest_ranges_by_rank + ) + transfers = [] + for dest_rank, ranges in enumerate(dest_ranges_by_rank): + source_ranks = range(cp_size) if dest_rank == local_rank else (local_rank,) + for source_rank in source_ranks: + source_positions, dest_positions = _intersection_position_tensors( + source_layout.ownership_ranges_by_rank[source_rank], + ranges, + ) + if not int(source_positions.numel()): + continue + transfers.append( + _make_peer_transfer( + source_rank=source_rank, + dest_rank=dest_rank, + source_positions=source_positions, + dest_positions=dest_positions, + source_count=source_layout.token_counts_by_rank[source_rank], + dest_count=dest_counts[dest_rank], + device=device, + ) + ) + return GdnCpExchangePlan.model_construct( + cp_size=cp_size, + source_token_counts_by_rank=source_layout.token_counts_by_rank, + dest_token_counts_by_rank=dest_counts, + transfers=tuple( + sorted(transfers, key=lambda item: (item.source_rank, item.dest_rank)) + ), + cross_rank_token_count_override=int(cross_rank_token_count), + ) + + +def _validate_layout_token_sets_match( + source_layout: TokenLayoutIndex, + dest_layout: TokenLayoutIndex, +) -> None: + source_ranges = _merged_token_ranges(source_layout.ownership_ranges_by_rank) + dest_ranges = _merged_token_ranges(dest_layout.ownership_ranges_by_rank) + if ( + source_ranges != dest_ranges + or sum(source_layout.token_counts_by_rank) != _range_list_count(source_ranges) + or sum(dest_layout.token_counts_by_rank) != _range_list_count(dest_ranges) + ): + raise ValueError( + "source and destination token layouts must cover the same tokens" + ) + + +def _make_peer_transfer( + *, + source_rank: int, + dest_rank: int, + source_positions: Tensor, + dest_positions: Tensor, + source_count: int, + dest_count: int, + device: torch.device | str | None, +) -> GdnCpPeerTransfer: + token_count = int(source_positions.numel()) + if token_count != int(dest_positions.numel()): + raise ValueError("source and destination position counts differ") + if _is_full_identity_transfer( + source_rank=source_rank, + dest_rank=dest_rank, + source_positions=source_positions, + dest_positions=dest_positions, + source_count=source_count, + dest_count=dest_count, + ): + source_tensor = None + dest_tensor = None + else: + target = torch.device(device) if device is not None else torch.device("cpu") + source_tensor = source_positions.to( + device=target, dtype=torch.long + ).contiguous() + dest_tensor = dest_positions.to(device=target, dtype=torch.long).contiguous() + return GdnCpPeerTransfer.model_construct( + source_rank=source_rank, + dest_rank=dest_rank, + token_count=token_count, + source_positions_tensor=source_tensor, + dest_positions_tensor=dest_tensor, + ) + + +def _is_full_identity_transfer( + *, + source_rank: int, + dest_rank: int, + source_positions: Tensor, + dest_positions: Tensor, + source_count: int, + dest_count: int, +) -> bool: + if source_rank != dest_rank or source_count != dest_count: + return False + if int(source_positions.numel()) != int(source_count): + return False + if int(dest_positions.numel()) != int(dest_count): + return False + expected = torch.arange(int(source_count), dtype=torch.long) + return bool(torch.equal(source_positions.cpu(), expected)) and bool( + torch.equal(dest_positions.cpu(), expected) + ) + + +def _reverse_exchange_plan(plan: GdnCpExchangePlan) -> GdnCpExchangePlan: + return GdnCpExchangePlan.model_construct( + cp_size=plan.cp_size, + source_token_counts_by_rank=_dest_counts_by_rank(plan), + dest_token_counts_by_rank=_source_counts_by_rank(plan), + cross_rank_token_count_override=plan.cross_rank_token_count_override, + transfers=tuple( + GdnCpPeerTransfer.model_construct( + source_rank=transfer.dest_rank, + dest_rank=transfer.source_rank, + token_count=_transfer_token_count(transfer), + source_positions_tensor=transfer.dest_positions_tensor, + dest_positions_tensor=transfer.source_positions_tensor, + ) + for transfer in sorted( + plan.transfers, key=lambda item: (item.dest_rank, item.source_rank) + ) + ), + ) + + +def move_cp_exchange_plan_to_device( + plan: GdnCpExchangePlan | None, + device: torch.device | str, +) -> GdnCpExchangePlan | None: + if plan is None: + return None + target = torch.device(device) + return GdnCpExchangePlan.model_construct( + cp_size=plan.cp_size, + source_token_counts_by_rank=_source_counts_by_rank(plan), + dest_token_counts_by_rank=_dest_counts_by_rank(plan), + transfers=tuple( + GdnCpPeerTransfer.model_construct( + source_rank=transfer.source_rank, + dest_rank=transfer.dest_rank, + token_count=transfer.token_count, + source_positions_tensor=_move_index_tensor_if_present( + transfer.source_positions_tensor, target + ), + dest_positions_tensor=_move_index_tensor_if_present( + transfer.dest_positions_tensor, target + ), + ) + for transfer in plan.transfers + ), + cross_rank_token_count_override=plan.cross_rank_token_count_override, + ) + + +def _move_index_tensor_if_present( + tensor: Tensor | None, device: torch.device +) -> Tensor | None: + if tensor is None or tensor.device == device: + return tensor + return tensor.to(device=device) + + +def send_split_sizes_for_rank(plan: GdnCpExchangePlan, rank: int) -> tuple[int, ...]: + _check_rank(plan, rank) + return tuple( + _transfer_token_count(_transfer(plan, source_rank=rank, dest_rank=dest_rank)) + for dest_rank in range(plan.cp_size) + ) + + +def recv_split_sizes_for_rank(plan: GdnCpExchangePlan, rank: int) -> tuple[int, ...]: + _check_rank(plan, rank) + return tuple( + _transfer_token_count(_transfer(plan, source_rank=source_rank, dest_rank=rank)) + for source_rank in range(plan.cp_size) + ) + + +def pack_rank_send_tensor( + local_tensor: Tensor, + plan: GdnCpExchangePlan, + *, + source_rank: int, +) -> Tensor: + """Pack one rank's local tensor in peer order for `all_to_all_single`.""" + + _check_rank(plan, source_rank) + expected_rows = _source_count_for_rank(plan, source_rank) + if int(local_tensor.shape[0]) != expected_rows: + raise ValueError( + f"rank {source_rank} tensor has {int(local_tensor.shape[0])} rows, " + f"expected {expected_rows}" + ) + pieces = [] + for dest_rank in range(plan.cp_size): + transfer = _transfer(plan, source_rank=source_rank, dest_rank=dest_rank) + if _transfer_token_count(transfer): + if _is_implicit_full_identity_transfer( + transfer, + source_count=_source_count_for_rank(plan, source_rank), + dest_count=_dest_count_for_rank(plan, dest_rank), + ): + pieces.append(local_tensor) + else: + index = _transfer_index_tensor( + transfer.source_positions_tensor, + device=local_tensor.device, + ) + pieces.append(local_tensor.index_select(0, index)) + if not pieces: + return local_tensor.new_empty((0, *local_tensor.shape[1:])) + return torch.cat(pieces, dim=0) + + +def unpack_rank_recv_tensor( + recv_buffer: Tensor, + plan: GdnCpExchangePlan, + *, + dest_rank: int, +) -> Tensor: + """Unpack one rank's `all_to_all_single` receive buffer into destination order.""" + + _check_rank(plan, dest_rank) + expected_rows = sum(recv_split_sizes_for_rank(plan, dest_rank)) + if int(recv_buffer.shape[0]) != expected_rows: + raise ValueError( + f"rank {dest_rank} recv buffer has {int(recv_buffer.shape[0])} rows, " + f"expected {expected_rows}" + ) + dest_rows = _dest_count_for_rank(plan, dest_rank) + output = recv_buffer.new_empty((dest_rows, *recv_buffer.shape[1:])) + offset = 0 + for source_rank in range(plan.cp_size): + transfer = _transfer(plan, source_rank=source_rank, dest_rank=dest_rank) + rows = _transfer_token_count(transfer) + peer_rows = recv_buffer[offset : offset + rows] + offset += rows + if rows == 0: + continue + if _is_implicit_full_identity_transfer( + transfer, + source_count=_source_count_for_rank(plan, source_rank), + dest_count=dest_rows, + ): + output.copy_(peer_rows) + continue + dest_index = _transfer_index_tensor( + transfer.dest_positions_tensor, + device=recv_buffer.device, + ) + output.index_copy_(0, dest_index, peer_rows) + if dest_rows == 0: + return recv_buffer.new_empty((0, *recv_buffer.shape[1:])) + return output + + +@torch.compiler.disable +def exchange_rank_tensor_all_to_all( + local_tensor: Tensor, + plan: GdnCpExchangePlan, + *, + rank: int, + group: Any | None = None, + backward_plan: GdnCpExchangePlan | None = None, +) -> Tensor: + """Redistribute one rank tensor with real `dist.all_to_all_single`. + + This is the eager distributed/autograd boundary for attention-layout to + GDN-layout token exchange. Backward applies the inverse exchange plan. + """ + + _check_rank(plan, rank) + if plan.cross_rank_token_count == 0: + return _exchange_rank_tensor_local(local_tensor, plan, rank=rank) + if not dist_is_available() or not dist_is_initialized(): + raise RuntimeError("torch.distributed must be initialized for GDN CP exchange") + world_size = get_world_size(group) + if world_size != plan.cp_size: + raise ValueError( + f"process group world size {world_size} must match plan cp_size " + f"{plan.cp_size}" + ) + if backward_plan is None: + raise ValueError("cross-rank GDN CP exchange requires a prebuilt backward_plan") + return _GdnCpExchangeFunction.apply(local_tensor, plan, backward_plan, rank, group) + + +def _transfer_token_count(transfer: GdnCpPeerTransfer) -> int: + return int(transfer.token_count) + + +def _is_implicit_full_identity_transfer( + transfer: GdnCpPeerTransfer, + *, + source_count: int, + dest_count: int, +) -> bool: + return ( + transfer.source_rank == transfer.dest_rank + and _transfer_token_count(transfer) == int(source_count) == int(dest_count) + and transfer.source_positions_tensor is None + and transfer.dest_positions_tensor is None + ) + + +def _transfer_positions_tuple(tensor: Tensor | None) -> tuple[int, ...]: + if tensor is None: + return () + return tuple(int(value) for value in tensor.detach().cpu().tolist()) + + +def _transfer_index_tensor( + tensor: Tensor | None, + *, + device: torch.device, +) -> Tensor: + if tensor is None: + raise ValueError("non-identity GDN CP transfer requires an index tensor") + if tensor.device == device: + return tensor + return tensor.to(device=device, non_blocking=True) + + +def _source_counts_by_rank(plan: GdnCpExchangePlan) -> tuple[int, ...]: + return plan.source_token_counts_by_rank + + +def _dest_counts_by_rank(plan: GdnCpExchangePlan) -> tuple[int, ...]: + return plan.dest_token_counts_by_rank + + +def _source_count_for_rank(plan: GdnCpExchangePlan, rank: int) -> int: + return _source_counts_by_rank(plan)[rank] + + +def _dest_count_for_rank(plan: GdnCpExchangePlan, rank: int) -> int: + return _dest_counts_by_rank(plan)[rank] + + +def _check_rank(plan: GdnCpExchangePlan, rank: int) -> None: + if rank < 0 or rank >= plan.cp_size: + raise ValueError(f"rank must be in [0, {plan.cp_size}), got {rank}") + + +def _transfer( + plan: GdnCpExchangePlan, + *, + source_rank: int, + dest_rank: int, +) -> GdnCpPeerTransfer: + for transfer in plan.transfers: + if transfer.source_rank == source_rank and transfer.dest_rank == dest_rank: + return transfer + return GdnCpPeerTransfer( + source_rank=source_rank, + dest_rank=dest_rank, + token_count=0, + ) + + +class _GdnCpExchangeFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + local_tensor: Tensor, + plan: GdnCpExchangePlan, + backward_plan: GdnCpExchangePlan, + rank: int, + group: Any | None, + ) -> Tensor: + ctx.rank = rank + ctx.group = group + ctx.reverse_plan = backward_plan + return _exchange_rank_tensor_all_to_all_forward( + local_tensor, + plan, + rank=rank, + group=group, + ) + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Any: + (grad_output,) = grad_outputs + grad_input = _exchange_rank_tensor_all_to_all_forward( + grad_output.contiguous(), + ctx.reverse_plan, + rank=ctx.rank, + group=ctx.group, + ) + return grad_input, None, None, None, None + + +def _exchange_rank_tensor_all_to_all_forward( + local_tensor: Tensor, + plan: GdnCpExchangePlan, + *, + rank: int, + group: Any | None, +) -> Tensor: + if plan.cross_rank_token_count == 0: + return _exchange_rank_tensor_local(local_tensor, plan, rank=rank) + accumulate = _rank_recv_requires_accumulation(plan, rank) + output = _init_rank_exchange_output( + local_tensor, plan, rank=rank, accumulate=accumulate + ) + send_buffer = _pack_rank_cross_send_tensor(local_tensor, plan, source_rank=rank) + send_buffer = send_buffer.contiguous() + recv_rows = sum(_cross_recv_split_sizes_for_rank(plan, rank)) + recv_buffer = local_tensor.new_empty((recv_rows, *local_tensor.shape[1:])) + all_to_all_single( + recv_buffer, + send_buffer, + output_split_sizes=list(_cross_recv_split_sizes_for_rank(plan, rank)), + input_split_sizes=list(_cross_send_split_sizes_for_rank(plan, rank)), + group=group, + ) + _unpack_rank_cross_recv_tensor_into( + output, recv_buffer, plan, dest_rank=rank, accumulate=accumulate + ) + return output + + +def _exchange_rank_tensor_local( + local_tensor: Tensor, + plan: GdnCpExchangePlan, + *, + rank: int, +) -> Tensor: + transfer = _transfer(plan, source_rank=rank, dest_rank=rank) + if _is_implicit_full_identity_transfer( + transfer, + source_count=_source_count_for_rank(plan, rank), + dest_count=_dest_count_for_rank(plan, rank), + ): + return local_tensor + return unpack_rank_recv_tensor( + pack_rank_send_tensor(local_tensor, plan, source_rank=rank), + plan, + dest_rank=rank, + ) + + +def _init_rank_exchange_output( + local_tensor: Tensor, + plan: GdnCpExchangePlan, + *, + rank: int, + accumulate: bool, +) -> Tensor: + dest_rows = _dest_count_for_rank(plan, rank) + output_shape = (dest_rows, *local_tensor.shape[1:]) + output = ( + local_tensor.new_zeros(output_shape) + if accumulate + else local_tensor.new_empty(output_shape) + ) + transfer = _transfer(plan, source_rank=rank, dest_rank=rank) + if not _transfer_token_count(transfer): + return output + if _is_implicit_full_identity_transfer( + transfer, + source_count=_source_count_for_rank(plan, rank), + dest_count=dest_rows, + ): + if accumulate: + output.add_(local_tensor) + else: + output.copy_(local_tensor) + return output + source_index = _transfer_index_tensor( + transfer.source_positions_tensor, + device=local_tensor.device, + ) + dest_index = _transfer_index_tensor( + transfer.dest_positions_tensor, + device=local_tensor.device, + ) + values = local_tensor.index_select(0, source_index) + if accumulate: + output.index_add_(0, dest_index, values) + else: + output.index_copy_(0, dest_index, values) + return output + + +def _pack_rank_cross_send_tensor( + local_tensor: Tensor, + plan: GdnCpExchangePlan, + *, + source_rank: int, +) -> Tensor: + pieces = [] + for dest_rank in range(plan.cp_size): + if dest_rank == source_rank: + continue + transfer = _transfer(plan, source_rank=source_rank, dest_rank=dest_rank) + if _transfer_token_count(transfer): + index = _transfer_index_tensor( + transfer.source_positions_tensor, + device=local_tensor.device, + ) + pieces.append(local_tensor.index_select(0, index)) + if not pieces: + return local_tensor.new_empty((0, *local_tensor.shape[1:])) + return torch.cat(pieces, dim=0) + + +def _unpack_rank_cross_recv_tensor_into( + output: Tensor, + recv_buffer: Tensor, + plan: GdnCpExchangePlan, + *, + dest_rank: int, + accumulate: bool, +) -> None: + expected_rows = sum(_cross_recv_split_sizes_for_rank(plan, dest_rank)) + if int(recv_buffer.shape[0]) != expected_rows: + raise ValueError( + f"recv buffer for rank {dest_rank} has {int(recv_buffer.shape[0])} rows; " + f"expected {expected_rows}" + ) + offset = 0 + for source_rank in range(plan.cp_size): + if source_rank == dest_rank: + continue + transfer = _transfer(plan, source_rank=source_rank, dest_rank=dest_rank) + rows = _transfer_token_count(transfer) + peer_rows = recv_buffer[offset : offset + rows] + offset += rows + if rows == 0: + continue + dest_index = _transfer_index_tensor( + transfer.dest_positions_tensor, + device=recv_buffer.device, + ) + if accumulate: + output.index_add_(0, dest_index, peer_rows) + else: + output.index_copy_(0, dest_index, peer_rows) + + +def _rank_recv_requires_accumulation(plan: GdnCpExchangePlan, rank: int) -> bool: + positions: list[int] = [] + for source_rank in range(plan.cp_size): + transfer = _transfer(plan, source_rank=source_rank, dest_rank=rank) + if not _transfer_token_count(transfer): + continue + positions.extend(_transfer_dest_positions_for_duplicate_check(plan, transfer)) + return len(positions) != len(set(positions)) + + +def _transfer_dest_positions_for_duplicate_check( + plan: GdnCpExchangePlan, transfer: GdnCpPeerTransfer +) -> tuple[int, ...]: + token_count = _transfer_token_count(transfer) + if token_count == 0: + return () + if _is_implicit_full_identity_transfer( + transfer, + source_count=_source_count_for_rank(plan, transfer.source_rank), + dest_count=_dest_count_for_rank(plan, transfer.dest_rank), + ): + return tuple(range(token_count)) + positions = _transfer_positions_tuple(transfer.dest_positions_tensor) + if len(positions) != token_count: + raise ValueError("GDN CP transfer destination positions must match token_count") + return positions + + +def _cross_send_split_sizes_for_rank( + plan: GdnCpExchangePlan, + rank: int, +) -> tuple[int, ...]: + return tuple( + 0 + if dest_rank == rank + else _transfer_token_count( + _transfer(plan, source_rank=rank, dest_rank=dest_rank) + ) + for dest_rank in range(plan.cp_size) + ) + + +def _cross_recv_split_sizes_for_rank( + plan: GdnCpExchangePlan, + rank: int, +) -> tuple[int, ...]: + return tuple( + 0 + if source_rank == rank + else _transfer_token_count( + _transfer(plan, source_rank=source_rank, dest_rank=rank) + ) + for source_rank in range(plan.cp_size) + ) diff --git a/src/art/megatron/gdn/operator.py b/src/art/megatron/gdn/operator.py new file mode 100644 index 000000000..034065cdb --- /dev/null +++ b/src/art/megatron/gdn/operator.py @@ -0,0 +1,2553 @@ +from __future__ import annotations + +from contextlib import contextmanager +from contextvars import ContextVar +import importlib +from types import MethodType +from typing import Any, Callable, Iterator, Literal, Sequence, cast + +from causal_conv1d import causal_conv1d_fn +from fla.modules.l2norm import l2norm +from fla.ops.gated_delta_rule import chunk_gated_delta_rule +from megatron.core.ssm.gated_delta_net import GatedDeltaNet +from megatron.core.transformer.transformer_layer import TransformerLayer +from pydantic import BaseModel, ConfigDict +import torch +from torch import Tensor +import torch.nn.functional as F + +from .conv_gelu import gdn_varlen_causal_conv_gelu, packed_varlen_causal_conv +from .gdn_shared_prefix import ( + GdnPackedExecutionSpec, + GdnParentStateTransferPlan, + GdnRankExecutionPlan, + GdnSegmentBucketPlan, + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, +) +from .segment_layout import ( + gather_bucket_streams_compact as _gather_bucket_streams_compact_fused, +) +from .segment_layout import ( + prepare_packed_recurrent_inputs as _prepare_packed_recurrent_inputs_fused, +) +from .segment_layout import ( + scatter_bucket_output_compact as _scatter_bucket_output_fused, +) + +_NVTX_ENABLED: ContextVar[bool] = ContextVar("art_gdn_nvtx_enabled", default=False) + + +class _BucketFlatLayout(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) + + padded_indices: Tensor + padded_mask: Tensor + real_indices: Tensor + output_indices: Tensor + output_selector: Tensor | None + + +def install_shared_prefix_gdn_hooks(model_chunks: Sequence[Any]) -> None: + """Patch Megatron GatedDeltaNet modules to honor ART shared-prefix packing.""" + + for chunk in model_chunks: + if not hasattr(chunk, "modules"): + continue + for module in chunk.modules(): + if not isinstance(module, GatedDeltaNet): + continue + if getattr(module, "_art_shared_prefix_gdn_hooked", False): + continue + original_forward = module.forward + module._art_physical_forward = original_forward + module.forward = MethodType(_shared_prefix_forward, module) + module._art_shared_prefix_gdn_hooked = True + + +def install_gdn_island_hooks(model_chunks: Sequence[Any]) -> None: + """Hoist CP layout conversion across consecutive Transformer GDN layers.""" + + for chunk in model_chunks: + if not hasattr(chunk, "modules"): + continue + _install_empty_safe_norm_hooks(chunk) + layers = [ + module + for module in chunk.modules() + if isinstance(module, TransformerLayer) + and hasattr(module, "self_attention") + ] + layer_is_gdn = [ + isinstance(layer.self_attention, GatedDeltaNet) for layer in layers + ] + for index, layer in enumerate(layers): + is_gdn = layer_is_gdn[index] + layer._art_gdn_island_is_gdn = is_gdn + layer._art_gdn_island_prev_is_gdn = index > 0 and layer_is_gdn[index - 1] + layer._art_gdn_island_next_is_gdn = ( + index + 1 < len(layers) and layer_is_gdn[index + 1] + ) + if getattr(layer, "_art_gdn_island_hooked", False): + continue + layer._art_gdn_island_physical_forward = layer.forward + layer.forward = MethodType(_gdn_island_layer_forward, layer) + layer._art_gdn_island_hooked = True + + +def _gdn_island_layer_forward(self: Any, *args: Any, **kwargs: Any) -> Any: + attention_bias = kwargs.get("attention_bias") + plan = getattr(attention_bias, "gdn_execution_plan", None) + original_forward = cast(Callable[..., Any], self._art_gdn_island_physical_forward) + if plan is None or int(getattr(plan, "cp_size", 1)) <= 1: + return original_forward(*args, **kwargs) + + hidden_states = _layer_forward_hidden_states(args, kwargs) + if hidden_states is None: + return original_forward(*args, **kwargs) + + is_gdn = bool(getattr(self, "_art_gdn_island_is_gdn", False)) + if not is_gdn: + if getattr(attention_bias, "gdn_hidden_layout", "attention") == "gdn": + _mark_attention_layout_active(attention_bias) + return original_forward(*args, **kwargs) + + prev_is_gdn = bool(getattr(self, "_art_gdn_island_prev_is_gdn", False)) + next_is_gdn = bool(getattr(self, "_art_gdn_island_next_is_gdn", False)) + if prev_is_gdn: + _mark_gdn_layout_active(attention_bias, hidden_states) + else: + hidden_states = _enter_gdn_island_layout( + hidden_states, attention_bias, force=True + ) + args, kwargs = _replace_layer_hidden_states(args, kwargs, hidden_states) + + output = ( + _empty_gdn_island_layer_forward(self, hidden_states, kwargs) + if int(hidden_states.shape[0]) == 0 + else original_forward(*args, **kwargs) + ) + if next_is_gdn: + _mark_gdn_layout_active(attention_bias, _layer_output_hidden_states(output)) + return output + + hidden_out = _leave_gdn_island_layout( + _layer_output_hidden_states(output), attention_bias + ) + return _replace_layer_output_hidden_states(output, hidden_out) + + +def _layer_forward_hidden_states( + args: tuple[Any, ...], kwargs: dict[str, Any] +) -> Tensor | None: + hidden_states = kwargs.get("hidden_states") + if isinstance(hidden_states, Tensor): + return hidden_states + if args and isinstance(args[0], Tensor): + return args[0] + return None + + +def _replace_layer_hidden_states( + args: tuple[Any, ...], kwargs: dict[str, Any], hidden_states: Tensor +) -> tuple[tuple[Any, ...], dict[str, Any]]: + if "hidden_states" in kwargs: + kwargs = dict(kwargs) + kwargs["hidden_states"] = hidden_states + return args, kwargs + if args: + return (hidden_states, *args[1:]), kwargs + kwargs = dict(kwargs) + kwargs["hidden_states"] = hidden_states + return args, kwargs + + +def _layer_output_hidden_states(output: Any) -> Tensor: + if isinstance(output, tuple): + return cast(Tensor, output[0]) + return cast(Tensor, output) + + +def _replace_layer_output_hidden_states(output: Any, hidden_states: Tensor) -> Any: + if isinstance(output, tuple): + return (hidden_states, *output[1:]) + return hidden_states + + +def _install_empty_safe_norm_hooks(root: Any) -> None: + if not isinstance(root, torch.nn.Module): + return + for module in root.modules(): + if getattr(module, "_art_empty_safe_norm_hooked", False): + continue + if not _is_empty_safe_norm_target(module): + continue + module._art_empty_safe_norm_physical_forward = module.forward + module.forward = MethodType(_empty_safe_norm_forward, module) + module._art_empty_safe_norm_hooked = True + + +def _is_empty_safe_norm_target(module: Any) -> bool: + if not isinstance(getattr(module, "weight", None), Tensor): + return False + module_name = type(module).__name__ + module_path = type(module).__module__ + return module_name in {"RMSNorm", "LayerNorm"} and module_path.startswith( + "transformer_engine." + ) + + +def _empty_safe_norm_forward( + self: Any, input_: Tensor, *args: Any, **kwargs: Any +) -> Any: + if isinstance(input_, Tensor) and int(input_.numel()) == 0: + return _apply_explicit_norm( + self, + input_, + config=None, + weight_name="weight", + bias_name="bias", + ) + original_forward = cast( + Callable[..., Any], self._art_empty_safe_norm_physical_forward + ) + return original_forward(input_, *args, **kwargs) + + +def _empty_gdn_island_layer_forward( + layer: Any, hidden_states: Tensor, kwargs: dict[str, Any] +) -> tuple[Tensor, Tensor | None]: + with _nvtx_range("art_gdn_empty_island_layer", hidden_states): + attention_output = layer.self_attention( + hidden_states, + attention_mask=kwargs.get("attention_mask"), + inference_context=kwargs.get( + "inference_context", kwargs.get("inference_params") + ), + rotary_pos_emb=kwargs.get("rotary_pos_emb"), + rotary_pos_cos=kwargs.get("rotary_pos_cos"), + rotary_pos_sin=kwargs.get("rotary_pos_sin"), + rotary_pos_cos_sin=kwargs.get("rotary_pos_cos_sin"), + attention_bias=kwargs.get("attention_bias"), + packed_seq_params=kwargs.get("packed_seq_params"), + sequence_len_offset=kwargs.get("sequence_len_offset"), + ) + context = kwargs.get("context") + if isinstance(attention_output, dict) and "context" in attention_output: + context = attention_output["context"] + attention_hidden = ( + attention_output[0] if isinstance(attention_output, tuple) else attention_output + ) + return hidden_states + cast(Tensor, attention_hidden), context + + +def _shared_prefix_forward( + self: Any, + hidden_states: Tensor, + attention_mask: Tensor, + key_value_states: Tensor | None = None, + inference_context: Any | None = None, + attention_bias: Any | None = None, + packed_seq_params: Any | None = None, + sequence_len_offset: int | None = None, + *, + inference_params: Any | None = None, + **kwargs: Any, +) -> tuple[Tensor, Tensor | None]: + group_ids = getattr(attention_bias, "group_ids", None) + parent_ids = getattr(attention_bias, "parent_ids", None) + execution_spec = getattr(attention_bias, "gdn_execution_spec", None) + execution_plan = getattr(attention_bias, "gdn_execution_plan", None) + if group_ids is None or parent_ids is None: + original_forward = cast( + Callable[..., tuple[Tensor, Tensor | None]], self._art_physical_forward + ) + return original_forward( + hidden_states, + attention_mask, + key_value_states=key_value_states, + inference_context=inference_context, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + inference_params=inference_params, + **kwargs, + ) + + del attention_mask, key_value_states, sequence_len_offset, kwargs + if inference_context is not None or inference_params is not None: + raise NotImplementedError("ART shared-prefix GDN does not support inference.") + if packed_seq_params is not None: + raise NotImplementedError( + "PackedSeqParams is not used in ART shared-prefix GDN." + ) + return gdn_shared_prefix_forward( + self, + hidden_states, + group_ids=cast(Tensor, group_ids), + parent_ids=cast(Tensor, parent_ids), + execution_spec=cast(GdnPackedExecutionSpec | None, execution_spec), + execution_plan=cast(GdnRankExecutionPlan | None, execution_plan), + input_layout=( + "gdn" + if getattr(attention_bias, "gdn_hidden_layout", "attention") == "gdn" + else "attention" + ), + output_layout=( + "gdn" + if getattr(attention_bias, "gdn_hidden_layout", "attention") == "gdn" + else "attention" + ), + require_prebuilt_plan=False, + ) + + +@torch.compiler.disable +def gdn_shared_prefix_forward( + gdn: Any, + hidden_states: Tensor, + *, + group_ids: Tensor, + parent_ids: Tensor, + execution_spec: GdnPackedExecutionSpec | None = None, + execution_plan: GdnRankExecutionPlan | None = None, + cp_group: Any | None = None, + require_prebuilt_plan: bool = False, + input_layout: Literal["attention", "gdn"] = "attention", + output_layout: Literal["attention", "gdn"] = "attention", +) -> tuple[Tensor, Tensor | None]: + """Run one GDN layer over ART shared-prefix packed rows.""" + + return run_gdn_layer( + gdn, + hidden_states, + group_ids=group_ids, + parent_ids=parent_ids, + execution_spec=execution_spec, + execution_plan=execution_plan, + cp_group=cp_group, + require_prebuilt_plan=require_prebuilt_plan, + input_layout=input_layout, + output_layout=output_layout, + ) + + +@torch.compiler.disable +def run_gdn_layer( + gdn: Any, + hidden_states: Tensor, + *, + group_ids: Tensor, + parent_ids: Tensor, + execution_spec: GdnPackedExecutionSpec | None = None, + execution_plan: GdnRankExecutionPlan | None = None, + cp_group: Any | None = None, + require_prebuilt_plan: bool = False, + input_layout: Literal["attention", "gdn"] = "attention", + output_layout: Literal["attention", "gdn"] = "attention", +) -> tuple[Tensor, Tensor | None]: + """Run one production shared-prefix GDN layer.""" + + _disable_reentrant_te_linear_transpose_cache(gdn) + if hidden_states.ndim != 3: + raise ValueError( + f"hidden_states must be [S, B, D], got {tuple(hidden_states.shape)}" + ) + seq_len, batch_size, _ = hidden_states.shape + requested_cp_size = ( + execution_plan.cp_size if execution_plan is not None else _default_cp_size() + ) + cp_rank = ( + execution_plan.cp_rank + if execution_plan is not None + else _default_cp_rank(requested_cp_size) + ) + full_shape_required = requested_cp_size == 1 + expected_group_seq_len = seq_len + if full_shape_required and _gdn_uses_sequence_parallel(gdn): + expected_group_seq_len *= int(getattr(gdn, "sp_size", 1)) + if full_shape_required and ( + int(group_ids.shape[0]) != batch_size + or int(group_ids.shape[1]) != expected_group_seq_len + ): + raise ValueError( + "shared-prefix GDN group_ids shape must match the logical sequence " + "processed by Megatron GDN after sequence-parallel input gather, got " + f"hidden={tuple(hidden_states.shape)} group_ids={tuple(group_ids.shape)} " + f"expected_group_shape={(batch_size, expected_group_seq_len)}" + ) + + if require_prebuilt_plan and execution_plan is None: + raise ValueError( + "ART shared-prefix GDN production path requires a prebuilt " + "GDN execution plan on SharedPrefixAttentionState. Build it once " + "per packed sequence via create_shared_prefix_state(..., " + "build_gdn_execution_spec=True)." + ) + + if execution_spec is None and execution_plan is None: + with _nvtx_range("art_gdn_parse_shared_prefix_layout", hidden_states): + execution_spec = parse_gdn_shared_prefix_segments( + group_ids, parent_ids, min_completions_per_family=0 + ) + if ( + execution_spec is not None + and requested_cp_size == 1 + and ( + execution_spec.batch_size != batch_size + or execution_spec.sequence_length != expected_group_seq_len + ) + ): + raise ValueError( + "GDN execution spec shape must match hidden_states, got " + f"spec={(execution_spec.batch_size, execution_spec.sequence_length)} " + f"expected={(batch_size, expected_group_seq_len)} " + f"hidden={(batch_size, seq_len)}" + ) + if execution_plan is None: + if execution_spec is None: + raise ValueError("GDN execution spec is required to build a missing plan") + with _nvtx_range("art_gdn_plan_shared_prefix_layout", hidden_states): + execution_plan = build_gdn_rank_execution_plan( + execution_spec, + device=hidden_states.device, + cp_rank=cp_rank, + cp_size=requested_cp_size, + ) + elif execution_plan.cp_size == 1 and ( + execution_plan.batch_size != batch_size + or execution_plan.sequence_length != seq_len + ): + raise ValueError( + "GDN execution plan shape must match hidden_states, got " + f"plan={(execution_plan.batch_size, execution_plan.sequence_length)} " + f"hidden={(batch_size, seq_len)}" + ) + if execution_plan.cp_size != 1: + return _run_cp_planned_prefixes_and_completions( + gdn, + hidden_states, + execution_plan, + group=cp_group or _default_cp_group(execution_plan.cp_size), + input_layout=input_layout, + output_layout=output_layout, + ) + if input_layout != "attention" or output_layout != "attention": + raise ValueError("GDN layout controls require a CP execution plan") + return _run_planned_prefixes_and_completions(gdn, hidden_states, execution_plan) + + +def _run_planned_prefixes_and_completions( + gdn: Any, + hidden_states: Tensor, + plan: GdnRankExecutionPlan, +) -> tuple[Tensor, Tensor | None]: + if _has_chunk_aligned_local_plan(plan): + return _run_chunk_aligned_prefixes_and_completions(gdn, hidden_states, plan) + raise ValueError( + "shared-prefix GDN requires a chunk-aligned execution plan; " + "prefix/completion bucket execution has been removed" + ) + + +def _has_chunk_aligned_local_plan(plan: GdnRankExecutionPlan) -> bool: + return bool( + plan.prefix_boundary_buckets + or plan.prefix_tail_buckets + or plan.completion_warmup_buckets + ) + + +def _run_chunk_aligned_prefixes_and_completions( + gdn: Any, + hidden_states: Tensor, + plan: GdnRankExecutionPlan, +) -> tuple[Tensor, Tensor | None]: + with _nvtx_range("art_gdn_in_proj", hidden_states): + qkv, gate, beta, recurrent_g = _project_gdn_inputs(gdn, hidden_states) + gate = gate.clone() + recurrent_output = torch.zeros_like(gate) + boundary_family_chunks: list[Tensor] = [] + boundary_conv_chunks: list[Tensor] = [] + boundary_rec_chunks: list[Tensor] = [] + + for bucket in plan.prefix_boundary_buckets: + with _nvtx_range("art_gdn_input_layout_gather_reorder", qkv): + prefix_qkv, prefix_beta, prefix_g = _gather_compact_bucket_streams( + qkv, beta, recurrent_g, bucket + ) + zero_conv = _zero_conv_state( + gdn, hidden_states, batch_size=bucket.segment_count + ) + zero_rec = _zero_recurrent_state( + gdn, hidden_states, batch_size=bucket.segment_count + ) + with _nvtx_range("art_gdn_prefix_boundary_segment", prefix_qkv): + prefix_out, prefix_conv, prefix_rec = run_gdn_bucket( + bucket, + (prefix_qkv, prefix_beta, prefix_g), + (zero_conv, zero_rec), + gdn=gdn, + output_final_state=True, + ) + if prefix_conv is None or prefix_rec is None: + raise RuntimeError("prefix boundary GDN execution must return final states") + recurrent_output = _scatter_bucket_recurrent_output( + recurrent_output, bucket, prefix_out + ) + boundary_family_chunks.append(bucket.family_indices) + boundary_conv_chunks.append(prefix_conv) + boundary_rec_chunks.append(prefix_rec) + + boundary_conv_table = _materialize_indexed_family_state_table( + plan=plan, + family_chunks=boundary_family_chunks, + state_chunks=boundary_conv_chunks, + zero_state=_zero_conv_state(gdn, hidden_states, batch_size=plan.family_count), + ) + boundary_rec_table = _materialize_indexed_family_state_table( + plan=plan, + family_chunks=boundary_family_chunks, + state_chunks=boundary_rec_chunks, + zero_state=_zero_recurrent_state( + gdn, hidden_states, batch_size=plan.family_count + ), + ) + + tail_family_chunks: list[Tensor] = [] + tail_conv_chunks: list[Tensor] = [] + tail_rec_chunks: list[Tensor] = [] + for bucket in plan.prefix_tail_buckets: + with _nvtx_range("art_gdn_input_layout_gather_reorder", qkv): + tail_qkv, tail_beta, tail_g = _gather_compact_bucket_streams( + qkv, beta, recurrent_g, bucket + ) + with _nvtx_range("art_gdn_state_fanout", tail_qkv): + tail_conv = boundary_conv_table.index_select(0, bucket.family_indices) + tail_rec = boundary_rec_table.index_select(0, bucket.family_indices) + with _nvtx_range("art_gdn_prefix_tail_segment", tail_qkv): + tail_out, tail_conv, tail_rec = run_gdn_bucket( + bucket, + (tail_qkv, tail_beta, tail_g), + (tail_conv, tail_rec), + gdn=gdn, + output_final_state=True, + ) + if tail_conv is None or tail_rec is None: + raise RuntimeError("prefix tail GDN execution must return final states") + recurrent_output = _scatter_bucket_recurrent_output( + recurrent_output, bucket, tail_out + ) + tail_family_chunks.append(bucket.family_indices) + tail_conv_chunks.append(tail_conv) + tail_rec_chunks.append(tail_rec) + + prefix_conv_table = _replace_indexed_family_states( + boundary_conv_table, + family_chunks=tail_family_chunks, + state_chunks=tail_conv_chunks, + ) + prefix_rec_table = _replace_indexed_family_states( + boundary_rec_table, + family_chunks=tail_family_chunks, + state_chunks=tail_rec_chunks, + ) + + for bucket in plan.completion_warmup_buckets: + with _nvtx_range("art_gdn_state_fanout", hidden_states): + completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) + completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) + with _nvtx_range("art_gdn_input_layout_gather_reorder", qkv): + completion_qkv, completion_beta, completion_g = ( + _gather_compact_bucket_streams(qkv, beta, recurrent_g, bucket) + ) + with _nvtx_range("art_gdn_completion_warmup_segment", completion_qkv): + completion_out, _, _ = run_gdn_bucket( + bucket, + (completion_qkv, completion_beta, completion_g), + (completion_conv, completion_rec), + gdn=gdn, + output_final_state=False, + ) + recurrent_output = _scatter_bucket_recurrent_output( + recurrent_output, bucket, completion_out + ) + + return _project_gdn_output(gdn, recurrent_output, gate, plan) + + +def _iter_prepared_bucket_columns( + bucket: GdnSegmentBucketPlan, + qkv: Tensor, + beta: Tensor, + recurrent_g: Tensor, + conv_initial: Tensor, + recurrent_initial: Tensor, +) -> Iterator[tuple[GdnSegmentBucketPlan, Tensor, Tensor, Tensor, Tensor, Tensor]]: + for column in range(int(bucket.lengths.numel())): + length = int(bucket.lengths[column].item()) + if length == 0: + continue + column_bucket = _slice_bucket_column(bucket, column=column, length=length) + yield ( + column_bucket, + qkv[column : column + 1, :, :length], + beta[column : column + 1, :length], + recurrent_g[column : column + 1, :length], + conv_initial[column : column + 1], + recurrent_initial[column : column + 1], + ) + + +def _slice_bucket_column( + bucket: GdnSegmentBucketPlan, *, column: int, length: int +) -> GdnSegmentBucketPlan: + lengths = bucket.lengths[column : column + 1] + cu_seqlens = torch.stack((lengths.new_zeros(()), lengths[0])) + output_mask = ( + None + if bucket.output_mask is None + else bucket.output_mask[:length, column : column + 1] + ) + return GdnSegmentBucketPlan.model_construct( + length=length, + lengths=lengths, + real_mask=bucket.real_mask[:length, column : column + 1], + cu_seqlens=cu_seqlens, + row_indices=bucket.row_indices[:length, column : column + 1], + position_indices=bucket.position_indices[:length, column : column + 1], + family_indices=bucket.family_indices[column : column + 1], + real_token_count_static=length, + output_mask=output_mask, + ) + + +def _run_cp_planned_prefixes_and_completions( + gdn: Any, + hidden_states: Tensor, + plan: GdnRankExecutionPlan, + *, + group: Any, + input_layout: Literal["attention", "gdn"], + output_layout: Literal["attention", "gdn"], +) -> tuple[Tensor, Tensor | None]: + if plan.attention_to_gdn is None or plan.gdn_to_attention is None: + raise ValueError("CP GDN execution requires prebuilt exchange plans") + if input_layout not in ("attention", "gdn") or output_layout not in ( + "attention", + "gdn", + ): + raise ValueError( + f"unsupported GDN CP layouts: {input_layout=} {output_layout=}" + ) + run_gdn_prepared_varlen_native_fla_cp = importlib.import_module( + "art.megatron.gdn.cp_runtime" + ).run_gdn_prepared_varlen_native_fla_cp + + if input_layout == "attention": + gdn_hidden, original_shape = gdn_cp_attention_to_gdn_layout( + hidden_states, plan, group + ) + else: + gdn_hidden = _validate_gdn_hidden_for_cp_plan(hidden_states, plan) + original_shape = _attention_original_shape_from_plan(hidden_states, plan) + with _nvtx_range("art_gdn_in_proj", gdn_hidden): + qkv, gate, beta, recurrent_g = _project_gdn_inputs(gdn, gdn_hidden) + gate = gate.clone() + recurrent_output = torch.zeros_like(gate) + prefix_family_chunks: list[Tensor] = [] + prefix_conv_chunks: list[Tensor] = [] + prefix_rec_chunks: list[Tensor] = [] + cp_dependency = _empty_autograd_dependency(qkv) + + for bucket in plan.chain_prefix_buckets: + with _nvtx_range("art_gdn_input_layout_gather_reorder", qkv): + prefix_qkv, prefix_beta, prefix_g = _gather_bucket_streams( + qkv, beta, recurrent_g, bucket + ) + zero_conv = _zero_conv_state(gdn, gdn_hidden, batch_size=prefix_qkv.shape[0]) + zero_rec = _zero_recurrent_state( + gdn, gdn_hidden, batch_size=prefix_qkv.shape[0] + ) + with _nvtx_range("art_gdn_cp_prefix_segment", prefix_qkv): + prefix_out, prefix_conv, prefix_rec = run_gdn_prepared_varlen_native_fla_cp( + gdn, + prefix_qkv, + beta=prefix_beta, + recurrent_g=prefix_g, + lengths=bucket.lengths, + cu_seqlens=bucket.cu_seqlens, + conv_initial=zero_conv, + recurrent_initial=zero_rec, + group=group, + output_final_state=True, + ) + if prefix_conv is None or prefix_rec is None: + raise RuntimeError("CP prefix GDN execution must return final states") + prefix_out = _add_autograd_dependency(prefix_out, cp_dependency) + prefix_conv = _add_autograd_dependency(prefix_conv, cp_dependency) + prefix_rec = _add_autograd_dependency(prefix_rec, cp_dependency) + cp_dependency = _make_autograd_dependency(prefix_out, prefix_conv, prefix_rec) + recurrent_output = _scatter_bucket_recurrent_output( + recurrent_output, bucket, prefix_out + ) + prefix_family_chunks.append(bucket.family_indices) + prefix_conv_chunks.append(prefix_conv) + prefix_rec_chunks.append(prefix_rec) + + boundary_family_chunks: list[Tensor] = [] + boundary_conv_chunks: list[Tensor] = [] + boundary_rec_chunks: list[Tensor] = [] + for bucket in plan.prefix_boundary_buckets: + with _nvtx_range("art_gdn_input_layout_gather_reorder", qkv): + prefix_qkv, prefix_beta, prefix_g = _gather_bucket_streams( + qkv, beta, recurrent_g, bucket + ) + zero_conv = _zero_conv_state(gdn, gdn_hidden, batch_size=prefix_qkv.shape[0]) + zero_rec = _zero_recurrent_state( + gdn, gdn_hidden, batch_size=prefix_qkv.shape[0] + ) + with _nvtx_range("art_gdn_local_prefix_segment", prefix_qkv): + prefix_out, prefix_conv, prefix_rec = _run_gdn_prepared_varlen_batch( + gdn, + prefix_qkv, + beta=prefix_beta, + recurrent_g=prefix_g, + bucket=bucket, + conv_initial=zero_conv, + recurrent_initial=zero_rec, + output_final_state=True, + ) + if prefix_conv is None or prefix_rec is None: + raise RuntimeError("local prefix GDN execution must return final states") + prefix_out = _add_autograd_dependency(prefix_out, cp_dependency) + prefix_conv = _add_autograd_dependency(prefix_conv, cp_dependency) + prefix_rec = _add_autograd_dependency(prefix_rec, cp_dependency) + recurrent_output = _scatter_bucket_recurrent_output( + recurrent_output, bucket, prefix_out + ) + boundary_family_chunks.append(bucket.family_indices) + boundary_conv_chunks.append(prefix_conv) + boundary_rec_chunks.append(prefix_rec) + prefix_family_chunks.append(bucket.family_indices) + prefix_conv_chunks.append(prefix_conv) + prefix_rec_chunks.append(prefix_rec) + + if plan.prefix_tail_buckets or plan.completion_warmup_buckets: + boundary_conv_table = _materialize_indexed_family_state_table( + plan=plan, + family_chunks=boundary_family_chunks, + state_chunks=boundary_conv_chunks, + zero_state=_zero_conv_state(gdn, gdn_hidden, batch_size=plan.family_count), + ) + boundary_rec_table = _materialize_indexed_family_state_table( + plan=plan, + family_chunks=boundary_family_chunks, + state_chunks=boundary_rec_chunks, + zero_state=_zero_recurrent_state( + gdn, gdn_hidden, batch_size=plan.family_count + ), + ) + tail_family_chunks: list[Tensor] = [] + tail_conv_chunks: list[Tensor] = [] + tail_rec_chunks: list[Tensor] = [] + for bucket in plan.prefix_tail_buckets: + with _nvtx_range("art_gdn_input_layout_gather_reorder", qkv): + tail_qkv, tail_beta, tail_g = _gather_bucket_streams( + qkv, beta, recurrent_g, bucket + ) + tail_conv = boundary_conv_table.index_select(0, bucket.family_indices) + tail_rec = boundary_rec_table.index_select(0, bucket.family_indices) + with _nvtx_range("art_gdn_local_prefix_segment", tail_qkv): + tail_out, tail_conv, tail_rec = _run_gdn_prepared_varlen_batch( + gdn, + tail_qkv, + beta=tail_beta, + recurrent_g=tail_g, + bucket=bucket, + conv_initial=tail_conv, + recurrent_initial=tail_rec, + output_final_state=True, + ) + if tail_conv is None or tail_rec is None: + raise RuntimeError("local prefix tail GDN execution must return states") + tail_out = _add_autograd_dependency(tail_out, cp_dependency) + tail_conv = _add_autograd_dependency(tail_conv, cp_dependency) + tail_rec = _add_autograd_dependency(tail_rec, cp_dependency) + recurrent_output = _scatter_bucket_recurrent_output( + recurrent_output, bucket, tail_out + ) + tail_family_chunks.append(bucket.family_indices) + tail_conv_chunks.append(tail_conv) + tail_rec_chunks.append(tail_rec) + prefix_family_chunks.append(bucket.family_indices) + prefix_conv_chunks.append(tail_conv) + prefix_rec_chunks.append(tail_rec) + prefix_conv_table = _replace_indexed_family_states( + boundary_conv_table, + family_chunks=tail_family_chunks, + state_chunks=tail_conv_chunks, + ) + prefix_rec_table = _replace_indexed_family_states( + boundary_rec_table, + family_chunks=tail_family_chunks, + state_chunks=tail_rec_chunks, + ) + for bucket in plan.completion_warmup_buckets: + completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) + completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) + completion_conv, completion_rec = _couple_parent_states( + completion_conv, completion_rec + ) + with _nvtx_range("art_gdn_input_layout_gather_reorder", qkv): + completion_qkv, completion_beta, completion_g = _gather_bucket_streams( + qkv, beta, recurrent_g, bucket + ) + for ( + column_bucket, + qkv_col, + beta_col, + g_col, + conv_col, + rec_col, + ) in _iter_prepared_bucket_columns( + bucket, + completion_qkv, + completion_beta, + completion_g, + completion_conv, + completion_rec, + ): + with _nvtx_range("art_gdn_local_completion_segment", qkv_col): + completion_out, _, _ = _run_gdn_prepared_varlen_batch( + gdn, + qkv_col, + beta=beta_col, + recurrent_g=g_col, + bucket=column_bucket, + conv_initial=conv_col, + recurrent_initial=rec_col, + output_final_state=False, + ) + completion_out = _add_autograd_dependency(completion_out, cp_dependency) + recurrent_output = _scatter_bucket_recurrent_output( + recurrent_output, column_bucket, completion_out + ) + + for bucket in plan.local_prefix_buckets: + with _nvtx_range("art_gdn_input_layout_gather_reorder", qkv): + prefix_qkv, prefix_beta, prefix_g = _gather_bucket_streams( + qkv, beta, recurrent_g, bucket + ) + zero_conv = _zero_conv_state(gdn, gdn_hidden, batch_size=prefix_qkv.shape[0]) + zero_rec = _zero_recurrent_state( + gdn, gdn_hidden, batch_size=prefix_qkv.shape[0] + ) + with _nvtx_range("art_gdn_local_prefix_segment", prefix_qkv): + prefix_out, prefix_conv, prefix_rec = _run_gdn_prepared_varlen_batch( + gdn, + prefix_qkv, + beta=prefix_beta, + recurrent_g=prefix_g, + bucket=bucket, + conv_initial=zero_conv, + recurrent_initial=zero_rec, + output_final_state=True, + ) + if prefix_conv is None or prefix_rec is None: + raise RuntimeError("local prefix GDN execution must return final states") + prefix_out = _add_autograd_dependency(prefix_out, cp_dependency) + prefix_conv = _add_autograd_dependency(prefix_conv, cp_dependency) + prefix_rec = _add_autograd_dependency(prefix_rec, cp_dependency) + recurrent_output = _scatter_bucket_recurrent_output( + recurrent_output, bucket, prefix_out + ) + prefix_family_chunks.append(bucket.family_indices) + prefix_conv_chunks.append(prefix_conv) + prefix_rec_chunks.append(prefix_rec) + + if not prefix_conv_chunks and not plan.parent_state_exchange_family_indices: + projected, out_bias = _project_gdn_output(gdn, recurrent_output, gate, plan) + if output_layout == "gdn": + return projected, out_bias + return _cp_output_to_attention(projected, plan, original_shape, group), out_bias + + prefix_conv_table = _materialize_ordered_family_state_table( + family_chunks=prefix_family_chunks, + state_chunks=prefix_conv_chunks, + zero_state=_zero_conv_state(gdn, gdn_hidden, batch_size=plan.family_count), + ) + prefix_rec_table = _materialize_ordered_family_state_table( + family_chunks=prefix_family_chunks, + state_chunks=prefix_rec_chunks, + zero_state=_zero_recurrent_state(gdn, gdn_hidden, batch_size=plan.family_count), + ) + for bucket in plan.chain_completion_buckets: + with _nvtx_range("art_gdn_input_layout_gather_reorder", qkv): + completion_qkv, completion_beta, completion_g = _gather_bucket_streams( + qkv, beta, recurrent_g, bucket + ) + completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) + completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) + completion_conv, completion_rec = _couple_parent_states( + completion_conv, completion_rec + ) + completion_conv = _scale_state_gradient(completion_conv, 1.0 / plan.cp_size) + completion_rec = _scale_state_gradient(completion_rec, 1.0 / plan.cp_size) + with _nvtx_range("art_gdn_cp_completion_segment", completion_qkv): + completion_out, _, _ = run_gdn_prepared_varlen_native_fla_cp( + gdn, + completion_qkv, + beta=completion_beta, + recurrent_g=completion_g, + lengths=bucket.lengths, + cu_seqlens=bucket.cu_seqlens, + conv_initial=completion_conv, + recurrent_initial=completion_rec, + group=group, + output_final_state=False, + ) + completion_out = _add_autograd_dependency(completion_out, cp_dependency) + cp_dependency = _make_autograd_dependency(completion_out) + recurrent_output = _scatter_bucket_recurrent_output( + recurrent_output, bucket, completion_out + ) + + ready_completion_buckets = ( + plan.ready_local_completion_buckets + if plan.ready_local_completion_buckets or plan.remote_local_completion_buckets + else plan.local_completion_buckets + ) + for bucket in ready_completion_buckets: + with _nvtx_range("art_gdn_input_layout_gather_reorder", qkv): + completion_qkv, completion_beta, completion_g = _gather_bucket_streams( + qkv, beta, recurrent_g, bucket + ) + completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) + completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) + completion_conv, completion_rec = _couple_parent_states( + completion_conv, completion_rec + ) + with _nvtx_range("art_gdn_local_completion_segment", completion_qkv): + completion_out, _, _ = _run_gdn_prepared_varlen_batch( + gdn, + completion_qkv, + beta=completion_beta, + recurrent_g=completion_g, + bucket=bucket, + conv_initial=completion_conv, + recurrent_initial=completion_rec, + output_final_state=False, + ) + completion_out = _add_autograd_dependency(completion_out, cp_dependency) + recurrent_output = _scatter_bucket_recurrent_output( + recurrent_output, bucket, completion_out + ) + + if plan.parent_state_exchange_family_indices: + if not plan.parent_state_transfers: + raise ValueError("CP parent-state exchange requires planned transfers") + with _nvtx_range("art_gdn_cp_parent_state_exchange", prefix_conv_table): + prefix_conv_table, prefix_rec_table, exchange_dependency = ( + _exchange_parent_state_rows( + prefix_conv_table, + prefix_rec_table, + transfers=plan.parent_state_transfers, + group=group, + ) + ) + cp_dependency = cp_dependency + exchange_dependency + + for bucket in plan.remote_local_completion_buckets: + with _nvtx_range("art_gdn_input_layout_gather_reorder", qkv): + completion_qkv, completion_beta, completion_g = _gather_bucket_streams( + qkv, beta, recurrent_g, bucket + ) + completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) + completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) + completion_conv, completion_rec = _couple_parent_states( + completion_conv, completion_rec + ) + with _nvtx_range("art_gdn_local_completion_segment", completion_qkv): + completion_out, _, _ = _run_gdn_prepared_varlen_batch( + gdn, + completion_qkv, + beta=completion_beta, + recurrent_g=completion_g, + bucket=bucket, + conv_initial=completion_conv, + recurrent_initial=completion_rec, + output_final_state=False, + ) + completion_out = _add_autograd_dependency(completion_out, cp_dependency) + recurrent_output = _scatter_bucket_recurrent_output( + recurrent_output, bucket, completion_out + ) + + projected, out_bias = _project_gdn_output(gdn, recurrent_output, gate, plan) + projected = _add_autograd_dependency(projected, cp_dependency) + if output_layout == "gdn": + return projected, out_bias + return _cp_output_to_attention(projected, plan, original_shape, group), out_bias + + +@torch.compiler.disable +def gdn_cp_attention_to_gdn_layout( + hidden_states: Tensor, + plan: GdnRankExecutionPlan, + group: Any, +) -> tuple[Tensor, tuple[int, int, int]]: + from .layout import exchange_rank_tensor_all_to_all + + if plan.attention_to_gdn is None or plan.gdn_to_attention is None: + raise ValueError("CP GDN layout conversion requires prebuilt exchange plans") + attention_flat, original_shape = _flatten_hidden_for_cp_plan(hidden_states, plan) + with _nvtx_range("art_gdn_cp_attention_to_gdn_exchange", attention_flat): + gdn_flat = exchange_rank_tensor_all_to_all( + attention_flat, + plan.attention_to_gdn, + rank=plan.cp_rank, + group=group, + backward_plan=plan.gdn_to_attention, + ) + return gdn_flat.unsqueeze(1).contiguous(), original_shape + + +@torch.compiler.disable +def gdn_cp_gdn_to_attention_layout( + gdn_hidden: Tensor, + plan: GdnRankExecutionPlan, + original_shape: tuple[int, int, int] | None, + group: Any, +) -> Tensor: + original_shape = original_shape or _attention_original_shape_from_plan( + gdn_hidden, plan + ) + return _cp_output_to_attention(gdn_hidden, plan, original_shape, group) + + +def _enter_gdn_island_layout( + hidden_states: Tensor, attention_bias: Any, *, force: bool = False +) -> Tensor: + plan = _require_gdn_cp_plan(attention_bias) + if not force and getattr(attention_bias, "gdn_hidden_layout", "attention") == "gdn": + return _validate_gdn_hidden_for_cp_plan(hidden_states, plan) + gdn_hidden, original_shape = gdn_cp_attention_to_gdn_layout( + hidden_states, + plan, + _default_cp_group(plan.cp_size), + ) + attention_bias.gdn_hidden_layout = "gdn" + attention_bias.gdn_attention_original_shape = original_shape + return gdn_hidden + + +def _mark_attention_layout_active(attention_bias: Any) -> None: + attention_bias.gdn_hidden_layout = "attention" + attention_bias.gdn_attention_original_shape = None + + +def _leave_gdn_island_layout(hidden_states: Tensor, attention_bias: Any) -> Tensor: + plan = _require_gdn_cp_plan(attention_bias) + gdn_hidden = _validate_gdn_hidden_for_cp_plan(hidden_states, plan) + attention_hidden = gdn_cp_gdn_to_attention_layout( + gdn_hidden, + plan, + getattr(attention_bias, "gdn_attention_original_shape", None), + _default_cp_group(plan.cp_size), + ) + _mark_attention_layout_active(attention_bias) + return attention_hidden + + +def _mark_gdn_layout_active(attention_bias: Any, hidden_states: Tensor) -> None: + plan = _require_gdn_cp_plan(attention_bias) + _validate_gdn_hidden_for_cp_plan(hidden_states, plan) + attention_bias.gdn_hidden_layout = "gdn" + if getattr(attention_bias, "gdn_attention_original_shape", None) is None: + attention_bias.gdn_attention_original_shape = ( + _attention_original_shape_from_plan(hidden_states, plan) + ) + + +def _require_gdn_cp_plan(attention_bias: Any) -> GdnRankExecutionPlan: + plan = getattr(attention_bias, "gdn_execution_plan", None) + if plan is None or int(getattr(plan, "cp_size", 1)) <= 1: + raise ValueError("GDN island layout conversion requires a CP execution plan") + return cast(GdnRankExecutionPlan, plan) + + +def _cp_output_to_attention( + gdn_output: Tensor, + plan: GdnRankExecutionPlan, + original_shape: tuple[int, int, int], + group: Any, +) -> Tensor: + from .layout import exchange_rank_tensor_all_to_all + + if plan.gdn_to_attention is None: + raise ValueError("CP GDN execution requires a GDN-to-attention exchange plan") + gdn_flat = gdn_output.squeeze(1).contiguous() + with _nvtx_range("art_gdn_cp_gdn_to_attention_exchange", gdn_flat): + attention_flat = exchange_rank_tensor_all_to_all( + gdn_flat, + plan.gdn_to_attention, + rank=plan.cp_rank, + group=group, + backward_plan=plan.attention_to_gdn, + ) + return _restore_hidden_from_cp_flat(attention_flat, original_shape) + + +def _flatten_hidden_for_cp_plan( + hidden_states: Tensor, plan: GdnRankExecutionPlan +) -> tuple[Tensor, tuple[int, int, int]]: + seq_len, batch_size, hidden_size = hidden_states.shape + flat = hidden_states.transpose(0, 1).reshape(seq_len * batch_size, hidden_size) + expected = int(plan.attention_token_count) + if int(flat.shape[0]) != expected: + raise ValueError( + "CP GDN hidden token count must match the rank-local attention plan, " + f"got {int(flat.shape[0])} tokens and expected {expected}" + ) + return flat.contiguous(), (seq_len, batch_size, hidden_size) + + +def _validate_gdn_hidden_for_cp_plan( + hidden_states: Tensor, plan: GdnRankExecutionPlan +) -> Tensor: + expected = int(plan.gdn_token_count) + if hidden_states.ndim != 3 or int(hidden_states.shape[0]) != expected: + raise ValueError( + "CP GDN-layout hidden_states must be [rank_gdn_tokens, 1, D], " + f"got {tuple(hidden_states.shape)} for {expected} planned tokens" + ) + if int(hidden_states.shape[1]) != 1: + raise ValueError( + "CP GDN-layout hidden_states must use a flattened local batch, " + f"got batch dimension {int(hidden_states.shape[1])}" + ) + return hidden_states.contiguous() + + +def _attention_original_shape_from_plan( + hidden_states: Tensor, plan: GdnRankExecutionPlan +) -> tuple[int, int, int]: + return (int(plan.attention_token_count), 1, int(hidden_states.shape[-1])) + + +def _restore_hidden_from_cp_flat( + flat: Tensor, original_shape: tuple[int, int, int] +) -> Tensor: + seq_len, batch_size, hidden_size = original_shape + if int(flat.shape[0]) != seq_len * batch_size: + raise ValueError( + "CP GDN output token count changed across layout exchange, got " + f"{int(flat.shape[0])} for original shape {original_shape}" + ) + return flat.reshape(batch_size, seq_len, hidden_size).transpose(0, 1).contiguous() + + +def _empty_autograd_dependency(reference: Tensor) -> Tensor: + return reference.new_zeros(()) + + +def _make_autograd_dependency(*tensors: Tensor | None) -> Tensor: + dependency: Tensor | None = None + for tensor in tensors: + if tensor is None or int(tensor.numel()) == 0: + continue + piece = tensor.reshape(-1)[:1].sum() * 0 + dependency = piece if dependency is None else dependency + piece + if dependency is None: + raise ValueError("at least one non-empty tensor is required") + return dependency + + +def _add_autograd_dependency(tensor: Tensor, dependency: Tensor) -> Tensor: + return tensor + dependency.to(dtype=tensor.dtype) + + +def _couple_parent_states( + conv_state: Tensor, recurrent_state: Tensor +) -> tuple[Tensor, Tensor]: + return _CoupledParentStates.apply(conv_state, recurrent_state) + + +class _CoupledParentStates(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, conv_state: Tensor, recurrent_state: Tensor + ) -> tuple[Tensor, Tensor]: + del ctx + return conv_state, recurrent_state + + @staticmethod + def backward( + ctx: Any, *grad_outputs: Tensor | None + ) -> tuple[Tensor | None, Tensor | None]: + del ctx + grad_conv, grad_recurrent = grad_outputs + return grad_conv, grad_recurrent + + +def _scale_state_gradient(tensor: Tensor, scale: float) -> Tensor: + return _ScaleStateGradient.apply(tensor, scale) + + +class _ScaleStateGradient(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, tensor: Tensor, scale: float) -> Tensor: + ctx.scale = scale + return tensor + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor | None) -> tuple[Tensor | None, None]: + (grad_output,) = grad_outputs + if grad_output is None: + return None, None + return grad_output * ctx.scale, None + + +def _gather_flat_bucket_streams( + qkv_flat: Tensor, + beta_flat: Tensor, + recurrent_g_flat: Tensor, + *, + layout: _BucketFlatLayout, + length: int, + segment_count: int, +) -> tuple[Tensor, Tensor, Tensor]: + return _FlatBucketStreamGather.apply( + qkv_flat, + beta_flat, + recurrent_g_flat, + layout.padded_indices, + layout.padded_mask, + length, + segment_count, + ) + + +def _gather_compact_bucket_streams( + qkv: Tensor, + beta: Tensor, + recurrent_g: Tensor, + bucket: GdnSegmentBucketPlan, +) -> tuple[Tensor, Tensor, Tensor]: + return _gather_bucket_streams_compact_fused( + qkv.reshape(-1, int(qkv.shape[-1])), + beta.reshape(-1, int(beta.shape[-1])), + recurrent_g.reshape(-1, int(recurrent_g.shape[-1])), + bucket.row_indices, + bucket.position_indices, + bucket.cu_seqlens, + token_count=int(bucket.real_token_count), + segment_count=int(bucket.segment_count), + sequence_length=int(qkv.shape[1]), + ) + + +class _FlatBucketStreamGather(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + qkv_flat: Tensor, + beta_flat: Tensor, + recurrent_g_flat: Tensor, + padded_indices: Tensor, + padded_mask: Tensor, + length: int, + segment_count: int, + ) -> tuple[Tensor, Tensor, Tensor]: + flat_indices = padded_indices.reshape(-1) + flat_mask = padded_mask.reshape(-1) + safe_indices = torch.where( + flat_mask, + flat_indices, + torch.zeros((), device=flat_indices.device, dtype=flat_indices.dtype), + ) + qkv = qkv_flat.index_select(0, safe_indices).reshape( + length, segment_count, int(qkv_flat.shape[-1]) + ) + beta = beta_flat.index_select(0, safe_indices).reshape( + length, segment_count, int(beta_flat.shape[-1]) + ) + recurrent_g = recurrent_g_flat.index_select(0, safe_indices).reshape( + length, segment_count, int(recurrent_g_flat.shape[-1]) + ) + qkv = qkv.masked_fill(~padded_mask.unsqueeze(-1), 0) + beta = beta.masked_fill(~padded_mask.unsqueeze(-1), 0) + recurrent_g = recurrent_g.masked_fill(~padded_mask.unsqueeze(-1), 0) + ctx.save_for_backward(safe_indices, flat_mask) + ctx.qkv_flat_count = int(qkv_flat.shape[0]) + ctx.beta_flat_count = int(beta_flat.shape[0]) + ctx.recurrent_g_flat_count = int(recurrent_g_flat.shape[0]) + return ( + qkv.permute(1, 2, 0).contiguous(), + beta.transpose(0, 1).contiguous(), + recurrent_g.transpose(0, 1).contiguous(), + ) + + @staticmethod + def backward( + ctx: Any, *grad_outputs: Tensor | None + ) -> tuple[Tensor | None, Tensor | None, Tensor | None, None, None, None, None]: + grad_qkv_bucket, grad_beta_bucket, grad_g_bucket = grad_outputs + safe_indices, flat_mask = ctx.saved_tensors + grad_qkv = ( + _bucket_stream_grad_to_flat( + grad_qkv_bucket.permute(2, 0, 1).contiguous() + if grad_qkv_bucket is not None + else None, + safe_indices, + flat_mask, + ctx.qkv_flat_count, + ) + if ctx.needs_input_grad[0] + else None + ) + grad_beta = ( + _bucket_stream_grad_to_flat( + grad_beta_bucket.transpose(0, 1).contiguous() + if grad_beta_bucket is not None + else None, + safe_indices, + flat_mask, + ctx.beta_flat_count, + ) + if ctx.needs_input_grad[1] + else None + ) + grad_g = ( + _bucket_stream_grad_to_flat( + grad_g_bucket.transpose(0, 1).contiguous() + if grad_g_bucket is not None + else None, + safe_indices, + flat_mask, + ctx.recurrent_g_flat_count, + ) + if ctx.needs_input_grad[2] + else None + ) + return grad_qkv, grad_beta, grad_g, None, None, None, None + + +def _bucket_stream_grad_to_flat( + grad: Tensor | None, + safe_indices: Tensor, + flat_mask: Tensor, + flat_count: int, +) -> Tensor | None: + if grad is None: + return None + grad_flat_values = grad.reshape(int(safe_indices.numel()), int(grad.shape[-1])) + grad_flat_values = grad_flat_values.masked_fill(~flat_mask.unsqueeze(-1), 0) + grad_flat = grad.new_zeros(flat_count, int(grad.shape[-1])) + return grad_flat.index_add(0, safe_indices, grad_flat_values) + + +def _scatter_compact_hidden( + compact: Tensor, + indices: Tensor, + *, + batch_size: int, + sequence_length: int, +) -> Tensor: + return _CompactHiddenScatter.apply(compact, indices, batch_size, sequence_length) + + +class _CompactHiddenScatter(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + compact: Tensor, + indices: Tensor, + batch_size: int, + sequence_length: int, + ) -> Tensor: + hidden_size = int(compact.shape[-1]) + flat = compact.new_zeros(batch_size * sequence_length, hidden_size) + if int(indices.numel()): + flat = flat.index_copy(0, indices, compact.reshape(-1, hidden_size)) + ctx.save_for_backward(indices) + ctx.batch_size = batch_size + ctx.sequence_length = sequence_length + return ( + flat.reshape(batch_size, sequence_length, hidden_size) + .transpose(0, 1) + .contiguous() + ) + + @staticmethod + def backward( + ctx: Any, *grad_outputs: Any + ) -> tuple[Tensor | None, None, None, None]: + (grad_output,) = grad_outputs + if grad_output is None: + return None, None, None, None + (indices,) = ctx.saved_tensors + flat_grad = grad_output.transpose(0, 1).reshape( + ctx.batch_size * ctx.sequence_length, int(grad_output.shape[-1]) + ) + return flat_grad.index_select(0, indices), None, None, None + + +def _project_gdn_inputs( + gdn: Any, hidden_states: Tensor +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + seq_len, batch_size, _ = hidden_states.shape + seq_len *= int(getattr(gdn, "sp_size", 1)) + qkvzba, _ = _in_proj(gdn, hidden_states) + qkvzba = qkvzba.transpose(0, 1) + qkv, gate, beta, alpha = torch.split( + qkvzba, + [ + (gdn.qk_dim * 2 + gdn.v_dim) // gdn.tp_size, + gdn.v_dim // gdn.tp_size, + gdn.num_value_heads // gdn.tp_size, + gdn.num_value_heads // gdn.tp_size, + ], + dim=-1, + ) + value_heads = _local_value_heads(gdn) + gate = gate.reshape( + batch_size, seq_len, value_heads, gdn.value_head_dim + ).contiguous() + beta = beta.reshape(batch_size, seq_len, value_heads).sigmoid().contiguous() + alpha = alpha.reshape(batch_size, seq_len, value_heads) + recurrent_g = ( + -gdn.A_log.exp() * F.softplus(alpha.float() + gdn.dt_bias) + ).contiguous() + return qkv.contiguous(), gate, beta, recurrent_g + + +def _in_proj(gdn: Any, hidden_states: Tensor) -> tuple[Tensor, Tensor | None]: + return gdn.in_proj(hidden_states) + + +def _gather_bucket_streams( + qkv: Tensor, + beta: Tensor, + recurrent_g: Tensor, + bucket: GdnSegmentBucketPlan, +) -> tuple[Tensor, Tensor, Tensor]: + layout = _bucket_flat_layout( + bucket, + sequence_length=int(qkv.shape[1]), + ) + return _gather_flat_bucket_streams( + qkv.reshape(-1, int(qkv.shape[-1])), + beta.reshape(-1, int(beta.shape[-1])), + recurrent_g.reshape(-1, int(recurrent_g.shape[-1])), + layout=layout, + length=int(bucket.length), + segment_count=int(bucket.segment_count), + ) + + +def _bucket_flat_layout( + bucket: GdnSegmentBucketPlan, *, sequence_length: int +) -> _BucketFlatLayout: + positions = bucket.position_indices.clamp_max(sequence_length - 1) + padded_indices = (bucket.row_indices * sequence_length + positions).contiguous() + padded_mask = bucket.real_mask.contiguous() + segment_major_indices = padded_indices.transpose(0, 1).contiguous() + segment_major_mask = padded_mask.transpose(0, 1).contiguous() + real_indices = segment_major_indices[segment_major_mask].contiguous() + output_mask = _bucket_output_mask(bucket).transpose(0, 1).contiguous() + output_indices = segment_major_indices[output_mask].contiguous() + output_selector = None + if bucket.output_mask is not None: + output_selector = output_mask[segment_major_mask].contiguous() + return _BucketFlatLayout( + padded_indices=padded_indices, + padded_mask=padded_mask, + real_indices=real_indices, + output_indices=output_indices, + output_selector=output_selector, + ) + + +def _project_gdn_output( + gdn: Any, + recurrent_output: Tensor, + gate: Tensor, + plan: GdnRankExecutionPlan, +) -> tuple[Tensor, Tensor | None]: + batch_size, seq_len, _, _ = recurrent_output.shape + with _nvtx_range("art_gdn_output_norm_gate", recurrent_output): + norm_out = _apply_gated_rms_norm(gdn, recurrent_output, gate) + norm_out = norm_out.reshape(batch_size, seq_len, _local_value_dim(gdn)) + norm_out = norm_out.transpose(0, 1).contiguous() + with _nvtx_range("art_gdn_out_proj", norm_out): + if plan.cp_size > 1: + out, out_bias = _out_proj_cp_full_shape(gdn, norm_out, plan) + else: + out, out_bias = _out_proj(gdn, norm_out) + return _mask_gdn_output(gdn, out, plan), out_bias + + +def _mask_gdn_output(gdn: Any, out: Tensor, plan: GdnRankExecutionPlan) -> Tensor: + real_mask = plan.real_token_mask.transpose(0, 1).unsqueeze(-1) + if tuple(real_mask.shape[:2]) == tuple(out.shape[:2]): + return out.masked_fill(~real_mask, 0) + full_batch = int(plan.packed_batch_size or plan.batch_size) + full_seq = int(plan.packed_sequence_length or plan.sequence_length) + full_count = full_batch * full_seq + local_indices = torch.tensor( + plan.gdn_token_indices, device=out.device, dtype=torch.long + ) + full_flat = torch.zeros(full_count, device=out.device, dtype=torch.bool) + if int(local_indices.numel()): + full_flat = full_flat.index_fill(0, local_indices, True) + full_mask = full_flat.reshape(full_batch, full_seq).transpose(0, 1).unsqueeze(-1) + if tuple(full_mask.shape[:2]) == tuple(out.shape[:2]): + return out.masked_fill(~full_mask, 0) + rank = _tp_rank(getattr(gdn.out_proj, "linear_proj", gdn.out_proj)) + start = rank * int(out.shape[0]) + end = start + int(out.shape[0]) + if end <= int(full_mask.shape[0]) and int(full_mask.shape[1]) == int(out.shape[1]): + return out.masked_fill(~full_mask[start:end], 0) + raise ValueError( + "GDN output mask shape must match projected output, got " + f"mask={tuple(real_mask.shape)} full_mask={tuple(full_mask.shape)} " + f"out={tuple(out.shape)}" + ) + + +def _out_proj_cp_full_shape( + gdn: Any, hidden_states: Tensor, plan: GdnRankExecutionPlan +) -> tuple[Tensor, Tensor | None]: + full_batch = int(plan.packed_batch_size or plan.batch_size) + full_seq = int(plan.packed_sequence_length or plan.sequence_length) + full_count = full_batch * full_seq + if full_count == int(hidden_states.shape[0]): + return _out_proj(gdn, hidden_states) + if int(hidden_states.shape[1]) != 1: + raise ValueError( + "CP GDN full-shape output projection expects flattened local batch, got " + f"{tuple(hidden_states.shape)}" + ) + local_indices = torch.tensor( + plan.gdn_token_indices, device=hidden_states.device, dtype=torch.long + ) + if int(local_indices.numel()) != int(hidden_states.shape[0]): + raise ValueError( + "CP GDN token index count must match local projection input, got " + f"{int(local_indices.numel())} indices for {tuple(hidden_states.shape)}" + ) + if int(local_indices.numel()) and int(local_indices.max().item()) >= full_count: + raise ValueError( + "CP GDN token index exceeds packed output shape, got " + f"max_index={int(local_indices.max().item())} full_count={full_count}" + ) + full_flat = hidden_states.new_zeros(full_count, int(hidden_states.shape[-1])) + if int(local_indices.numel()): + full_flat = full_flat.index_copy(0, local_indices, hidden_states.squeeze(1)) + full_hidden = ( + full_flat.reshape(full_batch, full_seq, int(hidden_states.shape[-1])) + .transpose(0, 1) + .contiguous() + ) + full_out, out_bias = _out_proj(gdn, full_hidden) + local_out = ( + full_out.transpose(0, 1) + .reshape(full_count, int(full_out.shape[-1])) + .index_select(0, local_indices) + .unsqueeze(1) + .contiguous() + ) + return local_out, out_bias + + +def _apply_gated_rms_norm(gdn: Any, x: Tensor, gate: Tensor) -> Tensor: + x_dtype = x.dtype + hidden = gdn.out_norm(x.reshape(-1, int(x.shape[-1]))) + gate = gate.reshape(-1, int(gate.shape[-1])) + return (hidden * gdn.act_fn(gate.float())).to(x_dtype) + + +def _out_proj(gdn: Any, hidden_states: Tensor) -> tuple[Tensor, Tensor | None]: + return gdn.out_proj(hidden_states) + + +def _apply_explicit_norm( + module: Any, + x: Tensor, + *, + config: Any, + weight_name: str, + bias_name: str, +) -> Tensor: + del config + x_dtype = x.dtype + x_float = x.float() + normalization = str(module.normalization) + if normalization == "RMSNorm": + normed = x_float * torch.rsqrt( + x_float.square().mean(dim=-1, keepdim=True) + float(module.eps) + ) + bias = None + elif normalization == "LayerNorm": + centered = x_float - x_float.mean(dim=-1, keepdim=True) + normed = centered * torch.rsqrt( + centered.square().mean(dim=-1, keepdim=True) + float(module.eps) + ) + bias = getattr(module, bias_name) + else: + raise ValueError(f"unsupported GDN normalization '{normalization}'") + + scale = getattr(module, weight_name).float() + if bool(module.zero_centered_gamma): + scale = scale + 1.0 + normed = normed * scale + if isinstance(bias, Tensor): + normed = normed + bias.float() + return normed.to(dtype=x_dtype) + + +def _uses_sequence_parallel(projection: Any) -> bool: + return bool(getattr(projection, "sequence_parallel", False)) and ( + _tp_world_size(projection) > 1 + ) + + +def _gdn_uses_sequence_parallel(gdn: Any) -> bool: + projection = getattr(gdn, "in_proj", None) + base_projection = getattr(projection, "in_proj", projection) + return _uses_sequence_parallel(base_projection) + + +def _tp_world_size(projection: Any) -> int: + del projection + from megatron.core import parallel_state as ps + + return int(ps.get_tensor_model_parallel_world_size()) + + +def _tp_rank(projection: Any) -> int: + del projection + from megatron.core import parallel_state as ps + + return int(ps.get_tensor_model_parallel_rank()) + + +def _local_key_heads(gdn: Any) -> int: + return int(gdn.num_key_heads // gdn.tp_size) + + +def _local_value_heads(gdn: Any) -> int: + return int(gdn.num_value_heads // gdn.tp_size) + + +def _local_value_dim(gdn: Any) -> int: + return _local_value_heads(gdn) * int(gdn.value_head_dim) + + +def _scatter_bucket_recurrent_output( + output: Tensor, bucket: GdnSegmentBucketPlan, bucket_output: Tensor +) -> Tensor: + return _scatter_bucket_output_fused( + output, + bucket_output, + bucket.row_indices, + bucket.position_indices, + _bucket_output_mask(bucket), + bucket.cu_seqlens, + ) + + +def _bucket_output_mask(bucket: GdnSegmentBucketPlan) -> Tensor: + output_mask = bucket.output_mask + return bucket.real_mask if output_mask is None else output_mask + + +def _materialize_indexed_family_state_table( + *, + plan: GdnRankExecutionPlan, + family_chunks: list[Tensor], + state_chunks: list[Tensor], + zero_state: Tensor, +) -> Tensor: + table = zero_state.detach() + if not state_chunks: + return table.requires_grad_(True) + values = torch.cat(state_chunks, dim=0) + family_indices = torch.cat(family_chunks, dim=0) + return table.index_copy(0, family_indices, values) + + +def _materialize_ordered_family_state_table( + *, + family_chunks: list[Tensor], + state_chunks: list[Tensor], + zero_state: Tensor, +) -> Tensor: + if len(family_chunks) != len(state_chunks): + raise RuntimeError("family and state chunk counts must match") + table = zero_state.detach().requires_grad_(True) + for family_indices, states in zip(family_chunks, state_chunks, strict=True): + table = table.index_copy(0, family_indices, states) + return table + + +def _replace_indexed_family_states( + table: Tensor, + *, + family_chunks: list[Tensor], + state_chunks: list[Tensor], +) -> Tensor: + if not state_chunks: + return table + return table.index_copy( + 0, + torch.cat(family_chunks, dim=0), + torch.cat(state_chunks, dim=0), + ) + + +def _exchange_parent_state_rows( + conv_table: Tensor, + rec_table: Tensor, + *, + transfers: tuple[GdnParentStateTransferPlan, ...], + group: Any, +) -> tuple[Tensor, Tensor, Tensor]: + if not transfers: + return conv_table, rec_table, _empty_autograd_dependency(conv_table) + conv_table, rec_table = _ParentStateExchange.apply( + conv_table, rec_table, transfers, group + ) + return conv_table, rec_table, _make_autograd_dependency(conv_table, rec_table) + + +class _ParentStateExchange(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + conv_table: Tensor, + rec_table: Tensor, + transfers: tuple[GdnParentStateTransferPlan, ...], + group: Any, + ) -> tuple[Tensor, Tensor]: + ctx.group = group + ctx.transfers = transfers + ctx.save_for_backward(conv_table, rec_table) + return ( + _exchange_parent_state_tensor_forward( + conv_table, + transfers, + group=group, + ), + _exchange_parent_state_tensor_forward( + rec_table, + transfers, + group=group, + ), + ) + + @staticmethod + def backward( + ctx: Any, *grad_outputs: Tensor | None + ) -> tuple[Tensor | None, Tensor | None, None, None]: + grad_conv, grad_rec = grad_outputs + conv_ref, rec_ref = ctx.saved_tensors + return ( + _exchange_parent_state_tensor_backward( + _zero_if_none(grad_conv, conv_ref), + ctx.transfers, + group=ctx.group, + ), + _exchange_parent_state_tensor_backward( + _zero_if_none(grad_rec, rec_ref), + ctx.transfers, + group=ctx.group, + ), + None, + None, + ) + + +def _exchange_parent_state_tensor_forward( + table: Tensor, + transfers: tuple[GdnParentStateTransferPlan, ...], + *, + group: Any, +) -> Tensor: + rank = torch.distributed.get_rank(group) # ty: ignore[possibly-missing-attribute] + output = table.clone() + recvs = _exchange_parent_state_rows_all_to_all( + table, transfers, rank=rank, reverse=False, group=group + ) + for transfer, rows in recvs: + index = _parent_state_index_tensor(transfer, device=table.device) + output.index_copy_(0, index, rows) + return output + + +def _exchange_parent_state_tensor_backward( + grad_output: Tensor, + transfers: tuple[GdnParentStateTransferPlan, ...], + *, + group: Any, +) -> Tensor: + rank = torch.distributed.get_rank(group) # ty: ignore[possibly-missing-attribute] + grad_input = grad_output.clone() + for transfer in transfers: + if transfer.dest_rank != rank: + continue + index = _parent_state_index_tensor(transfer, device=grad_output.device) + grad_input.index_fill_(0, index, 0) + recvs = _exchange_parent_state_rows_all_to_all( + grad_output, transfers, rank=rank, reverse=True, group=group + ) + for transfer, rows in recvs: + index = _parent_state_index_tensor(transfer, device=grad_output.device) + grad_input.index_add_(0, index, rows) + return grad_input + + +def _zero_if_none(grad: Tensor | None, reference: Tensor) -> Tensor: + if grad is None: + return reference.new_zeros(reference.shape) + return grad.contiguous() + + +def _exchange_parent_state_rows_all_to_all( + table: Tensor, + transfers: tuple[GdnParentStateTransferPlan, ...], + *, + rank: int, + reverse: bool, + group: Any, +) -> list[tuple[GdnParentStateTransferPlan, Tensor]]: + world_size = torch.distributed.get_world_size(group) # ty: ignore[possibly-missing-attribute] + send_counts = [0 for _ in range(world_size)] + recv_counts = [0 for _ in range(world_size)] + send_pieces: list[Tensor] = [] + for peer_rank in range(world_size): + for transfer in transfers: + send_rank = transfer.dest_rank if reverse else transfer.source_rank + recv_rank = transfer.source_rank if reverse else transfer.dest_rank + if send_rank == recv_rank: + continue + row_count = len(transfer.family_indices) + if rank == send_rank and peer_rank == recv_rank: + index = _parent_state_index_tensor(transfer, device=table.device) + send_pieces.append(table.index_select(0, index).contiguous()) + send_counts[peer_rank] += row_count + if rank == recv_rank and peer_rank == send_rank: + recv_counts[peer_rank] += row_count + + trailing_shape = tuple(table.shape[1:]) + send_buffer = ( + torch.cat(send_pieces, dim=0) + if send_pieces + else table.new_empty((0, *trailing_shape)) + ) + recv_buffer = table.new_empty((sum(recv_counts), *trailing_shape)) + work = torch.distributed.all_to_all_single( # ty: ignore[possibly-missing-attribute] + recv_buffer, + send_buffer, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=group, + async_op=True, + ) + work.wait() + + recvs: list[tuple[GdnParentStateTransferPlan, Tensor]] = [] + offset = 0 + for peer_rank, count in enumerate(recv_counts): + peer_end = offset + count + for transfer in transfers: + send_rank = transfer.dest_rank if reverse else transfer.source_rank + recv_rank = transfer.source_rank if reverse else transfer.dest_rank + if send_rank == recv_rank: + continue + if rank != recv_rank or peer_rank != send_rank: + continue + rows = len(transfer.family_indices) + recvs.append((transfer, recv_buffer[offset : offset + rows])) + offset += rows + if offset != peer_end: + raise RuntimeError( + "parent-state exchange unpack mismatch: " + f"rank={rank} peer={peer_rank} consumed={offset} expected={peer_end}" + ) + return recvs + + +def _parent_state_index_tensor( + transfer: GdnParentStateTransferPlan, + *, + device: torch.device, +) -> Tensor: + if ( + transfer.family_indices_tensor is not None + and transfer.family_indices_tensor.device == device + ): + return transfer.family_indices_tensor + return torch.tensor(transfer.family_indices, device=device, dtype=torch.long) + + +def _run_gdn_segment( + gdn: Any, + hidden_states: Tensor, + *, + conv_initial: Tensor, + recurrent_initial: Tensor, + output_final_state: bool = True, +) -> tuple[Tensor, Tensor | None, Tensor | None, Tensor | None]: + _disable_reentrant_te_linear_transpose_cache(gdn) + seq_len, batch_size, _ = hidden_states.shape + if int(conv_initial.shape[0]) != batch_size: + raise ValueError( + "conv_initial batch must match hidden_states batch, got " + f"{tuple(conv_initial.shape)} for hidden {tuple(hidden_states.shape)}" + ) + if int(recurrent_initial.shape[0]) != batch_size: + raise ValueError( + "recurrent_initial batch must match hidden_states batch, got " + f"{tuple(recurrent_initial.shape)} for hidden {tuple(hidden_states.shape)}" + ) + + with _nvtx_range("art_gdn_in_proj", hidden_states): + qkvzba, _ = _in_proj(gdn, hidden_states) + qkvzba = qkvzba.transpose(0, 1) + + with _nvtx_range("art_gdn_qkv_gate_beta_alpha_split_reshape", qkvzba): + qkv, gate, beta, alpha = torch.split( + qkvzba, + [ + (gdn.qk_dim * 2 + gdn.v_dim) // gdn.tp_size, + gdn.v_dim // gdn.tp_size, + gdn.num_value_heads // gdn.tp_size, + gdn.num_value_heads // gdn.tp_size, + ], + dim=-1, + ) + key_heads = _local_key_heads(gdn) + value_heads = _local_value_heads(gdn) + gate = gate.reshape(batch_size, seq_len, value_heads, gdn.value_head_dim) + beta = beta.reshape(batch_size, seq_len, value_heads) + alpha = alpha.reshape(batch_size, seq_len, value_heads) + + with _nvtx_range("art_gdn_causal_conv_forward", qkv): + qkv = qkv.transpose(1, 2) + qkv, conv_final = _causal_conv1d_with_state( + gdn, + qkv, + conv_initial, + output_final_state=output_final_state, + ) + qkv = qkv.transpose(1, 2) + + with _nvtx_range("art_gdn_qkv_head_prepare", qkv): + query, key, value = torch.split( + qkv, + [ + gdn.qk_dim // gdn.tp_size, + gdn.qk_dim // gdn.tp_size, + gdn.v_dim // gdn.tp_size, + ], + dim=-1, + ) + query = query.reshape(batch_size, seq_len, key_heads, gdn.key_head_dim) + key = key.reshape(batch_size, seq_len, key_heads, gdn.key_head_dim) + value = value.reshape(batch_size, seq_len, value_heads, gdn.value_head_dim) + if gdn.use_qk_l2norm: + query = _l2norm(query.contiguous()) + key = _l2norm(key.contiguous()) + if gdn.num_value_heads // gdn.num_key_heads > 1: + repeat = gdn.num_value_heads // gdn.num_key_heads + query = query.repeat_interleave(repeat, dim=2) + key = key.repeat_interleave(repeat, dim=2) + + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + gate = gate.contiguous() + beta = beta.contiguous() + alpha = alpha.contiguous() + + with _nvtx_range("art_gdn_recurrent_gate_prepare", alpha): + g = -gdn.A_log.exp() * F.softplus(alpha.float() + gdn.dt_bias) + beta = beta.sigmoid() + + with _nvtx_range("art_gdn_recurrent_forward", query): + recurrent_out, recurrent_final = _chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_initial, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=False, + ) + + with _nvtx_range("art_gdn_output_norm_gate", recurrent_out): + norm_out = _apply_gated_rms_norm(gdn, recurrent_out, gate) + norm_out = norm_out.reshape(batch_size, seq_len, _local_value_dim(gdn)) + norm_out = norm_out.transpose(0, 1).contiguous() + with _nvtx_range("art_gdn_out_proj", norm_out): + out, out_bias = _out_proj(gdn, norm_out) + return out, out_bias, conv_final, recurrent_final + + +def _run_gdn_prepared_varlen_batch( + gdn: Any, + qkv: Tensor, + *, + beta: Tensor, + recurrent_g: Tensor, + bucket: GdnSegmentBucketPlan, + conv_initial: Tensor, + recurrent_initial: Tensor, + output_final_state: bool = True, +) -> tuple[Tensor, Tensor | None, Tensor | None]: + _disable_reentrant_te_linear_transpose_cache(gdn) + batch_size, _, max_len = qkv.shape + if int(bucket.length) != max_len or int(bucket.segment_count) != batch_size: + raise ValueError( + "GDN prepared varlen bucket shape mismatch, got " + f"qkv={tuple(qkv.shape)} bucket_len={bucket.length} " + f"segments={bucket.segment_count}" + ) + if int(conv_initial.shape[0]) != batch_size: + raise ValueError( + "conv_initial batch must match bucket segment count, got " + f"{tuple(conv_initial.shape)} for {batch_size} segments" + ) + if int(recurrent_initial.shape[0]) != batch_size: + raise ValueError( + "recurrent_initial batch must match bucket segment count, got " + f"{tuple(recurrent_initial.shape)} for {batch_size} segments" + ) + + with _nvtx_range("art_gdn_causal_conv_forward", qkv): + qkv, conv_final = _causal_conv1d_varlen_with_state( + gdn, + qkv, + conv_initial, + bucket.lengths, + output_final_state=output_final_state, + ) + qkv = qkv.transpose(1, 2) + + with _nvtx_range("art_gdn_qkv_head_prepare", qkv): + query, key, value = torch.split( + qkv, + [ + gdn.qk_dim // gdn.tp_size, + gdn.qk_dim // gdn.tp_size, + gdn.v_dim // gdn.tp_size, + ], + dim=-1, + ) + key_heads = _local_key_heads(gdn) + value_heads = _local_value_heads(gdn) + query = query.reshape(batch_size, max_len, key_heads, gdn.key_head_dim) + key = key.reshape(batch_size, max_len, key_heads, gdn.key_head_dim) + value = value.reshape(batch_size, max_len, value_heads, gdn.value_head_dim) + if gdn.use_qk_l2norm: + query = _l2norm(query.contiguous()) + key = _l2norm(key.contiguous()) + if gdn.num_value_heads // gdn.num_key_heads > 1: + repeat = gdn.num_value_heads // gdn.num_key_heads + query = query.repeat_interleave(repeat, dim=2) + key = key.repeat_interleave(repeat, dim=2) + + real_mask = bucket.real_mask.transpose(0, 1) + query = query[real_mask].unsqueeze(0).contiguous() + key = key[real_mask].unsqueeze(0).contiguous() + value = value[real_mask].unsqueeze(0).contiguous() + beta = beta[real_mask].unsqueeze(0).contiguous() + recurrent_g = recurrent_g[real_mask].unsqueeze(0).contiguous() + + with _nvtx_range("art_gdn_recurrent_forward", query): + recurrent_out, recurrent_final = _chunk_gated_delta_rule( + query, + key, + value, + g=recurrent_g, + beta=beta, + initial_state=recurrent_initial, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=False, + cu_seqlens=bucket.cu_seqlens, + ) + return recurrent_out, conv_final, recurrent_final + + +def _run_gdn_varlen_batch( + gdn: Any, + hidden_states: Tensor, + *, + bucket: GdnSegmentBucketPlan, + conv_initial: Tensor, + recurrent_initial: Tensor, + output_final_state: bool = True, +) -> tuple[Tensor, Tensor | None, Tensor | None, Tensor | None]: + _disable_reentrant_te_linear_transpose_cache(gdn) + max_len, batch_size, _ = hidden_states.shape + if int(bucket.length) != max_len or int(bucket.segment_count) != batch_size: + raise ValueError( + "GDN varlen bucket shape mismatch, got " + f"hidden={tuple(hidden_states.shape)} bucket_len={bucket.length} " + f"segments={bucket.segment_count}" + ) + if int(conv_initial.shape[0]) != batch_size: + raise ValueError( + "conv_initial batch must match bucket segment count, got " + f"{tuple(conv_initial.shape)} for {batch_size} segments" + ) + if int(recurrent_initial.shape[0]) != batch_size: + raise ValueError( + "recurrent_initial batch must match bucket segment count, got " + f"{tuple(recurrent_initial.shape)} for {batch_size} segments" + ) + + with _nvtx_range("art_gdn_in_proj", hidden_states): + qkvzba, _ = _in_proj(gdn, hidden_states) + qkvzba = qkvzba.transpose(0, 1) + + with _nvtx_range("art_gdn_qkv_gate_beta_alpha_split_reshape", qkvzba): + qkv, gate, beta, alpha = torch.split( + qkvzba, + [ + (gdn.qk_dim * 2 + gdn.v_dim) // gdn.tp_size, + gdn.v_dim // gdn.tp_size, + gdn.num_value_heads // gdn.tp_size, + gdn.num_value_heads // gdn.tp_size, + ], + dim=-1, + ) + key_heads = _local_key_heads(gdn) + value_heads = _local_value_heads(gdn) + gate = gate.reshape(batch_size, max_len, value_heads, gdn.value_head_dim) + beta = beta.reshape(batch_size, max_len, value_heads) + alpha = alpha.reshape(batch_size, max_len, value_heads) + + with _nvtx_range("art_gdn_causal_conv_forward", qkv): + qkv = qkv.transpose(1, 2).contiguous() + qkv, conv_final = _causal_conv1d_varlen_with_state( + gdn, + qkv, + conv_initial, + bucket.lengths, + output_final_state=output_final_state, + ) + qkv = qkv.transpose(1, 2) + + with _nvtx_range("art_gdn_qkv_head_prepare", qkv): + query, key, value = torch.split( + qkv, + [ + gdn.qk_dim // gdn.tp_size, + gdn.qk_dim // gdn.tp_size, + gdn.v_dim // gdn.tp_size, + ], + dim=-1, + ) + query = query.reshape(batch_size, max_len, key_heads, gdn.key_head_dim) + key = key.reshape(batch_size, max_len, key_heads, gdn.key_head_dim) + value = value.reshape(batch_size, max_len, value_heads, gdn.value_head_dim) + if gdn.use_qk_l2norm: + query = _l2norm(query.contiguous()) + key = _l2norm(key.contiguous()) + if gdn.num_value_heads // gdn.num_key_heads > 1: + repeat = gdn.num_value_heads // gdn.num_key_heads + query = query.repeat_interleave(repeat, dim=2) + key = key.repeat_interleave(repeat, dim=2) + + with _nvtx_range("art_gdn_recurrent_gate_prepare", alpha): + g = -gdn.A_log.exp() * F.softplus(alpha.float() + gdn.dt_bias) + beta = beta.sigmoid() + + real_mask = bucket.real_mask.transpose(0, 1) + query = query[real_mask].unsqueeze(0).contiguous() + key = key[real_mask].unsqueeze(0).contiguous() + value = value[real_mask].unsqueeze(0).contiguous() + gate = gate[real_mask].unsqueeze(0).contiguous() + beta = beta[real_mask].unsqueeze(0).contiguous() + g = g[real_mask].unsqueeze(0).contiguous() + + with _nvtx_range("art_gdn_recurrent_forward", query): + recurrent_out, recurrent_final = _chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_initial, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=False, + cu_seqlens=bucket.cu_seqlens, + ) + + with _nvtx_range("art_gdn_output_norm_gate", recurrent_out): + norm_out = _apply_gated_rms_norm(gdn, recurrent_out, gate) + if norm_out.ndim == 4: + norm_out = norm_out.flatten(2).transpose(0, 1).contiguous() + elif norm_out.ndim == 3: + norm_out = ( + norm_out.transpose(0, 1).contiguous() + if int(norm_out.shape[0]) == 1 + else norm_out.reshape( + norm_out.shape[0], 1, _local_value_dim(gdn) + ).contiguous() + ) + elif norm_out.ndim == 2: + norm_out = norm_out.reshape( + 1, recurrent_out.shape[1], _local_value_dim(gdn) + ) + norm_out = norm_out.transpose(0, 1).contiguous() + else: + raise RuntimeError( + f"unexpected GDN norm output shape {tuple(norm_out.shape)}" + ) + with _nvtx_range("art_gdn_out_proj", norm_out): + out, out_bias = _out_proj(gdn, norm_out) + return out, out_bias, conv_final, recurrent_final + + +def _conv_final_from_varlen_qkv( + qkv: Tensor, conv_initial: Tensor, lengths: Tensor +) -> Tensor: + tail_width = int(conv_initial.shape[-1]) + if tail_width == 0: + return conv_initial + batch_size, channel_count, max_len = qkv.shape + arange = torch.arange(batch_size, device=qkv.device) + pieces = [] + for tail_offset in range(tail_width): + source = lengths - tail_width + tail_offset + from_qkv = source >= 0 + qkv_index = source.clamp(min=0, max=max_len - 1) + init_index = (source + tail_width).clamp(min=0, max=tail_width - 1) + qkv_piece = qkv[arange, :, qkv_index] + init_piece = conv_initial[arange, :, init_index] + pieces.append(torch.where(from_qkv.unsqueeze(1), qkv_piece, init_piece)) + return torch.stack(pieces, dim=-1).reshape(batch_size, channel_count, tail_width) + + +def _causal_conv1d_varlen_with_state( + gdn: Any, + qkv: Tensor, + conv_initial: Tensor, + lengths: Tensor, + *, + output_final_state: bool, +) -> tuple[Tensor, Tensor | None]: + if str(getattr(gdn, "activation", "")) == "gelu": + return gdn_varlen_causal_conv_gelu( + gdn, + qkv, + conv_initial, + lengths, + output_final_state=output_final_state, + ) + conv_final = ( + _conv_final_from_varlen_qkv(qkv, conv_initial, lengths) + if output_final_state + else None + ) + out, _ = _causal_conv1d_with_state( + gdn, + qkv, + conv_initial, + output_final_state=False, + ) + return out, conv_final + + +def _causal_conv1d_packed_varlen_with_state( + gdn: Any, + qkv: Tensor, + conv_initial: Tensor, + cu_seqlens: Tensor, + *, + output_final_state: bool, +) -> tuple[Tensor, Tensor | None]: + return packed_varlen_causal_conv( + qkv, + cu_seqlens, + conv_initial, + gdn.conv1d.weight.squeeze(1), + gdn.conv1d.bias, + activation=str(getattr(gdn, "activation", "gelu")), + output_final_state=output_final_state, + ) + + +def _causal_conv1d_with_state( + gdn: Any, + qkv: Tensor, + conv_initial: Tensor, + *, + output_final_state: bool, +) -> tuple[Tensor, Tensor | None]: + weight = gdn.conv1d.weight.squeeze(1) + bias = gdn.conv1d.bias + if not bool( + getattr(gdn.config, "deterministic_mode", False) + ) and gdn.activation in ("silu", "swish"): + qkv_fast = _channel_last_conv1d_layout(qkv) + conv_initial_fast = _channel_last_conv1d_layout(conv_initial) + if qkv_fast is not None and conv_initial_fast is not None: + conv_result = causal_conv1d_fn( + x=qkv_fast, + weight=weight, + bias=bias, + initial_states=conv_initial_fast, + return_final_states=output_final_state, + activation=gdn.activation, + ) + if output_final_state: + out, final = conv_result + else: + out, final = conv_result, None + return out, final + + qkv_dtype = qkv.dtype + if not bool(getattr(gdn.config, "deterministic_mode", False)): + final = ( + _conv_final_from_dense_qkv(qkv, conv_initial, weight.shape[1]) + if output_final_state + else None + ) + qkv_fast = _channel_last_conv1d_layout(qkv) + conv_initial_fast = _channel_last_conv1d_layout(conv_initial) + if qkv_fast is not None and conv_initial_fast is not None: + out = causal_conv1d_fn( + x=qkv_fast, + weight=weight, + bias=bias, + initial_states=conv_initial_fast, + return_final_states=False, + activation=None, + ) + out = gdn.act_fn(out).to(dtype=qkv_dtype) + return out, final + + extended = torch.cat([conv_initial, qkv], dim=-1) + out = F.conv1d( + extended, weight.unsqueeze(1), bias, padding=0, groups=extended.shape[1] + ) + out = out[..., : qkv.shape[-1]] + out = gdn.act_fn(out).to(dtype=qkv_dtype) + final = ( + extended[..., -(weight.shape[1] - 1) :].to(dtype=qkv_dtype) + if output_final_state + else None + ) + return out, final + + +def _conv_final_from_dense_qkv( + qkv: Tensor, conv_initial: Tensor, kernel_width: int +) -> Tensor: + tail_width = int(kernel_width) - 1 + if tail_width <= 0: + return conv_initial[..., :0].to(dtype=qkv.dtype) + if int(qkv.shape[-1]) >= tail_width: + return qkv[..., -tail_width:].to(dtype=qkv.dtype) + initial_width = tail_width - int(qkv.shape[-1]) + return torch.cat([conv_initial[..., -initial_width:], qkv], dim=-1).to( + dtype=qkv.dtype + ) + + +def _channel_last_conv1d_layout(tensor: Tensor) -> Tensor | None: + if _causal_conv1d_layout_supported(tensor): + return tensor + channel_last = tensor.transpose(1, 2).contiguous().transpose(1, 2) + if _causal_conv1d_layout_supported(channel_last): + return channel_last + return None + + +def _causal_conv1d_layout_supported(tensor: Tensor) -> bool: + return ( + int(tensor.shape[-1]) >= 8 + and int(tensor.stride(1)) == 1 + and all(int(tensor.stride(dim)) % 8 == 0 for dim in (0, 2)) + ) + + +def _disable_reentrant_te_linear_transpose_cache(gdn: Any) -> None: + if getattr(gdn, "_art_reentrant_te_linear_transpose_cache_disabled", False): + return + for root in (getattr(gdn, "in_proj", None), getattr(gdn, "out_proj", None)): + if isinstance(root, torch.nn.Module): + linears = root.modules() + else: + linears = (root,) + for linear in linears: + if hasattr(linear, "disable_parameter_transpose_cache"): + linear.disable_parameter_transpose_cache = True + gdn._art_reentrant_te_linear_transpose_cache_disabled = True + + +def run_gdn_bucket( + bucket: GdnSegmentBucketPlan, + projected_streams: tuple[Tensor, Tensor, Tensor], + parent_states: tuple[Tensor, Tensor], + *, + gdn: Any, + output_final_state: bool = True, +) -> tuple[Tensor, Tensor | None, Tensor | None]: + _disable_reentrant_te_linear_transpose_cache(gdn) + qkv, beta, recurrent_g = projected_streams + conv_initial, recurrent_initial = parent_states + token_count = int(qkv.shape[0]) if qkv.ndim == 2 else -1 + segment_count = int(bucket.segment_count) + if qkv.ndim != 2: + raise ValueError( + "GDN bucket execution requires compact projected streams; " + f"got qkv shape {tuple(qkv.shape)}" + ) + if token_count != int(bucket.real_token_count): + raise ValueError( + "GDN packed varlen token count mismatch, got " + f"qkv={tuple(qkv.shape)} and bucket tokens={bucket.real_token_count}" + ) + if tuple(beta.shape[:1]) != (token_count,) or tuple(recurrent_g.shape) != tuple( + beta.shape + ): + raise ValueError( + "packed beta/recurrent_g must be [tokens, heads], got " + f"{tuple(beta.shape)} and {tuple(recurrent_g.shape)}" + ) + if int(conv_initial.shape[0]) != segment_count: + raise ValueError( + "conv_initial batch must match bucket segment count, got " + f"{tuple(conv_initial.shape)} for {segment_count} segments" + ) + if int(recurrent_initial.shape[0]) != segment_count: + raise ValueError( + "recurrent_initial batch must match bucket segment count, got " + f"{tuple(recurrent_initial.shape)} for {segment_count} segments" + ) + + with _nvtx_range("art_gdn_causal_conv_forward", qkv): + qkv, conv_final = _causal_conv1d_packed_varlen_with_state( + gdn, + qkv, + conv_initial, + bucket.cu_seqlens, + output_final_state=output_final_state, + ) + + with _nvtx_range("art_gdn_qkv_head_prepare", qkv): + query, key, value, beta, recurrent_g = _prepare_packed_recurrent_inputs_fused( + qkv, + beta, + recurrent_g, + key_heads=_local_key_heads(gdn), + value_heads=_local_value_heads(gdn), + key_dim=int(gdn.key_head_dim), + value_dim=int(gdn.value_head_dim), + ) + if gdn.use_qk_l2norm: + query = l2norm(query.contiguous()) + key = l2norm(key.contiguous()) + + with _nvtx_range("art_gdn_recurrent_forward", query): + recurrent_out, recurrent_final = _chunk_gated_delta_rule( + query, + key, + value, + g=recurrent_g, + beta=beta, + initial_state=recurrent_initial, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=False, + cu_seqlens=bucket.cu_seqlens, + ) + return recurrent_out, conv_final, recurrent_final + + +def _zero_conv_state( + gdn: Any, + hidden_states: Tensor, + row: int | None = None, + *, + batch_size: int = 1, +) -> Tensor: + del row + return hidden_states.new_zeros( + batch_size, + gdn.conv_dim_local_tp, + gdn.conv_kernel_dim - 1, + ) + + +def _zero_recurrent_state( + gdn: Any, + hidden_states: Tensor, + row: int | None = None, + *, + batch_size: int = 1, +) -> Tensor: + del row + return hidden_states.new_zeros( + batch_size, + gdn.num_v_heads_local_tp, + gdn.key_head_dim, + gdn.value_head_dim, + dtype=torch.float32, + ) + + +def _default_cp_rank(cp_size: int) -> int: + del cp_size + from megatron.core import parallel_state as ps + + return int(ps.get_context_parallel_rank()) + + +def _default_cp_size() -> int: + from megatron.core import parallel_state as ps + + return max(1, int(ps.get_context_parallel_world_size())) + + +def _default_cp_group(cp_size: int) -> Any: + del cp_size + from megatron.core import parallel_state as ps + + return ps.get_context_parallel_group() + + +def _l2norm(x: Tensor) -> Tensor: + return l2norm(x) + + +def _chunk_gated_delta_rule(*args: Any, **kwargs: Any) -> tuple[Tensor, Tensor | None]: + return chunk_gated_delta_rule(*args, **kwargs) + + +@contextmanager +def _nvtx_range(label: str, tensor: Tensor | None = None) -> Iterator[None]: + if _NVTX_ENABLED.get() and tensor is not None and tensor.is_cuda: + torch.cuda.nvtx.range_push(label) + try: + yield + finally: + torch.cuda.nvtx.range_pop() + return + yield + + +@contextmanager +def gdn_nvtx_ranges(enabled: bool = True) -> Iterator[None]: + token = _NVTX_ENABLED.set(bool(enabled)) + try: + yield + finally: + _NVTX_ENABLED.reset(token) diff --git a/src/art/megatron/gdn/segment_layout.py b/src/art/megatron/gdn/segment_layout.py new file mode 100644 index 000000000..0dc4bdfdf --- /dev/null +++ b/src/art/megatron/gdn/segment_layout.py @@ -0,0 +1,940 @@ +from __future__ import annotations + +from typing import Any + +import torch +from torch import Tensor +import triton +import triton.language as tl + + +@triton.jit(do_not_specialize=["segment_count"]) +def _segment_from_cu(cu_seqlens, n, segment_count): + lo = n * 0 + hi = lo + segment_count + for _ in tl.static_range(0, 16): + mid = (lo + hi) // 2 + start = tl.load(cu_seqlens + mid) + take_upper = start <= n + lo = tl.where(take_upper, mid, lo) + hi = tl.where(take_upper, hi, mid) + return lo, n - tl.load(cu_seqlens + lo) + + +@triton.jit(do_not_specialize=["token_count", "segment_count", "sequence_length"]) +def _gather_compact_qkv_kernel( + qkv_flat, + row_indices, + position_indices, + cu_seqlens, + out, + token_count, + segment_count, + sequence_length, + channels: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + n = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + d = tl.program_id(1) * BLOCK_D + tl.arange(0, BLOCK_D) + token_mask = n < token_count + segment, offset = _segment_from_cu(cu_seqlens, n, segment_count) + p = offset * segment_count + segment + row = tl.load(row_indices + p, mask=token_mask, other=0) + pos = tl.load(position_indices + p, mask=token_mask, other=0) + src = row * sequence_length + pos + n64 = n.to(tl.int64) + d64 = d.to(tl.int64) + src64 = src.to(tl.int64) + mask = token_mask[:, None] & (d[None, :] < channels) + values = tl.load( + qkv_flat + src64[:, None] * channels + d64[None, :], + mask=mask, + other=0.0, + ) + tl.store(out + n64[:, None] * channels + d64[None, :], values, mask=mask) + + +@triton.jit(do_not_specialize=["token_count", "segment_count", "sequence_length"]) +def _gather_compact_aux_kernel( + x_flat, + row_indices, + position_indices, + cu_seqlens, + out, + token_count, + segment_count, + sequence_length, + width: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + n = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + d = tl.program_id(1) * BLOCK_D + tl.arange(0, BLOCK_D) + token_mask = n < token_count + segment, offset = _segment_from_cu(cu_seqlens, n, segment_count) + p = offset * segment_count + segment + row = tl.load(row_indices + p, mask=token_mask, other=0) + pos = tl.load(position_indices + p, mask=token_mask, other=0) + src = row * sequence_length + pos + n64 = n.to(tl.int64) + d64 = d.to(tl.int64) + src64 = src.to(tl.int64) + mask = token_mask[:, None] & (d[None, :] < width) + values = tl.load( + x_flat + src64[:, None] * width + d64[None, :], + mask=mask, + other=0.0, + ) + tl.store(out + n64[:, None] * width + d64[None, :], values, mask=mask) + + +@triton.jit(do_not_specialize=["token_count", "segment_count", "sequence_length"]) +def _scatter_compact_qkv_grad_kernel( + grad_out, + row_indices, + position_indices, + cu_seqlens, + grad_flat, + token_count, + segment_count, + sequence_length, + channels: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + n = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + d = tl.program_id(1) * BLOCK_D + tl.arange(0, BLOCK_D) + token_mask = n < token_count + segment, offset = _segment_from_cu(cu_seqlens, n, segment_count) + p = offset * segment_count + segment + row = tl.load(row_indices + p, mask=token_mask, other=0) + pos = tl.load(position_indices + p, mask=token_mask, other=0) + dst = row * sequence_length + pos + n64 = n.to(tl.int64) + d64 = d.to(tl.int64) + dst64 = dst.to(tl.int64) + mask = token_mask[:, None] & (d[None, :] < channels) + values = tl.load( + grad_out + n64[:, None] * channels + d64[None, :], + mask=mask, + other=0.0, + ) + tl.atomic_add( + grad_flat + dst64[:, None] * channels + d64[None, :], + values, + sem="relaxed", + mask=mask, + ) + + +@triton.jit(do_not_specialize=["token_count", "segment_count", "sequence_length"]) +def _scatter_compact_aux_grad_kernel( + grad_out, + row_indices, + position_indices, + cu_seqlens, + grad_flat, + token_count, + segment_count, + sequence_length, + width: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + n = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + d = tl.program_id(1) * BLOCK_D + tl.arange(0, BLOCK_D) + token_mask = n < token_count + segment, offset = _segment_from_cu(cu_seqlens, n, segment_count) + p = offset * segment_count + segment + row = tl.load(row_indices + p, mask=token_mask, other=0) + pos = tl.load(position_indices + p, mask=token_mask, other=0) + dst = row * sequence_length + pos + n64 = n.to(tl.int64) + d64 = d.to(tl.int64) + dst64 = dst.to(tl.int64) + mask = token_mask[:, None] & (d[None, :] < width) + values = tl.load( + grad_out + n64[:, None] * width + d64[None, :], + mask=mask, + other=0.0, + ) + tl.atomic_add( + grad_flat + dst64[:, None] * width + d64[None, :], + values, + sem="relaxed", + mask=mask, + ) + + +@triton.jit(do_not_specialize=["token_count"]) +def _prepare_packed_qkv_kernel( + qkv, + query, + key, + value, + token_count, + channels: tl.constexpr, + key_heads: tl.constexpr, + value_heads: tl.constexpr, + key_dim: tl.constexpr, + value_dim: tl.constexpr, + repeat: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + n = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + vh = tl.program_id(1) + kind = tl.program_id(2) + d = tl.arange(0, BLOCK_D) + token_mask = n < token_count + n64 = n.to(tl.int64) + d64 = d.to(tl.int64) + kh = vh // repeat + if kind == 0: + mask = d < key_dim + channel = kh * key_dim + d + channel64 = channel.to(tl.int64) + values = tl.load( + qkv + n64[:, None] * channels + channel64[None, :], + mask=token_mask[:, None] & mask[None, :], + other=0.0, + ) + tl.store( + query + (n64[:, None] * value_heads + vh) * key_dim + d64[None, :], + values, + mask=token_mask[:, None] & mask[None, :], + ) + elif kind == 1: + mask = d < key_dim + base = key_heads * key_dim + channel = base + kh * key_dim + d + channel64 = channel.to(tl.int64) + values = tl.load( + qkv + n64[:, None] * channels + channel64[None, :], + mask=token_mask[:, None] & mask[None, :], + other=0.0, + ) + tl.store( + key + (n64[:, None] * value_heads + vh) * key_dim + d64[None, :], + values, + mask=token_mask[:, None] & mask[None, :], + ) + else: + mask = d < value_dim + base = 2 * key_heads * key_dim + channel = base + vh * value_dim + d + channel64 = channel.to(tl.int64) + values = tl.load( + qkv + n64[:, None] * channels + channel64[None, :], + mask=token_mask[:, None] & mask[None, :], + other=0.0, + ) + tl.store( + value + (n64[:, None] * value_heads + vh) * value_dim + d64[None, :], + values, + mask=token_mask[:, None] & mask[None, :], + ) + + +@triton.jit(do_not_specialize=["token_count"]) +def _prepare_packed_qkv_backward_kernel( + grad_query, + grad_key, + grad_value, + grad_qkv, + token_count, + channels: tl.constexpr, + key_heads: tl.constexpr, + value_heads: tl.constexpr, + key_dim: tl.constexpr, + value_dim: tl.constexpr, + repeat: tl.constexpr, + HAS_QUERY: tl.constexpr, + HAS_KEY: tl.constexpr, + HAS_VALUE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_C: tl.constexpr, +): + n = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + c = tl.program_id(1) * BLOCK_C + tl.arange(0, BLOCK_C) + q_channels: tl.constexpr = key_heads * key_dim + k_channels: tl.constexpr = q_channels + v_base: tl.constexpr = q_channels + k_channels + n64 = n.to(tl.int64) + c64 = c.to(tl.int64) + mask = (n[:, None] < token_count) & (c[None, :] < channels) + is_query = c < q_channels + is_key = (c >= q_channels) & (c < v_base) + is_value = c >= v_base + values = tl.zeros((BLOCK_N, BLOCK_C), dtype=tl.float32) + + if HAS_QUERY: + q_kh = c // key_dim + q_d = c - q_kh * key_dim + q_mask = mask & is_query[None, :] + q_values = tl.zeros((BLOCK_N, BLOCK_C), dtype=tl.float32) + for r in tl.static_range(0, repeat): + vh = q_kh * repeat + r + q_values += tl.load( + grad_query + + (n64[:, None] * value_heads + vh[None, :].to(tl.int64)) * key_dim + + q_d[None, :], + mask=q_mask, + other=0.0, + ) + values = tl.where(q_mask, q_values, values) + + if HAS_KEY: + k_channel = c - q_channels + k_kh = k_channel // key_dim + k_d = k_channel - k_kh * key_dim + k_mask = mask & is_key[None, :] + k_values = tl.zeros((BLOCK_N, BLOCK_C), dtype=tl.float32) + for r in tl.static_range(0, repeat): + vh = k_kh * repeat + r + k_values += tl.load( + grad_key + + (n64[:, None] * value_heads + vh[None, :].to(tl.int64)) * key_dim + + k_d[None, :], + mask=k_mask, + other=0.0, + ) + values = tl.where(k_mask, k_values, values) + + if HAS_VALUE: + v_channel = c - v_base + vh = v_channel // value_dim + v_d = v_channel - vh * value_dim + v_mask = mask & is_value[None, :] + v_values = tl.load( + grad_value + + (n64[:, None] * value_heads + vh[None, :].to(tl.int64)) * value_dim + + v_d[None, :], + mask=v_mask, + other=0.0, + ) + values = tl.where(v_mask, v_values, values) + + tl.store(grad_qkv + n64[:, None] * channels + c64[None, :], values, mask=mask) + + +@triton.jit( + do_not_specialize=["token_count", "segment_count", "output_sequence_length"] +) +def _scatter_bucket_output_compact_forward_kernel( + output, + bucket_output, + row_indices, + position_indices, + output_mask, + cu_seqlens, + token_count, + segment_count, + output_sequence_length, + heads: tl.constexpr, + dim: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + n = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + hd = tl.program_id(1) * BLOCK_D + tl.arange(0, BLOCK_D) + segment, offset = _segment_from_cu(cu_seqlens, n, segment_count) + p = offset * segment_count + segment + token_mask = n < token_count + write = tl.load(output_mask + p, mask=token_mask, other=0).to(tl.int1) + row = tl.load(row_indices + p, mask=token_mask, other=0) + pos = tl.load(position_indices + p, mask=token_mask, other=0) + h = hd // dim + d = hd - h * dim + n64 = n.to(tl.int64) + row64 = row.to(tl.int64) + pos64 = pos.to(tl.int64) + h64 = h.to(tl.int64) + d64 = d.to(tl.int64) + mask = token_mask[:, None] & (hd[None, :] < heads * dim) & write[:, None] + values = tl.load( + bucket_output + (n64[:, None] * heads + h64[None, :]) * dim + d64[None, :], + mask=mask, + other=0.0, + ) + tl.store( + output + + ((row64[:, None] * output_sequence_length + pos64[:, None]) * heads + h64) + * dim + + d64, + values, + mask=mask, + ) + + +@triton.jit( + do_not_specialize=["token_count", "segment_count", "output_sequence_length"] +) +def _scatter_bucket_output_compact_backward_kernel( + grad_output, + grad_base, + grad_bucket_output, + row_indices, + position_indices, + output_mask, + cu_seqlens, + token_count, + segment_count, + output_sequence_length, + heads: tl.constexpr, + dim: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + n = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + hd = tl.program_id(1) * BLOCK_D + tl.arange(0, BLOCK_D) + segment, offset = _segment_from_cu(cu_seqlens, n, segment_count) + p = offset * segment_count + segment + token_mask = n < token_count + write = tl.load(output_mask + p, mask=token_mask, other=0).to(tl.int1) + row = tl.load(row_indices + p, mask=token_mask, other=0) + pos = tl.load(position_indices + p, mask=token_mask, other=0) + h = hd // dim + d = hd - h * dim + n64 = n.to(tl.int64) + row64 = row.to(tl.int64) + pos64 = pos.to(tl.int64) + h64 = h.to(tl.int64) + d64 = d.to(tl.int64) + mask = token_mask[:, None] & (hd[None, :] < heads * dim) & write[:, None] + output_offset = ( + (row64[:, None] * output_sequence_length + pos64[:, None]) * heads + h64 + ) * dim + d64 + values = tl.load(grad_output + output_offset, mask=mask, other=0.0) + tl.store( + grad_bucket_output + (n64[:, None] * heads + h64[None, :]) * dim + d64[None, :], + values, + mask=mask, + ) + tl.store( + grad_base + output_offset, + tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32), + mask=mask, + ) + + +class _CompactBucketStreamGather(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + qkv_flat: Tensor, + beta_flat: Tensor, + recurrent_g_flat: Tensor, + row_indices: Tensor, + position_indices: Tensor, + cu_seqlens: Tensor, + token_count: int, + segment_count: int, + sequence_length: int, + ) -> tuple[Tensor, Tensor, Tensor]: + _validate_cuda("qkv_flat", qkv_flat) + qkv_flat = qkv_flat.contiguous() + beta_flat = beta_flat.contiguous() + recurrent_g_flat = recurrent_g_flat.contiguous() + row_indices = row_indices.contiguous() + position_indices = position_indices.contiguous() + cu_seqlens = cu_seqlens.contiguous() + token_count = int(token_count) + segment_count = int(segment_count) + sequence_length = int(sequence_length) + qkv_channels = int(qkv_flat.shape[-1]) + aux_width = int(beta_flat.shape[-1]) + qkv = torch.empty( + (token_count, qkv_channels), + device=qkv_flat.device, + dtype=qkv_flat.dtype, + ) + beta = torch.empty( + (token_count, aux_width), device=beta_flat.device, dtype=beta_flat.dtype + ) + recurrent_g = torch.empty( + (token_count, aux_width), + device=recurrent_g_flat.device, + dtype=recurrent_g_flat.dtype, + ) + block_n, block_qkv, block_aux = 32, 64, 32 + _gather_compact_qkv_kernel[ + (triton.cdiv(token_count, block_n), triton.cdiv(qkv_channels, block_qkv)) + ]( + qkv_flat, + row_indices, + position_indices, + cu_seqlens, + qkv, + token_count, + segment_count, + sequence_length, + qkv_channels, + BLOCK_N=block_n, + BLOCK_D=block_qkv, + num_warps=4, + ) + grid_aux = ( + triton.cdiv(token_count, block_n), + triton.cdiv(aux_width, block_aux), + ) + _gather_compact_aux_kernel[grid_aux]( + beta_flat, + row_indices, + position_indices, + cu_seqlens, + beta, + token_count, + segment_count, + sequence_length, + aux_width, + BLOCK_N=block_n, + BLOCK_D=block_aux, + num_warps=4, + ) + _gather_compact_aux_kernel[grid_aux]( + recurrent_g_flat, + row_indices, + position_indices, + cu_seqlens, + recurrent_g, + token_count, + segment_count, + sequence_length, + aux_width, + BLOCK_N=block_n, + BLOCK_D=block_aux, + num_warps=4, + ) + ctx.save_for_backward(row_indices, position_indices, cu_seqlens) + ctx.token_count = token_count + ctx.segment_count = segment_count + ctx.sequence_length = sequence_length + ctx.qkv_flat_count = int(qkv_flat.shape[0]) + ctx.beta_flat_count = int(beta_flat.shape[0]) + ctx.recurrent_g_flat_count = int(recurrent_g_flat.shape[0]) + ctx.qkv_channels = qkv_channels + ctx.aux_width = aux_width + return qkv, beta, recurrent_g + + @staticmethod + def backward( + ctx: Any, *grad_outputs: Tensor | None + ) -> tuple[ + Tensor | None, + Tensor | None, + Tensor | None, + None, + None, + None, + None, + None, + None, + ]: + grad_qkv_bucket, grad_beta_bucket, grad_g_bucket = grad_outputs + row_indices, position_indices, cu_seqlens = ctx.saved_tensors + block_n, block_qkv, block_aux = 32, 64, 32 + grad_qkv = None + if ctx.needs_input_grad[0] and grad_qkv_bucket is not None: + grad_qkv_bucket = grad_qkv_bucket.contiguous() + grad_qkv = grad_qkv_bucket.new_zeros(ctx.qkv_flat_count, ctx.qkv_channels) + _scatter_compact_qkv_grad_kernel[ + ( + triton.cdiv(ctx.token_count, block_n), + triton.cdiv(ctx.qkv_channels, block_qkv), + ) + ]( + grad_qkv_bucket, + row_indices, + position_indices, + cu_seqlens, + grad_qkv, + ctx.token_count, + ctx.segment_count, + ctx.sequence_length, + ctx.qkv_channels, + BLOCK_N=block_n, + BLOCK_D=block_qkv, + num_warps=4, + ) + grad_beta = None + if ctx.needs_input_grad[1] and grad_beta_bucket is not None: + grad_beta_bucket = grad_beta_bucket.contiguous() + grad_beta = grad_beta_bucket.new_zeros(ctx.beta_flat_count, ctx.aux_width) + _scatter_compact_aux_grad_kernel[ + ( + triton.cdiv(ctx.token_count, block_n), + triton.cdiv(ctx.aux_width, block_aux), + ) + ]( + grad_beta_bucket, + row_indices, + position_indices, + cu_seqlens, + grad_beta, + ctx.token_count, + ctx.segment_count, + ctx.sequence_length, + ctx.aux_width, + BLOCK_N=block_n, + BLOCK_D=block_aux, + num_warps=4, + ) + grad_g = None + if ctx.needs_input_grad[2] and grad_g_bucket is not None: + grad_g_bucket = grad_g_bucket.contiguous() + grad_g = grad_g_bucket.new_zeros(ctx.recurrent_g_flat_count, ctx.aux_width) + _scatter_compact_aux_grad_kernel[ + ( + triton.cdiv(ctx.token_count, block_n), + triton.cdiv(ctx.aux_width, block_aux), + ) + ]( + grad_g_bucket, + row_indices, + position_indices, + cu_seqlens, + grad_g, + ctx.token_count, + ctx.segment_count, + ctx.sequence_length, + ctx.aux_width, + BLOCK_N=block_n, + BLOCK_D=block_aux, + num_warps=4, + ) + return grad_qkv, grad_beta, grad_g, None, None, None, None, None, None + + +class _PreparePackedRecurrentInputs(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + qkv: Tensor, + beta: Tensor, + recurrent_g: Tensor, + key_heads: int, + value_heads: int, + key_dim: int, + value_dim: int, + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + _validate_cuda("qkv", qkv) + qkv = qkv.contiguous() + beta = beta.contiguous() + recurrent_g = recurrent_g.contiguous() + token_count, channels = qkv.shape + key_heads = int(key_heads) + value_heads = int(value_heads) + key_dim = int(key_dim) + value_dim = int(value_dim) + if value_heads % key_heads != 0: + raise ValueError( + f"value_heads must be divisible by key_heads, got {value_heads} and {key_heads}" + ) + expected_channels = 2 * key_heads * key_dim + value_heads * value_dim + if int(channels) != expected_channels: + raise ValueError( + "packed qkv channel count mismatch, got " + f"{channels} and expected {expected_channels}" + ) + if tuple(beta.shape) != (token_count, value_heads): + raise ValueError( + f"beta must be [tokens, value_heads], got {tuple(beta.shape)}" + ) + if tuple(recurrent_g.shape) != tuple(beta.shape): + raise ValueError( + "recurrent_g shape must match beta, got " + f"{tuple(recurrent_g.shape)} and {tuple(beta.shape)}" + ) + repeat = value_heads // key_heads + query = torch.empty( + (1, token_count, value_heads, key_dim), device=qkv.device, dtype=qkv.dtype + ) + key = torch.empty_like(query) + value = torch.empty( + (1, token_count, value_heads, value_dim), device=qkv.device, dtype=qkv.dtype + ) + block_n = 16 + block_d = triton.next_power_of_2(max(key_dim, value_dim)) + if block_d > 128: + raise ValueError( + f"unsupported GDN head dimension {block_d}; expected <= 128" + ) + _prepare_packed_qkv_kernel[(triton.cdiv(token_count, block_n), value_heads, 3)]( + qkv, + query, + key, + value, + token_count, + channels, + key_heads, + value_heads, + key_dim, + value_dim, + repeat, + BLOCK_N=block_n, + BLOCK_D=block_d, + num_warps=1, + ) + ctx.input_shape = tuple(qkv.shape) + ctx.beta_shape = tuple(beta.shape) + ctx.input_dtype = qkv.dtype + ctx.beta_dtype = beta.dtype + ctx.g_dtype = recurrent_g.dtype + ctx.device = qkv.device + ctx.key_heads = key_heads + ctx.value_heads = value_heads + ctx.key_dim = key_dim + ctx.value_dim = value_dim + ctx.repeat = repeat + return query, key, value, beta.unsqueeze(0), recurrent_g.unsqueeze(0) + + @staticmethod + def backward( + ctx: Any, + *grad_outputs: Any, + ) -> tuple[ + Tensor | None, + Tensor | None, + Tensor | None, + None, + None, + None, + None, + ]: + grad_query, grad_key, grad_value, grad_beta_out, grad_g_out = grad_outputs + token_count, channels = ctx.input_shape + grad_qkv = None + device = None + if grad_query is not None: + device = grad_query.device + elif grad_key is not None: + device = grad_key.device + elif grad_value is not None: + device = grad_value.device + elif grad_beta_out is not None: + device = grad_beta_out.device + elif grad_g_out is not None: + device = grad_g_out.device + if ctx.needs_input_grad[0]: + if device is None: + raise RuntimeError("missing device for packed qkv gradient") + grad_qkv = torch.empty( + (token_count, channels), device=device, dtype=ctx.input_dtype + ) + grad_query_arg = ( + grad_query.contiguous() if grad_query is not None else grad_qkv + ) + grad_key_arg = grad_key.contiguous() if grad_key is not None else grad_qkv + grad_value_arg = ( + grad_value.contiguous() if grad_value is not None else grad_qkv + ) + block_n, block_c = 16, 64 + _prepare_packed_qkv_backward_kernel[ + (triton.cdiv(token_count, block_n), triton.cdiv(channels, block_c)) + ]( + grad_query_arg, + grad_key_arg, + grad_value_arg, + grad_qkv, + token_count, + channels, + ctx.key_heads, + ctx.value_heads, + ctx.key_dim, + ctx.value_dim, + ctx.repeat, + HAS_QUERY=grad_query is not None, + HAS_KEY=grad_key is not None, + HAS_VALUE=grad_value is not None, + BLOCK_N=block_n, + BLOCK_C=block_c, + num_warps=4, + ) + grad_beta = None + if ctx.needs_input_grad[1]: + grad_beta = ( + grad_beta_out.reshape(ctx.beta_shape).contiguous() + if grad_beta_out is not None + else torch.zeros( + ctx.beta_shape, device=ctx.device, dtype=ctx.beta_dtype + ) + ) + grad_g = None + if ctx.needs_input_grad[2]: + grad_g = ( + grad_g_out.reshape(ctx.beta_shape).contiguous() + if grad_g_out is not None + else torch.zeros(ctx.beta_shape, device=ctx.device, dtype=ctx.g_dtype) + ) + return grad_qkv, grad_beta, grad_g, None, None, None, None + + +class _CompactScatterBucketOutput(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + output: Tensor, + bucket_output: Tensor, + row_indices: Tensor, + position_indices: Tensor, + output_mask: Tensor, + cu_seqlens: Tensor, + ) -> Tensor: + _validate_cuda("output", output) + output = output.contiguous() + bucket_output = bucket_output.contiguous() + row_indices = row_indices.contiguous() + position_indices = position_indices.contiguous() + output_mask = output_mask.contiguous() + cu_seqlens = cu_seqlens.contiguous() + if bucket_output.ndim != 4 or int(bucket_output.shape[0]) != 1: + raise ValueError( + "bucket_output must have shape [1, tokens, heads, dim], got " + f"{tuple(bucket_output.shape)}" + ) + output_batch, output_sequence_length, heads, dim = output.shape + del output_batch + token_count = int(bucket_output.shape[1]) + segment_count = int(cu_seqlens.numel()) - 1 + if tuple(row_indices.shape) != tuple(position_indices.shape): + raise ValueError( + "row_indices and position_indices must have the same shape, got " + f"{tuple(row_indices.shape)} and {tuple(position_indices.shape)}" + ) + if tuple(output_mask.shape) != tuple(row_indices.shape): + raise ValueError( + "output_mask must match row_indices shape, got " + f"{tuple(output_mask.shape)} and {tuple(row_indices.shape)}" + ) + out = output.clone() + block_n, block_d = 16, 64 + _scatter_bucket_output_compact_forward_kernel[ + (triton.cdiv(token_count, block_n), triton.cdiv(heads * dim, block_d)) + ]( + out, + bucket_output, + row_indices, + position_indices, + output_mask, + cu_seqlens, + token_count, + segment_count, + output_sequence_length, + heads, + dim, + BLOCK_N=block_n, + BLOCK_D=block_d, + num_warps=4, + ) + ctx.save_for_backward(row_indices, position_indices, output_mask, cu_seqlens) + ctx.output_shape = tuple(output.shape) + ctx.bucket_output_shape = tuple(bucket_output.shape) + ctx.token_count = token_count + ctx.segment_count = segment_count + return out + + @staticmethod + def backward( + ctx: Any, *grad_outputs: Any + ) -> tuple[Tensor, Tensor, None, None, None, None]: + (grad_out,) = grad_outputs + row_indices, position_indices, output_mask, cu_seqlens = ctx.saved_tensors + _, output_sequence_length, heads, dim = ctx.output_shape + grad_out = grad_out.contiguous() + grad_base = grad_out.clone() + grad_bucket = grad_out.new_zeros(ctx.bucket_output_shape) + block_n, block_d = 16, 64 + _scatter_bucket_output_compact_backward_kernel[ + ( + triton.cdiv(ctx.token_count, block_n), + triton.cdiv(heads * dim, block_d), + ) + ]( + grad_out, + grad_base, + grad_bucket, + row_indices, + position_indices, + output_mask, + cu_seqlens, + ctx.token_count, + ctx.segment_count, + output_sequence_length, + heads, + dim, + BLOCK_N=block_n, + BLOCK_D=block_d, + num_warps=4, + ) + return grad_base, grad_bucket, None, None, None, None + + +def gather_bucket_streams_compact( + qkv_flat: Tensor, + beta_flat: Tensor, + recurrent_g_flat: Tensor, + row_indices: Tensor, + position_indices: Tensor, + cu_seqlens: Tensor, + *, + token_count: int, + segment_count: int, + sequence_length: int, +) -> tuple[Tensor, Tensor, Tensor]: + return _CompactBucketStreamGather.apply( + qkv_flat, + beta_flat, + recurrent_g_flat, + row_indices, + position_indices, + cu_seqlens, + token_count, + segment_count, + sequence_length, + ) + + +def scatter_bucket_output_compact( + output: Tensor, + bucket_output: Tensor, + row_indices: Tensor, + position_indices: Tensor, + output_mask: Tensor, + cu_seqlens: Tensor, +) -> Tensor: + return _CompactScatterBucketOutput.apply( + output, + bucket_output, + row_indices, + position_indices, + output_mask, + cu_seqlens, + ) + + +def prepare_packed_recurrent_inputs( + qkv: Tensor, + beta: Tensor, + recurrent_g: Tensor, + *, + key_heads: int, + value_heads: int, + key_dim: int, + value_dim: int, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + return _PreparePackedRecurrentInputs.apply( + qkv, + beta, + recurrent_g, + key_heads, + value_heads, + key_dim, + value_dim, + ) + + +def _validate_cuda(name: str, tensor: Tensor) -> None: + if not tensor.is_cuda: + raise ValueError(f"{name} must be a CUDA tensor") diff --git a/src/art/megatron/kernels/__init__.py b/src/art/megatron/kernels/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/art/megatron/kernels/__init__.py @@ -0,0 +1 @@ + diff --git a/src/art/megatron/cute_grouped_lora_quack.py b/src/art/megatron/kernels/cute_grouped_lora_quack.py similarity index 100% rename from src/art/megatron/cute_grouped_lora_quack.py rename to src/art/megatron/kernels/cute_grouped_lora_quack.py diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index 9fba022f2..c2a28bffa 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -13,6 +13,7 @@ ) from megatron.core.ssm.gated_delta_net import GatedDeltaNet from megatron.core.tensor_parallel.mappings import ( + gather_from_sequence_parallel_region, reduce_from_tensor_model_parallel_region, reduce_scatter_to_sequence_parallel_region, ) @@ -23,7 +24,10 @@ from pydantic import BaseModel, ConfigDict import torch -from .cute_grouped_lora_quack import quack_grouped_lora, quack_grouped_lora_dual +from .kernels.cute_grouped_lora_quack import ( + quack_grouped_lora, + quack_grouped_lora_dual, +) LORA_RANK = 1 LORA_ALPHA = 32 @@ -98,14 +102,48 @@ def _normalize_axis(axis: int, ndim: int) -> int: return axis +def _shard_weight_by_components( + weight: torch.Tensor, + *, + axis: int, + component_sizes: Sequence[int], + world_size: int, + rank: int, +) -> torch.Tensor: + if sum(component_sizes) != weight.shape[axis]: + raise ValueError( + f"Component sizes {tuple(component_sizes)} do not match axis {axis} " + f"extent {weight.shape[axis]}" + ) + local_components: list[torch.Tensor] = [] + for component in torch.split(weight, list(component_sizes), dim=axis): + if component.shape[axis] % world_size != 0: + raise ValueError( + f"Component shape {tuple(component.shape)} is not divisible by " + f"world size {world_size} on axis {axis}" + ) + local_size = component.shape[axis] // world_size + local_components.append(component.narrow(axis, rank * local_size, local_size)) + return torch.cat(local_components, dim=axis).contiguous() + + def _linear_disables_tensor_parallel_comm(linear: Any) -> bool: - # Shared experts can keep TP-sharded weights while deferring TP comm to the - # overlap path by setting parallel_mode=None / explicit_expert_comm=True. return getattr(linear, "parallel_mode", "") is None or getattr( linear, "explicit_expert_comm", False ) +def _column_parallel_lora_input(x: torch.Tensor, linear: Any) -> torch.Tensor: + if _linear_disables_tensor_parallel_comm(linear): + return x + if ( + bool(getattr(linear, "sequence_parallel", False)) + and int(getattr(linear, "tp_size", 1)) > 1 + ): + return gather_from_sequence_parallel_region(x) + return x + + def _set_lora_parallel_metadata( param: torch.nn.Parameter, *, @@ -153,6 +191,37 @@ def _set_lora_parallel_metadata( setattr(param, "partition_stride", 1) +def _set_lora_shard_strategy_metadata( + param: torch.nn.Parameter, + *, + strategy: str, + component_sizes: Sequence[int] | None = None, +) -> None: + setattr(param, "lora_tp_shard_strategy", strategy) + if component_sizes is not None: + setattr( + param, + "lora_tp_component_sizes", + tuple(int(size) for size in component_sizes), + ) + + +def _exported_shard_dim(param: torch.nn.Parameter) -> int: + axis = _normalize_axis(param.lora_tp_shard_dim, param.ndim) # ty: ignore[unresolved-attribute] + # LoRA exports always serialize a 2D tensor: + # - non-expert params export `param.T` + # - expert params export `param[expert].T` + if param.ndim == 3: + if axis == 0: + raise ValueError("LoRA expert shard_dim cannot reference the expert axis") + axis -= 1 + if axis not in (0, 1): + raise ValueError( + f"Unsupported exported LoRA shard axis {axis} for ndim={param.ndim}" + ) + return 1 - axis + + class LoRA(torch.nn.Module): def __init__( self, @@ -284,22 +353,48 @@ def load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None axis = _normalize_axis(axis, weight.ndim) world_size = _get_shard_world_size(domain) rank = _get_shard_rank(domain) - if weight.shape[axis] % world_size != 0: - raise ValueError( - f"{self.adapter_model_prefix}: weight shape {tuple(weight.shape)} is not divisible by world size " - f"{world_size} on axis {axis}" + strategy = getattr(into, "lora_tp_shard_strategy", "uniform") + if strategy == "componentwise": + component_sizes = tuple( + int(size) for size in getattr(into, "lora_tp_component_sizes", ()) ) - local_size = weight.shape[axis] // world_size - if into.shape[axis] != local_size: + if not component_sizes: + raise ValueError( + f"{self.adapter_model_prefix}: missing component sizes for shard strategy={strategy}" + ) + weight = _shard_weight_by_components( + weight, + axis=axis, + component_sizes=component_sizes, + world_size=world_size, + rank=rank, + ) + elif strategy == "uniform": + if weight.shape[axis] % world_size != 0: + raise ValueError( + f"{self.adapter_model_prefix}: weight shape {tuple(weight.shape)} is not divisible by world size " + f"{world_size} on axis {axis}" + ) + local_size = weight.shape[axis] // world_size + if into.shape[axis] != local_size: + raise ValueError( + f"{self.adapter_model_prefix}: expected local shard size {into.shape[axis]}, got {local_size}" + ) + weight = weight.narrow(axis, rank * local_size, local_size) + else: raise ValueError( - f"{self.adapter_model_prefix}: expected local shard size {into.shape[axis]}, got {local_size}" + f"{self.adapter_model_prefix}: unsupported shard strategy={strategy}" ) - weight = weight.narrow(axis, rank * local_size, local_size) elif tuple(weight.shape) != tuple(into.shape): raise ValueError( f"{self.adapter_model_prefix}: unsharded load shape mismatch, got {tuple(weight.shape)} " f"expected {tuple(into.shape)}" ) + if tuple(weight.shape) != tuple(into.shape): + raise ValueError( + f"{self.adapter_model_prefix}: sharded load shape mismatch, got {tuple(weight.shape)} " + f"expected {tuple(into.shape)}" + ) into.data.copy_(weight) into.requires_grad = True @@ -323,7 +418,7 @@ def _should_export_parameter(self, param: torch.nn.Parameter) -> bool: return _get_shard_rank(param.lora_shard_domain) == 0 # ty: ignore[unresolved-attribute] def _manifest_for_param(self, param: torch.nn.Parameter) -> dict[str, Any]: - return { + manifest = { "domain": param.lora_shard_domain, # ty: ignore[unresolved-attribute] "sharded": param.lora_tp_sharded, # ty: ignore[unresolved-attribute] "shard_dim": param.lora_tp_shard_dim, # ty: ignore[unresolved-attribute] @@ -334,6 +429,17 @@ def _manifest_for_param(self, param: torch.nn.Parameter) -> dict[str, Any]: if param.lora_tp_sharded # ty: ignore[unresolved-attribute] else 0, } + if param.lora_tp_sharded: # ty: ignore[unresolved-attribute] + manifest["export_shard_dim"] = _exported_shard_dim(param) + manifest["export_shard_strategy"] = getattr( + param, + "lora_tp_shard_strategy", + "uniform", + ) + component_sizes = list(getattr(param, "lora_tp_component_sizes", ())) + if component_sizes: + manifest["component_sizes"] = component_sizes + return manifest def _lora_params(self) -> list[tuple[str, torch.nn.Parameter]]: return [ @@ -368,6 +474,22 @@ def sharded_lora_state_dict(self) -> dict[str, torch.Tensor]: state[key] = param.data[expert].T if expert is not None else param.data.T return state + def sharded_lora_grad_dict(self) -> dict[str, torch.Tensor]: + grads: dict[str, torch.Tensor] = {} + for key, param, expert in self._export_items(): + if not hasattr(param, "main_grad"): + raise RuntimeError( + f"LoRA param missing main_grad attribute for key '{key}'" + ) + grad = cast(torch.Tensor, param.main_grad) + if grad is None: + raise RuntimeError(f"LoRA param main_grad is None for key '{key}'") + if hasattr(grad, "_local_tensor"): + grad = cast(Any, grad)._local_tensor + local_grad = grad[expert] if expert is not None else grad + grads[key] = local_grad.T + return grads + def forward( self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor | None = None ) -> torch.Tensor: @@ -378,8 +500,7 @@ def forward( bsz = tokens_per_expert if isinstance(bsz, list): bsz = torch.tensor(bsz, dtype=torch.int64, device="cpu") - # If no tokens routed locally, return zeros. - if isinstance(bsz, torch.Tensor) and int(torch.count_nonzero(bsz)) == 0: + if x.shape[0] == 0: return x.new_zeros((x.shape[0], self.B_T.shape[-1])) return quack_grouped_lora(x, self.A_T, self.B_T, bsz, scale=self.scale) out = (x @ self.A_T) @ self.B_T @@ -466,8 +587,9 @@ def __init__( raise ValueError( "num_attention_heads must be divisible by num_query_groups for QKV LoRA" ) - linear_qkv_weight = cast(torch.Tensor, linear_qkv.weight) - total_out_features_per_rank = linear_qkv_weight.shape[0] + weight = linear_qkv.weight + assert isinstance(weight, torch.Tensor) + total_out_features_per_rank = int(weight.shape[0]) kv_out_features = self.provider.kv_channels * self.provider.num_query_groups tp_world_size = ps.get_tensor_model_parallel_world_size() assert kv_out_features % tp_world_size == 0, ( @@ -498,7 +620,6 @@ def __init__( self.provider.num_attention_heads // self.provider.num_query_groups ) self.hidden_size_per_attention_head = self.provider.kv_channels - assert isinstance(linear_qkv.weight, torch.Tensor) self.q_proj_lora = self._build_qkv_lora( adapter_model_prefix=f"{adapter_model_prefix}.q_proj", linear_qkv=linear_qkv, @@ -572,8 +693,6 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: query_and_gate = self.q_proj_lora(layernorm_output) key = self.k_proj_lora(layernorm_output) value = self.v_proj_lora(layernorm_output) - # Qwen3 packs [query, key, value] per group, while Qwen3.5 packs - # [query, gate, key, value]. query_and_gate_5d = query_and_gate.reshape( query_and_gate.shape[0], query_and_gate.shape[1], @@ -632,6 +751,15 @@ def __init__( alpha=alpha, out_features=qkv_out_features_per_partition, ) + _set_lora_shard_strategy_metadata( + self.qkv_lora.B_T, + strategy="componentwise", + component_sizes=( + gated_delta_net.qk_dim, + gated_delta_net.qk_dim, + gated_delta_net.v_dim, + ), + ) self.z_lora = self._build_in_proj_lora( adapter_model_prefix=f"{adapter_model_prefix}.in_proj_z", in_proj=in_proj, @@ -678,10 +806,7 @@ def _build_in_proj_lora( ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: - ( - linear_output_and_layernorm_output, - bias, - ) = self.in_proj(x) + linear_output_and_layernorm_output, bias = self.in_proj(x) linear_output, layernorm_output = linear_output_and_layernorm_output assert isinstance(linear_output, torch.Tensor) assert isinstance(layernorm_output, torch.Tensor) @@ -775,7 +900,7 @@ def forward( counts = tokens_per_expert if isinstance(counts, list): counts = torch.tensor(counts, dtype=torch.int64, device="cpu") - if isinstance(counts, torch.Tensor) and int(torch.count_nonzero(counts)) == 0: + if x.shape[0] == 0: adapter_out = x.new_zeros((x.shape[0], self.linear_fc1.out_features)) else: adapter_out = quack_grouped_lora_dual( @@ -791,71 +916,53 @@ def forward( return base_out + adapter_out, bias_out -class SharedExpertsLinearFC1LoRA(torch.nn.Module): +class MLPExpertsLinearFC1FusedLoRA(torch.nn.Module): def __init__( self, adapter_model_prefix: str, - linear_fc1: TEColumnParallelLinear | TELayerNormColumnParallelLinear, + linear_fc1: TEColumnParallelGroupedLinear, rank: int, alpha: float, + num_local_experts: int, ) -> None: super().__init__() + assert linear_fc1 is not None + assert isinstance(linear_fc1.weight0, torch.Tensor) self.linear_fc1 = linear_fc1 - self.gate_lora = self._build_fc1_lora( - adapter_model_prefix=f"{adapter_model_prefix}.gate_proj", - linear_fc1=linear_fc1, - rank=rank, - alpha=alpha, - ) - self.up_lora = self._build_fc1_lora( - adapter_model_prefix=f"{adapter_model_prefix}.up_proj", - linear_fc1=linear_fc1, - rank=rank, - alpha=alpha, - ) - - @staticmethod - def _build_fc1_lora( - *, - adapter_model_prefix: str, - linear_fc1: TEColumnParallelLinear | TELayerNormColumnParallelLinear, - rank: int, - alpha: float, - ) -> LoRA: - assert isinstance(linear_fc1.weight, torch.Tensor) a_parallel_spec = LoRAParallelSpec( - shard_domain="tp", + shard_domain="expert_tp", sharded=False, shard_dim=None, - grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, + grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, grad_sync_op=GRAD_SYNC_OP_SUM, ) b_parallel_spec = a_parallel_spec.model_copy( update={ "sharded": True, "shard_dim": -1, + "grad_sync_domain": EXPERT_TP_GRAD_SYNC_DOMAIN, "grad_sync_op": GRAD_SYNC_OP_NONE, } ) - return LoRA( - adapter_model_prefix=adapter_model_prefix, + self.lora = LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.gate_up_proj", in_features=linear_fc1.in_features, - out_features=linear_fc1.out_features // 2, + out_features=linear_fc1.out_features, rank=rank, alpha=alpha, - dtype=linear_fc1.weight.dtype, - device=linear_fc1.weight.device, + dtype=linear_fc1.weight0.dtype, + device=linear_fc1.weight0.device, + num_local_experts=num_local_experts, a_parallel_spec=a_parallel_spec, b_parallel_spec=b_parallel_spec, - allreduce=True, + allreduce=False, ) - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: - base_out, bias_out = self.linear_fc1(x) - adapter_out = torch.cat( - [self.gate_lora(x), self.up_lora(x)], - dim=-1, - ) + def forward( + self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor | None]: + base_out, bias_out = self.linear_fc1(x, tokens_per_expert) + adapter_out = self.lora(x, tokens_per_expert=tokens_per_expert) return base_out + adapter_out, bias_out @@ -912,6 +1019,89 @@ def forward( return base_out + adapter_out, bias_out +class SharedExpertsLinearFC1LoRA(torch.nn.Module): + def __init__( + self, + adapter_model_prefix: str, + linear_fc1: TEColumnParallelLinear | TELayerNormColumnParallelLinear, + rank: int, + alpha: float, + ) -> None: + super().__init__() + if isinstance(linear_fc1, TELayerNormColumnParallelLinear): + linear_fc1.return_layernorm_output = True + linear_fc1.return_layernorm_output_gathered = True + self.linear_fc1 = linear_fc1 + self.gate_lora = self._build_fc1_lora( + adapter_model_prefix=f"{adapter_model_prefix}.gate_proj", + linear_fc1=linear_fc1, + rank=rank, + alpha=alpha, + ) + self.up_lora = self._build_fc1_lora( + adapter_model_prefix=f"{adapter_model_prefix}.up_proj", + linear_fc1=linear_fc1, + rank=rank, + alpha=alpha, + ) + + @staticmethod + def _build_fc1_lora( + *, + adapter_model_prefix: str, + linear_fc1: TEColumnParallelLinear | TELayerNormColumnParallelLinear, + rank: int, + alpha: float, + ) -> LoRA: + assert isinstance(linear_fc1.weight, torch.Tensor) + a_parallel_spec = LoRAParallelSpec( + shard_domain="tp", + sharded=False, + shard_dim=None, + grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, + grad_sync_op=GRAD_SYNC_OP_SUM, + ) + b_parallel_spec = a_parallel_spec.model_copy( + update={ + "sharded": True, + "shard_dim": -1, + "grad_sync_op": GRAD_SYNC_OP_NONE, + } + ) + return LoRA( + adapter_model_prefix=adapter_model_prefix, + in_features=linear_fc1.in_features, + out_features=linear_fc1.out_features // 2, + rank=rank, + alpha=alpha, + dtype=linear_fc1.weight.dtype, + device=linear_fc1.weight.device, + a_parallel_spec=a_parallel_spec, + b_parallel_spec=b_parallel_spec, + allreduce=True, + ) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: + base_output, bias_out = self.linear_fc1(x) + if isinstance(base_output, tuple): + base_out, lora_input = base_output + else: + base_out = base_output + lora_input = _column_parallel_lora_input(x, self.linear_fc1) + adapter_out = torch.cat( + [self.gate_lora(lora_input), self.up_lora(lora_input)], + dim=-1, + ) + if adapter_out.shape != base_out.shape: + adapter_model_prefix = self.gate_lora.adapter_model_prefix.rsplit(".", 1)[0] + raise RuntimeError( + f"{adapter_model_prefix}: LoRA adapter output shape " + f"{tuple(adapter_out.shape)} does not match base output shape " + f"{tuple(base_out.shape)}" + ) + return base_out + adapter_out, bias_out + + class SharedExpertsLinearFC2LoRA(torch.nn.Module): def __init__( self, @@ -935,166 +1125,262 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: return self.row_parallel_lora(x) +def _unwrap_attr( + value: Any, + attr_name: str, + expected_type: type[Any] | tuple[type[Any], ...], +) -> Any: + if isinstance(value, expected_type): + return value + unwrapped = getattr(value, attr_name) + assert isinstance(unwrapped, expected_type) + return unwrapped + + +def _adapter_model_prefix(module: TransformerLayer) -> str: + return f"base_model.model.model.layers.{module.layer_number - 1}" + + +def _is_language_transformer_layer_name(module_name: str) -> bool: + while module_name.startswith("module."): + module_name = module_name.removeprefix("module.") + return module_name.startswith(("decoder.layers.", "language_model.decoder.layers.")) + + +def _targets_include(target_modules: set[str], *names: str) -> bool: + return not target_modules or any(name in target_modules for name in names) + + +def wrap_standard_self_attention( + self_attention: SelfAttention, + *, + adapter_model_prefix: str, + provider: GPTModelProvider, + target_modules: set[str], + rank: int, + alpha: int, +) -> None: + if _targets_include(target_modules, "o_proj"): + self_attention_linear_proj = _unwrap_attr( + self_attention.linear_proj, + "linear_proj", + TERowParallelLinear, + ) + self_attention.linear_proj = SelfAttentionLinearProjLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.self_attn.o_proj", + linear_proj=self_attention_linear_proj, + rank=rank, + alpha=alpha, + provider=provider, + ) + if _targets_include(target_modules, "q_proj", "k_proj", "v_proj"): + self_attention_linear_qkv = _unwrap_attr( + self_attention.linear_qkv, + "linear_qkv", + TELayerNormColumnParallelLinear, + ) + self_attention.linear_qkv = SelfAttentionLinearQKVLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.self_attn", + linear_qkv=self_attention_linear_qkv, + rank=rank, + alpha=alpha, + provider=provider, + ) + + +def wrap_gated_delta_net_attention( + self_attention: GatedDeltaNet, + *, + adapter_model_prefix: str, + provider: GPTModelProvider, + target_modules: set[str], + rank: int, + alpha: int, +) -> None: + if _targets_include(target_modules, "out_proj"): + gated_delta_net_out_proj = _unwrap_attr( + self_attention.out_proj, + "out_proj", + TERowParallelLinear, + ) + self_attention.out_proj = SelfAttentionLinearProjLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.linear_attn.out_proj", + linear_proj=gated_delta_net_out_proj, + rank=rank, + alpha=alpha, + provider=provider, + ) + if _targets_include(target_modules, "in_proj_qkv", "in_proj_z"): + gated_delta_net_in_proj = _unwrap_attr( + self_attention.in_proj, + "in_proj", + TELayerNormColumnParallelLinear, + ) + self_attention.in_proj = GatedDeltaNetInProjLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.linear_attn", + in_proj=gated_delta_net_in_proj, + gated_delta_net=self_attention, + rank=rank, + alpha=alpha, + ) + + +def wrap_grouped_moe_experts( + experts: TEGroupedMLP, + *, + adapter_model_prefix: str, + target_modules: set[str], + rank: int, + alpha: int, +) -> None: + if _targets_include(target_modules, "gate_proj", "up_proj"): + mlp_experts_linear_fc1 = _unwrap_attr( + experts.linear_fc1, + "linear_fc1", + TEColumnParallelGroupedLinear, # type: ignore[arg-type] + ) + experts.linear_fc1 = MLPExpertsLinearFC1LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", + linear_fc1=mlp_experts_linear_fc1, + rank=rank, + alpha=alpha, + num_local_experts=experts.num_local_experts, + ) + if _targets_include(target_modules, "down_proj"): + mlp_experts_linear_fc2 = _unwrap_attr( + experts.linear_fc2, + "linear_fc2", + TERowParallelGroupedLinear, # type: ignore[arg-type] + ) + experts.linear_fc2 = MLPExpertsLinearFC2LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", + linear_fc2=mlp_experts_linear_fc2, + rank=rank, + alpha=alpha, + num_local_experts=experts.num_local_experts, + ) + + +def wrap_grouped_moe_experts_3d( + experts: TEGroupedMLP, + *, + adapter_model_prefix: str, + target_modules: set[str], + rank: int, + alpha: int, +) -> None: + if _targets_include(target_modules, "experts"): + mlp_experts_linear_fc1 = _unwrap_attr( + experts.linear_fc1, + "linear_fc1", + TEColumnParallelGroupedLinear, # type: ignore[arg-type] + ) + experts.linear_fc1 = MLPExpertsLinearFC1FusedLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", + linear_fc1=mlp_experts_linear_fc1, + rank=rank, + alpha=alpha, + num_local_experts=experts.num_local_experts, + ) + mlp_experts_linear_fc2 = _unwrap_attr( + experts.linear_fc2, + "linear_fc2", + TERowParallelGroupedLinear, # type: ignore[arg-type] + ) + experts.linear_fc2 = MLPExpertsLinearFC2LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", + linear_fc2=mlp_experts_linear_fc2, + rank=rank, + alpha=alpha, + num_local_experts=experts.num_local_experts, + ) + + +def wrap_dense_mlp( + mlp: Any, + *, + adapter_model_prefix: str, + provider: GPTModelProvider, + target_modules: set[str], + rank: int, + alpha: int, +) -> None: + if _targets_include(target_modules, "gate_proj", "up_proj"): + mlp_linear_fc1 = _unwrap_attr( + mlp.linear_fc1, + "linear_fc1", + (TEColumnParallelLinear, TELayerNormColumnParallelLinear), + ) + mlp.linear_fc1 = SharedExpertsLinearFC1LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp", + linear_fc1=mlp_linear_fc1, + rank=rank, + alpha=alpha, + ) + if _targets_include(target_modules, "down_proj"): + mlp_linear_fc2 = _unwrap_attr( + mlp.linear_fc2, + "linear_fc2", + TERowParallelLinear, + ) + mlp.linear_fc2 = SharedExpertsLinearFC2LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp", + linear_fc2=mlp_linear_fc2, + rank=rank, + alpha=alpha, + provider=provider, + ) + + +def wrap_shared_experts_mlp( + shared_experts: SharedExpertMLP, + *, + adapter_model_prefix: str, + provider: GPTModelProvider, + target_modules: set[str], + rank: int, + alpha: int, +) -> None: + if _targets_include(target_modules, "gate_proj", "up_proj"): + shared_experts_linear_fc1 = _unwrap_attr( + shared_experts.linear_fc1, + "linear_fc1", + (TEColumnParallelLinear, TELayerNormColumnParallelLinear), + ) + shared_experts.linear_fc1 = SharedExpertsLinearFC1LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.shared_expert", + linear_fc1=shared_experts_linear_fc1, + rank=rank, + alpha=alpha, + ) + if _targets_include(target_modules, "down_proj"): + shared_experts_linear_fc2 = _unwrap_attr( + shared_experts.linear_fc2, + "linear_fc2", + TERowParallelLinear, + ) + shared_experts.linear_fc2 = SharedExpertsLinearFC2LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.shared_expert", + linear_fc2=shared_experts_linear_fc2, + rank=rank, + alpha=alpha, + provider=provider, + ) + + def apply_lora_adapters( model: Sequence[torch.nn.Module], provider: GPTModelProvider, ) -> list[torch.nn.Module]: - def _is_language_transformer_layer(module_name: str) -> bool: - while module_name.startswith("module."): - module_name = module_name.removeprefix("module.") - return module_name.startswith( - ("decoder.layers.", "language_model.decoder.layers.") - ) - - def _unwrap_attr( - value: Any, - attr_name: str, - expected_type: type[Any] | tuple[type[Any], ...], - ) -> Any: - if isinstance(value, expected_type): - return value - unwrapped = getattr(value, attr_name) - assert isinstance(unwrapped, expected_type) - return unwrapped - - for chunk in model: - for module_name, module in chunk.named_modules(): - if isinstance(module, TransformerLayer): - if not _is_language_transformer_layer(module_name): - continue - adapter_model_prefix = ( - f"base_model.model.model.layers.{module.layer_number - 1}" - ) - if isinstance(module.self_attention, SelfAttention): - self_attention_linear_proj = _unwrap_attr( - module.self_attention.linear_proj, - "linear_proj", - TERowParallelLinear, - ) - module.self_attention.linear_proj = SelfAttentionLinearProjLoRA( - adapter_model_prefix=f"{adapter_model_prefix}.self_attn.o_proj", - linear_proj=self_attention_linear_proj, - rank=LORA_RANK, - alpha=LORA_ALPHA, - provider=provider, - ) - self_attention_linear_qkv = _unwrap_attr( - module.self_attention.linear_qkv, - "linear_qkv", - TELayerNormColumnParallelLinear, - ) - module.self_attention.linear_qkv = SelfAttentionLinearQKVLoRA( - adapter_model_prefix=f"{adapter_model_prefix}.self_attn", - linear_qkv=self_attention_linear_qkv, - rank=LORA_RANK, - alpha=LORA_ALPHA, - provider=provider, - ) - elif isinstance(module.self_attention, GatedDeltaNet): - gated_delta_net_out_proj = _unwrap_attr( - module.self_attention.out_proj, - "out_proj", - TERowParallelLinear, - ) - module.self_attention.out_proj = SelfAttentionLinearProjLoRA( - adapter_model_prefix=f"{adapter_model_prefix}.linear_attn.out_proj", - linear_proj=gated_delta_net_out_proj, - rank=LORA_RANK, - alpha=LORA_ALPHA, - provider=provider, - ) - gated_delta_net_in_proj = _unwrap_attr( - module.self_attention.in_proj, - "in_proj", - TELayerNormColumnParallelLinear, - ) - module.self_attention.in_proj = GatedDeltaNetInProjLoRA( - adapter_model_prefix=f"{adapter_model_prefix}.linear_attn", - in_proj=gated_delta_net_in_proj, - gated_delta_net=module.self_attention, - rank=LORA_RANK, - alpha=LORA_ALPHA, - ) - else: - raise TypeError( - "Unsupported self_attention module type for Megatron LoRA: " - f"{type(module.self_attention)}" - ) - experts = getattr(module.mlp, "experts", None) - if experts is not None: - assert isinstance(experts, TEGroupedMLP) - mlp_experts_linear_fc1 = _unwrap_attr( - experts.linear_fc1, - "linear_fc1", - TEColumnParallelGroupedLinear, # type: ignore[arg-type] - ) - experts.linear_fc1 = MLPExpertsLinearFC1LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", - linear_fc1=mlp_experts_linear_fc1, - rank=LORA_RANK, - alpha=LORA_ALPHA, - num_local_experts=experts.num_local_experts, - ) - mlp_experts_linear_fc2 = _unwrap_attr( - experts.linear_fc2, - "linear_fc2", - TERowParallelGroupedLinear, # type: ignore[arg-type] - ) - experts.linear_fc2 = MLPExpertsLinearFC2LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", - linear_fc2=mlp_experts_linear_fc2, - rank=LORA_RANK, - alpha=LORA_ALPHA, - num_local_experts=experts.num_local_experts, - ) - else: - mlp_linear_fc1 = _unwrap_attr( - module.mlp.linear_fc1, - "linear_fc1", - (TEColumnParallelLinear, TELayerNormColumnParallelLinear), - ) - module.mlp.linear_fc1 = SharedExpertsLinearFC1LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp", - linear_fc1=mlp_linear_fc1, - rank=LORA_RANK, - alpha=LORA_ALPHA, - ) - mlp_linear_fc2 = _unwrap_attr( - module.mlp.linear_fc2, - "linear_fc2", - TERowParallelLinear, - ) - module.mlp.linear_fc2 = SharedExpertsLinearFC2LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp", - linear_fc2=mlp_linear_fc2, - rank=LORA_RANK, - alpha=LORA_ALPHA, - provider=provider, - ) - shared_experts = getattr(module.mlp, "shared_experts", None) - if shared_experts is not None: - assert isinstance(shared_experts, SharedExpertMLP) - shared_experts_linear_fc1 = _unwrap_attr( - shared_experts.linear_fc1, - "linear_fc1", - (TEColumnParallelLinear, TELayerNormColumnParallelLinear), - ) - shared_experts.linear_fc1 = SharedExpertsLinearFC1LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.shared_expert", - linear_fc1=shared_experts_linear_fc1, - rank=LORA_RANK, - alpha=LORA_ALPHA, - ) - shared_experts_linear_fc2 = _unwrap_attr( - shared_experts.linear_fc2, - "linear_fc2", - TERowParallelLinear, - ) - shared_experts.linear_fc2 = SharedExpertsLinearFC2LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.shared_expert", - linear_fc2=shared_experts_linear_fc2, - rank=LORA_RANK, - alpha=LORA_ALPHA, - provider=provider, - ) + provider = cast(Any, provider) + handler = provider._art_model_support_handler + spec = provider._art_model_support_spec + target_modules = list(spec.default_target_modules) + handler.apply_lora_adapters( + model, + provider, + target_modules=target_modules, + rank=LORA_RANK, + alpha=LORA_ALPHA, + ) return list(model) diff --git a/src/art/megatron/model_support/__init__.py b/src/art/megatron/model_support/__init__.py new file mode 100644 index 000000000..60862ac54 --- /dev/null +++ b/src/art/megatron/model_support/__init__.py @@ -0,0 +1,91 @@ +from art.megatron.model_support.registry import ( + DEFAULT_DENSE_SPEC, + PROBE_ONLY_MODEL_SUPPORT_SPECS, + QWEN3_5_DENSE_MODELS, + QWEN3_5_DENSE_SPEC, + QWEN3_5_MODELS, + QWEN3_5_MOE_MODELS, + QWEN3_5_MOE_SPEC, + QWEN3_DENSE_MODELS, + QWEN3_DENSE_SPEC, + QWEN3_MOE_MODELS, + QWEN3_MOE_SPEC, + VALIDATED_MODEL_SUPPORT_SPECS, + UnsupportedModelArchitectureError, + default_target_modules_for_model, + get_model_support_handler, + get_model_support_handler_for_spec, + get_model_support_spec, + is_model_support_registered, + list_model_support_specs, + model_requires_merged_rollout, + model_uses_expert_parallel, + native_vllm_lora_status_for_model, +) +from art.megatron.model_support.spec import ( + ArchitectureReport, + DependencyFloor, + LayerFamilyInstance, + MinimalLayerCoverageReport, + ModelSupportHandler, + ModelSupportSpec, + NativeVllmLoraStatus, + RolloutWeightsMode, + ValidationReport, + ValidationStageResult, +) + +_LAZY_EXPORT_MODULES = { + "inspect_architecture": "art.megatron.model_support.discovery", + "summarize_layer_families": "art.megatron.model_support.discovery", +} + + +def __getattr__(name: str): + import importlib + + try: + module_name = _LAZY_EXPORT_MODULES[name] + except KeyError as exc: + raise AttributeError(name) from exc + value = getattr(importlib.import_module(module_name), name) + globals()[name] = value + return value + + +__all__ = [ + "ArchitectureReport", + "DEFAULT_DENSE_SPEC", + "DependencyFloor", + "LayerFamilyInstance", + "MinimalLayerCoverageReport", + "ModelSupportHandler", + "ModelSupportSpec", + "NativeVllmLoraStatus", + "QWEN3_5_DENSE_MODELS", + "QWEN3_5_DENSE_SPEC", + "QWEN3_5_MODELS", + "QWEN3_5_MOE_MODELS", + "QWEN3_DENSE_MODELS", + "QWEN3_DENSE_SPEC", + "QWEN3_MOE_MODELS", + "QWEN3_MOE_SPEC", + "QWEN3_5_MOE_SPEC", + "PROBE_ONLY_MODEL_SUPPORT_SPECS", + "RolloutWeightsMode", + "ValidationReport", + "ValidationStageResult", + "UnsupportedModelArchitectureError", + "VALIDATED_MODEL_SUPPORT_SPECS", + "default_target_modules_for_model", + "get_model_support_handler", + "get_model_support_handler_for_spec", + "get_model_support_spec", + "inspect_architecture", + "is_model_support_registered", + "list_model_support_specs", + "model_uses_expert_parallel", + "model_requires_merged_rollout", + "native_vllm_lora_status_for_model", + "summarize_layer_families", +] diff --git a/src/art/megatron/model_support/discovery.py b/src/art/megatron/model_support/discovery.py new file mode 100644 index 000000000..6f27dd05d --- /dev/null +++ b/src/art/megatron/model_support/discovery.py @@ -0,0 +1,75 @@ +from collections import Counter + +import torch + +from art.megatron.model_support.spec import ArchitectureReport, LayerFamilyInstance +from art.megatron.provider import get_provider_bundle + + +def summarize_layer_families( + layer_families: list[LayerFamilyInstance], +) -> list[LayerFamilyInstance]: + counts = Counter(family.key for family in layer_families) + exemplar_by_key: dict[str, LayerFamilyInstance] = {} + for family in layer_families: + exemplar_by_key.setdefault(family.key, family) + return [ + LayerFamilyInstance( + key=key, + count=count, + layer_index=exemplar_by_key[key].layer_index, + module_path=exemplar_by_key[key].module_path, + module_type=exemplar_by_key[key].module_type, + ) + for key, count in sorted(counts.items()) + ] + + +def recommended_min_layers( + layer_families: list[LayerFamilyInstance], +) -> int: + indexed_layers = [ + family.layer_index + for family in layer_families + if family.layer_index is not None + ] + if indexed_layers: + return max(indexed_layers) + 1 + return max(len(layer_families), 1) + + +def inspect_architecture( + base_model: str, + *, + torch_dtype: torch.dtype = torch.bfloat16, + allow_unvalidated_arch: bool = False, +) -> ArchitectureReport: + provider_bundle = get_provider_bundle( + base_model, + torch_dtype=torch_dtype, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + discovered = provider_bundle.handler.collect_layer_families( + provider_bundle.provider + ) + summarized = summarize_layer_families(discovered) + unresolved_risks: list[str] = [] + if not summarized: + unresolved_risks.append( + "handler did not report any layer families; codex review is required" + ) + if any(family.layer_index is None for family in summarized): + unresolved_risks.append( + "handler did not report representative layer indices for every family; " + "codex review is required" + ) + return ArchitectureReport( + base_model=base_model, + model_key=provider_bundle.spec.key, + handler_key=provider_bundle.handler.key, + bridge_type=type(provider_bundle.bridge._model_bridge).__name__, + provider_type=type(provider_bundle.provider).__name__, + layer_families=summarized, + recommended_min_layers=recommended_min_layers(summarized), + unresolved_risks=unresolved_risks, + ) diff --git a/src/art/megatron/model_support/handlers/__init__.py b/src/art/megatron/model_support/handlers/__init__.py new file mode 100644 index 000000000..80b18c7ce --- /dev/null +++ b/src/art/megatron/model_support/handlers/__init__.py @@ -0,0 +1,33 @@ +from art.megatron.model_support.handlers.default_dense import ( + DEFAULT_DENSE_HANDLER, + DefaultDenseHandler, + DefaultMoeHandler, +) +from art.megatron.model_support.handlers.qwen3_5 import ( + QWEN3_5_DENSE_HANDLER, + QWEN3_5_MOE_HANDLER, + Qwen35DenseHandler, + Qwen35MoeHandler, +) +from art.megatron.model_support.handlers.qwen3_dense import ( + QWEN3_DENSE_HANDLER, + Qwen3DenseHandler, +) +from art.megatron.model_support.handlers.qwen3_moe import ( + QWEN3_MOE_HANDLER, + Qwen3MoeHandler, +) + +__all__ = [ + "DEFAULT_DENSE_HANDLER", + "DefaultDenseHandler", + "DefaultMoeHandler", + "QWEN3_5_DENSE_HANDLER", + "Qwen35DenseHandler", + "QWEN3_DENSE_HANDLER", + "Qwen3DenseHandler", + "QWEN3_MOE_HANDLER", + "Qwen3MoeHandler", + "QWEN3_5_MOE_HANDLER", + "Qwen35MoeHandler", +] diff --git a/src/art/megatron/model_support/handlers/default_dense.py b/src/art/megatron/model_support/handlers/default_dense.py new file mode 100644 index 000000000..bb5cffaab --- /dev/null +++ b/src/art/megatron/model_support/handlers/default_dense.py @@ -0,0 +1,387 @@ +import re +from typing import Any, Sequence + +import torch + +from art.megatron.model_support.spec import ( + CompileWorkaroundConfig, + LayerFamilyInstance, + SharedExpertCompileState, +) + + +class DefaultDenseHandler: + key = "default_dense" + is_moe = False + native_vllm_lora_status = "disabled" + + def identity_lora_model_config(self, base_config: Any) -> Any: + return base_config + + def identity_lora_target_parameters( + self, + model: Any, + *, + target_modules: list[str], + ) -> list[str]: + suffixes = self._identity_lora_parameter_suffixes(target_modules) + return [name for name, _ in model.named_parameters() if name.endswith(suffixes)] + + def _identity_lora_parameter_suffixes( + self, + target_modules: list[str], + ) -> tuple[str, ...]: + target_set = set(target_modules) + suffixes: list[str] = [] + if "q_proj" in target_set: + suffixes.append("q_proj.weight") + if "k_proj" in target_set: + suffixes.append("k_proj.weight") + if "v_proj" in target_set: + suffixes.append("v_proj.weight") + if "o_proj" in target_set: + suffixes.append("o_proj.weight") + if "gate_proj" in target_set: + suffixes.extend(("gate_proj.weight", "mlp.experts.gate_up_proj")) + if "up_proj" in target_set: + suffixes.extend(("up_proj.weight", "mlp.experts.gate_up_proj")) + if "down_proj" in target_set: + suffixes.extend(("down_proj.weight", "mlp.experts.down_proj")) + if "experts" in target_set: + suffixes.extend(("mlp.experts.gate_up_proj", "mlp.experts.down_proj")) + return tuple(dict.fromkeys(suffixes)) + + def patch_provider(self, provider: Any, bridge: Any) -> None: + return None + + def patch_bridge(self, bridge: Any) -> None: + del bridge + return None + + def configure_provider_for_runtime(self, provider: Any) -> None: + del provider + return None + + def install_preprocess_patch(self, model_chunks: Sequence[Any]) -> None: + del model_chunks + return None + + def hf_tensor_map_to_art_canonical( + self, + hf_tensor_map: dict[str, torch.Tensor], + *, + expected_keys: set[str], + ) -> dict[str, torch.Tensor]: + return _unfuse_moe_hf_tensor_map_for_expected_keys( + hf_tensor_map, + expected_keys=expected_keys, + ) + + def to_vllm_lora_tensors( + self, + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], + ) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + return tensors, adapter_config + + def from_vllm_lora_tensors( + self, + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], + ) -> dict[str, torch.Tensor]: + del adapter_config + return tensors + + def _shared_expert_compile_state( + self, + provider: Any, + ) -> SharedExpertCompileState: + if int(getattr(provider, "moe_shared_expert_intermediate_size", 0) or 0) <= 0: + return "none" + if bool(getattr(provider, "moe_shared_expert_overlap", False)): + return "shared_expert_overlap" + return "shared_experts" + + def collect_layer_families(self, provider: Any) -> list[LayerFamilyInstance]: + del provider + return [ + LayerFamilyInstance(key="standard_attention", layer_index=0), + LayerFamilyInstance(key="dense_mlp", layer_index=0), + ] + + def apply_lora_adapters( + self, + model_chunks: Sequence[Any], + provider: Any, + *, + target_modules: list[str], + rank: int, + alpha: int, + ) -> None: + from megatron.core.transformer.transformer_layer import TransformerLayer + + from art.megatron.lora import ( + _adapter_model_prefix, + wrap_dense_mlp, + wrap_standard_self_attention, + ) + + target_set = set(target_modules) + for chunk in model_chunks: + for module in chunk.modules(): + if not isinstance(module, TransformerLayer): + continue + wrap_standard_self_attention( + module.self_attention, + adapter_model_prefix=_adapter_model_prefix(module), + provider=provider, + target_modules=target_set, + rank=rank, + alpha=alpha, + ) + _require_dense_mlp(module) + wrap_dense_mlp( + module.mlp, + adapter_model_prefix=_adapter_model_prefix(module), + provider=provider, + target_modules=target_set, + rank=rank, + alpha=alpha, + ) + + def build_adapter_weights_by_base( + self, + model_chunks: Sequence[Any], + ) -> dict[str, list[Any]]: + from megatron.core.transformer.transformer_layer import TransformerLayer + + from art.megatron.weights.adapter_export import ( + add_dense_mlp_adapter_weights, + add_standard_self_attention_adapter_weights, + layer_base_prefix, + ) + + adapter_weights_by_base: dict[str, list[Any]] = {} + for chunk in model_chunks: + for module_name, module in chunk.named_modules(): + if not isinstance(module, TransformerLayer): + continue + layer_prefix = layer_base_prefix(module, module_name=module_name) + _require_dense_mlp(module) + add_standard_self_attention_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + self_attention=module.self_attention, + ) + add_dense_mlp_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + mlp=module.mlp, + ) + return adapter_weights_by_base + + def compile_workaround_config( + self, + provider: Any, + ) -> CompileWorkaroundConfig: + return CompileWorkaroundConfig( + shared_expert_state=self._shared_expert_compile_state(provider) + ) + + def get_forward_kwargs(self, model: Any, **kwargs: Any) -> dict[str, Any]: + del model + return {"extra_block_kwargs": kwargs} + + +class DefaultMoeHandler(DefaultDenseHandler): + key = "default_moe" + is_moe = True + + def collect_layer_families(self, provider: Any) -> list[LayerFamilyInstance]: + layer_families = [LayerFamilyInstance(key="standard_attention", layer_index=0)] + layer_families.append(LayerFamilyInstance(key="grouped_moe_mlp", layer_index=0)) + if int(getattr(provider, "moe_shared_expert_intermediate_size", 0) or 0) > 0: + layer_families.append( + LayerFamilyInstance(key="shared_experts_mlp", layer_index=0) + ) + return layer_families + + def apply_lora_adapters( + self, + model_chunks: Sequence[Any], + provider: Any, + *, + target_modules: list[str], + rank: int, + alpha: int, + ) -> None: + from megatron.core.transformer.transformer_layer import TransformerLayer + + from art.megatron.lora import ( + _adapter_model_prefix, + wrap_grouped_moe_experts, + wrap_shared_experts_mlp, + wrap_standard_self_attention, + ) + + target_set = set(target_modules) + for chunk in model_chunks: + for module in chunk.modules(): + if not isinstance(module, TransformerLayer): + continue + adapter_model_prefix = _adapter_model_prefix(module) + wrap_standard_self_attention( + module.self_attention, + adapter_model_prefix=adapter_model_prefix, + provider=provider, + target_modules=target_set, + rank=rank, + alpha=alpha, + ) + wrap_grouped_moe_experts( + _require_moe_experts(module), + adapter_model_prefix=adapter_model_prefix, + target_modules=target_set, + rank=rank, + alpha=alpha, + ) + shared_experts = getattr(module.mlp, "shared_experts", None) + if shared_experts is not None: + wrap_shared_experts_mlp( + shared_experts, + adapter_model_prefix=adapter_model_prefix, + provider=provider, + target_modules=target_set, + rank=rank, + alpha=alpha, + ) + + def build_adapter_weights_by_base( + self, + model_chunks: Sequence[Any], + ) -> dict[str, list[Any]]: + from megatron.core.transformer.transformer_layer import TransformerLayer + + from art.megatron.weights.adapter_export import ( + add_grouped_moe_adapter_weights, + add_shared_experts_adapter_weights, + add_standard_self_attention_adapter_weights, + layer_base_prefix, + ) + + adapter_weights_by_base: dict[str, list[Any]] = {} + for chunk in model_chunks: + for module_name, module in chunk.named_modules(): + if not isinstance(module, TransformerLayer): + continue + layer_prefix = layer_base_prefix(module, module_name=module_name) + add_standard_self_attention_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + self_attention=module.self_attention, + ) + add_grouped_moe_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + experts=_require_moe_experts(module), + ) + shared_experts = getattr(module.mlp, "shared_experts", None) + if shared_experts is not None: + add_shared_experts_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + shared_experts=shared_experts, + ) + return adapter_weights_by_base + + +def _require_dense_mlp(module: Any) -> None: + if getattr(module.mlp, "experts", None) is not None: + raise TypeError( + "Dense model support handler received a MoE TransformerLayer; " + "use a MoE handler for this model." + ) + + +def _require_moe_experts(module: Any) -> Any: + experts = getattr(module.mlp, "experts", None) + if experts is None: + raise TypeError( + "MoE model support handler received a dense TransformerLayer; " + "use a dense handler for this model." + ) + return experts + + +_FUSED_MOE_EXPERT_PATTERN = re.compile( + r"^(?P.*\.mlp\.experts)\.(?Pgate_up_proj|down_proj)(?:\.weight)?$" +) + + +def _strip_language_model_prefix(key: str) -> str: + if key.startswith("model.language_model."): + return f"model.{key.removeprefix('model.language_model.')}" + return key + + +def _expected_unfused_experts_for_prefix( + expected_keys: set[str], + prefix: str, + *, + param: str, +) -> bool: + simplified_expected_keys = { + _strip_language_model_prefix(key) for key in expected_keys + } + if param == "gate_up_proj": + return ( + f"{prefix}.0.gate_proj.weight" in simplified_expected_keys + or f"{prefix}.0.up_proj.weight" in simplified_expected_keys + ) + if param == "down_proj": + return f"{prefix}.0.down_proj.weight" in simplified_expected_keys + return False + + +def _unfuse_moe_hf_tensor_map_for_expected_keys( + hf_tensor_map: dict[str, torch.Tensor], + *, + expected_keys: set[str], +) -> dict[str, torch.Tensor]: + canonical: dict[str, torch.Tensor] = {} + for key, value in hf_tensor_map.items(): + match = _FUSED_MOE_EXPERT_PATTERN.match(key) + if match is None: + canonical[key] = value + continue + + prefix = match.group("prefix") + param = match.group("param") + if value.ndim != 3 or not _expected_unfused_experts_for_prefix( + expected_keys, + prefix, + param=param, + ): + canonical[key] = value + continue + + num_experts = int(value.shape[0]) + if param == "gate_up_proj": + if value.shape[1] % 2 != 0: + canonical[key] = value + continue + gate_proj, up_proj = value.chunk(2, dim=1) + for expert in range(num_experts): + canonical[f"{prefix}.{expert}.gate_proj.weight"] = gate_proj[expert] + canonical[f"{prefix}.{expert}.up_proj.weight"] = up_proj[expert] + continue + + for expert in range(num_experts): + canonical[f"{prefix}.{expert}.down_proj.weight"] = value[expert] + + return canonical + + +DEFAULT_DENSE_HANDLER = DefaultDenseHandler() diff --git a/src/art/megatron/model_support/handlers/qwen3_5.py b/src/art/megatron/model_support/handlers/qwen3_5.py new file mode 100644 index 000000000..06ae9392d --- /dev/null +++ b/src/art/megatron/model_support/handlers/qwen3_5.py @@ -0,0 +1,1046 @@ +from copy import copy +from functools import lru_cache +import re +from types import MethodType +from typing import Any, Sequence, cast + +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.ssm.gated_delta_net import GatedDeltaNet +import torch + +from art.megatron.model_support.handlers.default_dense import ( + DefaultDenseHandler, + _require_dense_mlp, + _require_moe_experts, +) +from art.megatron.model_support.spec import ( + CompileWorkaroundConfig, + LayerFamilyInstance, +) +from art.megatron.provider_common import patch_layer_spec_tree +from art.megatron.training.model_chunks import ModelChunks + +_QWEN35_MOE_COMPILE_WORKAROUND_FLAGS = ( + "alltoall_dtoh", + "alltoall_dispatch_preprocess", + "deepep_permute_restore", +) +_ART_LAYER_PREFIX = "base_model.model.model.layers." +_VLLM_LAYER_PREFIX = "base_model.model.model.language_model.layers." +_ART_MOE_EXPERT_KEY_RE = re.compile( + r"^(?P.*\.mlp\.experts)\.(?P\d+)\." + r"(?Pgate_up_proj|down_proj)\.(?Plora_[AB])\.weight$" +) +_VLLM_MOE_KEY_RE = re.compile( + r"^(?P.*\.mlp\.experts)\." + r"(?:(?Pbase_layer)\.)?(?Plora_[AB])\.weight$" +) + + +class Qwen35BaseHandler(DefaultDenseHandler): + key = "qwen3_5_base" + native_vllm_lora_status = "validated" + + def identity_lora_model_config(self, base_config: Any) -> Any: + return getattr(base_config, "text_config", base_config) + + def _identity_lora_parameter_suffixes( + self, + target_modules: list[str], + ) -> tuple[str, ...]: + suffixes = list(super()._identity_lora_parameter_suffixes(target_modules)) + target_set = set(target_modules) + if "in_proj_qkv" in target_set: + suffixes.append("linear_attn.in_proj_qkv.weight") + if "in_proj_z" in target_set: + suffixes.append("linear_attn.in_proj_z.weight") + if "out_proj" in target_set: + suffixes.append("linear_attn.out_proj.weight") + return tuple(dict.fromkeys(suffixes)) + + def to_vllm_lora_tensors( + self, + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], + ) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + if _group_art_moe_tensors(tensors): + raise TypeError("Dense Qwen3.5 handler received MoE LoRA tensors") + transformed: dict[str, torch.Tensor] = {} + for key, tensor in tensors.items(): + vllm_key, tensor = _to_vllm_lora_tensor( + key, + tensor, + adapter_config=adapter_config, + ) + transformed[vllm_key] = tensor + return ( + transformed, + adapter_config, + ) + + def from_vllm_lora_tensors( + self, + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], + ) -> dict[str, torch.Tensor]: + if any(_VLLM_MOE_KEY_RE.match(key) for key in tensors): + raise TypeError("Dense Qwen3.5 handler received MoE vLLM LoRA tensors") + transformed: dict[str, torch.Tensor] = {} + for key, tensor in tensors.items(): + art_key, tensor = _from_vllm_lora_tensor( + key, + tensor, + adapter_config=adapter_config, + ) + transformed[art_key] = tensor + return transformed + + def install_preprocess_patch(self, model_chunks: Sequence[Any]) -> None: + from art.megatron.gdn.operator import ( + install_gdn_island_hooks, + install_shared_prefix_gdn_hooks, + ) + + install_shared_prefix_gdn_hooks(model_chunks) + install_gdn_island_hooks(model_chunks) + for chunk in cast(ModelChunks, list(model_chunks)): + module: Any = chunk + while hasattr(module, "module"): + module = module.module + gpt_module = ( + module + if isinstance(module, GPTModel) + else cast(GPTModel, getattr(module, "language_model")) + ) + if getattr(gpt_module, "mtp_process", False) or hasattr(gpt_module, "mtp"): + raise RuntimeError("ART Qwen3.5 Megatron training does not use MTP.") + preprocess = gpt_module._preprocess + + def preprocess_hook(*args, _preprocess=preprocess, **kwargs): + position_ids = kwargs.get("position_ids") + if isinstance(position_ids, torch.Tensor) and position_ids.ndim == 2: + kwargs = dict(kwargs) + kwargs["position_ids"] = position_ids.unsqueeze(0).expand( + 3, + position_ids.shape[0], + position_ids.shape[1], + ) + preproc_output = list(_preprocess(*args, **kwargs)) + decoder_input = cast(torch.Tensor, preproc_output[0]) + if not decoder_input.requires_grad and decoder_input.is_leaf: + decoder_input.requires_grad_(True) + return tuple(preproc_output) + + gpt_module._preprocess = preprocess_hook # type: ignore[attr-defined] + + def _attention_layer_families(self, provider: Any) -> list[LayerFamilyInstance]: + linear_attention_pattern = _linear_attention_pattern(provider) + gated_delta_net_layer_index = ( + linear_attention_pattern.index(1) if 1 in linear_attention_pattern else 0 + ) + standard_attention_layer_index = ( + linear_attention_pattern.index(0) if 0 in linear_attention_pattern else 0 + ) + layer_families = [ + LayerFamilyInstance( + key="standard_attention", + layer_index=standard_attention_layer_index, + ), + LayerFamilyInstance( + key="gated_delta_net_attention", + layer_index=gated_delta_net_layer_index, + ), + ] + return layer_families + + def collect_layer_families(self, provider: Any) -> list[LayerFamilyInstance]: + if int(getattr(provider, "num_moe_experts", 0) or 0) > 0: + raise TypeError("Dense Qwen3.5 handler received a MoE provider") + return [ + *self._attention_layer_families(provider), + LayerFamilyInstance(key="dense_mlp", layer_index=0), + ] + + def patch_bridge(self, bridge: Any) -> None: + del bridge + _ensure_qwen35_text_only_bridge_registered() + + def configure_provider_for_runtime(self, provider: Any) -> None: + provider.mtp_num_layers = None + provider.mtp_loss_scaling_factor = None + + def patch_provider(self, provider: Any, bridge: Any) -> None: + del bridge + ( + qwen3_vl_self_attention, + qwen35_provider_types, + patch_standard_attention_specs, + transformer_block_spec_factory, + ) = _require_qwen35_provider_symbols() + from art.megatron.flex_attention import FlexDotProductAttention + + matched_provider_type = next( + provider_type + for provider_type in qwen35_provider_types + if isinstance(provider, provider_type) + ) + + def _patch_qwen35_block_spec(block_spec: object) -> None: + patch_standard_attention_specs(block_spec, qwen3_vl_self_attention) + for layer_spec in getattr(block_spec, "layer_specs", ()): + patch_layer_spec_tree(layer_spec, FlexDotProductAttention) + + def _qwen35_layer_spec(config: Any, vp_stage: int | None = None) -> object: + block_spec = transformer_block_spec_factory(config, vp_stage=vp_stage) + _patch_qwen35_block_spec(block_spec) + return block_spec + + def _provide_qwen35_with_flex_attention( + self: Any, + pre_process: bool | None = None, + post_process: bool | None = None, + vp_stage: int | None = None, + ) -> Any: + return matched_provider_type.provide_language_model( + self, + pre_process=pre_process, + post_process=post_process, + vp_stage=vp_stage, + ) + + provider.scatter_embedding_sequence_parallel = True + provider.transformer_layer_spec = _qwen35_layer_spec + provider.provide = MethodType(_provide_qwen35_with_flex_attention, provider) + setattr(provider, "_art_text_only_language_model", True) + + def apply_lora_adapters( + self, + model_chunks: Sequence[Any], + provider: Any, + *, + target_modules: list[str], + rank: int, + alpha: int, + ) -> None: + from megatron.core.transformer.attention import SelfAttention + from megatron.core.transformer.transformer_layer import TransformerLayer + + from art.megatron.lora import ( + _adapter_model_prefix, + _is_language_transformer_layer_name, + wrap_gated_delta_net_attention, + wrap_standard_self_attention, + ) + + target_set = set(target_modules) + for chunk in model_chunks: + for module_name, module in chunk.named_modules(): + if not isinstance(module, TransformerLayer): + continue + if not _is_language_transformer_layer_name(module_name): + continue + adapter_model_prefix = _adapter_model_prefix(module) + if isinstance(module.self_attention, SelfAttention): + wrap_standard_self_attention( + module.self_attention, + adapter_model_prefix=adapter_model_prefix, + provider=provider, + target_modules=target_set, + rank=rank, + alpha=alpha, + ) + elif isinstance(module.self_attention, GatedDeltaNet): + wrap_gated_delta_net_attention( + module.self_attention, + adapter_model_prefix=adapter_model_prefix, + provider=provider, + target_modules=target_set, + rank=rank, + alpha=alpha, + ) + else: + raise TypeError( + "Unsupported self_attention module type for Megatron LoRA: " + f"{type(module.self_attention)}" + ) + self._wrap_mlp_lora( + module, + adapter_model_prefix=adapter_model_prefix, + provider=provider, + target_modules=target_set, + rank=rank, + alpha=alpha, + ) + + def build_adapter_weights_by_base( + self, + model_chunks: Sequence[Any], + ) -> dict[str, list[Any]]: + from megatron.core.transformer.attention import SelfAttention + from megatron.core.transformer.transformer_layer import TransformerLayer + + from art.megatron.lora import _is_language_transformer_layer_name + from art.megatron.weights.adapter_export import ( + add_gated_delta_net_adapter_weights, + add_standard_self_attention_adapter_weights, + layer_base_prefix, + ) + + _ensure_bridge_qwen35_adapter_name_map() + adapter_weights_by_base: dict[str, list[Any]] = {} + for chunk in model_chunks: + for module_name, module in chunk.named_modules(): + if not isinstance(module, TransformerLayer): + continue + if not _is_language_transformer_layer_name(module_name): + continue + layer_prefix = layer_base_prefix(module, module_name=module_name) + if isinstance(module.self_attention, SelfAttention): + add_standard_self_attention_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + self_attention=module.self_attention, + ) + elif isinstance(module.self_attention, GatedDeltaNet): + add_gated_delta_net_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + self_attention=module.self_attention, + ) + self._add_mlp_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + module=module, + ) + return adapter_weights_by_base + + def _wrap_mlp_lora( + self, + module: Any, + *, + adapter_model_prefix: str, + provider: Any, + target_modules: set[str], + rank: int, + alpha: int, + ) -> None: + from art.megatron.lora import wrap_dense_mlp + + _require_dense_mlp(module) + wrap_dense_mlp( + module.mlp, + adapter_model_prefix=adapter_model_prefix, + provider=provider, + target_modules=target_modules, + rank=rank, + alpha=alpha, + ) + + def _add_mlp_adapter_weights( + self, + adapter_weights_by_base: dict[str, list[Any]], + *, + layer_prefix: str, + module: Any, + ) -> None: + from art.megatron.weights.adapter_export import add_dense_mlp_adapter_weights + + _require_dense_mlp(module) + add_dense_mlp_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + mlp=module.mlp, + ) + + def get_forward_kwargs(self, model: Any, **kwargs: Any) -> dict[str, Any]: + unwrapped = model + while hasattr(unwrapped, "module"): + unwrapped = unwrapped.module + if type(unwrapped).__name__ == "Qwen3VLModel": + return {"extra_block_kwargs": {"extra_block_kwargs": kwargs}} + return {"extra_block_kwargs": kwargs} + + +class Qwen35DenseHandler(Qwen35BaseHandler): + key = "qwen3_5_dense" + + +class Qwen35MoeHandler(Qwen35BaseHandler): + key = "qwen3_5_moe" + is_moe = True + + def to_vllm_lora_tensors( + self, + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], + ) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + return _to_vllm_lora_tensors(tensors, adapter_config=adapter_config) + + def from_vllm_lora_tensors( + self, + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], + ) -> dict[str, torch.Tensor]: + return _from_vllm_lora_tensors(tensors, adapter_config=adapter_config) + + def configure_provider_for_runtime(self, provider: Any) -> None: + super().configure_provider_for_runtime(provider) + provider.moe_shared_expert_overlap = False + + def collect_layer_families(self, provider: Any) -> list[LayerFamilyInstance]: + if int(getattr(provider, "num_moe_experts", 0) or 0) <= 0: + raise TypeError("MoE Qwen3.5 handler received a dense provider") + layer_families = [ + *self._attention_layer_families(provider), + LayerFamilyInstance(key="grouped_moe_mlp", layer_index=0), + ] + if int(getattr(provider, "moe_shared_expert_intermediate_size", 0) or 0) > 0: + layer_families.append( + LayerFamilyInstance(key="shared_experts_mlp", layer_index=0) + ) + return layer_families + + def _wrap_mlp_lora( + self, + module: Any, + *, + adapter_model_prefix: str, + provider: Any, + target_modules: set[str], + rank: int, + alpha: int, + ) -> None: + from art.megatron.lora import wrap_grouped_moe_experts_3d + + wrap_grouped_moe_experts_3d( + _require_moe_experts(module), + adapter_model_prefix=adapter_model_prefix, + target_modules=target_modules, + rank=rank, + alpha=alpha, + ) + + def _add_mlp_adapter_weights( + self, + adapter_weights_by_base: dict[str, list[Any]], + *, + layer_prefix: str, + module: Any, + ) -> None: + from art.megatron.weights.adapter_export import ( + add_grouped_moe_adapter_weights, + ) + + add_grouped_moe_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + experts=_require_moe_experts(module), + ) + + def compile_workaround_config( + self, + provider: Any, + ) -> CompileWorkaroundConfig: + if bool(getattr(provider, "moe_shared_expert_overlap", False)): + return CompileWorkaroundConfig( + flags=("moe_forward",), + shared_expert_state="shared_expert_overlap", + disable_compile=True, + ) + return CompileWorkaroundConfig( + flags=_QWEN35_MOE_COMPILE_WORKAROUND_FLAGS, + shared_expert_state="shared_experts", + disable_compile=False, + ) + + +QWEN3_5_DENSE_HANDLER = Qwen35DenseHandler() +QWEN3_5_MOE_HANDLER = Qwen35MoeHandler() + + +def _to_vllm_key(key: str) -> str: + return ( + key.replace(_ART_LAYER_PREFIX, _VLLM_LAYER_PREFIX, 1) + if key.startswith(_ART_LAYER_PREFIX) + else key + ) + + +def _from_vllm_key(key: str) -> str: + return ( + key.replace(_VLLM_LAYER_PREFIX, _ART_LAYER_PREFIX, 1) + if key.startswith(_VLLM_LAYER_PREFIX) + else key + ) + + +def _is_lora_weight_key(key: str) -> bool: + return key.endswith((".lora_A.weight", ".lora_B.weight")) + + +def _is_self_attn_q_proj_lora_b(key: str) -> bool: + return key.endswith(".self_attn.q_proj.lora_B.weight") + + +@lru_cache(maxsize=8) +def _qwen35_text_config(base_model_name_or_path: str) -> Any: + from transformers import AutoConfig + + config = AutoConfig.from_pretrained( + base_model_name_or_path, + local_files_only=True, + trust_remote_code=True, + ) + return getattr(config, "text_config", config) + + +def _qwen35_attention_dims(adapter_config: dict[str, Any]) -> tuple[int, int, int]: + num_heads = adapter_config.get("num_attention_heads") + num_groups = adapter_config.get("num_key_value_heads") + head_dim = adapter_config.get("head_dim") + hidden_size = adapter_config.get("hidden_size") + if num_heads is None: + base_model = adapter_config.get("base_model_name_or_path") + if not base_model: + raise RuntimeError("Qwen3.5 LoRA adapter config is missing base model path") + config = _qwen35_text_config(str(base_model)) + num_heads = getattr(config, "num_attention_heads") + num_groups = getattr(config, "num_key_value_heads", num_heads) + head_dim = getattr(config, "head_dim", None) + hidden_size = getattr(config, "hidden_size", None) + num_heads = int(num_heads) + num_groups = int(num_groups if num_groups is not None else num_heads) + if head_dim is None: + if hidden_size is None: + raise RuntimeError("Qwen3.5 config is missing head_dim and hidden_size") + head_dim = int(hidden_size) // num_heads + head_dim = int(head_dim) + if num_heads % num_groups != 0: + raise RuntimeError( + f"Qwen3.5 attention heads {num_heads} are not divisible by " + f"query groups {num_groups}" + ) + return num_heads, num_groups, head_dim + + +def _qwen35_q_proj_lora_b_to_vllm( + tensor: torch.Tensor, + adapter_config: dict[str, Any], +) -> torch.Tensor: + num_heads, num_groups, head_dim = _qwen35_attention_dims(adapter_config) + heads_per_group = num_heads // num_groups + expected_rows = num_groups * 2 * heads_per_group * head_dim + if tensor.shape[0] != expected_rows: + raise RuntimeError( + f"Qwen3.5 q_proj LoRA-B rows {tensor.shape[0]} do not match " + f"attention output rows {expected_rows}" + ) + rank = tensor.shape[1] + grouped = tensor.reshape(num_groups, 2 * heads_per_group, head_dim, rank) + query = grouped[:, :heads_per_group] + gate = grouped[:, heads_per_group:] + return torch.cat((query, gate), dim=2).reshape(tensor.shape).contiguous() + + +def _qwen35_q_proj_lora_b_from_vllm( + tensor: torch.Tensor, + adapter_config: dict[str, Any], +) -> torch.Tensor: + num_heads, num_groups, head_dim = _qwen35_attention_dims(adapter_config) + heads_per_group = num_heads // num_groups + expected_rows = num_groups * heads_per_group * 2 * head_dim + if tensor.shape[0] != expected_rows: + raise RuntimeError( + f"Qwen3.5 q_proj LoRA-B rows {tensor.shape[0]} do not match " + f"attention output rows {expected_rows}" + ) + rank = tensor.shape[1] + per_head = tensor.reshape(num_groups, heads_per_group, 2 * head_dim, rank) + query, gate = per_head.split(head_dim, dim=2) + return torch.cat((query, gate), dim=1).reshape(tensor.shape).contiguous() + + +def _to_vllm_lora_tensor( + key: str, + tensor: torch.Tensor, + *, + adapter_config: dict[str, Any], +) -> tuple[str, torch.Tensor]: + vllm_key = _to_vllm_key(key) + if _is_self_attn_q_proj_lora_b(vllm_key): + tensor = _qwen35_q_proj_lora_b_to_vllm(tensor, adapter_config) + return vllm_key, tensor + + +def _from_vllm_lora_tensor( + key: str, + tensor: torch.Tensor, + *, + adapter_config: dict[str, Any], +) -> tuple[str, torch.Tensor]: + art_key = _from_vllm_key(key) + if _is_self_attn_q_proj_lora_b(art_key): + tensor = _qwen35_q_proj_lora_b_from_vllm(tensor, adapter_config) + return art_key, tensor + + +def _pack_vllm_3d_lora_b(blocks: list[torch.Tensor]) -> torch.Tensor: + stacked = torch.stack(blocks, dim=0) + return stacked.permute(1, 2, 0).reshape(stacked.shape[1], -1).contiguous() + + +def _unpack_vllm_3d_lora_b( + tensor: torch.Tensor, + *, + num_experts: int, + rank: int, +) -> torch.Tensor: + return tensor.reshape(tensor.shape[0], rank, num_experts).permute(2, 0, 1) + + +def _vllm_moe_config(adapter_config: dict[str, Any]) -> dict[str, Any]: + config = dict(adapter_config) + target_modules = [ + module + for module in list(config.get("target_modules") or []) + if module not in {"gate_proj", "up_proj", "down_proj", "gate_up_proj"} + ] + if "experts" not in target_modules: + target_modules.append("experts") + config["target_modules"] = target_modules + return config + + +def _group_art_moe_tensors( + tensors: dict[str, torch.Tensor], +) -> dict[str, dict[int, dict[str, dict[str, torch.Tensor]]]]: + grouped: dict[str, dict[int, dict[str, dict[str, torch.Tensor]]]] = {} + for key, tensor in tensors.items(): + match = _ART_MOE_EXPERT_KEY_RE.match(key) + if match is None: + continue + grouped.setdefault(match.group("prefix"), {}).setdefault( + int(match.group("expert")), + {}, + ).setdefault(match.group("module"), {})[match.group("lora")] = tensor + return grouped + + +def _to_vllm_lora_tensors( + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], +) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + grouped = _group_art_moe_tensors(tensors) + if not grouped: + transformed: dict[str, torch.Tensor] = {} + for key, tensor in tensors.items(): + vllm_key, tensor = _to_vllm_lora_tensor( + key, + tensor, + adapter_config=adapter_config, + ) + transformed[vllm_key] = tensor + return transformed, adapter_config + transformed: dict[str, torch.Tensor] = {} + used_keys: set[str] = set() + for prefix, experts in grouped.items(): + vllm_prefix = _to_vllm_key(prefix) + gate_up_a: list[torch.Tensor] = [] + gate_up_b: list[torch.Tensor] = [] + down_a: list[torch.Tensor] = [] + down_b: list[torch.Tensor] = [] + for expert in sorted(experts): + modules = experts[expert] + try: + gate_up_a_tensor = modules["gate_up_proj"]["lora_A"] + gate_up_b_tensor = modules["gate_up_proj"]["lora_B"] + d_a = modules["down_proj"]["lora_A"] + d_b = modules["down_proj"]["lora_B"] + except KeyError as exc: + raise RuntimeError( + f"Incomplete Qwen3.5 MoE LoRA block for {prefix}.{expert}" + ) from exc + gate_up_a.append(gate_up_a_tensor.contiguous()) + gate_up_b.append(gate_up_b_tensor.contiguous()) + down_a.append(d_a.contiguous()) + down_b.append(d_b.contiguous()) + for module_name in ("gate_up_proj", "down_proj"): + for lora_name in ("lora_A", "lora_B"): + used_keys.add(f"{prefix}.{expert}.{module_name}.{lora_name}.weight") + transformed[f"{vllm_prefix}.base_layer.lora_A.weight"] = torch.cat( + gate_up_a, + dim=0, + ).contiguous() + transformed[f"{vllm_prefix}.base_layer.lora_B.weight"] = _pack_vllm_3d_lora_b( + gate_up_b + ) + transformed[f"{vllm_prefix}.lora_A.weight"] = torch.cat( + down_a, + dim=0, + ).contiguous() + transformed[f"{vllm_prefix}.lora_B.weight"] = _pack_vllm_3d_lora_b(down_b) + for key, tensor in tensors.items(): + if key in used_keys: + continue + vllm_key, tensor = _to_vllm_lora_tensor( + key, + tensor, + adapter_config=adapter_config, + ) + transformed[vllm_key] = tensor + return transformed, _vllm_moe_config(adapter_config) + + +def _from_vllm_lora_tensors( + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], +) -> dict[str, torch.Tensor]: + grouped: dict[str, dict[str, torch.Tensor]] = {} + for key, tensor in tensors.items(): + match = _VLLM_MOE_KEY_RE.match(key) + if match is None: + continue + slot = ( + f"{'base_layer.' if match.group('base_layer') else ''}{match.group('lora')}" + ) + grouped.setdefault(match.group("prefix"), {})[slot] = tensor + if not grouped: + transformed: dict[str, torch.Tensor] = {} + for key, tensor in tensors.items(): + art_key, tensor = _from_vllm_lora_tensor( + key, + tensor, + adapter_config=adapter_config, + ) + transformed[art_key] = tensor + return transformed + + rank = int(adapter_config["r"]) + transformed: dict[str, torch.Tensor] = {} + used_keys: set[str] = set() + for prefix, slots in grouped.items(): + try: + gate_up_a = slots["base_layer.lora_A"] + gate_up_b = slots["base_layer.lora_B"] + down_a = slots["lora_A"] + down_b = slots["lora_B"] + except KeyError as exc: + raise RuntimeError( + f"Incomplete Qwen3.5 vLLM MoE LoRA block for {prefix}" + ) from exc + if gate_up_a.shape[0] % rank != 0: + raise RuntimeError( + f"{prefix}: gate/up lora_A shape {tuple(gate_up_a.shape)} " + f"is not divisible by rank {rank}" + ) + num_experts = gate_up_a.shape[0] // rank + art_prefix = _from_vllm_key(prefix) + gate_up_b_by_expert = _unpack_vllm_3d_lora_b( + gate_up_b, + num_experts=num_experts, + rank=rank, + ) + down_b_by_expert = _unpack_vllm_3d_lora_b( + down_b, + num_experts=num_experts, + rank=rank, + ) + for expert in range(num_experts): + row = expert * rank + gate_up_a_block = gate_up_a[row : row + rank] + down_a_block = down_a[row : row + rank] + gate_up_b_block = gate_up_b_by_expert[expert] + down_b_block = down_b_by_expert[expert] + transformed[f"{art_prefix}.{expert}.gate_up_proj.lora_A.weight"] = ( + gate_up_a_block.contiguous() + ) + transformed[f"{art_prefix}.{expert}.gate_up_proj.lora_B.weight"] = ( + gate_up_b_block.contiguous() + ) + transformed[f"{art_prefix}.{expert}.down_proj.lora_A.weight"] = ( + down_a_block.contiguous() + ) + transformed[f"{art_prefix}.{expert}.down_proj.lora_B.weight"] = ( + down_b_block.contiguous() + ) + used_keys.update( + { + f"{prefix}.base_layer.lora_A.weight", + f"{prefix}.base_layer.lora_B.weight", + f"{prefix}.lora_A.weight", + f"{prefix}.lora_B.weight", + } + ) + for key, tensor in tensors.items(): + if key in used_keys: + continue + art_key, tensor = _from_vllm_lora_tensor( + key, + tensor, + adapter_config=adapter_config, + ) + transformed[art_key] = tensor + return transformed + + +def _ensure_bridge_qwen35_adapter_name_map() -> None: + from megatron.bridge.models.conversion import peft_bridge + + extra_entries = { + ".in_proj_qkv.weight": "adapter_qkv", + ".in_proj_z.weight": "adapter_z", + ".in_proj_b.weight": "adapter_b", + ".in_proj_a.weight": "adapter_a", + } + for suffix, adapter_key in extra_entries.items(): + peft_bridge.ADAPTER_NAME_MAP.setdefault(suffix, adapter_key) + peft_bridge.ADAPTER_KEY_TO_SUFFIX.setdefault(adapter_key, suffix) + + +def _qwen35_provider_types() -> tuple[type[Any], ...]: + from megatron.bridge.models.qwen_vl.qwen35_vl_provider import ( + Qwen35VLModelProvider, + Qwen35VLMoEModelProvider, + ) + + return (Qwen35VLModelProvider, Qwen35VLMoEModelProvider) + + +def _require_qwen35_provider_symbols() -> tuple[Any, ...]: + from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.attention import ( + Qwen3VLSelfAttention, + ) + from megatron.bridge.models.qwen_vl.qwen35_vl_provider import ( + Qwen35VLModelProvider, + Qwen35VLMoEModelProvider, + _patch_standard_attention_specs, + ) + from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_transformer_block_with_experimental_attention_variant_spec, + ) + + return ( + Qwen3VLSelfAttention, + (Qwen35VLModelProvider, Qwen35VLMoEModelProvider), + _patch_standard_attention_specs, + get_transformer_block_with_experimental_attention_variant_spec, + ) + + +def _register_qwen35_text_only_module_types() -> None: + from megatron.bridge.models.conversion.param_mapping import AutoMapping + + AutoMapping.register_module_type("SharedExpertMLP", "column") + AutoMapping.register_module_type("GatedDeltaNet", "column") + + +def _qwen35_text_only_mapping_registry( + bridge_type: type[Any] | None = None, +) -> Any: + from megatron.bridge.models.conversion.mapping_registry import ( + MegatronMappingRegistry, + ) + from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import ( + Qwen35VLBridge, + Qwen35VLMoEBridge, + ) + + _register_qwen35_text_only_module_types() + upstream_bridge_type = bridge_type or Qwen35VLMoEBridge + assert upstream_bridge_type in {Qwen35VLBridge, Qwen35VLMoEBridge} + upstream_registry = upstream_bridge_type().mapping_registry() + language_mappings = [ + _text_only_qwen35_mapping(mapping) + for mapping in upstream_registry.mappings + if mapping.megatron_param.startswith("language_model.") + and not mapping.megatron_param.startswith("language_model.mtp.") + ] + return MegatronMappingRegistry(*language_mappings) + + +def _text_only_qwen35_mapping(mapping: Any) -> Any: + from megatron.bridge.models.qwen_vl.qwen3_vl_bridge import ( + FusedExpertMapping, + FusedGatedExpertMapping, + ) + + megatron_param = mapping.megatron_param.removeprefix("language_model.") + if isinstance(mapping, FusedGatedExpertMapping): + return _ArtExpertMLPGateUpProjMapping(megatron_param, mapping.hf_param) + if isinstance(mapping, FusedExpertMapping): + return _ArtExpertMLPDownProjMapping(megatron_param, mapping.hf_param) + cloned = copy(mapping) + cloned.megatron_param = megatron_param + return cloned + + +from megatron.bridge.models.qwen_vl.qwen3_vl_bridge import ( + FusedExpertMapping as _BridgeExpertMLPDownProjMapping, +) +from megatron.bridge.models.qwen_vl.qwen3_vl_bridge import ( + FusedGatedExpertMapping as _BridgeExpertMLPGateUpProjMapping, +) + + +class _ArtExpertMLPGateUpProjMapping(_BridgeExpertMLPGateUpProjMapping): + def hf_to_megatron( + self, + hf_weights: torch.Tensor | dict[str, torch.Tensor], + megatron_module: Any, + ) -> torch.Tensor: + from megatron.bridge.models.conversion.param_mapping import ( + _align_expert_weight_to_shape, + ) + from megatron.bridge.models.conversion.utils import ( + get_module_and_param_from_name, + ) + from megatron.bridge.utils.common_utils import ( + extract_expert_number_from_param, + ) + + global_expert_number = extract_expert_number_from_param(self.megatron_param) + expert_weight = ( + hf_weights[global_expert_number] + if isinstance(hf_weights, torch.Tensor) and hf_weights.ndim >= 3 + else hf_weights + ) + normalized_param = self._normalize_expert_param_name(self.megatron_param) + target_param = get_module_and_param_from_name( + megatron_module, normalized_param + )[1] + full_target_shape = ( + target_param.shape[0] * self.tp_size, + target_param.shape[1], + ) + gate_target_shape = ( + full_target_shape[0] // 2, + full_target_shape[1], + ) + if full_target_shape[0] % 2 != 0: + raise ValueError( + f"Expected even fused dim for {self.megatron_param}, got {full_target_shape}." + ) + if ( + isinstance(expert_weight, torch.Tensor) + and expert_weight.ndim == 3 + and expert_weight.shape[0] == 2 + ): + gate = _align_expert_weight_to_shape( + expert_weight[0], torch.Size(gate_target_shape), "gate" + ) + up = _align_expert_weight_to_shape( + expert_weight[1], torch.Size(gate_target_shape), "up" + ) + else: + fused = _align_expert_weight_to_shape( + cast(torch.Tensor, expert_weight), + torch.Size(full_target_shape), + "gate_up", + ) + gate, up = torch.chunk(fused, 2, dim=0) + return self._gated_mapping.hf_to_megatron( + {"gate": gate, "up": up}, + megatron_module, + ) + + +class _ArtExpertMLPDownProjMapping(_BridgeExpertMLPDownProjMapping): + def hf_to_megatron( + self, + hf_weights: torch.Tensor, + megatron_module: Any, + ) -> torch.Tensor: + from megatron.bridge.models.conversion.param_mapping import ( + ColumnParallelMapping, + RowParallelMapping, + _align_expert_weight_to_shape, + ) + from megatron.bridge.models.conversion.utils import ( + get_module_and_param_from_name, + ) + from megatron.bridge.utils.common_utils import ( + extract_expert_number_from_param, + ) + + global_expert_number = extract_expert_number_from_param(self.megatron_param) + expert_weight = ( + hf_weights[global_expert_number] if hf_weights.ndim >= 3 else hf_weights + ) + normalized_param = self._normalize_expert_param_name(self.megatron_param) + target_param = get_module_and_param_from_name( + megatron_module, normalized_param + )[1] + if self._mapping is None: + self._detected_type = self._detect_parallelism_type(megatron_module) + self._mapping = self._get_or_create_mapping(self._detected_type) + if isinstance(self._mapping, ColumnParallelMapping): + full_target_shape = ( + target_param.shape[0] * self.tp_size, + target_param.shape[1], + ) + elif isinstance(self._mapping, RowParallelMapping): + full_target_shape = ( + target_param.shape[0], + target_param.shape[1] * self.tp_size, + ) + else: + full_target_shape = tuple(target_param.shape) + aligned = _align_expert_weight_to_shape( + expert_weight, + torch.Size(full_target_shape), + "down_proj", + ) + return self._mapping.hf_to_megatron(aligned, megatron_module) + + +def _ensure_qwen35_text_only_bridge_registered() -> None: + return None + + +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import ( + _QWEN3_5_DENSE_HF_CLASS_NAME, + _QWEN3_5_MOE_HF_CLASS_NAME, + Qwen35VLBridge, + Qwen35VLMoEBridge, +) +from megatron.bridge.models.qwen_vl.qwen35_vl_provider import ( + Qwen35VLModelProvider, + Qwen35VLMoEModelProvider, +) + + +@MegatronModelBridge.register_bridge( + source=_QWEN3_5_DENSE_HF_CLASS_NAME, + target=GPTModel, + provider=Qwen35VLModelProvider, + model_type="qwen3_5_moe", +) +class _ArtQwen35DenseTextOnlyBridge(Qwen35VLBridge): + def mapping_registry(self) -> Any: + return _qwen35_text_only_mapping_registry(Qwen35VLBridge) + + +@MegatronModelBridge.register_bridge( + source=_QWEN3_5_MOE_HF_CLASS_NAME, + target=GPTModel, + provider=Qwen35VLMoEModelProvider, + model_type="qwen3_5_moe", +) +class _ArtQwen35TextOnlyBridge(Qwen35VLMoEBridge): + def mapping_registry(self) -> Any: + return _qwen35_text_only_mapping_registry(Qwen35VLMoEBridge) + + +def _linear_attention_pattern(provider: Any) -> list[int]: + from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_linear_attention_pattern, + ) + + return list(get_linear_attention_pattern(provider)) diff --git a/src/art/megatron/model_support/handlers/qwen3_common.py b/src/art/megatron/model_support/handlers/qwen3_common.py new file mode 100644 index 000000000..d8cca9754 --- /dev/null +++ b/src/art/megatron/model_support/handlers/qwen3_common.py @@ -0,0 +1,42 @@ +from typing import Any, Sequence, cast + +from megatron.core.models.gpt.gpt_model import GPTModel +import torch + +from art.megatron.training.model_chunks import ModelChunks + + +def install_qwen3_text_preprocess_patch(model_chunks: Sequence[Any]) -> None: + for chunk in cast(ModelChunks, list(model_chunks)): + module: Any = chunk + while hasattr(module, "module"): + module = module.module + gpt_module = ( + module + if isinstance(module, GPTModel) + else cast(GPTModel, getattr(module, "language_model")) + ) + preprocess = gpt_module._preprocess + + def preprocess_hook(*args, _preprocess=preprocess, **kwargs): + preproc_output = list(_preprocess(*args, **kwargs)) + decoder_input = cast(torch.Tensor, preproc_output[0]) + if not decoder_input.requires_grad and decoder_input.is_leaf: + decoder_input.requires_grad_(True) + position_ids = cast(torch.Tensor, kwargs["position_ids"]) + table = cast(torch.Tensor, preproc_output[1]) + embedding_dim = int(table.shape[-1]) + batch_size, sequence_length = position_ids.shape + gathered = table.view(table.shape[0], embedding_dim).index_select( + 0, + position_ids.reshape(-1), + ) + preproc_output[1] = ( + gathered.view(batch_size, sequence_length, embedding_dim) + .permute(1, 0, 2) + .contiguous() + .unsqueeze(2) + ) + return tuple(preproc_output) + + gpt_module._preprocess = preprocess_hook # type: ignore[attr-defined] diff --git a/src/art/megatron/model_support/handlers/qwen3_dense.py b/src/art/megatron/model_support/handlers/qwen3_dense.py new file mode 100644 index 000000000..5cf76e222 --- /dev/null +++ b/src/art/megatron/model_support/handlers/qwen3_dense.py @@ -0,0 +1,17 @@ +from typing import Any, Sequence + +from art.megatron.model_support.handlers.default_dense import DefaultDenseHandler +from art.megatron.model_support.handlers.qwen3_common import ( + install_qwen3_text_preprocess_patch, +) + + +class Qwen3DenseHandler(DefaultDenseHandler): + key = "qwen3_dense" + native_vllm_lora_status = "validated" + + def install_preprocess_patch(self, model_chunks: Sequence[Any]) -> None: + install_qwen3_text_preprocess_patch(model_chunks) + + +QWEN3_DENSE_HANDLER = Qwen3DenseHandler() diff --git a/src/art/megatron/model_support/handlers/qwen3_moe.py b/src/art/megatron/model_support/handlers/qwen3_moe.py new file mode 100644 index 000000000..45656f774 --- /dev/null +++ b/src/art/megatron/model_support/handlers/qwen3_moe.py @@ -0,0 +1,31 @@ +from typing import Any, Sequence + +from art.megatron.model_support.handlers.default_dense import DefaultMoeHandler +from art.megatron.model_support.handlers.qwen3_common import ( + install_qwen3_text_preprocess_patch, +) +from art.megatron.model_support.spec import CompileWorkaroundConfig + +_QWEN3_MOE_COMPILE_WORKAROUND_FLAGS = ( + "alltoall_dtoh", + "alltoall_dispatch_preprocess", + "deepep_permute_restore", +) + + +class Qwen3MoeHandler(DefaultMoeHandler): + key = "qwen3_moe" + native_vllm_lora_status = "validated" + + def install_preprocess_patch(self, model_chunks: Sequence[Any]) -> None: + install_qwen3_text_preprocess_patch(model_chunks) + + def compile_workaround_config( + self, + provider: Any, + ) -> CompileWorkaroundConfig: + del provider + return CompileWorkaroundConfig(flags=_QWEN3_MOE_COMPILE_WORKAROUND_FLAGS) + + +QWEN3_MOE_HANDLER = Qwen3MoeHandler() diff --git a/src/art/megatron/model_support/lora_disk.py b/src/art/megatron/model_support/lora_disk.py new file mode 100644 index 000000000..9df8d345a --- /dev/null +++ b/src/art/megatron/model_support/lora_disk.py @@ -0,0 +1,122 @@ +import importlib +import json +from pathlib import Path +from typing import Any + +import torch + +safetensors = importlib.import_module("safetensors") +safetensors_torch = importlib.import_module("safetensors.torch") +safe_open = safetensors.safe_open +save_file = safetensors_torch.save_file + + +def _jsonable_config(value: Any) -> Any: + if isinstance(value, dict): + return {key: _jsonable_config(item) for key, item in value.items()} + if isinstance(value, set): + return [_jsonable_config(item) for item in sorted(value, key=str)] + if isinstance(value, (list, tuple)): + return [_jsonable_config(item) for item in value] + return value + + +def load_adapter_config(lora_path: str | Path) -> dict[str, Any]: + config_path = Path(lora_path) / "adapter_config.json" + if not config_path.exists(): + return {} + with config_path.open("r", encoding="utf-8") as config_file: + config = json.load(config_file) + return config if isinstance(config, dict) else {} + + +def save_adapter_config(lora_path: str | Path, adapter_config: dict[str, Any]) -> None: + config_path = Path(lora_path) / "adapter_config.json" + with config_path.open("w", encoding="utf-8") as config_file: + json.dump( + _jsonable_config(adapter_config), + config_file, + indent=2, + sort_keys=True, + ) + config_file.write("\n") + + +def resolve_lora_handler( + lora_path: str | Path, + handler: Any | None = None, + *, + allow_unvalidated_arch: bool = False, +) -> Any: + if handler is not None: + return handler + base_model = load_adapter_config(lora_path).get("base_model_name_or_path") + if not isinstance(base_model, str) or not base_model: + raise RuntimeError(f"Missing base_model_name_or_path in {lora_path}") + from art.megatron.model_support import get_model_support_handler + + return get_model_support_handler( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + + +def load_vllm_lora_tensors( + lora_path: str | Path, +) -> dict[str, torch.Tensor]: + adapter_model_path = Path(lora_path) / "adapter_model.safetensors" + with safe_open(adapter_model_path, framework="pt") as adapter_file: + return {key: adapter_file.get_tensor(key) for key in adapter_file.keys()} + + +def save_vllm_lora_tensors( + lora_path: str | Path, + tensors: dict[str, torch.Tensor], + adapter_config: dict[str, Any], +) -> None: + base_dir = Path(lora_path) + base_dir.mkdir(parents=True, exist_ok=True) + save_file(tensors, base_dir / "adapter_model.safetensors") + save_adapter_config(base_dir, adapter_config) + + +def normalize_lora_checkpoint_to_vllm( + lora_path: str | Path, + *, + handler: Any | None = None, + adapter_config: dict[str, Any] | None = None, + allow_unvalidated_arch: bool = False, +) -> None: + adapter_model_path = Path(lora_path) / "adapter_model.safetensors" + if not adapter_model_path.exists(): + return + resolved_handler = resolve_lora_handler( + lora_path, + handler, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + if adapter_config is None: + adapter_config = load_adapter_config(lora_path) + tensors = load_vllm_lora_tensors(lora_path) + tensors, adapter_config = resolved_handler.to_vllm_lora_tensors( + tensors, + adapter_config=adapter_config, + ) + save_vllm_lora_tensors(lora_path, tensors, adapter_config) + + +def load_lora_tensors_for_megatron( + lora_path: str | Path, + *, + handler: Any | None = None, + allow_unvalidated_arch: bool = False, +) -> dict[str, torch.Tensor]: + resolved_handler = resolve_lora_handler( + lora_path, + handler, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + return resolved_handler.from_vllm_lora_tensors( + load_vllm_lora_tensors(lora_path), + adapter_config=load_adapter_config(lora_path), + ) diff --git a/src/art/megatron/model_support/registry.py b/src/art/megatron/model_support/registry.py new file mode 100644 index 000000000..6a9a3c729 --- /dev/null +++ b/src/art/megatron/model_support/registry.py @@ -0,0 +1,254 @@ +from art.megatron.model_support.handlers import ( + DEFAULT_DENSE_HANDLER, + QWEN3_5_DENSE_HANDLER, + QWEN3_5_MOE_HANDLER, + QWEN3_DENSE_HANDLER, + QWEN3_MOE_HANDLER, +) +from art.megatron.model_support.spec import ( + DependencyFloor, + ModelSupportHandler, + ModelSupportSpec, +) + +_DENSE_TARGET_MODULES = ( + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +) + +_QWEN3_5_DENSE_TARGET_MODULES = ( + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "in_proj_qkv", + "in_proj_z", + "out_proj", + "gate_proj", + "up_proj", + "down_proj", +) + +_QWEN3_5_MOE_TARGET_MODULES = ( + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "in_proj_qkv", + "in_proj_z", + "out_proj", + "experts", +) + +DEFAULT_DENSE_SPEC = ModelSupportSpec( + key="default_dense", + handler_key=DEFAULT_DENSE_HANDLER.key, + default_target_modules=_DENSE_TARGET_MODULES, + native_vllm_lora_status=DEFAULT_DENSE_HANDLER.native_vllm_lora_status, +) + +QWEN3_MOE_SPEC = ModelSupportSpec( + key="qwen3_moe", + handler_key=QWEN3_MOE_HANDLER.key, + model_names=( + "Qwen/Qwen3-30B-A3B", + "Qwen/Qwen3-30B-A3B-Base", + "Qwen/Qwen3-30B-A3B-Instruct-2507", + "Qwen/Qwen3-235B-A22B-Instruct-2507", + ), + default_target_modules=_DENSE_TARGET_MODULES, + native_vllm_lora_status=QWEN3_MOE_HANDLER.native_vllm_lora_status, +) + +QWEN3_DENSE_SPEC = ModelSupportSpec( + key="qwen3_dense", + handler_key=QWEN3_DENSE_HANDLER.key, + model_names=( + "Qwen/Qwen3-0.6B", + "Qwen/Qwen3-0.6B-Base", + "Qwen/Qwen3-1.7B", + "Qwen/Qwen3-1.7B-Base", + "Qwen/Qwen3-4B", + "Qwen/Qwen3-4B-Base", + "Qwen/Qwen3-4B-Instruct-2507", + "Qwen/Qwen3-8B", + "Qwen/Qwen3-8B-Base", + "Qwen/Qwen3-14B", + "Qwen/Qwen3-14B-Base", + "Qwen/Qwen3-32B", + "Qwen/Qwen3-32B-Base", + ), + default_target_modules=_DENSE_TARGET_MODULES, + native_vllm_lora_status=QWEN3_DENSE_HANDLER.native_vllm_lora_status, +) + +QWEN3_5_DENSE_SPEC = ModelSupportSpec( + key="qwen3_5_dense", + handler_key=QWEN3_5_DENSE_HANDLER.key, + model_names=( + "Qwen/Qwen3.5-4B", + "Qwen/Qwen3.5-27B", + "Qwen/Qwen3.6-27B", + ), + default_target_modules=_QWEN3_5_DENSE_TARGET_MODULES, + native_vllm_lora_status=QWEN3_5_DENSE_HANDLER.native_vllm_lora_status, + dependency_floor=DependencyFloor( + megatron_bridge="e049cc00c24d03e2ae45d2608c7a44e2d2364e3d", + ), +) + +QWEN3_5_MOE_SPEC = ModelSupportSpec( + key="qwen3_5_moe", + handler_key=QWEN3_5_MOE_HANDLER.key, + model_names=( + "Qwen/Qwen3.5-35B-A3B", + "Qwen/Qwen3.5-397B-A17B", + "Qwen/Qwen3.6-35B-A3B", + ), + default_target_modules=_QWEN3_5_MOE_TARGET_MODULES, + native_vllm_lora_status=QWEN3_5_MOE_HANDLER.native_vllm_lora_status, + dependency_floor=DependencyFloor( + megatron_bridge="e049cc00c24d03e2ae45d2608c7a44e2d2364e3d", + ), +) + +VALIDATED_MODEL_SUPPORT_SPECS = ( + QWEN3_MOE_SPEC, + QWEN3_DENSE_SPEC, + QWEN3_5_MOE_SPEC, + QWEN3_5_DENSE_SPEC, +) +PROBE_ONLY_MODEL_SUPPORT_SPECS = () +_ALL_MODEL_SUPPORT_SPECS = ( + DEFAULT_DENSE_SPEC, + *VALIDATED_MODEL_SUPPORT_SPECS, + *PROBE_ONLY_MODEL_SUPPORT_SPECS, +) +_SPECS_BY_KEY = {spec.key: spec for spec in _ALL_MODEL_SUPPORT_SPECS} +_SPECS_BY_MODEL = { + model_name: spec + for spec in VALIDATED_MODEL_SUPPORT_SPECS + for model_name in spec.model_names +} +_UNVALIDATED_ARCH_SPECS_BY_MODEL = { + model_name: spec + for spec in PROBE_ONLY_MODEL_SUPPORT_SPECS + for model_name in spec.model_names +} +_HANDLERS_BY_KEY: dict[str, ModelSupportHandler] = { + DEFAULT_DENSE_HANDLER.key: DEFAULT_DENSE_HANDLER, + QWEN3_DENSE_HANDLER.key: QWEN3_DENSE_HANDLER, + QWEN3_MOE_HANDLER.key: QWEN3_MOE_HANDLER, + QWEN3_5_DENSE_HANDLER.key: QWEN3_5_DENSE_HANDLER, + QWEN3_5_MOE_HANDLER.key: QWEN3_5_MOE_HANDLER, +} + +QWEN3_DENSE_MODELS = frozenset(QWEN3_DENSE_SPEC.model_names) +QWEN3_MOE_MODELS = frozenset(QWEN3_MOE_SPEC.model_names) +QWEN3_5_DENSE_MODELS = frozenset(QWEN3_5_DENSE_SPEC.model_names) +QWEN3_5_MOE_MODELS = frozenset(QWEN3_5_MOE_SPEC.model_names) +QWEN3_5_MODELS = QWEN3_5_DENSE_MODELS | QWEN3_5_MOE_MODELS + + +class UnsupportedModelArchitectureError(ValueError): + """Raised when a model has not passed the Megatron support workflow.""" + + +def get_model_support_spec( + base_model: str, + *, + allow_unvalidated_arch: bool = False, +) -> ModelSupportSpec: + if spec := _SPECS_BY_MODEL.get(base_model): + return spec + if allow_unvalidated_arch: + return _UNVALIDATED_ARCH_SPECS_BY_MODEL.get(base_model, DEFAULT_DENSE_SPEC) + supported = ", ".join(sorted(_SPECS_BY_MODEL)) + raise UnsupportedModelArchitectureError( + f"{base_model!r} has not passed the Megatron model-support workflow. " + "Pass allow_unvalidated_arch=True only for explicit validation/probing. " + f"Supported models: {supported}." + ) + + +def get_model_support_handler( + base_model: str, + *, + allow_unvalidated_arch: bool = False, +) -> ModelSupportHandler: + return get_model_support_handler_for_spec( + get_model_support_spec( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + ) + + +def get_model_support_handler_for_spec( + spec: ModelSupportSpec, +) -> ModelSupportHandler: + return _HANDLERS_BY_KEY[spec.handler_key] + + +def default_target_modules_for_model( + base_model: str, + *, + allow_unvalidated_arch: bool = False, +) -> list[str]: + return list( + get_model_support_spec( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ).default_target_modules + ) + + +def native_vllm_lora_status_for_model( + base_model: str, + *, + allow_unvalidated_arch: bool = False, +) -> str: + return get_model_support_handler( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ).native_vllm_lora_status + + +def model_requires_merged_rollout( + base_model: str, + *, + allow_unvalidated_arch: bool = False, +) -> bool: + return ( + get_model_support_spec( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ).default_rollout_weights_mode + == "merged" + ) + + +def model_uses_expert_parallel( + base_model: str, + *, + allow_unvalidated_arch: bool = False, +) -> bool: + return bool( + get_model_support_handler( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ).is_moe + ) + + +def is_model_support_registered(base_model: str) -> bool: + return base_model in _SPECS_BY_MODEL + + +def list_model_support_specs() -> list[ModelSupportSpec]: + return list(VALIDATED_MODEL_SUPPORT_SPECS) diff --git a/src/art/megatron/model_support/spec.py b/src/art/megatron/model_support/spec.py new file mode 100644 index 000000000..1e5858c88 --- /dev/null +++ b/src/art/megatron/model_support/spec.py @@ -0,0 +1,153 @@ +from typing import Any, Literal, Protocol, Sequence + +from pydantic import BaseModel, Field + +RolloutWeightsMode = Literal["lora", "merged"] +NativeVllmLoraStatus = Literal["disabled", "wip", "validated"] +SharedExpertCompileState = Literal[ + "none", + "shared_experts", + "shared_expert_overlap", +] + + +class DependencyFloor(BaseModel): + transformers: str | None = None + vllm: str | None = None + megatron_bridge: str | None = None + + +class LayerFamilyInstance(BaseModel): + key: str + count: int = 1 + layer_index: int | None = None + module_path: str | None = None + module_type: str | None = None + + +class ArchitectureReport(BaseModel): + base_model: str + model_key: str + handler_key: str + bridge_type: str | None = None + provider_type: str | None = None + layer_families: list[LayerFamilyInstance] = Field(default_factory=list) + recommended_min_layers: int = 1 + unresolved_risks: list[str] = Field(default_factory=list) + + +class MinimalLayerCoverageReport(BaseModel): + base_model: str + model_key: str + requested_num_layers: int + recommended_min_layers: int + covered: bool + missing_layer_families: list[str] = Field(default_factory=list) + unresolved_risks: list[str] = Field(default_factory=list) + + +class ValidationStageResult(BaseModel): + name: str + passed: bool = False + metrics: dict[str, Any] = Field(default_factory=dict) + artifact_dir: str | None = None + + +class ValidationReport(BaseModel): + base_model: str + model_key: str + dependency_versions: dict[str, str] = Field(default_factory=dict) + stages: list[ValidationStageResult] = Field(default_factory=list) + + +class CompileWorkaroundConfig(BaseModel): + flags: tuple[str, ...] = () + shared_expert_state: SharedExpertCompileState = "none" + disable_compile: bool = False + + +class ModelSupportSpec(BaseModel): + key: str + handler_key: str + model_names: tuple[str, ...] = () + default_target_modules: tuple[str, ...] + default_rollout_weights_mode: RolloutWeightsMode = "lora" + native_vllm_lora_status: NativeVllmLoraStatus = "disabled" + dependency_floor: DependencyFloor = Field(default_factory=DependencyFloor) + + +class ModelSupportHandler(Protocol): + key: str + is_moe: bool + native_vllm_lora_status: NativeVllmLoraStatus + + def identity_lora_model_config(self, base_config: Any) -> Any: ... + + def identity_lora_target_parameters( + self, + model: Any, + *, + target_modules: list[str], + ) -> list[str]: ... + + def patch_bridge(self, bridge: Any) -> None: ... + + def patch_provider(self, provider: Any, bridge: Any) -> None: ... + + def configure_provider_for_runtime(self, provider: Any) -> None: ... + + def install_preprocess_patch(self, model_chunks: Sequence[Any]) -> None: ... + + def collect_layer_families(self, provider: Any) -> list[LayerFamilyInstance]: ... + + def apply_lora_adapters( + self, + model_chunks: Sequence[Any], + provider: Any, + *, + target_modules: list[str], + rank: int, + alpha: int, + ) -> None: ... + + def build_adapter_weights_by_base( + self, + model_chunks: Sequence[Any], + ) -> dict[str, list[Any]]: ... + + def hf_tensor_map_to_art_canonical( + self, + hf_tensor_map: dict[str, Any], + *, + expected_keys: set[str], + ) -> dict[str, Any]: + """ + Testing-only hook for canonicalizing raw HuggingFace tensor maps into the + ART tensor-map keyspace expected by model-support probes. + + This currently exists to support validations such as HF parity, where the + raw HF model can expose fused parameter names or layouts that differ from + the canonical names ART compares against. + """ + ... + + def to_vllm_lora_tensors( + self, + tensors: dict[str, Any], + *, + adapter_config: dict[str, Any], + ) -> tuple[dict[str, Any], dict[str, Any]]: ... + + def from_vllm_lora_tensors( + self, + tensors: dict[str, Any], + *, + adapter_config: dict[str, Any], + ) -> dict[str, Any]: ... + + def compile_workaround_config( + self, + provider: Any, + ) -> CompileWorkaroundConfig: ... + + def get_forward_kwargs(self, model: Any, **kwargs: Any) -> dict[str, Any]: ... diff --git a/src/art/megatron/provider.py b/src/art/megatron/provider.py index 4b23f02c6..7c54eb75c 100644 --- a/src/art/megatron/provider.py +++ b/src/art/megatron/provider.py @@ -1,99 +1,24 @@ -import copy -from dataclasses import dataclass -import inspect -import json import os -from pathlib import Path -from types import MethodType -from typing import Any, Callable, Literal, cast -import warnings +from typing import Any, Literal, cast from megatron.bridge import AutoBridge from megatron.bridge.models.gpt_provider import GPTModelProvider -from megatron.bridge.models.hf_pretrained.state import ( - SafeTensorsStateSource, - StateDict, - StateSource, -) -from megatron.bridge.models.qwen.qwen3_moe_bridge import Qwen3MoEBridge -from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.attention import ( - Qwen3VLSelfAttention, -) -from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model import Qwen3VLModel -from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import Qwen35VLMoEBridge -from megatron.bridge.models.qwen_vl.qwen35_vl_provider import ( - Qwen35VLMoEModelProvider, - _patch_standard_attention_specs, -) from megatron.bridge.training.flex_dispatcher_backend import ( apply_flex_dispatcher_backend, ) -from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( - get_transformer_block_with_experimental_attention_variant_spec, -) from megatron.core.transformer.enums import AttnBackend -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_block import TransformerBlockSubmodules import torch from art.megatron.flex_attention import FlexDotProductAttention - -_finalized_env_settings_printed = False - - -@dataclass(frozen=True) -class ProviderBundle: - provider: GPTModelProvider - bridge: Any - - -def _resolve_layer_spec( - base_layer_spec: ModuleSpec | Callable[[GPTModelProvider], ModuleSpec], - config: GPTModelProvider, - vp_stage: int | None = None, -) -> ModuleSpec: - if isinstance(base_layer_spec, ModuleSpec): - return copy.deepcopy(base_layer_spec) - kwargs = ( - {"vp_stage": vp_stage} - if vp_stage in inspect.signature(base_layer_spec).parameters - else {} - ) - return base_layer_spec(config, **kwargs) - - -def _patch_core_attention(layer_spec: object) -> None: - submodules = getattr(layer_spec, "submodules", None) - self_attention = getattr(submodules, "self_attention", None) - attention_submodules = getattr(self_attention, "submodules", None) - if attention_submodules is None or not hasattr( - attention_submodules, "core_attention" - ): - return - attention_submodules.core_attention = FlexDotProductAttention - - -class _CastingStateSource(StateSource): - def __init__(self, source: StateSource, *, dtype: torch.dtype): - self._source = source - self._dtype = dtype - - def get_all_keys(self) -> list[str]: - return self._source.get_all_keys() - - def load_tensors(self, keys: list[str]) -> dict[str, torch.Tensor]: - loaded = self._source.load_tensors(keys) - return { - key: ( - value.to(dtype=self._dtype) - if torch.is_floating_point(value) and value.dtype != self._dtype - else value - ) - for key, value in loaded.items() - } - - def has_glob(self, pattern: str) -> bool: - return self._source.has_glob(pattern) +from art.megatron.model_support.registry import ( + get_model_support_handler_for_spec, + get_model_support_spec, +) +from art.megatron.provider_common import ( + ProviderBundle, + patch_layer_spec_tree, + resolve_layer_spec, +) def _env_flag(name: str) -> bool | None: @@ -108,7 +33,7 @@ def _env_flag(name: str) -> bool | None: raise ValueError(f"{name} must be a boolean-like value, got {raw!r}") -def _env_optional_str(name: str) -> tuple[bool, str | None]: +def _env_override_str(name: str) -> tuple[bool, str | None]: raw = os.environ.get(name) if raw is None: return False, None @@ -118,45 +43,25 @@ def _env_optional_str(name: str) -> tuple[bool, str | None]: return True, value -def _env_optional_int(name: str) -> tuple[bool, int | None]: - found, value = _env_optional_str(name) +def _env_override_int(name: str) -> tuple[bool, int | None]: + found, value = _env_override_str(name) if not found or value is None: return found, None return True, int(value) -def _env_default_or_even_positive_int(name: str) -> tuple[bool, int | None]: - raw = os.environ.get(name) - if raw is None: - return False, None - value = raw.strip().lower() - if value == "default": - return True, None - try: - parsed = int(raw.strip()) - except ValueError as exc: - raise ValueError( - f"{name} must be 'default' or a positive, even integer, got {raw!r}" - ) from exc - if parsed <= 0 or parsed % 2 != 0: - raise ValueError( - f"{name} must be 'default' or a positive, even integer, got {raw!r}" - ) - return True, parsed - - -def _env_optional_str_list(name: str) -> tuple[bool, list[str] | None]: - found, value = _env_optional_str(name) +def _env_override_str_list(name: str) -> tuple[bool, list[str] | None]: + found, value = _env_override_str(name) if not found or value is None: return found, None parts = [part.strip() for part in value.split(",")] return True, [part for part in parts if part] -def _env_optional_recompute_granularity( +def _env_override_recompute_granularity( name: str, ) -> tuple[bool, Literal["full", "selective"] | None]: - found, value = _env_optional_str(name) + found, value = _env_override_str(name) if not found or value is None: return found, None if value not in {"full", "selective"}: @@ -164,10 +69,10 @@ def _env_optional_recompute_granularity( return True, cast(Literal["full", "selective"], value) -def _env_optional_recompute_method( +def _env_override_recompute_method( name: str, ) -> tuple[bool, Literal["uniform", "block"] | None]: - found, value = _env_optional_str(name) + found, value = _env_override_str(name) if not found or value is None: return found, None if value not in {"uniform", "block"}: @@ -190,52 +95,35 @@ def _apply_default_parallel_topology(provider: GPTModelProvider) -> None: provider.tensor_model_parallel_size = visible_gpu_count provider.context_parallel_size = 1 provider.pipeline_model_parallel_size = 1 - provider.expert_model_parallel_size = visible_gpu_count + provider.expert_model_parallel_size = ( + visible_gpu_count + if int(getattr(provider, "num_moe_experts", 0) or 0) > 0 + else 1 + ) provider.expert_tensor_parallel_size = 1 -def _tp_ep_parallel_domain_size(provider: GPTModelProvider) -> int: - return int(provider.tensor_model_parallel_size) * int( - provider.expert_model_parallel_size +def _etp_ep_parallel_domain_size(provider: GPTModelProvider) -> int: + return ( + cast(int, provider.expert_tensor_parallel_size) + * provider.expert_model_parallel_size ) -def _maybe_print_finalized_env_settings(provider: GPTModelProvider) -> None: - global _finalized_env_settings_printed - if _finalized_env_settings_printed: +def _apply_art_training_runtime_prepare_defaults(provider: GPTModelProvider) -> None: + provider.recompute_granularity = "full" + provider.recompute_method = "uniform" + provider.recompute_num_layers = 1 + provider.moe_shared_expert_overlap = True + _apply_default_parallel_topology(provider) + + +def _apply_art_training_runtime_finalize_defaults(provider: GPTModelProvider) -> None: + if _etp_ep_parallel_domain_size(provider) <= 1: return - if torch.distributed.is_initialized(): # ty: ignore[possibly-missing-attribute] - if torch.distributed.get_rank() != 0: # ty: ignore[possibly-missing-attribute] - _finalized_env_settings_printed = True - return - _finalized_env_settings_printed = True - print( - "Finalized Megatron env settings:", - json.dumps( - { - "tensor_model_parallel_size": provider.tensor_model_parallel_size, - "expert_model_parallel_size": provider.expert_model_parallel_size, - "overlap_moe_expert_parallel_comm": provider.overlap_moe_expert_parallel_comm, - "delay_wgrad_compute": provider.delay_wgrad_compute, - "ep_overlap_early_attn_memory_release": provider.ep_overlap_early_attn_memory_release, - "moe_deepep_num_sms": provider.moe_deepep_num_sms, - "moe_apply_probs_on_input": provider.moe_apply_probs_on_input, - "bias_activation_fusion": provider.bias_activation_fusion, - "fine_grained_activation_offloading": provider.fine_grained_activation_offloading, - "offload_modules": provider.offload_modules, - "recompute_granularity": provider.recompute_granularity, - "recompute_method": provider.recompute_method, - "recompute_num_layers": provider.recompute_num_layers, - "recompute_modules": provider.recompute_modules, - "moe_shared_expert_overlap": provider.moe_shared_expert_overlap, - "moe_flex_dispatcher_backend": ( - "deepep" if _tp_ep_parallel_domain_size(provider) > 1 else None - ), - "sequence_parallel": provider.sequence_parallel, - }, - indent=2, - ), - ) + # use DeepEP for MoE expert comm. comm can be the same amount of time as actual MLP + # compute, so these are very beneficial + apply_flex_dispatcher_backend(provider, moe_flex_dispatcher_backend="deepep") def _apply_runtime_env_overrides(provider: GPTModelProvider) -> None: @@ -253,16 +141,10 @@ def _apply_runtime_env_overrides(provider: GPTModelProvider) -> None: if early_attn_release is not None: provider.ep_overlap_early_attn_memory_release = early_attn_release - found, deepep_num_sms = _env_default_or_even_positive_int( - "ART_MEGATRON_MOE_DEEPEP_NUM_SMS" - ) - if found: - provider.moe_deepep_num_sms = ( - _resolve_default_deepep_num_sms(provider) - if deepep_num_sms is None - else deepep_num_sms - ) - else: + found, deepep_num_sms = _env_override_int("ART_MEGATRON_MOE_DEEPEP_NUM_SMS") + if found and deepep_num_sms is not None: + provider.moe_deepep_num_sms = deepep_num_sms + if "ART_MEGATRON_MOE_DEEPEP_NUM_SMS" not in os.environ: provider.moe_deepep_num_sms = _resolve_default_deepep_num_sms(provider) moe_apply_probs_on_input = _env_flag("ART_MEGATRON_MOE_APPLY_PROBS_ON_INPUT") @@ -279,37 +161,53 @@ def _apply_runtime_env_overrides(provider: GPTModelProvider) -> None: if fine_grained_activation_offloading is not None: provider.fine_grained_activation_offloading = fine_grained_activation_offloading - offload_modules_found, offload_modules = _env_optional_str_list( + offload_modules_found, offload_modules = _env_override_str_list( "ART_MEGATRON_OFFLOAD_MODULES" ) if offload_modules_found: provider.offload_modules = [] if offload_modules is None else offload_modules - found, tensor_model_parallel_size = _env_optional_int( + found, tensor_model_parallel_size = _env_override_int( "ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE" ) if found and tensor_model_parallel_size is not None: provider.tensor_model_parallel_size = tensor_model_parallel_size + found, expert_model_parallel_size = _env_override_int( + "ART_MEGATRON_EXPERT_MODEL_PARALLEL_SIZE" + ) + if found and expert_model_parallel_size is not None: + provider.expert_model_parallel_size = expert_model_parallel_size + + found, expert_tensor_parallel_size = _env_override_int( + "ART_MEGATRON_EXPERT_TENSOR_PARALLEL_SIZE" + ) + if not found: + found, expert_tensor_parallel_size = _env_override_int( + "ART_MEGATRON_EXPERT_TENSOR_MODEL_PARALLEL_SIZE" + ) + if found and expert_tensor_parallel_size is not None: + provider.expert_tensor_parallel_size = expert_tensor_parallel_size + recompute_granularity_found, recompute_granularity = ( - _env_optional_recompute_granularity("ART_MEGATRON_RECOMPUTE_GRANULARITY") + _env_override_recompute_granularity("ART_MEGATRON_RECOMPUTE_GRANULARITY") ) if recompute_granularity_found: provider.recompute_granularity = recompute_granularity - recompute_method_found, recompute_method = _env_optional_recompute_method( + recompute_method_found, recompute_method = _env_override_recompute_method( "ART_MEGATRON_RECOMPUTE_METHOD" ) if recompute_method_found: provider.recompute_method = recompute_method - recompute_num_layers_found, recompute_num_layers = _env_optional_int( + recompute_num_layers_found, recompute_num_layers = _env_override_int( "ART_MEGATRON_RECOMPUTE_NUM_LAYERS" ) if recompute_num_layers_found: provider.recompute_num_layers = recompute_num_layers - recompute_modules_found, recompute_modules = _env_optional_str_list( + recompute_modules_found, recompute_modules = _env_override_str_list( "ART_MEGATRON_RECOMPUTE_MODULES" ) if recompute_modules_found: @@ -323,13 +221,6 @@ def _apply_runtime_env_overrides(provider: GPTModelProvider) -> None: # EP overlap is incompatible with full recompute in Megatron, so treat # overlap as the authoritative request even if a launcher exported the # usual recompute defaults. Selective recompute is still allowed. - if shared_expert_overlap: - warnings.warn( - "ART_MEGATRON_MOE_SHARED_EXPERT_OVERLAP=true is incompatible with " - "ART_MEGATRON_OVERLAP_MOE_EXPERT_PARALLEL_COMM; forcing " - "moe_shared_expert_overlap=False", - stacklevel=2, - ) provider.moe_shared_expert_overlap = False provider.recompute_method = None provider.recompute_num_layers = None @@ -337,122 +228,59 @@ def _apply_runtime_env_overrides(provider: GPTModelProvider) -> None: provider.recompute_granularity = None -def get_provider_bundle( +def _install_art_training_flex_attention(provider: GPTModelProvider) -> None: + base_layer_spec = provider.transformer_layer_spec + + def _flex_attention_layer_spec( + config: GPTModelProvider, vp_stage: int | None = None + ) -> object: + layer_spec = resolve_layer_spec(base_layer_spec, config, vp_stage) + patch_layer_spec_tree(layer_spec, FlexDotProductAttention) + return layer_spec + + provider.transformer_layer_spec = cast(Any, _flex_attention_layer_spec) + + +def _build_provider_bundle( model: str, *, - torch_dtype: torch.dtype = torch.bfloat16, + torch_dtype: torch.dtype, + allow_unvalidated_arch: bool = False, ) -> ProviderBundle: + spec = get_model_support_spec( + model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + handler = get_model_support_handler_for_spec(spec) bridge = AutoBridge.from_hf_pretrained( model, dtype=torch_dtype, trust_remote_code=True, ) - assert isinstance(bridge._model_bridge, (Qwen3MoEBridge, Qwen35VLMoEBridge)), ( - "Only Qwen3 and Qwen3.5 MoE models are supported" + handler.patch_bridge(bridge) + return ProviderBundle( + provider=bridge.to_megatron_provider(), + bridge=bridge, + handler=handler, + spec=spec, ) - if torch_dtype != torch.bfloat16: - model_name_or_path = bridge.hf_pretrained.model_name_or_path - assert model_name_or_path is not None - bridge.hf_pretrained._state_dict_accessor = StateDict( - _CastingStateSource( - SafeTensorsStateSource(cast(str | Path, model_name_or_path)), - dtype=torch_dtype, - ) - ) - provider = bridge.to_megatron_provider() - if isinstance(provider, Qwen35VLMoEModelProvider): - from megatron.bridge.models.gpt_provider import mtp_block_spec - - def _patch_qwen35_block_spec(block_spec: TransformerBlockSubmodules) -> None: - _patch_standard_attention_specs(block_spec, Qwen3VLSelfAttention) - layer_specs = block_spec.layer_specs - if layer_specs is None: - return - for layer_spec in layer_specs: - _patch_core_attention(layer_spec) - - def _qwen35_layer_spec( - config: GPTModelProvider, vp_stage: int | None = None - ) -> ModuleSpec: - block_spec = get_transformer_block_with_experimental_attention_variant_spec( - config, - vp_stage=vp_stage, - ) - _patch_qwen35_block_spec(block_spec) - return cast(ModuleSpec, block_spec) - - provider.transformer_layer_spec = _qwen35_layer_spec - - def _provide_qwen35_with_flex_attention( - self: Qwen35VLMoEModelProvider, - pre_process: bool | None = None, - post_process: bool | None = None, - vp_stage: int | None = None, - ) -> Qwen3VLModel: - language_transformer_config = cast(Any, self) - hf_vision_config = self.vision_config - hf_vision_config.torch_dtype = self.params_dtype - block_spec = cast( - ModuleSpec, - get_transformer_block_with_experimental_attention_variant_spec( - language_transformer_config, - vp_stage=vp_stage, - ), - ) - _patch_qwen35_block_spec(cast(TransformerBlockSubmodules, block_spec)) - pre_process_value = True if pre_process is None else pre_process - post_process_value = True if post_process is None else post_process - model = Qwen3VLModel( - language_transformer_config=language_transformer_config, - language_transformer_layer_spec=block_spec, - vision_transformer_config=hf_vision_config, - pre_process=pre_process_value, - post_process=post_process_value, - pg_collection=cast(Any, self._pg_collection), - mtp_block_spec=mtp_block_spec(self, vp_stage=vp_stage), - vp_stage=vp_stage, - ) - if ( - self.freeze_language_model - or self.freeze_vision_model - or self.freeze_vision_projection - ): - model.freeze( - freeze_language_model=self.freeze_language_model, - freeze_vision_model=self.freeze_vision_model, - freeze_vision_projection=self.freeze_vision_projection, - ) - return model - - provider.provide = cast( - Any, MethodType(_provide_qwen35_with_flex_attention, provider) - ) - base_layer_spec = provider.transformer_layer_spec - def _flex_attention_layer_spec( - config: GPTModelProvider, vp_stage: int | None = None - ) -> ModuleSpec: - layer_spec = _resolve_layer_spec(base_layer_spec, config, vp_stage) - layer_specs = getattr(layer_spec, "layer_specs", None) - if layer_specs is None: - _patch_core_attention(layer_spec) - else: - for block_layer_spec in layer_specs: - _patch_core_attention(block_layer_spec) - return layer_spec - provider.transformer_layer_spec = _flex_attention_layer_spec +def prepare_provider_bundle( + model: str, + *, + torch_dtype: torch.dtype = torch.bfloat16, + allow_unvalidated_arch: bool = False, +) -> ProviderBundle: + bundle = _build_provider_bundle( + model, + torch_dtype=torch_dtype, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + provider = bundle.provider + setattr(provider, "_art_model_support_handler", bundle.handler) + setattr(provider, "_art_model_support_spec", bundle.spec) provider.attention_backend = AttnBackend.auto - provider.recompute_granularity = "full" - provider.recompute_method = "uniform" - provider.recompute_num_layers = 1 - provider.moe_shared_expert_overlap = True - _apply_default_parallel_topology(provider) - _apply_runtime_env_overrides(provider) - if _tp_ep_parallel_domain_size(provider) > 1: - # use DeepEP for MoE expert comm. comm can be the same amount of time as actual MLP - # compute, so these are very beneficial - apply_flex_dispatcher_backend(provider, moe_flex_dispatcher_backend="deepep") provider.moe_permute_fusion = True provider.moe_router_dtype = "fp32" # params are disabled anyways, but should know about this if we switch to full FT @@ -460,18 +288,47 @@ def _flex_attention_layer_spec( provider.moe_aux_loss_coeff = 0.0 # effectively just a flag modifying finalize_model_grads behavior for DPxCP provider.calculate_per_token_loss = True + provider.cross_entropy_loss_fusion = True + provider.cross_entropy_fusion_impl = "te" + _apply_art_training_runtime_prepare_defaults(provider) + bundle.handler.configure_provider_for_runtime(provider) + _apply_runtime_env_overrides(provider) provider.sequence_parallel = provider.tensor_model_parallel_size > 1 - # ART computes its own RL loss, so MTP only adds incompatible postprocess work. - provider.mtp_enabled = False - provider.mtp_num_layers = 0 - _maybe_print_finalized_env_settings(provider) + _install_art_training_flex_attention(provider) + bundle.handler.patch_provider(provider, bundle.bridge) + return bundle + + +def finalize_provider_bundle(provider_bundle: ProviderBundle) -> ProviderBundle: + provider = cast(GPTModelProvider, provider_bundle.provider) + _apply_art_training_runtime_finalize_defaults(provider) provider.finalize() - return ProviderBundle(provider=provider, bridge=bridge) + return provider_bundle + + +def get_provider_bundle( + model: str, + *, + torch_dtype: torch.dtype = torch.bfloat16, + allow_unvalidated_arch: bool = False, +) -> ProviderBundle: + return finalize_provider_bundle( + prepare_provider_bundle( + model, + torch_dtype=torch_dtype, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + ) def get_provider( model: str, *, torch_dtype: torch.dtype = torch.bfloat16, + allow_unvalidated_arch: bool = False, ) -> GPTModelProvider: - return get_provider_bundle(model, torch_dtype=torch_dtype).provider + return get_provider_bundle( + model, + torch_dtype=torch_dtype, + allow_unvalidated_arch=allow_unvalidated_arch, + ).provider diff --git a/src/art/megatron/provider_common.py b/src/art/megatron/provider_common.py new file mode 100644 index 000000000..701428cff --- /dev/null +++ b/src/art/megatron/provider_common.py @@ -0,0 +1,53 @@ +import copy +import inspect +from typing import Any, Callable + +from megatron.core.transformer.spec_utils import ModuleSpec +from pydantic import BaseModel, ConfigDict + +from art.megatron.model_support.spec import ModelSupportSpec + + +class ProviderBundle(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + provider: Any + bridge: Any + handler: Any + spec: ModelSupportSpec + + +def resolve_layer_spec( + base_layer_spec: Any, + config: Any, + vp_stage: int | None = None, +) -> Any: + if isinstance(base_layer_spec, ModuleSpec): + return copy.deepcopy(base_layer_spec) + kwargs = ( + {"vp_stage": vp_stage} + if vp_stage in inspect.signature(base_layer_spec).parameters + else {} + ) + return base_layer_spec(config, **kwargs) + + +def patch_core_attention(layer_spec: object, core_attention: object) -> None: + submodules = getattr(layer_spec, "submodules", None) + self_attention = getattr(submodules, "self_attention", None) + attention_submodules = getattr(self_attention, "submodules", None) + if attention_submodules is None or not hasattr( + attention_submodules, + "core_attention", + ): + return + attention_submodules.core_attention = core_attention + + +def patch_layer_spec_tree(layer_spec: object, core_attention: object) -> None: + layer_specs = getattr(layer_spec, "layer_specs", None) + if layer_specs is None: + patch_core_attention(layer_spec, core_attention) + return + for block_layer_spec in layer_specs: + patch_core_attention(block_layer_spec, core_attention) diff --git a/src/art/megatron/routing_replay.py b/src/art/megatron/routing_replay.py index 0705a69a7..b30eddd0b 100644 --- a/src/art/megatron/routing_replay.py +++ b/src/art/megatron/routing_replay.py @@ -2,6 +2,7 @@ from collections import defaultdict import json +import logging from pathlib import Path import re import types @@ -16,6 +17,8 @@ from safetensors.torch import load_file, save_file import torch +from art.megatron.weights.param_name_canonicalization import canonical_art_param_name + ROUTER_NAME_TOKEN = ".mlp.router" ROUTER_KEY_FORMAT_VERSION = "moe_routing_replay_v1" GLOBAL_TOKEN_UIDS_KEY = "global_token_uids" @@ -24,6 +27,7 @@ _ROUTER_LAYER_PATTERN = re.compile(r"decoder\.layers\.(?P\d+)\.mlp\.router$") _TRACE_CHUNK_PREFIX_PATTERN = re.compile(r"^chunk(?P\d+)\.(?P.+)$") +logger = logging.getLogger(__name__) def _to_tensor_cpu_contiguous( @@ -112,11 +116,13 @@ def _trace_call_route_metadata( def build_router_key_from_module_name(*, chunk_index: int, module_name: str) -> str: - match = _ROUTER_LAYER_PATTERN.search(module_name) + canonical_name = canonical_art_param_name(module_name) + match = _ROUTER_LAYER_PATTERN.search(canonical_name) if match is None: raise RuntimeError( f"Unable to derive router key from module name '{module_name}'. " - f"Expected suffix matching '{_ROUTER_LAYER_PATTERN.pattern}'." + f"Canonicalized to '{canonical_name}', expected suffix matching " + f"'{_ROUTER_LAYER_PATTERN.pattern}'." ) layer_index = int(match.group("layer")) return f"chunk_{chunk_index:02d}.layer_{layer_index:04d}.mlp.router" @@ -505,11 +511,34 @@ def build_local_token_uids( tp_size = int(ps.get_tensor_model_parallel_world_size()) tp_rank = int(ps.get_tensor_model_parallel_rank()) if tp_size > 1 else 0 if sequence_parallel and tp_size > 1: - tokens_per_tp_rank = local_uids.shape[1] // tp_size - start = tp_rank * tokens_per_tp_rank - local_uids = local_uids[:, start : start + tokens_per_tp_rank] + total_tokens = int(local_uids.shape[1]) + if total_tokens != num_local_tokens: + if total_tokens % tp_size != 0: + raise RuntimeError( + "Routing replay cannot derive sequence-parallel local token " + "uids from merged rows: " + f"total_tokens={total_tokens}, tp_size={tp_size}, " + f"num_local_tokens={num_local_tokens}" + ) + tokens_per_tp_rank = total_tokens // tp_size + if tokens_per_tp_rank != num_local_tokens: + raise RuntimeError( + "Routing replay local token uid count mismatch after " + "context-parallel slicing: " + f"total_tokens={total_tokens}, tp_size={tp_size}, " + f"expected_local_tokens={num_local_tokens}, " + f"tp_local_tokens={tokens_per_tp_rank}" + ) + start = tp_rank * tokens_per_tp_rank + local_uids = local_uids[:, start : start + tokens_per_tp_rank] - return local_uids.reshape(-1).contiguous() + local_uids = local_uids.reshape(-1).contiguous() + if int(local_uids.numel()) != num_local_tokens: + raise RuntimeError( + "Routing replay local token uid count mismatch: " + f"expected={num_local_tokens}, got={int(local_uids.numel())}" + ) + return local_uids _ACTIVE_ROUTING_REPLAY_CONTROLLER: MoeRoutingReplayController | None = None @@ -573,6 +602,43 @@ def _attach_trace_row_uids( setattr(target, TRACE_UID_SPAN_ATTR, uid_span) +@torch._dynamo.disable +def _propagate_grouped_mlp_trace_row_uids(source: Any, linear_fc2: Any) -> None: + row_token_uids, uid_span = _trace_row_uids_from_source(source) + if row_token_uids is None: + return + _attach_trace_row_uids( + linear_fc2, + row_token_uids=row_token_uids, + uid_span=uid_span, + ) + + +@torch._dynamo.disable +def _propagate_fc2_trace_row_uids( + *, + x: Any, + module: Any, + linear_fc2: Any, + lora: Any, +) -> None: + row_token_uids, uid_span = _trace_row_uids_from_source(x) + if row_token_uids is None: + row_token_uids, uid_span = _trace_row_uids_from_source(module) + if row_token_uids is None: + return + _attach_trace_row_uids( + linear_fc2, + row_token_uids=row_token_uids, + uid_span=uid_span, + ) + _attach_trace_row_uids( + lora, + row_token_uids=row_token_uids, + uid_span=uid_span, + ) + + def _canonicalize_expert_token_order( expert_inputs: torch.Tensor, expert_probs: torch.Tensor, @@ -680,6 +746,56 @@ def _canonical_trace_row_uids( return torch.cat(row_uid_chunks, dim=0).contiguous(), row_uid_span +@torch._dynamo.disable +def _build_dispatch_postprocess_trace( + *, + dispatcher: Any, + controller: Any, + global_input_token_uids: torch.Tensor, + expert_inputs: torch.Tensor, + expert_probs: torch.Tensor, + tokens_per_expert: torch.Tensor | list[int], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: + expert_token_uids = global_input_token_uids + if dispatcher.num_local_experts > 1: + sorted_token_uids = sort_chunks_by_idxs( + expert_token_uids.unsqueeze(-1), + dispatcher.num_global_tokens_per_local_expert.ravel(), + dispatcher.sort_input_by_local_experts, + fused=False, + )[0] + expert_token_uids = sorted_token_uids.reshape(-1).contiguous() + + ( + expert_inputs, + expert_probs, + canonical_expert_token_uids, + inverse_order_cpu, + ) = _canonicalize_expert_token_order( + expert_inputs, + expert_probs, + expert_token_uids, + tokens_per_expert=tokens_per_expert, + ) + active_step_routes = controller._active_step_routes + if active_step_routes is None: + raise RuntimeError("MoE replay dispatcher preprocess called before set_step") + trace_row_uids, trace_uid_span = _canonical_trace_row_uids( + canonical_expert_token_uids, + tokens_per_expert=tokens_per_expert, + local_expert_indices=getattr(dispatcher, "local_expert_indices", None), + sample_uid_span=int(active_step_routes.global_token_uids.numel()), + num_experts=int(getattr(dispatcher, "num_experts", 1)), + ) + return ( + expert_inputs, + expert_probs, + inverse_order_cpu, + trace_row_uids, + trace_uid_span, + ) + + def _patch_alltoall_dispatcher_preprocess() -> None: try: from megatron.core.transformer.moe.experts import TEGroupedMLP @@ -811,40 +927,21 @@ def patched_dispatch_postprocess( if controller is None or global_input_token_uids is None or self.drop_and_pad: return expert_inputs, tokens_per_expert, expert_probs - expert_token_uids = global_input_token_uids - if self.num_local_experts > 1: - sorted_token_uids = sort_chunks_by_idxs( - expert_token_uids.unsqueeze(-1), - self.num_global_tokens_per_local_expert.ravel(), - self.sort_input_by_local_experts, - fused=False, - )[0] - expert_token_uids = sorted_token_uids.reshape(-1).contiguous() - ( expert_inputs, expert_probs, - canonical_expert_token_uids, inverse_order_cpu, - ) = _canonicalize_expert_token_order( - expert_inputs, - expert_probs, - expert_token_uids, + trace_row_uids, + trace_uid_span, + ) = _build_dispatch_postprocess_trace( + dispatcher=self, + controller=controller, + global_input_token_uids=global_input_token_uids, + expert_inputs=expert_inputs, + expert_probs=expert_probs, tokens_per_expert=tokens_per_expert, ) self._art_replay_expert_input_inverse_permutation = inverse_order_cpu - active_step_routes = controller._active_step_routes - if active_step_routes is None: - raise RuntimeError( - "MoE replay dispatcher preprocess called before set_step" - ) - trace_row_uids, trace_uid_span = _canonical_trace_row_uids( - canonical_expert_token_uids, - tokens_per_expert=tokens_per_expert, - local_expert_indices=getattr(self, "local_expert_indices", None), - sample_uid_span=int(active_step_routes.global_token_uids.numel()), - num_experts=int(getattr(self, "num_experts", 1)), - ) _attach_trace_row_uids( expert_inputs, row_token_uids=trace_row_uids, @@ -870,15 +967,10 @@ def patched_te_grouped_mlp_forward( tokens_per_expert: torch.Tensor, permuted_probs: torch.Tensor, ): - row_token_uids, uid_span = _trace_row_uids_from_source( - permuted_local_hidden_states + _propagate_grouped_mlp_trace_row_uids( + permuted_local_hidden_states, + self.linear_fc2, ) - if row_token_uids is not None: - _attach_trace_row_uids( - self.linear_fc2, - row_token_uids=row_token_uids, - uid_span=uid_span, - ) return original_te_grouped_mlp_forward( self, permuted_local_hidden_states, @@ -891,20 +983,12 @@ def patched_fc2_forward( x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor | None]: - row_token_uids, uid_span = _trace_row_uids_from_source(x) - if row_token_uids is None: - row_token_uids, uid_span = _trace_row_uids_from_source(self) - if row_token_uids is not None: - _attach_trace_row_uids( - self.linear_fc2, - row_token_uids=row_token_uids, - uid_span=uid_span, - ) - _attach_trace_row_uids( - self.lora, - row_token_uids=row_token_uids, - uid_span=uid_span, - ) + _propagate_fc2_trace_row_uids( + x=x, + module=self, + linear_fc2=self.linear_fc2, + lora=self.lora, + ) return original_fc2_forward(self, x, tokens_per_expert) setattr(MoEAlltoAllTokenDispatcher, "preprocess", patched_preprocess) @@ -936,9 +1020,11 @@ def __init__( bundle: MoeRoutingReplayBundle, strict: bool, local_token_indexer: LocalTokenIndexer | None = None, + allow_recompute_reuse: bool = True, ) -> None: self.bundle = bundle self.strict = strict + self.allow_recompute_reuse = allow_recompute_reuse self.local_token_indexer = ( local_token_indexer or TopologyAwareLocalTokenIndexer() ) @@ -948,6 +1034,9 @@ def __init__( self._active_step_routes: StepRoutes | None = None self._router_call_cursors: dict[str, int] = {} self._router_call_sequences: dict[str, list[int]] = {} + self._router_last_call_indices: dict[str, int] = {} + self._router_last_call_keys: dict[str, tuple[str, int] | None] = {} + self._router_reuse_counts: dict[str, int] = {} self._global_uid_to_row_index: dict[int, int] = {} self._local_router_keys: set[str] = set() self._active_micro_order: int | None = None @@ -1081,6 +1170,9 @@ def set_step( ) self._router_call_cursors = {} self._router_call_sequences = {} + self._router_last_call_indices = {} + self._router_last_call_keys = {} + self._router_reuse_counts = {} local_call_keys = self._build_local_call_keys( sample_index=sample_index, ) @@ -1169,6 +1261,15 @@ def _router_call_key(route: RouterCallRoute) -> tuple[str, int] | None: return ("dummy_micro_slot", int(route.micro_slot)) return None + def _active_router_call_key(self) -> tuple[str, int] | None: + active_micro_order = self._active_micro_order + if active_micro_order is None: + return None + return self._sample_or_dummy_call_key( + global_sample_index=self._active_sample_index, + local_micro_index=active_micro_order, + ) + @staticmethod def _legacy_router_call_sequence( *, @@ -1241,11 +1342,20 @@ def finalize_step(self) -> None: f"step={self._active_step_index}, router='{router_key}', " f"consumed={consumed}, expected={len(call_sequence)}" ) + if self._router_reuse_counts: + logger.info( + "Routing replay reused routes for recompute: step=%s counts=%s", + self._active_step_index, + dict(sorted(self._router_reuse_counts.items())), + ) self._active_step_index = None self._active_sample_index = None self._active_step_routes = None self._router_call_cursors = {} self._router_call_sequences = {} + self._router_last_call_indices = {} + self._router_last_call_keys = {} + self._router_reuse_counts = {} self._global_uid_to_row_index = {} self._active_micro_order = None if _ACTIVE_ROUTING_REPLAY_CONTROLLER is self: @@ -1272,14 +1382,43 @@ def get_route_for_router( f"step={self._active_step_index}, router='{router_key}'" ) router_calls = step_routes.routers[router_key].calls - if call_cursor >= len(call_sequence): - raise RuntimeError( - "Routing replay call cursor exceeded local call sequence: " - f"step={self._active_step_index}, router='{router_key}', " - f"call_cursor={call_cursor}, sequence_length={len(call_sequence)}" + active_call_key = self._active_router_call_key() + last_call_index = self._router_last_call_indices.get(router_key) + last_call_key = self._router_last_call_keys.get(router_key) + next_call_key = None + if call_cursor < len(call_sequence): + next_call_key = self._router_call_key( + router_calls[call_sequence[call_cursor]] ) - route = router_calls[call_sequence[call_cursor]] - self._router_call_cursors[router_key] = call_cursor + 1 + + if ( + active_call_key is not None + and last_call_index is not None + and last_call_key == active_call_key + and next_call_key != active_call_key + ): + if not self.allow_recompute_reuse: + raise RuntimeError( + "Routing replay recompute reuse is disabled: " + f"step={self._active_step_index}, router='{router_key}', " + f"call_key={active_call_key}" + ) + route = router_calls[last_call_index] + self._router_reuse_counts[router_key] = ( + self._router_reuse_counts.get(router_key, 0) + 1 + ) + else: + if call_cursor >= len(call_sequence): + raise RuntimeError( + "Routing replay call cursor exceeded local call sequence: " + f"step={self._active_step_index}, router='{router_key}', " + f"call_cursor={call_cursor}, sequence_length={len(call_sequence)}" + ) + route_call_index = call_sequence[call_cursor] + route = router_calls[route_call_index] + self._router_call_cursors[router_key] = call_cursor + 1 + self._router_last_call_indices[router_key] = route_call_index + self._router_last_call_keys[router_key] = self._router_call_key(route) num_local_tokens = int(logits.shape[0]) num_experts = int(logits.shape[1]) diff --git a/src/art/megatron/runtime/__init__.py b/src/art/megatron/runtime/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/art/megatron/runtime/__init__.py @@ -0,0 +1 @@ + diff --git a/src/art/megatron/backend.py b/src/art/megatron/runtime/backend.py similarity index 68% rename from src/art/megatron/backend.py rename to src/art/megatron/runtime/backend.py index 8d9f43e7f..5847d1ecb 100644 --- a/src/art/megatron/backend.py +++ b/src/art/megatron/runtime/backend.py @@ -1,11 +1,9 @@ -import os - from mp_actors import move_to_child_process -from ..local.backend import LocalBackend -from ..local.service import ModelService -from ..model import TrainableModel -from ..utils.output_dirs import get_model_dir +from ...local.backend import LocalBackend +from ...local.service import ModelService +from ...model import TrainableModel +from ...utils.output_dirs import get_model_dir class MegatronBackend(LocalBackend): @@ -21,9 +19,8 @@ def __init__( self._supports_result_packing = True async def _get_service(self, model: TrainableModel) -> ModelService: - from ..dev.get_model_config import get_model_config - from ..dev.validate import is_dedicated_mode, validate_dedicated_config - from .service import MegatronService + from ...dev.get_model_config import get_model_config + from ..service import MegatronService if model.name not in self._services: config = get_model_config( @@ -31,19 +28,13 @@ async def _get_service(self, model: TrainableModel) -> ModelService: output_dir=get_model_dir(model=model, art_path=self._path), config=model._internal_config, ) - validate_dedicated_config(config) - dedicated = is_dedicated_mode(config) - if dedicated: - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( - str(gpu_id) for gpu_id in config["trainer_gpu_ids"] - ) self._services[model.name] = MegatronService( model_name=model.name, base_model=model.base_model, config=config, output_dir=get_model_dir(model=model, art_path=self._path), ) - if not dedicated and not self._in_process: + if not self._in_process: self._services[model.name] = move_to_child_process( self._services[model.name], process_name="megatron-service", diff --git a/src/art/megatron/runtime/bridge_runtime.py b/src/art/megatron/runtime/bridge_runtime.py new file mode 100644 index 000000000..7e801691d --- /dev/null +++ b/src/art/megatron/runtime/bridge_runtime.py @@ -0,0 +1,400 @@ +from __future__ import annotations + +from collections.abc import Iterable, Mapping +import contextlib +import fnmatch +from typing import Any, cast + +from megatron.bridge.models.common.unimodal import to_empty_if_meta_device +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.param_mapping import ( + ColumnParallelMapping, + MegatronParamMapping, + ReplicatedMapping, + get_module_and_param_from_name, +) +from megatron.bridge.models.model_provider import ModelProviderMixin +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.enums import ModelType +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.module import Float16Module, MegatronModule +from megatron.core.utils import get_model_config +import torch + + +def _pin_cpu_tensor(tensor: torch.Tensor) -> torch.Tensor: + if tensor.device.type != "cpu" or not torch.cuda.is_available(): + return tensor + try: + return tensor if tensor.is_pinned() else tensor.pin_memory() + except RuntimeError: + return tensor + + +def _iter_hf_param_names(hf_param: Any) -> Iterable[str]: + if isinstance(hf_param, str): + yield hf_param + return + if isinstance(hf_param, Mapping): + for value in hf_param.values(): + yield from _iter_hf_param_names(value) + + +def _needs_local_hf_prefetch(task: Any) -> bool: + if task is None or task.megatron_module is None: + return False + mapping = task.mapping + tp_size = int(getattr(mapping, "tp_size", 1)) + if tp_size <= 1: + return True + if type(mapping).__name__ == "DirectMapping": + return True + return int(getattr(mapping, "tp_rank", 0)) == 0 + + +def load_unique_hf_keys_once( + tasks: Iterable[Any], + hf_state_dict: Mapping[str, torch.Tensor], +) -> dict[str, torch.Tensor]: + keys = sorted( + { + key + for task in tasks + if _needs_local_hf_prefetch(task) + for key in _iter_hf_param_names(task.mapping.hf_param) + } + ) + if not keys: + return {} + if hasattr(hf_state_dict, "__getitem__"): + hf_state_dict_getter = cast(Any, hf_state_dict) + loaded = ( + hf_state_dict_getter[keys] + if not isinstance(hf_state_dict, dict) + else {key: hf_state_dict[key] for key in keys} + ) + else: + loaded = {key: hf_state_dict[key] for key in keys} + return { + key: _pin_cpu_tensor(value) + for key, value in cast(Mapping[str, torch.Tensor], loaded).items() + } + + +class _CachedStateLookup(Mapping[str, torch.Tensor]): + def __init__( + self, + *, + cache: Mapping[str, torch.Tensor], + source: Mapping[str, torch.Tensor], + ) -> None: + self._cache = cache + self._source = source + + def __getitem__(self, key: str) -> torch.Tensor: + if key in self._cache: + return self._cache[key] + return _pin_cpu_tensor(self._source[key]) + + def __iter__(self): + seen = set(self._cache) + yield from self._cache + for key in self._source: + if key not in seen: + yield key + + def __len__(self) -> int: + return len(set(self._cache).union(self._source)) + + +def _materialization_device() -> torch.device: + return torch.device("cuda", torch.cuda.current_device()) + + +def _apply_pre_wrap_hook( + model: list[MegatronModule], + pre_wrap_hook: Any, +) -> list[MegatronModule]: + if pre_wrap_hook is None: + return model + if not callable(pre_wrap_hook): + raise RuntimeError("pre_wrap_hook must be callable") + updated = pre_wrap_hook(model) + return model if updated is None else updated + + +def _set_tp_attrs(model: list[MegatronModule]) -> None: + from megatron.core import tensor_parallel + + for model_module in model: + for param in model_module.parameters(): + tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes( + param + ) + + +def _wrap_with_mp_wrapper( + model: list[MegatronModule], + model_config: Any, + mixed_precision_wrapper: Any, +) -> list[MegatronModule]: + if not (model_config.fp16 or model_config.bf16) or mixed_precision_wrapper is None: + return model + keep_in_fp32: list[tuple[Any, torch.Tensor]] = [] + for model_module in model: + for submodule in model_module.modules(): + if hasattr(submodule, "_maintain_float32_expert_bias"): + expert_bias = getattr(submodule, "expert_bias", None) + if expert_bias is not None: + keep_in_fp32.append((submodule, expert_bias.data.clone())) + wrapped = [ + mixed_precision_wrapper(model_config, model_module) for model_module in model + ] + for submodule, fp32_data in keep_in_fp32: + submodule.expert_bias.data = fp32_data + return wrapped + + +def _art_get_model( + model_provider: ModelProviderMixin, + ddp_config: DistributedDataParallelConfig, + model_type=ModelType.encoder_or_decoder, + overlap_param_gather_with_optimizer_step: bool = False, + fp16: bool | None = None, + bf16: bool | None = None, + use_megatron_fsdp: bool = False, + use_torch_fsdp2: bool = False, + wrap_with_ddp: bool = True, + data_parallel_random_init: bool = False, + use_cpu_initialization: None | bool = False, + init_model_with_meta_device: bool | None = None, + pre_wrap_hook: Any = None, + mixed_precision_wrapper: Any = Float16Module, + *, + pg_collection: ProcessGroupCollection, +) -> list[MegatronModule]: + from megatron.bridge.models import model_provider as model_provider_module + + if fp16: + setattr(model_provider, "fp16", fp16) + if bf16: + setattr(model_provider, "bf16", bf16) + + setattr(model_provider, "use_cpu_initialization", bool(use_cpu_initialization)) + if init_model_with_meta_device: + setattr(model_provider, "init_model_with_meta_device", True) + with torch.device("meta"): + model = model_provider_module._create_model( + model_provider, + model_type, + pg_collection=pg_collection, + ) + else: + model = model_provider_module._create_model( + model_provider, + model_type, + pg_collection=pg_collection, + ) + + if init_model_with_meta_device and not use_torch_fsdp2 and not use_megatron_fsdp: + device = _materialization_device() + model = [ + to_empty_if_meta_device(model_module, device=device) + for model_module in model + ] + + model = _apply_pre_wrap_hook(model, pre_wrap_hook) + _set_tp_attrs(model) + model_provider_module._print_num_params(model, pg_collection=pg_collection) + model_config = get_model_config(model[0]) + + if ( + not use_torch_fsdp2 + and not model_config.use_cpu_initialization + and not model_config.init_model_with_meta_device + ): + for model_module in model: + model_module.cuda(torch.cuda.current_device()) + + model = _wrap_with_mp_wrapper(model, model_config, mixed_precision_wrapper) + if model_provider_module.correct_amax_history_if_needed is not None: + model_provider_module.correct_amax_history_if_needed(cast(Any, model)) + if wrap_with_ddp: + model = model_provider_module._ddp_wrap( + model, + data_parallel_random_init, + ddp_config, + overlap_param_gather_with_optimizer_step, + use_megatron_fsdp=use_megatron_fsdp, + use_torch_fsdp2=use_torch_fsdp2, + pg_collection=pg_collection, + ) + return model + + +def _column_parallel_hf_to_megatron( + self: ColumnParallelMapping, + hf_weights: torch.Tensor, + megatron_module: torch.nn.Module, +) -> torch.Tensor: + if self.tp_size == 1: + return hf_weights + normalized_param = self._normalize_expert_param_name(self.megatron_param) + target_param = get_module_and_param_from_name( + cast(Any, megatron_module), normalized_param + )[1] + if self.tp_rank == 0: + full_size = hf_weights.shape[0] + if full_size % self.tp_size != 0: + raise ValueError( + f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks" + ) + splits = list(torch.chunk(hf_weights, self.tp_size, dim=0)) + else: + splits = None + return self.scatter_to_tp_ranks( + splits, + target_param.shape, + target_param.dtype, + target_param.device, + ) + + +def _scatter_to_tp_ranks( + self: MegatronParamMapping, + splits: list[torch.Tensor] | None, + output_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + src_rank: int = 0, +) -> torch.Tensor: + if self.tp_size == 1: + return cast(list[torch.Tensor], splits)[0].to( + device=device, dtype=dtype, non_blocking=True + ) + output = torch.empty(output_shape, dtype=dtype, device=device) + dist = cast(Any, torch.distributed) + global_src = dist.get_global_rank(group=self.tp_group, group_rank=src_rank) + scatter_list = None + if self.tp_rank == src_rank and splits: + scatter_list = [ + shard.to(device=device, dtype=dtype, non_blocking=True) for shard in splits + ] + dist.scatter(output, scatter_list, src=global_src, group=self.tp_group) + return output + + +def _replicated_hf_to_megatron( + self: ReplicatedMapping, + hf_weights: torch.Tensor, + megatron_module: torch.nn.Module, +) -> torch.Tensor: + if hasattr(megatron_module, "weight"): + target_device = cast(Any, megatron_module).weight.device + else: + target_device = next(megatron_module.parameters()).device + if self.tp_size == 1: + return hf_weights.to(device=target_device, non_blocking=True) + broadcast_device = target_device + if ( + broadcast_device.type != "cuda" + or broadcast_device.index != torch.cuda.current_device() + ): + broadcast_device = _materialization_device() + if self.tp_rank == 0: + tensor = hf_weights.to(device=cast(Any, broadcast_device), non_blocking=True) + else: + tensor = torch.empty_like(hf_weights, device=cast(Any, broadcast_device)) + return self.broadcast_tensor_to_tp_ranks(tensor, src_rank=0) + + +def _optimized_load_weights_hf_to_megatron( + self: MegatronModelBridge, + hf_pretrained: Any, + megatron_model: Any, + allowed_mismatched_params: list[str] | None = None, +) -> list[Any]: + if not isinstance(megatron_model, list): + megatron_model = [megatron_model] + with contextlib.ExitStack() as stack: + if hasattr(megatron_model[0], "hide_teacher_model"): + stack.enter_context(megatron_model[0].hide_teacher_model()) + if hasattr(megatron_model[0], "hide_loss_modules"): + stack.enter_context(megatron_model[0].hide_loss_modules()) + tasks = self.build_conversion_tasks(hf_pretrained, megatron_model) + hf_state_dict = hf_pretrained.state + raw_cache = load_unique_hf_keys_once(tasks, hf_state_dict) + cached_state = _CachedStateLookup(cache=raw_cache, source=hf_state_dict) + description = f"Loading from {hf_pretrained.model_name_or_path}" + pending_device_copy = False + for task in self._with_progress_tracking(tasks, description): + if task is None or task.megatron_module is None: + continue + hf_weights = self.maybe_modify_loaded_hf_weight( + task.mapping.hf_param, cached_state + ) + converted_weights = task.mapping.hf_to_megatron( + hf_weights, task.megatron_module + ) + if converted_weights is None: + continue + assert task.param_weight is not None, ( + "param_weight is required for HF->Megatron conversion" + ) + if converted_weights.shape != task.param_weight.shape: + is_whitelisted = False + if allowed_mismatched_params: + for pattern in allowed_mismatched_params: + if fnmatch.fnmatch( + task.mapping.megatron_param, pattern + ) or fnmatch.fnmatch(task.param_name, pattern): + is_whitelisted = True + break + if is_whitelisted: + continue + raise ValueError( + f"Shape mismatch for megatron param {task.mapping.megatron_param}:\n" + f" Expected shape: {task.param_weight.shape}\n" + f" Got shape: {converted_weights.shape}\n" + f" Bridge type: {type(task.mapping).__name__}\n" + f" HF mapping: {task.mapping.hf_param}" + ) + task.param_weight.data.copy_(converted_weights, non_blocking=True) + if task.param_weight.device.type == "cuda": + pending_device_copy = True + if pending_device_copy and torch.cuda.is_available(): + torch.cuda.synchronize() + self._broadcast_shared_embeddings(megatron_model) + return megatron_model + + +def install_art_bridge_runtime_patches() -> None: + from megatron.bridge.models import model_provider as model_provider_module + + if not getattr( + model_provider_module.get_model, "__art_meta_materialization__", False + ): + setattr(_art_get_model, "__art_meta_materialization__", True) + setattr(model_provider_module, "get_model", _art_get_model) + if not getattr( + MegatronParamMapping.scatter_to_tp_ranks, "__art_non_blocking__", False + ): + setattr(_scatter_to_tp_ranks, "__art_non_blocking__", True) + setattr(MegatronParamMapping, "scatter_to_tp_ranks", _scatter_to_tp_ranks) + if not getattr(ColumnParallelMapping.hf_to_megatron, "__art_cast_last__", False): + setattr(_column_parallel_hf_to_megatron, "__art_cast_last__", True) + setattr( + ColumnParallelMapping, "hf_to_megatron", _column_parallel_hf_to_megatron + ) + if not getattr(ReplicatedMapping.hf_to_megatron, "__art_cast_last__", False): + setattr(_replicated_hf_to_megatron, "__art_cast_last__", True) + setattr(ReplicatedMapping, "hf_to_megatron", _replicated_hf_to_megatron) + if not getattr( + MegatronModelBridge.load_weights_hf_to_megatron, "__art_cached_load__", False + ): + setattr(_optimized_load_weights_hf_to_megatron, "__art_cached_load__", True) + setattr( + MegatronModelBridge, + "load_weights_hf_to_megatron", + _optimized_load_weights_hf_to_megatron, + ) diff --git a/src/art/megatron/client.py b/src/art/megatron/runtime/client.py similarity index 67% rename from src/art/megatron/client.py rename to src/art/megatron/runtime/client.py index 79fcfeef5..34efafa63 100644 --- a/src/art/megatron/client.py +++ b/src/art/megatron/runtime/client.py @@ -4,8 +4,9 @@ import os from typing import Any, AsyncIterator -from .jobs import DEFAULT_JOBS_DIR, MegatronJob -from .merge import merge_lora_adapter +from art.megatron.weights.merge import merge_lora_adapter + +from .jobs import DEFAULT_JOBS_DIR, MegatronJob, MegatronSyncJob, dump_megatron_job DEFAULT_TRAINING_LOG_DIR = "/tmp/megatron_training_logs" @@ -27,19 +28,26 @@ def create_megatron_job_paths( def write_megatron_job(job: MegatronJob, *, job_path: str) -> None: os.makedirs(os.path.dirname(job_path), exist_ok=True) with open(job_path, "w", encoding="utf-8") as handle: - handle.write(job.model_dump_json()) + handle.write(dump_megatron_job(job)) async def stream_megatron_job( job: MegatronJob, *, job_path: str, + process: Any | None = None, + process_log_path: str | None = None, poll_interval: float = 0.1, ) -> AsyncIterator[dict[str, Any]]: num_lines = 0 try: while True: await asyncio.sleep(poll_interval) + if process is not None and process.returncode is not None: + raise RuntimeError( + f"Megatron worker exited with code {process.returncode}. " + f"Check logs at {process_log_path or job.log_path}" + ) try: with open(job.log_path, "a+", encoding="utf-8") as log_file: log_file.seek(0) @@ -51,7 +59,11 @@ async def stream_megatron_job( if not (line := line.strip()): continue if line == "all done": - merge_lora_adapter(job.lora_path) + if not isinstance(job, MegatronSyncJob): + merge_lora_adapter( + job.lora_path, + allow_unvalidated_arch=job.allow_unvalidated_arch, + ) return num_lines += 1 yield json.loads(line) diff --git a/src/art/megatron/jobs.py b/src/art/megatron/runtime/jobs.py similarity index 57% rename from src/art/megatron/jobs.py rename to src/art/megatron/runtime/jobs.py index bd20d0748..0044d210b 100644 --- a/src/art/megatron/jobs.py +++ b/src/art/megatron/runtime/jobs.py @@ -1,26 +1,15 @@ -from typing import Any, Literal +from typing import Annotated, Any, Literal, TypeAlias -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, TypeAdapter -from .. import types -from ..preprocessing.pack import DiskPackedTensors +from ... import types +from ...preprocessing.pack import DiskPackedTensors DEFAULT_TRAINING_LOG_PATH = "/tmp/megatron_training_log.jsonl" DEFAULT_JOBS_DIR = "/tmp/megatron_training_jobs" DEFAULT_VLLM_WAKE_LOCK_PATH = "/tmp/megatron_vllm_waking" -class MegatronTrainingJob(BaseModel): - lora_path: str - optimizer_state_path: str - disk_packed_tensors: DiskPackedTensors - config: types.TrainConfig - experimental_config: dict[str, Any] - moe_routing_replay_path: str | None = None - moe_routing_replay_strict: bool = True - log_path: str = DEFAULT_TRAINING_LOG_PATH - - class MergedWeightTransferInitInfo(BaseModel): master_address: str master_port: int @@ -32,23 +21,42 @@ class MergedWeightTransferSpec(BaseModel): init_info: MergedWeightTransferInitInfo vllm_base_url: str served_model_name: str + api_key: str | None = None -class MegatronMergedTrainJob(MegatronTrainingJob): - job_type: Literal["merged"] = "merged" +class _MegatronTrainingJobBase(BaseModel): + lora_path: str + allow_unvalidated_arch: bool = False + optimizer_state_path: str + disk_packed_tensors: DiskPackedTensors + config: types.TrainConfig + experimental_config: dict[str, Any] + moe_routing_replay_path: str | None = None + moe_routing_replay_strict: bool = True + log_path: str = DEFAULT_TRAINING_LOG_PATH + + +class MegatronTrainingJob(_MegatronTrainingJobBase): + kind: Literal["train_lora"] = "train_lora" + + +class MegatronMergedTrainingJob(_MegatronTrainingJobBase): + kind: Literal["train_merged"] = "train_merged" merged_weight_transfer: MergedWeightTransferSpec class MegatronSyncJob(BaseModel): - job_type: Literal["sync"] = "sync" + kind: Literal["sync"] = "sync" lora_path: str + allow_unvalidated_arch: bool = False merged_weight_transfer: MergedWeightTransferSpec log_path: str = DEFAULT_TRAINING_LOG_PATH class MegatronSFTTrainingJob(BaseModel): - job_type: Literal["sft"] = "sft" + kind: Literal["sft"] = "sft" lora_path: str + allow_unvalidated_arch: bool = False optimizer_state_path: str sft_data_dir: str num_batches: int @@ -60,9 +68,20 @@ class MegatronSFTTrainingJob(BaseModel): log_path: str = DEFAULT_TRAINING_LOG_PATH -MegatronJob = ( +MegatronJob: TypeAlias = Annotated[ MegatronTrainingJob - | MegatronMergedTrainJob + | MegatronMergedTrainingJob | MegatronSyncJob - | MegatronSFTTrainingJob -) + | MegatronSFTTrainingJob, + Field(discriminator="kind"), +] + +_MEGATRON_JOB_ADAPTER = TypeAdapter(MegatronJob) + + +def dump_megatron_job(job: MegatronJob) -> str: + return _MEGATRON_JOB_ADAPTER.dump_json(job).decode() + + +def load_megatron_job(raw: str | bytes) -> MegatronJob: + return _MEGATRON_JOB_ADAPTER.validate_json(raw) diff --git a/src/art/megatron/runtime_env.py b/src/art/megatron/runtime/runtime_env.py similarity index 100% rename from src/art/megatron/runtime_env.py rename to src/art/megatron/runtime/runtime_env.py diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index a1ce338d0..87a5d65ea 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -1,24 +1,16 @@ import asyncio from dataclasses import dataclass, field -from functools import cached_property import importlib -import json -import logging import os from pathlib import Path -import shlex import shutil -import signal import socket import subprocess import sys -from typing import Any, AsyncIterator, Literal, cast +from typing import Any, AsyncIterator, Literal, TypedDict, cast from peft.tuners.lora.config import LoraConfig import torch -from vllm import AsyncEngineArgs -from vllm.lora.request import LoRARequest -from vllm.v1.engine.async_llm import AsyncLLM from .. import dev, types from ..dev.get_model_config import default_target_modules @@ -26,34 +18,55 @@ from ..local.checkpoints import get_last_checkpoint_dir from ..preprocessing.pack import DiskPackedTensors from ..preprocessing.tokenize import SFTBatch -from ..unsloth.service import do_sleep, do_wake_up, gc_and_empty_cuda_cache +from ..unsloth.train import gc_and_empty_cuda_cache from ..utils.convert_moe_lora import convert_checkpoint_if_needed from ..utils.get_model_step import get_step_from_dir -from ..utils.network import find_free_tcp_port +from ..utils.lifecycle import ( + ChildProcessSupervisor, + ServiceLifecycle, + managed_process_cmd, + terminate_asyncio_process_group, + terminate_popen_process_group, +) from ..utils.output_dirs import get_step_checkpoint_dir -from ..vllm import get_llm, openai_server_task, run_on_workers -from .client import create_megatron_job_paths, stream_megatron_job, write_megatron_job -from .jobs import ( - MegatronMergedTrainJob, +from ..vllm_runtime import ( + VllmRuntimeLaunchConfig, + build_vllm_runtime_server_cmd, + get_vllm_runtime_working_dir, + wait_for_vllm_runtime, +) +from .lora import LORA_ALPHA, LORA_RANK +from .model_support.lora_disk import normalize_lora_checkpoint_to_vllm +from .runtime.client import ( + create_megatron_job_paths, + stream_megatron_job, + write_megatron_job, +) +from .runtime.jobs import ( + MegatronMergedTrainingJob, MegatronSFTTrainingJob, MegatronSyncJob, MegatronTrainingJob, MergedWeightTransferInitInfo, MergedWeightTransferSpec, ) -from .lora import LORA_ALPHA, LORA_RANK -from .sft_batches import materialize_sft_batches +from .training.sft_batches import materialize_sft_batches safetensors = importlib.import_module("safetensors") safe_open = safetensors.safe_open +class _RuntimeRequestKwargs(TypedDict, total=False): + headers: dict[str, str] + + def create_identity_lora( base_model: str, lora_path: str, rank: int = LORA_RANK, lora_alpha: int = LORA_ALPHA, random_state: int | None = None, + allow_unvalidated_arch: bool = False, ) -> None: """Create an identity LoRA adapter for a Megatron model. @@ -73,15 +86,17 @@ def create_identity_lora( from peft import get_peft_model from transformers import AutoConfig, AutoModelForCausalLM + from .model_support import get_model_support_handler + if random_state is not None: torch.manual_seed(random_state) + target_modules = default_target_modules(base_model) + handler = get_model_support_handler( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) base_config = AutoConfig.from_pretrained(base_model, trust_remote_code=True) - model_config = base_config - nested_text_config = getattr(base_config, "text_config", None) - if not hasattr(base_config, "vocab_size") and hasattr( - nested_text_config, "vocab_size" - ): - model_config = nested_text_config + model_config = handler.identity_lora_model_config(base_config) with init_empty_weights(): model = AutoModelForCausalLM.from_config( model_config, torch_dtype=torch.bfloat16, trust_remote_code=True @@ -93,26 +108,10 @@ def create_identity_lora( r=rank, lora_alpha=lora_alpha, target_modules=[], - target_parameters=[ - name - for name, _ in model.named_parameters() - if name.endswith( - ( - "q_proj.weight", - "k_proj.weight", - "v_proj.weight", - "o_proj.weight", - "linear_attn.in_proj_qkv.weight", - "linear_attn.in_proj_z.weight", - "linear_attn.out_proj.weight", - "mlp.experts.gate_up_proj", - "mlp.experts.down_proj", - "mlp.shared_expert.gate_proj.weight", - "mlp.shared_expert.up_proj.weight", - "mlp.shared_expert.down_proj.weight", - ) - ) - ], + target_parameters=handler.identity_lora_target_parameters( + model, + target_modules=target_modules, + ), bias="none", ) @@ -134,24 +133,24 @@ def _skip_meta_to( peft_model.save_pretrained(lora_path) convert_checkpoint_if_needed(lora_path) - # Write final adapter_config with per-expert target_modules - LoraConfig( + final_config = LoraConfig( base_model_name_or_path=base_model, r=rank, lora_alpha=lora_alpha, - target_modules=default_target_modules(base_model), + target_modules=target_modules, bias="none", - ).save_pretrained(lora_path) - + ).to_dict() + normalize_lora_checkpoint_to_vllm( + lora_path, + handler=handler, + adapter_config=final_config, + ) del peft_model, model if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache() -logger = logging.getLogger(__name__) - - @dataclass class MegatronService: model_name: str @@ -160,16 +159,32 @@ class MegatronService: output_dir: str _is_sleeping: bool = False _latest_step: int = 0 - _lora_id_counter: int = 1 _megatron_process: asyncio.subprocess.Process | None = None - _vllm_process: subprocess.Popen | None = field(default=None, repr=False) # type: ignore[type-arg] - _vllm_log_file: Any = field(default=None, repr=False) + _megatron_log_file: Any = None + _megatron_log_path: str | None = None + _vllm_process: subprocess.Popen[Any] | None = None + _vllm_log_file: Any = None + _vllm_log_path: str | None = None _vllm_host: str = "127.0.0.1" _vllm_port: int = 0 - _merged_weight_transfer_init_info: MergedWeightTransferInitInfo | None = field( - default=None, + _vllm_api_key: str | None = None + _merged_weight_transfer_init_info: MergedWeightTransferInitInfo | None = None + _lifecycle: ServiceLifecycle = field( + default_factory=ServiceLifecycle, + init=False, repr=False, ) + _child_processes: ChildProcessSupervisor = field(init=False, repr=False) + + def __post_init__(self) -> None: + self._child_processes = ChildProcessSupervisor(self._on_child_process_exit) + self._validate_megatron_dependencies() + + def _on_child_process_exit(self, _error: RuntimeError) -> None: + self.close() + + def _raise_if_child_failed(self) -> None: + self._child_processes.raise_if_failed() @property def is_dedicated(self) -> bool: @@ -192,6 +207,10 @@ def _megatron_random_state(self) -> int | None: return int(random_state) return None + @property + def _allow_unvalidated_arch(self) -> bool: + return bool(self.config.get("allow_unvalidated_arch", False)) + def _megatron_runtime_paths(self) -> tuple[str, str, str]: runtime_dir = Path(self.output_dir) / "megatron_runtime" jobs_dir = runtime_dir / "jobs" @@ -204,14 +223,72 @@ def _megatron_runtime_paths(self) -> tuple[str, str, str]: str(runtime_dir / "vllm_waking.lock"), ) + def _clear_wake_lock(self) -> None: + _, _, wake_lock_path = self._megatron_runtime_paths() + if os.path.exists(wake_lock_path): + os.remove(wake_lock_path) + def _allocate_master_port(self) -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.bind(("", 0)) return int(sock.getsockname()[1]) - def _next_lora_id(self) -> int: - self._lora_id_counter += 1 - return self._lora_id_counter + def _install_parent_signal_cleanup(self) -> None: + self._lifecycle.install_parent_cleanup(self.close) + + def _restore_parent_signal_cleanup(self) -> None: + self._lifecycle.restore_parent_cleanup() + + def _runtime_cuda_visible_devices(self) -> str: + if self.is_dedicated: + return ",".join(str(gpu_id) for gpu_id in self.config["inference_gpu_ids"]) + if visible := os.environ.get("CUDA_VISIBLE_DEVICES"): + return visible + return ",".join(str(index) for index in range(torch.cuda.device_count())) + + def _runtime_engine_args( + self, config: dev.OpenAIServerConfig | None + ) -> dict[str, object]: + engine_args = dict(self.config.get("engine_args", {})) + if config and "engine_args" in config: + engine_args.update(dict(config["engine_args"])) + engine_args.setdefault("generation_config", "vllm") + if self.rollout_weights_mode == "merged": + engine_args["weight_transfer_config"] = {"backend": "nccl"} + engine_args.pop("enable_lora", None) + engine_args.pop("max_loras", None) + else: + engine_args["enable_lora"] = True + engine_args.setdefault("max_loras", 2) + for key in ("model", "served_model_name"): + engine_args.pop(key, None) + return engine_args + + def _runtime_server_args( + self, config: dev.OpenAIServerConfig | None + ) -> dict[str, object]: + server_args: dict[str, object] = { + "return_tokens_as_token_ids": True, + "enable_auto_tool_choice": True, + "tool_call_parser": "hermes", + } + if config and "server_args" in config: + server_args.update(dict(config["server_args"])) + for key in ("port", "host", "lora_modules"): + server_args.pop(key, None) + return server_args + + def _runtime_headers(self) -> dict[str, str]: + if self._vllm_api_key is None: + return {} + return {"Authorization": f"Bearer {self._vllm_api_key}"} + + def _runtime_request_kwargs(self) -> _RuntimeRequestKwargs: + headers = self._runtime_headers() + return {"headers": headers} if headers else {} + + def _sleep_mode_enabled(self) -> bool: + return bool(self.config.get("engine_args", {}).get("enable_sleep_mode", True)) def _get_optimizer_state_path(self, job_type: Literal["rl", "sft"]) -> str: optimizer_state_path = os.path.join( @@ -229,29 +306,28 @@ def _default_lora_adapter_config(self) -> LoraConfig: bias="none", ) - def _adapter_has_weights(self, lora_path: str) -> bool: + def _adapter_exists_and_loads(self, lora_path: str) -> bool: adapter_path = os.path.join(lora_path, "adapter_model.safetensors") if not os.path.exists(adapter_path): return False - try: - with safe_open(adapter_path, framework="pt") as adapter_file: - for key in adapter_file.keys(): - tensor = adapter_file.get_tensor(key) - if torch.any(tensor != 0): - return True - except Exception: - return False - return False + with safe_open(adapter_path, framework="pt") as adapter_file: + keys = list(adapter_file.keys()) + if not keys: + raise RuntimeError(f"LoRA adapter contains no tensors: {adapter_path}") + for key in keys: + adapter_file.get_tensor(key) + return True def _create_identity_lora(self, lora_path: str) -> None: create_identity_lora( self.base_model, lora_path, random_state=self._megatron_random_state(), + allow_unvalidated_arch=self._allow_unvalidated_arch, ) def _ensure_identity_lora(self, lora_path: str) -> None: - if self._adapter_has_weights(lora_path): + if self._adapter_exists_and_loads(lora_path): return self._create_identity_lora(lora_path) @@ -276,6 +352,7 @@ def _build_merged_weight_transfer_spec(self, step: int) -> MergedWeightTransferS init_info=init_info, vllm_base_url=self._vllm_base_url, served_model_name=f"{self.model_name}@{step}", + api_key=self._vllm_api_key, ) def _resolve_active_lora_path(self) -> str: @@ -285,18 +362,19 @@ def _resolve_active_lora_path(self) -> str: self._latest_step = 0 else: self._latest_step = get_step_from_dir(self.output_dir) - if self.is_dedicated or self.rollout_weights_mode == "lora": - self._ensure_identity_lora(lora_path) + self._ensure_identity_lora(lora_path) self._ensure_lora_adapter_config(lora_path) return lora_path async def _set_served_model_name(self, step: int) -> None: import httpx + self._raise_if_child_failed() async with httpx.AsyncClient() as client: response = await client.post( f"{self._vllm_base_url}/art/set_served_model_name", json={"name": f"{self.model_name}@{step}"}, + **self._runtime_request_kwargs(), timeout=30.0, ) response.raise_for_status() @@ -305,19 +383,20 @@ async def _set_served_model_name(self, step: int) -> None: async def _init_merged_weight_transfer(self) -> None: import httpx + self._raise_if_child_failed() if self._merged_weight_transfer_init_info is not None: return - assert len(self.config["trainer_gpu_ids"]) == 1 async with httpx.AsyncClient() as client: response = await client.get( f"{self._vllm_base_url}/get_world_size", + **self._runtime_request_kwargs(), timeout=30.0, ) response.raise_for_status() inference_world_size = int(response.json()["world_size"]) self._merged_weight_transfer_init_info = MergedWeightTransferInitInfo( master_address="127.0.0.1", - master_port=find_free_tcp_port(), + master_port=self._allocate_master_port(), rank_offset=1, world_size=inference_world_size + 1, ) @@ -328,100 +407,91 @@ async def _start_vllm_subprocess( port: int, config: dev.OpenAIServerConfig | None, ) -> tuple[str, int]: - import atexit - import httpx - inference_gpu_ids = self.config["inference_gpu_ids"] - cuda_devices = ",".join(str(gpu_id) for gpu_id in inference_gpu_ids) - - server_args: dict[str, object] = { - "return_tokens_as_token_ids": True, - "enable_auto_tool_choice": True, - "tool_call_parser": "hermes", - } - if config and "server_args" in config: - server_args.update(dict(config["server_args"])) - for key in ("port", "host", "lora_modules", "api_key"): - server_args.pop(key, None) - - engine_args = dict(self.config.get("engine_args", {})) - if config and "engine_args" in config: - engine_args.update(dict(config["engine_args"])) - engine_args.setdefault("generation_config", "vllm") - if self.rollout_weights_mode == "merged": - engine_args["weight_transfer_config"] = {"backend": "nccl"} - engine_args.pop("enable_lora", None) - engine_args.pop("max_loras", None) - else: - engine_args["enable_lora"] = True - engine_args.setdefault("max_loras", 2) - for key in ("model", "served_model_name", "enable_sleep_mode"): - engine_args.pop(key, None) - - cmd = [ - sys.executable, - "-m", - "art.vllm.dedicated_server", - f"--model={self.base_model}", - f"--port={port}", - f"--host={self._vllm_host}", - f"--cuda-visible-devices={cuda_devices}", - f"--lora-path={lora_path}", - f"--served-model-name={self.model_name}@{self._latest_step}", - f"--rollout-weights-mode={self.rollout_weights_mode}", - f"--engine-args-json={json.dumps(engine_args)}", - f"--server-args-json={json.dumps(server_args)}", - ] + self._raise_if_child_failed() + server_args = self._runtime_server_args(config) + api_key = server_args.get("api_key") + self._vllm_api_key = api_key if isinstance(api_key, str) else None + cmd = build_vllm_runtime_server_cmd( + VllmRuntimeLaunchConfig( + base_model=self.base_model, + port=port, + host=self._vllm_host, + cuda_visible_devices=self._runtime_cuda_visible_devices(), + lora_path=lora_path, + served_model_name=f"{self.model_name}@{self._latest_step}", + rollout_weights_mode=self.rollout_weights_mode, + engine_args=self._runtime_engine_args(config), + server_args=server_args, + ) + ) log_dir = os.path.join(self.output_dir, "logs") os.makedirs(log_dir, exist_ok=True) - self._vllm_log_file = open( - os.path.join(log_dir, "vllm-dedicated.log"), "w", buffering=1 - ) + self._vllm_log_path = os.path.join(log_dir, "vllm-runtime.log") + self._vllm_log_file = open(self._vllm_log_path, "w", buffering=1) self._vllm_process = subprocess.Popen( - cmd, + managed_process_cmd(cmd), + cwd=str(get_vllm_runtime_working_dir()), + env=os.environ.copy(), stdout=self._vllm_log_file, stderr=subprocess.STDOUT, bufsize=1, + start_new_session=True, ) + self._install_parent_signal_cleanup() self._vllm_port = port - timeout = float(os.environ.get("ART_DEDICATED_VLLM_TIMEOUT", 600)) - elapsed = 0.0 + timeout = float(os.environ.get("ART_DEDICATED_VLLM_TIMEOUT", 1200)) async with httpx.AsyncClient() as client: - while elapsed < timeout: - if self._vllm_process.poll() is not None: - raise RuntimeError( - "vLLM subprocess exited with code " - f"{self._vllm_process.returncode}. " - f"Check logs at {log_dir}/vllm-dedicated.log" - ) - try: - response = await client.get( - f"{self._vllm_base_url}/v1/models", - timeout=5.0, - ) - if response.status_code == 200: - break - except (httpx.ConnectError, httpx.ReadTimeout): - pass - await asyncio.sleep(1.0) - elapsed += 1.0 - else: + try: + await wait_for_vllm_runtime( + process=self._vllm_process, + host=self._vllm_host, + port=self._vllm_port, + timeout=timeout, + ) + except TimeoutError as exc: self._stop_vllm_subprocess() raise TimeoutError( f"vLLM subprocess did not become ready within {timeout}s. " - f"Check logs at {log_dir}/vllm-dedicated.log" - ) + f"Check logs at {log_dir}/vllm-runtime.log" + ) from exc + except RuntimeError as exc: + returncode = self._vllm_process.returncode + self._stop_vllm_subprocess() + raise RuntimeError( + f"vLLM subprocess exited with code {returncode}. " + f"Check logs at {log_dir}/vllm-runtime.log" + ) from exc - atexit.register(self.close) - logger.info("vLLM subprocess ready on port %d (GPUs: %s)", port, cuda_devices) + try: + response = await client.get( + f"{self._vllm_base_url}/v1/models", + **self._runtime_request_kwargs(), + timeout=5.0, + ) + response.raise_for_status() + except httpx.HTTPError as exc: + self._stop_vllm_subprocess() + raise RuntimeError( + "vLLM passed /health but /v1/models was not reachable. " + f"Check logs at {log_dir}/vllm-runtime.log" + ) from exc + assert self._vllm_process is not None + assert self._vllm_log_path is not None + self._child_processes.watch_popen( + "vLLM runtime", + self._vllm_process, + log_path=self._vllm_log_path, + ) return self._vllm_host, self._vllm_port async def _reload_adapter(self, checkpoint_path: str, step: int) -> None: import httpx + self._raise_if_child_failed() async with httpx.AsyncClient() as client: response = await client.post( f"{self._vllm_base_url}/v1/load_lora_adapter", @@ -430,6 +500,7 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None: "lora_path": checkpoint_path, "load_inplace": True, }, + **self._runtime_request_kwargs(), timeout=60.0, ) response.raise_for_status() @@ -441,116 +512,143 @@ async def _sync_dedicated_merged_weights( lora_path: str, step: int, ) -> None: + self._raise_if_child_failed() await self._ensure_megatron_running() await self._init_merged_weight_transfer() + self._clear_pending_jobs() job_path, log_path = self._create_megatron_job_paths() job = MegatronSyncJob( lora_path=lora_path, + allow_unvalidated_arch=self._allow_unvalidated_arch, merged_weight_transfer=self._build_merged_weight_transfer_spec(step), log_path=log_path, ) write_megatron_job(job, job_path=job_path) - async for _ in stream_megatron_job(job, job_path=job_path): + async for _ in stream_megatron_job( + job, + job_path=job_path, + process=self._megatron_process, + process_log_path=self._megatron_log_path, + ): pass self._latest_step = step - def _stop_vllm_subprocess(self) -> None: - if self._vllm_process is not None: - self._vllm_process.terminate() - try: - self._vllm_process.wait(timeout=5) - except subprocess.TimeoutExpired: - self._vllm_process.kill() - self._vllm_process.wait() - self._vllm_process = None - if self._vllm_log_file is not None: - self._vllm_log_file.close() - self._vllm_log_file = None - self._merged_weight_transfer_init_info = None + async def _sleep_runtime(self) -> None: + import httpx - def _stop_megatron_process(self) -> None: - if self._megatron_process is None: - return - if self._megatron_process.returncode is None: - os.killpg(os.getpgid(self._megatron_process.pid), signal.SIGTERM) - self._megatron_process = None + self._raise_if_child_failed() + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self._vllm_base_url}/sleep", + params={"level": 1, "mode": "wait"}, + **self._runtime_request_kwargs(), + timeout=300.0, + ) + response.raise_for_status() + self._is_sleeping = True - async def _add_lora_aliases( - self, llm: AsyncLLM, step: int, checkpoint_dir: str - ) -> None: - added = await llm.add_lora( - LoRARequest( - lora_name=f"{self.model_name}@{step}", - lora_int_id=self._next_lora_id(), - lora_path=checkpoint_dir, + async def _wake_runtime(self) -> None: + import httpx + + self._raise_if_child_failed() + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self._vllm_base_url}/wake_up", + **self._runtime_request_kwargs(), + timeout=300.0, ) - ) - if not added: - raise RuntimeError(f"Failed to add LoRA adapter for step {step}") - self._latest_step = step + response.raise_for_status() + self._is_sleeping = False async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: - if self.is_dedicated: - if self.rollout_weights_mode == "merged": - await self._set_served_model_name(step) - else: - await self._reload_adapter(checkpoint_dir, step) - return - llm = await self.llm - await llm.pause_generation() - await self._add_lora_aliases(llm, step, checkpoint_dir) - await llm.resume_generation() + self._raise_if_child_failed() + if self.rollout_weights_mode == "merged": + await self._set_served_model_name(step) + else: + await self._reload_adapter(checkpoint_dir, step) + self._latest_step = step + + def _validate_megatron_dependencies(self) -> None: + try: + import megatron.bridge # type: ignore + except ImportError as exc: + raise RuntimeError( + "Megatron dependencies are not available in the active ART environment. " + "Run `setup.sh` for this worktree and build the project venv with " + "`uv sync --extra backend --extra megatron` before starting Megatron " + "training." + ) from exc async def _ensure_megatron_running(self) -> None: """Lazily start Megatron training process if not running.""" + self._raise_if_child_failed() if self._megatron_process is not None: if self._megatron_process.returncode is None: return self._megatron_process = None - try: - import megatron.bridge # type: ignore - - setup_cmd = "" - except ImportError: - setup_script = Path(__file__).parent / "setup.sh" - setup_cmd = f"bash {setup_script} && " + self._validate_megatron_dependencies() train_script = Path(__file__).parent / "train.py" project_root = Path(__file__).resolve().parents[3] - jobs_dir, _training_log_dir, wake_lock_path = self._megatron_runtime_paths() - launch_env = os.environ.copy() + env = os.environ.copy() if self.is_dedicated: trainer_gpu_ids = self.config["trainer_gpu_ids"] num_gpus = len(trainer_gpu_ids) - launch_env["CUDA_VISIBLE_DEVICES"] = ",".join( + env["CUDA_VISIBLE_DEVICES"] = ",".join( str(gpu_id) for gpu_id in trainer_gpu_ids ) else: num_gpus = torch.cuda.device_count() - launch_env["MODEL_IDENTIFIER"] = self.base_model - launch_env["ART_MEGATRON_JOBS_DIR"] = jobs_dir - launch_env["ART_MEGATRON_WAKE_LOCK_PATH"] = wake_lock_path - master_addr = launch_env.get("MASTER_ADDR", "127.0.0.1") + jobs_dir, _training_log_dir, wake_lock_path = self._megatron_runtime_paths() + env["MODEL_IDENTIFIER"] = self.base_model + if self._allow_unvalidated_arch: + env["ART_MEGATRON_ALLOW_UNVALIDATED_ARCH"] = "1" + env["ART_MEGATRON_JOBS_DIR"] = jobs_dir + env["ART_MEGATRON_WAKE_LOCK_PATH"] = wake_lock_path + master_addr = env.get("MASTER_ADDR", "127.0.0.1") master_port = str(self._allocate_master_port()) - launch_env["MASTER_ADDR"] = master_addr - launch_env["MASTER_PORT"] = master_port + env["MASTER_ADDR"] = master_addr + env["MASTER_PORT"] = master_port random_state = self._megatron_random_state() if random_state is not None: - launch_env["ART_MEGATRON_RANDOM_STATE"] = str(random_state) + env["ART_MEGATRON_RANDOM_STATE"] = str(random_state) - command = ( - f"{setup_cmd}uv run --project {shlex.quote(str(project_root))} " - f"torchrun --master-addr {shlex.quote(master_addr)} " - f"--master-port {shlex.quote(master_port)} " - f"--nproc_per_node {num_gpus} {shlex.quote(str(train_script))}" + command = [ + sys.executable, + "-m", + "torch.distributed.run", + "--master-addr", + master_addr, + "--master-port", + master_port, + "--nproc_per_node", + str(num_gpus), + str(train_script), + ] + log_dir = Path(self.output_dir) / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + megatron_log_path = str(log_dir / "megatron-runtime.log") + self._megatron_log_path = megatron_log_path + self._megatron_log_file = open( + megatron_log_path, + "w", + buffering=1, ) - self._megatron_process = await asyncio.create_subprocess_shell( - command, + self._megatron_process = await asyncio.create_subprocess_exec( + *managed_process_cmd(command), cwd=str(project_root), - env=launch_env, + env=env, + stdout=self._megatron_log_file, + stderr=self._megatron_log_file, start_new_session=True, ) + self._install_parent_signal_cleanup() + self._child_processes.watch_asyncio_process( + "Megatron worker", + self._megatron_process, + log_path=megatron_log_path, + ) def _clear_pending_jobs(self) -> None: jobs_dir, _training_log_dir, _wake_lock_path = self._megatron_runtime_paths() @@ -567,25 +665,28 @@ def _create_megatron_job_paths(self) -> tuple[str, str]: ) def _resolve_training_lora_path(self) -> str: - return self._resolve_active_lora_path() + lora_path = get_last_checkpoint_dir(self.output_dir) + if lora_path is None: + lora_path = get_step_checkpoint_dir(self.output_dir, 0) + self._latest_step = 0 + self._ensure_identity_lora(lora_path) + self._ensure_lora_adapter_config(lora_path) + return lora_path - async def _prepare_for_training(self) -> tuple[AsyncLLM, str]: - llm = await self.llm - await llm.pause_generation() - await llm.reset_prefix_cache() - await run_on_workers(llm, do_sleep, level=2) - self._is_sleeping = True + async def _prepare_for_training(self) -> str: + self._raise_if_child_failed() + self._validate_megatron_dependencies() + await self._ensure_megatron_running() + await self._sleep_runtime() gc_and_empty_cuda_cache() - await self._ensure_megatron_running() lora_path = self._resolve_training_lora_path() self._clear_pending_jobs() - return llm, lora_path + return lora_path async def _publish_training_checkpoint( self, *, - llm: AsyncLLM, lora_path: str, ) -> None: next_step = self._latest_step + 1 @@ -601,70 +702,41 @@ async def _publish_training_checkpoint( try: with open(wake_lock_path, "w") as lock_file: lock_file.write("waking vllm\n") - await run_on_workers(llm, do_wake_up) - self._is_sleeping = False + await self._wake_runtime() finally: if os.path.exists(wake_lock_path): os.remove(wake_lock_path) - await self._add_lora_aliases(llm, next_step, new_checkpoint_dir) - await llm.resume_generation() - - async def _publish_dedicated_training_checkpoint(self, *, lora_path: str) -> None: - next_step = self._latest_step + 1 - new_checkpoint_dir = get_step_checkpoint_dir(self.output_dir, next_step) - os.makedirs(new_checkpoint_dir, exist_ok=True) - shutil.copy( - f"{lora_path}/adapter_model.safetensors", - f"{new_checkpoint_dir}/adapter_model.safetensors", - ) - self._ensure_lora_adapter_config(new_checkpoint_dir, source_path=lora_path) await self._reload_adapter(new_checkpoint_dir, next_step) async def start_openai_server( self, config: dev.OpenAIServerConfig | None ) -> tuple[str, int]: + self._raise_if_child_failed() lora_path = self._resolve_active_lora_path() - if self.is_dedicated: - port = (config or {}).get("server_args", {}).get("port", 8000) - location = await self._start_vllm_subprocess(lora_path, port, config) + if not self.is_dedicated and not self._sleep_mode_enabled(): + raise ValueError( + "Shared-GPU mode requires engine_args.enable_sleep_mode=True " + "for the external vLLM runtime" + ) + + port = (config or {}).get("server_args", {}).get("port", 8000) + location = await self._start_vllm_subprocess(lora_path, port, config) + try: if self.rollout_weights_mode == "merged": - self._clear_pending_jobs() await self._sync_dedicated_merged_weights( lora_path=lora_path, step=self._latest_step, ) - return location - - lora_path_for_server = ( - lora_path if self._adapter_has_weights(lora_path) else None - ) - server_config = dev.get_openai_server_config( - model_name=self.model_name, - base_model=self.base_model, - log_file=f"{self.output_dir}/logs/vllm.log", - lora_path=lora_path_for_server, - config=config, - ) - await openai_server_task(engine=await self.llm, config=server_config) - return ( - server_config.get("server_args", {}).get("host") or "0.0.0.0", - server_config.get("server_args", {}).get("port", 8000), - ) + except BaseException: + await self.aclose() + raise + return location async def vllm_engine_is_sleeping(self) -> bool: - if self.is_dedicated: - return False return self._is_sleeping - async def aclose(self) -> None: - self.close() - - def close(self) -> None: - self._stop_vllm_subprocess() - self._stop_megatron_process() - async def train( self, disk_packed_tensors: DiskPackedTensors, @@ -672,54 +744,66 @@ async def train( _config: dev.TrainConfig, verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: - if self.is_dedicated: - await self._ensure_megatron_running() - - lora_path = self._resolve_active_lora_path() - self._clear_pending_jobs() + try: + self._raise_if_child_failed() if _config.get("moe_routing_replay_bundle") is not None: raise RuntimeError( "moe_routing_replay_bundle is only supported for in-process/runtime APIs; " "MegatronService subprocess jobs must use moe_routing_replay_path." ) - job_path, log_path = self._create_megatron_job_paths() - next_step = self._latest_step + 1 - if self.rollout_weights_mode == "merged": - await self._init_merged_weight_transfer() - job = MegatronMergedTrainJob( - lora_path=lora_path, - optimizer_state_path=self._get_optimizer_state_path("rl"), - disk_packed_tensors=disk_packed_tensors, - config=config, - experimental_config=cast(dict[str, Any], _config), - moe_routing_replay_path=_config.get("moe_routing_replay_path"), - moe_routing_replay_strict=_config.get( - "moe_routing_replay_strict", True - ), - merged_weight_transfer=self._build_merged_weight_transfer_spec( - next_step - ), - log_path=log_path, - ) - else: - job = MegatronTrainingJob( - lora_path=lora_path, - optimizer_state_path=self._get_optimizer_state_path("rl"), - disk_packed_tensors=disk_packed_tensors, - config=config, - experimental_config=cast(dict[str, Any], _config), - moe_routing_replay_path=_config.get("moe_routing_replay_path"), - moe_routing_replay_strict=_config.get( - "moe_routing_replay_strict", True - ), - log_path=log_path, - ) - write_megatron_job(job, job_path=job_path) - - async for result in stream_megatron_job(job, job_path=job_path): - yield {key: float(value) for key, value in result.items()} + if self.is_dedicated: + await self._ensure_megatron_running() + lora_path = self._resolve_active_lora_path() + self._clear_pending_jobs() + next_step = self._latest_step + 1 + job_path, log_path = self._create_megatron_job_paths() + if self.rollout_weights_mode == "merged": + await self._init_merged_weight_transfer() + job: MegatronTrainingJob | MegatronMergedTrainingJob = ( + MegatronMergedTrainingJob( + lora_path=lora_path, + allow_unvalidated_arch=self._allow_unvalidated_arch, + optimizer_state_path=self._get_optimizer_state_path("rl"), + disk_packed_tensors=disk_packed_tensors, + config=config, + experimental_config=cast(dict[str, Any], _config), + moe_routing_replay_path=_config.get( + "moe_routing_replay_path" + ), + moe_routing_replay_strict=_config.get( + "moe_routing_replay_strict", + True, + ), + merged_weight_transfer=self._build_merged_weight_transfer_spec( + next_step + ), + log_path=log_path, + ) + ) + else: + job = MegatronTrainingJob( + lora_path=lora_path, + allow_unvalidated_arch=self._allow_unvalidated_arch, + optimizer_state_path=self._get_optimizer_state_path("rl"), + disk_packed_tensors=disk_packed_tensors, + config=config, + experimental_config=cast(dict[str, Any], _config), + moe_routing_replay_path=_config.get("moe_routing_replay_path"), + moe_routing_replay_strict=_config.get( + "moe_routing_replay_strict", + True, + ), + log_path=log_path, + ) + write_megatron_job(job, job_path=job_path) + async for result in stream_megatron_job( + job, + job_path=job_path, + process=self._megatron_process, + process_log_path=self._megatron_log_path, + ): + yield {key: float(value) for key, value in result.items()} - if self.rollout_weights_mode == "merged": new_checkpoint_dir = get_step_checkpoint_dir(self.output_dir, next_step) os.makedirs(new_checkpoint_dir, exist_ok=True) shutil.copy( @@ -727,36 +811,43 @@ async def train( f"{new_checkpoint_dir}/adapter_model.safetensors", ) self._ensure_lora_adapter_config( - new_checkpoint_dir, - source_path=lora_path, + new_checkpoint_dir, source_path=lora_path ) - self._latest_step = next_step - else: - await self._publish_dedicated_training_checkpoint(lora_path=lora_path) - return - llm, lora_path = await self._prepare_for_training() - if _config.get("moe_routing_replay_bundle") is not None: - raise RuntimeError( - "moe_routing_replay_bundle is only supported for in-process/runtime APIs; " - "MegatronService subprocess jobs must use moe_routing_replay_path." + if self.rollout_weights_mode == "merged": + self._latest_step = next_step + else: + await self._reload_adapter(new_checkpoint_dir, next_step) + return + + lora_path = await self._prepare_for_training() + job_path, log_path = self._create_megatron_job_paths() + job = MegatronTrainingJob( + lora_path=lora_path, + allow_unvalidated_arch=self._allow_unvalidated_arch, + optimizer_state_path=self._get_optimizer_state_path("rl"), + disk_packed_tensors=disk_packed_tensors, + config=config, + experimental_config=cast(dict[str, Any], _config), + moe_routing_replay_path=_config.get("moe_routing_replay_path"), + moe_routing_replay_strict=_config.get( + "moe_routing_replay_strict", True + ), + log_path=log_path, ) - job_path, log_path = self._create_megatron_job_paths() - job = MegatronTrainingJob( - lora_path=lora_path, - optimizer_state_path=self._get_optimizer_state_path("rl"), - disk_packed_tensors=disk_packed_tensors, - config=config, - experimental_config=cast(dict[str, Any], _config), - moe_routing_replay_path=_config.get("moe_routing_replay_path"), - moe_routing_replay_strict=_config.get("moe_routing_replay_strict", True), - log_path=log_path, - ) - write_megatron_job(job, job_path=job_path) + write_megatron_job(job, job_path=job_path) - async for result in stream_megatron_job(job, job_path=job_path): - yield {key: float(value) for key, value in result.items()} + async for result in stream_megatron_job( + job, + job_path=job_path, + process=self._megatron_process, + process_log_path=self._megatron_log_path, + ): + yield {key: float(value) for key, value in result.items()} - await self._publish_training_checkpoint(llm=llm, lora_path=lora_path) + await self._publish_training_checkpoint(lora_path=lora_path) + except BaseException: + await self.aclose() + raise async def train_sft( self, @@ -764,43 +855,81 @@ async def train_sft( config: types.TrainSFTConfig, verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: - if self.is_dedicated: - raise NotImplementedError( - "SFT training is not supported for dedicated MegatronService" + try: + self._raise_if_child_failed() + if self.is_dedicated: + raise NotImplementedError( + "train_sft is not yet supported in dedicated mode" + ) + lora_path = await self._prepare_for_training() + serialized_batches = materialize_sft_batches(batches) + job_path, log_path = self._create_megatron_job_paths() + grad_accumulation_sequences = ( + config.batch_size if isinstance(config.batch_size, int) else None ) - llm, lora_path = await self._prepare_for_training() - serialized_batches = materialize_sft_batches(batches) - job_path, log_path = self._create_megatron_job_paths() - grad_accumulation_sequences = ( - config.batch_size if isinstance(config.batch_size, int) else None - ) - job = MegatronSFTTrainingJob( - lora_path=lora_path, - optimizer_state_path=self._get_optimizer_state_path("sft"), - sft_data_dir=serialized_batches.sft_data_dir, - num_batches=serialized_batches.num_batches, - learning_rates=serialized_batches.learning_rates, - grad_accumulation_sequences=grad_accumulation_sequences, - log_path=log_path, - ) - write_megatron_job(job, job_path=job_path) + job = MegatronSFTTrainingJob( + lora_path=lora_path, + allow_unvalidated_arch=self._allow_unvalidated_arch, + optimizer_state_path=self._get_optimizer_state_path("sft"), + sft_data_dir=serialized_batches.sft_data_dir, + num_batches=serialized_batches.num_batches, + learning_rates=serialized_batches.learning_rates, + grad_accumulation_sequences=grad_accumulation_sequences, + log_path=log_path, + ) + write_megatron_job(job, job_path=job_path) - async for result in stream_megatron_job(job, job_path=job_path): - yield { - "loss/train": float(result["loss"]), - "loss/learning_rate": float(result["learning_rate"]), - "loss/grad_norm": float(result["grad_norm"]), - } - - await self._publish_training_checkpoint(llm=llm, lora_path=lora_path) - - @cached_property - def llm(self) -> asyncio.Task[AsyncLLM]: - engine_args = { - **self.config.get("engine_args", {}), - "enable_lora": True, - "max_loras": self.config.get("engine_args", {}).get("max_loras", 2), - } - for key in ["enable_log_requests", "disable_log_requests"]: - engine_args.pop(key, None) - return asyncio.create_task(get_llm(AsyncEngineArgs(**engine_args))) # type: ignore + async for result in stream_megatron_job( + job, + job_path=job_path, + process=self._megatron_process, + process_log_path=self._megatron_log_path, + ): + yield { + "loss/train": float(result["loss"]), + "loss/learning_rate": float(result["learning_rate"]), + "loss/grad_norm": float(result["grad_norm"]), + } + + await self._publish_training_checkpoint(lora_path=lora_path) + except BaseException: + await self.aclose() + raise + + async def aclose(self) -> None: + self.close() + + def _stop_vllm_subprocess(self) -> None: + if self._vllm_process is not None: + terminate_popen_process_group(self._vllm_process) + self._vllm_process = None + if self._vllm_log_file is not None: + self._vllm_log_file.close() + self._vllm_log_file = None + self._vllm_log_path = None + self._merged_weight_transfer_init_info = None + + def _stop_megatron_process(self) -> None: + if self._megatron_process is None: + if self._megatron_log_file is not None: + self._megatron_log_file.close() + self._megatron_log_file = None + self._megatron_log_path = None + return + terminate_asyncio_process_group(self._megatron_process) + self._megatron_process = None + if self._megatron_log_file is not None: + self._megatron_log_file.close() + self._megatron_log_file = None + self._megatron_log_path = None + + def close(self) -> None: + if not self._lifecycle.begin_close(): + return + try: + self._child_processes.close() + self._stop_vllm_subprocess() + self._stop_megatron_process() + self._clear_wake_lock() + finally: + self._restore_parent_signal_cleanup() diff --git a/src/art/megatron/setup.sh b/src/art/megatron/setup.sh index 3fc8bddc6..6d3a5548c 100755 --- a/src/art/megatron/setup.sh +++ b/src/art/megatron/setup.sh @@ -3,28 +3,36 @@ set -euo pipefail export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda-12.8}" export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0}" -# install missing cudnn headers, DeepEP RDMA headers, and ninja build tools -if ! dpkg-query -W libcudnn9-headers-cuda-12 libibverbs-dev ninja-build >/dev/null 2>&1; then - if [[ "$(id -u)" -eq 0 ]]; then +# Install missing cudnn headers, DeepEP RDMA headers, and ninja build tools. +missing_packages=() +for package in libcudnn9-headers-cuda-12 libibverbs-dev ninja-build; do + if ! dpkg-query -W "${package}" >/dev/null 2>&1; then + missing_packages+=("${package}") + fi +done + +if [ "${#missing_packages[@]}" -gt 0 ]; then + if [ "$(id -u)" -eq 0 ]; then apt-get update - apt-get install -y libcudnn9-headers-cuda-12 libibverbs-dev ninja-build + apt-get install -y "${missing_packages[@]}" elif command -v sudo >/dev/null 2>&1 && sudo -n true >/dev/null 2>&1; then sudo apt-get update - sudo apt-get install -y libcudnn9-headers-cuda-12 libibverbs-dev ninja-build + sudo apt-get install -y "${missing_packages[@]}" else - echo "Missing required system packages; install libcudnn9-headers-cuda-12, libibverbs-dev, and ninja-build or rerun with passwordless sudo." >&2 + echo "Missing required packages: ${missing_packages[*]}" >&2 + echo "Install them as root or run with passwordless sudo available." >&2 exit 1 fi fi # Python dependencies are declared in pyproject.toml extras. -# Keep backend + megatron together so setup does not prune runtime deps (e.g. vllm). +# Megatron setup still needs the shared backend extras, but the vLLM runtime now +# lives in its own project and venv under vllm_runtime/. script_dir="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" repo_root="$(cd -- "${script_dir}/../../.." && pwd)" cd "${repo_root}" -uv_bin="${HOME}/.local/bin/uv" -if [[ -x "${uv_bin}" ]]; then - "${uv_bin}" sync --extra backend --extra megatron --frozen --active -else - uv sync --extra backend --extra megatron --frozen --active +uv_bin="uv" +if [ -x "${HOME}/.local/bin/uv" ]; then + uv_bin="${HOME}/.local/bin/uv" fi +"${uv_bin}" sync --extra backend --extra megatron --frozen --active diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index 048f6fd6b..d40a7215a 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -1,5 +1,5 @@ # isort: off -from art.megatron.runtime_env import configure_megatron_runtime_env +from art.megatron.runtime.runtime_env import configure_megatron_runtime_env configure_megatron_runtime_env() # isort: on @@ -12,69 +12,68 @@ - merge_lora_adapter """ -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass import gc import importlib import json import math import os import random -import re import shutil import time -from typing import Any, Callable, cast +from typing import Any, Callable, Literal, cast from megatron.core import parallel_state as ps from megatron.core.distributed import DistributedDataParallelConfig -from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_layer import TransformerLayer from pydantic import BaseModel, ConfigDict, field_validator import torch from torch._inductor.runtime.cache_dir_utils import cache_dir as inductor_cache_dir -from torch.distributed import all_reduce from art import dev, types -from art.dev.validate import QWEN3_5_MOE_MODELS from art.loss import loss_fn, shift_tensor -from art.megatron.bridge_adapter_compat import build_adapter_weights_by_base +from art.megatron.runtime.bridge_runtime import install_art_bridge_runtime_patches + +install_art_bridge_runtime_patches() + from art.megatron.compile_workarounds import install_torch_compile_workarounds -from art.megatron.finalize_grads import finalize_model_grads_extended from art.megatron.flex_attention import create_shared_prefix_attention_state -from art.megatron.jobs import ( +from art.megatron.lora import apply_lora_adapters +from art.megatron.provider import finalize_provider_bundle, prepare_provider_bundle +from art.megatron.provider_common import ProviderBundle +from art.megatron.routing_replay import ( + MoeRoutingReplayBundle, + MoeRoutingReplayController, +) +from art.megatron.runtime.jobs import ( DEFAULT_JOBS_DIR, DEFAULT_VLLM_WAKE_LOCK_PATH, MegatronJob, - MegatronMergedTrainJob, + MegatronMergedTrainingJob, MegatronSFTTrainingJob, MegatronSyncJob, MegatronTrainingJob, MergedWeightTransferInitInfo, MergedWeightTransferSpec, + load_megatron_job, ) -from art.megatron.lora import ( - apply_lora_adapters, -) -from art.megatron.merge import load_lora_adapter_state_dict, merge_lora_adapter -from art.megatron.model_chunks import ( +from art.megatron.training.finalize_grads import finalize_model_grads_extended +from art.megatron.training.model_chunks import ( ModelChunks, as_megatron_api_chunks, - unwrap_megatron_chunk, validate_model_chunks, ) -from art.megatron.offload import ( +from art.megatron.training.offload import ( OffloadState, offload_to_cpu, reload_to_gpu, ) -from art.megatron.provider import _env_flag, get_provider_bundle -from art.megatron.routing_replay import ( - MoeRoutingReplayBundle, - MoeRoutingReplayController, +from art.megatron.training.sft_batches import load_sft_batch_from_disk +from art.megatron.weights.merge import load_lora_adapter_state_dict, merge_lora_adapter +from art.megatron.weights.merged_weight_export import ( + sync_merged_weights_to_vllm, ) -from art.megatron.sft_batches import load_sft_batch_from_disk from art.metrics_taxonomy import TRAIN_GRADIENT_STEPS_KEY from art.preprocessing.pack import ( PackedTensors, @@ -104,8 +103,8 @@ class TrainingRuntime(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) + provider_bundle: ProviderBundle provider: Any - bridge: Any model: ModelChunks optimizer: Any | None optimizer_config: OptimizerConfig @@ -121,6 +120,18 @@ def _validate_model(cls, value: ModelChunks) -> ModelChunks: validate_model_chunks(value) return value + @property + def bridge(self) -> Any: + return self.provider_bundle.bridge + + @property + def model_support_handler(self) -> Any: + return self.provider_bundle.handler + + @property + def model_support_spec(self) -> Any: + return self.provider_bundle.spec + class TrainStepResult(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -133,15 +144,6 @@ class TrainStepResult(BaseModel): num_zeros_in_grad: int | None -@dataclass -class MergedWeightExport: - bridge: Any - model: list[MegatronModule] - model_config: Any - conversion_tasks: list[Any] - adapter_weights_by_base: dict[str, list[Any]] - - def print0(rank: int, *values: Any) -> None: if rank == 0: print(*values) @@ -154,6 +156,25 @@ def freeze_model(model_chunks: list[MegatronModule]) -> list[MegatronModule]: return model_chunks +def _register_trainable_parameter_mode( + provider: Any, + *, + trainable_parameter_mode: Literal["lora", "base_model"], +) -> None: + if trainable_parameter_mode == "lora": + provider.register_pre_wrap_hook(freeze_model) + provider.register_pre_wrap_hook( + lambda chunks: apply_lora_adapters(chunks, provider) + ) + return + if trainable_parameter_mode == "base_model": + return + raise ValueError( + "trainable_parameter_mode must be 'lora' or 'base_model', got " + f"{trainable_parameter_mode!r}" + ) + + def _frozen_linear_grad_input( grad_output: torch.Tensor, weight: torch.Tensor, @@ -178,7 +199,10 @@ def _fast_backward( (weight,) = ctx.saved_tensors grad_input = _frozen_linear_grad_input(grad_output, weight) if ctx.allreduce_dgrad: - all_reduce(grad_input, group=ctx.tp_group) + torch.distributed.all_reduce( # ty: ignore[possibly-missing-attribute] + grad_input, + group=ctx.tp_group, + ) return grad_input, None, None, None, None setattr(_fast_backward, "__art_fast_output_backward__", True) @@ -197,44 +221,12 @@ def _eager_initialize_optimizer_state(optimizer: Any) -> None: init_state_fn(inner_optimizer, getattr(optimizer, "config", None)) -def _compile_enabled(model_identifier: str) -> bool: - disabled = _env_flag("ART_DISABLE_MEGATRON_COMPILE") - if disabled is not None: - return disabled is not True - return model_identifier not in QWEN3_5_MOE_MODELS - - -def _install_gpt_preprocess_hook(model_chunks: ModelChunks) -> None: - for chunk in model_chunks: - module: Any = unwrap_megatron_chunk(chunk) - while not isinstance(module, GPTModel) and hasattr(module, "module"): - module = module.module - if not isinstance(module, GPTModel): - module = getattr(module, "language_model", None) - if not isinstance(module, GPTModel): - continue - preprocess = module._preprocess - - def preprocess_hook(*args, _preprocess=preprocess, **kwargs): - preproc_output = list(_preprocess(*args, **kwargs)) - preproc_output[0].requires_grad = True # type: ignore[index] - table = preproc_output[1] # [S, B, 1, D] # type: ignore[index] - embedding_dim = table.size(-1) - table_flat = table.view(table.size(0), embedding_dim) - position_ids = kwargs["position_ids"] # [B, S] - if position_ids.ndim != 2: - return tuple(preproc_output) - batch_size, sequence_length = position_ids.shape - gathered = table_flat.index_select(0, position_ids.reshape(-1)) - gathered = ( - gathered.view(batch_size, sequence_length, embedding_dim) - .permute(1, 0, 2) - .contiguous() - ) - preproc_output[1] = gathered.unsqueeze(2) # [S, B, 1, D] - return tuple(preproc_output) - - module._preprocess = preprocess_hook # type: ignore[attr-defined] +def _compile_enabled() -> bool: + return os.environ.get("ART_DISABLE_MEGATRON_COMPILE", "0") in { + "0", + "false", + "False", + } def _default_optimizer_config() -> OptimizerConfig: @@ -322,6 +314,7 @@ def build_training_runtime( *, model_identifier: str | None = None, provider_torch_dtype: torch.dtype = torch.bfloat16, + provider_bundle_configure: Callable[[ProviderBundle], None] | None = None, provider_configure: Callable[[Any], None] | None = None, optimizer_config: OptimizerConfig | None = None, moe_routing_replay_path: str | None = None, @@ -329,10 +322,9 @@ def build_training_runtime( moe_routing_replay_strict: bool = True, print_env: bool = True, build_optimizer: bool = True, + trainable_parameter_mode: Literal["lora", "base_model"] = "lora", + allow_unvalidated_arch: bool = False, ) -> TrainingRuntime: - resolved_model_identifier = model_identifier or os.environ.get( - "MODEL_IDENTIFIER", DEFAULT_MODEL_IDENTIFIER - ) if random_state := os.environ.get("ART_MEGATRON_RANDOM_STATE"): seed = int(random_state) random.seed(seed) @@ -340,16 +332,21 @@ def build_training_runtime( if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) _install_fast_frozen_output_backward() - provider_bundle = get_provider_bundle( - resolved_model_identifier, + provider_bundle = prepare_provider_bundle( + model_identifier + or os.environ.get("MODEL_IDENTIFIER", DEFAULT_MODEL_IDENTIFIER), torch_dtype=provider_torch_dtype, + allow_unvalidated_arch=allow_unvalidated_arch, ) + if provider_bundle_configure is not None: + provider_bundle_configure(provider_bundle) provider = provider_bundle.provider if provider_configure is not None: provider_configure(provider) - provider.register_pre_wrap_hook(freeze_model) - provider.register_pre_wrap_hook( - lambda chunks: apply_lora_adapters(chunks, provider) + finalize_provider_bundle(provider_bundle) + _register_trainable_parameter_mode( + provider, + trainable_parameter_mode=trainable_parameter_mode, ) model = cast( @@ -361,6 +358,7 @@ def build_training_runtime( average_in_collective=False, ), data_parallel_random_init=False, + init_model_with_meta_device=True, ), ) @@ -376,37 +374,21 @@ def build_training_runtime( print("Resolved inductor cache_dir():", inductor_cache_dir()) print("TRITON_CACHE_DIR:", os.environ["TRITON_CACHE_DIR"]) - _install_gpt_preprocess_hook(model) - if _compile_enabled(resolved_model_identifier): - if rank == 0: - print("Enabling torch.compile for", resolved_model_identifier) - install_torch_compile_workarounds() + provider_bundle.handler.install_preprocess_patch(model) + compile_workaround_config = provider_bundle.handler.compile_workaround_config( + provider + ) + if _compile_enabled() and not compile_workaround_config.disable_compile: + install_torch_compile_workarounds(compile_workaround_config) for chunk in model: _compile_transformer_layers(chunk) - elif ( - rank == 0 - and _env_flag("ART_DISABLE_MEGATRON_COMPILE") is None - and resolved_model_identifier in QWEN3_5_MOE_MODELS - ): - print( - "Disabling torch.compile for", - resolved_model_identifier, - "because Qwen3.5 MoE currently fails in PyTorch compiled backward stream ops.", - ) optimizer_config = optimizer_config or _default_optimizer_config() - optimizer = ( - _build_optimizer( - model, - optimizer_config, - ) - if build_optimizer - else None - ) + optimizer = _build_optimizer(model, optimizer_config) if build_optimizer else None runtime = TrainingRuntime( + provider_bundle=provider_bundle, provider=provider, - bridge=provider_bundle.bridge, model=model, optimizer=optimizer, optimizer_config=optimizer_config, @@ -467,7 +449,7 @@ def run_megatron_worker_loop( def run_megatron_rl_job( runtime: TrainingRuntime, - job: MegatronTrainingJob | MegatronMergedTrainJob, + job: MegatronTrainingJob | MegatronMergedTrainingJob, ) -> None: packed_tensors = None adapter_model = None @@ -512,6 +494,7 @@ def run_megatron_rl_job( ) step_result = run_training_step( model_chunks=runtime.model, + model_support_handler=runtime.model_support_handler, optimizer=runtime.optimizer, learning_rate=job.config.learning_rate, inputs=micro_inputs, @@ -547,12 +530,6 @@ def run_megatron_rl_job( lora_path=job.lora_path, optimizer_state_path=job.optimizer_state_path, ) - if isinstance(job, MegatronMergedTrainJob): - _sync_merged_weights_to_vllm( - runtime, - job.merged_weight_transfer, - pause_generation=True, - ) finally: if packed_tensors is not None: del packed_tensors @@ -568,29 +545,6 @@ def run_megatron_rl_job( torch.cuda.empty_cache() -def run_megatron_sync_job( - runtime: TrainingRuntime, - job: MegatronSyncJob, -) -> None: - adapter_model = None - try: - adapter_model = maybe_load_adapter_into_model( - runtime.model, - f"{job.lora_path}/adapter_model.safetensors", - rank=runtime.rank, - ) - _sync_merged_weights_to_vllm( - runtime, - job.merged_weight_transfer, - pause_generation=False, - ) - finally: - if adapter_model is not None: - del adapter_model - gc.collect() - torch.cuda.empty_cache() - - def _flush_param_grads_to_main_grads(model_chunks: ModelChunks) -> None: """Fallback for direct SFT jobs when DDP post-hooks leave grads in param.grad. @@ -672,6 +626,7 @@ def run_megatron_sft_job( ) step_result = run_megatron_sft_step( model_chunks=runtime.model, + model_support_handler=runtime.model_support_handler, optimizer=runtime.optimizer, learning_rate=job.learning_rates[batch_idx], inputs=micro_inputs, @@ -710,9 +665,9 @@ def run_megatron_sft_job( "learning_rate": job.learning_rates[batch_idx], "grad_norm": grad_norm, "num_trajectories": float(num_trajectories), + "num_dropped_trajectories": float(num_dropped_trajectories), "num_tokens": float(global_tokens), "num_trainable_tokens": float(global_trainable_tokens), - "num_dropped_trajectories": float(num_dropped_trajectories), "tokens_per_second": tokens_per_second, } ) @@ -734,34 +689,44 @@ def run_megatron_sft_job( def _load_megatron_job(job_path: str, *, supports_sft: bool) -> MegatronJob: with open(job_path, "rb") as handle: - job_data = json.loads(handle.read()) - job_type = job_data.get("job_type") - if job_type == "sft": - if not supports_sft: - raise NotImplementedError("SFT jobs are not supported in this worker loop") - return MegatronSFTTrainingJob.model_validate(job_data) - if job_type == "merged": - return MegatronMergedTrainJob.model_validate(job_data) - if job_type == "sync": - return MegatronSyncJob.model_validate(job_data) - return MegatronTrainingJob.model_validate(job_data) + job = load_megatron_job(handle.read()) + if isinstance(job, MegatronSFTTrainingJob) and not supports_sft: + raise NotImplementedError("SFT jobs are not supported in this worker loop") + return job def _run_megatron_job(runtime: TrainingRuntime, job: MegatronJob) -> None: + if isinstance(job, MegatronSyncJob): + adapter_model = _load_adapter_into_model( + runtime.model, + job.lora_path, + runtime.rank, + handler=runtime.model_support_handler, + ) + del adapter_model + _sync_merged_weights_to_vllm( + runtime, + job.merged_weight_transfer, + pause_generation=False, + ) + return if isinstance(job, MegatronSFTTrainingJob): run_megatron_sft_job(runtime, job) return - if isinstance(job, MegatronSyncJob): - run_megatron_sync_job(runtime, job) - return run_megatron_rl_job(runtime, job) + if isinstance(job, MegatronMergedTrainingJob): + _sync_merged_weights_to_vllm( + runtime, + job.merged_weight_transfer, + pause_generation=True, + ) def _job_cleanup_path(job: MegatronJob) -> str | None: - if isinstance(job, MegatronSFTTrainingJob): - return job.sft_data_dir if isinstance(job, MegatronSyncJob): return None + if isinstance(job, MegatronSFTTrainingJob): + return job.sft_data_dir return job.disk_packed_tensors["dir"] @@ -771,9 +736,12 @@ def _load_lora_and_optimizer( lora_path: str, optimizer_state_path: str, ) -> dict[str, torch.Tensor]: - print0(runtime.rank, "Loading adapter model from", lora_path) - adapter_model = load_lora_adapter_state_dict(lora_path) - load_adapter_into_model(runtime.model, adapter_model) + adapter_model = _load_adapter_into_model( + runtime.model, + lora_path, + runtime.rank, + handler=runtime.model_support_handler, + ) runtime.optimizer = _build_optimizer( runtime.model, runtime.optimizer_config, @@ -798,6 +766,20 @@ def _load_lora_and_optimizer( return adapter_model +def _load_adapter_into_model( + model_chunks: ModelChunks, + lora_path: str, + rank: int, + *, + handler: Any | None = None, + optimizer: Any | None = None, +) -> dict[str, torch.Tensor]: + print0(rank, "Loading adapter model from", lora_path) + adapter_model = load_lora_adapter_state_dict(lora_path, handler=handler) + load_adapter_into_model(model_chunks, adapter_model, optimizer) + return adapter_model + + def _save_lora_and_optimizer( runtime: TrainingRuntime, *, @@ -892,18 +874,6 @@ def iter_modules(model_chunks: ModelChunks) -> Any: yield module -def iter_named_modules(model_chunks: list[MegatronModule]) -> Any: - for chunk in model_chunks: - for module_name, module in chunk.named_modules(): - yield module_name, module - - -def _is_language_transformer_layer_name(module_name: str) -> bool: - while module_name.startswith("module."): - module_name = module_name.removeprefix("module.") - return module_name.startswith(("decoder.layers.", "language_model.decoder.layers.")) - - def load_adapter_into_model( model_chunks: ModelChunks, adapter_model: dict[str, torch.Tensor], @@ -919,25 +889,6 @@ def load_adapter_into_model( optimizer.reload_model_params() -def maybe_load_adapter_into_model( - model_chunks: ModelChunks, - adapter_model_path: str, - optimizer: Any | None = None, - *, - rank: int, -) -> dict[str, torch.Tensor]: - if not os.path.exists(adapter_model_path): - print0(rank, "No adapter model found at", adapter_model_path) - return {} - print0(rank, "Loading adapter model from", adapter_model_path) - with safe_open(adapter_model_path, framework="pt") as adapter_file: - adapter_model = { - key: adapter_file.get_tensor(key) for key in adapter_file.keys() - } - load_adapter_into_model(model_chunks, adapter_model, optimizer) - return adapter_model - - def collect_sharded_lora_state( model_chunks: ModelChunks, adapter_model: dict[str, torch.Tensor], @@ -1101,18 +1052,6 @@ def _move_inputs_to_device(inputs: PackedTensors, device: torch.device) -> None: inputs[key] = value.to(device) # type: ignore[index] -def _attention_block_kwargs( - model_chunk: torch.nn.Module, - attention_state: Any, -) -> dict[str, Any]: - model = model_chunk - while hasattr(model, "module"): - model = model.module # type: ignore[assignment] - if type(model).__name__ == "Qwen3VLModel": - return {"extra_block_kwargs": {"attention_bias": attention_state}} - return {"attention_bias": attention_state} - - def _optimizer_step( optimizer: Any, learning_rate: float, @@ -1191,6 +1130,7 @@ def _prepare_sft_micro_inputs( def run_megatron_sft_step( *, model_chunks: ModelChunks, + model_support_handler: Any, optimizer: Any, learning_rate: float, inputs: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]], @@ -1214,8 +1154,6 @@ def run_megatron_sft_step( assert len(micro_inputs) == 1 micro_sample_indices = [sample_index] - device = next(model_chunks[0].parameters()).device - if moe_routing_replay_controller is not None: resolved_global_grad_accumulation_sequences = ( resolve_global_grad_accumulation_sequences( @@ -1228,13 +1166,20 @@ def run_megatron_sft_step( global_grad_accumulation_sequences=resolved_global_grad_accumulation_sequences, ) + device = next(model_chunks[0].parameters()).device + for chunk in model_chunks: chunk.zero_grad_buffer() # ty: ignore[call-non-callable] raw_loss_sum: torch.Tensor | None = None num_tokens = _local_trainable_sft_token_count_tensor(micro_inputs, device=device) - for micro in micro_inputs: + for micro_order, micro in enumerate(micro_inputs): + if moe_routing_replay_controller is not None: + moe_routing_replay_controller.begin_micro( + micro_sample_indices[micro_order], + micro_order, + ) input_ids, position_ids, shifted_labels, mask, seq_len = ( _prepare_sft_micro_inputs(micro, device) ) @@ -1243,9 +1188,10 @@ def run_megatron_sft_step( position_ids=position_ids, attention_mask=_placeholder_attention_mask(device), labels=shifted_labels, - extra_block_kwargs={ - "attention_bias": _causal_attention_state(seq_len, device), - }, + **model_support_handler.get_forward_kwargs( + model_chunks[0], + attention_bias=_causal_attention_state(seq_len, device), + ), ) masked_loss = per_token_loss[mask].sum() masked_loss.backward() @@ -1289,6 +1235,7 @@ def run_megatron_sft_step( def run_training_step( *, model_chunks: ModelChunks, + model_support_handler: Any, optimizer: Any, learning_rate: float, inputs: PackedTensors | list[PackedTensors], @@ -1337,25 +1284,36 @@ def run_training_step( probs_corr_sum = 0.0 new_logprobs_list: list[torch.Tensor] = [] - for micro in micro_inputs: + for micro_order, micro in enumerate(micro_inputs): + if moe_routing_replay_controller is not None: + moe_routing_replay_controller.begin_micro( + micro_sample_indices[micro_order], + micro_order, + ) _move_inputs_to_device(micro, device) attention_state = create_shared_prefix_attention_state( group_ids=micro["group_ids"], parent_ids=micro["parent_ids"], ) attention_mask = torch.zeros((1, 1, 1, 1), dtype=torch.bool, device=device) + shifted_labels = shift_tensor(micro["tokens"], -100) + shifted_assistant_mask = shift_tensor(micro["assistant_mask"], False) + shifted_labels = torch.where( + shifted_assistant_mask, + shifted_labels, + torch.full_like(shifted_labels, -100), + ) - model_output = model_chunks[0]( + new_logprobs = -model_chunks[0]( input_ids=micro["tokens"], position_ids=micro["input_pos"], attention_mask=attention_mask, - labels=shift_tensor(micro["tokens"], 0), - extra_block_kwargs=_attention_block_kwargs( + labels=shifted_labels, + **model_support_handler.get_forward_kwargs( model_chunks[0], - attention_state, + attention_bias=attention_state, ), ) - new_logprobs = -model_output loss_info = loss_fn( micro, # ty: ignore[invalid-argument-type] @@ -1366,6 +1324,15 @@ def run_training_step( reduction="sum", ) micro_loss = loss_info.policy_loss + if not micro_loss.requires_grad: + raise RuntimeError( + "RL micro_loss is detached before backward: " + f"new_logprobs.requires_grad={new_logprobs.requires_grad}, " + f"policy_loss_sum_requires_grad={loss_info.policy_loss_sum.requires_grad}, " + f"assistant_tokens={int(shift_tensor(micro['assistant_mask'], False).sum().item())}, " + f"nonzero_weights={int(torch.count_nonzero(shift_tensor(micro['weights'], 0.0)).item())}, " + f"nonzero_advantages={int(torch.count_nonzero(shift_tensor(micro['advantages'], 0.0)).item())}" + ) micro_loss.backward() probs_corr_sum += float(loss_info.probs_corr.item()) detached_micro_loss = micro_loss.detach() @@ -1414,289 +1381,26 @@ def run_training_step( ) -def _is_art_adapter_param_name(name: str) -> bool: - return any( - segment in name - for segment in ( - ".lora.", - ".q_proj_lora.", - ".k_proj_lora.", - ".v_proj_lora.", - ".qkv_lora.", - ".z_lora.", - ".gate_lora.", - ".up_lora.", - ) - ) - - -def _canonical_art_param_name(name: str) -> str: - while name.startswith("module."): - name = name[len("module.") :] - while name.startswith("_orig_mod."): - name = name[len("_orig_mod.") :] - while "._orig_mod." in name: - name = name.replace("._orig_mod.", ".") - if name.endswith("._orig_mod"): - name = name[: -len("._orig_mod")] - segments = name.split(".") - canonical: list[str] = [] - i = 0 - while i < len(segments): - if i + 1 < len(segments): - current = segments[i] - nxt = segments[i + 1] - if ( - current - in { - "linear_proj", - "linear_qkv", - "in_proj", - "linear_fc1", - "linear_fc2", - } - and nxt == current - ): - canonical.append(current) - i += 2 - continue - if current == "out_proj" and nxt == "linear_proj": - canonical.append(current) - i += 2 - continue - if current == "row_parallel_lora" and nxt == "linear_proj": - i += 2 - continue - canonical.append(segments[i]) - i += 1 - return ".".join(canonical) - - -def _unwrap_art_wrapper_name(name: str) -> str: - return _canonical_art_param_name(name) - - -def _mapping_hf_weights_exist(mapping: Any, hf_keys: set[str]) -> bool: - if getattr(mapping, "allow_hf_name_mismatch", False): - return True - hf_param = mapping.hf_param - if isinstance(hf_param, str): - return hf_param in hf_keys - if isinstance(hf_param, dict): - return all(param in hf_keys for param in hf_param.values()) - return False - - -def _build_art_conversion_tasks(runtime: TrainingRuntime) -> list[Any]: - from itertools import chain - - from megatron.bridge.models.conversion.model_bridge import ( - WeightConversionTask, - _megatron_local_name_to_global, - ) - from megatron.bridge.models.conversion.utils import ( - get_module_and_param_from_name, - persistent_buffers, - ) - - bridge = runtime.bridge - mapping_registry = bridge._model_bridge.mapping_registry() - hf_source = bridge.hf_pretrained.state.source - hf_keys = set(hf_source.get_all_keys()) - megatron_chunks = as_megatron_api_chunks(runtime.model) - model_config = megatron_chunks[0].config - tasks: list[Any] = [] - for vp_stage, model in enumerate(megatron_chunks): - for local_name, _ in chain(model.named_parameters(), persistent_buffers(model)): - if "_extra_state" in local_name or _is_art_adapter_param_name(local_name): - continue - global_name = _megatron_local_name_to_global( - megatron_chunks, - model_config, - _canonical_art_param_name(local_name), - vp_stage, - ) - mapping = mapping_registry.megatron_to_hf_lookup(global_name) - if mapping is None or not _mapping_hf_weights_exist(mapping, hf_keys): - continue - module_and_param = get_module_and_param_from_name( - megatron_chunks, - local_name, - vp_stage, - ) - local_module = module_and_param[0] - local_weights = module_and_param[1] - if local_module is not None and not hasattr(local_module, "config"): - setattr(local_module, "config", model_config) - tasks.append( - WeightConversionTask( - pp_rank=0, - vp_stage=vp_stage, - param_name=local_name, - global_param_name=global_name, - megatron_module=local_module, - param_weight=local_weights, - mapping=mapping, - ) - ) - return tasks - - -def _build_merged_weight_export(runtime: TrainingRuntime) -> MergedWeightExport: - megatron_chunks = as_megatron_api_chunks(runtime.model) - return MergedWeightExport( - bridge=runtime.bridge, - model=megatron_chunks, - model_config=megatron_chunks[0].config, - conversion_tasks=_build_art_conversion_tasks(runtime), - adapter_weights_by_base=build_adapter_weights_by_base(megatron_chunks), - ) - - -def _iter_merged_vllm_weights(weight_export: MergedWeightExport) -> Any: - # vLLM expects HF checkpoint names, but Megatron only has live trainer weights. - # Convert through Bridge here, then merge ART's LoRA deltas into those tensors. - bridge = weight_export.bridge - model_bridge = bridge._model_bridge - hf_state_dict = bridge.hf_pretrained.state - grouped_buffers: dict[str, dict[int, torch.Tensor]] = {} - for task in weight_export.conversion_tasks: - converted_weights_dict = task.mapping.megatron_to_hf( - task.param_weight, - task.megatron_module, - ) - adapter_weights = weight_export.adapter_weights_by_base.get( - task.global_param_name - ) - if adapter_weights is not None: - converted_weights_dict = model_bridge._merge_lora_adapter_weights( - weight_export.model, - converted_weights_dict, - adapter_weights, - ) - if getattr(task.mapping, "is_grouped_export", False): - merged_result = model_bridge._accumulate_grouped_export( - task, - converted_weights_dict, - weight_export.model_config, - grouped_buffers, - hf_state_dict, - ) - if merged_result is None: - continue - converted_weights_dict = merged_result - else: - converted_weights_dict = model_bridge.maybe_modify_converted_hf_weight( - task, - converted_weights_dict, - hf_state_dict, - ) - for hf_name, tensor in converted_weights_dict.items(): - yield hf_name, tensor - - -def _ensure_merged_weight_transfer_group( - runtime: TrainingRuntime, - spec: MergedWeightTransferSpec, -) -> None: - assert runtime.rank == 0 - assert runtime.world_size == 1 - if runtime.merged_weight_transfer_init_info == spec.init_info: - assert runtime.merged_weight_transfer_group is not None - return - import httpx - from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferEngine - - def _remote_init() -> None: - response = httpx.post( - f"{spec.vllm_base_url}/init_weight_transfer_engine", - json={"init_info": spec.init_info.model_dump()}, - timeout=300.0, - ) - response.raise_for_status() - - with ThreadPoolExecutor(max_workers=1) as executor: - remote_future = executor.submit(_remote_init) - time.sleep(1.0) - runtime.merged_weight_transfer_group = NCCLWeightTransferEngine.trainer_init( - { - "master_address": spec.init_info.master_address, - "master_port": spec.init_info.master_port, - "world_size": spec.init_info.world_size, - } - ) - remote_future.result() - runtime.merged_weight_transfer_init_info = spec.init_info - - def _sync_merged_weights_to_vllm( runtime: TrainingRuntime, spec: MergedWeightTransferSpec, *, pause_generation: bool, ) -> None: - assert runtime.rank == 0 - assert runtime.world_size == 1 - - import httpx - from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferEngine - - _ensure_merged_weight_transfer_group(runtime, spec) - weight_export = _build_merged_weight_export(runtime) - - def _send_weights() -> None: - NCCLWeightTransferEngine.trainer_send_weights( - _iter_merged_vllm_weights(weight_export), - {"group": runtime.merged_weight_transfer_group}, - ) - - with httpx.Client() as client: - if pause_generation: - response = client.post( - f"{spec.vllm_base_url}/pause", - params={"mode": "wait"}, - timeout=300.0, - ) - response.raise_for_status() - try: - torch.cuda.synchronize() - names: list[str] = [] - dtype_names: list[str] = [] - shapes: list[list[int]] = [] - for name, tensor in _iter_merged_vllm_weights(weight_export): - names.append(name) - dtype_names.append(str(tensor.dtype).removeprefix("torch.")) - shapes.append(list(tensor.shape)) - with ThreadPoolExecutor(max_workers=1) as executor: - send_future = executor.submit(_send_weights) - response = client.post( - f"{spec.vllm_base_url}/update_weights", - json={ - "update_info": { - "names": names, - "dtype_names": dtype_names, - "shapes": shapes, - "is_checkpoint_format": True, - } - }, - timeout=600.0, - ) - response.raise_for_status() - send_future.result() - response = client.post( - f"{spec.vllm_base_url}/art/set_served_model_name", - json={"name": spec.served_model_name}, - timeout=30.0, - ) - response.raise_for_status() - torch.cuda.synchronize() - finally: - if pause_generation: - response = client.post( - f"{spec.vllm_base_url}/resume", - timeout=30.0, - ) - response.raise_for_status() + ( + runtime.merged_weight_transfer_group, + runtime.merged_weight_transfer_init_info, + ) = sync_merged_weights_to_vllm( + bridge=runtime.bridge, + model=runtime.model, + model_support_handler=runtime.model_support_handler, + rank=runtime.rank, + world_size=runtime.world_size, + merged_weight_transfer_group=runtime.merged_weight_transfer_group, + merged_weight_transfer_init_info=runtime.merged_weight_transfer_init_info, + spec=spec, + pause_generation=pause_generation, + ) def _run_service_loop(runtime: TrainingRuntime) -> None: @@ -1732,6 +1436,10 @@ def main() -> None: runtime = build_training_runtime( model_identifier=os.environ.get("MODEL_IDENTIFIER", DEFAULT_MODEL_IDENTIFIER), build_optimizer=False, + allow_unvalidated_arch=os.environ.get( + "ART_MEGATRON_ALLOW_UNVALIDATED_ARCH", "" + ).lower() + in {"1", "true", "yes", "on"}, ) _run_service_loop(runtime) diff --git a/src/art/megatron/training/__init__.py b/src/art/megatron/training/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/art/megatron/training/__init__.py @@ -0,0 +1 @@ + diff --git a/src/art/megatron/finalize_grads.py b/src/art/megatron/training/finalize_grads.py similarity index 100% rename from src/art/megatron/finalize_grads.py rename to src/art/megatron/training/finalize_grads.py diff --git a/src/art/megatron/model_chunks.py b/src/art/megatron/training/model_chunks.py similarity index 100% rename from src/art/megatron/model_chunks.py rename to src/art/megatron/training/model_chunks.py diff --git a/src/art/megatron/offload.py b/src/art/megatron/training/offload.py similarity index 98% rename from src/art/megatron/offload.py rename to src/art/megatron/training/offload.py index 5f6d4bf2f..a25b9f120 100644 --- a/src/art/megatron/offload.py +++ b/src/art/megatron/training/offload.py @@ -6,7 +6,7 @@ from megatron.core.distributed import DistributedDataParallel import torch -from art.megatron.model_chunks import unwrap_megatron_chunk +from .model_chunks import unwrap_megatron_chunk @dataclass diff --git a/src/art/megatron/sft_batches.py b/src/art/megatron/training/sft_batches.py similarity index 98% rename from src/art/megatron/sft_batches.py rename to src/art/megatron/training/sft_batches.py index 9dfd113a3..9c20640f2 100644 --- a/src/art/megatron/sft_batches.py +++ b/src/art/megatron/training/sft_batches.py @@ -12,7 +12,7 @@ save_file = safetensors_torch.save_file if TYPE_CHECKING: - from ..preprocessing.tokenize import SFTBatch + from ...preprocessing.tokenize import SFTBatch DEFAULT_SFT_DATA_DIR = "/tmp/megatron_sft_data" diff --git a/src/art/megatron/weights/__init__.py b/src/art/megatron/weights/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/art/megatron/weights/__init__.py @@ -0,0 +1 @@ + diff --git a/src/art/megatron/weights/adapter_export.py b/src/art/megatron/weights/adapter_export.py new file mode 100644 index 000000000..9f989f7de --- /dev/null +++ b/src/art/megatron/weights/adapter_export.py @@ -0,0 +1,347 @@ +import math +from typing import Any + +from megatron.bridge.models.conversion.model_bridge import MegatronWeightTuple +from megatron.bridge.models.conversion.peft_bridge import AdapterWeight +from megatron.core.transformer.transformer_layer import TransformerLayer +import torch + +from art.megatron.lora import ( + GatedDeltaNetInProjLoRA, + LoRA, + MLPExpertsLinearFC1FusedLoRA, + MLPExpertsLinearFC1LoRA, + MLPExpertsLinearFC2LoRA, + SelfAttentionLinearProjLoRA, + SelfAttentionLinearQKVLoRA, + SharedExpertsLinearFC1LoRA, + SharedExpertsLinearFC2LoRA, +) +from art.megatron.weights.param_name_canonicalization import canonical_art_param_name + + +def layer_base_prefix( + module: TransformerLayer, + *, + module_name: str | None = None, +) -> str: + if module_name is not None: + canonical_name = canonical_art_param_name(module_name) + if canonical_name.startswith( + ("decoder.layers.", "language_model.decoder.layers.") + ): + return canonical_name + return f"language_model.decoder.layers.{module.layer_number - 1}" + + +def _adapter_alpha_dim(lora: LoRA) -> tuple[int, int]: + dim = int(lora.A_T.shape[-1]) + alpha = float(lora.scale) * dim + rounded_alpha = round(alpha) + assert math.isclose(alpha, rounded_alpha) + return rounded_alpha, dim + + +def _adapter_tensors( + lora: LoRA, + expert_idx: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + a_t = lora.A_T if expert_idx is None else lora.A_T[expert_idx] + b_t = lora.B_T if expert_idx is None else lora.B_T[expert_idx] + return a_t.transpose(-1, -2).contiguous(), b_t.transpose(-1, -2).contiguous() + + +def _adapter_param_prefix(base_prefix: str, adapter_key: str | None) -> str: + if adapter_key is None: + return f"{base_prefix}.adapter" + return f"{base_prefix}.adapter.{adapter_key}" + + +def _adapter_weight( + *, + base_prefix: str, + adapter_key: str | None, + alpha: int, + dim: int, + linear_in: torch.Tensor, + linear_out: torch.Tensor, +) -> AdapterWeight: + param_prefix = _adapter_param_prefix(base_prefix, adapter_key) + return AdapterWeight( + global_base_prefix=base_prefix, + adapter_key=adapter_key, + alpha=alpha, + dim=dim, + linear_in_weight=MegatronWeightTuple( + param_name=f"{param_prefix}.linear_in.weight", + weight=linear_in, + vp_stage=0, + ), + linear_out_weight=MegatronWeightTuple( + param_name=f"{param_prefix}.linear_out.weight", + weight=linear_out, + vp_stage=0, + ), + ) + + +def _simple_adapter_weight( + base_prefix: str, + lora: LoRA, + *, + adapter_key: str | None = None, + expert_idx: int | None = None, +) -> AdapterWeight: + alpha, dim = _adapter_alpha_dim(lora) + linear_in, linear_out = _adapter_tensors(lora, expert_idx) + return _adapter_weight( + base_prefix=base_prefix, + adapter_key=adapter_key, + alpha=alpha, + dim=dim, + linear_in=linear_in, + linear_out=linear_out, + ) + + +def _zero_adapter_weight( + *, + base_prefix: str, + adapter_key: str, + input_dim: int, + output_dim: int, + like: torch.Tensor, +) -> AdapterWeight: + return _adapter_weight( + base_prefix=base_prefix, + adapter_key=adapter_key, + alpha=1, + dim=1, + linear_in=like.new_zeros((1, input_dim)), + linear_out=like.new_zeros((output_dim, 1)), + ) + + +def _fused_pair_adapter_weight( + base_prefix: str, + first_lora: LoRA, + second_lora: LoRA, + *, + first_expert_idx: int | None = None, + second_expert_idx: int | None = None, +) -> AdapterWeight: + first_linear_in, first_linear_out = _adapter_tensors(first_lora, first_expert_idx) + second_linear_in, second_linear_out = _adapter_tensors( + second_lora, + second_expert_idx, + ) + assert math.isclose(float(first_lora.scale), float(second_lora.scale)) + total_dim = int(first_linear_in.shape[0] + second_linear_in.shape[0]) + alpha = round(float(first_lora.scale) * total_dim) + + first_rank = int(first_linear_in.shape[0]) + second_rank = int(second_linear_in.shape[0]) + first_out = int(first_linear_out.shape[0]) + second_out = int(second_linear_out.shape[0]) + + first_padding = first_linear_out.new_zeros((first_out, second_rank)) + second_padding = second_linear_out.new_zeros((second_out, first_rank)) + return _adapter_weight( + base_prefix=base_prefix, + adapter_key=None, + alpha=alpha, + dim=total_dim, + linear_in=torch.cat([first_linear_in, second_linear_in], dim=0), + linear_out=torch.cat( + [ + torch.cat([first_linear_out, first_padding], dim=1), + torch.cat([second_padding, second_linear_out], dim=1), + ], + dim=0, + ), + ) + + +def add_standard_self_attention_adapter_weights( + adapter_weights_by_base: dict[str, list[Any]], + *, + layer_prefix: str, + self_attention: Any, +) -> None: + linear_proj = getattr(self_attention, "linear_proj", None) + if isinstance(linear_proj, SelfAttentionLinearProjLoRA): + base_prefix = f"{layer_prefix}.self_attention.linear_proj" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _simple_adapter_weight(base_prefix, linear_proj.lora) + ] + + linear_qkv = getattr(self_attention, "linear_qkv", None) + if isinstance(linear_qkv, SelfAttentionLinearQKVLoRA): + base_prefix = f"{layer_prefix}.self_attention.linear_qkv" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _simple_adapter_weight( + base_prefix, + linear_qkv.q_proj_lora, + adapter_key="adapter_q", + ), + _simple_adapter_weight( + base_prefix, + linear_qkv.k_proj_lora, + adapter_key="adapter_k", + ), + _simple_adapter_weight( + base_prefix, + linear_qkv.v_proj_lora, + adapter_key="adapter_v", + ), + ] + + +def add_gated_delta_net_adapter_weights( + adapter_weights_by_base: dict[str, list[Any]], + *, + layer_prefix: str, + self_attention: Any, +) -> None: + out_proj = getattr(self_attention, "out_proj", None) + if isinstance(out_proj, SelfAttentionLinearProjLoRA): + base_prefix = f"{layer_prefix}.self_attention.out_proj" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _simple_adapter_weight(base_prefix, out_proj.lora) + ] + + in_proj = getattr(self_attention, "in_proj", None) + if isinstance(in_proj, GatedDeltaNetInProjLoRA): + base_prefix = f"{layer_prefix}.self_attention.in_proj" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _simple_adapter_weight( + base_prefix, + in_proj.qkv_lora, + adapter_key="adapter_qkv", + ), + _simple_adapter_weight( + base_prefix, + in_proj.z_lora, + adapter_key="adapter_z", + ), + _zero_adapter_weight( + base_prefix=base_prefix, + adapter_key="adapter_b", + input_dim=int(in_proj.qkv_lora.A_T.shape[-1]), + output_dim=int(in_proj.num_value_heads_per_partition), + like=in_proj.qkv_lora.B_T, + ), + _zero_adapter_weight( + base_prefix=base_prefix, + adapter_key="adapter_a", + input_dim=int(in_proj.qkv_lora.A_T.shape[-1]), + output_dim=int(in_proj.num_value_heads_per_partition), + like=in_proj.qkv_lora.B_T, + ), + ] + + +def add_grouped_moe_adapter_weights( + adapter_weights_by_base: dict[str, list[Any]], + *, + layer_prefix: str, + experts: Any, +) -> None: + linear_fc1 = getattr(experts, "linear_fc1", None) + if isinstance(linear_fc1, MLPExpertsLinearFC1FusedLoRA): + base_prefix = f"{layer_prefix}.mlp.experts.linear_fc1" + for local_expert_idx in range(linear_fc1.lora.num_local_experts): + global_expert_idx = local_expert_idx + linear_fc1.lora._expert_offset + adapter_weights_by_base[f"{base_prefix}.weight{global_expert_idx}"] = [ + _simple_adapter_weight( + base_prefix, + linear_fc1.lora, + expert_idx=local_expert_idx, + ) + ] + elif isinstance(linear_fc1, MLPExpertsLinearFC1LoRA): + base_prefix = f"{layer_prefix}.mlp.experts.linear_fc1" + for local_expert_idx in range(linear_fc1.gate_lora.num_local_experts): + global_expert_idx = local_expert_idx + linear_fc1.gate_lora._expert_offset + adapter_weights_by_base[f"{base_prefix}.weight{global_expert_idx}"] = [ + _fused_pair_adapter_weight( + base_prefix, + linear_fc1.gate_lora, + linear_fc1.up_lora, + first_expert_idx=local_expert_idx, + second_expert_idx=local_expert_idx, + ) + ] + + linear_fc2 = getattr(experts, "linear_fc2", None) + if isinstance(linear_fc2, MLPExpertsLinearFC2LoRA): + base_prefix = f"{layer_prefix}.mlp.experts.linear_fc2" + for local_expert_idx in range(linear_fc2.lora.num_local_experts): + global_expert_idx = local_expert_idx + linear_fc2.lora._expert_offset + adapter_weights_by_base[f"{base_prefix}.weight{global_expert_idx}"] = [ + _simple_adapter_weight( + base_prefix, + linear_fc2.lora, + expert_idx=local_expert_idx, + ) + ] + + +def add_dense_mlp_adapter_weights( + adapter_weights_by_base: dict[str, list[Any]], + *, + layer_prefix: str, + mlp: Any, +) -> None: + linear_fc1 = getattr(mlp, "linear_fc1", None) + if isinstance(linear_fc1, SharedExpertsLinearFC1LoRA): + base_prefix = f"{layer_prefix}.mlp.linear_fc1" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _simple_adapter_weight( + base_prefix, + linear_fc1.gate_lora, + adapter_key="adapter_gate", + ), + _simple_adapter_weight( + base_prefix, + linear_fc1.up_lora, + adapter_key="adapter_up", + ), + ] + + linear_fc2 = getattr(mlp, "linear_fc2", None) + if isinstance(linear_fc2, SharedExpertsLinearFC2LoRA): + base_prefix = f"{layer_prefix}.mlp.linear_fc2" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _simple_adapter_weight(base_prefix, linear_fc2.row_parallel_lora.lora) + ] + + +def add_shared_experts_adapter_weights( + adapter_weights_by_base: dict[str, list[Any]], + *, + layer_prefix: str, + shared_experts: Any, +) -> None: + linear_fc1 = getattr(shared_experts, "linear_fc1", None) + if isinstance(linear_fc1, SharedExpertsLinearFC1LoRA): + base_prefix = f"{layer_prefix}.mlp.shared_experts.linear_fc1" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _simple_adapter_weight( + base_prefix, + linear_fc1.gate_lora, + adapter_key="adapter_gate", + ), + _simple_adapter_weight( + base_prefix, + linear_fc1.up_lora, + adapter_key="adapter_up", + ), + ] + + linear_fc2 = getattr(shared_experts, "linear_fc2", None) + if isinstance(linear_fc2, SharedExpertsLinearFC2LoRA): + base_prefix = f"{layer_prefix}.mlp.shared_experts.linear_fc2" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _simple_adapter_weight(base_prefix, linear_fc2.row_parallel_lora.lora) + ] diff --git a/src/art/megatron/merge.py b/src/art/megatron/weights/merge.py similarity index 51% rename from src/art/megatron/merge.py rename to src/art/megatron/weights/merge.py index a77c22cf3..00a4b601b 100644 --- a/src/art/megatron/merge.py +++ b/src/art/megatron/weights/merge.py @@ -5,12 +5,100 @@ import torch +from art.megatron.model_support.lora_disk import ( + load_lora_tensors_for_megatron, + normalize_lora_checkpoint_to_vllm, +) + safetensors = importlib.import_module("safetensors") safetensors_torch = importlib.import_module("safetensors.torch") safe_open = safetensors.safe_open save_file = safetensors_torch.save_file +def _merge_sharded_tensor( + key: str, + *, + ordered_shards: list[torch.Tensor], + manifest: dict[str, Any], +) -> torch.Tensor: + strategy = manifest.get("export_shard_strategy") + assert strategy is not None + axis = int(manifest.get("export_shard_dim", 1 if "lora_A" in key else 0)) + if strategy == "componentwise": + component_sizes = [int(size) for size in manifest.get("component_sizes", [])] + world_size = int(manifest["shard_world_size"]) + if not component_sizes: + raise RuntimeError( + f"Missing component_sizes for key={key} shard strategy={strategy}" + ) + local_sizes = [] + for size in component_sizes: + if size % world_size != 0: + raise RuntimeError( + f"Component size {size} is not divisible by shard_world_size={world_size} for key={key}" + ) + local_sizes.append(size // world_size) + split_shards = [ + torch.split(shard, local_sizes, dim=axis) for shard in ordered_shards + ] + merged_components = [ + torch.cat([parts[index] for parts in split_shards], dim=axis) + for index in range(len(local_sizes)) + ] + return torch.cat(merged_components, dim=axis).contiguous() + if strategy != "uniform": + raise RuntimeError(f"Unsupported shard strategy={strategy} for key={key}") + return torch.cat(ordered_shards, dim=axis).contiguous() + + +def merge_sharded_adapter_entries( + entries_by_key: dict[str, list[tuple[dict[str, Any], torch.Tensor]]], +) -> dict[str, torch.Tensor]: + adapter_model: dict[str, torch.Tensor] = {} + for key, key_entries in entries_by_key.items(): + first_manifest = key_entries[0][0] + sharded = bool(first_manifest["sharded"]) + shard_world_size = int(first_manifest["shard_world_size"]) + for manifest_entry, _tensor in key_entries: + if bool(manifest_entry["sharded"]) != sharded: + raise RuntimeError(f"Inconsistent sharded flag for key={key}") + if int(manifest_entry["shard_world_size"]) != shard_world_size: + raise RuntimeError(f"Inconsistent shard world size for key={key}") + + if not sharded: + if len(key_entries) != 1: + raise RuntimeError( + f"Replicated key={key} expected 1 shard, got {len(key_entries)}" + ) + adapter_model[key] = key_entries[0][1] + continue + + shard_rank_to_tensor: dict[int, torch.Tensor] = {} + for manifest_entry, shard_tensor in key_entries: + shard_rank = int(manifest_entry["shard_rank"]) + if shard_rank in shard_rank_to_tensor: + raise RuntimeError(f"Duplicate shard_rank={shard_rank} for key={key}") + shard_rank_to_tensor[shard_rank] = shard_tensor + + expected_shard_ranks = set(range(shard_world_size)) + if set(shard_rank_to_tensor) != expected_shard_ranks: + raise RuntimeError( + f"Shard rank coverage mismatch for key={key}: " + f"expected {sorted(expected_shard_ranks)}, got {sorted(shard_rank_to_tensor)}" + ) + + ordered_shards = [ + shard_rank_to_tensor[shard_rank] for shard_rank in range(shard_world_size) + ] + adapter_model[key] = _merge_sharded_tensor( + key, + ordered_shards=ordered_shards, + manifest=first_manifest, + ) + return adapter_model + + def _load_adapter_shards( base_dir: Path, ) -> tuple[ @@ -57,56 +145,24 @@ def _load_adapter_shards( for key, tensor in shard_tensors.items(): entries_by_key.setdefault(key, []).append((shard_manifest[key], tensor)) - adapter_model: dict[str, torch.Tensor] = {} - for key, key_entries in entries_by_key.items(): - first_manifest = key_entries[0][0] - sharded = bool(first_manifest["sharded"]) - shard_world_size = int(first_manifest["shard_world_size"]) - for manifest_entry, _tensor in key_entries: - if bool(manifest_entry["sharded"]) != sharded: - raise RuntimeError(f"Inconsistent sharded flag for key={key}") - if int(manifest_entry["shard_world_size"]) != shard_world_size: - raise RuntimeError(f"Inconsistent shard world size for key={key}") - - if not sharded: - if len(key_entries) != 1: - raise RuntimeError( - f"Replicated key={key} expected 1 shard, got {len(key_entries)}" - ) - tensor = key_entries[0][1] - else: - shard_rank_to_tensor: dict[int, torch.Tensor] = {} - for manifest_entry, shard_tensor in key_entries: - shard_rank = int(manifest_entry["shard_rank"]) - if shard_rank in shard_rank_to_tensor: - raise RuntimeError( - f"Duplicate shard_rank={shard_rank} for key={key}" - ) - shard_rank_to_tensor[shard_rank] = shard_tensor - - expected_shard_ranks = set(range(shard_world_size)) - if set(shard_rank_to_tensor) != expected_shard_ranks: - raise RuntimeError( - f"Shard rank coverage mismatch for key={key}: " - f"expected {sorted(expected_shard_ranks)}, got {sorted(shard_rank_to_tensor)}" - ) - - ordered_shards = [ - shard_rank_to_tensor[shard_rank] - for shard_rank in range(shard_world_size) - ] - concat_dim = 1 if "lora_A" in key else 0 - tensor = torch.cat(ordered_shards, dim=concat_dim) - adapter_model[key] = tensor + adapter_model = merge_sharded_adapter_entries(entries_by_key) return adapter_model, shard_filenames, manifest_filenames -def load_lora_adapter_state_dict(lora_path: str) -> dict[str, torch.Tensor]: +def load_lora_adapter_state_dict( + lora_path: str, + *, + handler: Any | None = None, + allow_unvalidated_arch: bool = False, +) -> dict[str, torch.Tensor]: base_dir = Path(lora_path) adapter_model_path = base_dir / "adapter_model.safetensors" if adapter_model_path.exists(): - with safe_open(adapter_model_path, framework="pt") as file: - return {key: file.get_tensor(key) for key in file.keys()} + return load_lora_tensors_for_megatron( + lora_path, + handler=handler, + allow_unvalidated_arch=allow_unvalidated_arch, + ) adapter_model, _shard_filenames, _manifest_filenames = _load_adapter_shards( base_dir @@ -114,17 +170,20 @@ def load_lora_adapter_state_dict(lora_path: str) -> dict[str, torch.Tensor]: return adapter_model -def merge_lora_adapter(lora_path: str) -> None: +def merge_lora_adapter( + lora_path: str, + *, + allow_unvalidated_arch: bool = False, +) -> None: base_dir = Path(lora_path) - try: - adapter_model, shard_filenames, manifest_filenames = _load_adapter_shards( - base_dir - ) - except FileNotFoundError: - return + adapter_model, shard_filenames, manifest_filenames = _load_adapter_shards(base_dir) adapter_model_path = base_dir / "adapter_model.safetensors" save_file(adapter_model, adapter_model_path) + normalize_lora_checkpoint_to_vllm( + base_dir, + allow_unvalidated_arch=allow_unvalidated_arch, + ) for filename in shard_filenames: filename.unlink() for filename in manifest_filenames: diff --git a/src/art/megatron/weights/merged_weight_export.py b/src/art/megatron/weights/merged_weight_export.py new file mode 100644 index 000000000..0ae2b766c --- /dev/null +++ b/src/art/megatron/weights/merged_weight_export.py @@ -0,0 +1,501 @@ +from concurrent.futures import ThreadPoolExecutor +from itertools import chain +import time +from typing import Any, Iterator, cast + +from pydantic import BaseModel, ConfigDict +import torch + +from art.megatron.runtime.jobs import ( + MergedWeightTransferInitInfo, + MergedWeightTransferSpec, +) +from art.megatron.training.model_chunks import ModelChunks, as_megatron_api_chunks +from art.megatron.weights.param_name_canonicalization import ( + canonical_art_param_name, + is_art_adapter_param_name, +) +from art.weight_transfer import ( + DEFAULT_PACKED_BUFFER_SIZE_BYTES, + DEFAULT_PACKED_NUM_BUFFERS, + trainer_init, + trainer_send_weights, +) + + +class MergedWeightExport(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + bridge: Any + model: ModelChunks + model_config_value: Any + conversion_tasks: list[Any] + adapter_weights_by_base: dict[str, list[Any]] + + +def _hf_param_names(hf_param: Any) -> list[str]: + if isinstance(hf_param, str): + return [hf_param] + return list(hf_param.values()) + + +def build_art_conversion_tasks(*, bridge: Any, model: ModelChunks) -> list[Any]: + from megatron.bridge.models.conversion.model_bridge import ( + WeightConversionTask, + _megatron_local_name_to_global, + ) + from megatron.bridge.models.conversion.utils import ( + get_module_and_param_from_name, + persistent_buffers, + ) + + mapping_registry = bridge._model_bridge.mapping_registry() + hf_source = bridge.hf_pretrained.state.source + hf_keys = set(hf_source.get_all_keys()) + megatron_models = as_megatron_api_chunks(model) + model_config = cast(Any, model[0].config) + tasks: list[Any] = [] + for vp_stage, chunk in enumerate(model): + for local_name, _ in chain( + chunk.named_parameters(), + persistent_buffers(chunk), + ): + if "_extra_state" in local_name or is_art_adapter_param_name(local_name): + continue + global_name = _megatron_local_name_to_global( + megatron_models, + model_config, + canonical_art_param_name(local_name), + vp_stage, + ) + mapping = mapping_registry.megatron_to_hf_lookup(global_name) + hf_params = _hf_param_names(mapping.hf_param) + missing_hf_params = sorted(set(hf_params) - hf_keys) + if missing_hf_params: + raise RuntimeError( + f"Missing HF checkpoint weights for Megatron param {global_name}: " + f"{missing_hf_params}" + ) + local_module, local_weights = cast( + tuple[Any, torch.Tensor], + get_module_and_param_from_name( + megatron_models, + local_name, + vp_stage, + ), + ) + if local_module is not None and not hasattr(local_module, "config"): + setattr(local_module, "config", model_config) + tasks.append( + WeightConversionTask( + pp_rank=0, + vp_stage=vp_stage, + param_name=local_name, + global_param_name=global_name, + megatron_module=local_module, + param_weight=local_weights, + mapping=mapping, + ) + ) + return tasks + + +def build_merged_weight_export( + *, + bridge: Any, + model: ModelChunks, + model_support_handler: Any, +) -> MergedWeightExport: + return MergedWeightExport( + bridge=bridge, + model=model, + model_config_value=model[0].config, + conversion_tasks=build_art_conversion_tasks( + bridge=bridge, + model=model, + ), + adapter_weights_by_base=model_support_handler.build_adapter_weights_by_base( + model + ), + ) + + +def iter_merged_vllm_weights( + weight_export: MergedWeightExport, +) -> Iterator[tuple[str, torch.Tensor]]: + bridge = weight_export.bridge + model_bridge = bridge._model_bridge + hf_state_dict = bridge.hf_pretrained.state + grouped_buffers: dict[str, dict[int, torch.Tensor]] = {} + for task in weight_export.conversion_tasks: + converted_weights_dict = task.mapping.megatron_to_hf( + task.param_weight, + task.megatron_module, + ) + adapter_weights = weight_export.adapter_weights_by_base.get( + task.global_param_name + ) + if adapter_weights is not None: + try: + converted_weights_dict = model_bridge._merge_lora_adapter_weights( + weight_export.model, + converted_weights_dict, + adapter_weights, + ) + except Exception as exc: + converted_shapes = { + key: tuple(value.shape) + for key, value in converted_weights_dict.items() + } + adapter_summaries = [ + { + "base_prefix": adapter_weight.global_base_prefix, + "adapter_key": adapter_weight.adapter_key, + "linear_in": tuple( + adapter_weight.linear_in_weight.weight.shape + ), + "linear_out": tuple( + adapter_weight.linear_out_weight.weight.shape + ), + } + for adapter_weight in adapter_weights + ] + raise RuntimeError( + "Failed merged LoRA export for " + f"{task.global_param_name}: converted={converted_shapes} " + f"adapter_weights={adapter_summaries}" + ) from exc + if getattr(task.mapping, "is_grouped_export", False): + merged_result = model_bridge._accumulate_grouped_export( + task, + converted_weights_dict, + weight_export.model_config_value, + grouped_buffers, + hf_state_dict, + ) + if merged_result is None: + continue + converted_weights_dict = merged_result + else: + converted_weights_dict = model_bridge.maybe_modify_converted_hf_weight( + task, + converted_weights_dict, + hf_state_dict, + ) + yield from converted_weights_dict.items() + + +def _is_sender_rank(rank: int) -> bool: + return rank == 0 + + +def _maybe_distributed_barrier(world_size: int) -> None: + if world_size <= 1: + return + dist = cast(Any, torch.distributed) + if not dist.is_available() or not dist.is_initialized(): + return + dist.barrier() + + +def _runtime_headers(spec: MergedWeightTransferSpec) -> dict[str, str]: + if spec.api_key is None: + return {} + return {"Authorization": f"Bearer {spec.api_key}"} + + +def _post_with_retry( + post: Any, + url: str, + *, + phase: str, + retry_seconds: float = 10.0, + **kwargs: Any, +) -> Any: + if kwargs.get("headers") == {}: + kwargs = {key: value for key, value in kwargs.items() if key != "headers"} + deadline = time.monotonic() + retry_seconds + while True: + try: + response = post(url, **kwargs) + response.raise_for_status() + return response + except Exception as exc: + if time.monotonic() >= deadline: + raise RuntimeError( + f"{phase} failed after retrying for {retry_seconds:g}s" + ) from exc + time.sleep(0.5) + + +def _sync_rank_zero_status( + *, + rank: int, + world_size: int, + phase: str, + error: BaseException | None, +) -> None: + dist = cast(Any, torch.distributed) + if world_size <= 1 or not (dist.is_available() and dist.is_initialized()): + if error is not None: + raise RuntimeError(f"{phase} failed on rank 0") from error + return + payload = [ + f"{type(error).__name__}: {error}" + if _is_sender_rank(rank) and error is not None + else None + ] + dist.broadcast_object_list(payload, src=0) + if payload[0] is None: + return + if _is_sender_rank(rank): + raise RuntimeError(f"{phase} failed on rank 0: {payload[0]}") from error + raise RuntimeError(f"{phase} failed on rank 0: {payload[0]}") + + +def _drain_merged_vllm_weights( + weight_export: MergedWeightExport, + *, + names: list[str] | None = None, + dtype_names: list[str] | None = None, + shapes: list[list[int]] | None = None, +) -> None: + for name, tensor in iter_merged_vllm_weights(weight_export): + if names is not None: + assert dtype_names is not None + assert shapes is not None + names.append(name) + dtype_names.append(str(tensor.dtype).removeprefix("torch.")) + shapes.append(list(tensor.shape)) + + +def ensure_merged_weight_transfer_group( + *, + rank: int, + world_size: int, + merged_weight_transfer_group: Any | None, + merged_weight_transfer_init_info: MergedWeightTransferInitInfo | None, + spec: MergedWeightTransferSpec, +) -> tuple[Any, MergedWeightTransferInitInfo]: + if merged_weight_transfer_init_info == spec.init_info: + if _is_sender_rank(rank): + assert merged_weight_transfer_group is not None + assert merged_weight_transfer_init_info is not None + _maybe_distributed_barrier(world_size) + return merged_weight_transfer_group, merged_weight_transfer_init_info + + import httpx + + error: BaseException | None = None + if _is_sender_rank(rank): + init_kwargs = { + "master_address": spec.init_info.master_address, + "master_port": spec.init_info.master_port, + "world_size": spec.init_info.world_size, + } + executor = ThreadPoolExecutor(max_workers=1) + try: + trainer_future = executor.submit(trainer_init, init_kwargs) + _post_with_retry( + httpx.post, + f"{spec.vllm_base_url}/init_weight_transfer_engine", + phase="initialize merged weight transfer", + json={"init_info": spec.init_info.model_dump()}, + headers=_runtime_headers(spec), + timeout=300.0, + ) + merged_weight_transfer_group = trainer_future.result() + except BaseException as exc: + error = exc + finally: + executor.shutdown(wait=error is None, cancel_futures=error is not None) + _sync_rank_zero_status( + rank=rank, + world_size=world_size, + phase="initialize merged weight transfer", + error=error, + ) + return merged_weight_transfer_group, spec.init_info + + +def sync_merged_weights_to_vllm( + *, + bridge: Any, + model: ModelChunks, + model_support_handler: Any, + rank: int, + world_size: int, + merged_weight_transfer_group: Any | None, + merged_weight_transfer_init_info: MergedWeightTransferInitInfo | None, + spec: MergedWeightTransferSpec, + pause_generation: bool, +) -> tuple[Any, MergedWeightTransferInitInfo]: + import httpx + + ( + merged_weight_transfer_group, + merged_weight_transfer_init_info, + ) = ensure_merged_weight_transfer_group( + rank=rank, + world_size=world_size, + merged_weight_transfer_group=merged_weight_transfer_group, + merged_weight_transfer_init_info=merged_weight_transfer_init_info, + spec=spec, + ) + weight_export = build_merged_weight_export( + bridge=bridge, + model=model, + model_support_handler=model_support_handler, + ) + + def _send_weights() -> None: + assert merged_weight_transfer_group is not None + trainer_send_weights( + iter_merged_vllm_weights(weight_export), + { + "group": merged_weight_transfer_group, + "packed": True, + "packed_buffer_size_bytes": DEFAULT_PACKED_BUFFER_SIZE_BYTES, + "packed_num_buffers": DEFAULT_PACKED_NUM_BUFFERS, + }, + ) + + torch.cuda.synchronize() + names: list[str] = [] + dtype_names: list[str] = [] + shapes: list[list[int]] = [] + _drain_merged_vllm_weights( + weight_export, + names=names if _is_sender_rank(rank) else None, + dtype_names=dtype_names if _is_sender_rank(rank) else None, + shapes=shapes if _is_sender_rank(rank) else None, + ) + _maybe_distributed_barrier(world_size) + + pause_error: BaseException | None = None + update_error: BaseException | None = None + resume_error: BaseException | None = None + + if _is_sender_rank(rank): + with httpx.Client() as client: + if pause_generation: + try: + _post_with_retry( + client.post, + f"{spec.vllm_base_url}/pause", + phase="pause generation", + params={"mode": "wait"}, + headers=_runtime_headers(spec), + timeout=300.0, + ) + except BaseException as exc: + pause_error = exc + + _sync_rank_zero_status( + rank=rank, + world_size=world_size, + phase="pause generation", + error=pause_error, + ) + try: + _post_with_retry( + client.post, + f"{spec.vllm_base_url}/start_weight_update", + phase="start merged weight update", + json={"is_checkpoint_format": True}, + headers=_runtime_headers(spec), + timeout=300.0, + ) + with ThreadPoolExecutor(max_workers=1) as executor: + send_future = executor.submit(_send_weights) + _post_with_retry( + client.post, + f"{spec.vllm_base_url}/update_weights", + phase="update merged weights", + json={ + "update_info": { + "names": names, + "dtype_names": dtype_names, + "shapes": shapes, + "packed": True, + "packed_buffer_size_bytes": DEFAULT_PACKED_BUFFER_SIZE_BYTES, + "packed_num_buffers": DEFAULT_PACKED_NUM_BUFFERS, + } + }, + headers=_runtime_headers(spec), + timeout=600.0, + ) + send_future.result() + _post_with_retry( + client.post, + f"{spec.vllm_base_url}/finish_weight_update", + phase="finish merged weight update", + headers=_runtime_headers(spec), + timeout=600.0, + ) + _post_with_retry( + client.post, + f"{spec.vllm_base_url}/art/set_served_model_name", + phase="set served model name", + json={"name": spec.served_model_name}, + headers=_runtime_headers(spec), + timeout=30.0, + ) + torch.cuda.synchronize() + except BaseException as exc: + update_error = exc + finally: + if pause_generation: + try: + _post_with_retry( + client.post, + f"{spec.vllm_base_url}/resume", + phase="resume generation", + headers=_runtime_headers(spec), + timeout=30.0, + ) + except BaseException as exc: + resume_error = exc + _sync_rank_zero_status( + rank=rank, + world_size=world_size, + phase="update merged weights", + error=update_error, + ) + _sync_rank_zero_status( + rank=rank, + world_size=world_size, + phase="resume generation", + error=resume_error, + ) + else: + _sync_rank_zero_status( + rank=rank, + world_size=world_size, + phase="pause generation", + error=None, + ) + _drain_merged_vllm_weights(weight_export) + _sync_rank_zero_status( + rank=rank, + world_size=world_size, + phase="update merged weights", + error=None, + ) + _sync_rank_zero_status( + rank=rank, + world_size=world_size, + phase="resume generation", + error=None, + ) + return merged_weight_transfer_group, merged_weight_transfer_init_info + + +__all__ = [ + "MergedWeightExport", + "build_art_conversion_tasks", + "build_merged_weight_export", + "ensure_merged_weight_transfer_group", + "iter_merged_vllm_weights", + "sync_merged_weights_to_vllm", +] diff --git a/src/art/megatron/weights/param_name_canonicalization.py b/src/art/megatron/weights/param_name_canonicalization.py new file mode 100644 index 000000000..7e20624dd --- /dev/null +++ b/src/art/megatron/weights/param_name_canonicalization.py @@ -0,0 +1,54 @@ +def is_art_adapter_param_name(name: str) -> bool: + return any( + segment in name + for segment in ( + ".lora.", + ".q_proj_lora.", + ".k_proj_lora.", + ".v_proj_lora.", + ".qkv_lora.", + ".z_lora.", + ".gate_lora.", + ".up_lora.", + ) + ) + + +def canonical_art_param_name(name: str) -> str: + segments = name.split(".") + while segments and segments[0] == "module": + segments = segments[1:] + + canonical: list[str] = [] + i = 0 + while i < len(segments): + if segments[i] == "_orig_mod": + i += 1 + continue + if i + 1 < len(segments): + current = segments[i] + nxt = segments[i + 1] + if ( + current + in { + "linear_proj", + "linear_qkv", + "in_proj", + "linear_fc1", + "linear_fc2", + } + and nxt == current + ): + canonical.append(current) + i += 2 + continue + if current == "out_proj" and nxt == "linear_proj": + canonical.append(current) + i += 2 + continue + if current == "row_parallel_lora" and nxt == "linear_proj": + i += 2 + continue + canonical.append(segments[i]) + i += 1 + return ".".join(canonical) diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index 99ec77d76..2aa7fd992 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -306,11 +306,6 @@ def _validate_backend_support(self) -> None: "PipelineTrainer + LocalBackend(dedicated) requires " "loss_fn_config=None." ) - if not self.normalize_advantages: - raise ValueError( - "PipelineTrainer + LocalBackend(dedicated) requires " - "normalize_advantages=True." - ) if self.adam_params is not None: raise ValueError( "PipelineTrainer + LocalBackend(dedicated) requires adam_params=None." @@ -767,7 +762,6 @@ def _trigger_collapse(self) -> None: async def _log_zero_variance_groups(self, step: int) -> None: if not self._discard_queue: return - # Cap logged groups to avoid huge parquet files when many discards accumulate. discarded = list(self._discard_queue[:50]) await self.model.log(discarded, split="discarded", step=step) self._discard_queue.clear() diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index 54f81d81a..7fe6b215c 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -14,6 +14,7 @@ from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase from ..trajectories import History, Trajectory, TrajectoryGroup, get_messages +from ..types import MessagesAndChoices ChatTemplateTool = dict[Any, Any] | Callable[..., Any] @@ -71,6 +72,14 @@ def _normalize_tool_call_arguments_for_chat_template( return normalized_messages +def _messages_for_chat_template( + tokenizer: PreTrainedTokenizerBase, + messages_and_choices: MessagesAndChoices, +) -> list[dict[str, Any]]: + messages = cast(list[dict[str, Any]], get_messages(messages_and_choices)) + return _normalize_tool_call_arguments_for_chat_template(tokenizer, messages) + + @dataclass class TokenizedResult: advantage: float @@ -274,23 +283,21 @@ def tokenize_trajectory( if last_assistant_index == -1: return None messages_and_choices = history.messages_and_choices[: last_assistant_index + 1] - messages = cast(list[dict[str, Any]], get_messages(messages_and_choices)) - # Qwen3.5's chat template uses `tool_call.arguments|items`, so it needs a - # mapping here instead of the OpenAI JSON string. - messages = _normalize_tool_call_arguments_for_chat_template(tokenizer, messages) + messages = _messages_for_chat_template(tokenizer, messages_and_choices) tools = _normalize_tools_for_chat_template(history.tools) + chat_template_kwargs = ( + {"enable_thinking": False} + if _chat_template_disables_thinking(tokenizer) + else {} + ) chat = cast( str, - tokenizer.apply_chat_template( + cast(Any, tokenizer).apply_chat_template( messages, tools=tools, continue_final_message=True, tokenize=False, - **( - {"enable_thinking": False} - if _chat_template_disables_thinking(tokenizer) - else {} - ), # type: ignore[arg-type] + **chat_template_kwargs, ), ) original_token_ids = _apply_chat_template_token_ids( @@ -298,11 +305,7 @@ def tokenize_trajectory( messages, tools=tools, continue_final_message=True, - **( - {"enable_thinking": False} - if _chat_template_disables_thinking(tokenizer) - else {} - ), + **chat_template_kwargs, ) sentinel_token_id = max(set(range(tokenizer.vocab_size)) - set(original_token_ids)) sentinel_token = tokenizer.decode(sentinel_token_id) @@ -334,11 +337,7 @@ def tokenize_trajectory( token_template_messages, tools=tools, continue_final_message=True, - **( - {"enable_thinking": False} - if _chat_template_disables_thinking(tokenizer) - else {} - ), + **chat_template_kwargs, ) assistant_mask: list[int] = [0] * len(token_ids) logprobs = [float("nan")] * len(token_ids) @@ -529,16 +528,25 @@ def tokenize_sft_batch( num_trainable_tokens = 0 num_dropped_trajectories = 0 for trajectory in trajectory_batch: - messages = trajectory.messages_and_choices - tools = trajectory.tools + messages = _messages_for_chat_template( + tokenizer, + trajectory.messages_and_choices, + ) + tools = _normalize_tools_for_chat_template(trajectory.tools) + chat_template_kwargs = ( + {"enable_thinking": False} + if _chat_template_disables_thinking(tokenizer) + else {} + ) # Single-step tokenization: apply_chat_template with tokenize=True input_ids = _apply_chat_template_token_ids( tokenizer, - cast(Any, messages), - tools=cast(Any, tools), + messages, + tools=tools, tokenize=True, add_generation_prompt=False, + **chat_template_kwargs, ) if max_seq_length is not None and len(input_ids) > max_seq_length: num_dropped_trajectories += 1 diff --git a/src/art/tinker/renderers.py b/src/art/tinker/renderers.py index d4de2946b..6ad8bcd81 100644 --- a/src/art/tinker/renderers.py +++ b/src/art/tinker/renderers.py @@ -1,7 +1,11 @@ +def is_qwen3_dot_family_model(base_model: str) -> bool: + return base_model.startswith("Qwen/Qwen3.") + + def get_renderer_name(base_model: str) -> str: if base_model.startswith("meta-llama/"): return "llama3" - elif base_model.startswith("Qwen/Qwen3."): + elif is_qwen3_dot_family_model(base_model): # print("Defaulting to Qwen3.5 renderer with thinking for", base_model) # print(renderer_name_message) return "qwen3_5_disable_thinking" diff --git a/src/art/tinker/server.py b/src/art/tinker/server.py index 1c7bdbb9a..328d9a976 100644 --- a/src/art/tinker/server.py +++ b/src/art/tinker/server.py @@ -34,7 +34,7 @@ import uvicorn from art.tinker.prefix_cache import LRUTrieCache -from art.tinker.renderers import get_renderer_name +from art.tinker.renderers import get_renderer_name, is_qwen3_dot_family_model from art.types import Message, Tools from mp_actors import close_proxy, move_to_child_process @@ -72,11 +72,11 @@ def _normalize_message_or_choice( return cast(Message, _MESSAGE_ADAPTER.validate_python(message_or_choice)) -def _normalize_qwen3_5_messages( +def _normalize_qwen3_dot_messages( base_model: str, messages: list[ChatCompletionMessageParam] ) -> list[dict[str, Any]]: normalized_messages = [cast(dict[str, Any], message) for message in messages] - if not base_model.startswith("Qwen/Qwen3."): + if not is_qwen3_dot_family_model(base_model): return normalized_messages for i, message in enumerate(normalized_messages): tool_calls = message.get("tool_calls") @@ -114,7 +114,7 @@ def _normalize_qwen3_5_messages( def _chat_template_disables_thinking(base_model: str) -> bool: - return base_model.startswith("Qwen/Qwen3.") + return is_qwen3_dot_family_model(base_model) @dataclass @@ -145,36 +145,47 @@ def models(self, models: dict[str, str]) -> None: async def start(self) -> tuple[str, int]: host = self.host or "0.0.0.0" port = self.port or get_free_port(host) - self._workers = [ - move_to_child_process( - OpenAICompatibleTinkerServerWorker(), - process_name=f"openai-compatible-tinker-server-worker-{i}", - ) - for i in range(self.num_workers or self._default_num_workers()) - ] - self._task = asyncio.create_task(self._run(host, port)) - client = AsyncOpenAI(api_key="default", base_url=f"http://{host}:{port}/v1") - start = time.time() - while True: - timeout = float(os.environ.get("ART_SERVER_TIMEOUT", 300.0)) - if time.time() - start > timeout: - raise TimeoutError( - f"Unable to reach OpenAI-compatible server within {timeout} seconds. You can increase this timeout by setting the ART_SERVER_TIMEOUT environment variable." + try: + self._workers = [] + for i in range(self.num_workers or self._default_num_workers()): + self._workers.append( + move_to_child_process( + OpenAICompatibleTinkerServerWorker(), + process_name=f"openai-compatible-tinker-server-worker-{i}", + ) ) - try: - await client.completions.create(model="", prompt="") - break # Server is ready - except Exception: - await asyncio.sleep(0.1) - return host, port + self._task = asyncio.create_task(self._run(host, port)) + client = AsyncOpenAI(api_key="default", base_url=f"http://{host}:{port}/v1") + start = time.time() + while True: + timeout = float(os.environ.get("ART_SERVER_TIMEOUT", 300.0)) + if time.time() - start > timeout: + raise TimeoutError( + f"Unable to reach OpenAI-compatible server within {timeout} seconds. You can increase this timeout by setting the ART_SERVER_TIMEOUT environment variable." + ) + try: + await client.completions.create(model="", prompt="") + break # Server is ready + except Exception: + await asyncio.sleep(0.1) + return host, port + except BaseException: + await self.stop() + raise async def stop(self) -> None: - if self._task is not None: - self._task.cancel() - await self._task - self._task = None - for worker in self._workers: - close_proxy(worker) + try: + if self._task is not None: + self._task.cancel() + try: + await self._task + except (asyncio.CancelledError, Exception): + pass + self._task = None + finally: + for worker in self._workers: + close_proxy(worker) + self._workers.clear() def _get_request_tenant( self, request: Request @@ -543,7 +554,7 @@ async def prompt_tokens( messages: list[ChatCompletionMessageParam], tools: list[ChatCompletionToolUnionParam] | None, ) -> list[int]: - normalized_messages = _normalize_qwen3_5_messages(base_model, messages) + normalized_messages = _normalize_qwen3_dot_messages(base_model, messages) tokenizer = self._get_renderer(base_model).tokenizer if _chat_template_disables_thinking(base_model): encoding = tokenizer.apply_chat_template( diff --git a/src/art/tinker/service.py b/src/art/tinker/service.py index eed41810b..eff922d6b 100644 --- a/src/art/tinker/service.py +++ b/src/art/tinker/service.py @@ -48,13 +48,34 @@ async def start_openai_server( host=config.get("host") if config else None, port=config.get("port") if config else None, ) - self._server.models = state.models - with log_timing("Starting OpenAI-compatible Tinker server"): - return await self._server.start() + try: + self._server.models = state.models + with log_timing("Starting OpenAI-compatible Tinker server"): + return await self._server.start() + except BaseException: + await self.aclose() + raise async def vllm_engine_is_sleeping(self) -> bool: return False + async def aclose(self) -> None: + if self._server is not None: + await self._server.stop() + self._server = None + + def close(self) -> None: + if self._server is None: + return + if self._server._task is not None: + self._server._task.cancel() + from mp_actors import close_proxy + + for worker in self._server._workers: + close_proxy(worker) + self._server._workers.clear() + self._server = None + async def train( self, disk_packed_tensors: DiskPackedTensors, diff --git a/src/art/tinker_native/backend.py b/src/art/tinker_native/backend.py index 87c904fe1..2a9564f67 100644 --- a/src/art/tinker_native/backend.py +++ b/src/art/tinker_native/backend.py @@ -175,9 +175,14 @@ async def _tinker_sample_call(self, label: str, awaitable: Awaitable[T]) -> T: ) async def close(self) -> None: + tasks: list[asyncio.Task[None]] = [] for state in self._model_state.values(): if state.server_task is not None: state.server_task.cancel() + tasks.append(state.server_task) + state.server_task = None + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) async def register(self, model: Model) -> None: model.base_path = self._path diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index 2a5a60abf..a9c7a8078 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -3,18 +3,14 @@ import asyncio from dataclasses import dataclass, field from functools import cached_property -import json import logging import os +import socket import subprocess -import sys -from typing import Any, AsyncIterator, Literal, cast +from typing import Any, AsyncIterator, Literal, TypedDict, cast import torch from trl import GRPOTrainer -from vllm import AsyncEngineArgs -from vllm.lora.request import LoRARequest -from vllm.v1.engine.async_llm import AsyncLLM from .. import dev, types from ..dev.validate import is_dedicated_mode @@ -24,9 +20,25 @@ from ..preprocessing.tokenize import SFTBatch from ..utils.convert_moe_lora import convert_checkpoint_if_needed from ..utils.get_model_step import get_step_from_dir -from ..utils.network import find_free_tcp_port +from ..utils.lifecycle import ( + ChildProcessSupervisor, + ServiceLifecycle, + managed_process_cmd, + terminate_popen_process_group, +) from ..utils.output_dirs import get_step_checkpoint_dir -from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers +from ..vllm_runtime import ( + VllmRuntimeLaunchConfig, + build_vllm_runtime_server_cmd, + get_vllm_runtime_working_dir, + wait_for_vllm_runtime, +) +from ..weight_transfer import ( + DEFAULT_PACKED_BUFFER_SIZE_BYTES, + DEFAULT_PACKED_NUM_BUFFERS, + trainer_init, + trainer_send_weights, +) from .train import ( UnslothTrainContext, create_unsloth_train_context, @@ -38,6 +50,10 @@ logger = logging.getLogger(__name__) +class _RuntimeRequestKwargs(TypedDict, total=False): + headers: dict[str, str] + + def save_checkpoint( trainer: "GRPOTrainer", output_dir: str, @@ -83,6 +99,12 @@ def save_checkpoint( return checkpoint_dir +def _find_free_tcp_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + def _normalize_merged_checkpoint_name(name: str) -> str: # PEFT wraps adapted modules under `.base_layer`, but vLLM expects the # original checkpoint parameter names during update_weights(). @@ -92,9 +114,6 @@ def _normalize_merged_checkpoint_name(name: str) -> str: return normalized -_find_free_tcp_port = find_free_tcp_port - - # ============================================================================ # Service # ============================================================================ @@ -108,16 +127,30 @@ class UnslothService: output_dir: str _is_sleeping: bool = False _latest_step: int = 0 - _lora_id_counter: int = 1 # Start from 1 since 0 is reserved # Dedicated mode subprocess state _vllm_process: subprocess.Popen | None = field(default=None, repr=False) # type: ignore[type-arg] _vllm_log_file: Any = field(default=None, repr=False) + _vllm_log_path: str | None = None _vllm_host: str = "127.0.0.1" _vllm_port: int = 0 + _vllm_api_key: str | None = None _weight_transfer_group: Any = field(default=None, init=False, repr=False) - _server_task: asyncio.Task[None] | None = field( - default=None, init=False, repr=False + _lifecycle: ServiceLifecycle = field( + default_factory=ServiceLifecycle, + init=False, + repr=False, ) + _child_processes: ChildProcessSupervisor = field(init=False, repr=False) + + def __post_init__(self) -> None: + self._child_processes = ChildProcessSupervisor(self._on_child_process_exit) + + def _on_child_process_exit(self, error: RuntimeError) -> None: + logger.error("%s", error) + self.close() + + def _raise_if_child_failed(self) -> None: + self._child_processes.raise_if_failed() @property def is_dedicated(self) -> bool: @@ -133,22 +166,61 @@ def rollout_weights_mode(self) -> Literal["lora", "merged"]: def _vllm_base_url(self) -> str: return f"http://{self._vllm_host}:{self._vllm_port}" - def _next_lora_id(self) -> int: - """Return a new unique LoRA ID to avoid collisions in vLLM.""" - self._lora_id_counter += 1 - return self._lora_id_counter + def _runtime_cuda_visible_devices(self) -> str: + if self.is_dedicated: + return ",".join(str(gpu_id) for gpu_id in self.config["inference_gpu_ids"]) + if visible := os.environ.get("CUDA_VISIBLE_DEVICES"): + return visible + return ",".join(str(index) for index in range(torch.cuda.device_count())) + + def _runtime_engine_args( + self, config: dev.OpenAIServerConfig | None + ) -> dict[str, object]: + engine_args = dict(self.config.get("engine_args", {})) + if config and "engine_args" in config: + engine_args.update(dict(config["engine_args"])) + engine_args.setdefault("generation_config", "vllm") + if self.rollout_weights_mode == "merged": + engine_args["weight_transfer_config"] = {"backend": "nccl"} + engine_args.pop("enable_lora", None) + engine_args.pop("max_loras", None) + else: + engine_args["enable_lora"] = True + engine_args.setdefault("max_loras", 2) + for key in ("model", "served_model_name"): + engine_args.pop(key, None) + return engine_args + + def _runtime_server_args( + self, config: dev.OpenAIServerConfig | None + ) -> dict[str, object]: + server_args: dict[str, object] = { + "return_tokens_as_token_ids": True, + "enable_auto_tool_choice": True, + "tool_call_parser": "hermes", + } + if config and "server_args" in config: + server_args.update(dict(config["server_args"])) + for key in ("port", "host", "lora_modules"): + server_args.pop(key, None) + return server_args + + def _runtime_headers(self) -> dict[str, str]: + if self._vllm_api_key is None: + return {} + return {"Authorization": f"Bearer {self._vllm_api_key}"} + + def _runtime_request_kwargs(self) -> _RuntimeRequestKwargs: + headers = self._runtime_headers() + return {"headers": headers} if headers else {} + + def _sleep_mode_enabled(self) -> bool: + return bool(self.config.get("engine_args", {}).get("enable_sleep_mode", True)) async def aclose(self) -> None: state = self.__dict__.get("_state") if isinstance(state, UnslothTrainContext): await state.stop_background_training() - if self._server_task is not None: - self._server_task.cancel() - try: - await self._server_task - except asyncio.CancelledError: - pass - self._server_task = None self.close() # ========================================================================= @@ -161,107 +233,104 @@ async def _start_vllm_subprocess( port: int, config: dev.OpenAIServerConfig | None = None, ) -> tuple[str, int]: - """Launch vLLM as a subprocess on inference GPUs. Returns (host, port).""" - import atexit - - inference_gpu_ids = self.config["inference_gpu_ids"] - cuda_devices = ",".join(str(g) for g in inference_gpu_ids) - - # Build server_args: ART defaults, then user overrides, strip CLI-handled keys - server_args: dict[str, object] = { - "return_tokens_as_token_ids": True, - "enable_auto_tool_choice": True, - "tool_call_parser": "hermes", - } - if config and "server_args" in config: - server_args.update(dict(config["server_args"])) - for key in ("port", "host", "lora_modules", "api_key"): - server_args.pop(key, None) - - # Build engine_args: model-level config, then user server overrides, - # add dedicated-mode defaults, strip CLI-handled keys - engine_args = dict(self.config.get("engine_args", {})) - if config and "engine_args" in config: - engine_args.update(dict(config["engine_args"])) - engine_args.setdefault("generation_config", "vllm") - if self.rollout_weights_mode == "merged": - engine_args["weight_transfer_config"] = {"backend": "nccl"} - engine_args.pop("enable_lora", None) - engine_args.pop("max_loras", None) - else: - engine_args["enable_lora"] = True - engine_args.setdefault("max_loras", 2) - for key in ("model", "served_model_name", "enable_sleep_mode"): - engine_args.pop(key, None) - - cmd = [ - sys.executable, - "-m", - "art.vllm.dedicated_server", - f"--model={self.base_model}", - f"--port={port}", - f"--host={self._vllm_host}", - f"--cuda-visible-devices={cuda_devices}", - f"--lora-path={lora_path}", - f"--served-model-name={self.model_name}@{self._latest_step}", - f"--rollout-weights-mode={self.rollout_weights_mode}", - f"--engine-args-json={json.dumps(engine_args)}", - f"--server-args-json={json.dumps(server_args)}", - ] + self._raise_if_child_failed() + server_args = self._runtime_server_args(config) + api_key = server_args.get("api_key") + self._vllm_api_key = api_key if isinstance(api_key, str) else None + cmd = build_vllm_runtime_server_cmd( + VllmRuntimeLaunchConfig( + base_model=self.base_model, + port=port, + host=self._vllm_host, + cuda_visible_devices=self._runtime_cuda_visible_devices(), + lora_path=lora_path, + served_model_name=f"{self.model_name}@{self._latest_step}", + rollout_weights_mode=self.rollout_weights_mode, + engine_args=self._runtime_engine_args(config), + server_args=server_args, + ) + ) + self._lifecycle.install_parent_cleanup(self.close) log_dir = os.path.join(self.output_dir, "logs") os.makedirs(log_dir, exist_ok=True) - self._vllm_log_file = open( - os.path.join(log_dir, "vllm-dedicated.log"), "w", buffering=1 - ) + self._vllm_log_path = os.path.join(log_dir, "vllm-runtime.log") + self._vllm_log_file = open(self._vllm_log_path, "w", buffering=1) self._vllm_process = subprocess.Popen( - cmd, stdout=self._vllm_log_file, stderr=subprocess.STDOUT, bufsize=1 + managed_process_cmd(cmd), + cwd=str(get_vllm_runtime_working_dir()), + env=os.environ.copy(), + stdout=self._vllm_log_file, + stderr=subprocess.STDOUT, + bufsize=1, + start_new_session=True, ) self._vllm_port = port import httpx - timeout = float(os.environ.get("ART_DEDICATED_VLLM_TIMEOUT", 600)) - poll_interval = 1.0 - elapsed = 0.0 + timeout = float(os.environ.get("ART_DEDICATED_VLLM_TIMEOUT", 1200)) async with httpx.AsyncClient() as client: - while elapsed < timeout: - if self._vllm_process.poll() is not None: - raise RuntimeError( - f"vLLM subprocess exited with code {self._vllm_process.returncode}. " - f"Check logs at {log_dir}/vllm-dedicated.log" - ) - try: - resp = await client.get( - f"http://{self._vllm_host}:{self._vllm_port}/v1/models", - timeout=5.0, - ) - if resp.status_code == 200: - break - except (httpx.ConnectError, httpx.ReadTimeout): - pass - await asyncio.sleep(poll_interval) - elapsed += poll_interval - else: + try: + await wait_for_vllm_runtime( + process=self._vllm_process, + host=self._vllm_host, + port=self._vllm_port, + timeout=timeout, + ) + except TimeoutError as exc: self.close() raise TimeoutError( f"vLLM subprocess did not become ready within {timeout}s. " - f"Check logs at {log_dir}/vllm-dedicated.log" + f"Check logs at {log_dir}/vllm-runtime.log" + ) from exc + except RuntimeError as exc: + returncode = self._vllm_process.returncode + self.close() + raise RuntimeError( + f"vLLM subprocess exited with code {returncode}. " + f"Check logs at {log_dir}/vllm-runtime.log" + ) from exc + + try: + resp = await client.get( + f"http://{self._vllm_host}:{self._vllm_port}/v1/models", + **self._runtime_request_kwargs(), + timeout=5.0, ) + resp.raise_for_status() + except httpx.HTTPError as exc: + self.close() + raise RuntimeError( + "vLLM passed /health but /v1/models was not reachable. " + f"Check logs at {log_dir}/vllm-runtime.log" + ) from exc - atexit.register(self.close) - logger.info("vLLM subprocess ready on port %d (GPUs: %s)", port, cuda_devices) + assert self._vllm_process is not None + assert self._vllm_log_path is not None + self._child_processes.watch_popen( + "vLLM runtime", + self._vllm_process, + log_path=self._vllm_log_path, + ) + logger.info( + "vLLM runtime ready on port %d (GPUs: %s)", + port, + self._runtime_cuda_visible_devices(), + ) return self._vllm_host, self._vllm_port async def _set_served_model_name(self, step: int) -> None: import httpx + self._raise_if_child_failed() served_model_name = f"{self.model_name}@{step}" async with httpx.AsyncClient() as client: response = await client.post( f"{self._vllm_base_url}/art/set_served_model_name", json={"name": served_model_name}, + **self._runtime_request_kwargs(), timeout=30.0, ) response.raise_for_status() @@ -272,16 +341,15 @@ async def _set_served_model_name(self, step: int) -> None: async def _init_merged_weight_transfer(self) -> None: import httpx - from vllm.distributed.weight_transfer.nccl_engine import ( - NCCLWeightTransferEngine, - ) + self._raise_if_child_failed() if self._weight_transfer_group is not None: return async with httpx.AsyncClient() as client: world_size_response = await client.get( f"{self._vllm_base_url}/get_world_size", + **self._runtime_request_kwargs(), timeout=30.0, ) try: @@ -293,7 +361,7 @@ async def _init_merged_weight_transfer(self) -> None: ) from exc inference_world_size = int(world_size_response.json()["world_size"]) - master_port = find_free_tcp_port() + master_port = _find_free_tcp_port() init_info = { "master_address": "127.0.0.1", "master_port": master_port, @@ -305,13 +373,12 @@ async def _init_merged_weight_transfer(self) -> None: client.post( f"{self._vllm_base_url}/init_weight_transfer_engine", json={"init_info": init_info}, + **self._runtime_request_kwargs(), timeout=300.0, ) ) - # TODO: replace this with a real readiness handshake if this ever flakes. - await asyncio.sleep(1.0) self._weight_transfer_group = await asyncio.to_thread( - NCCLWeightTransferEngine.trainer_init, + trainer_init, { "master_address": init_info["master_address"], "master_port": init_info["master_port"], @@ -359,10 +426,8 @@ async def _sync_merged_weights( pause_generation: bool, ) -> None: import httpx - from vllm.distributed.weight_transfer.nccl_engine import ( - NCCLWeightTransferEngine, - ) + self._raise_if_child_failed() assert self._weight_transfer_group is not None peft_model = self._state.peft_model @@ -376,6 +441,7 @@ async def _sync_merged_weights( response = await client.post( f"{self._vllm_base_url}/pause", params={"mode": "wait"}, + **self._runtime_request_kwargs(), timeout=300.0, ) response.raise_for_status() @@ -385,6 +451,13 @@ async def _sync_merged_weights( torch.cuda.synchronize() weights = self._merged_checkpoint_weights_for_vllm() + response = await client.post( + f"{self._vllm_base_url}/start_weight_update", + json={"is_checkpoint_format": True}, + **self._runtime_request_kwargs(), + timeout=300.0, + ) + response.raise_for_status() update_info = { "names": [name for name, _ in weights], "dtype_names": [ @@ -392,18 +465,26 @@ async def _sync_merged_weights( for _, tensor in weights ], "shapes": [list(tensor.shape) for _, tensor in weights], - "is_checkpoint_format": True, + "packed": True, + "packed_buffer_size_bytes": DEFAULT_PACKED_BUFFER_SIZE_BYTES, + "packed_num_buffers": DEFAULT_PACKED_NUM_BUFFERS, } _, update_response = await asyncio.gather( asyncio.to_thread( - NCCLWeightTransferEngine.trainer_send_weights, + trainer_send_weights, iter(weights), - {"group": self._weight_transfer_group}, + { + "group": self._weight_transfer_group, + "packed": True, + "packed_buffer_size_bytes": DEFAULT_PACKED_BUFFER_SIZE_BYTES, + "packed_num_buffers": DEFAULT_PACKED_NUM_BUFFERS, + }, ), client.post( f"{self._vllm_base_url}/update_weights", json={"update_info": update_info}, + **self._runtime_request_kwargs(), timeout=600.0, ), ) @@ -414,6 +495,12 @@ async def _sync_merged_weights( "Merged rollout weights require a vLLM build with the " "/update_weights endpoint" ) from exc + response = await client.post( + f"{self._vllm_base_url}/finish_weight_update", + **self._runtime_request_kwargs(), + timeout=600.0, + ) + response.raise_for_status() self._latest_step = step await self._set_served_model_name(step) except Exception as exc: @@ -427,6 +514,7 @@ async def _sync_merged_weights( try: response = await client.post( f"{self._vllm_base_url}/resume", + **self._runtime_request_kwargs(), timeout=30.0, ) response.raise_for_status() @@ -446,6 +534,7 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None: """Reload LoRA adapter in vLLM subprocess via HTTP.""" import httpx + self._raise_if_child_failed() lora_name = f"{self.model_name}@{step}" logger.info( f"[DEDICATED] _reload_adapter START: lora_name={lora_name} " @@ -459,6 +548,7 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None: "lora_path": checkpoint_path, "load_inplace": True, }, + **self._runtime_request_kwargs(), timeout=60.0, ) response.raise_for_status() @@ -468,23 +558,21 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None: ) def close(self) -> None: - """Terminate vLLM subprocess and cancel server task if running.""" - self._weight_transfer_group = None - if self._server_task is not None: - self._server_task.cancel() - self._server_task = None - if self._vllm_process is None: + """Terminate vLLM subprocess if running.""" + if not self._lifecycle.begin_close(): return - self._vllm_process.terminate() + self._weight_transfer_group = None try: - self._vllm_process.wait(timeout=5) - except subprocess.TimeoutExpired: - self._vllm_process.kill() - self._vllm_process.wait() - self._vllm_process = None - if self._vllm_log_file is not None: - self._vllm_log_file.close() - self._vllm_log_file = None + self._child_processes.close() + if self._vllm_process is not None: + terminate_popen_process_group(self._vllm_process) + self._vllm_process = None + if self._vllm_log_file is not None: + self._vllm_log_file.close() + self._vllm_log_file = None + self._vllm_log_path = None + finally: + self._lifecycle.restore_parent_cleanup() # ========================================================================= # start_openai_server @@ -493,6 +581,7 @@ def close(self) -> None: async def start_openai_server( self, config: dev.OpenAIServerConfig | None ) -> tuple[str, int]: + self._raise_if_child_failed() lora_path = get_last_checkpoint_dir(self.output_dir) if lora_path is None: lora_path = get_step_checkpoint_dir(self.output_dir, 0) @@ -503,73 +592,66 @@ async def start_openai_server( else: self._latest_step = get_step_from_dir(self.output_dir) - if self.is_dedicated: - port = (config or {}).get("server_args", {}).get("port", 8000) - vllm_location = await self._start_vllm_subprocess( - lora_path, - port, - config=config, - ) + if not self.is_dedicated: + if not self._sleep_mode_enabled(): + raise ValueError( + "Shared-GPU mode requires engine_args.enable_sleep_mode=True " + "for the external vLLM runtime" + ) + self._state.offload_to_cpu() + + port = (config or {}).get("server_args", {}).get("port", 8000) + vllm_location = await self._start_vllm_subprocess( + lora_path, + port, + config=config, + ) + try: if self.rollout_weights_mode == "merged": _ = self._state await self._init_merged_weight_transfer() await self._sync_merged_weights(self._latest_step, False) - return vllm_location - - # Shared mode: in-process vLLM - self._state.offload_to_cpu() - - server_config = dev.get_openai_server_config( - model_name=self.model_name, - base_model=self.base_model, - log_file=f"{self.output_dir}/logs/vllm.log", - lora_path=lora_path, - config=config, - ) - self._server_task = await openai_server_task( - engine=await self.llm, - config=server_config, - ) - return server_config.get("server_args", {}).get( - "host" - ) or "0.0.0.0", server_config.get("server_args", {}).get("port", 8000) + except BaseException: + await self.aclose() + raise + return vllm_location async def vllm_engine_is_sleeping(self) -> bool: - if self.is_dedicated: - return False return self._is_sleeping - async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: - """Register a LoRA adapter for a specific checkpoint step. - This is called when training is skipped but the checkpoint is renamed. - """ - logger.info( - f"[DEDICATED] register_lora_for_step called: step={step} " - f"checkpoint_dir={checkpoint_dir} is_dedicated={self.is_dedicated}" - ) - if self.is_dedicated: - if self.rollout_weights_mode == "merged": - await self._set_served_model_name(step) - else: - await self._reload_adapter(checkpoint_dir, step) - self._latest_step = step - return + async def _sleep_runtime(self) -> None: + import httpx - llm = await self.llm - await llm.pause_generation() - added = await llm.add_lora( - LoRARequest( - lora_name=f"{self.model_name}@{step}", - lora_int_id=self._next_lora_id(), - lora_path=checkpoint_dir, + self._raise_if_child_failed() + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self._vllm_base_url}/sleep", + params={"level": 1, "mode": "wait"}, + **self._runtime_request_kwargs(), + timeout=300.0, ) - ) - if not added: - raise RuntimeError( - f"Failed to add LoRA adapter for step {step} at {checkpoint_dir}" + response.raise_for_status() + self._is_sleeping = True + + async def _wake_runtime(self) -> None: + import httpx + + self._raise_if_child_failed() + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self._vllm_base_url}/wake_up", + **self._runtime_request_kwargs(), + timeout=300.0, ) + response.raise_for_status() + self._is_sleeping = False + + async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: + if self.rollout_weights_mode == "merged": + await self._set_served_model_name(step) + else: + await self._reload_adapter(checkpoint_dir, step) self._latest_step = step - await llm.resume_generation() async def train( self, @@ -578,17 +660,22 @@ async def train( _config: dev.TrainConfig, verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: - if self.is_dedicated: - async for result in self._train_dedicated( + try: + self._raise_if_child_failed() + if self.is_dedicated: + async for result in self._train_dedicated( + disk_packed_tensors, config, _config, verbose + ): + yield result + return + + async for result in self._train_shared( disk_packed_tensors, config, _config, verbose ): yield result - return - - async for result in self._train_shared( - disk_packed_tensors, config, _config, verbose - ): - yield result + except BaseException: + await self.aclose() + raise async def _train_dedicated( self, @@ -638,29 +725,8 @@ async def _train_shared( _config: dev.TrainConfig, verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: - """Train in shared mode — sleep/wake cycle with in-process vLLM.""" - llm = await self.llm - - # Pause generation to prevent new requests during training - await llm.pause_generation() - - # Determine sleep level based on outstanding requests: - # - level 1: offload KV cache to CPU (can resume with existing KV state) - # - level 2: discard KV cache (fresh start after wake) - has_unfinished = llm.output_processor.has_unfinished_requests() - if has_unfinished: - sleep_level = 1 - else: - # Reset prefix cache before discarding KV cache - await llm.reset_prefix_cache() - sleep_level = 2 - - # Put workers to sleep - await run_on_workers(llm, do_sleep, level=sleep_level) - self._is_sleeping = True + await self._sleep_runtime() gc_and_empty_cuda_cache() - - # Reload training model to GPU (after vLLM is asleep) self._state.reload_to_gpu() async for result in run_unsloth_rl_training( @@ -672,48 +738,21 @@ async def _train_shared( ): yield result - # Save checkpoint after training checkpoint_dir = save_checkpoint( trainer=self._state.trainer, output_dir=self.output_dir, verbose=verbose, ) - # Offload training model to CPU before waking vLLM self._state.offload_to_cpu() - - # Free memory before waking up vLLM gc_and_empty_cuda_cache() - await asyncio.sleep( - 0.5 - ) # Longer delay to allow memory cleanup and pending ops to complete - - # Wake up workers - await run_on_workers(llm, do_wake_up) - self._is_sleeping = False + await asyncio.sleep(0.5) + await self._wake_runtime() - # Determine the new step from the checkpoint directory - # checkpoint_dir format is: {output_dir}/checkpoints/{step:04d} new_step = int(os.path.basename(checkpoint_dir)) - - # Add the new LoRA adapter - # We keep old LoRAs loaded - vLLM will page them out as needed - added = await llm.add_lora( - LoRARequest( - lora_name=f"{self.model_name}@{new_step}", - lora_int_id=self._next_lora_id(), - lora_path=checkpoint_dir, - ) - ) - if not added: - raise RuntimeError( - f"Failed to add LoRA adapter for step {new_step} at {checkpoint_dir}" - ) + await self._reload_adapter(checkpoint_dir, new_step) self._latest_step = new_step - # Resume generation after LoRA add is complete - await llm.resume_generation() - if verbose: print("UnslothService.train complete") @@ -737,223 +776,58 @@ async def train_sft( Yields: Dictionary containing training metrics for each batch. """ - if self.is_dedicated: - raise NotImplementedError( - "train_sft is not yet supported in dedicated mode" - ) - import time - - llm = await self.llm - - # === Setup === - # Pause generation to prevent new requests during training - await llm.pause_generation() - - # Determine sleep level based on outstanding requests - has_unfinished = llm.output_processor.has_unfinished_requests() - if has_unfinished: - sleep_level = 1 - else: - await llm.reset_prefix_cache() - sleep_level = 2 - - # Put workers to sleep - await run_on_workers(llm, do_sleep, level=sleep_level) - self._is_sleeping = True - gc_and_empty_cuda_cache() - - # Reload training model to GPU (after vLLM is asleep) - self._state.reload_to_gpu() - if verbose: - print("SFT training started") - - async for result in run_unsloth_sft_training( - self._state, - batches, - verbose=verbose, - max_grad_norm=1.0, - ): - yield { - "loss/train": result["loss"], - "loss/learning_rate": result["learning_rate"], - "loss/grad_norm": result["grad_norm"], - } - - # === Cleanup === - # Save checkpoint after training - checkpoint_dir = save_checkpoint( - trainer=self._state.trainer, - output_dir=self.output_dir, - verbose=verbose, - ) - - # Offload training model to CPU before waking vLLM - self._state.offload_to_cpu() - - # Free memory before waking up vLLM - gc_and_empty_cuda_cache() - await asyncio.sleep(0.5) + try: + self._raise_if_child_failed() + if self.is_dedicated: + raise NotImplementedError( + "train_sft is not yet supported in dedicated mode" + ) - # Wake up workers - await run_on_workers(llm, do_wake_up) - self._is_sleeping = False + await self._sleep_runtime() + gc_and_empty_cuda_cache() + self._state.reload_to_gpu() + if verbose: + print("SFT training started") + + async for result in run_unsloth_sft_training( + self._state, + batches, + verbose=verbose, + max_grad_norm=1.0, + ): + yield { + "loss/train": result["loss"], + "loss/learning_rate": result["learning_rate"], + "loss/grad_norm": result["grad_norm"], + } - # Add the new LoRA adapter - new_step = int(os.path.basename(checkpoint_dir)) - added = await llm.add_lora( - LoRARequest( - lora_name=f"{self.model_name}@{new_step}", - lora_int_id=self._next_lora_id(), - lora_path=checkpoint_dir, + checkpoint_dir = save_checkpoint( + trainer=self._state.trainer, + output_dir=self.output_dir, + verbose=verbose, ) - ) - if not added: - raise RuntimeError( - f"Failed to add LoRA adapter for step {new_step} at {checkpoint_dir}" - ) - self._latest_step = new_step - # Resume generation after LoRA swap is complete - await llm.resume_generation() + self._state.offload_to_cpu() + gc_and_empty_cuda_cache() + await asyncio.sleep(0.5) + await self._wake_runtime() + new_step = int(os.path.basename(checkpoint_dir)) + await self._reload_adapter(checkpoint_dir, new_step) + self._latest_step = new_step - if verbose: - print("SFT training finished") + if verbose: + print("SFT training finished") + except BaseException: + await self.aclose() + raise @cached_property def _state(self) -> UnslothTrainContext: init_args = dict(self.config.get("init_args", {})) checkpoint_dir = get_last_checkpoint_dir(self.output_dir) - if checkpoint_dir: - init_args["model_name"] = checkpoint_dir - else: - init_args["model_name"] = self.base_model + init_args["model_name"] = checkpoint_dir or self.base_model return create_unsloth_train_context( init_args=init_args, peft_args=cast(dict[str, Any], self.config.get("peft_args", {})), trainer_args=cast(dict[str, Any], self.config.get("trainer_args", {})), ) - - @cached_property - def llm(self) -> asyncio.Task[AsyncLLM]: - # Filter engine args to remove incompatible boolean flags - engine_args = { - **self.config.get("engine_args", {}), - "enable_lora": True, - "max_loras": self.config.get("engine_args", {}).get("max_loras", 2), - } - # Remove boolean flags that vLLM's argparse doesn't accept as =False - for key in ["enable_log_requests", "disable_log_requests"]: - engine_args.pop(key, None) - return asyncio.create_task(get_llm(AsyncEngineArgs(**engine_args))) # ty:ignore[invalid-argument-type] - - -# ============================================================================ -# Worker Sleep/Wake Functions -# ============================================================================ - - -def do_sleep(*, level: int) -> None: - """ - Put the worker to sleep, offloading both weights and KV cache. - - Args: - level: The sleep level: - - 1: offload KV cache to CPU (can resume with existing KV state) - - 2: discard KV cache (fresh start after wake) - """ - import ctypes - import gc - - import torch - from vllm.device_allocator.cumem import ( - CuMemAllocator, - libcudart, - unmap_and_release, - ) - - try: - from vllm.utils.platform_utils import is_pin_memory_available - except ImportError: - from vllm.utils import is_pin_memory_available - - worker = get_worker() - allocator = CuMemAllocator.get_instance() - - # Determine what to offload based on level: - # level=1: offload both weights and kv_cache to CPU - # level=2: offload weights, discard kv_cache - offload_to = "cpu" if level == 1 else "none" - tags_to_process = {"weights", "kv_cache"} - - # Save buffers before level 2 sleep (like vLLM does) - if level == 2: - model = worker.model_runner.model - worker._sleep_saved_buffers = { - name: buffer.cpu().clone() for name, buffer in model.named_buffers() - } - - for ptr, data in allocator.pointer_to_data.items(): - if data.tag not in tags_to_process: - continue - handle = data.handle - size_in_bytes = handle[1] - - # Always backup weights; backup kv_cache only at level 1 - if offload_to != "none" or data.tag == "weights": - cpu_backup_tensor = torch.empty( - size_in_bytes, - dtype=torch.uint8, - device="cpu", - pin_memory=is_pin_memory_available(), - ) - cpu_ptr = cpu_backup_tensor.data_ptr() - libcudart.cudaMemcpy( # ty:ignore[possibly-missing-attribute] - ctypes.c_void_p(cpu_ptr), ctypes.c_void_p(ptr), size_in_bytes - ) - data.cpu_backup_tensor = cpu_backup_tensor - - unmap_and_release(handle) - - gc.collect() - torch.cuda.empty_cache() - - -def do_wake_up() -> None: - """ - Wake up the worker from sleep, restoring offloaded weights and KV cache. - """ - import ctypes - - from vllm.device_allocator.cumem import ( - CuMemAllocator, - create_and_map, - libcudart, - ) - - worker = get_worker() - allocator = CuMemAllocator.get_instance() - - tags_to_process = {"weights", "kv_cache"} - - for ptr, data in allocator.pointer_to_data.items(): - if data.tag not in tags_to_process: - continue - create_and_map(data.handle) - if data.cpu_backup_tensor is not None: - cpu_backup_tensor = data.cpu_backup_tensor - size_in_bytes = cpu_backup_tensor.numel() * cpu_backup_tensor.element_size() - cpu_ptr = cpu_backup_tensor.data_ptr() - libcudart.cudaMemcpy( # ty:ignore[possibly-missing-attribute] - ctypes.c_void_p(ptr), - ctypes.c_void_p(cpu_ptr), - size_in_bytes, - ) - data.cpu_backup_tensor = None - - # Restore buffers after level 2 sleep (like vLLM does) - if hasattr(worker, "_sleep_saved_buffers") and worker._sleep_saved_buffers: - model = worker.model_runner.model - for name, buffer in model.named_buffers(): - if name in worker._sleep_saved_buffers: - buffer.copy_(worker._sleep_saved_buffers[name].to(buffer.device)) - worker._sleep_saved_buffers = {} diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index 716ff3c23..46d4e410f 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -814,18 +814,7 @@ async def run_unsloth_rl_training( param_group["weight_decay"] = 0.1 packed_tensors = packed_tensors_from_dir(**disk_packed_tensors) - - if verbose: - unfinished = getattr(ctx.results_queue, "_unfinished_tasks", "?") - qsize = ctx.results_queue.qsize() - print( - f"[run_unsloth_rl_training] before join: warmup_pending={ctx.warmup_pending} " - f"train_task={'none' if ctx.train_task is None else ('done' if ctx.train_task.done() else 'running')} " - f"results_queue.qsize={qsize} results_queue._unfinished_tasks={unfinished}" - ) await ctx.results_queue.join() - if verbose: - print("[run_unsloth_rl_training] join complete") if ctx.train_task is None: ctx.train_task = asyncio.create_task( @@ -853,49 +842,23 @@ async def run_unsloth_rl_training( ctx.inputs_queue.put_nowait( create_train_inputs(packed_tensors, offset, config, _config, warmup) ) - if verbose: - print( - f"[run_unsloth_rl_training] put input (warmup={warmup}, offset={offset}), " - f"inputs_queue.qsize={ctx.inputs_queue.qsize()}" - ) + result_task = asyncio.create_task(ctx.results_queue.get()) done, pending = await asyncio.wait( [ - asyncio.create_task(ctx.results_queue.get()), + result_task, ctx.train_task, ], return_when=asyncio.FIRST_COMPLETED, ) - if verbose: - print( - f"[run_unsloth_rl_training] asyncio.wait returned: " - f"done={len(done)} tasks, train_task_in_done={ctx.train_task in done}" - ) - # Cancel any abandoned results_queue.get() tasks that didn't complete - for t in pending: - if t is not ctx.train_task: - t.cancel() + if result_task in pending: + result_task.cancel() if verbose: print( "Done waiting for a result from the queue or for the training task to, presumably, raise an exception" ) for task in done: - if task is ctx.train_task: - try: - result = task.result() - except Exception as e: - print( - f"[run_unsloth_rl_training] train_task raised exception: " - f"{type(e).__name__}: {e}" - ) - raise - print( - f"[run_unsloth_rl_training] train_task returned (should never happen): " - f"result={result!r}" - ) - assert result is not None, "The training task should never finish." - else: - result = task.result() + result = task.result() assert result is not None, "The training task should never finish." ctx.results_queue.task_done() if warmup: diff --git a/src/art/utils/convert_moe_lora.py b/src/art/utils/convert_moe_lora.py index 0ea80f63a..8f1bd982c 100644 --- a/src/art/utils/convert_moe_lora.py +++ b/src/art/utils/convert_moe_lora.py @@ -1,15 +1,14 @@ -"""Convert fused MoE LoRA adapters to per-expert format for vLLM compatibility. +"""Convert PEFT fused MoE LoRA target-parameter adapters for vLLM. Unsloth with transformers v5 saves MoE expert LoRA as fused 2D tensors: - mlp.experts.base_layer.lora_A [num_experts*rank, intermediate*2] (gate_up_proj) - mlp.experts.base_layer.lora_B [hidden, num_experts*rank] (gate_up_proj) - mlp.experts.lora_A [num_experts*rank, hidden] (down_proj) - mlp.experts.lora_B [intermediate, num_experts*rank] (down_proj) - -vLLM expects per-expert keys: - mlp.experts.0.gate_proj.lora_A [rank, hidden] - mlp.experts.0.gate_proj.lora_B [intermediate, rank] - ... + mlp.experts.base_layer.lora_A [num_experts*rank, intermediate*2] + mlp.experts.base_layer.lora_B [hidden, num_experts*rank] + mlp.experts.lora_A [num_experts*rank, hidden] + mlp.experts.lora_B [intermediate, num_experts*rank] + +vLLM's 3D MoE LoRA path expects the same fused keys with standard LoRA +orientation, so conversion swaps/transposes each A/B pair and keeps target +modules at "experts". """ import json @@ -20,67 +19,26 @@ import torch -def _has_fused_moe_lora(tensors: dict[str, torch.Tensor]) -> bool: - """Check if the adapter contains fused MoE LoRA tensors.""" +def _has_peft_fused_moe_lora( + tensors: dict[str, torch.Tensor], + adapter_config: dict, +) -> bool: + """Check if the adapter contains PEFT target-parameter fused MoE tensors.""" + if not adapter_config.get("target_parameters"): + return False return any( re.search(r"mlp\.experts\.(base_layer\.)?lora_[AB]\.weight$", key) for key in tensors ) -def _infer_moe_params( - tensors: dict[str, torch.Tensor], - adapter_config: dict, -) -> tuple[int, int, int, int]: - """Infer num_experts, rank, intermediate_size, hidden_size from tensor shapes.""" - rank = adapter_config.get("r", adapter_config.get("lora_rank", 8)) - - for key, tensor in tensors.items(): - # gate_up_proj lora_A: [num_experts*rank, intermediate*2] - if re.search(r"mlp\.experts\.base_layer\.lora_A\.weight$", key): - num_experts_times_rank = tensor.shape[0] - intermediate_times_2 = tensor.shape[1] - num_experts = num_experts_times_rank // rank - intermediate_size = intermediate_times_2 // 2 - break - # down_proj lora_B: [intermediate, num_experts*rank] - if re.search(r"mlp\.experts\.lora_B\.weight$", key): - intermediate_size = tensor.shape[0] - num_experts = tensor.shape[1] // rank - break - else: - raise ValueError("Could not find fused MoE tensors to infer parameters") - - # Get hidden_size from gate_up_proj lora_B: [hidden, num_experts*rank] - # or from down_proj lora_A: [num_experts*rank, hidden] - for key, tensor in tensors.items(): - if re.search(r"mlp\.experts\.base_layer\.lora_B\.weight$", key): - hidden_size = tensor.shape[0] - break - if re.search(r"mlp\.experts\.lora_A\.weight$", key): - hidden_size = tensor.shape[1] - break - else: - raise ValueError("Could not infer hidden_size from fused MoE tensors") - - return num_experts, rank, intermediate_size, hidden_size - - def convert_fused_moe_lora( tensors: dict[str, torch.Tensor], - num_experts: int, - rank: int, - intermediate_size: int, - hidden_size: int, ) -> dict[str, torch.Tensor]: - """Convert fused MoE LoRA tensors to per-expert format. - - Non-expert tensors (e.g. self_attn) are passed through unchanged. - """ + """Convert PEFT fused MoE LoRA tensors to vLLM's fused experts layout.""" new_tensors: dict[str, torch.Tensor] = {} for key, tensor in tensors.items(): - # Non-expert tensors: keep as-is m = re.match( r"(.*\.mlp\.experts)\.(base_layer\.lora_(A|B)|lora_(A|B))\.weight$", key, @@ -90,53 +48,16 @@ def convert_fused_moe_lora( continue prefix = m.group(1) - is_base_layer = "base_layer" in key - is_A = "lora_A" in key - - if is_base_layer: - # gate_up_proj (fused gate + up) - if is_A: - # [num_experts*rank, intermediate*2] → per expert - per_expert = tensor.reshape(num_experts, rank, intermediate_size * 2) - for e in range(num_experts): - expert_a = per_expert[e] # [rank, intermediate*2] - gate_a = expert_a[:, :intermediate_size] - up_a = expert_a[:, intermediate_size:] - new_tensors[f"{prefix}.{e}.gate_proj.lora_B.weight"] = ( - gate_a.T.contiguous() - ) - new_tensors[f"{prefix}.{e}.up_proj.lora_B.weight"] = ( - up_a.T.contiguous() - ) - else: - # [hidden, num_experts*rank] → per expert - per_expert = tensor.reshape(hidden_size, num_experts, rank) - for e in range(num_experts): - expert_b = per_expert[:, e, :] # [hidden, rank] - new_tensors[f"{prefix}.{e}.gate_proj.lora_A.weight"] = ( - expert_b.T.contiguous() - ) - new_tensors[f"{prefix}.{e}.up_proj.lora_A.weight"] = ( - expert_b.T.contiguous() - ) + if m.group(2) == "base_layer.lora_A": + new_tensors[f"{prefix}.base_layer.lora_B.weight"] = tensor.T.contiguous() + elif m.group(2) == "base_layer.lora_B": + new_tensors[f"{prefix}.base_layer.lora_A.weight"] = tensor.T.contiguous() + elif m.group(2) == "lora_A": + new_tensors[f"{prefix}.lora_B.weight"] = tensor.T.contiguous() + elif m.group(2) == "lora_B": + new_tensors[f"{prefix}.lora_A.weight"] = tensor.T.contiguous() else: - # down_proj - if is_A: - # [num_experts*rank, hidden] → per expert - per_expert = tensor.reshape(num_experts, rank, hidden_size) - for e in range(num_experts): - expert_a = per_expert[e] # [rank, hidden] - new_tensors[f"{prefix}.{e}.down_proj.lora_B.weight"] = ( - expert_a.T.contiguous() - ) - else: - # [intermediate, num_experts*rank] → per expert - per_expert = tensor.reshape(intermediate_size, num_experts, rank) - for e in range(num_experts): - expert_b = per_expert[:, e, :] # [intermediate, rank] - new_tensors[f"{prefix}.{e}.down_proj.lora_A.weight"] = ( - expert_b.T.contiguous() - ) + raise AssertionError(f"Unhandled MoE LoRA tensor key: {key}") return new_tensors @@ -153,28 +74,23 @@ def convert_checkpoint_if_needed(checkpoint_dir: str) -> None: return tensors = safetensors.torch.load_file(adapter_path) - if not _has_fused_moe_lora(tensors): - return - with open(config_path) as f: adapter_config = json.load(f) - num_experts, rank, intermediate_size, hidden_size = _infer_moe_params( - tensors, adapter_config - ) + if not _has_peft_fused_moe_lora(tensors, adapter_config): + return - new_tensors = convert_fused_moe_lora( - tensors, num_experts, rank, intermediate_size, hidden_size - ) + new_tensors = convert_fused_moe_lora(tensors) # Overwrite the adapter with the converted tensors safetensors.torch.save_file(new_tensors, adapter_path) # Update adapter_config.json target_modules adapter_config["target_modules"] = [ - m for m in adapter_config.get("target_modules", []) if "experts" not in m - ] + ["gate_proj", "up_proj", "down_proj"] - # Remove target_parameters if present (not needed for per-expert format) + m + for m in adapter_config.get("target_modules", []) + if m not in {"experts", "gate_proj", "up_proj", "down_proj"} + ] + ["experts"] adapter_config.pop("target_parameters", None) with open(config_path, "w") as f: diff --git a/src/art/utils/lifecycle.py b/src/art/utils/lifecycle.py new file mode 100644 index 000000000..6fe315659 --- /dev/null +++ b/src/art/utils/lifecycle.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +import asyncio +import atexit +from collections.abc import Awaitable, Callable, Sequence +import os +from pathlib import Path +import signal +import subprocess +import sys +import time +from typing import Any + + +def managed_process_cmd( + command: Sequence[str], *, parent_pid: int | None = None +) -> list[str]: + return [ + sys.executable, + str(Path(__file__).resolve().with_name("managed_process.py")), + "--parent-pid", + str(parent_pid or os.getpid()), + "--", + *command, + ] + + +def kill_process_group(pid: int, sig: signal.Signals) -> None: + try: + os.killpg(os.getpgid(pid), sig) + except ProcessLookupError: + pass + + +def terminate_popen_process_group( + process: subprocess.Popen[Any], + *, + timeout: float = 5.0, +) -> None: + if process.poll() is None: + kill_process_group(process.pid, signal.SIGTERM) + try: + process.wait(timeout=timeout) + except subprocess.TimeoutExpired: + kill_process_group(process.pid, signal.SIGKILL) + process.wait() + + +def terminate_asyncio_process_group(process: Any, *, timeout: float = 5.0) -> None: + if process.returncode is None: + kill_process_group(process.pid, signal.SIGTERM) + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + finished_pid, _ = os.waitpid(process.pid, os.WNOHANG) + except ChildProcessError: + return + if finished_pid: + return + time.sleep(0.05) + kill_process_group(process.pid, signal.SIGKILL) + try: + os.waitpid(process.pid, 0) + except ChildProcessError: + pass + + +class ChildProcessSupervisor: + def __init__(self, on_unexpected_exit: Callable[[RuntimeError], None]) -> None: + self._on_unexpected_exit = on_unexpected_exit + self._tasks: dict[str, asyncio.Task[None]] = {} + self._failure: RuntimeError | None = None + self._closing = False + + def watch_popen( + self, + name: str, + process: subprocess.Popen[Any], + *, + log_path: str, + ) -> None: + self._watch(name, lambda: self._wait_popen(process), log_path=log_path) + + def watch_asyncio_process( + self, + name: str, + process: Any, + *, + log_path: str, + ) -> None: + self._watch(name, process.wait, log_path=log_path) + + def raise_if_failed(self) -> None: + if self._failure is not None: + raise self._failure + + def close(self) -> None: + self._closing = True + current = self._current_task() + for task in self._tasks.values(): + if task is not current: + task.cancel() + self._tasks.clear() + + def _watch( + self, + name: str, + wait: Callable[[], Awaitable[int]], + *, + log_path: str, + ) -> None: + previous = self._tasks.pop(name, None) + if previous is not None: + previous.cancel() + self._tasks[name] = asyncio.create_task( + self._watch_exit(name, wait, log_path=log_path) + ) + + async def _watch_exit( + self, + name: str, + wait: Callable[[], Awaitable[int]], + *, + log_path: str, + ) -> None: + try: + returncode = await wait() + except asyncio.CancelledError: + return + if self._closing: + return + error = RuntimeError( + f"{name} exited with code {returncode}. Check logs at {log_path}" + ) + self._failure = error + self._on_unexpected_exit(error) + + async def _wait_popen(self, process: subprocess.Popen[Any]) -> int: + return int(await asyncio.to_thread(process.wait)) + + def _current_task(self) -> asyncio.Task[Any] | None: + try: + return asyncio.current_task() + except RuntimeError: + return None + + +class ServiceLifecycle: + def __init__(self) -> None: + self.closing = False + self._close_callback: Callable[[], None] | None = None + self._previous_signal_handlers: dict[int, Any] = {} + + def begin_close(self) -> bool: + if self.closing: + return False + self.closing = True + return True + + def install_parent_cleanup(self, close: Callable[[], None]) -> None: + if self._close_callback is not None: + return + self._close_callback = close + atexit.register(close) + + def _default_signal_exit(signum: int) -> None: + if signum == signal.SIGINT: + raise KeyboardInterrupt + raise SystemExit(128 + signum) + + for signum in (signal.SIGINT, signal.SIGTERM): + previous = signal.getsignal(signum) + self._previous_signal_handlers[signum] = previous + + def _handler(received_signum, frame, *, _previous=previous): + close() + if callable(_previous): + _previous(received_signum, frame) + return + if _previous == signal.SIG_IGN: + return + _default_signal_exit(received_signum) + + signal.signal(signum, _handler) + + def restore_parent_cleanup(self) -> None: + if self._close_callback is not None: + try: + atexit.unregister(self._close_callback) + except ValueError: + pass + self._close_callback = None + for signum, previous in self._previous_signal_handlers.items(): + signal.signal(signum, previous) + self._previous_signal_handlers.clear() diff --git a/src/art/utils/managed_process.py b/src/art/utils/managed_process.py new file mode 100644 index 000000000..88aa51fc1 --- /dev/null +++ b/src/art/utils/managed_process.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import argparse +import ctypes +import os +import signal +import subprocess +import sys +import time + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run an ART-owned child process") + parser.add_argument("--parent-pid", type=int, required=True) + parser.add_argument("command", nargs=argparse.REMAINDER) + args = parser.parse_args() + if args.command[:1] == ["--"]: + args.command = args.command[1:] + if not args.command: + parser.error("missing command") + return args + + +def set_parent_death_signal(parent_pid: int, sig: signal.Signals) -> None: + if sys.platform != "linux": + return + libc = ctypes.CDLL(None, use_errno=True) + if libc.prctl(1, int(sig), 0, 0, 0) != 0: + errno = ctypes.get_errno() + raise OSError(errno, os.strerror(errno)) + if os.getppid() != parent_pid: + os._exit(1) + + +def main() -> None: + args = parse_args() + if hasattr(os, "setsid") and os.getpgrp() != os.getpid(): + os.setsid() + + process: subprocess.Popen[bytes] | None = None + child_pgid: int | None = None + shutting_down = False + requested_shutdown: tuple[signal.Signals, int] | None = None + + def signal_child_group(sig: signal.Signals) -> None: + if child_pgid is None: + return + try: + os.killpg(child_pgid, sig) + except ProcessLookupError: + pass + + def sweep_child_group() -> None: + signal_child_group(signal.SIGTERM) + time.sleep(float(os.environ.get("ART_MANAGED_PROCESS_SWEEP_GRACE", 0.5))) + signal_child_group(signal.SIGKILL) + + def shutdown(sig: signal.Signals, exit_code: int) -> None: + nonlocal shutting_down + if shutting_down: + return + shutting_down = True + signal_child_group(sig) + if process is not None: + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + signal_child_group(signal.SIGKILL) + process.wait() + sweep_child_group() + os._exit(exit_code) + + def handle_signal(signum: int, _frame: object | None) -> None: + nonlocal requested_shutdown + requested_shutdown = (signal.Signals(signum), 128 + signum) + + signal.signal(signal.SIGINT, handle_signal) + signal.signal(signal.SIGTERM, handle_signal) + + wrapper_pid = os.getpid() + process = subprocess.Popen( + args.command, + start_new_session=True, + preexec_fn=lambda: set_parent_death_signal(wrapper_pid, signal.SIGTERM), + ) + child_pgid = process.pid + + while True: + if requested_shutdown is not None: + shutdown(*requested_shutdown) + if os.getppid() != args.parent_pid: + shutdown(signal.SIGTERM, 1) + return_code = process.poll() + if return_code is not None: + sweep_child_group() + sys.exit(return_code) + time.sleep(0.5) + + +if __name__ == "__main__": + main() diff --git a/src/art/vllm/__init__.py b/src/art/vllm/__init__.py deleted file mode 100644 index 9ae9c5efb..000000000 --- a/src/art/vllm/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -"""vLLM integration module for art.""" - -# Server functionality -# Engine and worker management -from .engine import ( - WorkerExtension, - get_llm, - get_worker, - run_on_workers, -) - -# Patches - these are typically imported for their side effects -from .patches import ( - patch_listen_for_disconnect, - patch_tool_parser_manager, - subclass_chat_completion_request, -) -from .server import ( - get_uvicorn_logging_config, - openai_server_task, - set_vllm_log_file, -) - -__all__ = [ - # Server - "openai_server_task", - "get_uvicorn_logging_config", - "set_vllm_log_file", - # Engine - "get_llm", - "run_on_workers", - "get_worker", - "WorkerExtension", - # Patches - "subclass_chat_completion_request", - "patch_listen_for_disconnect", - "patch_tool_parser_manager", -] diff --git a/src/art/vllm/engine.py b/src/art/vllm/engine.py deleted file mode 100644 index c8da5c55b..000000000 --- a/src/art/vllm/engine.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Engine and worker management for vLLM.""" - -import asyncio -import contextlib -import contextvars -from dataclasses import replace -import os -import time -from typing import Any, Callable, Generator, ParamSpec, TypeVar, cast - -import cloudpickle -import vllm -from vllm.v1.engine.async_llm import AsyncLLM -from vllm.v1.worker.gpu_worker import Worker - - -async def get_llm(args: vllm.AsyncEngineArgs) -> AsyncLLM: # ty:ignore[unresolved-attribute] - """ - Create an AsyncLLM engine with model download and patches applied. - - Args: - args: The engine arguments including model name and configuration. - - Returns: - A configured AsyncLLM instance. - """ - # Download model only if it's not a local path - if not os.path.exists(args.model): - process = await asyncio.create_subprocess_shell( - f"HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download {args.model}" - ) - await process.wait() - - llm = AsyncLLM.from_engine_args( - replace( - args, - worker_extension_cls=f"{WorkerExtension.__module__}.{WorkerExtension.__qualname__}", - enable_sleep_mode=True, - ) - ) - return llm - - -P = ParamSpec("P") -R = TypeVar("R") - - -async def run_on_workers( - llm: AsyncLLM, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs -) -> list[R]: - """ - Run a function on all workers in a distributed setup. - - Args: - llm: The AsyncLLM instance with workers. - func: The function to run on each worker. - *args: Positional arguments for the function. - **kwargs: Keyword arguments for the function. - - Returns: - List of results from each worker. - """ - return await llm.collective_rpc( - "run", args=(cloudpickle.dumps(func), *args), kwargs=kwargs - ) - - -# Context variable to hold the current worker -_worker: contextvars.ContextVar["ExtendedWorker"] = contextvars.ContextVar("worker") - - -def get_worker() -> "ExtendedWorker": - """Get the current worker instance""" - return _worker.get() - - -class WorkerExtension: - """Extension for running arbitrary functions on vLLM workers.""" - - def run(self, pickled_func: bytes, *args: Any, **kwargs: Any) -> Any: - func = cloudpickle.loads(pickled_func) - token = _worker.set(cast(ExtendedWorker, self)) - try: - return func(*args, **kwargs) - finally: - _worker.reset(token) - - @contextlib.contextmanager - def time(self, name: str) -> Generator[None, None, None]: - from vllm.v1.worker.gpu_worker import logger - - start_time = time.perf_counter() - yield - end_time = time.perf_counter() - logger.info(f"{name}: {end_time - start_time:.2f} seconds") - - -class ExtendedWorker(Worker, WorkerExtension): - pass diff --git a/src/art/vllm/patches.py b/src/art/vllm/patches.py deleted file mode 100644 index cf6284fc1..000000000 --- a/src/art/vllm/patches.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Monkey patches and modifications for vLLM.""" - -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from torch import Tensor - - -def patch_transformers_v5_compat() -> None: - _patch_rope_validation_ignore_keys() - _patch_qwen3_vl_moe_tie_word_embeddings() - _patch_qwen3_5_lora() - - -def _patch_rope_validation_ignore_keys() -> None: - from transformers.configuration_utils import PretrainedConfig - - original = PretrainedConfig.convert_rope_params_to_dict # type: ignore[attr-defined] - - # Return if already patched - if getattr(original, "__art_patched__", False): - return - - def patched(self: Any, ignore_keys_at_rope_validation: Any = None, **kwargs: Any): - if ignore_keys_at_rope_validation is not None: - ignore_keys_at_rope_validation = set(ignore_keys_at_rope_validation) - return original( - self, - ignore_keys_at_rope_validation=ignore_keys_at_rope_validation, - **kwargs, - ) - - patched.__art_patched__ = True # type: ignore[attr-defined] - PretrainedConfig.convert_rope_params_to_dict = patched # type: ignore[method-assign] - - -def _patch_qwen3_vl_moe_tie_word_embeddings() -> None: - from transformers import Qwen3VLMoeTextConfig - - setattr(Qwen3VLMoeTextConfig, "tie_word_embeddings", False) - - -def _patch_qwen3_5_lora() -> None: - from vllm.lora.layers.column_parallel_linear import ( - MergedColumnParallelLinearWithLoRA, - MergedColumnParallelLinearWithShardedLoRA, - ) - from vllm.lora.layers.utils import _not_fully_sharded_can_replace - from vllm.model_executor.models.qwen3_5 import ( - Qwen3_5ForCausalLMBase, - Qwen3_5ForConditionalGeneration, - ) - - projections = ["in_proj_q", "in_proj_k", "in_proj_v", "in_proj_z"] - Qwen3_5ForCausalLMBase.packed_modules_mapping["in_proj_qkvz"] = projections - Qwen3_5ForConditionalGeneration.packed_modules_mapping["in_proj_qkvz"] = projections - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: Any, - lora_config: Any, - packed_modules_list: list[str], - model_config: Any = None, - ) -> bool: - from vllm.model_executor.layers.linear import MergedColumnParallelLinear - - return type(source_layer) is MergedColumnParallelLinear and len( - packed_modules_list - ) == len(source_layer.output_sizes) - - MergedColumnParallelLinearWithLoRA.can_replace_layer = can_replace_layer - - def slice_lora_a( - self: Any, - lora_a: "list[Tensor | None]", - ) -> "list[Tensor | None]": - output_shard_size = self.lora_a_stacked[0].shape[2] - output_start_idx = self.tp_rank * output_shard_size - return [ - a[output_start_idx : output_start_idx + output_shard_size, :] - if a is not None - else None - for a in lora_a - ] - - MergedColumnParallelLinearWithShardedLoRA.slice_lora_a = slice_lora_a # ty:ignore[invalid-assignment] - - -def subclass_chat_completion_request() -> None: - """ - Subclass ChatCompletionRequest so that logprobs are always returned. - """ - from vllm.entrypoints.openai.chat_completion import protocol - - class ChatCompletionRequest(protocol.ChatCompletionRequest): - def __init__(self, *args: object, **kwargs: object) -> None: - super().__init__(*args, **kwargs) # ty:ignore[invalid-argument-type] - self.logprobs = True - if self.top_logprobs is None: - self.top_logprobs = 0 - - protocol.ChatCompletionRequest = ChatCompletionRequest # ty:ignore[invalid-assignment] - - -def patch_listen_for_disconnect() -> None: - async def patched_listen_for_disconnect(request): - try: - while True: - message = await request.receive() - if message["type"] == "http.disconnect": - break - except UnboundLocalError: - pass - - # Replace the original function - import vllm.entrypoints.utils - - vllm.entrypoints.utils.listen_for_disconnect = patched_listen_for_disconnect # ty:ignore[invalid-assignment] - - -def patch_tool_parser_manager() -> None: - """ - Patch ToolParserManager to support streaming tool call logprobs. - """ - from vllm.entrypoints.openai.engine.protocol import DeltaMessage - from vllm.tool_parsers.abstract_tool_parser import ToolParserManager - - get_tool_parser = ToolParserManager.get_tool_parser - - def patched_get_tool_parser(name: str) -> type: - tool_parser_class = get_tool_parser(name) - original = tool_parser_class.extract_tool_calls_streaming - - def patch( - *args: Any, - **kwargs: Any, - ) -> Any: - return original(*args, **kwargs) or DeltaMessage() - - tool_parser_class.extract_tool_calls_streaming = patch # ty:ignore[invalid-assignment] - return tool_parser_class - - ToolParserManager.get_tool_parser = patched_get_tool_parser # ty:ignore[invalid-assignment] diff --git a/src/art/vllm/server.py b/src/art/vllm/server.py deleted file mode 100644 index f6d2b82d3..000000000 --- a/src/art/vllm/server.py +++ /dev/null @@ -1,210 +0,0 @@ -"""OpenAI-compatible server functionality for vLLM.""" - -import asyncio -from contextlib import asynccontextmanager -import logging -import os -from typing import Any, AsyncIterator, Coroutine, cast - -from openai import AsyncOpenAI -from uvicorn.config import LOGGING_CONFIG -from vllm.engine.protocol import EngineClient -from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args -from vllm.logger import _DATE_FORMAT, _FORMAT -from vllm.logging_utils import NewLineFormatter -from vllm.utils.argparse_utils import FlexibleArgumentParser - -from ..dev.openai_server import OpenAIServerConfig - -_openai_serving_models: Any | None = None - - -async def openai_server_task( - engine: EngineClient, - config: OpenAIServerConfig, -) -> asyncio.Task[None]: - """ - Starts an asyncio task that runs an OpenAI-compatible server. - - Args: - engine: The vLLM engine client. - config: The configuration for the OpenAI-compatible server. - - Returns: - A running asyncio task for the OpenAI-compatible server. Cancel the task - to stop the server. - """ - # Import patches before importing api_server - from .patches import ( - patch_listen_for_disconnect, - patch_tool_parser_manager, - subclass_chat_completion_request, - ) - - # We must subclass ChatCompletionRequest before importing api_server - # or logprobs will not always be returned - subclass_chat_completion_request() - # Capture the OpenAIServingModels instance so dynamically added LoRAs - # are reflected in the model list. - from vllm.entrypoints.openai import api_server - from vllm.entrypoints.openai.models import serving as serving_models - - serving_models_any = cast(Any, serving_models) - if not getattr(serving_models_any, "_art_openai_serving_models_patched", False): - serving_models_any._art_openai_serving_models_patched = True - original_init = serving_models.OpenAIServingModels.__init__ - - def _init(self, *args: Any, **kwargs: Any) -> None: - original_init(self, *args, **kwargs) - global _openai_serving_models - _openai_serving_models = self - - serving_models.OpenAIServingModels.__init__ = _init # ty:ignore[invalid-assignment] - - patch_listen_for_disconnect() - patch_tool_parser_manager() - set_vllm_log_file(config.get("log_file", "vllm.log")) - - # Patch engine.add_lora to normalize requests across vLLM schema changes. - add_lora = engine.add_lora - - async def _add_lora(lora_request) -> bool: - from vllm.lora.request import LoRARequest - - if not isinstance(lora_request, LoRARequest): - lora_request = LoRARequest( - lora_name=lora_request.lora_name, - lora_int_id=lora_request.lora_int_id, - lora_path=lora_request.lora_path, - base_model_name=getattr(lora_request, "base_model_name", None), - load_inplace=getattr(lora_request, "load_inplace", False), - ) - added = await add_lora(lora_request) - if added and _openai_serving_models is not None: - _openai_serving_models.lora_requests[lora_request.lora_name] = lora_request - return added - - engine.add_lora = _add_lora # ty:ignore[invalid-assignment] - - @asynccontextmanager - async def build_async_engine_client( - *args: Any, - **kwargs: Any, - ) -> AsyncIterator[EngineClient]: - yield engine - - api_server.build_async_engine_client = build_async_engine_client - openai_server_task = asyncio.create_task(_openai_server_coroutine(config)) - server_args = config.get("server_args", {}) - client = AsyncOpenAI( - api_key=server_args.get("api_key"), - base_url=f"http://{server_args.get('host', '0.0.0.0')}:{server_args.get('port', 8000)}/v1", - ) - - async def test_client() -> None: - while True: - try: - async for _ in client.models.list(): - return - except: # noqa: E722 - await asyncio.sleep(0.1) - - test_client_task = asyncio.create_task(test_client()) - try: - timeout = float(os.environ.get("ART_SERVER_TIMEOUT", 30.0)) - done, _ = await asyncio.wait( - [openai_server_task, test_client_task], - timeout=timeout, - return_when="FIRST_COMPLETED", - ) - if not done: - raise TimeoutError( - f"Unable to reach OpenAI-compatible server within {timeout} seconds. You can increase this timeout by setting the ART_SERVER_TIMEOUT environment variable." - ) - for task in done: - task.result() - - return openai_server_task - except Exception: - openai_server_task.cancel() - test_client_task.cancel() - raise - - -def _openai_server_coroutine( - config: OpenAIServerConfig, -) -> Coroutine[Any, Any, None]: - from vllm.entrypoints.openai import api_server - - parser = FlexibleArgumentParser( - description="vLLM OpenAI-Compatible RESTful API server." - ) - parser = make_arg_parser(parser) - engine_args = config.get("engine_args", {}) - server_args = config.get("server_args", {}) - args = [ - *[ - f"--{key.replace('_', '-')}{f'={item}' if item is not True else ''}" - for args in [engine_args, server_args] - for key, value in args.items() - for item in (value if isinstance(value, list) else [value]) - if item is not None - ], - ] - namespace = parser.parse_args(args) - assert namespace is not None - validate_parsed_serve_args(namespace) - return api_server.run_server( - namespace, - log_config=get_uvicorn_logging_config(config.get("log_file", "vllm.log")), - ) - - -def get_uvicorn_logging_config(path: str) -> dict[str, Any]: - """ - Returns a Uvicorn logging config that writes to the given path. - """ - return { - **LOGGING_CONFIG, - "handlers": { - "default": { - "formatter": "default", - "class": "logging.FileHandler", - "filename": path, - }, - "access": { - "formatter": "default", - "class": "logging.FileHandler", - "filename": path, - }, - }, - } - - -def set_vllm_log_file(path: str) -> None: - """ - Sets the vLLM log file to the given path. - """ - - # Create directory for the log file if it doesn't exist - os.makedirs(os.path.dirname(path), exist_ok=True) - - # Get the vLLM logger - vllm_logger = logging.getLogger("vllm") - - # Remove existing handlers - for handler in vllm_logger.handlers[:]: - vllm_logger.removeHandler(handler) - - # Create a file handler - file_handler = logging.FileHandler(path) - - # Use vLLM's NewLineFormatter which adds the fileinfo field - formatter = NewLineFormatter(fmt=_FORMAT, datefmt=_DATE_FORMAT) - file_handler.setFormatter(formatter) - - # Add the handler to the logger - vllm_logger.addHandler(file_handler) - - # Set log level to filter out DEBUG messages - vllm_logger.setLevel(logging.INFO) diff --git a/src/art/vllm_runtime.py b/src/art/vllm_runtime.py new file mode 100644 index 000000000..58a081921 --- /dev/null +++ b/src/art/vllm_runtime.py @@ -0,0 +1,411 @@ +import asyncio +from contextlib import contextmanager +import fcntl +import hashlib +import json +import math +import os +from pathlib import Path +import shlex +import shutil +import subprocess +import tempfile +from typing import Any, Literal + +import httpx +from pydantic import BaseModel, ConfigDict, Field + +RUNTIME_SERVER = "art-vllm-runtime-server" +RUNTIME_PACKAGE = "art-vllm-runtime" +RUNTIME_PROTOCOL_VERSION = 1 +RUNTIME_INSTALL_MARKER = "openpipe-art-vllm-runtime" + + +class VllmRuntimeLaunchConfig(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + base_model: str + port: int + host: str = "127.0.0.1" + cuda_visible_devices: str + lora_path: str + served_model_name: str + rollout_weights_mode: Literal["lora", "merged"] + engine_args: dict[str, object] = Field(default_factory=dict) + server_args: dict[str, object] = Field(default_factory=dict) + + +class VllmRuntimeManifest(BaseModel): + model_config = ConfigDict(extra="forbid") + + art_package: str = "openpipe-art" + art_version: str + runtime_package: str = RUNTIME_PACKAGE + runtime_version: str + protocol_version: int = RUNTIME_PROTOCOL_VERSION + python: str + runtime_wheel: str + runtime_wheel_sha256: str + pyproject: str = "pyproject.toml" + pyproject_sha256: str + lockfile: str = "uv.lock" + lockfile_sha256: str + + +class VllmRuntimeInstallMarker(BaseModel): + model_config = ConfigDict(extra="forbid") + + managed_by: str = RUNTIME_INSTALL_MARKER + runtime_package: str = RUNTIME_PACKAGE + runtime_version: str + protocol_version: int = RUNTIME_PROTOCOL_VERSION + manifest_hash: str + runtime_wheel_sha256: str + cache_root: str + + +def get_vllm_runtime_project_root() -> Path: + override = os.environ.get("ART_VLLM_RUNTIME_PROJECT_ROOT") + if override: + return Path(override).resolve() + return Path(__file__).resolve().parents[2] / "vllm_runtime" + + +def get_vllm_runtime_working_dir() -> Path: + runtime_root = get_vllm_runtime_project_root() + if runtime_root.exists(): + return runtime_root + return Path.cwd() + + +def get_vllm_runtime_cache_root() -> Path: + override = os.environ.get("ART_VLLM_RUNTIME_CACHE_DIR") + if override: + return Path(override).expanduser() + return Path.home() / ".cache" / "art" / "vllm_runtime" + + +def _bundled_runtime_dir() -> Path: + return Path(__file__).resolve().parent / "_vllm_runtime" + + +def _source_runtime_bin() -> Path: + return get_vllm_runtime_project_root() / ".venv" / "bin" / RUNTIME_SERVER + + +def _runtime_bin(runtime_dir: Path) -> Path: + return runtime_dir / ".venv" / "bin" / RUNTIME_SERVER + + +def _runtime_python(runtime_dir: Path) -> Path: + return runtime_dir / ".venv" / "bin" / "python" + + +def _is_executable_file(path: Path) -> bool: + return path.is_file() and os.access(path, os.X_OK) + + +def _sha256_file(path: Path) -> str: + digest = hashlib.sha256() + with path.open("rb") as file: + for chunk in iter(lambda: file.read(1024 * 1024), b""): + digest.update(chunk) + return digest.hexdigest() + + +def _manifest_hash(manifest: VllmRuntimeManifest) -> str: + payload = json.dumps(manifest.model_dump(), sort_keys=True).encode() + return hashlib.sha256(payload).hexdigest() + + +def _load_bundled_manifest(bundle_dir: Path | None = None) -> VllmRuntimeManifest: + bundle_dir = bundle_dir or _bundled_runtime_dir() + manifest_path = bundle_dir / "manifest.json" + if not manifest_path.exists(): + raise RuntimeError( + "ART vLLM runtime bundle is missing. Reinstall openpipe-art from a " + "wheel built with scripts/build_package.py or set ART_VLLM_RUNTIME_BIN." + ) + return VllmRuntimeManifest.model_validate_json(manifest_path.read_text()) + + +def _run_install_command(command: list[str], *, cwd: Path | None = None) -> None: + try: + result = subprocess.run(command, cwd=cwd, capture_output=True, text=True) + except FileNotFoundError as exc: + raise RuntimeError( + "uv is required to install ART's managed vLLM runtime. Install uv or " + "set ART_VLLM_RUNTIME_BIN to an existing runtime server." + ) from exc + if result.returncode == 0: + return + output = (result.stdout + result.stderr)[-4000:] + raise RuntimeError( + "Failed to install ART's managed vLLM runtime with command " + f"{shlex.join(command)}.\n{output}" + ) + + +@contextmanager +def _runtime_install_lock(cache_root: Path): + cache_root.mkdir(parents=True, exist_ok=True) + lock_path = cache_root / ".install.lock" + with lock_path.open("w") as lock_file: + fcntl.flock(lock_file, fcntl.LOCK_EX) + try: + yield + finally: + fcntl.flock(lock_file, fcntl.LOCK_UN) + + +def _install_marker_path(runtime_dir: Path) -> Path: + return runtime_dir / "install.json" + + +def _read_install_marker(runtime_dir: Path) -> VllmRuntimeInstallMarker | None: + marker_path = _install_marker_path(runtime_dir) + if not marker_path.exists(): + return None + try: + return VllmRuntimeInstallMarker.model_validate_json(marker_path.read_text()) + except ValueError: + return None + + +def _is_managed_runtime_dir( + runtime_dir: Path, + *, + cache_root: Path, + expected_hash: str | None = None, +) -> bool: + if not runtime_dir.is_dir(): + return False + if runtime_dir.resolve().parent != cache_root.resolve(): + return False + if len(runtime_dir.name) != 64 or any( + c not in "0123456789abcdef" for c in runtime_dir.name + ): + return False + if expected_hash is not None and runtime_dir.name != expected_hash: + return False + marker = _read_install_marker(runtime_dir) + if marker is None: + return False + if marker.managed_by != RUNTIME_INSTALL_MARKER: + return False + if marker.runtime_package != RUNTIME_PACKAGE: + return False + if marker.manifest_hash != runtime_dir.name: + return False + if marker.cache_root != str(cache_root.resolve()): + return False + if not (runtime_dir / ".venv" / "pyvenv.cfg").exists(): + return False + return True + + +def _validate_managed_runtime( + runtime_dir: Path, + *, + cache_root: Path, + manifest: VllmRuntimeManifest, + manifest_hash: str, +) -> Path | None: + if not _is_managed_runtime_dir( + runtime_dir, cache_root=cache_root, expected_hash=manifest_hash + ): + return None + marker = _read_install_marker(runtime_dir) + if marker is None: + return None + if marker.runtime_version != manifest.runtime_version: + return None + if marker.protocol_version != manifest.protocol_version: + return None + if marker.runtime_wheel_sha256 != manifest.runtime_wheel_sha256: + return None + runtime_bin = _runtime_bin(runtime_dir) + if not _is_executable_file(runtime_bin): + return None + return runtime_bin + + +def _cleanup_old_managed_runtimes(cache_root: Path, *, keep_hash: str) -> None: + if os.environ.get("ART_VLLM_RUNTIME_KEEP_OLD"): + return + if not cache_root.exists(): + return + for child in cache_root.iterdir(): + if child.name == keep_hash: + continue + if not _is_managed_runtime_dir(child, cache_root=cache_root): + continue + shutil.rmtree(child) + + +def _install_managed_runtime( + *, + bundle_dir: Path, + cache_root: Path, + manifest: VllmRuntimeManifest, + manifest_hash: str, +) -> Path: + runtime_wheel = bundle_dir / manifest.runtime_wheel + if _sha256_file(runtime_wheel) != manifest.runtime_wheel_sha256: + raise RuntimeError(f"Bundled vLLM runtime wheel hash mismatch: {runtime_wheel}") + + cache_root.mkdir(parents=True, exist_ok=True) + stage = Path( + tempfile.mkdtemp(prefix=f".{manifest_hash}.tmp-", dir=str(cache_root.resolve())) + ) + runtime_dir = cache_root / manifest_hash + promoted = False + try: + shutil.copy2(bundle_dir / manifest.pyproject, stage / "pyproject.toml") + shutil.copy2(bundle_dir / manifest.lockfile, stage / "uv.lock") + _run_install_command( + [ + "uv", + "sync", + "--project", + str(stage), + "--frozen", + "--no-install-project", + "--no-dev", + ] + ) + if runtime_dir.exists(): + existing = _validate_managed_runtime( + runtime_dir, + cache_root=cache_root, + manifest=manifest, + manifest_hash=manifest_hash, + ) + if existing is not None: + shutil.rmtree(stage) + return existing + raise RuntimeError( + f"Refusing to replace invalid vLLM runtime cache directory: {runtime_dir}" + ) + stage.rename(runtime_dir) + promoted = True + runtime_python = _runtime_python(runtime_dir) + _run_install_command( + [ + "uv", + "pip", + "install", + "--no-deps", + "--python", + str(runtime_python), + str(runtime_wheel), + ] + ) + runtime_bin = _runtime_bin(runtime_dir) + if not _is_executable_file(runtime_bin): + raise RuntimeError(f"vLLM runtime server was not installed: {runtime_bin}") + + marker = VllmRuntimeInstallMarker( + runtime_version=manifest.runtime_version, + protocol_version=manifest.protocol_version, + manifest_hash=manifest_hash, + runtime_wheel_sha256=manifest.runtime_wheel_sha256, + cache_root=str(cache_root.resolve()), + ) + _install_marker_path(runtime_dir).write_text( + json.dumps(marker.model_dump(), indent=2, sort_keys=True) + "\n" + ) + _cleanup_old_managed_runtimes(cache_root, keep_hash=manifest_hash) + return runtime_bin + except Exception: + shutil.rmtree(runtime_dir if promoted else stage, ignore_errors=True) + raise + + +def ensure_vllm_runtime() -> Path: + bundle_dir = _bundled_runtime_dir() + manifest = _load_bundled_manifest(bundle_dir) + manifest_hash = _manifest_hash(manifest) + cache_root = get_vllm_runtime_cache_root() + cache_root.mkdir(parents=True, exist_ok=True) + cache_root = cache_root.resolve() + runtime_dir = cache_root / manifest_hash + + with _runtime_install_lock(cache_root): + existing = _validate_managed_runtime( + runtime_dir, + cache_root=cache_root, + manifest=manifest, + manifest_hash=manifest_hash, + ) + if existing is not None: + _cleanup_old_managed_runtimes(cache_root, keep_hash=manifest_hash) + return existing + return _install_managed_runtime( + bundle_dir=bundle_dir, + cache_root=cache_root, + manifest=manifest, + manifest_hash=manifest_hash, + ) + + +def _runtime_command_prefix() -> list[str]: + override = os.environ.get("ART_VLLM_RUNTIME_BIN") + if override: + return shlex.split(override) + runtime_bin = _source_runtime_bin() + if runtime_bin.exists(): + return [str(runtime_bin)] + runtime_root = get_vllm_runtime_project_root() + if ( + runtime_root.exists() + and not (_bundled_runtime_dir() / "manifest.json").exists() + ): + raise RuntimeError( + "vLLM runtime env is not built. Run `uv sync` in " + f"{runtime_root} or set ART_VLLM_RUNTIME_BIN." + ) + return [str(ensure_vllm_runtime())] + + +def build_vllm_runtime_server_cmd(config: VllmRuntimeLaunchConfig) -> list[str]: + return [ + *_runtime_command_prefix(), + f"--model={config.base_model}", + f"--port={config.port}", + f"--host={config.host}", + f"--cuda-visible-devices={config.cuda_visible_devices}", + f"--lora-path={config.lora_path}", + f"--served-model-name={config.served_model_name}", + f"--rollout-weights-mode={config.rollout_weights_mode}", + f"--engine-args-json={json.dumps(config.engine_args)}", + f"--server-args-json={json.dumps(config.server_args)}", + ] + + +async def wait_for_vllm_runtime( + *, + process: subprocess.Popen[Any], + host: str, + port: int, + timeout: float, +) -> None: + deadline = asyncio.get_running_loop().time() + timeout + url = f"http://{host}:{port}/health" + async with httpx.AsyncClient() as client: + while True: + if process.poll() is not None: + raise RuntimeError( + f"vLLM runtime exited with code {process.returncode}" + ) + try: + response = await client.get(url, timeout=5.0) + if response.status_code == 200: + return + except httpx.HTTPError: + pass + if asyncio.get_running_loop().time() >= deadline: + raise TimeoutError( + f"vLLM runtime did not become ready within {math.ceil(timeout)}s" + ) + await asyncio.sleep(0.5) diff --git a/src/art/weight_transfer/__init__.py b/src/art/weight_transfer/__init__.py new file mode 100644 index 000000000..f8140bd78 --- /dev/null +++ b/src/art/weight_transfer/__init__.py @@ -0,0 +1,15 @@ +from .nccl import ( + DEFAULT_PACKED_BUFFER_SIZE_BYTES, + DEFAULT_PACKED_NUM_BUFFERS, + TrainerNcclCommunicator, + trainer_init, + trainer_send_weights, +) + +__all__ = [ + "DEFAULT_PACKED_BUFFER_SIZE_BYTES", + "DEFAULT_PACKED_NUM_BUFFERS", + "TrainerNcclCommunicator", + "trainer_init", + "trainer_send_weights", +] diff --git a/src/art/weight_transfer/nccl.py b/src/art/weight_transfer/nccl.py new file mode 100644 index 000000000..a0b3b7e4a --- /dev/null +++ b/src/art/weight_transfer/nccl.py @@ -0,0 +1,363 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Trainer-side NCCL transport subset extracted from vLLM.""" + +import ctypes +from datetime import timedelta +import importlib.util +import os +from pathlib import Path +import pickle +import socket +from typing import Any, cast + +from pydantic import BaseModel, ConfigDict +import torch +from torch.distributed import TCPStore + +from .packed_tensor import ( + DEFAULT_PACKED_BUFFER_SIZE_BYTES, + DEFAULT_PACKED_NUM_BUFFERS, + packed_broadcast_producer, +) + + +class TrainerNcclSendWeightsArgs(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + group: Any + src: int = 0 + post_iter_func: Any = None + packed: bool = False + stream: Any = None + packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES + packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS + + +class _NcclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +_nccl_result_t = ctypes.c_int +_nccl_comm_t = ctypes.c_void_p +_cuda_stream_t = ctypes.c_void_p +_buffer_type = ctypes.c_void_p + + +class _NcclDataType: + INT8 = 0 + UINT8 = 1 + INT32 = 2 + INT64 = 4 + FLOAT16 = 6 + FLOAT32 = 7 + FLOAT64 = 8 + BFLOAT16 = 9 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.INT8 + if dtype == torch.uint8: + return cls.UINT8 + if dtype == torch.int32: + return cls.INT32 + if dtype == torch.int64: + return cls.INT64 + if dtype == torch.float16: + return cls.FLOAT16 + if dtype == torch.float32: + return cls.FLOAT32 + if dtype == torch.float64: + return cls.FLOAT64 + if dtype == torch.bfloat16: + return cls.BFLOAT16 + raise ValueError(f"Unsupported NCCL dtype: {dtype}") + + +class _NcclRedOp: + SUM = 0 + + +class _NcclLibrary: + def __init__(self, so_file: str | None = None): + self._lib = ctypes.CDLL(so_file or _find_nccl_library()) + self._configure("ncclGetErrorString", ctypes.c_char_p, [_nccl_result_t]) + self._configure( + "ncclGetUniqueId", _nccl_result_t, [ctypes.POINTER(_NcclUniqueId)] + ) + self._configure( + "ncclCommInitRank", + _nccl_result_t, + [ctypes.POINTER(_nccl_comm_t), ctypes.c_int, _NcclUniqueId, ctypes.c_int], + ) + self._configure( + "ncclAllReduce", + _nccl_result_t, + [ + _buffer_type, + _buffer_type, + ctypes.c_size_t, + ctypes.c_int, + ctypes.c_int, + _nccl_comm_t, + _cuda_stream_t, + ], + ) + self._configure( + "ncclBroadcast", + _nccl_result_t, + [ + _buffer_type, + _buffer_type, + ctypes.c_size_t, + ctypes.c_int, + ctypes.c_int, + _nccl_comm_t, + _cuda_stream_t, + ], + ) + + def _configure(self, name: str, restype: Any, argtypes: list[Any]) -> None: + function = getattr(self._lib, name) + function.restype = restype + function.argtypes = argtypes + + def _check(self, result: int) -> None: + if result != 0: + error = self._lib.ncclGetErrorString(result).decode("utf-8") + raise RuntimeError(f"NCCL error: {error}") + + def get_unique_id(self) -> _NcclUniqueId: + unique_id = _NcclUniqueId() + self._check(self._lib.ncclGetUniqueId(ctypes.byref(unique_id))) + return unique_id + + def init_rank(self, world_size: int, unique_id: _NcclUniqueId, rank: int) -> Any: + comm = _nccl_comm_t() + self._check( + self._lib.ncclCommInitRank(ctypes.byref(comm), world_size, unique_id, rank) + ) + return comm + + def all_reduce( + self, + tensor: torch.Tensor, + comm: Any, + stream: torch.cuda.Stream, + ) -> None: + self._check( + self._lib.ncclAllReduce( + _buffer_type(tensor.data_ptr()), + _buffer_type(tensor.data_ptr()), + tensor.numel(), + _NcclDataType.from_torch(tensor.dtype), + _NcclRedOp.SUM, + comm, + _cuda_stream_t(stream.cuda_stream), + ) + ) + + def broadcast( + self, + tensor: torch.Tensor, + comm: Any, + *, + rank: int, + src: int, + stream: torch.cuda.Stream, + ) -> None: + send_buffer = _buffer_type(tensor.data_ptr()) if rank == src else _buffer_type() + self._check( + self._lib.ncclBroadcast( + send_buffer, + _buffer_type(tensor.data_ptr()), + tensor.numel(), + _NcclDataType.from_torch(tensor.dtype), + src, + comm, + _cuda_stream_t(stream.cuda_stream), + ) + ) + + +def _nccl_unique_id_to_bytes(unique_id: _NcclUniqueId) -> bytes: + return ctypes.string_at(ctypes.byref(unique_id), ctypes.sizeof(unique_id)) + + +def _nccl_unique_id_from_bytes(payload: bytes) -> _NcclUniqueId: + assert len(payload) == ctypes.sizeof(_NcclUniqueId) + unique_id = _NcclUniqueId() + ctypes.memmove(ctypes.byref(unique_id), payload, len(payload)) + return unique_id + + +class _BootstrapGroup: + def __init__( + self, + *, + host: str, + port: int, + rank: int, + world_size: int, + store_timeout: int = 300, + ) -> None: + launch_server = rank == 0 + listen_socket = None + listen_fd = None + if launch_server: + listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + listen_socket.bind((host, port)) + listen_socket.listen() + listen_fd = listen_socket.fileno() + self.rank = rank + self.world_size = world_size + self.socket = listen_socket + self.store = TCPStore( + host_name=host, + port=port, + world_size=world_size, + is_master=launch_server, + timeout=timedelta(seconds=store_timeout), + use_libuv=False, + master_listen_fd=listen_fd, + ) + self._broadcast_send_counter = 0 + self._broadcast_recv_counter = {value: 0 for value in range(world_size)} + + def broadcast_obj(self, obj: Any | None, *, src: int) -> Any: + if self.rank == src: + key = f"broadcast_from/{src}/{self._broadcast_send_counter}" + self.store.set(key, cast(Any, pickle.dumps(obj))) + self._broadcast_send_counter += 1 + return obj + key = f"broadcast_from/{src}/{self._broadcast_recv_counter[src]}" + received = pickle.loads(self.store.get(key)) + self._broadcast_recv_counter[src] += 1 + return received + + +class TrainerNcclCommunicator: + def __init__( + self, + *, + host: str, + port: int, + rank: int, + world_size: int, + device: int | torch.device, + ) -> None: + bootstrap_group = _BootstrapGroup( + host=host, + port=port, + rank=rank, + world_size=world_size, + ) + self._bootstrap_group = bootstrap_group + self.rank = rank + self.world_size = world_size + self.device = ( + torch.device(f"cuda:{device}") if isinstance(device, int) else device + ) + self._nccl = _NcclLibrary() + unique_id_bytes = ( + _nccl_unique_id_to_bytes(self._nccl.get_unique_id()) if rank == 0 else None + ) + unique_id = _nccl_unique_id_from_bytes( + bootstrap_group.broadcast_obj(unique_id_bytes, src=0) + ) + with torch.cuda.device(self.device): + self._comm = self._nccl.init_rank(world_size, unique_id, rank) + stream = torch.cuda.current_stream(self.device) + warmup = torch.zeros(1, device=self.device) + self.all_reduce(warmup, stream=stream) + stream.synchronize() + + def all_reduce( + self, + tensor: torch.Tensor, + *, + stream: torch.cuda.Stream | None = None, + ) -> None: + assert tensor.device == self.device + self._nccl.all_reduce( + tensor, + self._comm, + stream=stream or torch.cuda.current_stream(self.device), + ) + + def broadcast( + self, + tensor: torch.Tensor, + *, + src: int, + stream: torch.cuda.Stream | None = None, + ) -> None: + assert tensor.device == self.device + self._nccl.broadcast( + tensor, + self._comm, + rank=self.rank, + src=src, + stream=stream or torch.cuda.current_stream(self.device), + ) + + +def _find_nccl_library() -> str: + if override := os.environ.get("VLLM_NCCL_SO_PATH"): + return override + if torch.version.cuda is not None: + spec = importlib.util.find_spec("nvidia.nccl") + if spec is None or spec.submodule_search_locations is None: + raise RuntimeError( + "CUDA weight transfer requires the nvidia-nccl-cu12 package." + ) + nccl_library = ( + Path(next(iter(spec.submodule_search_locations))) / "lib" / "libnccl.so.2" + ) + if not nccl_library.exists(): + raise RuntimeError(f"nvidia-nccl-cu12 is missing {nccl_library}") + return str(nccl_library) + if torch.version.hip is not None: + return "librccl.so.1" + raise ValueError("NCCL only supports CUDA and ROCm backends.") + + +def trainer_init(init_info: dict[str, object]) -> TrainerNcclCommunicator: + return TrainerNcclCommunicator( + host=str(init_info["master_address"]), + port=int(cast(Any, init_info["master_port"])), + rank=0, + world_size=int(cast(Any, init_info["world_size"])), + device=torch.cuda.current_device(), + ) + + +def trainer_send_weights( + iterator: Any, + trainer_args: dict[str, Any] | TrainerNcclSendWeightsArgs, +) -> None: + args = ( + TrainerNcclSendWeightsArgs(**trainer_args) + if isinstance(trainer_args, dict) + else trainer_args + ) + post_iter_func = args.post_iter_func or (lambda item: item[1]) + if args.packed: + packed_broadcast_producer( + iterator=iterator, + group=args.group, + src=args.src, + post_iter_func=post_iter_func, + buffer_size_bytes=args.packed_buffer_size_bytes, + num_buffers=args.packed_num_buffers, + ) + return + for item in iterator: + tensor = post_iter_func(item) + args.group.broadcast( + tensor, + src=args.src, + stream=args.stream or torch.cuda.current_stream(tensor.device), + ) diff --git a/src/art/weight_transfer/packed_tensor.py b/src/art/weight_transfer/packed_tensor.py new file mode 100644 index 000000000..100bb5008 --- /dev/null +++ b/src/art/weight_transfer/packed_tensor.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Packed tensor utilities for efficient trainer-side weight transfer.""" + +from collections.abc import Callable, Iterator +import math +from typing import Any + +import torch + +DEFAULT_PACKED_BUFFER_SIZE_BYTES = 1024 * 1024 * 1024 +DEFAULT_PACKED_NUM_BUFFERS = 2 + + +def packed_broadcast_producer( + iterator: Iterator[tuple[str, torch.Tensor]], + group: Any, + src: int, + post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor], + buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES, + num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS, +) -> None: + target_packed_tensor_size = buffer_size_bytes + streams = [torch.cuda.Stream() for _ in range(num_buffers)] + buffer_idx = 0 + packing_tensor_list: list[list[torch.Tensor]] = [[] for _ in range(num_buffers)] + packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)] + packed_tensors: list[torch.Tensor] = [ + torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers) + ] + + while True: + streams[buffer_idx].synchronize() + with torch.cuda.stream(streams[buffer_idx]): + try: + packing_tensor_list[buffer_idx] = [] + packing_tensor_sizes[buffer_idx] = 0 + while True: + tensor = ( + post_iter_func(next(iterator)) + .contiguous() + .view(torch.uint8) + .view(-1) + ) + packing_tensor_list[buffer_idx].append(tensor) + packing_tensor_sizes[buffer_idx] += tensor.numel() + if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size: + break + packed_tensors[buffer_idx] = torch.cat( + packing_tensor_list[buffer_idx], dim=0 + ) + group.broadcast(packed_tensors[buffer_idx], src=src) + buffer_idx = (buffer_idx + 1) % num_buffers + except StopIteration: + if packing_tensor_list[buffer_idx]: + packed_tensors[buffer_idx] = torch.cat( + packing_tensor_list[buffer_idx], dim=0 + ) + group.broadcast(packed_tensors[buffer_idx], src=src) + break + for stream in streams: + stream.synchronize() + + +def packed_broadcast_consumer( + iterator: Iterator[tuple[str, tuple[list[int], torch.dtype]]], + group: Any, + src: int, + post_unpack_func: Callable[[list[tuple[str, torch.Tensor]]], None], + buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES, + num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS, +) -> None: + def unpack_tensor( + packed_tensor: torch.Tensor, + names: list[str], + shapes: list[list[int]], + dtypes: list[torch.dtype], + tensor_sizes: list[int], + ) -> list[tuple[str, torch.Tensor]]: + unpacked_tensors = packed_tensor.split(tensor_sizes) + return [ + (name, tensor.contiguous().view(dtype).view(*shape)) + for name, shape, dtype, tensor in zip( + names, shapes, dtypes, unpacked_tensors + ) + ] + + target_packed_tensor_size = buffer_size_bytes + streams = [torch.cuda.Stream() for _ in range(num_buffers)] + buffer_idx = 0 + packing_tensor_meta_data: list[list[tuple[str, list[int], torch.dtype, int]]] = [ + [] for _ in range(num_buffers) + ] + packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)] + packed_tensors: list[torch.Tensor] = [ + torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers) + ] + + while True: + streams[buffer_idx].synchronize() + with torch.cuda.stream(streams[buffer_idx]): + packing_tensor_meta_data[buffer_idx] = [] + packing_tensor_sizes[buffer_idx] = 0 + try: + while True: + name, (shape, dtype) = next(iterator) + tensor_size = math.prod(shape) * dtype.itemsize + packing_tensor_meta_data[buffer_idx].append( + (name, shape, dtype, tensor_size) + ) + packing_tensor_sizes[buffer_idx] += tensor_size + if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size: + break + packed_tensors[buffer_idx] = torch.empty( + packing_tensor_sizes[buffer_idx], dtype=torch.uint8, device="cuda" + ) + group.broadcast(packed_tensors[buffer_idx], src=src) + names, shapes, dtypes, tensor_sizes = zip( + *packing_tensor_meta_data[buffer_idx] + ) + post_unpack_func( + unpack_tensor( + packed_tensors[buffer_idx], + list(names), + list(shapes), + list(dtypes), + list(tensor_sizes), + ) + ) + buffer_idx = (buffer_idx + 1) % num_buffers + except StopIteration: + if packing_tensor_meta_data[buffer_idx]: + packed_tensors[buffer_idx] = torch.empty( + packing_tensor_sizes[buffer_idx], + dtype=torch.uint8, + device="cuda", + ) + group.broadcast(packed_tensors[buffer_idx], src=src) + names, shapes, dtypes, tensor_sizes = zip( + *packing_tensor_meta_data[buffer_idx] + ) + post_unpack_func( + unpack_tensor( + packed_tensors[buffer_idx], + list(names), + list(shapes), + list(dtypes), + list(tensor_sizes), + ) + ) + break diff --git a/src/mp_actors/move.py b/src/mp_actors/move.py index b1e5a4399..0831201b6 100644 --- a/src/mp_actors/move.py +++ b/src/mp_actors/move.py @@ -7,8 +7,10 @@ import multiprocessing as mp import os import queue +import signal import sys import threading +import time from typing import Any, AsyncGenerator, TypeVar, cast import weakref @@ -109,11 +111,51 @@ def __init__( self._process_name = process_name self._requests = mp.Queue() self._responses = mp.Queue() + ready = mp.Queue() self._process = mp.Process( target=_target, - args=(obj, self._requests, self._responses, log_file, process_name), + args=( + obj, + self._requests, + self._responses, + ready, + os.getpid(), + log_file, + process_name, + ), ) self._process.start() + startup_timeout = float(os.environ.get("ART_MP_ACTOR_START_TIMEOUT", 300.0)) + deadline = time.monotonic() + startup_timeout + try: + while True: + try: + ready_status, ready_payload = ready.get(timeout=0.1) + break + except queue.Empty as exc: + if not self._process.is_alive(): + self._process.join(timeout=1) + raise self._process_error() from exc + if time.monotonic() >= deadline: + raise RuntimeError( + f"Child process did not enter its process group within {startup_timeout:.1f}s" + ) from exc + except BaseException: + self._process.terminate() + self._process.join(timeout=1) + if self._process.is_alive(): + self._process.kill() + self._process.join(timeout=1) + raise + if ready_status != "ok": + self._process.terminate() + self._process.join(timeout=1) + raise RuntimeError( + f"Child process failed to enter process group: {ready_payload}" + ) + self._process_group_id = int(ready_payload) + ready.close() + ready.cancel_join_thread() self._futures: dict[int, Future] = {} self._futures_lock = threading.Lock() self._dead_process_error: RuntimeError | None = None @@ -257,11 +299,17 @@ def close(self): self._closing = True self._fail_pending(RuntimeError("Proxy is closing")) - # terminate child process and force kill if needed - self._process.terminate() + # terminate child process group and force kill if needed + try: + os.killpg(self._process_group_id, signal.SIGTERM) + except ProcessLookupError: + pass self._process.join(timeout=1) if self._process.is_alive(): - self._process.kill() + try: + os.killpg(self._process_group_id, signal.SIGKILL) + except ProcessLookupError: + pass self._process.join(timeout=1) # close and cancel queue feeder threads @@ -276,9 +324,45 @@ def _target( obj: object, requests: mp.Queue, responses: mp.Queue, + ready: mp.Queue, + parent_pid: int, log_file: str | None = None, process_name: str | None = None, ) -> None: + try: + if hasattr(os, "setsid") and os.getpgrp() != os.getpid(): + os.setsid() + ready.put_nowait(("ok", os.getpgrp())) + except BaseException as exc: + ready.put_nowait(("error", repr(exc))) + raise + + def monitor_parent() -> None: + while True: + if os.getppid() != parent_pid: + + def force_exit() -> None: + time.sleep(5) + try: + os.killpg(os.getpgrp(), signal.SIGKILL) + except ProcessLookupError: + pass + + threading.Thread(target=force_exit, daemon=True).start() + try: + close = getattr(obj, "close", None) + if callable(close): + close() + except BaseException: + pass + try: + os.killpg(os.getpgrp(), signal.SIGKILL) + except ProcessLookupError: + pass + os._exit(1) + time.sleep(0.5) + + threading.Thread(target=monitor_parent, daemon=True).start() if process_name: setproctitle.setproctitle(process_name) if log_file: diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..eafb9af57 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test helpers and integration modules.""" diff --git a/tests/integration/megatron/__init__.py b/tests/integration/megatron/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/integration/megatron/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/integration/megatron/lora/__init__.py b/tests/integration/megatron/lora/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/integration/megatron/lora/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/integration/megatron/lora/merged_vllm_serving.py b/tests/integration/megatron/lora/merged_vllm_serving.py new file mode 100644 index 000000000..2d63c996e --- /dev/null +++ b/tests/integration/megatron/lora/merged_vllm_serving.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import asyncio +import os +from pathlib import Path +import socket + +from pydantic import BaseModel, Field +import torch + +from art import dev +from art.megatron.service import MegatronService + +from ..model_support.oracle_harness import ( + ORACLE_TOPOLOGY, + OracleCaseConfig, + ensure_case_artifacts, +) +from ..model_support.oracle_worker import provider_topology_env + +_TRAINER_GPU_IDS_ENV = "ART_MODEL_SUPPORT_TRAINER_GPU_IDS" +_INFERENCE_GPU_IDS_ENV = "ART_MODEL_SUPPORT_INFERENCE_GPU_IDS" + + +class MergedVllmServingReport(BaseModel): + base_model: str + output_dir: str + host: str + port: int + trainer_gpu_ids: list[int] + inference_gpu_ids: list[int] + served_model_name: str + model_ids: list[str] = Field(default_factory=list) + completion_text: str = "" + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _parse_gpu_id_env(name: str) -> list[int] | None: + raw = os.environ.get(name) + if raw is None or raw.strip() == "": + return None + return [int(part.strip()) for part in raw.split(",") if part.strip()] + + +def _resolve_dedicated_gpu_ids() -> tuple[list[int], list[int]]: + trainer_gpu_ids = _parse_gpu_id_env(_TRAINER_GPU_IDS_ENV) + inference_gpu_ids = _parse_gpu_id_env(_INFERENCE_GPU_IDS_ENV) + if trainer_gpu_ids is not None or inference_gpu_ids is not None: + if trainer_gpu_ids is None or inference_gpu_ids is None: + raise RuntimeError( + f"{_TRAINER_GPU_IDS_ENV} and {_INFERENCE_GPU_IDS_ENV} must both be set" + ) + return trainer_gpu_ids, inference_gpu_ids + + visible_gpu_count = int(torch.cuda.device_count()) + if visible_gpu_count < 2: + raise RuntimeError( + f"Need at least 2 visible GPUs for merged serving, found {visible_gpu_count}" + ) + return [0], [1] + + +async def _run_merged_vllm_serving( + case_config: OracleCaseConfig, +) -> MergedVllmServingReport: + trainer_gpu_ids, inference_gpu_ids = _resolve_dedicated_gpu_ids() + service_name = "model_support_merged_validation" + case_artifacts = ensure_case_artifacts(case_config) + output_dir = str(Path(case_artifacts.case_dir) / "merged_vllm_serving") + os.makedirs(output_dir, exist_ok=True) + internal_config = dev.InternalModelConfig( + trainer_gpu_ids=trainer_gpu_ids, + inference_gpu_ids=inference_gpu_ids, + rollout_weights_mode="merged", + allow_unvalidated_arch=case_config.allow_unvalidated_arch, + ) + dev.validate_dedicated_config(internal_config) + with provider_topology_env(ORACLE_TOPOLOGY): + service = MegatronService( + model_name=service_name, + base_model=case_config.base_model, + config=internal_config, + output_dir=output_dir, + ) + port = _find_free_port() + try: + host, resolved_port = await service.start_openai_server( + {"server_args": {"port": port}} + ) + import httpx + + async with httpx.AsyncClient() as client: + models_response = await client.get( + f"http://{host}:{resolved_port}/v1/models", + timeout=60.0, + ) + models_response.raise_for_status() + model_ids = [ + str(model_info["id"]) + for model_info in models_response.json().get("data", []) + if isinstance(model_info, dict) and "id" in model_info + ] + + served_model_name = f"{service_name}@{service._latest_step}" + completion_response = await client.post( + f"http://{host}:{resolved_port}/v1/completions", + json={ + "model": served_model_name, + "prompt": "Hello", + "max_tokens": 1, + "temperature": 0.0, + }, + timeout=900.0, + ) + completion_response.raise_for_status() + completion_json = completion_response.json() + completion_text = str( + completion_json.get("choices", [{}])[0].get("text", "") + ) + return MergedVllmServingReport( + base_model=case_config.base_model, + output_dir=output_dir, + host=host, + port=resolved_port, + trainer_gpu_ids=trainer_gpu_ids, + inference_gpu_ids=inference_gpu_ids, + served_model_name=served_model_name, + model_ids=model_ids, + completion_text=completion_text, + ) + finally: + service.close() + + +def run_merged_vllm_serving( + case_config: OracleCaseConfig, +) -> MergedVllmServingReport: + return asyncio.run(_run_merged_vllm_serving(case_config)) diff --git a/tests/integration/megatron/lora/native_vllm_lora.py b/tests/integration/megatron/lora/native_vllm_lora.py new file mode 100644 index 000000000..e28597bbc --- /dev/null +++ b/tests/integration/megatron/lora/native_vllm_lora.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import asyncio +import os +from pathlib import Path +import shutil +import socket +import tempfile + +from pydantic import BaseModel, Field +import torch + +from art import dev +from art.megatron.service import MegatronService +from art.utils.output_dirs import get_step_checkpoint_dir + +from ..model_support.oracle_harness import ( + ORACLE_TOPOLOGY, + OracleCaseConfig, + ensure_case_artifacts, +) +from ..model_support.oracle_worker import provider_topology_env + +_TRAINER_GPU_IDS_ENV = "ART_MODEL_SUPPORT_TRAINER_GPU_IDS" +_INFERENCE_GPU_IDS_ENV = "ART_MODEL_SUPPORT_INFERENCE_GPU_IDS" + + +class NativeVllmLoraServingReport(BaseModel): + base_model: str + output_dir: str + host: str + port: int + trainer_gpu_ids: list[int] + inference_gpu_ids: list[int] + rollout_weights_mode: str = "lora" + step0_name: str + step1_name: str + model_ids_before: list[str] = Field(default_factory=list) + model_ids_after: list[str] = Field(default_factory=list) + step0_served: bool + step1_served: bool + step0_completion_text: str = "" + step1_completion_text: str = "" + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _parse_gpu_id_env(name: str) -> list[int] | None: + raw = os.environ.get(name) + if raw is None or raw.strip() == "": + return None + return [int(part.strip()) for part in raw.split(",") if part.strip()] + + +def _resolve_dedicated_gpu_ids() -> tuple[list[int], list[int]]: + trainer_gpu_ids = _parse_gpu_id_env(_TRAINER_GPU_IDS_ENV) + inference_gpu_ids = _parse_gpu_id_env(_INFERENCE_GPU_IDS_ENV) + if trainer_gpu_ids is not None or inference_gpu_ids is not None: + if trainer_gpu_ids is None or inference_gpu_ids is None: + raise RuntimeError( + f"{_TRAINER_GPU_IDS_ENV} and {_INFERENCE_GPU_IDS_ENV} must both be set" + ) + return trainer_gpu_ids, inference_gpu_ids + + visible_gpu_count = int(torch.cuda.device_count()) + if visible_gpu_count < 2: + raise RuntimeError( + f"Need at least 2 visible GPUs for native LoRA serving, found {visible_gpu_count}" + ) + return [0], [1] + + +async def _model_ids(client, base_url: str) -> list[str]: + response = await client.get(f"{base_url}/v1/models", timeout=60.0) + response.raise_for_status() + return [ + str(model_info["id"]) + for model_info in response.json().get("data", []) + if isinstance(model_info, dict) and "id" in model_info + ] + + +async def _completion_text(client, base_url: str, model_name: str) -> str: + response = await client.post( + f"{base_url}/v1/completions", + json={ + "model": model_name, + "prompt": "Hello", + "max_tokens": 1, + "temperature": 0.0, + }, + timeout=900.0, + ) + response.raise_for_status() + return str(response.json().get("choices", [{}])[0].get("text", "")) + + +def _copy_adapter_checkpoint(source_dir: str, dest_dir: str) -> None: + os.makedirs(dest_dir, exist_ok=True) + for filename in ("adapter_model.safetensors", "adapter_config.json"): + shutil.copy(Path(source_dir) / filename, Path(dest_dir) / filename) + + +async def _run_native_vllm_lora( + case_config: OracleCaseConfig, +) -> NativeVllmLoraServingReport: + trainer_gpu_ids, inference_gpu_ids = _resolve_dedicated_gpu_ids() + service_name = "model_support_native_lora_validation" + case_artifacts = ensure_case_artifacts(case_config) + output_root = Path(case_artifacts.case_dir) / "native_vllm_lora" + output_root.mkdir(parents=True, exist_ok=True) + output_dir = tempfile.mkdtemp(prefix="run_", dir=output_root) + internal_config = dev.InternalModelConfig( + trainer_gpu_ids=trainer_gpu_ids, + inference_gpu_ids=inference_gpu_ids, + rollout_weights_mode="lora", + allow_unvalidated_arch=case_config.allow_unvalidated_arch, + ) + dev.validate_dedicated_config(internal_config) + with provider_topology_env(ORACLE_TOPOLOGY): + service = MegatronService( + model_name=service_name, + base_model=case_config.base_model, + config=internal_config, + output_dir=output_dir, + ) + port = _find_free_port() + try: + host, resolved_port = await service.start_openai_server( + {"server_args": {"port": port}} + ) + import httpx + + base_url = f"http://{host}:{resolved_port}" + step0_name = f"{service_name}@0" + step1_name = f"{service_name}@1" + async with httpx.AsyncClient() as client: + model_ids_before = await _model_ids(client, base_url) + step0_completion_text = await _completion_text( + client, + base_url, + step0_name, + ) + step0_dir = get_step_checkpoint_dir(output_dir, 0) + step1_dir = get_step_checkpoint_dir(output_dir, 1) + _copy_adapter_checkpoint(step0_dir, step1_dir) + await service.register_lora_for_step(1, step1_dir) + model_ids_after = await _model_ids(client, base_url) + step1_completion_text = await _completion_text( + client, + base_url, + step1_name, + ) + + return NativeVllmLoraServingReport( + base_model=case_config.base_model, + output_dir=output_dir, + host=host, + port=resolved_port, + trainer_gpu_ids=trainer_gpu_ids, + inference_gpu_ids=inference_gpu_ids, + step0_name=step0_name, + step1_name=step1_name, + model_ids_before=model_ids_before, + model_ids_after=model_ids_after, + step0_served=True, + step1_served=True, + step0_completion_text=step0_completion_text, + step1_completion_text=step1_completion_text, + ) + finally: + service.close() + + +def run_native_vllm_lora( + case_config: OracleCaseConfig, +) -> NativeVllmLoraServingReport: + return asyncio.run(_run_native_vllm_lora(case_config)) diff --git a/tests/integration/megatron/lora/test_lora_disk_codecs.py b/tests/integration/megatron/lora/test_lora_disk_codecs.py new file mode 100644 index 000000000..be6075f5d --- /dev/null +++ b/tests/integration/megatron/lora/test_lora_disk_codecs.py @@ -0,0 +1,514 @@ +import json +from pathlib import Path +import subprocess +import sys + +from safetensors.torch import load_file, save_file +import torch + +from art.megatron.model_support.handlers import ( + DEFAULT_DENSE_HANDLER, + QWEN3_5_MOE_HANDLER, + QWEN3_MOE_HANDLER, +) +from art.megatron.weights.merge import load_lora_adapter_state_dict, merge_lora_adapter +from art.utils.convert_moe_lora import convert_checkpoint_if_needed + +REPO_ROOT = Path(__file__).parents[4] +VLLM_PYTHON = REPO_ROOT / "vllm_runtime/.venv/bin/python" + + +def _config(base_model: str, rank: int = 2, alpha: int = 4) -> dict: + return { + "base_model_name_or_path": base_model, + "r": rank, + "lora_alpha": alpha, + "target_modules": [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "in_proj_qkv", + "in_proj_z", + "out_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + "bias": "none", + } + + +def _qwen35_config(base_model: str, rank: int = 2, alpha: int = 4) -> dict: + config = _config(base_model, rank=rank, alpha=alpha) + config.update( + { + "num_attention_heads": 2, + "num_key_value_heads": 1, + "head_dim": 3, + } + ) + return config + + +def _assert_tensors_equal( + actual: dict[str, torch.Tensor], + expected: dict[str, torch.Tensor], +) -> None: + assert set(actual) == set(expected) + for key, tensor in expected.items(): + assert torch.equal(actual[key], tensor), key + + +def _save_adapter(path: Path, tensors: dict[str, torch.Tensor], config: dict) -> None: + path.mkdir(parents=True, exist_ok=True) + save_file(tensors, path / "adapter_model.safetensors") + (path / "adapter_config.json").write_text(json.dumps(config), encoding="utf-8") + + +def _assert_stock_vllm_loads( + path: Path, + *, + expected_modules: set[str], + mapper: str = "none", +) -> list[str]: + script = r""" +import json +import sys +from vllm.lora.lora_model import LoRAModel +from vllm.lora.peft_helper import PEFTHelper + +path = sys.argv[1] +expected = set(json.loads(sys.argv[2])) +mapper_name = sys.argv[3] +weights_mapper = None +if mapper_name == "qwen35": + from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration + weights_mapper = Qwen3VLForConditionalGeneration.hf_to_vllm_mapper +peft = PEFTHelper.from_local_dir(path, max_position_embeddings=None) +lora = LoRAModel.from_local_checkpoint( + path, + expected, + peft, + lora_model_id=1, + device="cpu", + weights_mapper=weights_mapper, +) +print(json.dumps(sorted(lora.loras))) +""" + result = subprocess.run( + [ + str(VLLM_PYTHON), + "-c", + script, + str(path), + json.dumps(sorted(expected_modules)), + mapper, + ], + check=True, + text=True, + capture_output=True, + ) + return json.loads(result.stdout.strip().splitlines()[-1]) + + +def _qwen35_moe_art_tensors(prefix: str, *, rank: int = 2) -> dict[str, torch.Tensor]: + hidden = 3 + q_out = 12 + intermediate = 4 + tensors: dict[str, torch.Tensor] = { + f"{prefix}.self_attn.q_proj.lora_A.weight": torch.arange( + rank * hidden, + dtype=torch.float32, + ).reshape(rank, hidden), + f"{prefix}.self_attn.q_proj.lora_B.weight": torch.arange( + q_out * rank, + dtype=torch.float32, + ).reshape(q_out, rank) + + 100, + } + offset = 200 + for expert in range(2): + for module in ("gate_up_proj", "down_proj"): + out_dim = hidden if module == "down_proj" else 2 * intermediate + in_dim = intermediate if module == "down_proj" else hidden + tensors[f"{prefix}.mlp.experts.{expert}.{module}.lora_A.weight"] = ( + torch.arange(rank * in_dim, dtype=torch.float32).reshape(rank, in_dim) + + offset + ) + offset += 100 + tensors[f"{prefix}.mlp.experts.{expert}.{module}.lora_B.weight"] = ( + torch.arange(out_dim * rank, dtype=torch.float32).reshape(out_dim, rank) + + offset + ) + offset += 100 + return tensors + + +def _qwen3_dense_lora_tensors(prefix: str, *, rank: int = 2) -> dict[str, torch.Tensor]: + module_dims = { + "self_attn.q_proj": (rank, 3, 3), + "self_attn.k_proj": (rank, 3, 3), + "self_attn.v_proj": (rank, 3, 3), + "self_attn.o_proj": (rank, 3, 3), + "mlp.gate_proj": (rank, 3, 4), + "mlp.up_proj": (rank, 3, 4), + "mlp.down_proj": (rank, 4, 3), + } + tensors: dict[str, torch.Tensor] = {} + offset = 0 + for module, (rank_dim, in_dim, out_dim) in module_dims.items(): + tensors[f"{prefix}.{module}.lora_A.weight"] = ( + torch.arange(rank_dim * in_dim, dtype=torch.float32).reshape( + rank_dim, + in_dim, + ) + + offset + ) + offset += 100 + tensors[f"{prefix}.{module}.lora_B.weight"] = ( + torch.arange(out_dim * rank_dim, dtype=torch.float32).reshape( + out_dim, + rank_dim, + ) + + offset + ) + offset += 100 + return tensors + + +def _qwen3_moe_lora_tensors(prefix: str, *, rank: int = 2) -> dict[str, torch.Tensor]: + tensors = { + key: value + for key, value in _qwen3_dense_lora_tensors(prefix, rank=rank).items() + if ".mlp." not in key + } + offset = 1000 + for expert in range(2): + for module, in_dim, out_dim in ( + ("gate_proj", 3, 4), + ("up_proj", 3, 4), + ("down_proj", 4, 3), + ): + expert_prefix = f"{prefix}.mlp.experts.{expert}.{module}" + tensors[f"{expert_prefix}.lora_A.weight"] = ( + torch.arange(rank * in_dim, dtype=torch.float32).reshape(rank, in_dim) + + offset + ) + offset += 100 + tensors[f"{expert_prefix}.lora_B.weight"] = ( + torch.arange(out_dim * rank, dtype=torch.float32).reshape(out_dim, rank) + + offset + ) + offset += 100 + return tensors + + +def test_peft_fused_moe_checkpoint_converts_to_vllm_3d_layout(tmp_path: Path) -> None: + prefix = "base_model.model.model.layers.0.mlp.experts" + peft_tensors = { + f"{prefix}.base_layer.lora_A.weight": torch.arange( + 2 * 8, + dtype=torch.float32, + ).reshape(2, 8), + f"{prefix}.base_layer.lora_B.weight": torch.arange( + 3 * 2, + dtype=torch.float32, + ).reshape(3, 2) + + 100, + f"{prefix}.lora_A.weight": torch.arange( + 2 * 3, + dtype=torch.float32, + ).reshape(2, 3) + + 200, + f"{prefix}.lora_B.weight": torch.arange( + 4 * 2, + dtype=torch.float32, + ).reshape(4, 2) + + 300, + } + _save_adapter( + tmp_path, + peft_tensors, + { + "r": 1, + "lora_alpha": 1, + "target_modules": ["q_proj"], + "target_parameters": [ + "model.layers.0.mlp.experts.gate_up_proj", + "model.layers.0.mlp.experts.down_proj", + ], + }, + ) + + convert_checkpoint_if_needed(str(tmp_path)) + + converted = load_file(tmp_path / "adapter_model.safetensors") + _assert_tensors_equal( + converted, + { + f"{prefix}.base_layer.lora_A.weight": peft_tensors[ + f"{prefix}.base_layer.lora_B.weight" + ].T.contiguous(), + f"{prefix}.base_layer.lora_B.weight": peft_tensors[ + f"{prefix}.base_layer.lora_A.weight" + ].T.contiguous(), + f"{prefix}.lora_A.weight": peft_tensors[ + f"{prefix}.lora_B.weight" + ].T.contiguous(), + f"{prefix}.lora_B.weight": peft_tensors[ + f"{prefix}.lora_A.weight" + ].T.contiguous(), + }, + ) + adapter_config = json.loads((tmp_path / "adapter_config.json").read_text()) + assert adapter_config["target_modules"] == ["q_proj", "experts"] + assert "target_parameters" not in adapter_config + + +def test_qwen35_and_qwen36_vllm_canonical_roundtrip_and_stock_loader(tmp_path: Path): + art_prefix = "base_model.model.model.layers.0" + original = _qwen35_moe_art_tensors(art_prefix) + for base_model in ("Qwen/Qwen3.5-35B-A3B", "Qwen/Qwen3.6-35B-A3B"): + vllm_tensors, vllm_config = QWEN3_5_MOE_HANDLER.to_vllm_lora_tensors( + original, + adapter_config=_qwen35_config(base_model), + ) + assert vllm_config["r"] == 2 + assert vllm_config["lora_alpha"] == 4 + assert vllm_config["target_modules"] == [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "in_proj_qkv", + "in_proj_z", + "out_proj", + "experts", + ] + assert all("language_model.layers" in key for key in vllm_tensors) + roundtrip = QWEN3_5_MOE_HANDLER.from_vllm_lora_tensors( + vllm_tensors, + adapter_config=vllm_config, + ) + _assert_tensors_equal(roundtrip, original) + adapter_dir = tmp_path / base_model.replace("/", "_") + _save_adapter(adapter_dir, vllm_tensors, vllm_config) + loaded_modules = _assert_stock_vllm_loads( + adapter_dir, + expected_modules=set(vllm_config["target_modules"]), + mapper="qwen35", + ) + assert "language_model.model.layers.0.mlp.experts" in loaded_modules + assert "language_model.model.layers.0.mlp.experts.base_layer" in loaded_modules + + +def test_qwen35_and_qwen36_dense_prefix_roundtrip_and_stock_loader(tmp_path: Path): + original = { + "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.ones( + 2, + 3, + ), + "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": torch.ones( + 12, + 2, + ), + } + for base_model in ("Qwen/Qwen3.5-4B", "Qwen/Qwen3.6-4B"): + vllm_tensors, vllm_config = QWEN3_5_MOE_HANDLER.to_vllm_lora_tensors( + original, + adapter_config=_qwen35_config(base_model), + ) + assert set(vllm_tensors) == { + key.replace( + "base_model.model.model.layers.", + "base_model.model.model.language_model.layers.", + ) + for key in original + } + roundtrip = QWEN3_5_MOE_HANDLER.from_vllm_lora_tensors( + vllm_tensors, + adapter_config=vllm_config, + ) + _assert_tensors_equal(roundtrip, original) + adapter_dir = tmp_path / base_model.replace("/", "_") + _save_adapter(adapter_dir, vllm_tensors, vllm_config) + loaded_modules = _assert_stock_vllm_loads( + adapter_dir, + expected_modules={"q_proj"}, + mapper="qwen35", + ) + assert loaded_modules == ["language_model.model.layers.0.self_attn.q_proj"] + + +def test_qwen3_dense_and_moe_are_already_vllm_canonical(tmp_path: Path): + dense = _qwen3_dense_lora_tensors("base_model.model.model.layers.0") + assert ( + DEFAULT_DENSE_HANDLER.to_vllm_lora_tensors( + dense, + adapter_config=_config("Qwen/Qwen3-0.6B"), + )[0] + == dense + ) + dense_dir = tmp_path / "qwen3_dense" + _save_adapter(dense_dir, dense, _config("Qwen/Qwen3-0.6B")) + assert _assert_stock_vllm_loads( + dense_dir, + expected_modules={ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + }, + ) == [ + "model.layers.0.mlp.down_proj", + "model.layers.0.mlp.gate_proj", + "model.layers.0.mlp.up_proj", + "model.layers.0.self_attn.k_proj", + "model.layers.0.self_attn.o_proj", + "model.layers.0.self_attn.q_proj", + "model.layers.0.self_attn.v_proj", + ] + + moe = _qwen3_moe_lora_tensors("base_model.model.model.layers.0") + assert ( + QWEN3_MOE_HANDLER.to_vllm_lora_tensors( + moe, + adapter_config=_config("Qwen/Qwen3-30B-A3B"), + )[0] + == moe + ) + moe_dir = tmp_path / "qwen3_moe" + _save_adapter(moe_dir, moe, _config("Qwen/Qwen3-30B-A3B")) + assert _assert_stock_vllm_loads( + moe_dir, + expected_modules={ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "experts.0.gate_proj", + "experts.0.up_proj", + "experts.0.down_proj", + "experts.1.gate_proj", + "experts.1.up_proj", + "experts.1.down_proj", + }, + ) == [ + "model.layers.0.mlp.experts.0.down_proj", + "model.layers.0.mlp.experts.0.gate_proj", + "model.layers.0.mlp.experts.0.up_proj", + "model.layers.0.mlp.experts.1.down_proj", + "model.layers.0.mlp.experts.1.gate_proj", + "model.layers.0.mlp.experts.1.up_proj", + "model.layers.0.self_attn.k_proj", + "model.layers.0.self_attn.o_proj", + "model.layers.0.self_attn.q_proj", + "model.layers.0.self_attn.v_proj", + ] + + +def test_qwen35_megatron_shards_merge_to_vllm_checkpoint_and_roundtrip( + tmp_path: Path, +): + prefix = "base_model.model.model.layers.0.mlp.experts.0" + rank = 1 + hidden = 2 + intermediate = 4 + full = { + f"{prefix}.gate_up_proj.lora_A.weight": torch.tensor([[1.0, 2.0]]), + f"{prefix}.gate_up_proj.lora_B.weight": torch.arange( + 2 * intermediate * rank, + dtype=torch.float32, + ).reshape(2 * intermediate, rank), + f"{prefix}.down_proj.lora_A.weight": torch.arange( + rank * intermediate, + dtype=torch.float32, + ).reshape(rank, intermediate) + + 20, + f"{prefix}.down_proj.lora_B.weight": torch.arange( + hidden * rank, + dtype=torch.float32, + ).reshape(hidden, rank) + + 30, + } + + def unsharded() -> dict: + return {"sharded": False, "shard_world_size": 1, "shard_rank": 0} + + def sharded(rank_id: int, dim: int) -> dict: + return { + "sharded": True, + "shard_world_size": 2, + "shard_rank": rank_id, + "export_shard_dim": dim, + "export_shard_strategy": "uniform", + } + + shard0 = { + f"{prefix}.gate_up_proj.lora_A.weight": full[ + f"{prefix}.gate_up_proj.lora_A.weight" + ], + f"{prefix}.down_proj.lora_B.weight": full[f"{prefix}.down_proj.lora_B.weight"], + f"{prefix}.gate_up_proj.lora_B.weight": full[ + f"{prefix}.gate_up_proj.lora_B.weight" + ][:4], + f"{prefix}.down_proj.lora_A.weight": full[f"{prefix}.down_proj.lora_A.weight"][ + :, :2 + ], + } + manifest0 = { + f"{prefix}.gate_up_proj.lora_A.weight": unsharded(), + f"{prefix}.down_proj.lora_B.weight": unsharded(), + f"{prefix}.gate_up_proj.lora_B.weight": sharded(0, 0), + f"{prefix}.down_proj.lora_A.weight": sharded(0, 1), + } + shard1 = { + f"{prefix}.gate_up_proj.lora_B.weight": full[ + f"{prefix}.gate_up_proj.lora_B.weight" + ][4:], + f"{prefix}.down_proj.lora_A.weight": full[f"{prefix}.down_proj.lora_A.weight"][ + :, 2: + ], + } + manifest1 = { + f"{prefix}.gate_up_proj.lora_B.weight": sharded(1, 0), + f"{prefix}.down_proj.lora_A.weight": sharded(1, 1), + } + adapter_dir = tmp_path / "qwen35_megatron_shards" + adapter_dir.mkdir() + (adapter_dir / "adapter_config.json").write_text( + json.dumps(_config("Qwen/Qwen3.5-35B-A3B", rank=rank, alpha=rank)), + encoding="utf-8", + ) + save_file(shard0, adapter_dir / "adapter_model-01-of-02.safetensors") + save_file(shard1, adapter_dir / "adapter_model-02-of-02.safetensors") + (adapter_dir / "adapter_manifest-01-of-02.json").write_text( + json.dumps(manifest0), + encoding="utf-8", + ) + (adapter_dir / "adapter_manifest-02-of-02.json").write_text( + json.dumps(manifest1), + encoding="utf-8", + ) + + merge_lora_adapter(str(adapter_dir)) + + assert not list(adapter_dir.glob("adapter_model-*-of-*.safetensors")) + assert not list(adapter_dir.glob("adapter_manifest-*-of-*.json")) + roundtrip = load_lora_adapter_state_dict( + str(adapter_dir), + handler=QWEN3_5_MOE_HANDLER, + ) + _assert_tensors_equal(roundtrip, full) + final_config = json.loads((adapter_dir / "adapter_config.json").read_text()) + loaded_modules = _assert_stock_vllm_loads( + adapter_dir, + expected_modules=set(final_config["target_modules"]), + mapper="qwen35", + ) + assert "language_model.model.layers.0.mlp.experts" in loaded_modules + assert "language_model.model.layers.0.mlp.experts.base_layer" in loaded_modules diff --git a/tests/integration/megatron/lora/test_merged_weight_export.py b/tests/integration/megatron/lora/test_merged_weight_export.py new file mode 100644 index 000000000..a495f8ce9 --- /dev/null +++ b/tests/integration/megatron/lora/test_merged_weight_export.py @@ -0,0 +1,264 @@ +from typing import Any, cast + +import httpx +import torch + +from art.megatron.runtime.jobs import ( + MergedWeightTransferInitInfo, + MergedWeightTransferSpec, +) +import art.megatron.weights.merged_weight_export as export + + +def _spec() -> MergedWeightTransferSpec: + return MergedWeightTransferSpec( + init_info=MergedWeightTransferInitInfo( + master_address="127.0.0.1", + master_port=23456, + rank_offset=1, + world_size=3, + ), + vllm_base_url="http://runtime.test", + served_model_name="model@7", + ) + + +class _OkResponse: + def raise_for_status(self) -> None: + return None + + +def test_ensure_merged_weight_transfer_group_rank_zero_initializes_runtime_and_trainer( + monkeypatch, +) -> None: + spec = _spec() + calls: list[tuple[str, object]] = [] + + def fake_trainer_init(init_info: dict[str, object]) -> str: + calls.append(("trainer_init", init_info)) + return "trainer-group" + + def fake_post(url: str, *, json: dict[str, object], timeout: float) -> _OkResponse: + calls.append(("post", (url, json, timeout))) + return _OkResponse() + + monkeypatch.setattr(export, "trainer_init", fake_trainer_init) + monkeypatch.setattr(httpx, "post", fake_post) + monkeypatch.setattr(export, "_maybe_distributed_barrier", lambda world_size: None) + + group, init_info = export.ensure_merged_weight_transfer_group( + rank=0, + world_size=2, + merged_weight_transfer_group=None, + merged_weight_transfer_init_info=None, + spec=spec, + ) + + assert group == "trainer-group" + assert init_info == spec.init_info + assert sorted(calls, key=lambda item: item[0]) == [ + ( + "post", + ( + "http://runtime.test/init_weight_transfer_engine", + {"init_info": spec.init_info.model_dump()}, + 300.0, + ), + ), + ( + "trainer_init", + { + "master_address": "127.0.0.1", + "master_port": 23456, + "world_size": 3, + }, + ), + ] + + +def test_ensure_merged_weight_transfer_group_non_sender_skips_runtime_init( + monkeypatch, +) -> None: + spec = _spec() + barriers: list[int] = [] + + monkeypatch.setattr( + export, + "trainer_init", + lambda init_info: (_ for _ in ()).throw( + AssertionError("unexpected trainer_init") + ), + ) + monkeypatch.setattr( + httpx, + "post", + lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError("unexpected post") + ), + ) + monkeypatch.setattr(export, "_maybe_distributed_barrier", barriers.append) + + group, init_info = export.ensure_merged_weight_transfer_group( + rank=1, + world_size=2, + merged_weight_transfer_group=None, + merged_weight_transfer_init_info=None, + spec=spec, + ) + + assert group is None + assert init_info == spec.init_info + assert barriers == [] + + +def test_sync_merged_weights_to_vllm_non_sender_only_drains_export( + monkeypatch, +) -> None: + spec = _spec() + barrier_calls: list[int] = [] + iter_passes: list[int] = [] + + monkeypatch.setattr( + export, + "ensure_merged_weight_transfer_group", + lambda **kwargs: (None, spec.init_info), + ) + monkeypatch.setattr(export, "build_merged_weight_export", lambda **kwargs: object()) + + def fake_iter(_weight_export: object): + iter_passes.append(len(iter_passes) + 1) + yield ("layer.weight", torch.zeros((2, 3), dtype=torch.float16)) + yield ("layer.bias", torch.zeros((3,), dtype=torch.float32)) + + monkeypatch.setattr(export, "iter_merged_vllm_weights", fake_iter) + monkeypatch.setattr(export, "_maybe_distributed_barrier", barrier_calls.append) + monkeypatch.setattr(torch.cuda, "synchronize", lambda: None) + monkeypatch.setattr( + export, + "trainer_send_weights", + lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError("unexpected send") + ), + ) + monkeypatch.setattr( + httpx, + "Client", + lambda: (_ for _ in ()).throw(AssertionError("unexpected http client")), + ) + + group, init_info = export.sync_merged_weights_to_vllm( + bridge=object(), + model=cast(Any, object()), + model_support_handler=object(), + rank=1, + world_size=2, + merged_weight_transfer_group=None, + merged_weight_transfer_init_info=None, + spec=spec, + pause_generation=True, + ) + + assert group is None + assert init_info == spec.init_info + assert iter_passes == [1, 2] + assert barrier_calls == [2] + + +def test_sync_merged_weights_to_vllm_sender_controls_runtime_and_sends( + monkeypatch, +) -> None: + spec = _spec() + barrier_calls: list[int] = [] + sent_items: list[list[tuple[str, torch.Tensor]]] = [] + posts: list[ + tuple[str, dict[str, object] | None, dict[str, object] | None, float] + ] = [] + + monkeypatch.setattr( + export, + "ensure_merged_weight_transfer_group", + lambda **kwargs: ("trainer-group", spec.init_info), + ) + monkeypatch.setattr(export, "build_merged_weight_export", lambda **kwargs: object()) + + def fake_iter(_weight_export: object): + yield ("layer.weight", torch.zeros((2, 3), dtype=torch.float16)) + yield ("layer.bias", torch.zeros((3,), dtype=torch.float32)) + + def fake_send(iterator, trainer_args): + sent_items.append(list(iterator)) + assert trainer_args["group"] == "trainer-group" + assert trainer_args["packed"] is True + + class FakeClient: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return None + + def post( + self, + url: str, + *, + json: dict[str, object] | None = None, + params: dict[str, object] | None = None, + timeout: float, + ) -> _OkResponse: + posts.append((url, json, params, timeout)) + return _OkResponse() + + monkeypatch.setattr(export, "iter_merged_vllm_weights", fake_iter) + monkeypatch.setattr(export, "trainer_send_weights", fake_send) + monkeypatch.setattr(export, "_maybe_distributed_barrier", barrier_calls.append) + monkeypatch.setattr(torch.cuda, "synchronize", lambda: None) + monkeypatch.setattr(httpx, "Client", FakeClient) + + group, init_info = export.sync_merged_weights_to_vllm( + bridge=object(), + model=cast(Any, object()), + model_support_handler=object(), + rank=0, + world_size=2, + merged_weight_transfer_group=None, + merged_weight_transfer_init_info=None, + spec=spec, + pause_generation=True, + ) + + assert group == "trainer-group" + assert init_info == spec.init_info + assert [name for name, _ in sent_items[0]] == ["layer.weight", "layer.bias"] + assert posts == [ + ("http://runtime.test/pause", None, {"mode": "wait"}, 300.0), + ( + "http://runtime.test/start_weight_update", + {"is_checkpoint_format": True}, + None, + 300.0, + ), + ( + "http://runtime.test/update_weights", + { + "update_info": { + "names": ["layer.weight", "layer.bias"], + "dtype_names": ["float16", "float32"], + "shapes": [[2, 3], [3]], + "packed": True, + "packed_buffer_size_bytes": export.DEFAULT_PACKED_BUFFER_SIZE_BYTES, + "packed_num_buffers": export.DEFAULT_PACKED_NUM_BUFFERS, + } + }, + None, + 600.0, + ), + ("http://runtime.test/finish_weight_update", None, None, 600.0), + ( + "http://runtime.test/art/set_served_model_name", + {"name": "model@7"}, + None, + 30.0, + ), + ("http://runtime.test/resume", None, None, 30.0), + ] + assert barrier_calls == [2] diff --git a/tests/integration/megatron/lora/test_weight_transfer_bootstrap_contract.py b/tests/integration/megatron/lora/test_weight_transfer_bootstrap_contract.py new file mode 100644 index 000000000..07676bd1b --- /dev/null +++ b/tests/integration/megatron/lora/test_weight_transfer_bootstrap_contract.py @@ -0,0 +1,60 @@ +from contextlib import nullcontext +from types import SimpleNamespace + +import pytest +import torch + +import art.weight_transfer.nccl as nccl + + +def test_trainer_nccl_unique_id_round_trips_as_raw_bytes() -> None: + payload = bytes(range(128)) + unique_id = nccl._nccl_unique_id_from_bytes(payload) + assert nccl._nccl_unique_id_to_bytes(unique_id) == payload + + +def test_trainer_nccl_communicator_retains_bootstrap_group( + monkeypatch: pytest.MonkeyPatch, +) -> None: + payload = bytes(range(128)) + bootstrap_group = SimpleNamespace( + broadcast_obj=lambda obj, src: obj if obj is not None else payload + ) + + class FakeNcclLibrary: + def get_unique_id(self): + return nccl._nccl_unique_id_from_bytes(payload) + + def init_rank(self, world_size, unique_id, rank): + assert world_size == 2 + assert rank == 0 + assert nccl._nccl_unique_id_to_bytes(unique_id) == payload + return "comm" + + monkeypatch.setattr(nccl, "_BootstrapGroup", lambda **kwargs: bootstrap_group) + monkeypatch.setattr(nccl, "_NcclLibrary", FakeNcclLibrary) + monkeypatch.setattr(torch.cuda, "device", lambda device: nullcontext()) + monkeypatch.setattr( + torch.cuda, + "current_stream", + lambda device=None: SimpleNamespace(synchronize=lambda: None), + ) + monkeypatch.setattr( + nccl.TrainerNcclCommunicator, + "all_reduce", + lambda self, tensor, *, stream=None: None, + ) + monkeypatch.setattr( + torch, + "zeros", + lambda *args, **kwargs: SimpleNamespace(device=torch.device("cuda:0")), + ) + + communicator = nccl.TrainerNcclCommunicator( + host="127.0.0.1", + port=12345, + rank=0, + world_size=2, + device=0, + ) + assert communicator._bootstrap_group is bootstrap_group diff --git a/tests/integration/megatron/model_support/__init__.py b/tests/integration/megatron/model_support/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/integration/megatron/model_support/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/integration/megatron/model_support/chat_template_rollout.py b/tests/integration/megatron/model_support/chat_template_rollout.py new file mode 100644 index 000000000..84311755a --- /dev/null +++ b/tests/integration/megatron/model_support/chat_template_rollout.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +from pathlib import Path + +from pydantic import BaseModel, Field + +import art +from art.local import LocalBackend +from art.preprocessing.pack import PackedTensors +from art.preprocessing.tokenize import ( + TokenizedResult, + _apply_chat_template_token_ids, + _messages_for_chat_template, + tokenize_trajectory, + tokenize_trajectory_groups, +) +from art.trajectories import History +from tests.support.chat_template_conformance_cases import ( + build_chat_template_conformance_inputs, +) + + +def _slugify(value: str) -> str: + return value.lower().replace("/", "_").replace(".", "_").replace("-", "_") + + +def _artifact_dir(base_model: str) -> Path: + root = Path(__file__).resolve().parents[4] / ".local" / "model_support_validation" + path = root / _slugify(base_model) / "chat_template_rollout" + path.mkdir(parents=True, exist_ok=True) + return path + + +def _history(trajectory: art.Trajectory) -> History: + return History( + messages_and_choices=trajectory.messages_and_choices, + tools=trajectory.tools, + ) + + +def _pack_trajectory_group( + backend: LocalBackend, + model: art.TrainableModel, + trajectory_group: art.TrajectoryGroup, +) -> PackedTensors: + packed_tensors = backend._get_packed_tensors( + model, + [trajectory_group], + advantage_balance=0.0, + allow_training_without_logprobs=False, + scale_rewards=True, + plot_tensors=False, + packed_sequence_length=512, + logprob_calculation_chunk_size=256, + ) + if packed_tensors is None: + raise RuntimeError("chat template conformance produced no packed tensors") + return packed_tensors + + +def _assistant_prefix_tokens( + result: TokenizedResult, + *, + choice_index: int = 0, +) -> list[int]: + if not result.choice_offsets: + raise RuntimeError("Expected at least one trainable assistant span") + return result.token_ids[: result.choice_offsets[choice_index]] + + +class ChatTemplateScenarioReport(BaseModel): + name: str + entrypoint: str + passed: bool + assistant_token_count: int = 0 + packed_num_sequences: int = 0 + packed_sequence_length: int = 0 + result_count: int = 0 + num_tokens: int = 0 + num_trainable_tokens: int = 0 + mutation_changed_prompt: bool = False + expected_error_substring: str | None = None + observed_error: str | None = None + + +class ChatTemplateRolloutReport(BaseModel): + base_model: str + output_dir: str + passed: bool + scenario_count: int + failed_scenarios: list[str] = Field(default_factory=list) + scenarios: list[ChatTemplateScenarioReport] = Field(default_factory=list) + + +def run_chat_template_rollout(base_model: str) -> ChatTemplateRolloutReport: + output_dir = _artifact_dir(base_model) + backend = LocalBackend(path=str(output_dir)) + model = art.TrainableModel( + name="model-support-chat-template", + project="model-support-validation", + base_model=base_model, + _internal_config={"init_args": {"max_seq_length": 2048}}, + ) + tokenizer = backend._tokenizers.get(base_model) + if tokenizer is None: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(base_model) + backend._tokenizers[base_model] = tokenizer + + inputs = build_chat_template_conformance_inputs(tokenizer) + scenarios: list[ChatTemplateScenarioReport] = [] + + text_pack = _pack_trajectory_group(backend, model, inputs.text_pack_group) + scenarios.append( + ChatTemplateScenarioReport( + name="rl_text_pack", + entrypoint="LocalBackend._get_packed_tensors", + passed=int(text_pack["assistant_mask"].sum().item()) > 0, + assistant_token_count=int(text_pack["assistant_mask"].sum().item()), + packed_num_sequences=int(text_pack["tokens"].shape[0]), + packed_sequence_length=int(text_pack["tokens"].shape[1]), + ) + ) + + non_final_tool_call_base = tokenize_trajectory( + tokenizer=tokenizer, + image_processor=None, + history=_history(inputs.non_final_tool_call_base), + advantage=1.0, + allow_training_without_logprobs=False, + trajectory=inputs.non_final_tool_call_base, + ) + non_final_tool_call_mutated = tokenize_trajectory( + tokenizer=tokenizer, + image_processor=None, + history=_history(inputs.non_final_tool_call_mutated), + advantage=1.0, + allow_training_without_logprobs=False, + trajectory=inputs.non_final_tool_call_mutated, + ) + if non_final_tool_call_base is None or non_final_tool_call_mutated is None: + raise RuntimeError("tool-call tokenization produced no trainable tokens") + if ( + len(non_final_tool_call_base.choice_offsets) < 2 + or len(non_final_tool_call_mutated.choice_offsets) < 2 + ): + raise RuntimeError("expected non-final tool call and final assistant answer") + non_final_tool_call_prefix_changed = _assistant_prefix_tokens( + non_final_tool_call_base, + choice_index=-1, + ) != _assistant_prefix_tokens( + non_final_tool_call_mutated, + choice_index=-1, + ) + scenarios.append( + ChatTemplateScenarioReport( + name="rl_non_final_tool_call_prefill_mutation", + entrypoint="tokenize_trajectory", + passed=non_final_tool_call_prefix_changed + and int(sum(non_final_tool_call_base.assistant_mask)) > 0, + assistant_token_count=int(sum(non_final_tool_call_base.assistant_mask)), + mutation_changed_prompt=non_final_tool_call_prefix_changed, + ) + ) + + tool_conversation_pack = _pack_trajectory_group( + backend, + model, + inputs.tool_conversation_group, + ) + scenarios.append( + ChatTemplateScenarioReport( + name="rl_tool_conversation_pack", + entrypoint="LocalBackend._get_packed_tensors", + passed=int(tool_conversation_pack["assistant_mask"].sum().item()) > 0, + assistant_token_count=int( + tool_conversation_pack["assistant_mask"].sum().item() + ), + packed_num_sequences=int(tool_conversation_pack["tokens"].shape[0]), + packed_sequence_length=int(tool_conversation_pack["tokens"].shape[1]), + ) + ) + + additional_history_results = list( + tokenize_trajectory_groups( + tokenizer, + [inputs.additional_histories_group], + allow_training_without_logprobs=False, + scale_rewards=True, + ) + ) + additional_histories_pack = _pack_trajectory_group( + backend, + model, + inputs.additional_histories_group, + ) + scenarios.append( + ChatTemplateScenarioReport( + name="additional_histories_pack", + entrypoint="tokenize_trajectory_groups + LocalBackend._get_packed_tensors", + passed=len(additional_history_results) >= 4 + and int(additional_histories_pack["assistant_mask"].sum().item()) > 0, + assistant_token_count=int( + additional_histories_pack["assistant_mask"].sum().item() + ), + packed_num_sequences=int(additional_histories_pack["tokens"].shape[0]), + packed_sequence_length=int(additional_histories_pack["tokens"].shape[1]), + result_count=len(additional_history_results), + ) + ) + + full_conversation_messages = _messages_for_chat_template( + tokenizer, + inputs.sft_tool_conversation.messages_and_choices, + ) + full_conversation_mutated_messages = _messages_for_chat_template( + tokenizer, + inputs.sft_tool_conversation_mutated.messages_and_choices, + ) + full_conversation_input_ids = _apply_chat_template_token_ids( + tokenizer, + full_conversation_messages, + tools=inputs.sft_tool_conversation.tools, + tokenize=True, + add_generation_prompt=False, + ) + full_conversation_mutated_input_ids = _apply_chat_template_token_ids( + tokenizer, + full_conversation_mutated_messages, + tools=inputs.sft_tool_conversation_mutated.tools, + tokenize=True, + add_generation_prompt=False, + ) + scenarios.append( + ChatTemplateScenarioReport( + name="full_conversation_token_mutation", + entrypoint="_apply_chat_template_token_ids", + passed=full_conversation_input_ids != full_conversation_mutated_input_ids + and len(full_conversation_input_ids) > 0, + num_tokens=len(full_conversation_input_ids), + mutation_changed_prompt=( + full_conversation_input_ids != full_conversation_mutated_input_ids + ), + ) + ) + + expected_error = "Assistant message has tool_calls" + observed_error: str | None = None + try: + tokenize_trajectory( + tokenizer=tokenizer, + image_processor=None, + history=_history(inputs.unsupported_assistant_tool_calls), + advantage=1.0, + allow_training_without_logprobs=True, + trajectory=inputs.unsupported_assistant_tool_calls, + ) + except ValueError as exc: + observed_error = str(exc) + scenarios.append( + ChatTemplateScenarioReport( + name="unsupported_assistant_tool_calls_without_logprobs", + entrypoint="tokenize_trajectory", + passed=observed_error is not None and expected_error in observed_error, + expected_error_substring=expected_error, + observed_error=observed_error, + ) + ) + + failed_scenarios = [scenario.name for scenario in scenarios if not scenario.passed] + report = ChatTemplateRolloutReport( + base_model=base_model, + output_dir=str(output_dir), + passed=not failed_scenarios, + scenario_count=len(scenarios), + failed_scenarios=failed_scenarios, + scenarios=scenarios, + ) + (output_dir / "report.json").write_text( + report.model_dump_json(indent=2), + encoding="utf-8", + ) + return report diff --git a/tests/integration/megatron_forward_trace.py b/tests/integration/megatron/model_support/forward_trace.py similarity index 72% rename from tests/integration/megatron_forward_trace.py rename to tests/integration/megatron/model_support/forward_trace.py index 98f43fc65..41bbd0d38 100644 --- a/tests/integration/megatron_forward_trace.py +++ b/tests/integration/megatron/model_support/forward_trace.py @@ -7,6 +7,14 @@ import torch CAPTURE_NAME_TOKENS = ( + ".self_attention", + ".self_attention.in_proj", + ".self_attention.in_proj.in_proj", + ".self_attention.in_proj.qkv_lora", + ".self_attention.in_proj.z_lora", + ".self_attention.out_norm", + ".self_attention.out_proj", + ".self_attention.out_proj.lora", ".self_attention.linear_qkv", ".self_attention.linear_qkv.q_proj_lora", ".self_attention.linear_qkv.k_proj_lora", @@ -19,6 +27,12 @@ ".mlp.experts.linear_fc1.up_lora", ".mlp.experts.linear_fc2", ".mlp.experts.linear_fc2.lora", + ".mlp.linear_fc1", + ".mlp.linear_fc1.gate_lora", + ".mlp.linear_fc1.up_lora", + ".mlp.linear_fc2", + ".mlp.linear_fc2.row_parallel_lora", + ".mlp.linear_fc2.row_parallel_lora.lora", ) ROUTER_NAME_TOKEN = ".mlp.router" PRIMARY_OUTPUT_CANONICAL_KEY = "primary_output__is_canonical" @@ -127,6 +141,14 @@ def _shard_world_size_for_domain(domain: Any) -> int: return 1 +def _world_size_key_for_domain(domain: Any) -> str | None: + if domain == "tp": + return "tp_world_size" + if domain == "expert_tp": + return "etp_world_size" + return None + + def _extract_primary_tensor(value: Any) -> torch.Tensor | None: if isinstance(value, torch.Tensor): return value @@ -186,6 +208,7 @@ def _extract_tensor_attr(value: Any, attr_name: str) -> Any: return None +@torch._dynamo.disable def _extract_router_topk(output: Any) -> tuple[torch.Tensor, torch.Tensor] | None: if not isinstance(output, tuple) or len(output) < 2: return None @@ -214,10 +237,12 @@ def __init__( enabled: bool, capture_name_tokens: tuple[str, ...] = CAPTURE_NAME_TOKENS, micro_start_callback: Callable[[int | None, int], None] | None = None, + strict_output_match: bool = True, ) -> None: self.enabled = enabled self.capture_name_tokens = capture_name_tokens self.micro_start_callback = micro_start_callback + self.strict_output_match = strict_output_match self.current_step_index: int | None = None self.current_step_trace: dict[str, list[dict[str, Any]]] = {} self.current_micro_sample_index: int | None = None @@ -227,6 +252,7 @@ def __init__( self.current_step_outputs: list[ tuple[int | None, int, int | None, torch.Tensor] ] = [] + self._trace_metadata_by_name: dict[str, dict[str, Any]] = {} self._next_micro_order = 0 self._hook_handles: list[Any] = [] if not enabled: @@ -244,8 +270,17 @@ def _register_hooks(self, model_chunks: list[Any]) -> None: root_module.register_forward_hook(self._root_post_hook) ) for chunk_index, chunk in enumerate(model_chunks): - for module_name, module in chunk.named_modules(): + named_modules = list(chunk.named_modules()) + module_by_name = dict(named_modules) + for module_name, module in named_modules: trace_module_name = f"chunk{chunk_index}.{module_name}" + metadata = self._build_module_trace_metadata( + module_name=module_name, + module=module, + module_by_name=module_by_name, + ) + if metadata: + self._trace_metadata_by_name[trace_module_name] = metadata is_layer_output = ( ".decoder.layers." in module_name and module_name.rsplit(".", 1)[-1].isdigit() @@ -260,6 +295,45 @@ def _register_hooks(self, model_chunks: list[Any]) -> None: ) ) + @classmethod + def _build_module_trace_metadata( + cls, + *, + module_name: str, + module: Any, + module_by_name: dict[str, Any], + ) -> dict[str, Any]: + if module_name.endswith(".self_attention.in_proj"): + return { + "component_sizes": cls._gdn_in_proj_component_sizes(module), + } + if module_name.endswith(".self_attention.in_proj.in_proj"): + parent_module = module_by_name[module_name.rsplit(".", 1)[0]] + return { + "component_sizes": cls._gdn_in_proj_component_sizes(parent_module), + } + if module_name.endswith(".self_attention.out_norm"): + gdn_module = module_by_name[module_name.removesuffix(".out_norm")] + return { + "local_heads": int(gdn_module.num_value_heads // gdn_module.tp_size), + } + return {} + + @staticmethod + def _gdn_in_proj_component_sizes(module: Any) -> tuple[int, ...]: + qkv_sizes = tuple( + int(size) + for size in getattr(module.qkv_lora.B_T, "lora_tp_component_sizes") + ) + z_world_size = _shard_world_size_for_domain(module.z_lora.B_T.lora_shard_domain) + tp_world_size = _safe_ps_stat("get_tensor_model_parallel_world_size", 1) + return ( + *qkv_sizes, + int(module.z_lora.B_T.shape[-1]) * z_world_size, + int(module.num_value_heads_per_partition) * tp_world_size, + int(module.num_value_heads_per_partition) * tp_world_size, + ) + @staticmethod def _sequence_parallel_enabled(module: Any) -> bool: """Returns sequence-parallel flag from module/provider/config when present.""" @@ -289,7 +363,21 @@ def _lora_primary_output_merge_hint(module: Any) -> dict[str, Any] | None: if bool(getattr(b_param, "lora_tp_sharded", False)) and b_world_size > 1: shard_dim = getattr(b_param, "lora_tp_shard_dim", None) if isinstance(shard_dim, int): - return {"op": "concat", "dim": shard_dim} + hint: dict[str, Any] = {"op": "concat", "dim": shard_dim} + component_sizes = tuple( + int(size) + for size in getattr(b_param, "lora_tp_component_sizes", ()) + ) + world_size_key = _world_size_key_for_domain(b_domain) + if component_sizes and world_size_key is not None: + hint.update( + { + "layout": "componentwise", + "component_sizes": component_sizes, + "world_size_key": world_size_key, + } + ) + return hint a_param = getattr(lora_module, "A_T", None) if a_param is None: return None @@ -310,6 +398,27 @@ def _infer_primary_output_merge_hint( if lora_hint is not None: return lora_hint + trace_metadata = self._trace_metadata_by_name.get(name, {}) + component_sizes = trace_metadata.get("component_sizes") + if isinstance(component_sizes, tuple) and component_sizes: + return { + "op": "concat", + "dim": -1, + "layout": "componentwise", + "component_sizes": component_sizes, + "world_size_key": "tp_world_size", + } + + local_heads = trace_metadata.get("local_heads") + if isinstance(local_heads, int) and local_heads > 0: + return { + "op": "concat", + "dim": 0, + "layout": "rank_blocked_token_heads", + "local_heads": local_heads, + "world_size_key": "tp_world_size", + } + # Base MoE expert linears need expert-TP aware merge semantics. # With etp>1: # - FC1 (column-parallel) shards output features -> concat on feature dim. @@ -322,6 +431,7 @@ def _infer_primary_output_merge_hint( "op": "concat", "dim": -1, "layout": "gate_up_rank_interleaved", + "world_size_key": "etp_world_size", } return {"op": "concat", "dim": 0} if ".mlp.experts.linear_fc2" in name and ".lora" not in name: @@ -329,12 +439,42 @@ def _infer_primary_output_merge_hint( return {"op": "sum"} return {"op": "concat", "dim": 0} + if ".mlp.linear_fc1" in name and ".lora" not in name: + tp_world_size = _safe_ps_stat("get_tensor_model_parallel_world_size", 1) + if tp_world_size > 1: + return { + "op": "concat", + "dim": -1, + "layout": "gate_up_rank_interleaved", + "world_size_key": "tp_world_size", + } + return {"op": "concat", "dim": -1} + if ".mlp.linear_fc2.row_parallel_lora" in name and ".lora" not in name: + if self._sequence_parallel_enabled(module): + return {"op": "concat", "dim": 0} + return None + if ".mlp.linear_fc2" in name and ".lora" not in name: + row_parallel_lora = getattr(module, "row_parallel_lora", None) + if row_parallel_lora is not None and self._sequence_parallel_enabled( + row_parallel_lora + ): + return {"op": "concat", "dim": 0} + return None + gather_output = getattr(module, "gather_output", None) if isinstance(gather_output, bool) and not gather_output: return {"op": "concat", "dim": -1} if ".self_attention.linear_qkv" in name: return {"op": "concat", "dim": -1} + if name.endswith(".self_attention.in_proj"): + return {"op": "concat", "dim": -1} + if name.endswith( + ".self_attention.out_proj" + ) and self._sequence_parallel_enabled(module): + return {"op": "concat", "dim": 0} + if name.endswith(".self_attention") and self._sequence_parallel_enabled(module): + return {"op": "concat", "dim": 0} if ".mlp.experts." in name: return {"op": "concat", "dim": 0} @@ -359,40 +499,46 @@ def _build_merge_hints(self, name: str, module: Any) -> dict[str, dict[str, Any] hints["router_topk_scores"] = concat_dim0 return hints + @torch._dynamo.disable + def _record_module_hook( + self, name: str, module: Any, inputs: Any, output: Any + ) -> None: + if self.current_step_index is None: + return + micro_call_index = self.current_micro_module_call_counts.get(name, 0) + self.current_micro_module_call_counts[name] = micro_call_index + 1 + trace_item: dict[str, Any] = { + "micro_call_index": micro_call_index, + "micro_order": self.current_micro_order, + "micro_sample_index": self.current_micro_sample_index, + "module_type": module.__class__.__name__, + "rank_meta": _rank_metadata(), + "merge_hints": self._build_merge_hints(name, module), + "inputs": _materialize_trace_value(inputs), + "output": _materialize_trace_value(output), + "primary_input": self.guess_primary_tensor(inputs), + "primary_output": self.guess_primary_tensor(output), + } + if ROUTER_NAME_TOKEN in name: + router_topk = _extract_router_topk(output) + if router_topk is not None: + topk_ids, topk_scores = router_topk + trace_item["router_topk_ids"] = topk_ids + trace_item["router_topk_scores"] = topk_scores + trace_items = self._split_expert_trace_items( + module_name=name, + module=module, + inputs=inputs, + trace_item=trace_item, + ) + trace_calls = self.current_step_trace.setdefault(name, []) + for split_item in trace_items: + split_item["call_index"] = len(trace_calls) + trace_calls.append(split_item) + def _make_hook(self, name: str, module: Any): def _hook(_module: Any, inputs: Any, output: Any) -> None: - if self.current_step_index is None: - return - micro_call_index = self.current_micro_module_call_counts.get(name, 0) - self.current_micro_module_call_counts[name] = micro_call_index + 1 - trace_item: dict[str, Any] = { - "micro_call_index": micro_call_index, - "micro_order": self.current_micro_order, - "micro_sample_index": self.current_micro_sample_index, - "module_type": module.__class__.__name__, - "rank_meta": _rank_metadata(), - "merge_hints": self._build_merge_hints(name, module), - "inputs": _materialize_trace_value(inputs), - "output": _materialize_trace_value(output), - "primary_input": self.guess_primary_tensor(inputs), - "primary_output": self.guess_primary_tensor(output), - } - if ROUTER_NAME_TOKEN in name: - router_topk = _extract_router_topk(output) - if router_topk is not None: - topk_ids, topk_scores = router_topk - trace_item["router_topk_ids"] = topk_ids - trace_item["router_topk_scores"] = topk_scores - trace_items = self._split_expert_trace_items( - module_name=name, - module=module, - inputs=inputs, - trace_item=trace_item, - ) - trace_calls = self.current_step_trace.setdefault(name, []) - for split_item in trace_items: - split_item["call_index"] = len(trace_calls) - trace_calls.append(split_item) + self._record_module_hook(name, module, inputs, output) return _hook @@ -408,6 +554,7 @@ def _sample_index_for_micro(self, micro_order: int) -> int | None: return self.current_step_sample_indices[micro_order] return None + @torch._dynamo.disable def _root_pre_hook(self, _module: Any, _args: Any) -> None: if self.current_step_index is None: return @@ -415,6 +562,7 @@ def _root_pre_hook(self, _module: Any, _args: Any) -> None: sample_index = self._sample_index_for_micro(micro_order) self.begin_micro(sample_index=sample_index, micro_order=micro_order) + @torch._dynamo.disable def _root_post_hook(self, _module: Any, _inputs: Any, output: Any) -> None: if self.current_step_index is None: return @@ -604,41 +752,163 @@ def _primary_output_merge_hint(call: dict[str, Any]) -> dict[str, Any] | None: return primary_hint @classmethod - def _canonicalize_etp_fc1_feature_layout( + def _canonicalize_gate_up_rank_interleaved_feature_layout( cls, *, module_name: str, tensor: torch.Tensor, call: dict[str, Any], ) -> torch.Tensor: - """Normalizes expert-TP fc1 feature order to a topology-independent layout.""" - if ".mlp.experts.linear_fc1" not in module_name or ".lora" in module_name: - return tensor - if tensor.ndim != 2: - return tensor + """Normalizes TP/ETP fused gate-up fc1 output feature order.""" + del module_name primary_hint = cls._primary_output_merge_hint(call) if not isinstance(primary_hint, dict): return tensor if primary_hint.get("layout") != "gate_up_rank_interleaved": return tensor + world_size_key = primary_hint.get("world_size_key") + if not isinstance(world_size_key, str): + raise RuntimeError("gate_up_rank_interleaved hint requires world_size_key") rank_meta = call.get("rank_meta") - etp_world_size = None + rank_world_size = None if isinstance(rank_meta, list) and rank_meta: first_meta = rank_meta[0] if isinstance(first_meta, dict): - etp_world_size = first_meta.get("etp_world_size") + rank_world_size = first_meta.get(world_size_key) elif isinstance(rank_meta, dict): - etp_world_size = rank_meta.get("etp_world_size") - if not isinstance(etp_world_size, int) or etp_world_size <= 1: - return tensor - block_count = 2 * etp_world_size - if tensor.shape[1] % block_count != 0: + rank_world_size = rank_meta.get(world_size_key) + if not isinstance(rank_world_size, int) or rank_world_size <= 1: return tensor - blocks = torch.chunk(tensor, block_count, dim=1) + block_count = 2 * rank_world_size + if tensor.ndim < 1 or tensor.shape[-1] % block_count != 0: + raise RuntimeError( + "gate_up_rank_interleaved tensor feature size must divide by " + f"{block_count}, got shape={tuple(tensor.shape)}" + ) + blocks = torch.chunk(tensor, block_count, dim=-1) reordered = [blocks[index] for index in range(0, block_count, 2)] + [ blocks[index] for index in range(1, block_count, 2) ] - return torch.cat(reordered, dim=1).contiguous() + return torch.cat(reordered, dim=-1).contiguous() + + @classmethod + def _canonicalize_componentwise_feature_layout( + cls, + *, + module_name: str, + tensor: torch.Tensor, + call: dict[str, Any], + ) -> torch.Tensor: + """Normalizes fused componentwise TP output order, e.g. GDN q/k/v.""" + del module_name + primary_hint = cls._primary_output_merge_hint(call) + if not isinstance(primary_hint, dict): + return tensor + if primary_hint.get("layout") != "componentwise": + return tensor + dim = primary_hint.get("dim") + component_sizes = primary_hint.get("component_sizes") + world_size_key = primary_hint.get("world_size_key") + if not isinstance(dim, int) or not isinstance(world_size_key, str): + raise RuntimeError("componentwise hint requires dim and world_size_key") + if not isinstance(component_sizes, tuple) or not all( + isinstance(size, int) and size > 0 for size in component_sizes + ): + raise RuntimeError("componentwise hint requires positive component sizes") + rank_meta = call.get("rank_meta") + rank_world_size = None + if isinstance(rank_meta, list) and rank_meta: + first_meta = rank_meta[0] + if isinstance(first_meta, dict): + rank_world_size = first_meta.get(world_size_key) + elif isinstance(rank_meta, dict): + rank_world_size = rank_meta.get(world_size_key) + if not isinstance(rank_world_size, int) or rank_world_size <= 1: + return tensor + axis = dim if dim >= 0 else tensor.ndim + dim + if axis < 0 or axis >= tensor.ndim: + raise RuntimeError( + f"Invalid componentwise axis {dim} for {tensor.ndim}D tensor" + ) + if sum(component_sizes) != tensor.shape[axis]: + raise RuntimeError( + "componentwise component sizes must match tensor extent, got " + f"sizes={component_sizes} shape={tuple(tensor.shape)} axis={axis}" + ) + if any(size % rank_world_size != 0 for size in component_sizes): + raise RuntimeError( + "componentwise component sizes must divide rank world size, got " + f"sizes={component_sizes} world_size={rank_world_size}" + ) + local_sizes = [size // rank_world_size for size in component_sizes] + rank_chunks: list[list[torch.Tensor]] = [] + cursor = 0 + for _rank in range(rank_world_size): + rank_components = [] + for local_size in local_sizes: + rank_components.append(tensor.narrow(axis, cursor, local_size)) + cursor += local_size + rank_chunks.append(rank_components) + ordered = [ + rank_chunks[rank][component_index] + for component_index in range(len(component_sizes)) + for rank in range(rank_world_size) + ] + return torch.cat(ordered, dim=axis).contiguous() + + @classmethod + def _canonicalize_rank_blocked_token_heads( + cls, + *, + module_name: str, + tensor: torch.Tensor, + call: dict[str, Any], + ) -> torch.Tensor: + del module_name + primary_hint = cls._primary_output_merge_hint(call) + if not isinstance(primary_hint, dict): + return tensor + if primary_hint.get("layout") != "rank_blocked_token_heads": + return tensor + local_heads = primary_hint.get("local_heads") + world_size_key = primary_hint.get("world_size_key") + if not isinstance(local_heads, int) or local_heads <= 0: + raise RuntimeError("rank_blocked_token_heads hint requires local_heads") + if not isinstance(world_size_key, str): + raise RuntimeError("rank_blocked_token_heads hint requires world_size_key") + rank_meta = call.get("rank_meta") + rank_world_size = None + if isinstance(rank_meta, list) and rank_meta: + first_meta = rank_meta[0] + if isinstance(first_meta, dict): + rank_world_size = first_meta.get(world_size_key) + elif isinstance(rank_meta, dict): + rank_world_size = rank_meta.get(world_size_key) + if not isinstance(rank_world_size, int) or rank_world_size <= 1: + return tensor + if tensor.ndim != 2: + raise RuntimeError( + "rank_blocked_token_heads expects a 2D [rows, head_dim] tensor, " + f"got shape={tuple(tensor.shape)}" + ) + rows_per_rank, remainder = divmod(int(tensor.shape[0]), rank_world_size) + if remainder != 0: + raise RuntimeError( + "rank_blocked_token_heads rows must divide rank world size, got " + f"shape={tuple(tensor.shape)} world_size={rank_world_size}" + ) + token_count, head_remainder = divmod(rows_per_rank, local_heads) + if head_remainder != 0: + raise RuntimeError( + "rank_blocked_token_heads rows per rank must divide local_heads, got " + f"rows_per_rank={rows_per_rank} local_heads={local_heads}" + ) + return ( + tensor.reshape(rank_world_size, token_count, local_heads, tensor.shape[-1]) + .permute(1, 0, 2, 3) + .reshape(tensor.shape) + .contiguous() + ) @classmethod def _canonicalize_moe_expert_row_order( @@ -675,7 +945,17 @@ def _canonicalize_primary_output_tensor( call: dict[str, Any], ) -> torch.Tensor: """Runs all remaining primary-output canonicalization passes for one call.""" - tensor = cls._canonicalize_etp_fc1_feature_layout( + tensor = cls._canonicalize_gate_up_rank_interleaved_feature_layout( + module_name=module_name, + tensor=tensor, + call=call, + ) + tensor = cls._canonicalize_componentwise_feature_layout( + module_name=module_name, + tensor=tensor, + call=call, + ) + tensor = cls._canonicalize_rank_blocked_token_heads( module_name=module_name, tensor=tensor, call=call, @@ -917,7 +1197,9 @@ def _gather_rank_traces( return cast(list[dict[str, list[dict[str, Any]]]], gathered) @staticmethod - def _merge_group_tensor(tensors: list[torch.Tensor]) -> torch.Tensor: + def _merge_group_tensor( + tensors: list[torch.Tensor], *, strict: bool = True + ) -> torch.Tensor: if len(tensors) == 1: return tensors[0] first = tensors[0] @@ -925,6 +1207,8 @@ def _merge_group_tensor(tensors: list[torch.Tensor]) -> torch.Tensor: torch.equal(first, tensor) for tensor in tensors[1:] ): return first + if not strict: + return first raise RuntimeError( "Mismatched output captures for the same micro output across non-DP ranks" ) @@ -965,7 +1249,10 @@ def ordered_step_outputs(self) -> list[torch.Tensor] | None: key=lambda item: _captured_output_sort_key(item[0], item[2], item[1]), ) return [ - self._merge_group_tensor(grouped[group_key]) + self._merge_group_tensor( + grouped[group_key], + strict=self.strict_output_match, + ) for group_key in ordered_group_keys ] diff --git a/tests/integration/megatron/model_support/hf_parity.py b/tests/integration/megatron/model_support/hf_parity.py new file mode 100644 index 000000000..cdb99d92f --- /dev/null +++ b/tests/integration/megatron/model_support/hf_parity.py @@ -0,0 +1,374 @@ +from __future__ import annotations + +import os +from pathlib import Path +import subprocess +import sys +from typing import Any + +from pydantic import BaseModel, Field + +from art.megatron.model_support.spec import MinimalLayerCoverageReport + +from .oracle_harness import ( + NON_FINITE_METRIC_VALUE, + ORACLE_TOPOLOGY, + DiffAccumulator, + DiskPackedTensorsSpec, + OracleCaseConfig, + PhasePassFn, + _default_phase_pass_fns, + _read_json, + _write_json, + ensure_case_artifacts, +) +from .oracle_worker import provider_topology_env +from .workflow import assess_minimal_layer_coverage + +HF_PARITY_ENABLE_ENV = "ART_RUN_HF_PARITY" +HF_PARITY_OUTPUT_DIRNAME = "hf_parity_sft" +HF_PARITY_REPORT_FILENAME = "report.json" + +REPO_ROOT = Path(__file__).resolve().parents[4] + + +class HfParityMetricRow(BaseModel): + phase: str + param: str + numel: float + mean_abs_diff: float + relative_l2: float + typical_abs_scale: float + candidate_abs_scale: float + mean_abs_pct: float + pass_signal: bool = True + failure_reasons: list[str] = Field(default_factory=list) + + +class HfParityRunRequest(BaseModel): + case_id: str + case_config: OracleCaseConfig + packed_tensors: DiskPackedTensorsSpec + output_dir: str + coverage: MinimalLayerCoverageReport + + +class HfParityReport(BaseModel): + case_id: str + base_model: str + model_key: str + requested_num_layers: int + coverage: MinimalLayerCoverageReport + signal: str + pass_count: int + fail_count: int + metrics: list[HfParityMetricRow] = Field(default_factory=list) + + +def _hf_parity_phase_pass_fns() -> dict[str, PhasePassFn]: + return _default_phase_pass_fns() + + +def hf_parity_enabled() -> bool: + value = os.environ.get(HF_PARITY_ENABLE_ENV) + if value is None: + return False + return value.strip().lower() in {"1", "true", "yes", "on"} + + +def _inf_summary() -> dict[str, float]: + return { + "numel": 0.0, + "mean_abs_diff": NON_FINITE_METRIC_VALUE, + "relative_l2": NON_FINITE_METRIC_VALUE, + "typical_abs_scale": 0.0, + "candidate_abs_scale": 0.0, + "mean_abs_pct": NON_FINITE_METRIC_VALUE, + } + + +def _build_metric_row( + *, + phase: str, + param: str, + summary: dict[str, float], + structural_failure: str | None = None, +) -> HfParityMetricRow: + row = HfParityMetricRow( + phase=phase, + param=param, + numel=summary["numel"], + mean_abs_diff=summary["mean_abs_diff"], + relative_l2=summary["relative_l2"], + typical_abs_scale=summary["typical_abs_scale"], + candidate_abs_scale=summary["candidate_abs_scale"], + mean_abs_pct=summary["mean_abs_pct"], + ) + pass_fn = _hf_parity_phase_pass_fns().get(phase) + if pass_fn is None: + row.pass_signal = structural_failure is None + if structural_failure is not None: + row.failure_reasons = [structural_failure] + return row + row.pass_signal = bool(pass_fn(summary)) + explain = getattr(pass_fn, "failure_reasons", None) + if callable(explain) and not row.pass_signal: + row.failure_reasons = list(explain(summary)) + if structural_failure is not None: + row.pass_signal = False + row.failure_reasons = [structural_failure, *row.failure_reasons] + return row + + +def summarize_tensor_pair(reference: Any, candidate: Any) -> dict[str, float]: + if tuple(reference.shape) != tuple(candidate.shape): + return _inf_summary() + accumulator = DiffAccumulator() + accumulator.update(reference, candidate) + return accumulator.as_summary() + + +def build_tensor_map_metric_rows( + *, + phase: str, + reference: dict[str, Any], + candidate: dict[str, Any], +) -> list[HfParityMetricRow]: + reference_keys = set(reference.keys()) + candidate_keys = set(candidate.keys()) + if reference_keys != candidate_keys: + missing = sorted(reference_keys - candidate_keys) + extra = sorted(candidate_keys - reference_keys) + return [ + _build_metric_row( + phase=phase, + param="__tensor_set__", + summary=_inf_summary(), + structural_failure=f"missing={missing[:5]} extra={extra[:5]}", + ) + ] + rows: list[HfParityMetricRow] = [] + for key in sorted(reference_keys): + if tuple(reference[key].shape) != tuple(candidate[key].shape): + rows.append( + _build_metric_row( + phase=phase, + param=key, + summary=_inf_summary(), + structural_failure=f"shape mismatch for '{key}'", + ) + ) + continue + rows.append( + _build_metric_row( + phase=phase, + param=key, + summary=summarize_tensor_pair(reference[key], candidate[key]), + ) + ) + return rows + + +def build_parity_sample_indices( + *, + num_sequences: int, + global_grad_accumulation_sequences: int, +) -> list[int | None]: + return [ + index if index < num_sequences else None + for index in range(global_grad_accumulation_sequences) + ] + + +def _iter_hf_layer_config_views(config: Any) -> list[tuple[str, Any]]: + views: list[tuple[str, Any]] = [("", config)] + base_config_key = getattr(config, "base_config_key", None) + candidate_names = [ + name + for name in [ + base_config_key if isinstance(base_config_key, str) else None, + "text_config", + "language_config", + "llm_config", + "decoder_config", + ] + if isinstance(name, str) + ] + seen_ids = {id(config)} + for name in candidate_names: + nested = getattr(config, name, None) + if nested is None or id(nested) in seen_ids: + continue + seen_ids.add(id(nested)) + views.append((f"{name}.", nested)) + return views + + +def set_hf_config_num_layers(config: Any, num_layers: int) -> str: + for prefix, config_view in _iter_hf_layer_config_views(config): + for field in ("num_hidden_layers", "num_layers", "n_layer"): + if not hasattr(config_view, field): + continue + setattr(config_view, field, num_layers) + layer_types = getattr(config_view, "layer_types", None) + if isinstance(layer_types, (list, tuple)): + setattr(config_view, "layer_types", list(layer_types[:num_layers])) + mlp_only_layers = getattr(config_view, "mlp_only_layers", None) + if isinstance(mlp_only_layers, (list, tuple)): + setattr( + config_view, + "mlp_only_layers", + [layer for layer in mlp_only_layers if int(layer) < num_layers], + ) + return f"{prefix}{field}" + raise ValueError( + f"Could not find a supported layer-count field on HF config type {type(config)}" + ) + + +def zero_hf_dropout_config(config: Any) -> None: + for field in ( + "attention_dropout", + "hidden_dropout", + "dropout", + "embd_pdrop", + "resid_pdrop", + "attn_pdrop", + "classifier_dropout", + ): + if hasattr(config, field): + setattr(config, field, 0.0) + + +def assert_hf_parity_pass(report: HfParityReport, *, report_path: Path) -> None: + if report.signal == "pass": + return + first_failure = next(row for row in report.metrics if not row.pass_signal) + raise AssertionError( + f"HF parity failed: phase={first_failure.phase} param={first_failure.param} " + f"reasons={'; '.join(first_failure.failure_reasons)} report={report_path}" + ) + + +def run_hf_parity_subprocess(request: HfParityRunRequest, output_dir: Path) -> None: + request_path = output_dir / "run_request.json" + _write_json(request_path, request.model_dump(mode="json")) + worker_cwd = REPO_ROOT / "tests" + command = [ + sys.executable, + "-m", + "integration.megatron.model_support.hf_parity_worker", + "--run-request", + str(request_path), + ] + env = {**os.environ, "PYTHONUNBUFFERED": "1"} + run = subprocess.run( + command, + cwd=str(worker_cwd), + env=env, + capture_output=True, + text=True, + check=False, + ) + combined_output = f"{run.stdout}\n{run.stderr}".strip() + (output_dir / "worker.log").write_text(combined_output + "\n", encoding="utf-8") + if run.returncode != 0: + tail = "\n".join(combined_output.splitlines()[-80:]) + raise RuntimeError( + f"HF parity worker failed with exit code {run.returncode}.\n{tail}" + ) + + +def run_hf_parity( + *, + case_config: OracleCaseConfig, +) -> HfParityReport: + if case_config.precision != "fp32": + raise ValueError("HF parity currently requires fp32 precision") + if case_config.num_steps != 1: + raise ValueError("HF parity currently requires num_steps=1") + + coverage = assess_minimal_layer_coverage( + base_model=case_config.base_model, + num_layers=case_config.num_layers, + allow_unvalidated_arch=case_config.allow_unvalidated_arch, + ) + if not coverage.covered: + raise AssertionError( + "HF parity toy model does not cover required layer families: " + f"missing={coverage.missing_layer_families} " + f"risks={coverage.unresolved_risks}" + ) + + case_artifacts = ensure_case_artifacts(case_config) + output_dir = Path(case_artifacts.case_dir) / HF_PARITY_OUTPUT_DIRNAME + report_path = output_dir / HF_PARITY_REPORT_FILENAME + output_dir.mkdir(parents=True, exist_ok=True) + if report_path.exists(): + report_path.unlink() + request = HfParityRunRequest( + case_id=case_artifacts.case_id, + case_config=case_config, + packed_tensors=case_artifacts.packed_tensors, + output_dir=str(output_dir), + coverage=coverage, + ) + with provider_topology_env(ORACLE_TOPOLOGY): + run_hf_parity_subprocess(request, output_dir) + report = HfParityReport.model_validate(_read_json(report_path)) + assert_hf_parity_pass(report, report_path=report_path) + return report + + +def build_hf_parity_report( + *, + request: HfParityRunRequest, + outputs_summary: dict[str, float], + loss_summary: dict[str, float], + grads_rows: list[HfParityMetricRow], +) -> HfParityReport: + rows = [ + _build_metric_row( + phase="outputs", + param="trainable_token_losses", + summary=outputs_summary, + ), + _build_metric_row( + phase="losses", + param="loss", + summary=loss_summary, + ), + *grads_rows, + ] + pass_count = sum(1 for row in rows if row.pass_signal) + fail_count = len(rows) - pass_count + return HfParityReport( + case_id=request.case_id, + base_model=request.case_config.base_model, + model_key=request.coverage.model_key, + requested_num_layers=request.case_config.num_layers, + coverage=request.coverage, + signal="pass" if fail_count == 0 else "fail", + pass_count=pass_count, + fail_count=fail_count, + metrics=rows, + ) + + +__all__ = [ + "HF_PARITY_ENABLE_ENV", + "HF_PARITY_OUTPUT_DIRNAME", + "HF_PARITY_REPORT_FILENAME", + "HfParityMetricRow", + "HfParityReport", + "HfParityRunRequest", + "assert_hf_parity_pass", + "build_hf_parity_report", + "build_parity_sample_indices", + "build_tensor_map_metric_rows", + "hf_parity_enabled", + "run_hf_parity", + "set_hf_config_num_layers", + "summarize_tensor_pair", + "zero_hf_dropout_config", +] diff --git a/tests/integration/megatron/model_support/hf_parity_worker.py b/tests/integration/megatron/model_support/hf_parity_worker.py new file mode 100644 index 000000000..26e1fa1a4 --- /dev/null +++ b/tests/integration/megatron/model_support/hf_parity_worker.py @@ -0,0 +1,872 @@ +from __future__ import annotations + +import argparse +import faulthandler +import os +from pathlib import Path +import re +import sys +import time +from typing import Any, cast + +import torch +import torch.nn.functional as F + +from art.megatron import train as megatron_train +from art.megatron.model_support import get_model_support_handler +from art.megatron.routing_replay import ( + MoeRoutingReplayBundle, + RouterCallRoute, + StepRouterRoutes, + StepRoutes, +) +from art.megatron.routing_replay import ( + ParallelTopology as ReplayParallelTopology, +) +from art.megatron.weights.merged_weight_export import build_art_conversion_tasks +from art.preprocessing.pack import packed_tensors_from_dir + +from .hf_parity import ( + HF_PARITY_REPORT_FILENAME, + HfParityRunRequest, + build_hf_parity_report, + build_parity_sample_indices, + build_tensor_map_metric_rows, + set_hf_config_num_layers, + summarize_tensor_pair, + zero_hf_dropout_config, +) +from .oracle_harness import ORACLE_TOPOLOGY, _read_json, _write_json +from .oracle_worker import ( + _assert_runtime_configuration, + _build_optimizer_config, + _configure_cuda_precision, + _configure_provider, + _set_deterministic_seed, +) +from .test_inputs import build_sft_trajectory_tensors_from_packed_tensors + +HF_PARITY_DEBUG_ENV = "ART_HF_PARITY_DEBUG" +_DEBUG_START_TIME = time.perf_counter() +_VISUAL_HF_PREFIXES = ("model.visual.", "visual.") +_HF_MOE_ROUTER_NAME_PATTERN = re.compile(r"^model\.layers\.(?P\d+)\.mlp\.gate$") +_REPLAY_ROUTER_LAYER_PATTERN = re.compile( + r"^chunk_\d+\.layer_(?P\d+)\.mlp\.router$" +) +_GATE_WEIGHT_PATTERN = re.compile( + r"^model(?:\.language_model)?\.layers\.(?P\d+)\.mlp\.gate\.weight$" +) +_EXPERT_WEIGHT_PATTERN = re.compile( + r"^model(?:\.language_model)?\.layers\.(?P\d+)\.mlp\.experts\." + r"(?P\d+)\.(?:down_proj|gate_proj|up_proj)\.weight$" +) + + +def _hf_moe_router_key(module_name: str) -> str | None: + match = _HF_MOE_ROUTER_NAME_PATTERN.match(module_name) + if match is None: + return None + return f"chunk_00.layer_{int(match.group('layer')):04d}.mlp.router" + + +class _HfMoeRoutingCapture: + def __init__(self, model: Any) -> None: + self._handles: list[Any] = [] + self._routes: dict[str, dict[int, RouterCallRoute]] = {} + self._active_sample_index: int | None = None + self._active_micro_slot = 0 + for module_name, module in model.named_modules(): + router_key = _hf_moe_router_key(module_name) + if router_key is None: + continue + self._routes[router_key] = {} + self._handles.append( + module.register_forward_hook(self._make_hook(router_key, module)) + ) + + @property + def enabled(self) -> bool: + return bool(self._handles) + + def set_active_micro(self, sample_index: int | None, micro_slot: int) -> None: + self._active_sample_index = sample_index + self._active_micro_slot = micro_slot + + def close(self) -> None: + for handle in self._handles: + handle.remove() + self._handles.clear() + + def build_replay_bundle( + self, + *, + topology: ReplayParallelTopology, + ) -> MoeRoutingReplayBundle | None: + if not self.enabled: + return None + routers: dict[str, StepRouterRoutes] = {} + max_topk = 0 + num_global_tokens: int | None = None + for router_key in sorted(self._routes): + calls = self._routes[router_key] + if not calls: + raise RuntimeError(f"HF parity captured no routes for '{router_key}'") + routers[router_key] = StepRouterRoutes(calls=calls) + for route in calls.values(): + max_topk = max(max_topk, route.max_topk) + if num_global_tokens is None: + num_global_tokens = route.num_global_tokens + elif num_global_tokens != route.num_global_tokens: + raise RuntimeError( + "HF parity routing capture token count mismatch: " + f"expected={num_global_tokens}, got={route.num_global_tokens}, " + f"router='{router_key}'" + ) + if num_global_tokens is None: + raise RuntimeError("HF parity routing capture produced no route tokens") + return MoeRoutingReplayBundle( + topology=topology, + num_steps=1, + max_topk=max_topk, + router_keys=sorted(routers), + steps={ + 0: StepRoutes( + routers=routers, + global_token_uids=torch.arange( + num_global_tokens, dtype=torch.int64 + ), + ) + }, + ) + + def _make_hook(self, router_key: str, module: Any) -> Any: + def _hook(_module: Any, _inputs: Any, output: Any) -> None: + if not isinstance(output, tuple) or len(output) < 3: + raise RuntimeError( + f"Expected HF router tuple output for '{router_key}', got {type(output)}" + ) + router_scores = output[1] + router_indices = output[2] + if not isinstance(router_scores, torch.Tensor) or not isinstance( + router_indices, torch.Tensor + ): + raise RuntimeError( + f"Expected tensor router outputs for '{router_key}', " + f"got scores={type(router_scores)} indices={type(router_indices)}" + ) + route = RouterCallRoute( + expert_indices=router_indices.detach().cpu().to(torch.int32), + expert_probs=router_scores.detach().cpu().to(torch.float32), + expert_mask=torch.ones_like( + router_indices.detach().cpu(), dtype=torch.bool + ), + num_experts=int( + getattr(module, "num_experts", router_scores.shape[-1]) + ), + sample_index=self._active_sample_index, + micro_slot=( + None + if self._active_sample_index is not None + else self._active_micro_slot + ), + ) + self._routes[router_key][len(self._routes[router_key])] = route + + return _hook + + +def _debug(message: str) -> None: + if os.environ.get(HF_PARITY_DEBUG_ENV, "").strip().lower() not in { + "1", + "true", + "yes", + "on", + }: + return + elapsed = time.perf_counter() - _DEBUG_START_TIME + print(f"[hf_parity +{elapsed:8.2f}s] {message}", flush=True) + + +def _enable_debug_traceback_dump() -> None: + if os.environ.get(HF_PARITY_DEBUG_ENV, "").strip().lower() not in { + "1", + "true", + "yes", + "on", + }: + return + faulthandler.enable() + faulthandler.dump_traceback_later(60, repeat=True) + + +def _debug_enabled() -> bool: + return os.environ.get(HF_PARITY_DEBUG_ENV, "").strip().lower() in { + "1", + "true", + "yes", + "on", + } + + +def _install_bridge_timing_debug(provider_bundle: Any) -> None: + if not _debug_enabled(): + return + provider = provider_bundle.provider + pre_wrap_hooks = list(getattr(provider, "_pre_wrap_hooks", [])) + _debug( + "registered pre-wrap hooks: " + + ", ".join( + getattr(hook, "__qualname__", repr(hook)) for hook in pre_wrap_hooks + ) + ) + timed_hooks = [] + for index, hook in enumerate(pre_wrap_hooks): + label = f"pre_wrap_hook[{index}]" + + def _timed_hook( + model: list[Any], _hook: Any = hook, _label: str = label + ) -> list[Any]: + start = time.perf_counter() + _debug(f"{_label}: start") + try: + return _hook(model) + finally: + _debug(f"{_label}: done in {time.perf_counter() - start:.2f}s") + + timed_hooks.append(_timed_hook) + if pre_wrap_hooks: + provider._pre_wrap_hooks = timed_hooks + + model_bridge = getattr(provider_bundle.bridge, "_model_bridge", None) + if model_bridge is None: + return + if getattr(model_bridge, "_art_hf_parity_timing_wrapped", False): + return + original = model_bridge.load_weights_hf_to_megatron + + def _timed_load_weights(*args: Any, **kwargs: Any) -> Any: + start = time.perf_counter() + _debug("bridge.load_weights_hf_to_megatron: start") + try: + return original(*args, **kwargs) + finally: + _debug( + "bridge.load_weights_hf_to_megatron: done in " + f"{time.perf_counter() - start:.2f}s" + ) + + model_bridge.load_weights_hf_to_megatron = _timed_load_weights + model_bridge._art_hf_parity_timing_wrapped = True + + +def _load_hf_model( + *, + base_model: str, + num_layers: int, + device: torch.device, +) -> Any: + from transformers import AutoConfig, AutoModelForCausalLM + + config = AutoConfig.from_pretrained(base_model, trust_remote_code=True) + set_hf_config_num_layers(config, num_layers) + zero_hf_dropout_config(config) + model = AutoModelForCausalLM.from_pretrained( + base_model, + config=config, + trust_remote_code=True, + torch_dtype=torch.float32, + low_cpu_mem_usage=True, + ) + model.train() + return cast(Any, model).to(device) + + +def _collect_hf_grads(model: Any) -> dict[str, torch.Tensor]: + grads: dict[str, torch.Tensor] = {} + for name, param in model.named_parameters(): + grad = param.grad + if grad is None: + grad = torch.zeros_like(param) + grads[name] = grad.detach().cpu().to(dtype=torch.float32) + return grads + + +def _bridge_compatible_hf_key(key: str, expected_keys: set[str]) -> str: + if key in expected_keys: + return key + if key.startswith("model."): + prefixed = f"model.language_model.{key.removeprefix('model.')}" + if prefixed in expected_keys: + return prefixed + if key.startswith("model.language_model."): + stripped = f"model.{key.removeprefix('model.language_model.')}" + if stripped in expected_keys: + return stripped + return key + + +def _normalize_hf_tensor_map_for_bridge( + hf_map: dict[str, torch.Tensor], + expected_keys: set[str], +) -> dict[str, torch.Tensor]: + normalized: dict[str, torch.Tensor] = {} + for key, value in hf_map.items(): + normalized_key = _bridge_compatible_hf_key(key, expected_keys) + if normalized_key in normalized: + raise RuntimeError( + f"Duplicate normalized HF key '{normalized_key}' from '{key}'" + ) + normalized[normalized_key] = value + return normalized + + +def _active_embedding_token_rows( + micro_inputs: list[dict[str, torch.Tensor]], +) -> torch.Tensor: + active_token_ids: list[torch.Tensor] = [] + for micro in micro_inputs: + attention_mask = micro["attention_mask"].reshape(-1).to(dtype=torch.bool) + if not bool(attention_mask.any()): + continue + active_token_ids.append(micro["input_ids"].reshape(-1)[attention_mask].cpu()) + if not active_token_ids: + return torch.zeros((0,), dtype=torch.long) + return torch.unique(torch.cat(active_token_ids, dim=0), sorted=True) + + +def _active_router_rows_by_layer( + replay_bundle: MoeRoutingReplayBundle | None, +) -> dict[int, torch.Tensor]: + if replay_bundle is None: + return {} + active_rows: dict[int, torch.Tensor] = {} + step_routes = replay_bundle.steps.get(0) + if step_routes is None: + return {} + for router_key, router_routes in step_routes.routers.items(): + match = _REPLAY_ROUTER_LAYER_PATTERN.match(router_key) + if match is None: + continue + layer_index = int(match.group("layer")) + layer_rows: list[torch.Tensor] = [] + for route in router_routes.calls.values(): + if route.expert_indices.numel() == 0: + continue + layer_rows.append(route.expert_indices[route.expert_mask].to(torch.long)) + if layer_rows: + active_rows[layer_index] = torch.unique( + torch.cat(layer_rows, dim=0), + sorted=True, + ) + return active_rows + + +def _loss_active_last_layer_experts( + replay_bundle: MoeRoutingReplayBundle | None, + micro_inputs: list[dict[str, torch.Tensor]], + sample_indices: list[int | None], + *, + layer_index: int, +) -> set[int]: + if replay_bundle is None: + return set() + experts: set[int] = set() + step_routes = replay_bundle.steps.get(0) + if step_routes is None: + return experts + for router_key, router_routes in step_routes.routers.items(): + match = _REPLAY_ROUTER_LAYER_PATTERN.match(router_key) + if match is None or int(match.group("layer")) != layer_index: + continue + for route in router_routes.calls.values(): + micro_index = ( + sample_indices.index(route.sample_index) + if route.sample_index is not None + else route.micro_slot + ) + if micro_index is None: + continue + micro = micro_inputs[micro_index] + actual_len = max(int(micro["attention_mask"].reshape(-1).sum().item()), 1) + shifted_labels = megatron_train.shift_tensor( + micro["labels"].reshape(-1)[:actual_len].unsqueeze(0), -100 + ).reshape(-1) + loss_mask = (shifted_labels != -100).cpu() + selected = route.expert_indices[loss_mask][route.expert_mask[loss_mask]] + experts.update(int(expert) for expert in selected.reshape(-1).tolist()) + return experts + + +def _focus_derivative_tensor_map( + tensor_map: dict[str, torch.Tensor], + *, + active_embedding_rows: torch.Tensor, + active_router_rows: dict[int, torch.Tensor], + last_layer_index: int, + loss_active_last_layer_experts: set[int], +) -> dict[str, torch.Tensor]: + focused: dict[str, torch.Tensor] = {} + for key, value in tensor_map.items(): + if match := _EXPERT_WEIGHT_PATTERN.match(key): + if ( + int(match.group("layer")) == last_layer_index + and int(match.group("expert")) not in loss_active_last_layer_experts + ): + continue + focused_value = value + if ( + key == "model.language_model.embed_tokens.weight" + and active_embedding_rows.numel() > 0 + ): + focused_value = value.index_select(0, active_embedding_rows) + elif match := _GATE_WEIGHT_PATTERN.match(key): + active_rows = active_router_rows.get(int(match.group("layer"))) + if active_rows is not None and active_rows.numel() > 0: + focused_value = value.index_select(0, active_rows) + focused[key] = focused_value + return focused + + +def _run_hf_sft_step( + *, + base_model: str, + num_layers: int, + micro_inputs: list[dict[str, torch.Tensor]], + sample_indices: list[int | None], + topology: ReplayParallelTopology, + device: torch.device, +) -> tuple[ + torch.Tensor, + torch.Tensor, + dict[str, torch.Tensor], + MoeRoutingReplayBundle | None, +]: + _debug("loading HF model") + model = _load_hf_model(base_model=base_model, num_layers=num_layers, device=device) + route_capture = _HfMoeRoutingCapture(model) + _debug("running HF forward/backward") + model.zero_grad(set_to_none=True) + loss_sum = torch.tensor(0.0, device=device) + token_count = 0 + trainable_losses: list[torch.Tensor] = [] + total_token_count = max( + sum( + int(megatron_train._count_sft_trainable_tokens(micro)) + for micro in micro_inputs + ), + 1, + ) + for micro_slot, (micro, sample_index) in enumerate( + zip(micro_inputs, sample_indices, strict=True) + ): + route_capture.set_active_micro(sample_index, micro_slot) + attention_mask = micro["attention_mask"].reshape(-1) + actual_len = max(int(attention_mask.sum().item()), 1) + input_ids = micro["input_ids"].reshape(-1)[:actual_len].unsqueeze(0).to(device) + labels = micro["labels"].reshape(-1)[:actual_len].unsqueeze(0).to(device) + hf_attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device) + logits = model( + input_ids=input_ids, + attention_mask=hf_attention_mask, + use_cache=False, + ).logits + shifted_labels = megatron_train.shift_tensor(labels, -100) + per_token_loss = F.cross_entropy( + logits.reshape(-1, logits.shape[-1]), + shifted_labels.reshape(-1), + reduction="none", + ignore_index=-100, + ).reshape(shifted_labels.shape) + mask = shifted_labels != -100 + masked_losses = per_token_loss[mask] + trainable_losses.append(masked_losses.detach().cpu()) + loss_sum = loss_sum + masked_losses.sum() + token_count += int(mask.sum().item()) + (masked_losses.sum() / total_token_count).backward() + grads = _collect_hf_grads(model) + routing_replay_bundle = route_capture.build_replay_bundle(topology=topology) + scalar_loss = (loss_sum / max(token_count, 1)).detach().cpu().reshape(1) + output_vector = torch.cat(trainable_losses, dim=0).to(dtype=torch.float32) + route_capture.close() + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + _debug("finished HF step") + return output_vector, scalar_loss, grads, routing_replay_bundle + + +def _build_megatron_runtime( + request: HfParityRunRequest, +) -> megatron_train.TrainingRuntime: + return megatron_train.build_training_runtime( + model_identifier=request.case_config.base_model, + provider_torch_dtype=torch.float32, + provider_bundle_configure=_install_bridge_timing_debug, + provider_configure=lambda provider: _configure_provider( + provider, ORACLE_TOPOLOGY, request.case_config + ), + optimizer_config=_build_optimizer_config(request.case_config), + print_env=False, + trainable_parameter_mode="base_model", + allow_unvalidated_arch=request.case_config.allow_unvalidated_arch, + ) + + +def _megatron_task_tensor( + task: Any, + *, + mode: str, +) -> torch.Tensor: + param = cast(torch.nn.Parameter, task.param_weight) + if mode == "grad": + grad = param.grad + if grad is None: + grad = getattr(param, "main_grad", None) + if grad is None: + grad = torch.zeros_like(param) + if hasattr(grad, "_local_tensor"): + grad = cast(torch.Tensor, grad._local_tensor) + return cast(torch.Tensor, grad) + if mode == "param": + return param.detach() + raise ValueError(f"Unsupported task-tensor mode: {mode}") + + +def _mapping_supports_derivative_parity(mapping: Any) -> bool: + from megatron.bridge.models.conversion.param_mapping import ( + RMSNorm2ZeroCenteredRMSNormMapping, + ) + + return not isinstance(mapping, RMSNorm2ZeroCenteredRMSNormMapping) + + +def _is_language_hf_param_name(name: str) -> bool: + return not name.startswith(_VISUAL_HF_PREFIXES) + + +def _language_hf_param_names(mapping: Any) -> list[str]: + hf_param = mapping.hf_param + if isinstance(hf_param, str): + return [hf_param] + if isinstance(hf_param, dict): + return [value for value in hf_param.values() if isinstance(value, str)] + return [] + + +def _mapping_targets_language_only(mapping: Any) -> bool: + names = _language_hf_param_names(mapping) + if not names: + return True + return all(_is_language_hf_param_name(name) for name in names) + + +def _filter_language_only_tensor_map( + tensor_map: dict[str, torch.Tensor], +) -> dict[str, torch.Tensor]: + return { + key: value + for key, value in tensor_map.items() + if _is_language_hf_param_name(key) + } + + +def _convert_megatron_tasks_to_hf( + runtime: megatron_train.TrainingRuntime, + *, + mode: str, + tasks: list[Any] | None = None, +) -> dict[str, torch.Tensor]: + if tasks is None: + tasks = [ + task + for task in build_art_conversion_tasks( + bridge=runtime.bridge, + model=runtime.model, + ) + if isinstance(task.param_weight, torch.nn.Parameter) + ] + model_bridge = runtime.bridge._model_bridge + hf_state_dict = runtime.bridge.hf_pretrained.state + grouped_buffers: dict[str, dict[int, torch.Tensor]] = {} + converted: dict[str, torch.Tensor] = {} + for task in tasks: + tensor = _megatron_task_tensor(task, mode=mode) + converted_weights_dict = task.mapping.megatron_to_hf( + tensor, + task.megatron_module, + ) + if getattr(task.mapping, "is_grouped_export", False): + merged_result = model_bridge._accumulate_grouped_export( + task, + converted_weights_dict, + runtime.model[0].config, + grouped_buffers, + hf_state_dict, + ) + if merged_result is None: + continue + converted_weights_dict = merged_result + else: + converted_weights_dict = model_bridge.maybe_modify_converted_hf_weight( + task, + converted_weights_dict, + hf_state_dict, + ) + for hf_name, value in converted_weights_dict.items(): + if not _is_language_hf_param_name(hf_name): + continue + if hf_name in converted: + raise RuntimeError(f"Duplicate converted HF key '{hf_name}' in {mode}") + converted[hf_name] = value.detach().cpu().to(dtype=torch.float32) + return converted + + +def _run_megatron_sft_step( + *, + request: HfParityRunRequest, + micro_inputs: list[dict[str, torch.Tensor]], + sample_indices: list[int | None], + device: torch.device, + moe_routing_replay_bundle: MoeRoutingReplayBundle | None = None, +) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]: + runtime = _build_megatron_runtime(request) + _assert_runtime_configuration(runtime.model, request.case_config) + assert runtime.optimizer is not None + if moe_routing_replay_bundle is not None: + megatron_train.configure_moe_routing_replay( + runtime, + replay_bundle=moe_routing_replay_bundle, + strict=True, + ) + controller = runtime.moe_routing_replay_controller + if controller is None: + raise RuntimeError( + "Expected MoE routing replay controller to be configured" + ) + controller.set_step( + step_index=0, + sample_index=sample_indices, + global_grad_accumulation_sequences=request.case_config.grad_accumulation_sequences, + ) + _debug("initializing Megatron optimizer state") + megatron_train._eager_initialize_optimizer_state(runtime.optimizer) + tasks = [ + task + for task in build_art_conversion_tasks( + bridge=runtime.bridge, + model=runtime.model, + ) + if isinstance(task.param_weight, torch.nn.Parameter) + ] + _debug(f"built {len(tasks)} Megatron conversion tasks") + for chunk in runtime.model: + if hasattr(chunk, "zero_grad_buffer"): + chunk.zero_grad_buffer() # ty: ignore[call-non-callable] + for param in chunk.parameters(): + param.grad = None + loss_sum = torch.tensor(0.0, device=device) + token_count = 0 + trainable_losses: list[torch.Tensor] = [] + for micro_order, micro in enumerate(micro_inputs): + if runtime.moe_routing_replay_controller is not None: + runtime.moe_routing_replay_controller.begin_micro( + sample_indices[micro_order], + micro_order, + ) + input_ids, position_ids, shifted_labels, mask, seq_len = ( + megatron_train._prepare_sft_micro_inputs(micro, device) + ) + attention_mask = megatron_train._placeholder_attention_mask(device) + forward_kwargs = runtime.model_support_handler.get_forward_kwargs( + runtime.model[0], + attention_bias=megatron_train._causal_attention_state(seq_len, device), + ) + per_token_loss = runtime.model[0]( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=shifted_labels, + **forward_kwargs, + ) + masked_losses = per_token_loss[mask] + trainable_losses.append(masked_losses.detach().cpu()) + loss_sum = loss_sum + masked_losses.sum() + token_count += int(mask.sum().item()) + masked_losses.sum().backward() + _debug("finished Megatron forward/backward") + num_tokens = megatron_train._local_trainable_sft_token_count_tensor( + micro_inputs, + device=device, + ) + megatron_train._flush_param_grads_to_main_grads(runtime.model) + megatron_train.finalize_model_grads_extended( + megatron_train.as_megatron_api_chunks(runtime.model), + num_tokens=num_tokens, + ) + _debug("finalized Megatron grads") + derivative_tasks = [ + task + for task in tasks + if _mapping_supports_derivative_parity(task.mapping) + and _mapping_targets_language_only(task.mapping) + ] + _debug(f"retained {len(derivative_tasks)} derivative-safe conversion tasks") + grads = _convert_megatron_tasks_to_hf( + runtime, + mode="grad", + tasks=derivative_tasks, + ) + _debug("exported Megatron grads") + if runtime.moe_routing_replay_controller is not None: + runtime.moe_routing_replay_controller.finalize_step() + scalar_loss = (loss_sum / max(token_count, 1)).detach().cpu().reshape(1) + output_vector = torch.cat(trainable_losses, dim=0).to(dtype=torch.float32) + _debug("finished Megatron step") + return output_vector, scalar_loss, grads + + +def _normalize_hf_grads_for_bridge( + hf_grads: dict[str, torch.Tensor], + *, + expected_grad_keys: set[str], + model_support_handler: Any, +) -> dict[str, torch.Tensor]: + hf_grads = _filter_language_only_tensor_map(hf_grads) + hf_grads = model_support_handler.hf_tensor_map_to_art_canonical( + hf_grads, + expected_keys=expected_grad_keys, + ) + normalized_hf_grads = _normalize_hf_tensor_map_for_bridge( + hf_grads, + expected_grad_keys, + ) + return { + key: normalized_hf_grads[key] + for key in sorted(expected_grad_keys) + if key in normalized_hf_grads + } + + +def _worker_run(request: HfParityRunRequest) -> None: + if not torch.cuda.is_available(): + raise RuntimeError("HF parity requires at least one CUDA device") + torch.cuda.set_device(0) + _set_deterministic_seed(request.case_config.seed) + _configure_cuda_precision(request.case_config) + _enable_debug_traceback_dump() + + packed_tensors = packed_tensors_from_dir( + **request.packed_tensors.model_dump(exclude_none=True) + ) + trajectory_tensors = build_sft_trajectory_tensors_from_packed_tensors( + packed_tensors + ) + zero_template = megatron_train._zero_contribution_sft_inputs(trajectory_tensors[0]) + sample_indices = build_parity_sample_indices( + num_sequences=len(trajectory_tensors), + global_grad_accumulation_sequences=request.case_config.grad_accumulation_sequences, + ) + micro_inputs = megatron_train.select_sft_micro_inputs( + trajectory_tensors, + sample_indices, + zero_template, + ) + replay_topology = ReplayParallelTopology.model_validate( + ORACLE_TOPOLOGY.model_dump( + include={"tp", "ep", "etp", "dp", "sp", "cp", "pp", "vpp"}, + mode="python", + ) + ) + device = torch.device("cuda", 0) + try: + _debug("starting HF parity worker") + model_support_handler = get_model_support_handler( + request.case_config.base_model, + allow_unvalidated_arch=request.case_config.allow_unvalidated_arch, + ) + hf_outputs, hf_loss, hf_grads, moe_routing_replay_bundle = _run_hf_sft_step( + base_model=request.case_config.base_model, + num_layers=request.case_config.num_layers, + micro_inputs=micro_inputs, + sample_indices=sample_indices, + topology=replay_topology, + device=device, + ) + megatron_outputs, megatron_loss, megatron_grads = _run_megatron_sft_step( + request=request, + micro_inputs=micro_inputs, + sample_indices=sample_indices, + device=device, + moe_routing_replay_bundle=moe_routing_replay_bundle, + ) + _debug("finished HF and Megatron steps, building report") + normalized_hf_grads = _normalize_hf_grads_for_bridge( + hf_grads, + expected_grad_keys=set(megatron_grads.keys()), + model_support_handler=model_support_handler, + ) + active_embedding_rows = _active_embedding_token_rows(micro_inputs) + active_router_rows = _active_router_rows_by_layer(moe_routing_replay_bundle) + last_layer_index = request.case_config.num_layers - 1 + loss_active_last_layer_experts = _loss_active_last_layer_experts( + moe_routing_replay_bundle, + micro_inputs, + sample_indices, + layer_index=last_layer_index, + ) + normalized_hf_grads = _focus_derivative_tensor_map( + normalized_hf_grads, + active_embedding_rows=active_embedding_rows, + active_router_rows=active_router_rows, + last_layer_index=last_layer_index, + loss_active_last_layer_experts=loss_active_last_layer_experts, + ) + megatron_grads = _focus_derivative_tensor_map( + megatron_grads, + active_embedding_rows=active_embedding_rows, + active_router_rows=active_router_rows, + last_layer_index=last_layer_index, + loss_active_last_layer_experts=loss_active_last_layer_experts, + ) + outputs_summary = summarize_tensor_pair(hf_outputs, megatron_outputs) + loss_summary = summarize_tensor_pair(hf_loss, megatron_loss) + grads_rows = build_tensor_map_metric_rows( + phase="grads", + reference=normalized_hf_grads, + candidate=megatron_grads, + ) + report = build_hf_parity_report( + request=request, + outputs_summary=outputs_summary, + loss_summary=loss_summary, + grads_rows=grads_rows, + ) + _write_json( + Path(request.output_dir) / HF_PARITY_REPORT_FILENAME, + report.model_dump(mode="json"), + ) + _debug("wrote HF parity report") + finally: + if torch.distributed.is_initialized(): # ty: ignore[possibly-missing-attribute] + torch.distributed.destroy_process_group() # ty: ignore[possibly-missing-attribute] + + +def run_worker_cli(run_request_path: Path) -> None: + request = HfParityRunRequest.model_validate(_read_json(run_request_path)) + _worker_run(request) + + +def _parse_args(argv: list[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Megatron HF parity worker") + parser.add_argument("--run-request", type=Path, required=True) + return parser.parse_args(argv) + + +def _main(argv: list[str]) -> int: + args = _parse_args(argv) + run_worker_cli(args.run_request) + return 0 + + +if __name__ == "__main__": + raise SystemExit(_main(sys.argv[1:])) diff --git a/tests/integration/megatron/model_support/lora_coverage.py b/tests/integration/megatron/model_support/lora_coverage.py new file mode 100644 index 000000000..2cfb84ddb --- /dev/null +++ b/tests/integration/megatron/model_support/lora_coverage.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +import socket +from typing import Any + +from megatron.core import parallel_state as ps +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from pydantic import BaseModel, Field +import torch +from torch.distributed import ( + destroy_process_group, + init_process_group, + is_initialized, +) + +from art.megatron import train as megatron_train +from art.megatron.lora import LoRA + +from .oracle_harness import OracleCaseConfig, oracle_topology +from .oracle_worker import _configure_provider, provider_topology_env + +_WRAPPED_TARGET_SUFFIXES: dict[str, tuple[str, ...]] = { + "q_proj": (".self_attn.q_proj",), + "k_proj": (".self_attn.k_proj",), + "v_proj": (".self_attn.v_proj",), + "o_proj": (".self_attn.o_proj",), + "in_proj_qkv": (".linear_attn.in_proj_qkv",), + "in_proj_z": (".linear_attn.in_proj_z",), + "out_proj": (".linear_attn.out_proj",), + "gate_proj": (".gate_proj",), + "up_proj": (".up_proj",), + "down_proj": (".down_proj",), + "experts": ( + ".mlp.experts.{expert}.gate_up_proj", + ".mlp.experts.{expert}.down_proj", + ), +} + + +class LoraCoverageReport(BaseModel): + base_model: str + target_modules: list[str] + wrapped_target_modules: list[str] = Field(default_factory=list) + exported_target_modules: list[str] = Field(default_factory=list) + missing_wrapped_target_modules: list[str] = Field(default_factory=list) + missing_exported_target_modules: list[str] = Field(default_factory=list) + wrapped_adapter_prefix_count: int = 0 + export_base_count: int = 0 + export_adapter_count: int = 0 + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +@contextmanager +def _single_rank_model_parallel() -> Iterator[None]: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for Megatron LoRA coverage.") + if is_initialized(): + raise RuntimeError("torch.distributed is already initialized in this process.") + torch.cuda.set_device(0) + init_process_group( + backend="nccl", + init_method=f"tcp://127.0.0.1:{_find_free_port()}", + rank=0, + world_size=1, + ) + try: + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + expert_model_parallel_size=1, + ) + model_parallel_cuda_manual_seed(1234) + yield + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() + if is_initialized(): + destroy_process_group() + + +def _covered_wrapped_target_modules(adapter_prefixes: set[str]) -> set[str]: + covered: set[str] = set() + for target_module, suffixes in _WRAPPED_TARGET_SUFFIXES.items(): + if any( + prefix.endswith(suffix) + for prefix in adapter_prefixes + for suffix in suffixes + ): + covered.add(target_module) + if target_module == "experts" and any( + ".mlp.experts." in prefix for prefix in adapter_prefixes + ): + covered.add(target_module) + return covered + + +def _covered_exported_target_modules( + adapter_weights_by_base: dict[str, list[Any]], +) -> set[str]: + covered: set[str] = set() + for base_name, adapter_weights in adapter_weights_by_base.items(): + if base_name.endswith(".self_attention.linear_qkv.weight"): + for adapter_weight in adapter_weights: + adapter_key = getattr(adapter_weight, "adapter_key", None) + if adapter_key == "adapter_q": + covered.add("q_proj") + elif adapter_key == "adapter_k": + covered.add("k_proj") + elif adapter_key == "adapter_v": + covered.add("v_proj") + continue + if base_name.endswith(".self_attention.linear_proj.weight"): + covered.add("o_proj") + continue + if base_name.endswith(".self_attention.in_proj.weight"): + covered.update({"in_proj_qkv", "in_proj_z"}) + continue + if base_name.endswith(".self_attention.out_proj.weight"): + covered.add("out_proj") + continue + if ".mlp.experts.linear_fc" in base_name: + covered.add("experts") + continue + if ".linear_fc1.weight" in base_name: + covered.update({"gate_proj", "up_proj"}) + continue + if ".linear_fc2.weight" in base_name: + covered.add("down_proj") + return covered + + +def run_lora_coverage(case_config: OracleCaseConfig) -> LoraCoverageReport: + topology = oracle_topology(is_moe=case_config.is_moe) + with _single_rank_model_parallel(): + with provider_topology_env(topology): + runtime = megatron_train.build_training_runtime( + model_identifier=case_config.base_model, + provider_torch_dtype=torch.float32, + provider_configure=lambda provider: _configure_provider( + provider, topology, case_config + ), + print_env=False, + build_optimizer=False, + allow_unvalidated_arch=case_config.allow_unvalidated_arch, + ) + adapter_prefixes = { + module.adapter_model_prefix + for chunk in runtime.model + for module in chunk.modules() + if isinstance(module, LoRA) + } + adapter_weights_by_base = ( + runtime.provider_bundle.handler.build_adapter_weights_by_base(runtime.model) + ) + + target_modules = list(runtime.provider_bundle.spec.default_target_modules) + wrapped_target_modules = sorted(_covered_wrapped_target_modules(adapter_prefixes)) + exported_target_modules = sorted( + _covered_exported_target_modules(adapter_weights_by_base) + ) + return LoraCoverageReport( + base_model=case_config.base_model, + target_modules=target_modules, + wrapped_target_modules=wrapped_target_modules, + exported_target_modules=exported_target_modules, + missing_wrapped_target_modules=sorted( + set(target_modules) - set(wrapped_target_modules) + ), + missing_exported_target_modules=sorted( + set(target_modules) - set(exported_target_modules) + ), + wrapped_adapter_prefix_count=len(adapter_prefixes), + export_base_count=len(adapter_weights_by_base), + export_adapter_count=sum( + len(adapter_weights) for adapter_weights in adapter_weights_by_base.values() + ), + ) diff --git a/tests/integration/megatron_oracle_harness.py b/tests/integration/megatron/model_support/oracle_harness.py similarity index 78% rename from tests/integration/megatron_oracle_harness.py rename to tests/integration/megatron/model_support/oracle_harness.py index 9938b3b82..0de3b5a2e 100644 --- a/tests/integration/megatron_oracle_harness.py +++ b/tests/integration/megatron/model_support/oracle_harness.py @@ -16,17 +16,16 @@ from rich.table import Table import torch -from .megatron_forward_trace import ForwardTraceCapture +from .forward_trace import ForwardTraceCapture -REPO_ROOT = Path(__file__).resolve().parents[2] +REPO_ROOT = Path(__file__).resolve().parents[4] ARTIFACT_ROOT = Path(REPO_ROOT / ".local/megatron_lora_correctness") ORACLE_MOE_ROUTING_BUNDLE_DIRNAME = "oracle_moe_routing_replay" -ORACLE_REPLAY_TOPOLOGY_SUFFIX = "oracle_replay" REGENERATE_ENV = "ART_REGENERATE_ORACLE" -EXTENDED_TOPOLOGIES_ENV = "ART_ENABLE_EXTENDED_TOPOLOGIES" SENSITIVITY_MUTATION_ENV = "ART_SENSITIVITY_MUTATIONS" ORACLE_OBJECTIVE_ENV = "ART_ORACLE_OBJECTIVE" +KEEP_TOPOLOGY_ARTIFACTS_ENV = "ART_ORACLE_KEEP_TOPOLOGY_ARTIFACTS" OracleObjective = Literal["rl", "sft"] SUPPORTED_ORACLE_OBJECTIVES: tuple[OracleObjective, ...] = ("rl", "sft") @@ -97,15 +96,23 @@ def oracle_output_slug( def supported_sensitivity_mutations_for_objective( objective: OracleObjective, + *, + is_moe: bool = True, ) -> tuple[SensitivityMutation, ...]: + del is_moe return OBJECTIVE_SENSITIVITY_MUTATIONS[objective] def objective_supports_sensitivity_mutation( objective: OracleObjective, mutation: SensitivityMutation, + *, + is_moe: bool = True, ) -> bool: - return mutation in supported_sensitivity_mutations_for_objective(objective) + return mutation in supported_sensitivity_mutations_for_objective( + objective, + is_moe=is_moe, + ) def selected_oracle_objectives() -> list[OracleObjective]: @@ -172,14 +179,21 @@ def world_size(self) -> int: Topology(tp=2, ep=1, etp=1, dp=1, sp=True), Topology(tp=2, ep=2, etp=1, dp=1, sp=True), Topology(tp=2, ep=1, etp=2, dp=1, sp=True), -] -EXTENDED_TOPOLOGIES = [ Topology(tp=1, ep=1, etp=1, dp=2, sp=False), Topology(tp=1, ep=2, etp=1, dp=2, sp=False), Topology(tp=1, ep=1, etp=2, dp=2, sp=True), ] +DENSE_TOPOLOGIES = [ + Topology(tp=1, ep=1, etp=1, dp=1, sp=False), + Topology(tp=2, ep=1, etp=1, dp=1, sp=True), + Topology(tp=1, ep=1, etp=1, dp=2, sp=False), + Topology(tp=2, ep=1, etp=1, dp=2, sp=True), +] ORACLE_TOPOLOGY = TOPOLOGIES[0] +DENSE_ORACLE_TOPOLOGY = DENSE_TOPOLOGIES[0] SENSITIVITY_TOPOLOGY = Topology(tp=2, ep=2, etp=1, dp=1, sp=True) +DENSE_SENSITIVITY_TOPOLOGY = Topology(tp=2, ep=1, etp=1, dp=1, sp=True) +DENSE_DP_SENSITIVITY_TOPOLOGY = Topology(tp=1, ep=1, etp=1, dp=2, sp=False) SENSITIVITY_TOPOLOGY_BY_MUTATION: dict[SensitivityMutation, Topology] = { mutation: SENSITIVITY_TOPOLOGY for mutation in SUPPORTED_SENSITIVITY_MUTATIONS } @@ -196,14 +210,24 @@ def world_size(self) -> int: } +def oracle_topology(*, is_moe: bool = True) -> Topology: + return ORACLE_TOPOLOGY if is_moe else DENSE_ORACLE_TOPOLOGY + + +def selected_suite_topologies(*, is_moe: bool = True) -> list[Topology]: + return list(TOPOLOGIES if is_moe else DENSE_TOPOLOGIES) + + class PackedTensorConfig(BaseModel): """Controls synthetic packed tensor generation used by oracle harness runs.""" num_sequences: int = 4 sequence_length: int = 256 prefill_tokens: int = 64 - decode_tokens: int = 64 + completion_branches_per_prefix: int = Field(default=2, ge=1) decode_tokens_jitter: int = Field(default=32, ge=0) + decode_tokens: int = 64 + packing_mode: Literal["stop_early", "truncate"] = "stop_early" vocab_high: int = 8192 @@ -233,6 +257,7 @@ class MetricThresholdRule(BaseModel): """Callable row pass rule that AND-checks configured metric upper bounds.""" limits: dict[str, float] = Field(default_factory=dict) + minimums: dict[str, float] = Field(default_factory=dict) def failure_reasons(self, summary: MetricSummary) -> list[str]: """Builds readable failure reasons for this threshold rule.""" @@ -244,6 +269,13 @@ def failure_reasons(self, summary: MetricSummary) -> list[str]: continue if float(value) > float(limit): reasons.append(f"{key}={float(value):.6g}>{float(limit):.6g}") + for key, minimum in sorted(self.minimums.items()): + value = summary.get(key) + if not isinstance(value, (int, float)): + reasons.append(f"{key}=missing") + continue + if float(value) <= float(minimum): + reasons.append(f"{key}={float(value):.6g}<={float(minimum):.6g}") return reasons def __call__(self, summary: MetricSummary) -> bool: @@ -255,6 +287,7 @@ class OracleCaseConfig(BaseModel): """Contains all deterministic run parameters for one oracle case.""" base_model: str + is_moe: bool = True precision: Literal["bf16", "fp32"] = "fp32" num_layers: int = 4 seed: int = 20260304 @@ -265,6 +298,7 @@ class OracleCaseConfig(BaseModel): loss_scale: float = 1 packed_tensors: PackedTensorConfig = Field(default_factory=PackedTensorConfig) lora: LoraConfig = Field(default_factory=LoraConfig) + allow_unvalidated_arch: bool = False class DiskPackedTensorsSpec(BaseModel): @@ -404,6 +438,7 @@ def __init__(self) -> None: self.diff_sq_sum = 0.0 self.ref_sq_sum = 0.0 self.ref_abs_sum = 0.0 + self.candidate_abs_sum = 0.0 self.router_topk_total = 0 self.router_topk_mismatch = 0 self.router_top1_total = 0 @@ -421,6 +456,7 @@ def update(self, reference, candidate) -> None: # type: ignore[no-untyped-def] self.diff_sq_sum += float((cand - ref).square().sum().item()) self.ref_sq_sum += float(ref.square().sum().item()) self.ref_abs_sum += float(ref.abs().sum().item()) + self.candidate_abs_sum += float(cand.abs().sum().item()) @staticmethod def layer_averaged_summary(reference_stack, candidate_stack) -> dict[str, float]: # type: ignore[no-untyped-def] @@ -435,6 +471,7 @@ def layer_averaged_summary(reference_stack, candidate_stack) -> dict[str, float] "mean_abs_diff", "relative_l2", "typical_abs_scale", + "candidate_abs_scale", "mean_abs_pct", ] } @@ -477,12 +514,14 @@ def as_summary(self) -> dict[str, float]: "mean_abs_diff": 0.0, "relative_l2": 0.0, "typical_abs_scale": 0.0, + "candidate_abs_scale": 0.0, "mean_abs_pct": 0.0, "topk_mismatch_fraction": topk_fraction, "top1_mismatch_fraction": top1_fraction, } mean_abs = self.abs_sum / self.numel typical_abs = self.ref_abs_sum / self.numel + candidate_abs = self.candidate_abs_sum / self.numel mean_abs_pct = (mean_abs / (typical_abs + 1e-12)) * 100.0 return { "numel": _finite_metric(float(self.numel), default=0.0), @@ -491,6 +530,7 @@ def as_summary(self) -> dict[str, float]: (self.diff_sq_sum**0.5) / max(self.ref_sq_sum**0.5, 1e-12) ), "typical_abs_scale": _finite_metric(typical_abs, default=0.0), + "candidate_abs_scale": _finite_metric(candidate_abs, default=0.0), "mean_abs_pct": _finite_metric(mean_abs_pct), "topk_mismatch_fraction": _finite_metric(topk_fraction, default=1.0), "top1_mismatch_fraction": _finite_metric(top1_fraction, default=1.0), @@ -547,37 +587,70 @@ def sensitivity_enabled() -> bool: def selected_sensitivity_mutations_for_objective( objective: OracleObjective, mutations: list[SensitivityMutation], + *, + is_moe: bool = True, + max_world_size: int | None = None, ) -> list[SensitivityMutation]: return [ mutation for mutation in mutations - if objective_supports_sensitivity_mutation(objective, mutation) + if objective_supports_sensitivity_mutation( + objective, + mutation, + is_moe=is_moe, + ) + and ( + max_world_size is None + or sensitivity_topology_for_mutation( + mutation, + is_moe=is_moe, + ).world_size() + <= max_world_size + ) ] -def sensitivity_topology_for_mutation(mutation: SensitivityMutation) -> Topology: +def sensitivity_topology_for_mutation( + mutation: SensitivityMutation, + *, + is_moe: bool = True, +) -> Topology: """Returns the sensitivity topology required for one mutation.""" + if not is_moe: + if mutation in { + "dp_grad_accumulation_seqs", + "dp_local_token_normalization", + "sft_local_token_normalization", + }: + return DENSE_DP_SENSITIVITY_TOPOLOGY + return DENSE_SENSITIVITY_TOPOLOGY return SENSITIVITY_TOPOLOGY_BY_MUTATION[mutation] -def sensitivity_required_world_size(mutations: list[SensitivityMutation]) -> int: +def sensitivity_required_world_size( + mutations: list[SensitivityMutation], + *, + is_moe: bool = True, +) -> int: """Returns the max world-size required by a selected mutation set.""" + if not mutations: + return 0 return max( - sensitivity_topology_for_mutation(mutation).world_size() + sensitivity_topology_for_mutation(mutation, is_moe=is_moe).world_size() for mutation in mutations ) -def extended_topologies_enabled() -> bool: - """Returns whether extended topologies are enabled for the suite.""" - return _truthy(os.environ.get(EXTENDED_TOPOLOGIES_ENV)) - - def regenerate_requested() -> bool: """Returns whether regeneration mode is enabled for oracle artifacts.""" return _truthy(os.environ.get(REGENERATE_ENV)) +def keep_topology_artifacts() -> bool: + """Returns whether oracle topology tensor artifacts should be retained.""" + return _truthy(os.environ.get(KEEP_TOPOLOGY_ARTIFACTS_ENV)) + + def case_config( base_model: str = "Qwen/Qwen3-30B-A3B-Instruct-2507", ) -> OracleCaseConfig: @@ -630,37 +703,23 @@ def _build_packed_tensors( raise ValueError("num_sequences must be greater than 1") shape = (config.num_sequences, config.sequence_length) generator = torch.Generator().manual_seed(seed) - tokens = torch.randint( - low=10, - high=config.vocab_high, - size=shape, - dtype=torch.long, - generator=generator, - ) - # Ensure paired cross-DP rows are never token-identical. - half = config.num_sequences // 2 - if half > 0 and config.num_sequences % 2 == 0: - for pair_index in range(half): - left_index = pair_index - right_index = pair_index + half - if torch.equal(tokens[left_index], tokens[right_index]): - token_span = max(1, config.vocab_high - 10) - tokens[right_index] = ((tokens[right_index] - 10 + 1) % token_span) + 10 - group_ids = torch.zeros(shape, dtype=torch.long) + tokens = torch.zeros(shape, dtype=torch.long) + token_low = 10 + token_span = max(1, config.vocab_high - token_low) + group_ids = torch.full(shape, -1, dtype=torch.long) parent_ids = torch.full(shape, -1, dtype=torch.long) - input_pos = ( - torch.arange(config.sequence_length, dtype=torch.long) - .unsqueeze(0) - .expand(config.num_sequences, -1) - .clone() - ) - prefix_length = max(1, min(config.sequence_length - 1, config.prefill_tokens)) assistant_mask = torch.zeros(shape, dtype=torch.bool) - max_decode_tokens = max(1, config.sequence_length - prefix_length) - base_decode_tokens = max(1, min(config.decode_tokens, max_decode_tokens)) - jitter_width = min(config.decode_tokens_jitter, max_decode_tokens - 1) - candidate_decode_lengths: list[int] = [] - for _ in range(config.num_sequences): + input_pos = torch.zeros(shape, dtype=torch.long) + logprobs = torch.full(shape, float("nan"), dtype=torch.float32) + advantages = torch.zeros(shape, dtype=torch.float32) + weights = torch.zeros(shape, dtype=torch.float32) + + prefix_length = max(1, min(config.sequence_length - 1, config.prefill_tokens)) + max_completion_tokens = max(1, config.sequence_length - prefix_length) + base_completion_tokens = max(1, min(config.decode_tokens, max_completion_tokens)) + jitter_width = min(config.decode_tokens_jitter, max_completion_tokens - 1) + + def _sample_completion_length() -> int: if jitter_width > 0: jitter = int( torch.randint( @@ -673,58 +732,159 @@ def _build_packed_tensors( ) else: jitter = 0 - decode_length = max( + return max( 1, - min(max_decode_tokens, base_decode_tokens + jitter), + min(max_completion_tokens, base_completion_tokens + jitter), + ) + + def _sample_token_block(length: int) -> torch.Tensor: + return torch.randint( + low=token_low, + high=config.vocab_high, + size=(length,), + dtype=torch.long, + generator=generator, ) - candidate_decode_lengths.append(decode_length) - # Keep jitter local around the configured decode length, but force pairwise - # differences across halves so default DP rank shards see different lengths. + + def _sample_logprob_block(length: int) -> torch.Tensor: + return ( + torch.randn( + (length,), + generator=generator, + dtype=torch.float32, + ) + * 0.25 + - 1.75 + ) + + def _sample_advantage_value() -> float: + return float( + ( + torch.randn( + (1,), + generator=generator, + dtype=torch.float32, + ) + * 0.5 + ).item() + ) + + for sequence_index in range(config.num_sequences): + cursor = 0 + next_group_id = 0 + while cursor < config.sequence_length: + prompt_group_id = next_group_id + next_group_id += 1 + completion_lengths = [ + _sample_completion_length() + for _ in range(config.completion_branches_per_prefix) + ] + remaining = config.sequence_length - cursor + + if config.packing_mode == "stop_early": + included_completion_lengths = list(completion_lengths) + while ( + included_completion_lengths + and (prefix_length + sum(included_completion_lengths)) > remaining + ): + included_completion_lengths.pop() + if not included_completion_lengths: + break + + prompt_end = cursor + prefix_length + tokens[sequence_index, cursor:prompt_end] = _sample_token_block( + prefix_length + ) + group_ids[sequence_index, cursor:prompt_end] = prompt_group_id + parent_ids[sequence_index, cursor:prompt_end] = prompt_group_id + input_pos[sequence_index, cursor:prompt_end] = torch.arange( + prefix_length, dtype=torch.long + ) + cursor = prompt_end + + for completion_length in included_completion_lengths: + completion_group_id = next_group_id + next_group_id += 1 + completion_end = cursor + completion_length + tokens[sequence_index, cursor:completion_end] = _sample_token_block( + completion_length + ) + group_ids[sequence_index, cursor:completion_end] = ( + completion_group_id + ) + parent_ids[sequence_index, cursor:completion_end] = prompt_group_id + input_pos[sequence_index, cursor:completion_end] = torch.arange( + prefix_length, + prefix_length + completion_length, + dtype=torch.long, + ) + assistant_mask[sequence_index, cursor:completion_end] = True + logprobs[sequence_index, cursor:completion_end] = ( + _sample_logprob_block(completion_length) + ) + advantages[sequence_index, cursor:completion_end] = ( + _sample_advantage_value() + ) + weights[sequence_index, cursor:completion_end] = 1.0 + cursor = completion_end + continue + + prompt_take = min(prefix_length, remaining) + prompt_end = cursor + prompt_take + tokens[sequence_index, cursor:prompt_end] = _sample_token_block(prompt_take) + group_ids[sequence_index, cursor:prompt_end] = prompt_group_id + parent_ids[sequence_index, cursor:prompt_end] = prompt_group_id + input_pos[sequence_index, cursor:prompt_end] = torch.arange( + prompt_take, dtype=torch.long + ) + cursor = prompt_end + if cursor >= config.sequence_length: + break + + for completion_length in completion_lengths: + if cursor >= config.sequence_length: + break + completion_group_id = next_group_id + next_group_id += 1 + remaining = config.sequence_length - cursor + completion_take = min(completion_length, remaining) + completion_end = cursor + completion_take + tokens[sequence_index, cursor:completion_end] = _sample_token_block( + completion_take + ) + group_ids[sequence_index, cursor:completion_end] = completion_group_id + parent_ids[sequence_index, cursor:completion_end] = prompt_group_id + input_pos[sequence_index, cursor:completion_end] = torch.arange( + prefix_length, + prefix_length + completion_take, + dtype=torch.long, + ) + assistant_mask[sequence_index, cursor:completion_end] = True + logprobs[sequence_index, cursor:completion_end] = _sample_logprob_block( + completion_take + ) + advantages[sequence_index, cursor:completion_end] = ( + _sample_advantage_value() + ) + weights[sequence_index, cursor:completion_end] = 1.0 + cursor = completion_end + + half = config.num_sequences // 2 if half > 0 and config.num_sequences % 2 == 0: + valid_lengths = (group_ids != -1).sum(dim=1) for pair_index in range(half): left_index = pair_index right_index = pair_index + half - if ( - candidate_decode_lengths[left_index] - != candidate_decode_lengths[right_index] - ): + left_valid = int(valid_lengths[left_index].item()) + right_valid = int(valid_lengths[right_index].item()) + if left_valid != right_valid or left_valid == 0: continue - if candidate_decode_lengths[right_index] < max_decode_tokens: - candidate_decode_lengths[right_index] += 1 - elif candidate_decode_lengths[right_index] > 1: - candidate_decode_lengths[right_index] -= 1 - - for sequence_index, decode_length in enumerate(candidate_decode_lengths): - active_stop = prefix_length + decode_length - assistant_mask[sequence_index, prefix_length:active_stop] = True - decode_span = max(1, min(config.decode_tokens, decode_length)) - cursor = prefix_length - branch = 1 - while cursor < active_stop: - end = min(active_stop, cursor + decode_span) - group_ids[sequence_index, cursor:end] = branch - parent_ids[sequence_index, cursor:end] = 0 - cursor = end - branch += 1 - logprobs = ( - torch.randn( - shape, - generator=generator, - dtype=torch.float32, - ) - * 0.25 - - 1.75 - ) - advantages = ( - torch.randn( - shape, - generator=generator, - dtype=torch.float32, - ) - * 0.1 - + 1.0 - ) - weights = torch.ones(shape, dtype=torch.float32) + if torch.equal( + tokens[left_index, :left_valid], tokens[right_index, :right_valid] + ): + tokens[right_index, 0] = ( + (tokens[right_index, 0] - token_low + 1) % token_span + ) + token_low return { "tokens": tokens, "group_ids": group_ids, @@ -790,6 +950,19 @@ def _replace_topology_dir(path: Path) -> None: (path / "traces").mkdir(parents=True, exist_ok=True) +def _prune_topology_artifacts(path: Path) -> None: + """Keeps small diagnostics and removes tensors that are only needed for comparison.""" + if keep_topology_artifacts() or not path.exists(): + return + for child in path.iterdir(): + if child.name in {"variant_report.json", "run_request.json", "worker.log"}: + continue + if child.is_dir(): + shutil.rmtree(child) + continue + child.unlink() + + def _load_manifest(topology_dir: Path) -> RunManifest: """Loads one run manifest for a topology output directory.""" manifest_path = topology_dir / "manifest.json" @@ -920,7 +1093,8 @@ def __init__( self.case_artifacts = ensure_case_artifacts(case_config) self.case_id = self.case_artifacts.case_id self.case_dir = Path(self.case_artifacts.case_dir) - self.oracle_slug = oracle_output_slug(objective, ORACLE_TOPOLOGY) + self.oracle_topology = oracle_topology(is_moe=case_config.is_moe) + self.oracle_slug = oracle_output_slug(objective, self.oracle_topology) self.oracle_dir = self.case_dir / self.oracle_slug self.oracle_routing_bundle_dir = ( self.case_dir / f"{objective}__{ORACLE_MOE_ROUTING_BUNDLE_DIRNAME}" @@ -964,13 +1138,13 @@ def _run_topology( None if capture_bundle_dir is None else str(capture_bundle_dir) ), ) - from .megatron_oracle_worker import run_worker_subprocess + from .oracle_worker import run_worker_subprocess run_worker_subprocess(request, topology_dir, repo_root=REPO_ROOT) return topology_dir def ensure_oracle(self) -> Path: - """Ensures oracle capture and canonical replay artifacts exist exactly once per session.""" + """Ensures routing capture and the canonical replay-backed oracle exist once.""" regenerate = regenerate_requested() if self._oracle_initialized and (not regenerate or self._oracle_regenerated): return self.oracle_dir @@ -985,20 +1159,26 @@ def ensure_oracle(self) -> Path: ) run_oracle_topology = partial( self._run_topology, - topology=ORACLE_TOPOLOGY, + topology=self.oracle_topology, mutation=None, regenerate=True, ) - if need_capture: + if self.case_config.is_moe and need_capture: run_oracle_topology( output_slug=f"{self.oracle_slug}__oracle_capture", replay_bundle_dir=None, capture_bundle_dir=self.oracle_routing_bundle_dir, ) - if regenerate or not oracle_manifest.exists(): + if ( + regenerate + or not oracle_manifest.exists() + or not self.shared_init_path.exists() + ): run_oracle_topology( output_slug=self.oracle_slug, - replay_bundle_dir=self.oracle_routing_bundle_dir, + replay_bundle_dir=( + self.oracle_routing_bundle_dir if self.case_config.is_moe else None + ), capture_bundle_dir=None, ) self._oracle_initialized = True @@ -1018,7 +1198,9 @@ def ensure_variant_artifacts( topology=variant.topology, output_slug=output_slug, mutation=variant.mutation, - replay_bundle_dir=self.oracle_routing_bundle_dir, + replay_bundle_dir=( + self.oracle_routing_bundle_dir if self.case_config.is_moe else None + ), capture_bundle_dir=None, regenerate=variant.force_regenerate, ) @@ -1058,6 +1240,7 @@ def _inf_summary() -> dict[str, float]: "mean_abs_diff": NON_FINITE_METRIC_VALUE, "relative_l2": NON_FINITE_METRIC_VALUE, "typical_abs_scale": 0.0, + "candidate_abs_scale": 0.0, "mean_abs_pct": NON_FINITE_METRIC_VALUE, "topk_mismatch_fraction": 1.0, "top1_mismatch_fraction": 1.0, @@ -1409,6 +1592,15 @@ def _write_variant_report(self, topology_dir: Path, report: VariantReport) -> No topology_dir / "variant_report.json", report.model_dump(mode="json") ) + def _prune_reference_artifacts(self) -> None: + """Drops oracle-only tensors after all comparisons that need them are complete.""" + _prune_topology_artifacts(self.oracle_dir) + if self.case_config.is_moe: + _prune_topology_artifacts(self.oracle_routing_bundle_dir) + _prune_topology_artifacts( + self.case_dir / f"{self.oracle_slug}__oracle_capture" + ) + def print_report(self, report: VariantReport) -> None: """Prints a row-level table excluding expert-specific rows.""" table_rows = [ @@ -1463,6 +1655,7 @@ def run_variant( topology_dir = self.ensure_variant_artifacts(variant) report = self.compare_variant(variant) self._write_variant_report(topology_dir, report) + _prune_topology_artifacts(topology_dir) self.print_report(report) return report @@ -1472,16 +1665,19 @@ def run_suite( ) -> list[VariantReport]: """Runs variants in order and stops at the first unexpected signal.""" reports: list[VariantReport] = [] - for variant in variants: - report = self.run_variant(variant) - reports.append(report) - self.assert_expected_signal( - report, - "Megatron correctness suite mismatch", - report_path=self.case_dir - / variant.resolved_output_slug() - / "variant_report.json", - ) + try: + for variant in variants: + report = self.run_variant(variant) + reports.append(report) + self.assert_expected_signal( + report, + "Megatron correctness suite mismatch", + report_path=self.case_dir + / variant.resolved_output_slug() + / "variant_report.json", + ) + finally: + self._prune_reference_artifacts() return reports @@ -1490,10 +1686,19 @@ def _default_phase_pass_fns() -> dict[str, PhasePassFn]: # note the metrics get averaged across layers to reduce noise # we also average across experts to reduce noise # we don't expect particular layers to see errors as opposed to the others so this is helpful - fwd_out_loss = MetricThresholdRule( - limits={"relative_l2": 1e-2, "mean_abs_pct": 1.0} + non_zero_scales = {"typical_abs_scale": 0.0, "candidate_abs_scale": 0.0} + fwd_out = MetricThresholdRule( + limits={"relative_l2": 1e-2, "mean_abs_pct": 1.0}, + minimums=non_zero_scales, + ) + loss = MetricThresholdRule( + limits={"relative_l2": 2e-2, "mean_abs_pct": 2.0}, + minimums=non_zero_scales, + ) + grads_deltas = MetricThresholdRule( + limits={"mean_abs_pct": 3.0}, + minimums=non_zero_scales, ) - grads_deltas = MetricThresholdRule(limits={"mean_abs_pct": 3.0}) router_topk_rule = ( MetricThresholdRule( # should be no mismatch due to router replay limits={ @@ -1502,33 +1707,29 @@ def _default_phase_pass_fns() -> dict[str, PhasePassFn]: } ) ) - return {key: fwd_out_loss for key in ["forward", "outputs", "losses"]} | { + return { + "forward": fwd_out, + "outputs": fwd_out, + "losses": loss, + } | { "grads": grads_deltas, "deltas": grads_deltas, "router_topk_ids": router_topk_rule, } -def _suite_variants(objective: OracleObjective) -> list[VariantSpec]: +def _suite_variants( + objective: OracleObjective, + *, + is_moe: bool, + max_world_size: int | None = None, +) -> list[VariantSpec]: """Builds the standard oracle suite variant ordering.""" phase_pass = _default_phase_pass_fns() - variants = [ - VariantSpec( - name=f"{objective}_oracle_replay_parity", - objective=objective, - topology=ORACLE_TOPOLOGY, - output_slug=oracle_output_slug( - objective, - ORACLE_TOPOLOGY, - ORACLE_REPLAY_TOPOLOGY_SUFFIX, - ), - pass_fn_by_phase=phase_pass, - force_regenerate=regenerate_requested(), - ) - ] - for topology in TOPOLOGIES[1:] + ( - EXTENDED_TOPOLOGIES if extended_topologies_enabled() else [] - ): + variants: list[VariantSpec] = [] + for topology in selected_suite_topologies(is_moe=is_moe)[1:]: + if max_world_size is not None and topology.world_size() > max_world_size: + continue variants.append( VariantSpec( name=f"{objective}_topology_{topology.slug()}", @@ -1543,12 +1744,21 @@ def _suite_variants(objective: OracleObjective) -> list[VariantSpec]: def run_suite( *, case_config: OracleCaseConfig, + max_world_size: int | None = None, ) -> list[VariantReport]: - """Runs replay parity and topology variants with fail-fast assertions.""" + """Runs non-oracle topologies against the canonical replay-backed oracle.""" reports: list[VariantReport] = [] for objective in selected_oracle_objectives(): runner = VariantRunner(objective=objective, case_config=case_config) - reports.extend(runner.run_suite(_suite_variants(objective))) + reports.extend( + runner.run_suite( + _suite_variants( + objective, + is_moe=case_config.is_moe, + max_world_size=max_world_size, + ) + ) + ) return reports @@ -1556,16 +1766,28 @@ def run_sensitivity_suite( *, case_config: OracleCaseConfig, mutations: list[SensitivityMutation], + max_world_size: int | None = None, ) -> list[VariantReport]: """Runs a list of sensitivity mutations and expects each to fail.""" phase_pass = _default_phase_pass_fns() reports: list[VariantReport] = [] ran_any_variants = False + matched_any_objective = False for objective in selected_oracle_objectives(): runner = VariantRunner(objective=objective, case_config=case_config) + objective_supported_mutations = selected_sensitivity_mutations_for_objective( + objective, + mutations, + is_moe=case_config.is_moe, + ) + matched_any_objective = matched_any_objective or bool( + objective_supported_mutations + ) objective_mutations = selected_sensitivity_mutations_for_objective( objective, mutations, + is_moe=case_config.is_moe, + max_world_size=max_world_size, ) if not objective_mutations: continue @@ -1573,7 +1795,10 @@ def run_sensitivity_suite( VariantSpec( name=f"{objective}_sensitivity_{mutation}", objective=objective, - topology=sensitivity_topology_for_mutation(mutation), + topology=sensitivity_topology_for_mutation( + mutation, + is_moe=case_config.is_moe, + ), mutation=mutation, expected_signal="fail", pass_fn_by_phase=phase_pass, @@ -1582,13 +1807,17 @@ def run_sensitivity_suite( ] ran_any_variants = True reports.extend(runner.run_suite(variants)) - if ran_any_variants: + if ran_any_variants or (max_world_size is not None and matched_any_objective): return reports requested = ", ".join(mutations) - supported = ", ".join( - f"{objective}: {', '.join(supported_sensitivity_mutations_for_objective(objective))}" - for objective in selected_oracle_objectives() - ) + supported_by_objective = [] + for objective in selected_oracle_objectives(): + objective_supported = supported_sensitivity_mutations_for_objective( + objective, + is_moe=case_config.is_moe, + ) + supported_by_objective.append(f"{objective}: {', '.join(objective_supported)}") + supported = ", ".join(supported_by_objective) raise ValueError( "No sensitivity variants matched the selected objectives. " f"Requested mutations: {requested}. Supported by objective: {supported}." diff --git a/tests/integration/megatron_oracle_worker.py b/tests/integration/megatron/model_support/oracle_worker.py similarity index 81% rename from tests/integration/megatron_oracle_worker.py rename to tests/integration/megatron/model_support/oracle_worker.py index 123756758..712358028 100644 --- a/tests/integration/megatron_oracle_worker.py +++ b/tests/integration/megatron/model_support/oracle_worker.py @@ -2,12 +2,14 @@ import argparse from contextlib import ExitStack, contextmanager +import faulthandler import hashlib import os from pathlib import Path import random import subprocess import sys +import time from types import MethodType from typing import Any, Callable @@ -22,8 +24,8 @@ ) from art.preprocessing.pack import PackedTensors -from .megatron_forward_trace import ForwardTraceCapture -from .megatron_oracle_harness import ( +from .forward_trace import ForwardTraceCapture +from .oracle_harness import ( SUPPORTED_SENSITIVITY_MUTATIONS, OracleCaseConfig, RunManifest, @@ -35,6 +37,38 @@ _require_not_none, _write_json, ) +from .test_inputs import build_sft_trajectory_tensors_from_packed_tensors + +_TOPOLOGY_ENV_VARS = { + "tp": "ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", + "ep": "ART_MEGATRON_EXPERT_MODEL_PARALLEL_SIZE", + "etp": "ART_MEGATRON_EXPERT_TENSOR_PARALLEL_SIZE", +} +_ORACLE_DEBUG_ENV = "ART_ORACLE_DEBUG" +_ORACLE_DEBUG_START_TIME = time.perf_counter() + + +def _oracle_debug_enabled() -> bool: + return os.environ.get(_ORACLE_DEBUG_ENV, "").strip().lower() in { + "1", + "true", + "yes", + "on", + } + + +def _debug(message: str) -> None: + if not _oracle_debug_enabled(): + return + elapsed = time.perf_counter() - _ORACLE_DEBUG_START_TIME + print(f"[oracle-debug +{elapsed:.2f}s] {message}", flush=True) + + +def _enable_debug_traceback_dump() -> None: + if not _oracle_debug_enabled(): + return + faulthandler.enable() + faulthandler.dump_traceback_later(60, repeat=True) def run_worker_subprocess( @@ -46,7 +80,7 @@ def run_worker_subprocess( """Runs one distributed worker subprocess and stores combined logs.""" request_path = topology_dir / "run_request.json" _write_json(request_path, request.model_dump(mode="json")) - worker_module = "integration.megatron_oracle_worker" + worker_module = "integration.megatron.model_support.oracle_worker" worker_cwd = repo_root / "tests" command = [ @@ -62,18 +96,45 @@ def run_worker_subprocess( "--run-request", str(request_path), ] - env = {**os.environ, "PYTHONUNBUFFERED": "1"} - env["ART_DISABLE_MEGATRON_COMPILE"] = "1" - run = subprocess.run( - command, - cwd=str(worker_cwd), - env=env, - capture_output=True, - text=True, - check=False, - ) - combined_output = f"{run.stdout}\n{run.stderr}".strip() - (topology_dir / "worker.log").write_text(combined_output + "\n", encoding="utf-8") + combined_lines: list[str] = [] + worker_log_path = topology_dir / "worker.log" + live_log_raw = os.environ.get("ART_ORACLE_LIVE_TRAINING_LOG") + live_log_path = None if not live_log_raw else Path(live_log_raw) + worker_log_path.parent.mkdir(parents=True, exist_ok=True) + with worker_log_path.open("w", encoding="utf-8") as worker_log: + live_log = None + try: + if live_log_path is not None: + live_log_path.parent.mkdir(parents=True, exist_ok=True) + live_log = live_log_path.open("a", encoding="utf-8") + live_log.write( + f"\n=== {request.objective} {request.topology.slug()} ===\n" + ) + live_log.flush() + env = {**os.environ, "PYTHONUNBUFFERED": "1"} + env["ART_DISABLE_MEGATRON_COMPILE"] = "1" + run = subprocess.Popen( + command, + cwd=str(worker_cwd), + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + assert run.stdout is not None + for line in run.stdout: + combined_lines.append(line) + worker_log.write(line) + worker_log.flush() + if live_log is not None: + live_log.write(line) + live_log.flush() + run.returncode = run.wait() + finally: + if live_log is not None: + live_log.close() + combined_output = "".join(combined_lines).strip() if run.returncode != 0: tail = "\n".join(combined_output.splitlines()[-80:]) raise RuntimeError( @@ -93,26 +154,48 @@ def _set_deterministic_seed(seed: int) -> None: torch.backends.cudnn.benchmark = False +def provider_topology_env_vars(topology: Topology) -> dict[str, str]: + return { + _TOPOLOGY_ENV_VARS["tp"]: str(topology.tp), + _TOPOLOGY_ENV_VARS["ep"]: str(topology.ep), + _TOPOLOGY_ENV_VARS["etp"]: str(topology.etp), + } + + +@contextmanager +def provider_topology_env(topology: Topology): + previous = {name: os.environ.get(name) for name in _TOPOLOGY_ENV_VARS.values()} + os.environ.update(provider_topology_env_vars(topology)) + try: + yield + finally: + for name, value in previous.items(): + if value is None: + os.environ.pop(name, None) + continue + os.environ[name] = value + + def _merge_sharded_dicts(shards_by_rank: list[dict[str, Any]]) -> dict[str, Any]: """Merges rank-sharded LoRA tensors into a full state dict on rank 0.""" - import torch - - merged: dict[str, list[Any]] = {} - for rank_shards in shards_by_rank: - for key, tensor in rank_shards.items(): - merged.setdefault(key, []).append(tensor.detach().cpu()) - full_state: dict[str, Any] = {} - for key, shards in merged.items(): - if len(shards) == 1: - full_state[key] = shards[0].contiguous() - continue - concat_dim = 1 if ".lora_A." in key else 0 - full_state[key] = torch.cat(shards, dim=concat_dim).contiguous() - return full_state + from art.megatron.weights.merge import merge_sharded_adapter_entries + + entries_by_key: dict[str, list[tuple[dict[str, Any], torch.Tensor]]] = {} + for rank_entry in shards_by_rank: + rank_state = rank_entry["state"] + rank_manifest = rank_entry["manifest"] + for key, tensor in rank_state.items(): + if key not in rank_manifest: + raise RuntimeError(f"Missing manifest entry for sharded key '{key}'") + entries_by_key.setdefault(key, []).append( + (rank_manifest[key], tensor.detach().cpu()) + ) + return merge_sharded_adapter_entries(entries_by_key) def _gather_full_state( local_state: dict[str, Any], + local_manifest: dict[str, Any], ) -> dict[str, Any] | None: """Gathers local state dicts to rank 0 and merges them.""" import torch @@ -121,7 +204,9 @@ def _gather_full_state( world_size = torch.distributed.get_world_size() # ty: ignore[possibly-missing-attribute] gathered = [None for _ in range(world_size)] if rank == 0 else None torch.distributed.gather_object( # ty: ignore[possibly-missing-attribute] - local_state, gathered, dst=0 + {"state": local_state, "manifest": local_manifest}, + gathered, + dst=0, ) if rank != 0: return None @@ -135,8 +220,17 @@ def _collect_lora_state( ) -> dict[str, Any] | None: """Collects full LoRA adapter state for validation and delta computation.""" local_state: dict[str, Any] = {} + local_manifest: dict[str, Any] = {} for chunk in model_chunks: for module in chunk.modules(): + if hasattr(module, "sharded_lora_manifest"): + module_manifest = module.sharded_lora_manifest() + for key, value in module_manifest.items(): + if key in local_manifest and local_manifest[key] != value: + raise RuntimeError( + f"Duplicate manifest key while collecting state: {key}" + ) + local_manifest[key] = value if not hasattr(module, "sharded_lora_state_dict"): continue module_state = module.sharded_lora_state_dict() @@ -146,33 +240,35 @@ def _collect_lora_state( f"Duplicate LoRA key while collecting state: {key}" ) local_state[key] = value.detach().cpu() - return _gather_full_state(local_state) + return _gather_full_state(local_state, local_manifest) def _collect_lora_grads( model_chunks: list[Any], ) -> dict[str, Any] | None: """Collects full LoRA gradient tensors across all ranks.""" - from art.megatron.lora import LoRA - local_grads: dict[str, Any] = {} + local_manifest: dict[str, Any] = {} for chunk in model_chunks: for module in chunk.modules(): - if not isinstance(module, LoRA): + if hasattr(module, "sharded_lora_manifest"): + module_manifest = module.sharded_lora_manifest() + for key, value in module_manifest.items(): + if key in local_manifest and local_manifest[key] != value: + raise RuntimeError( + f"Duplicate manifest key while collecting grads: {key}" + ) + local_manifest[key] = value + if not hasattr(module, "sharded_lora_grad_dict"): continue - for key, param, expert in module._export_items(): # type: ignore[attr-defined] - if not hasattr(param, "main_grad"): + module_grads = module.sharded_lora_grad_dict() + for key, value in module_grads.items(): + if key in local_grads: raise RuntimeError( - f"LoRA param missing main_grad attribute for key '{key}'" + f"Duplicate LoRA grad key while collecting grads: {key}" ) - grad = param.main_grad - if grad is None: - raise RuntimeError(f"LoRA param main_grad is None for key '{key}'") - if hasattr(grad, "_local_tensor"): - grad = grad._local_tensor - captured_grad = grad[expert] if expert is not None else grad - local_grads[key] = captured_grad.detach().cpu().T - return _gather_full_state(local_grads) + local_grads[key] = value.detach().cpu() + return _gather_full_state(local_grads, local_manifest) def _apply_save_mutation_to_tensor_map( @@ -260,13 +356,7 @@ def _configure_provider( case_config: OracleCaseConfig, ) -> None: """Applies deterministic topology/model overrides to provider config.""" - provider.tensor_model_parallel_size = topology.tp - provider.expert_model_parallel_size = topology.ep - provider.expert_tensor_parallel_size = topology.etp - # These are intentionally pinned to 1 for now - provider.pipeline_model_parallel_size = 1 - provider.context_parallel_size = 1 - provider.sequence_parallel = topology.sp + del topology provider.num_layers = case_config.num_layers if case_config.precision == "fp32": provider.bf16 = False @@ -284,42 +374,33 @@ def _configure_provider( @contextmanager -def _patch_get_provider_for_oracle( +def _patch_finalize_provider_bundle_for_oracle( megatron_train_module: Any, - topology: Topology, case_config: OracleCaseConfig, ): - from art.megatron import provider as provider_module - - original_get_provider = megatron_train_module.get_provider - - def _oracle_get_provider( - model: str, - *, - torch_dtype: torch.dtype = torch.bfloat16, - ) -> Any: - provider = original_get_provider(model, torch_dtype=torch_dtype) - _configure_provider(provider, topology, case_config) - provider_module._apply_runtime_env_overrides(provider) + original_finalize_provider_bundle = megatron_train_module.finalize_provider_bundle + + def _oracle_finalize_provider_bundle(provider_bundle: Any) -> Any: + provider = provider_bundle.provider if case_config.precision == "fp32": - provider.moe_token_dispatcher_type = "alltoall" - provider.moe_flex_dispatcher_backend = None - provider.moe_shared_expert_overlap = True - provider.overlap_moe_expert_parallel_comm = False + if case_config.is_moe: + provider.moe_token_dispatcher_type = "alltoall" + provider.moe_flex_dispatcher_backend = None + provider.moe_shared_expert_overlap = True + provider.overlap_moe_expert_parallel_comm = False provider.delay_wgrad_compute = False provider.ep_overlap_early_attn_memory_release = False - elif provider_module._tp_ep_parallel_domain_size(provider) > 1: - provider_module.apply_flex_dispatcher_backend( - provider, moe_flex_dispatcher_backend="deepep" - ) - provider.finalize() - return provider + provider.finalize() + return provider_bundle + return original_finalize_provider_bundle(provider_bundle) - megatron_train_module.get_provider = _oracle_get_provider + megatron_train_module.finalize_provider_bundle = _oracle_finalize_provider_bundle try: yield finally: - megatron_train_module.get_provider = original_get_provider + megatron_train_module.finalize_provider_bundle = ( + original_finalize_provider_bundle + ) def _build_optimizer_config(case_config: OracleCaseConfig): @@ -437,6 +518,8 @@ def _matches_grad_sync_skip_mutation( return ( ".mlp.experts.linear_fc1.gate_lora.A_T" in param_name or ".mlp.experts.linear_fc1.up_lora.A_T" in param_name + or ".mlp.linear_fc1.gate_lora.A_T" in param_name + or ".mlp.linear_fc1.up_lora.A_T" in param_name ) return False @@ -459,8 +542,8 @@ def _apply_grad_sync_skip_mutation( # this only passes lora params atm, so we assume lora params below if not _matches_grad_sync_skip_mutation(param_name, mutation): continue - if ( - mutation == "bwd_skip_sync_fc1_a" and param.grad_sync_domain != "expert_tp" # ty: ignore[unresolved-attribute] + if mutation == "bwd_skip_sync_fc1_a" and ( + ".mlp.experts." in param_name and param.grad_sync_domain != "expert_tp" # ty: ignore[unresolved-attribute] ): continue @@ -512,6 +595,11 @@ def _apply_o_proj_forward_mutation( for module in chunk.modules(): if not isinstance(module, SelfAttentionLinearProjLoRA): continue + if not module.reduce_output: + continue + adapter_prefix = module.lora.adapter_model_prefix + if not adapter_prefix.endswith((".o_proj", ".out_proj")): + continue original_forwards.append((module, module.forward)) def _mutated_forward(self: Any, x: Any): @@ -783,23 +871,6 @@ def _scaled_loss_fn(*args: Any, **kwargs: Any): ) -def _build_sft_trajectory_tensors_from_packed_tensors( - packed_tensors: PackedTensors, -) -> list[dict[str, torch.Tensor]]: - tokens = packed_tensors["tokens"] - assistant_mask = packed_tensors["assistant_mask"] - labels = torch.where(assistant_mask, tokens, torch.full_like(tokens, -100)) - attention_mask = torch.ones_like(tokens, dtype=torch.long) - return [ - { - "input_ids": tokens[index].detach().clone(), - "attention_mask": attention_mask[index].detach().clone(), - "labels": labels[index].detach().clone(), - } - for index in range(int(tokens.shape[0])) - ] - - def _worker_run(request: WorkerRunRequest) -> None: """Executes one full distributed training trace generation worker run.""" from safetensors.torch import load_file, save_file # ty: ignore[unresolved-import] @@ -812,22 +883,33 @@ def _worker_run(request: WorkerRunRequest) -> None: local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local_rank) torch.distributed.init_process_group(backend="nccl") # ty: ignore[possibly-missing-attribute] + _enable_debug_traceback_dump() _set_deterministic_seed(request.case_config.seed) _configure_cuda_precision(request.case_config) - with _patch_get_provider_for_oracle( - megatron_train, request.topology, request.case_config - ): - runtime = megatron_train.build_training_runtime( - model_identifier=request.case_config.base_model, - provider_torch_dtype=( - torch.float32 - if request.case_config.precision == "fp32" - else torch.bfloat16 - ), - optimizer_config=_build_optimizer_config(request.case_config), - print_env=False, + with provider_topology_env(request.topology): + _debug( + f"starting build_training_runtime objective={request.objective} " + f"topology={request.topology.slug()} local_rank={local_rank}" ) + with _patch_finalize_provider_bundle_for_oracle( + megatron_train, request.case_config + ): + runtime = megatron_train.build_training_runtime( + model_identifier=request.case_config.base_model, + provider_torch_dtype=( + torch.float32 + if request.case_config.precision == "fp32" + else torch.bfloat16 + ), + provider_configure=lambda provider: _configure_provider( + provider, request.topology, request.case_config + ), + optimizer_config=_build_optimizer_config(request.case_config), + print_env=False, + allow_unvalidated_arch=request.case_config.allow_unvalidated_arch, + ) + _debug("finished build_training_runtime") model_chunks = runtime.model optimizer = runtime.optimizer megatron_train.configure_moe_routing_replay( @@ -878,7 +960,7 @@ def _worker_run(request: WorkerRunRequest) -> None: template = megatron_train.select_indexed_inputs(packed_tensors, 0) rl_zero_template = megatron_train._zero_contribution_inputs(template) else: - sft_trajectory_tensors = _build_sft_trajectory_tensors_from_packed_tensors( + sft_trajectory_tensors = build_sft_trajectory_tensors_from_packed_tensors( packed_tensors ) sft_zero_template = megatron_train._zero_contribution_sft_inputs( @@ -905,6 +987,7 @@ def _worker_run(request: WorkerRunRequest) -> None: model_chunks, enabled=True, micro_start_callback=micro_start_callback, + strict_output_match=request.mutation is None, ) def _capture_lora_grads() -> None: @@ -922,6 +1005,7 @@ def _capture_lora_grads() -> None: ), _patch_lora_for_fp32(model_chunks, optimizer), ): + _debug("starting training loop") for step_index in range(request.case_config.num_steps): micro_sample_indices = megatron_train.build_micro_sample_indices( step_index=step_index, @@ -930,6 +1014,7 @@ def _capture_lora_grads() -> None: ) forward_trace_capture.set_step(step_index, micro_sample_indices) captured_grads = None + _debug(f"starting step_index={step_index}") if request.objective == "rl": micro_inputs = megatron_train.select_micro_inputs( packed_tensors, @@ -938,6 +1023,7 @@ def _capture_lora_grads() -> None: ) step_result = megatron_train.run_training_step( model_chunks=model_chunks, + model_support_handler=runtime.model_support_handler, optimizer=optimizer, learning_rate=train_config.learning_rate, inputs=micro_inputs, @@ -956,6 +1042,7 @@ def _capture_lora_grads() -> None: ) step_result = megatron_train.run_megatron_sft_step( model_chunks=model_chunks, + model_support_handler=runtime.model_support_handler, optimizer=optimizer, learning_rate=train_config.learning_rate, inputs=micro_inputs, @@ -964,6 +1051,7 @@ def _capture_lora_grads() -> None: global_grad_accumulation_sequences=global_grad_accumulation_sequences, moe_routing_replay_controller=runtime.moe_routing_replay_controller, ) + _debug(f"finished step_index={step_index}") ordered_micro_outputs = forward_trace_capture.ordered_step_outputs() forward_trace_capture.save_current_step(traces_dir) torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] diff --git a/tests/integration/megatron/model_support/packed_position_ids.py b/tests/integration/megatron/model_support/packed_position_ids.py new file mode 100644 index 000000000..e29a0fbf4 --- /dev/null +++ b/tests/integration/megatron/model_support/packed_position_ids.py @@ -0,0 +1,969 @@ +from __future__ import annotations + +import argparse +import os +from pathlib import Path +import subprocess +import sys +import time +from typing import Any, cast + +from megatron.core import parallel_state as ps +from megatron.core.models.gpt.gpt_model import GPTModel +from pydantic import BaseModel, Field +import torch + +from art.megatron import train as megatron_train +from art.megatron.flex_attention import create_shared_prefix_attention_state +from art.megatron.model_support.discovery import inspect_architecture + +from .oracle_harness import ( + ORACLE_TOPOLOGY, + OracleCaseConfig, + PackedTensorConfig, + _read_json, + _write_json, +) +from .oracle_worker import _configure_provider, provider_topology_env + +# Qwen3.5/3.6 hybrid MoE runs show small shape-dependent logit drift between +# the single packed forward and many shorter reference forwards, even when the +# rotary grouping and shared-prefix semantics are correct. Keep the bound tight, +# but above the observed ~0.13% truncate-case jitter. +_LOGITS_MEAN_ABS_PCT_LIMIT = 0.2 +_DEBUG_ENV = "ART_PACKED_POSITION_IDS_DEBUG" +PACKED_POSITION_IDS_REPORT_FILENAME = "report.json" +REPO_ROOT = Path(__file__).resolve().parents[4] + + +def _slugify(value: str) -> str: + return value.lower().replace("/", "_").replace(".", "_").replace("-", "_") + + +def _artifact_dir(base_model: str) -> Path: + root = Path(__file__).resolve().parents[4] / ".local" / "model_support_validation" + path = root / _slugify(base_model) / "packed_position_ids" + path.mkdir(parents=True, exist_ok=True) + return path + + +def _debug_enabled() -> bool: + value = os.environ.get(_DEBUG_ENV, "") + return value not in ("", "0", "false", "False") + + +def _debug_log(message: str) -> None: + if _debug_enabled(): + print(f"[packed_position_ids] {message}", flush=True) + + +def _env_int(name: str, default: int) -> int: + raw = os.environ.get(name) + if raw is None or raw == "": + return default + return int(raw) + + +def _reset_vllm_compile_overrides() -> None: + """Undo vLLM's global Inductor compile-thread override for this test worker.""" + os.environ.pop("TORCHINDUCTOR_COMPILE_THREADS", None) + torch._inductor.config.compile_threads = ( + torch._inductor.config.decide_compile_threads() + ) + _debug_log( + f"reset inductor compile_threads={torch._inductor.config.compile_threads}" + ) + + +def _cuda_synchronize(device: torch.device | None = None) -> None: + if not torch.cuda.is_available(): + return + if device is None: + torch.cuda.synchronize() + return + torch.cuda.synchronize(device) + + +def _time_block( + label: str, + fn: Any, + *, + device: torch.device | None = None, +) -> Any: + _cuda_synchronize(device) + start = time.perf_counter() + result = fn() + _cuda_synchronize(device) + elapsed = time.perf_counter() - start + _debug_log(f"{label}: {elapsed:.3f}s") + return result + + +def _cleanup_distributed_state() -> None: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() + if torch.distributed.is_initialized(): # type: ignore[possibly-missing-attribute] + torch.distributed.destroy_process_group() # type: ignore[possibly-missing-attribute] + + +def _locate_gpt_module(model_chunks: list[Any]) -> GPTModel: + for chunk in model_chunks: + module: Any = chunk + while hasattr(module, "module"): + module = module.module + if isinstance(module, GPTModel): + return module + language_model = getattr(module, "language_model", None) + if isinstance(language_model, GPTModel): + return language_model + raise RuntimeError("Failed to locate GPTModel for packed position id validation") + + +class PackedPositionIdScenario(BaseModel): + name: str + num_sequences: int + sequence_length: int + checked_token_count: int + prompt_family_count: int + repeated_position_key_count: int + rotary_grouping_checked: bool + rotary_grouping_respected: bool + completion_pair_count: int + logits_equivalent: bool + logits_mean_abs_pct: float + logits_max_abs_diff: float + matched: bool + + +class PackedPositionIdsReport(BaseModel): + base_model: str + output_dir: str + num_layers: int + scenarios: list[PackedPositionIdScenario] = Field(default_factory=list) + + +class PackedPositionIdsRunRequest(BaseModel): + base_model: str + num_layers: int + output_dir: str + allow_unvalidated_arch: bool = False + + +def _prompt_family_count(group_ids: torch.Tensor, parent_ids: torch.Tensor) -> int: + families = 0 + for row_index in range(int(group_ids.shape[0])): + valid_tokens = int((group_ids[row_index] != -1).sum().item()) + cursor = 0 + while cursor < valid_tokens: + group_id = int(group_ids[row_index, cursor].item()) + parent_id = int(parent_ids[row_index, cursor].item()) + if group_id == parent_id: + families += 1 + while ( + cursor < valid_tokens + and int(group_ids[row_index, cursor].item()) == group_id + ): + cursor += 1 + return families + + +def _position_keys(position_ids: torch.Tensor) -> list[tuple[int, ...]]: + if position_ids.ndim == 1: + return [(int(value),) for value in position_ids.tolist()] + if position_ids.ndim == 2: + return [ + (int(position_ids[batch_index, token_index].item()),) + for batch_index in range(int(position_ids.shape[0])) + for token_index in range(int(position_ids.shape[1])) + ] + if position_ids.ndim == 3: + channel_first = position_ids.permute(1, 2, 0).contiguous() + return [ + tuple( + int(value) for value in channel_first[batch_index, token_index].tolist() + ) + for batch_index in range(int(channel_first.shape[0])) + for token_index in range(int(channel_first.shape[1])) + ] + raise ValueError( + f"Unsupported position_ids rank for packed position validation: {position_ids.ndim}" + ) + + +def _flatten_rotary_vectors( + rotary_output: torch.Tensor, + *, + position_ids: torch.Tensor, +) -> torch.Tensor: + sequence_length = int(position_ids.shape[-1]) + batch_size = int(position_ids.shape[-2]) if position_ids.ndim >= 2 else 1 + if rotary_output.ndim < 2 or rotary_output.shape[0] != sequence_length: + raise ValueError( + "Unexpected rotary output shape for packed position validation: " + f"{tuple(rotary_output.shape)} with position_ids shape {tuple(position_ids.shape)}" + ) + embedding_dim = int(rotary_output.shape[-1]) + vectors = rotary_output.reshape(sequence_length, -1, embedding_dim) + if vectors.shape[1] != batch_size: + raise ValueError( + "Rotary output batch/slot mismatch for packed position validation: " + f"got {vectors.shape[1]} slots for batch_size={batch_size}" + ) + return vectors.permute(1, 0, 2).reshape(batch_size * sequence_length, embedding_dim) + + +def _rotary_grouping_check( + rotary_output: torch.Tensor | None, + *, + position_ids: torch.Tensor, +) -> tuple[bool, bool, int]: + keys = _position_keys(position_ids) + key_counts: dict[tuple[int, ...], int] = {} + for key in keys: + key_counts[key] = key_counts.get(key, 0) + 1 + repeated_position_key_count = sum(1 for count in key_counts.values() if count > 1) + if rotary_output is None: + return False, True, repeated_position_key_count + vectors = _flatten_rotary_vectors(rotary_output, position_ids=position_ids) + first_vector_by_key: dict[tuple[int, ...], torch.Tensor] = {} + for key, vector in zip(keys, vectors, strict=True): + reference = first_vector_by_key.get(key) + if reference is None: + first_vector_by_key[key] = vector + continue + if not torch.equal(reference, vector): + return True, False, repeated_position_key_count + return True, True, repeated_position_key_count + + +def _build_art_realistic_packed_tensors( + config: PackedTensorConfig, + seed: int, +) -> dict[str, Any]: + if config.num_sequences <= 1: + raise ValueError("num_sequences must be greater than 1") + if config.prefill_tokens < 2: + raise ValueError( + "prefill_tokens must be at least 2 to build ART-style branch context" + ) + if config.sequence_length < 3: + raise ValueError( + "sequence_length must leave room for shared prompt, branch context, " + "and at least one trainable token" + ) + + shape = (config.num_sequences, config.sequence_length) + generator = torch.Generator().manual_seed(seed) + tokens = torch.zeros(shape, dtype=torch.long) + group_ids = torch.full(shape, -1, dtype=torch.long) + parent_ids = torch.full(shape, -1, dtype=torch.long) + input_pos = torch.zeros(shape, dtype=torch.long) + assistant_mask = torch.zeros(shape, dtype=torch.bool) + logprobs = torch.full(shape, float("nan"), dtype=torch.float32) + advantages = torch.zeros(shape, dtype=torch.float32) + weights = torch.zeros(shape, dtype=torch.float32) + + first_trainable_pos = max(2, min(config.sequence_length - 1, config.prefill_tokens)) + shared_prompt_length = first_trainable_pos - 1 + max_completion_tokens = max(1, config.sequence_length - first_trainable_pos) + base_completion_tokens = max(1, min(config.decode_tokens, max_completion_tokens)) + jitter_width = min(config.decode_tokens_jitter, max_completion_tokens - 1) + token_low = 10 + token_span = max(1, config.vocab_high - token_low) + + def _sample_completion_length() -> int: + if jitter_width > 0: + jitter = int( + torch.randint( + low=-jitter_width, + high=jitter_width + 1, + size=(1,), + generator=generator, + dtype=torch.long, + ).item() + ) + else: + jitter = 0 + return max(1, min(max_completion_tokens, base_completion_tokens + jitter)) + + def _sample_token_block(length: int) -> torch.Tensor: + return torch.randint( + low=token_low, + high=config.vocab_high, + size=(length,), + dtype=torch.long, + generator=generator, + ) + + def _sample_logprob_block(length: int) -> torch.Tensor: + return ( + torch.randn((length,), generator=generator, dtype=torch.float32) * 0.25 + - 1.75 + ) + + def _sample_advantage_value() -> float: + return float( + (torch.randn((1,), generator=generator, dtype=torch.float32) * 0.5).item() + ) + + def _write_prompt( + sequence_index: int, + cursor: int, + prompt_group_id: int, + ) -> tuple[int, int]: + prompt_tokens = _sample_token_block(first_trainable_pos) + prompt_end = cursor + shared_prompt_length + tokens[sequence_index, cursor:prompt_end] = prompt_tokens[:shared_prompt_length] + group_ids[sequence_index, cursor:prompt_end] = prompt_group_id + parent_ids[sequence_index, cursor:prompt_end] = prompt_group_id + input_pos[sequence_index, cursor:prompt_end] = torch.arange( + shared_prompt_length, + dtype=torch.long, + ) + return prompt_end, int(prompt_tokens[shared_prompt_length].item()) + + def _write_branch( + sequence_index: int, + cursor: int, + completion_group_id: int, + prompt_group_id: int, + context_token: int, + completion_length: int, + ) -> int: + branch_end = cursor + 1 + completion_length + tokens[sequence_index, cursor] = context_token + tokens[sequence_index, cursor + 1 : branch_end] = _sample_token_block( + completion_length + ) + group_ids[sequence_index, cursor:branch_end] = completion_group_id + parent_ids[sequence_index, cursor:branch_end] = prompt_group_id + input_pos[sequence_index, cursor:branch_end] = torch.arange( + shared_prompt_length, + shared_prompt_length + 1 + completion_length, + dtype=torch.long, + ) + trainable_start = cursor + 1 + assistant_mask[sequence_index, trainable_start:branch_end] = True + logprobs[sequence_index, trainable_start:branch_end] = _sample_logprob_block( + completion_length + ) + advantages[sequence_index, trainable_start:branch_end] = ( + _sample_advantage_value() + ) + weights[sequence_index, trainable_start:branch_end] = 1.0 / completion_length + return branch_end + + for sequence_index in range(config.num_sequences): + cursor = 0 + next_group_id = 0 + while cursor < config.sequence_length: + prompt_group_id = next_group_id + next_group_id += 1 + completion_lengths = [ + _sample_completion_length() + for _ in range(config.completion_branches_per_prefix) + ] + remaining = config.sequence_length - cursor + if remaining <= shared_prompt_length + 1: + break + + if config.packing_mode == "stop_early": + included_completion_lengths = list(completion_lengths) + while included_completion_lengths and ( + shared_prompt_length + + sum(1 + length for length in included_completion_lengths) + > remaining + ): + included_completion_lengths.pop() + if not included_completion_lengths: + break + + cursor, context_token = _write_prompt( + sequence_index, + cursor, + prompt_group_id, + ) + for completion_length in included_completion_lengths: + completion_group_id = next_group_id + next_group_id += 1 + cursor = _write_branch( + sequence_index, + cursor, + completion_group_id, + prompt_group_id, + context_token, + completion_length, + ) + continue + + cursor, context_token = _write_prompt( + sequence_index, + cursor, + prompt_group_id, + ) + for completion_length in completion_lengths: + remaining = config.sequence_length - cursor + if remaining <= 1: + break + completion_take = min(completion_length, remaining - 1) + completion_group_id = next_group_id + next_group_id += 1 + cursor = _write_branch( + sequence_index, + cursor, + completion_group_id, + prompt_group_id, + context_token, + completion_take, + ) + + half = config.num_sequences // 2 + if half > 0 and config.num_sequences % 2 == 0: + valid_lengths = (group_ids != -1).sum(dim=1) + for pair_index in range(half): + left_index = pair_index + right_index = pair_index + half + left_valid = int(valid_lengths[left_index].item()) + right_valid = int(valid_lengths[right_index].item()) + if left_valid != right_valid or left_valid == 0: + continue + if torch.equal( + tokens[left_index, :left_valid], + tokens[right_index, :right_valid], + ): + tokens[right_index, 0] = ( + (tokens[right_index, 0] - token_low + 1) % token_span + ) + token_low + + weights = torch.where(assistant_mask, weights, torch.zeros_like(weights)) + if bool(assistant_mask.any().item()): + weights[assistant_mask] /= weights[assistant_mask].mean() + advantages = torch.where( + assistant_mask, + advantages, + torch.zeros_like(advantages), + ) + advantage_scale = ( + advantages[assistant_mask].abs() * weights[assistant_mask] + ).mean() + if float(advantage_scale.item()) > 0.0: + advantages[assistant_mask] /= advantage_scale + + return { + "tokens": tokens, + "group_ids": group_ids, + "parent_ids": parent_ids, + "input_pos": input_pos, + "assistant_mask": assistant_mask, + "logprobs": logprobs, + "advantages": advantages, + "weights": weights, + "pixel_values": [None] * config.num_sequences, + "image_grid_thw": [None] * config.num_sequences, + } + + +def _prompt_family_segments( + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + *, + required_completion_count: int = 2, +) -> list[tuple[tuple[int, int], list[tuple[int, int]]]]: + families: list[tuple[tuple[int, int], list[tuple[int, int]]]] = [] + valid_tokens = int((group_ids != -1).sum().item()) + cursor = 0 + while cursor < valid_tokens: + group_id = int(group_ids[cursor].item()) + parent_id = int(parent_ids[cursor].item()) + prompt_start = cursor + while cursor < valid_tokens and int(group_ids[cursor].item()) == group_id: + cursor += 1 + prompt_end = cursor + if group_id != parent_id: + continue + completions: list[tuple[int, int]] = [] + while cursor < valid_tokens: + completion_group_id = int(group_ids[cursor].item()) + completion_parent_id = int(parent_ids[cursor].item()) + if completion_parent_id != group_id or completion_group_id == group_id: + break + completion_start = cursor + while ( + cursor < valid_tokens + and int(group_ids[cursor].item()) == completion_group_id + ): + cursor += 1 + completions.append((completion_start, cursor)) + if len(completions) >= required_completion_count: + families.append(((prompt_start, prompt_end), completions)) + return families + + +def _run_logits( + *, + model: Any, + handler: Any, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_bias: Any, +) -> torch.Tensor: + forward_kwargs = handler.get_forward_kwargs( + model, + attention_bias=attention_bias, + ) + with torch.no_grad(): + return cast( + torch.Tensor, + model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=torch.zeros( + (1, 1, 1, 1), + dtype=torch.bool, + device=input_ids.device, + ), + labels=None, + **forward_kwargs, + ), + ) + + +def _logits_equivalence_check( + *, + model: Any, + handler: Any, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, +) -> tuple[int, bool, float, float]: + _debug_log( + "logits_check start " + f"batch={int(input_ids.shape[0])} seq={int(input_ids.shape[1])}" + ) + completion_pair_count = 0 + logits_max_abs_diff = 0.0 + logits_abs_sum = 0.0 + logits_ref_abs_sum = 0.0 + logits_numel = 0 + for row_index in range(int(input_ids.shape[0])): + row_group_ids = group_ids[row_index : row_index + 1] + row_parent_ids = parent_ids[row_index : row_index + 1] + families = _prompt_family_segments(row_group_ids[0], row_parent_ids[0]) + if not families: + _debug_log(f"logits_check row={row_index} skipped no prompt family") + continue + row_input_ids = input_ids[row_index : row_index + 1] + row_position_ids = position_ids[row_index : row_index + 1] + packed_bias = create_shared_prefix_attention_state( + group_ids=row_group_ids, + parent_ids=row_parent_ids, + ) + _debug_log(f"logits_check row={row_index} families={len(families)}") + packed_logits = _time_block( + f"logits_check row={row_index} packed_forward", + lambda: _run_logits( + model=model, + handler=handler, + input_ids=row_input_ids, + position_ids=row_position_ids, + attention_bias=packed_bias, + ), + device=row_input_ids.device, + ) + for family_index, (prompt_segment, completion_segments) in enumerate(families): + prompt_start, prompt_end = prompt_segment + _debug_log( + "logits_check row=" + f"{row_index} family={family_index} " + f"prompt=({prompt_start},{prompt_end}) " + f"completions={completion_segments}" + ) + for completion_index, (completion_start, completion_end) in enumerate( + completion_segments + ): + reference_input_ids = torch.cat( + ( + row_input_ids[:, prompt_start:prompt_end], + row_input_ids[:, completion_start:completion_end], + ), + dim=1, + ) + reference_position_ids = torch.cat( + ( + row_position_ids[:, prompt_start:prompt_end], + row_position_ids[:, completion_start:completion_end], + ), + dim=1, + ) + reference_group_ids = torch.zeros_like(reference_input_ids) + reference_parent_ids = torch.zeros_like(reference_input_ids) + reference_bias = create_shared_prefix_attention_state( + group_ids=reference_group_ids, + parent_ids=reference_parent_ids, + ) + _debug_log( + "logits_check row=" + f"{row_index} family={family_index} " + f"completion={completion_index} " + f"segment=({completion_start},{completion_end}) " + f"reference_seq={int(reference_input_ids.shape[1])}" + ) + reference_logits = _time_block( + ( + f"logits_check row={row_index} " + f"family={family_index} " + f"completion={completion_index} reference_forward" + ), + lambda: _run_logits( + model=model, + handler=handler, + input_ids=reference_input_ids, + position_ids=reference_position_ids, + attention_bias=reference_bias, + ), + device=reference_input_ids.device, + ) + if completion_end - completion_start < 2: + continue + packed_completion_logits = packed_logits[ + :, + completion_start : completion_end - 1, + :, + ] + reference_completion_logits = reference_logits[ + :, + prompt_end - prompt_start : -1, + :, + ] + diff = (packed_completion_logits - reference_completion_logits).abs() + logits_abs_sum += float(diff.sum().item()) + logits_ref_abs_sum += float( + reference_completion_logits.abs().sum().item() + ) + logits_numel += int(diff.numel()) + logits_max_abs_diff = max( + logits_max_abs_diff, + float(diff.max().item()), + ) + completion_pair_count += 1 + _debug_log( + "logits_check row=" + f"{row_index} family={family_index} " + f"completion={completion_index} " + f"max_abs_diff={float(diff.max().item()):.6f}" + ) + if completion_pair_count > 0: + mean_abs = logits_abs_sum / max(logits_numel, 1) + typical_abs = logits_ref_abs_sum / max(logits_numel, 1) + logits_mean_abs_pct = (mean_abs / (typical_abs + 1e-12)) * 100.0 + logits_equivalent = logits_mean_abs_pct <= _LOGITS_MEAN_ABS_PCT_LIMIT + _debug_log( + "logits_check done " + f"pairs={completion_pair_count} " + f"equivalent={logits_equivalent} " + f"mean_abs_pct={logits_mean_abs_pct:.6f} " + f"max_abs_diff={logits_max_abs_diff:.6f}" + ) + return ( + completion_pair_count, + logits_equivalent, + logits_mean_abs_pct, + logits_max_abs_diff, + ) + _debug_log("logits_check finished without any prompt family") + return 0, False, float("inf"), float("inf") + + +def _run_packed_position_ids_subprocess( + request: PackedPositionIdsRunRequest, + output_dir: Path, +) -> None: + request_path = output_dir / "run_request.json" + _write_json(request_path, request.model_dump(mode="json")) + worker_cwd = REPO_ROOT / "tests" + command = [ + sys.executable, + "-m", + "integration.megatron.model_support.packed_position_ids", + "--run-request", + str(request_path), + ] + env = {**os.environ, "PYTHONUNBUFFERED": "1"} + run = subprocess.run( + command, + cwd=str(worker_cwd), + env=env, + capture_output=True, + text=True, + check=False, + ) + combined_output = f"{run.stdout}\n{run.stderr}".strip() + (output_dir / "worker.log").write_text(combined_output + "\n", encoding="utf-8") + if run.returncode != 0: + tail = "\n".join(combined_output.splitlines()[-80:]) + raise RuntimeError( + "Packed position ids worker failed with exit code " + f"{run.returncode}.\n{tail}" + ) + + +def _run_packed_position_ids_worker( + *, + base_model: str, + num_layers: int, + output_dir: Path, + allow_unvalidated_arch: bool = False, +) -> PackedPositionIdsReport: + _debug_log(f"run start base_model={base_model} num_layers={num_layers}") + _reset_vllm_compile_overrides() + scenarios = [ + ( + "stop_early", + PackedTensorConfig( + num_sequences=4, + sequence_length=_env_int( + "ART_PACKED_POSITION_IDS_STOP_EARLY_SEQUENCE_LENGTH", 1024 + ), + prefill_tokens=_env_int( + "ART_PACKED_POSITION_IDS_STOP_EARLY_PREFILL_TOKENS", 256 + ), + completion_branches_per_prefix=2, + decode_tokens=_env_int( + "ART_PACKED_POSITION_IDS_STOP_EARLY_DECODE_TOKENS", 128 + ), + decode_tokens_jitter=_env_int( + "ART_PACKED_POSITION_IDS_STOP_EARLY_DECODE_TOKENS_JITTER", 32 + ), + packing_mode="stop_early", + ), + ), + ( + "truncate", + PackedTensorConfig( + num_sequences=4, + sequence_length=_env_int( + "ART_PACKED_POSITION_IDS_TRUNCATE_SEQUENCE_LENGTH", 1024 + ), + prefill_tokens=_env_int( + "ART_PACKED_POSITION_IDS_TRUNCATE_PREFILL_TOKENS", 256 + ), + completion_branches_per_prefix=2, + decode_tokens=_env_int( + "ART_PACKED_POSITION_IDS_TRUNCATE_DECODE_TOKENS", 128 + ), + decode_tokens_jitter=_env_int( + "ART_PACKED_POSITION_IDS_TRUNCATE_DECODE_TOKENS_JITTER", 32 + ), + packing_mode="truncate", + ), + ), + ] + report = PackedPositionIdsReport( + base_model=base_model, + output_dir=str(output_dir), + num_layers=num_layers, + ) + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for packed position id validation") + + case_config = OracleCaseConfig( + base_model=base_model, + precision="fp32", + num_layers=num_layers, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + runtime: megatron_train.TrainingRuntime | None = None + try: + with provider_topology_env(ORACLE_TOPOLOGY): + runtime = _time_block( + "build_training_runtime", + lambda: megatron_train.build_training_runtime( + model_identifier=base_model, + provider_torch_dtype=torch.float32, + provider_configure=lambda provider: _configure_provider( + provider, + ORACLE_TOPOLOGY, + case_config, + ), + print_env=False, + build_optimizer=False, + trainable_parameter_mode="base_model", + allow_unvalidated_arch=allow_unvalidated_arch, + ), + ) + model_chunks = cast(list[Any], runtime.model) + gpt_module = _locate_gpt_module(model_chunks) + for chunk in model_chunks: + chunk.eval() + hooked_preprocess = gpt_module._preprocess + + for scenario_name, packed_config in scenarios: + _debug_log( + f"scenario {scenario_name} start seq_len={packed_config.sequence_length}" + ) + packed_tensors = _time_block( + f"scenario {scenario_name} build_packed_tensors", + lambda: _build_art_realistic_packed_tensors( + packed_config, + case_config.seed, + ), + ) + position_ids = cast(torch.Tensor, packed_tensors["input_pos"]).cuda() + input_ids = cast(torch.Tensor, packed_tensors["tokens"]).cuda() + group_ids = cast(torch.Tensor, packed_tensors["group_ids"]).cuda() + parent_ids = cast(torch.Tensor, packed_tensors["parent_ids"]).cuda() + rotary_grouping_checked = False + rotary_grouping_respected = True + repeated_position_key_count = 0 + for row_index in range(int(position_ids.shape[0])): + row_position_ids = position_ids[row_index : row_index + 1] + row_input_ids = input_ids[row_index : row_index + 1] + hooked_output = _time_block( + f"scenario {scenario_name} row={row_index} hooked_preprocess", + lambda: hooked_preprocess( + input_ids=row_input_ids, + position_ids=row_position_ids, + ), + device=row_input_ids.device, + ) + rotary_output = hooked_output[1] + checked, respected, repeated_count = _rotary_grouping_check( + cast(torch.Tensor | None, rotary_output) + if torch.is_tensor(rotary_output) + else None, + position_ids=row_position_ids, + ) + rotary_grouping_checked = rotary_grouping_checked or checked + rotary_grouping_respected = rotary_grouping_respected and respected + repeated_position_key_count += repeated_count + _debug_log( + f"scenario {scenario_name} row={row_index} " + f"checked={checked} respected={respected} " + f"repeated_keys={repeated_count}" + ) + ( + completion_pair_count, + logits_equivalent, + logits_mean_abs_pct, + logits_max_abs_diff, + ) = _time_block( + f"scenario {scenario_name} logits_equivalence_check", + lambda: _logits_equivalence_check( + model=model_chunks[0], + handler=runtime.model_support_handler, + input_ids=input_ids, + position_ids=position_ids, + group_ids=group_ids, + parent_ids=parent_ids, + ), + device=input_ids.device, + ) + matched = ( + repeated_position_key_count > 0 + and completion_pair_count > 0 + and rotary_grouping_checked + and rotary_grouping_respected + and logits_equivalent + ) + _debug_log( + f"scenario {scenario_name} done matched={matched} " + f"pairs={completion_pair_count} logits_equivalent={logits_equivalent} " + f"logits_mean_abs_pct={logits_mean_abs_pct:.6f} " + f"logits_max_abs_diff={logits_max_abs_diff:.6f}" + ) + report.scenarios.append( + PackedPositionIdScenario( + name=scenario_name, + num_sequences=int(position_ids.shape[0]), + sequence_length=int(position_ids.shape[1]), + checked_token_count=int((group_ids != -1).sum().item()), + prompt_family_count=_prompt_family_count( + group_ids.cpu(), + parent_ids.cpu(), + ), + repeated_position_key_count=repeated_position_key_count, + rotary_grouping_checked=rotary_grouping_checked, + rotary_grouping_respected=rotary_grouping_respected, + completion_pair_count=completion_pair_count, + logits_equivalent=logits_equivalent, + logits_mean_abs_pct=logits_mean_abs_pct, + logits_max_abs_diff=logits_max_abs_diff, + matched=matched, + ) + ) + del model_chunks + torch.cuda.empty_cache() + _debug_log("run complete; model deleted and cuda cache emptied") + finally: + del runtime + torch.cuda.empty_cache() + _cleanup_distributed_state() + + (output_dir / PACKED_POSITION_IDS_REPORT_FILENAME).write_text( + report.model_dump_json(indent=2), + encoding="utf-8", + ) + return report + + +def run_packed_position_ids( + *, + base_model: str, + num_layers: int | None = None, + allow_unvalidated_arch: bool = False, +) -> PackedPositionIdsReport: + _debug_log(f"run start base_model={base_model} requested_num_layers={num_layers}") + resolved_num_layers = ( + max( + 1, + inspect_architecture( + base_model, + torch_dtype=torch.float32, + allow_unvalidated_arch=allow_unvalidated_arch, + ).recommended_min_layers, + ) + if num_layers is None + else num_layers + ) + _debug_log(f"run resolved_num_layers={resolved_num_layers}") + output_dir = _artifact_dir(base_model) + report_path = output_dir / PACKED_POSITION_IDS_REPORT_FILENAME + if report_path.exists(): + report_path.unlink() + request = PackedPositionIdsRunRequest( + base_model=base_model, + num_layers=resolved_num_layers, + output_dir=str(output_dir), + allow_unvalidated_arch=allow_unvalidated_arch, + ) + with provider_topology_env(ORACLE_TOPOLOGY): + _run_packed_position_ids_subprocess(request, output_dir) + return PackedPositionIdsReport.model_validate(_read_json(report_path)) + + +def run_worker_cli(run_request_path: Path) -> None: + request = PackedPositionIdsRunRequest.model_validate(_read_json(run_request_path)) + _run_packed_position_ids_worker( + base_model=request.base_model, + num_layers=request.num_layers, + output_dir=Path(request.output_dir), + allow_unvalidated_arch=request.allow_unvalidated_arch, + ) + + +def _parse_args(argv: list[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Megatron packed position ids worker") + parser.add_argument("--run-request", type=Path, required=True) + return parser.parse_args(argv) + + +def _main(argv: list[str]) -> int: + args = _parse_args(argv) + run_worker_cli(args.run_request) + return 0 + + +if __name__ == "__main__": + raise SystemExit(_main(sys.argv[1:])) diff --git a/tests/integration/megatron/model_support/test_compile_flags.py b/tests/integration/megatron/model_support/test_compile_flags.py new file mode 100644 index 000000000..0edac9a94 --- /dev/null +++ b/tests/integration/megatron/model_support/test_compile_flags.py @@ -0,0 +1,21 @@ +from art.megatron.model_support.handlers.qwen3_5 import QWEN3_5_MOE_HANDLER +from art.megatron.model_support.handlers.qwen3_moe import QWEN3_MOE_HANDLER + + +def test_qwen3_moe_compile_workarounds_cover_deepep_permute_restore() -> None: + config = QWEN3_MOE_HANDLER.compile_workaround_config(object()) + assert config.flags == ( + "alltoall_dtoh", + "alltoall_dispatch_preprocess", + "deepep_permute_restore", + ) + + +def test_qwen35_moe_compile_workarounds_cover_deepep_permute_restore() -> None: + provider = type("Provider", (), {"moe_shared_expert_overlap": False})() + config = QWEN3_5_MOE_HANDLER.compile_workaround_config(provider) + assert config.flags == ( + "alltoall_dtoh", + "alltoall_dispatch_preprocess", + "deepep_permute_restore", + ) diff --git a/tests/integration/megatron/model_support/test_hf_parity.py b/tests/integration/megatron/model_support/test_hf_parity.py new file mode 100644 index 000000000..631f0acf5 --- /dev/null +++ b/tests/integration/megatron/model_support/test_hf_parity.py @@ -0,0 +1,34 @@ +from pathlib import Path + +import pytest + +from .hf_parity import HF_PARITY_ENABLE_ENV, hf_parity_enabled, run_hf_parity +from .oracle_harness import available_gpu_count, case_config + +HF_PARITY_LOG_PATH = Path(__file__).resolve().parents[4] / ".local" / "hf_parity.log" + + +def test_megatron_hf_sft_parity() -> None: + if not hf_parity_enabled(): + HF_PARITY_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) + HF_PARITY_LOG_PATH.write_text( + f"HF parity skipped. Set {HF_PARITY_ENABLE_ENV}=1 to enable.\n", + encoding="utf-8", + ) + pytest.skip(f"Set {HF_PARITY_ENABLE_ENV}=1 to enable HF parity.") + if available_gpu_count() < 1: + HF_PARITY_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) + HF_PARITY_LOG_PATH.write_text( + "HF parity skipped. Need at least 1 GPU.\n", + encoding="utf-8", + ) + pytest.skip("Need at least 1 GPU for HF parity.") + report = run_hf_parity( + case_config=case_config(base_model="Qwen/Qwen3.5-35B-A3B"), + ) + HF_PARITY_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) + HF_PARITY_LOG_PATH.write_text( + f"HF parity report: {report.model_dump_json(indent=2)}\n", + encoding="utf-8", + ) + assert report.signal == "pass" diff --git a/tests/integration/megatron/model_support/test_hf_parity_invariants.py b/tests/integration/megatron/model_support/test_hf_parity_invariants.py new file mode 100644 index 000000000..24345136c --- /dev/null +++ b/tests/integration/megatron/model_support/test_hf_parity_invariants.py @@ -0,0 +1,347 @@ +from types import SimpleNamespace +from typing import Any, cast + +import pytest +import torch + +from art.megatron.model_support.handlers import QWEN3_5_MOE_HANDLER +from art.megatron.model_support.spec import MinimalLayerCoverageReport + +from . import hf_parity as hf_parity_module +from . import hf_parity_worker as hf_parity_worker_module +from .hf_parity import ( + HF_PARITY_OUTPUT_DIRNAME, + HF_PARITY_REPORT_FILENAME, + HfParityReport, + HfParityRunRequest, + build_parity_sample_indices, + build_tensor_map_metric_rows, + run_hf_parity, + set_hf_config_num_layers, +) +from .hf_parity_worker import ( + _build_megatron_runtime, + _filter_language_only_tensor_map, + _is_language_hf_param_name, + _mapping_supports_derivative_parity, + _normalize_hf_grads_for_bridge, + _normalize_hf_tensor_map_for_bridge, +) +from .oracle_harness import DiskPackedTensorsSpec, OracleCaseConfig + + +def test_build_parity_sample_indices_pads_with_none() -> None: + assert build_parity_sample_indices( + num_sequences=2, + global_grad_accumulation_sequences=4, + ) == [0, 1, None, None] + + +def test_set_hf_config_num_layers_updates_supported_field() -> None: + config = SimpleNamespace(num_hidden_layers=28) + + field = set_hf_config_num_layers(config, 4) + + assert field == "num_hidden_layers" + assert config.num_hidden_layers == 4 + + +def test_set_hf_config_num_layers_updates_nested_text_config() -> None: + text_config = SimpleNamespace( + num_hidden_layers=40, + layer_types=["linear_attention", "linear_attention", "full_attention"] * 2, + mlp_only_layers=[1, 4, 7], + ) + config = SimpleNamespace(text_config=text_config) + + field = set_hf_config_num_layers(config, 4) + + assert field == "text_config.num_hidden_layers" + assert text_config.num_hidden_layers == 4 + assert text_config.layer_types == [ + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + ] + assert text_config.mlp_only_layers == [1] + + +def test_run_hf_parity_rejects_uncovered_toy_model(monkeypatch) -> None: + monkeypatch.setattr( + hf_parity_module, + "assess_minimal_layer_coverage", + lambda **_: SimpleNamespace( + covered=False, + missing_layer_families=["standard_attention"], + unresolved_risks=[], + ), + ) + + with pytest.raises( + AssertionError, + match="HF parity toy model does not cover required layer families", + ): + run_hf_parity( + case_config=OracleCaseConfig( + base_model="Qwen/Qwen3.5-35B-A3B", + num_layers=2, + ) + ) + + +def test_run_hf_parity_always_reruns_existing_report( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: + coverage = MinimalLayerCoverageReport( + base_model="Qwen/Qwen3.5-35B-A3B", + model_key="qwen3_5_moe", + requested_num_layers=4, + recommended_min_layers=4, + covered=True, + ) + case_dir = tmp_path / "case" + output_dir = case_dir / HF_PARITY_OUTPUT_DIRNAME + output_dir.mkdir(parents=True) + stale_report = HfParityReport( + case_id="stale", + base_model="Qwen/Qwen3.5-35B-A3B", + model_key="qwen3_5_moe", + requested_num_layers=4, + coverage=coverage, + signal="pass", + pass_count=99, + fail_count=0, + ) + (output_dir / HF_PARITY_REPORT_FILENAME).write_text( + stale_report.model_dump_json(indent=2), + encoding="utf-8", + ) + + monkeypatch.setattr( + hf_parity_module, + "assess_minimal_layer_coverage", + lambda **_: coverage, + ) + monkeypatch.setattr( + hf_parity_module, + "ensure_case_artifacts", + lambda _: SimpleNamespace( + case_id="fresh-case", + case_dir=str(case_dir), + packed_tensors=DiskPackedTensorsSpec( + dir=str(case_dir / "packed"), + num_sequences=4, + sequence_length=8, + ), + ), + ) + calls: list[str] = [] + + def _fake_subprocess(request, run_output_dir): + calls.append(request.case_id) + fresh_report = HfParityReport( + case_id=request.case_id, + base_model=request.case_config.base_model, + model_key=request.coverage.model_key, + requested_num_layers=request.case_config.num_layers, + coverage=request.coverage, + signal="pass", + pass_count=1, + fail_count=0, + ) + (run_output_dir / HF_PARITY_REPORT_FILENAME).write_text( + fresh_report.model_dump_json(indent=2), + encoding="utf-8", + ) + + monkeypatch.setattr(hf_parity_module, "run_hf_parity_subprocess", _fake_subprocess) + + report = run_hf_parity( + case_config=OracleCaseConfig(base_model="Qwen/Qwen3.5-35B-A3B") + ) + + assert calls == ["fresh-case"] + assert report.case_id == "fresh-case" + assert report.pass_count == 1 + + +def test_run_hf_parity_subprocess_does_not_override_recompute( + monkeypatch, tmp_path +) -> None: + request = HfParityRunRequest( + case_id="case-id", + case_config=OracleCaseConfig(base_model="Qwen/Qwen3.5-35B-A3B"), + packed_tensors=DiskPackedTensorsSpec( + dir=str(tmp_path / "packed"), + num_sequences=4, + sequence_length=8, + ), + output_dir=str(tmp_path), + coverage=MinimalLayerCoverageReport( + base_model="Qwen/Qwen3.5-35B-A3B", + model_key="qwen3_5_moe", + requested_num_layers=4, + recommended_min_layers=4, + covered=True, + ), + ) + captured: dict[str, Any] = {} + + def _fake_run(*args, **kwargs): + del args + captured.update(kwargs) + return SimpleNamespace(returncode=0, stdout="", stderr="") + + monkeypatch.setattr(hf_parity_module.subprocess, "run", _fake_run) + + hf_parity_module.run_hf_parity_subprocess(request, tmp_path) + + env = cast(dict[str, str], captured["env"]) + assert "ART_MEGATRON_RECOMPUTE_GRANULARITY" not in env + assert "ART_MEGATRON_RECOMPUTE_METHOD" not in env + assert "ART_MEGATRON_RECOMPUTE_NUM_LAYERS" not in env + assert "ART_MEGATRON_RECOMPUTE_MODULES" not in env + + +def test_normalize_hf_tensor_map_for_bridge_adds_language_model_prefix() -> None: + normalized = _normalize_hf_tensor_map_for_bridge( + { + "model.layers.0.input_layernorm.weight": torch.ones(1), + "lm_head.weight": torch.ones(1), + }, + { + "model.language_model.layers.0.input_layernorm.weight", + "lm_head.weight", + }, + ) + + assert set(normalized) == { + "model.language_model.layers.0.input_layernorm.weight", + "lm_head.weight", + } + + +def test_build_tensor_map_metric_rows_rejects_tensor_set_mismatch() -> None: + rows = build_tensor_map_metric_rows( + phase="grads", + reference={"a": torch.ones(1)}, + candidate={"b": torch.ones(1)}, + ) + + assert len(rows) == 1 + assert rows[0].param == "__tensor_set__" + assert rows[0].pass_signal is False + assert "missing=['a'] extra=['b']" in rows[0].failure_reasons[0] + + +def test_build_tensor_map_metric_rows_enforces_nonzero_per_tensor() -> None: + rows = build_tensor_map_metric_rows( + phase="grads", + reference={"all_zero": torch.zeros(2), "active": torch.ones(2)}, + candidate={"all_zero": torch.zeros(2), "active": torch.ones(2)}, + ) + by_param = {row.param: row for row in rows} + + assert by_param["all_zero"].pass_signal is False + assert by_param["active"].pass_signal is True + + +def test_language_hf_param_filter_keeps_text_and_drops_visual() -> None: + assert _is_language_hf_param_name("model.layers.0.self_attn.q_proj.weight") is True + assert _is_language_hf_param_name("model.visual.blocks.0.attn.qkv.weight") is False + filtered = _filter_language_only_tensor_map( + { + "model.layers.0.self_attn.q_proj.weight": torch.ones(1), + "model.visual.blocks.0.attn.qkv.weight": torch.ones(1), + } + ) + assert set(filtered) == {"model.layers.0.self_attn.q_proj.weight"} + assert torch.equal( + filtered["model.layers.0.self_attn.q_proj.weight"], + torch.ones(1), + ) + + +def test_normalize_hf_grads_for_bridge_keeps_expected_key_set() -> None: + normalized = _normalize_hf_grads_for_bridge( + { + "model.layers.0.input_layernorm.weight": torch.ones(1), + "lm_head.weight": torch.ones(1), + "model.visual.blocks.0.attn.qkv.weight": torch.ones(1), + }, + expected_grad_keys={ + "model.language_model.layers.0.input_layernorm.weight", + "lm_head.weight", + }, + model_support_handler=QWEN3_5_MOE_HANDLER, + ) + + assert set(normalized) == { + "model.language_model.layers.0.input_layernorm.weight", + "lm_head.weight", + } + + +def test_build_megatron_runtime_uses_training_provider_bundle( + monkeypatch: pytest.MonkeyPatch, +) -> None: + calls: list[dict[str, Any]] = [] + runtime = SimpleNamespace(provider="provider", model=["model"]) + + monkeypatch.setattr( + hf_parity_worker_module.megatron_train, + "build_training_runtime", + lambda **kwargs: calls.append(kwargs) or runtime, + ) + + request = HfParityRunRequest( + case_id="case", + case_config=OracleCaseConfig(base_model="Qwen/Qwen3.5-35B-A3B"), + packed_tensors=DiskPackedTensorsSpec( + dir="/tmp", num_sequences=4, sequence_length=8 + ), + output_dir="/tmp/out", + coverage=MinimalLayerCoverageReport( + base_model="Qwen/Qwen3.5-35B-A3B", + model_key="qwen3_5_moe", + requested_num_layers=4, + recommended_min_layers=4, + covered=True, + ), + ) + + built_runtime = _build_megatron_runtime(request) + + assert built_runtime is runtime + assert len(calls) == 1 + kwargs = calls[0] + assert kwargs["model_identifier"] == "Qwen/Qwen3.5-35B-A3B" + assert kwargs["provider_torch_dtype"] == torch.float32 + assert ( + kwargs["provider_bundle_configure"] + is hf_parity_worker_module._install_bridge_timing_debug + ) + assert kwargs["print_env"] is False + assert kwargs["trainable_parameter_mode"] == "base_model" + configured_provider = SimpleNamespace() + kwargs["provider_configure"](configured_provider) + optimizer_config = kwargs["optimizer_config"] + assert configured_provider.num_layers == request.case_config.num_layers + assert optimizer_config.params_dtype == torch.float32 + + +def test_mapping_supports_derivative_parity_rejects_affine_weight_exports() -> None: + from megatron.bridge.models.conversion.param_mapping import ( + AutoMapping, + RMSNorm2ZeroCenteredRMSNormMapping, + ) + + assert _mapping_supports_derivative_parity(AutoMapping("a", "b")) is True + assert ( + _mapping_supports_derivative_parity( + RMSNorm2ZeroCenteredRMSNormMapping("a", "b") + ) + is False + ) diff --git a/tests/integration/megatron/model_support/test_inputs.py b/tests/integration/megatron/model_support/test_inputs.py new file mode 100644 index 000000000..817ef18b4 --- /dev/null +++ b/tests/integration/megatron/model_support/test_inputs.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import torch + +from art.preprocessing.pack import PackedTensors + + +@torch.no_grad() +def build_sft_trajectory_tensors_from_packed_tensors( + packed_tensors: PackedTensors, +) -> list[dict[str, torch.Tensor]]: + tokens = packed_tensors["tokens"] + assistant_mask = packed_tensors["assistant_mask"] + labels = torch.where(assistant_mask, tokens, torch.full_like(tokens, -100)) + attention_mask = torch.ones_like(tokens, dtype=torch.long) + return [ + { + "input_ids": tokens[index].detach().clone(), + "attention_mask": attention_mask[index].detach().clone(), + "labels": labels[index].detach().clone(), + } + for index in range(int(tokens.shape[0])) + ] diff --git a/tests/integration/test_megatron_lora_oracle_correctness.py b/tests/integration/megatron/model_support/test_lora_oracle_correctness.py similarity index 79% rename from tests/integration/test_megatron_lora_oracle_correctness.py rename to tests/integration/megatron/model_support/test_lora_oracle_correctness.py index 0f02c4052..c66e87482 100644 --- a/tests/integration/test_megatron_lora_oracle_correctness.py +++ b/tests/integration/megatron/model_support/test_lora_oracle_correctness.py @@ -4,21 +4,18 @@ import pytest -from .megatron_oracle_harness import ( - EXTENDED_TOPOLOGIES, +from .oracle_harness import ( + ORACLE_TOPOLOGY, SENSITIVITY_MUTATION_ENV, - TOPOLOGIES, available_gpu_count, case_config, - extended_topologies_enabled, run_sensitivity_suite, run_suite, sensitivity_enabled, sensitivity_mutations, - sensitivity_required_world_size, ) -REPO_ROOT = Path(__file__).resolve().parents[2] +REPO_ROOT = Path(__file__).resolve().parents[4] CORRECTNESS_LOG_PATH = REPO_ROOT / ".local" / "correctness.log" SENSITIVITY_LOG_PATH = REPO_ROOT / ".local" / "sensitivity.log" @@ -51,34 +48,27 @@ def _require_gpus_for(topology_world_size: int) -> None: ) -def _suite_world_size() -> int: - suite_topologies = list(TOPOLOGIES) - if extended_topologies_enabled(): - suite_topologies.extend(EXTENDED_TOPOLOGIES) - return max(topology.world_size() for topology in suite_topologies) - - def test_megatron_lora_topology_suite(capsys: pytest.CaptureFixture[str]) -> None: """ Runs the suite of topologies and expects each to pass (numerical differences within our thresholds) """ _announce_report_log(log_path=CORRECTNESS_LOG_PATH, capsys=capsys) - suite_world_size = _suite_world_size() gpu_count = available_gpu_count() - if gpu_count < suite_world_size: + if gpu_count < ORACLE_TOPOLOGY.world_size(): CORRECTNESS_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) CORRECTNESS_LOG_PATH.write_text( ( "Topology suite skipped. " - f"Need {suite_world_size} GPUs, found {gpu_count}.\n" + f"Need {ORACLE_TOPOLOGY.world_size()} GPUs, found {gpu_count}.\n" ), encoding="utf-8", ) - _require_gpus_for(suite_world_size) + _require_gpus_for(ORACLE_TOPOLOGY.world_size()) _run_suite_with_log( log_path=CORRECTNESS_LOG_PATH, run=lambda: run_suite( case_config=case_config(), + max_world_size=gpu_count, ), ) @@ -105,22 +95,22 @@ def test_megatron_lora_diff_sensitivity(capsys: pytest.CaptureFixture[str]) -> N ) mutations = sensitivity_mutations() assert mutations - sensitivity_world_size = sensitivity_required_world_size(mutations) gpu_count = available_gpu_count() - if gpu_count < sensitivity_world_size: + if gpu_count < ORACLE_TOPOLOGY.world_size(): SENSITIVITY_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) SENSITIVITY_LOG_PATH.write_text( ( "Sensitivity suite skipped. " - f"Need {sensitivity_world_size} GPUs, found {gpu_count}.\n" + f"Need {ORACLE_TOPOLOGY.world_size()} GPUs, found {gpu_count}.\n" ), encoding="utf-8", ) - _require_gpus_for(sensitivity_world_size) + _require_gpus_for(ORACLE_TOPOLOGY.world_size()) _run_suite_with_log( log_path=SENSITIVITY_LOG_PATH, run=lambda: run_sensitivity_suite( case_config=case_config(), mutations=mutations, + max_world_size=gpu_count, ), ) diff --git a/tests/integration/megatron/model_support/test_oracle_harness_invariants.py b/tests/integration/megatron/model_support/test_oracle_harness_invariants.py new file mode 100644 index 000000000..194b4d24d --- /dev/null +++ b/tests/integration/megatron/model_support/test_oracle_harness_invariants.py @@ -0,0 +1,131 @@ +import torch + +from .forward_trace import ForwardTraceCapture +from .oracle_harness import ( + DENSE_ORACLE_TOPOLOGY, + ORACLE_TOPOLOGY, + DiffAccumulator, + MetricThresholdRule, + _default_phase_pass_fns, + _suite_variants, + selected_sensitivity_mutations_for_objective, +) + + +def test_metric_threshold_rule_can_require_strictly_positive_values() -> None: + rule = MetricThresholdRule(minimums={"candidate_abs_scale": 0.0}) + + summary = {"candidate_abs_scale": 0.0} + + assert not rule(summary) + assert rule.failure_reasons(summary) == ["candidate_abs_scale=0<=0"] + + +def test_diff_accumulator_summary_tracks_candidate_abs_scale() -> None: + accumulator = DiffAccumulator() + + accumulator.update( + torch.tensor([1.0, -2.0], dtype=torch.float32), + torch.tensor([0.5, 0.0], dtype=torch.float32), + ) + + summary = accumulator.as_summary() + + assert summary["typical_abs_scale"] == 1.5 + assert summary["candidate_abs_scale"] == 0.25 + + +def test_default_phase_rules_require_non_zero_forward_outputs_losses_grads_and_deltas() -> ( + None +): + phase_pass = _default_phase_pass_fns() + zero_signal_summary = { + "relative_l2": 0.0, + "mean_abs_pct": 0.0, + "typical_abs_scale": 0.0, + "candidate_abs_scale": 0.0, + } + + assert not phase_pass["forward"](zero_signal_summary) + assert not phase_pass["outputs"](zero_signal_summary) + assert not phase_pass["losses"](zero_signal_summary) + assert not phase_pass["grads"](zero_signal_summary) + assert not phase_pass["deltas"](zero_signal_summary) + + +def test_suite_variants_skip_duplicate_oracle_replay_variant() -> None: + variants = _suite_variants("rl", is_moe=True) + + assert variants + assert all(variant.topology != ORACLE_TOPOLOGY for variant in variants) + assert all("oracle_replay" not in variant.name for variant in variants) + + +def test_dense_suite_variants_include_tp2_dp2_without_oracle_duplicate() -> None: + variants = _suite_variants("rl", is_moe=False) + + assert variants + assert all(variant.topology != DENSE_ORACLE_TOPOLOGY for variant in variants) + assert any( + variant.topology.tp == 2 and variant.topology.dp == 2 for variant in variants + ) + + +def test_moe_suite_variants_include_dp2_ep_and_etp_topologies() -> None: + variants = _suite_variants("rl", is_moe=True) + + assert any( + variant.topology.tp == 1 + and variant.topology.ep == 2 + and variant.topology.dp == 2 + for variant in variants + ) + assert any( + variant.topology.tp == 1 + and variant.topology.etp == 2 + and variant.topology.dp == 2 + for variant in variants + ) + + +def test_max_world_size_arg_filters_dense_variants() -> None: + variants = _suite_variants("rl", is_moe=False, max_world_size=2) + + assert variants + assert all(variant.topology.world_size() <= 2 for variant in variants) + assert not any( + variant.topology.tp == 2 and variant.topology.dp == 2 for variant in variants + ) + + +def test_max_world_size_arg_filters_sensitivity_mutations() -> None: + mutations = selected_sensitivity_mutations_for_objective( + "rl", + ["skip_finalize", "dp_local_token_normalization"], + is_moe=True, + max_world_size=1, + ) + + assert mutations == [] + + +def test_gate_up_rank_interleaved_trace_layout_canonicalizes_dense_tp() -> None: + canonical = torch.arange(16, dtype=torch.float32).reshape(2, 1, 8) + gate0, gate1, up0, up1 = canonical.chunk(4, dim=-1) + rank_concat = torch.cat((gate0, up0, gate1, up1), dim=-1) + + actual = ForwardTraceCapture._canonicalize_primary_output_tensor( + module_name="chunk0.module.decoder.layers.0.mlp.linear_fc1", + tensor=rank_concat, + call={ + "merge_hints": { + "primary_output": { + "layout": "gate_up_rank_interleaved", + "world_size_key": "tp_world_size", + } + }, + "rank_meta": [{"tp_world_size": 2}, {"tp_world_size": 2}], + }, + ) + + assert torch.equal(actual, canonical) diff --git a/tests/integration/megatron/model_support/test_packed_position_ids.py b/tests/integration/megatron/model_support/test_packed_position_ids.py new file mode 100644 index 000000000..d3f2abf0f --- /dev/null +++ b/tests/integration/megatron/model_support/test_packed_position_ids.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import pytest + +torch = pytest.importorskip("torch") +pytest.importorskip("megatron.bridge") + +from .packed_position_ids import run_packed_position_ids + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA is required for packed position id validation", +) +def test_run_packed_position_ids_qwen35() -> None: + report = run_packed_position_ids( + base_model="Qwen/Qwen3.5-35B-A3B", + ) + + assert len(report.scenarios) == 2 + assert all(scenario.matched for scenario in report.scenarios) + assert all(scenario.checked_token_count > 0 for scenario in report.scenarios) + assert all(scenario.prompt_family_count >= 2 for scenario in report.scenarios) + assert all(scenario.rotary_grouping_checked for scenario in report.scenarios) + assert all( + scenario.repeated_position_key_count > 0 for scenario in report.scenarios + ) + assert all(scenario.completion_pair_count > 0 for scenario in report.scenarios) + assert all(scenario.logits_mean_abs_pct <= 0.2 for scenario in report.scenarios) diff --git a/tests/integration/megatron/model_support/test_provider_support.py b/tests/integration/megatron/model_support/test_provider_support.py new file mode 100644 index 000000000..7f1ce9703 --- /dev/null +++ b/tests/integration/megatron/model_support/test_provider_support.py @@ -0,0 +1,329 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast + +import pytest + +pytest.importorskip("megatron.bridge") + +from megatron.core.transformer.enums import AttnBackend + +from art.megatron.flex_attention import FlexDotProductAttention +from art.megatron.model_support.registry import UnsupportedModelArchitectureError +import art.megatron.provider as provider_module + + +class _FakeProvider: + def __init__(self) -> None: + self.transformer_layer_spec = self._base_layer_spec + self.finalized = False + self.overlap_moe_expert_parallel_comm = False + self.num_moe_experts = 0 + + def _base_layer_spec( + self, config: object, vp_stage: int | None = None + ) -> SimpleNamespace: + del config, vp_stage + return SimpleNamespace( + submodules=SimpleNamespace( + self_attention=SimpleNamespace( + submodules=SimpleNamespace(core_attention=object()) + ) + ), + ) + + def finalize(self) -> None: + self.finalized = True + + +class _FakeHybridProvider(_FakeProvider): + def _base_layer_spec( + self, config: object, vp_stage: int | None = None + ) -> SimpleNamespace: + del config, vp_stage + gdn_layer = SimpleNamespace( + submodules=SimpleNamespace( + self_attention=SimpleNamespace(submodules=SimpleNamespace()) + ) + ) + attention_layer = SimpleNamespace( + submodules=SimpleNamespace( + self_attention=SimpleNamespace( + submodules=SimpleNamespace(core_attention=object()) + ) + ), + ) + return SimpleNamespace(layer_specs=[gdn_layer, attention_layer]) + + +class _FakeBridge: + def __init__(self, *, model_bridge: object, provider: _FakeProvider) -> None: + self._model_bridge = model_bridge + self._provider = provider + self.hf_pretrained = SimpleNamespace(model_name_or_path="unused") + + def to_megatron_provider(self) -> _FakeProvider: + return self._provider + + +def test_get_provider_accepts_registry_supported_models( + monkeypatch: pytest.MonkeyPatch, +) -> None: + provider = _FakeProvider() + provider.num_moe_experts = 8 + fake_bridge = _FakeBridge( + model_bridge=object(), + provider=provider, + ) + monkeypatch.setattr( + provider_module.AutoBridge, + "from_hf_pretrained", + lambda *args, **kwargs: fake_bridge, + ) + monkeypatch.setattr(provider_module.torch.cuda, "device_count", lambda: 2) + + resolved = provider_module.get_provider("Qwen/Qwen3-30B-A3B-Instruct-2507") + + assert resolved is provider + assert provider.finalized is True + assert resolved.attention_backend is AttnBackend.auto + assert resolved.recompute_granularity == "full" + assert resolved.recompute_method == "uniform" + assert resolved.recompute_num_layers == 1 + assert resolved.tensor_model_parallel_size == 2 + assert resolved.context_parallel_size == 1 + assert resolved.pipeline_model_parallel_size == 1 + assert resolved.expert_model_parallel_size == 2 + assert resolved.expert_tensor_parallel_size == 1 + assert resolved.sequence_parallel is True + assert resolved.moe_shared_expert_overlap is False + assert resolved.moe_router_dtype == "fp32" + assert resolved.moe_aux_loss_coeff == 0.0 + assert resolved.calculate_per_token_loss is True + + layer_spec = cast(Any, resolved.transformer_layer_spec)(resolved, vp_stage=7) + assert ( + layer_spec.submodules.self_attention.submodules.core_attention + is FlexDotProductAttention + ) + + +def test_qwen35_provider_uses_handler_shared_expert_runtime_default( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from art.megatron.model_support.handlers import qwen3_5 as qwen35_handler_module + + provider = _FakeProvider() + fake_bridge = _FakeBridge( + model_bridge=object(), + provider=provider, + ) + monkeypatch.setattr( + provider_module.AutoBridge, + "from_hf_pretrained", + lambda *args, **kwargs: fake_bridge, + ) + monkeypatch.setattr(provider_module.torch.cuda, "device_count", lambda: 2) + monkeypatch.setattr( + qwen35_handler_module, + "_qwen35_provider_types", + lambda: (_FakeProvider,), + ) + monkeypatch.setattr( + qwen35_handler_module, + "_require_qwen35_provider_symbols", + lambda: ( + object(), + (_FakeProvider,), + lambda block_spec, attention_module: None, + provider._base_layer_spec, + ), + ) + + resolved = provider_module.get_provider("Qwen/Qwen3.5-35B-A3B") + + assert resolved.moe_shared_expert_overlap is False + assert resolved.scatter_embedding_sequence_parallel is True + + +def test_get_provider_rejects_unregistered_model_before_bridge( + monkeypatch: pytest.MonkeyPatch, +) -> None: + def from_hf_pretrained(*args: object, **kwargs: object) -> object: + raise AssertionError("AutoBridge should not be called for unsupported models") + + monkeypatch.setattr( + provider_module.AutoBridge, "from_hf_pretrained", from_hf_pretrained + ) + + with pytest.raises( + UnsupportedModelArchitectureError, + match="has not passed the Megatron model-support workflow", + ): + provider_module.get_provider("unsupported/model") + + +def test_get_provider_preserves_hybrid_layer_specs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + provider = _FakeHybridProvider() + fake_bridge = _FakeBridge( + model_bridge=object(), + provider=provider, + ) + monkeypatch.setattr( + provider_module.AutoBridge, + "from_hf_pretrained", + lambda *args, **kwargs: fake_bridge, + ) + monkeypatch.setattr(provider_module.torch.cuda, "device_count", lambda: 1) + + resolved = provider_module.get_provider( + "unused-qwen", + allow_unvalidated_arch=True, + ) + layer_spec = cast(Any, resolved).transformer_layer_spec(resolved, vp_stage=0) + + assert hasattr(layer_spec, "layer_specs") + gdn_layer, attention_layer = cast(Any, layer_spec).layer_specs + assert not hasattr(gdn_layer.submodules.self_attention.submodules, "core_attention") + assert ( + attention_layer.submodules.self_attention.submodules.core_attention + is FlexDotProductAttention + ) + + +def test_finalize_provider_bundle_uses_post_prepare_topology( + monkeypatch: pytest.MonkeyPatch, +) -> None: + provider = _FakeProvider() + setattr(provider, "num_moe_experts", 8) + fake_bridge = _FakeBridge( + model_bridge=object(), + provider=provider, + ) + dispatcher_calls: list[tuple[int, int, str]] = [] + monkeypatch.setattr( + provider_module.AutoBridge, + "from_hf_pretrained", + lambda *args, **kwargs: fake_bridge, + ) + monkeypatch.setattr(provider_module.torch.cuda, "device_count", lambda: 2) + monkeypatch.setattr( + provider_module, + "apply_flex_dispatcher_backend", + lambda provider, moe_flex_dispatcher_backend: dispatcher_calls.append( + ( + int(provider.tensor_model_parallel_size), + int(provider.expert_model_parallel_size), + cast(str, moe_flex_dispatcher_backend), + ) + ), + ) + + bundle = provider_module.prepare_provider_bundle("Qwen/Qwen3-30B-A3B-Instruct-2507") + + assert provider.finalized is False + assert getattr(provider, "tensor_model_parallel_size") == 2 + assert getattr(provider, "expert_model_parallel_size") == 2 + + bundle.provider.tensor_model_parallel_size = 1 + bundle.provider.expert_model_parallel_size = 1 + bundle.provider.sequence_parallel = False + provider_module.finalize_provider_bundle(bundle) + + assert dispatcher_calls == [] + assert provider.finalized is True + assert getattr(provider, "sequence_parallel") is False + + +def test_get_provider_bundle_honors_single_gpu_env_topology( + monkeypatch: pytest.MonkeyPatch, +) -> None: + provider = _FakeProvider() + fake_bridge = _FakeBridge( + model_bridge=object(), + provider=provider, + ) + monkeypatch.setattr( + provider_module.AutoBridge, + "from_hf_pretrained", + lambda *args, **kwargs: fake_bridge, + ) + monkeypatch.setattr(provider_module.torch.cuda, "device_count", lambda: 2) + monkeypatch.setenv("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "1") + monkeypatch.setenv("ART_MEGATRON_EXPERT_MODEL_PARALLEL_SIZE", "1") + monkeypatch.setenv("ART_MEGATRON_EXPERT_TENSOR_PARALLEL_SIZE", "1") + + bundle = provider_module.get_provider_bundle("Qwen/Qwen3-30B-A3B-Instruct-2507") + resolved = bundle.provider + + assert resolved.tensor_model_parallel_size == 1 + assert resolved.context_parallel_size == 1 + assert resolved.pipeline_model_parallel_size == 1 + assert resolved.expert_model_parallel_size == 1 + assert resolved.expert_tensor_parallel_size == 1 + assert resolved.sequence_parallel is False + assert resolved.recompute_granularity == "full" + assert resolved.recompute_method == "uniform" + assert resolved.recompute_num_layers == 1 + + layer_spec = resolved.transformer_layer_spec(resolved, vp_stage=0) + assert ( + layer_spec.submodules.self_attention.submodules.core_attention + is FlexDotProductAttention + ) + + +def test_get_provider_bundle_disables_recompute_from_env( + monkeypatch: pytest.MonkeyPatch, +) -> None: + provider = _FakeProvider() + fake_bridge = _FakeBridge( + model_bridge=object(), + provider=provider, + ) + monkeypatch.setattr( + provider_module.AutoBridge, + "from_hf_pretrained", + lambda *args, **kwargs: fake_bridge, + ) + monkeypatch.setattr(provider_module.torch.cuda, "device_count", lambda: 1) + monkeypatch.setenv("ART_MEGATRON_RECOMPUTE_GRANULARITY", "disabled") + monkeypatch.setenv("ART_MEGATRON_RECOMPUTE_METHOD", "disabled") + monkeypatch.setenv("ART_MEGATRON_RECOMPUTE_NUM_LAYERS", "disabled") + monkeypatch.setenv("ART_MEGATRON_RECOMPUTE_MODULES", "disabled") + + resolved = provider_module.get_provider("Qwen/Qwen3-30B-A3B-Instruct-2507") + + assert resolved.recompute_granularity is None + assert resolved.recompute_method is None + assert resolved.recompute_num_layers is None + assert resolved.recompute_modules is None + + +def test_get_provider_bundle_honors_expert_parallel_env_overrides( + monkeypatch: pytest.MonkeyPatch, +) -> None: + provider = _FakeProvider() + fake_bridge = _FakeBridge( + model_bridge=object(), + provider=provider, + ) + monkeypatch.setattr( + provider_module.AutoBridge, + "from_hf_pretrained", + lambda *args, **kwargs: fake_bridge, + ) + monkeypatch.setattr(provider_module.torch.cuda, "device_count", lambda: 4) + monkeypatch.setenv("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "2") + monkeypatch.setenv("ART_MEGATRON_EXPERT_MODEL_PARALLEL_SIZE", "1") + monkeypatch.setenv("ART_MEGATRON_EXPERT_TENSOR_PARALLEL_SIZE", "2") + + resolved = provider_module.get_provider("Qwen/Qwen3-30B-A3B-Instruct-2507") + + assert resolved.tensor_model_parallel_size == 2 + assert resolved.expert_model_parallel_size == 1 + assert resolved.expert_tensor_parallel_size == 2 + assert resolved.sequence_parallel is True diff --git a/tests/integration/megatron/model_support/test_workflow.py b/tests/integration/megatron/model_support/test_workflow.py new file mode 100644 index 000000000..551578402 --- /dev/null +++ b/tests/integration/megatron/model_support/test_workflow.py @@ -0,0 +1,888 @@ +from types import SimpleNamespace + +from art.megatron.model_support.spec import ( + ArchitectureReport, + LayerFamilyInstance, + ValidationStageResult, +) + +from .workflow import ( + MANDATORY_VALIDATION_STAGES, + NATIVE_VLLM_LORA_STAGE, + SKIP_SENSITIVITY_ENV, + assess_minimal_layer_coverage, + build_validation_report, + build_validation_stage_names, + run_chat_template_rollout_stage, + run_correctness_sensitivity_stage, + run_lora_coverage_stage, + run_merged_vllm_serving_stage, + run_native_vllm_lora_stage, + run_packed_position_ids_stage, + run_train_inf_mismatch_stage, + run_yes_no_trainability_stage, +) + + +def test_build_validation_stage_names_has_fixed_order() -> None: + assert build_validation_stage_names() == list(MANDATORY_VALIDATION_STAGES) + assert build_validation_stage_names(include_native_vllm_lora=True) == [ + *MANDATORY_VALIDATION_STAGES, + NATIVE_VLLM_LORA_STAGE, + ] + assert build_validation_stage_names(native_vllm_lora_status="wip") == [ + *MANDATORY_VALIDATION_STAGES, + NATIVE_VLLM_LORA_STAGE, + ] + + +def test_build_validation_report_populates_architecture_stage( + monkeypatch, +) -> None: + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow.inspect_architecture", + lambda base_model: ArchitectureReport( + base_model=base_model, + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + layer_families=[LayerFamilyInstance(key="standard_attention", count=2)], + recommended_min_layers=1, + ), + ) + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow.detect_dependency_versions", + lambda: {"transformers": "5.2.0"}, + ) + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow._run_stage_in_subprocess", + lambda *, stage_name, base_model, architecture, allow_unvalidated_arch=False: { + "hf_parity": ValidationStageResult( + name="hf_parity", + passed=True, + metrics={"signal": "pass", "requested_num_layers": 1}, + artifact_dir="/tmp/hf_parity", + ), + "lora_coverage": ValidationStageResult( + name="lora_coverage", + passed=True, + metrics={"wrapped_adapter_prefix_count": 12}, + ), + "train_inf_mismatch": ValidationStageResult( + name="train_inf_mismatch", + passed=True, + metrics={"passed_count": 1, "failed_count": 0}, + artifact_dir="/tmp/train-inf-mismatch", + ), + "merged_vllm_serving": ValidationStageResult( + name="merged_vllm_serving", + passed=True, + metrics={"served_model_name": "validation@0"}, + artifact_dir="/tmp/merged-serving", + ), + "correctness_sensitivity": ValidationStageResult( + name="correctness_sensitivity", + passed=True, + metrics={ + "correctness_variant_count": 4, + "sensitivity_variant_count": 9, + }, + artifact_dir="/tmp/correctness", + ), + "chat_template_rollout": ValidationStageResult( + name="chat_template_rollout", + passed=True, + metrics={ + "passed": True, + "scenario_count": 6, + "failed_scenarios": [], + }, + artifact_dir="/tmp/chat-template", + ), + "packed_position_ids": ValidationStageResult( + name="packed_position_ids", + passed=True, + metrics={ + "num_layers": 4, + "scenarios": [ + { + "name": "stop_early", + "matched": True, + "checked_token_count": 40, + } + ], + }, + artifact_dir="/tmp/packed-position-ids", + ), + "yes_no_trainability": ValidationStageResult( + name="yes_no_trainability", + passed=True, + metrics={ + "latest_step": 3, + "final_eval_reward": 0.97, + }, + artifact_dir="/tmp/trainability", + ), + "native_vllm_lora": ValidationStageResult( + name="native_vllm_lora", + passed=True, + metrics={ + "rollout_weights_mode": "lora", + "step0_name": "validation@0", + "step1_name": "validation@1", + "model_ids_before": ["validation@0"], + "model_ids_after": ["validation@0", "validation@1"], + "step0_served": True, + "step1_served": True, + }, + artifact_dir="/tmp/native-vllm-lora", + ), + }[stage_name], + ) + + report = build_validation_report(base_model="Qwen/Qwen3.5-35B-A3B") + + assert report.base_model == "Qwen/Qwen3.5-35B-A3B" + assert report.model_key == "qwen3_5_moe" + assert report.dependency_versions == {"transformers": "5.2.0"} + dependency_stage = next( + stage for stage in report.stages if stage.name == "dependency_resolution" + ) + assert dependency_stage.passed is True + assert dependency_stage.metrics == {"transformers": "5.2.0"} + architecture_stage = next( + stage for stage in report.stages if stage.name == "architecture_discovery" + ) + assert architecture_stage.passed is True + assert architecture_stage.metrics == { + "recommended_min_layers": 1, + "layer_families": [ + { + "key": "standard_attention", + "count": 2, + "layer_index": None, + "module_path": None, + "module_type": None, + } + ], + "unresolved_risks": [], + } + hf_parity_stage = next( + stage for stage in report.stages if stage.name == "hf_parity" + ) + assert hf_parity_stage.passed is True + assert hf_parity_stage.metrics == {"signal": "pass", "requested_num_layers": 1} + assert hf_parity_stage.artifact_dir == "/tmp/hf_parity" + lora_coverage_stage = next( + stage for stage in report.stages if stage.name == "lora_coverage" + ) + assert lora_coverage_stage.passed is True + assert lora_coverage_stage.metrics == {"wrapped_adapter_prefix_count": 12} + mismatch_stage = next( + stage for stage in report.stages if stage.name == "train_inf_mismatch" + ) + assert mismatch_stage.passed is True + assert mismatch_stage.metrics == {"passed_count": 1, "failed_count": 0} + assert mismatch_stage.artifact_dir == "/tmp/train-inf-mismatch" + correctness_stage = next( + stage for stage in report.stages if stage.name == "correctness_sensitivity" + ) + assert correctness_stage.passed is True + assert correctness_stage.metrics == { + "correctness_variant_count": 4, + "sensitivity_variant_count": 9, + } + assert correctness_stage.artifact_dir == "/tmp/correctness" + merged_stage = next( + stage for stage in report.stages if stage.name == "merged_vllm_serving" + ) + assert merged_stage.passed is True + assert merged_stage.metrics == {"served_model_name": "validation@0"} + assert merged_stage.artifact_dir == "/tmp/merged-serving" + chat_template_stage = next( + stage for stage in report.stages if stage.name == "chat_template_rollout" + ) + assert chat_template_stage.passed is True + assert chat_template_stage.metrics == { + "passed": True, + "scenario_count": 6, + "failed_scenarios": [], + } + assert chat_template_stage.artifact_dir == "/tmp/chat-template" + position_id_stage = next( + stage for stage in report.stages if stage.name == "packed_position_ids" + ) + assert position_id_stage.passed is True + assert position_id_stage.metrics == { + "num_layers": 4, + "scenarios": [ + { + "name": "stop_early", + "matched": True, + "checked_token_count": 40, + } + ], + } + assert position_id_stage.artifact_dir == "/tmp/packed-position-ids" + trainability_stage = next( + stage for stage in report.stages if stage.name == "yes_no_trainability" + ) + assert trainability_stage.passed is True + assert trainability_stage.metrics == { + "latest_step": 3, + "final_eval_reward": 0.97, + } + assert trainability_stage.artifact_dir == "/tmp/trainability" + native_vllm_lora_stage = next( + stage for stage in report.stages if stage.name == "native_vllm_lora" + ) + assert native_vllm_lora_stage.passed is True + assert native_vllm_lora_stage.metrics == { + "rollout_weights_mode": "lora", + "step0_name": "validation@0", + "step1_name": "validation@1", + "model_ids_before": ["validation@0"], + "model_ids_after": ["validation@0", "validation@1"], + "step0_served": True, + "step1_served": True, + } + assert native_vllm_lora_stage.artifact_dir == "/tmp/native-vllm-lora" + + +def test_build_validation_report_captures_hf_parity_failure(monkeypatch) -> None: + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow.inspect_architecture", + lambda base_model: ArchitectureReport( + base_model=base_model, + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + layer_families=[], + recommended_min_layers=4, + ), + ) + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow.detect_dependency_versions", + lambda: {}, + ) + + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow._run_stage_in_subprocess", + lambda *, stage_name, base_model, architecture, allow_unvalidated_arch=False: ( + ValidationStageResult( + name="hf_parity", + passed=False, + metrics={"error": "AssertionError: parity failed"}, + ) + if stage_name == "hf_parity" + else ValidationStageResult( + name=stage_name, + passed=True, + metrics={}, + ) + ), + ) + + report = build_validation_report(base_model="Qwen/Qwen3.5-35B-A3B") + + hf_parity_stage = next( + stage for stage in report.stages if stage.name == "hf_parity" + ) + assert hf_parity_stage.passed is False + assert hf_parity_stage.metrics == {"error": "AssertionError: parity failed"} + assert hf_parity_stage.artifact_dir is None + + +def test_build_validation_report_captures_lora_coverage_failure(monkeypatch) -> None: + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow.inspect_architecture", + lambda base_model: ArchitectureReport( + base_model=base_model, + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + layer_families=[], + recommended_min_layers=4, + ), + ) + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow.detect_dependency_versions", + lambda: {}, + ) + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow._run_stage_in_subprocess", + lambda *, stage_name, base_model, architecture, allow_unvalidated_arch=False: ( + ValidationStageResult( + name="lora_coverage", + passed=False, + metrics={"error": "RuntimeError: missing wrapped targets"}, + ) + if stage_name == "lora_coverage" + else ValidationStageResult( + name=stage_name, + passed=True, + metrics={}, + ) + ), + ) + + report = build_validation_report(base_model="Qwen/Qwen3.5-35B-A3B") + + lora_coverage_stage = next( + stage for stage in report.stages if stage.name == "lora_coverage" + ) + assert lora_coverage_stage.passed is False + assert lora_coverage_stage.metrics == { + "error": "RuntimeError: missing wrapped targets" + } + + +def test_assess_minimal_layer_coverage_reports_missing_families( + monkeypatch, +) -> None: + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow.inspect_architecture", + lambda base_model: ArchitectureReport( + base_model=base_model, + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + layer_families=[ + LayerFamilyInstance(key="gated_delta_net_attention", layer_index=0), + LayerFamilyInstance(key="standard_attention", layer_index=3), + LayerFamilyInstance(key="grouped_moe_mlp", layer_index=0), + LayerFamilyInstance(key="shared_experts_mlp", layer_index=0), + ], + recommended_min_layers=4, + ), + ) + + coverage = assess_minimal_layer_coverage( + base_model="Qwen/Qwen3.5-35B-A3B", + num_layers=2, + ) + + assert coverage.covered is False + assert coverage.requested_num_layers == 2 + assert coverage.recommended_min_layers == 4 + assert coverage.missing_layer_families == ["standard_attention"] + assert coverage.unresolved_risks == [] + + +def test_run_chat_template_rollout_stage(monkeypatch) -> None: + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow._import_integration_module", + lambda name: SimpleNamespace( + run_chat_template_rollout=lambda *, base_model: SimpleNamespace( + passed=True, + scenario_count=6, + failed_scenarios=[], + output_dir="/tmp/chat-template", + model_dump=lambda mode="json": { + "passed": True, + "scenario_count": 6, + "failed_scenarios": [], + }, + ) + ), + ) + + result = run_chat_template_rollout_stage( + base_model="Qwen/Qwen3.5-35B-A3B", + architecture=ArchitectureReport( + base_model="Qwen/Qwen3.5-35B-A3B", + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + ), + ) + + assert result.passed is True + assert result.artifact_dir == "/tmp/chat-template" + + +def test_run_correctness_sensitivity_stage_runs_dense_models(monkeypatch) -> None: + case_configs: list[SimpleNamespace] = [] + oracle_module = SimpleNamespace( + OracleCaseConfig=lambda **kwargs: SimpleNamespace(**kwargs), + selected_suite_topologies=lambda *, is_moe: [ + SimpleNamespace(world_size=lambda: 1, slug=lambda: "tp1"), + SimpleNamespace(world_size=lambda: 2, slug=lambda: "tp2"), + SimpleNamespace(world_size=lambda: 2, slug=lambda: "dp2"), + SimpleNamespace(world_size=lambda: 4, slug=lambda: "tp2_dp2"), + ], + oracle_topology=lambda *, is_moe: SimpleNamespace(world_size=lambda: 1), + selected_oracle_objectives=lambda: ["sft"], + supported_sensitivity_mutations_for_objective=lambda objective, *, is_moe: ( + ["skip_finalize"] if objective == "sft" and not is_moe else [] + ), + sensitivity_topology_for_mutation=lambda mutation, *, is_moe: SimpleNamespace( + world_size=lambda: 2 + ), + available_gpu_count=lambda: 4, + run_suite=lambda case_config, max_world_size: ( + case_configs.append(case_config) + or [ + SimpleNamespace( + variant="sft_topology_tp2_dp2", + topology="tp2_dp2", + signal="pass", + fail_count=0, + ) + ] + ), + run_sensitivity_suite=lambda case_config, mutations, max_world_size: [ + SimpleNamespace( + variant="sft_sensitivity_skip_finalize", + topology="tp2", + signal="fail", + expected_signal="fail", + fail_count=1, + ) + ], + ensure_case_artifacts=lambda case_config: SimpleNamespace( + case_dir="/tmp/oracle" + ), + keep_topology_artifacts=lambda: False, + ) + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow._import_integration_module", + lambda name: oracle_module, + ) + monkeypatch.delenv(SKIP_SENSITIVITY_ENV, raising=False) + + result = run_correctness_sensitivity_stage( + base_model="Qwen/Qwen3.5-4B", + architecture=ArchitectureReport( + base_model="Qwen/Qwen3.5-4B", + model_key="qwen3_5_dense", + handler_key="qwen3_5_dense", + layer_families=[ + LayerFamilyInstance(key="dense_mlp", layer_index=0), + LayerFamilyInstance(key="gated_delta_net_attention", layer_index=0), + LayerFamilyInstance(key="standard_attention", layer_index=3), + ], + recommended_min_layers=4, + ), + ) + + assert result.passed is True + assert result.metrics["is_moe"] is False + assert result.metrics["available_gpu_count"] == 4 + assert result.metrics["max_world_size"] == 4 + assert result.metrics["required_gpu_count"] == 1 + assert result.metrics["correctness_variant_count"] == 1 + assert result.metrics["correctness_excluded_topologies"] == [] + assert result.metrics["sensitivity_mutations"] == ["skip_finalize"] + assert case_configs[0].is_moe is False + + +def test_run_yes_no_trainability_stage(monkeypatch) -> None: + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow._import_integration_module", + lambda name: SimpleNamespace( + run_yes_no_trainability=lambda *, base_model, allow_unvalidated_arch=False: ( + SimpleNamespace( + latest_step=2, + initial_eval_reward=0.4, + final_eval_reward=0.95, + reward_threshold=0.95, + saturated_step=2, + output_dir="/tmp/trainability", + model_dump=lambda mode="json": { + "latest_step": 2, + "initial_eval_reward": 0.4, + "final_eval_reward": 0.95, + "reward_threshold": 0.95, + "saturated_step": 2, + }, + ) + ) + ), + ) + + result = run_yes_no_trainability_stage( + base_model="Qwen/Qwen3.5-35B-A3B", + architecture=ArchitectureReport( + base_model="Qwen/Qwen3.5-35B-A3B", + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + ), + ) + + assert result.passed is True + assert result.artifact_dir == "/tmp/trainability" + + +def test_run_train_inf_mismatch_stage(monkeypatch) -> None: + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow._import_integration_module", + lambda name: SimpleNamespace( + run_train_inf_mismatch=lambda *, base_model: SimpleNamespace( + passed=True, + artifact_dir="/tmp/train-inf-mismatch", + model_dump=lambda mode="json": { + "base_model": base_model, + "passed": True, + "passed_count": 1, + "failed_count": 0, + }, + ) + ), + ) + + result = run_train_inf_mismatch_stage( + base_model="Qwen/Qwen3.5-35B-A3B", + architecture=ArchitectureReport( + base_model="Qwen/Qwen3.5-35B-A3B", + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + ), + ) + + assert result.name == "train_inf_mismatch" + assert result.passed is True + assert result.artifact_dir == "/tmp/train-inf-mismatch" + assert result.metrics == { + "base_model": "Qwen/Qwen3.5-35B-A3B", + "passed": True, + "passed_count": 1, + "failed_count": 0, + } + + +def test_run_native_vllm_lora_stage(monkeypatch) -> None: + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow._import_integration_module", + lambda name: ( + SimpleNamespace( + OracleCaseConfig=lambda **kwargs: SimpleNamespace(**kwargs), + ) + if name == "integration.megatron.model_support.oracle_harness" + else SimpleNamespace( + run_native_vllm_lora=lambda case_config: SimpleNamespace( + rollout_weights_mode="lora", + step0_name="validation@0", + step1_name="validation@1", + model_ids_before=["validation@0"], + model_ids_after=["validation@0", "validation@1"], + step0_served=True, + step1_served=True, + output_dir="/tmp/native-vllm-lora", + model_dump=lambda mode="json": { + "rollout_weights_mode": "lora", + "step0_name": "validation@0", + "step1_name": "validation@1", + "model_ids_before": ["validation@0"], + "model_ids_after": ["validation@0", "validation@1"], + "step0_served": True, + "step1_served": True, + }, + ) + ) + ), + ) + + result = run_native_vllm_lora_stage( + base_model="Qwen/Qwen3.5-35B-A3B", + architecture=ArchitectureReport( + base_model="Qwen/Qwen3.5-35B-A3B", + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + ), + ) + + assert result.name == "native_vllm_lora" + assert result.passed is True + assert result.artifact_dir == "/tmp/native-vllm-lora" + + +def test_run_packed_position_ids_stage(monkeypatch) -> None: + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow._import_integration_module", + lambda name: SimpleNamespace( + run_packed_position_ids=lambda *, base_model, num_layers, allow_unvalidated_arch=False: ( + SimpleNamespace( + output_dir="/tmp/packed-position-ids", + model_dump=lambda mode="json": { + "base_model": base_model, + "num_layers": num_layers, + "scenarios": [ + { + "name": "stop_early", + "matched": True, + "checked_token_count": 40, + }, + { + "name": "truncate", + "matched": True, + "checked_token_count": 44, + }, + ], + }, + ) + ) + ), + ) + + result = run_packed_position_ids_stage( + base_model="Qwen/Qwen3.5-35B-A3B", + architecture=ArchitectureReport( + base_model="Qwen/Qwen3.5-35B-A3B", + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + recommended_min_layers=4, + ), + ) + + assert result.passed is True + assert result.artifact_dir == "/tmp/packed-position-ids" + + +def test_assess_minimal_layer_coverage_passes_when_prefix_covers_all_families( + monkeypatch, +) -> None: + architecture = ArchitectureReport( + base_model="Qwen/Qwen3.5-35B-A3B", + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + layer_families=[ + LayerFamilyInstance(key="gated_delta_net_attention", layer_index=0), + LayerFamilyInstance(key="standard_attention", layer_index=3), + LayerFamilyInstance(key="grouped_moe_mlp", layer_index=0), + LayerFamilyInstance(key="shared_experts_mlp", layer_index=0), + ], + recommended_min_layers=4, + ) + + coverage = assess_minimal_layer_coverage( + base_model=architecture.base_model, + num_layers=4, + architecture=architecture, + ) + + assert coverage.covered is True + assert coverage.missing_layer_families == [] + + +def test_run_lora_coverage_stage_reports_missing_targets(monkeypatch) -> None: + architecture = ArchitectureReport( + base_model="Qwen/Qwen3.5-35B-A3B", + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + layer_families=[LayerFamilyInstance(key="grouped_moe_mlp", layer_index=0)], + recommended_min_layers=4, + ) + oracle_module = SimpleNamespace( + OracleCaseConfig=lambda **kwargs: SimpleNamespace(**kwargs) + ) + coverage_report = SimpleNamespace( + missing_wrapped_target_modules=["in_proj_z"], + missing_exported_target_modules=[], + model_dump=lambda mode="json": { + "base_model": "Qwen/Qwen3.5-35B-A3B", + "missing_wrapped_target_modules": ["in_proj_z"], + }, + ) + coverage_module = SimpleNamespace( + run_lora_coverage=lambda case_config: coverage_report + ) + + def _import_integration_module(name: str): + if name == "integration.megatron.model_support.oracle_harness": + return oracle_module + if name == "integration.megatron.model_support.lora_coverage": + return coverage_module + raise AssertionError(name) + + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow._import_integration_module", + _import_integration_module, + ) + + stage = run_lora_coverage_stage( + base_model="Qwen/Qwen3.5-35B-A3B", + architecture=architecture, + ) + + assert stage.name == "lora_coverage" + assert stage.passed is False + assert stage.metrics == { + "base_model": "Qwen/Qwen3.5-35B-A3B", + "missing_wrapped_target_modules": ["in_proj_z"], + } + + +def test_run_correctness_sensitivity_stage_summarizes_reports(monkeypatch) -> None: + architecture = ArchitectureReport( + base_model="Qwen/Qwen3.5-35B-A3B", + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + layer_families=[LayerFamilyInstance(key="grouped_moe_mlp", layer_index=0)], + recommended_min_layers=4, + ) + oracle_module = SimpleNamespace( + OracleCaseConfig=lambda **kwargs: SimpleNamespace(**kwargs), + selected_suite_topologies=lambda *, is_moe: [ + SimpleNamespace(world_size=lambda: 1, slug=lambda: "tp1"), + SimpleNamespace(world_size=lambda: 2, slug=lambda: "tp2"), + ], + oracle_topology=lambda *, is_moe: SimpleNamespace(world_size=lambda: 1), + selected_oracle_objectives=lambda: ["sft"], + supported_sensitivity_mutations_for_objective=lambda objective, *, is_moe: ( + ["skip_finalize"] if objective == "sft" else [] + ), + sensitivity_topology_for_mutation=lambda mutation, *, is_moe: SimpleNamespace( + world_size=lambda: 2 + ), + available_gpu_count=lambda: 2, + run_suite=lambda case_config, max_world_size: [ + SimpleNamespace( + variant="sft_topology_tp2", + topology="tp2", + signal="pass", + fail_count=0, + ) + ], + run_sensitivity_suite=lambda case_config, mutations, max_world_size: [ + SimpleNamespace( + variant="sft_sensitivity_skip_finalize", + topology="tp2", + signal="fail", + expected_signal="fail", + fail_count=1, + ) + ], + ensure_case_artifacts=lambda case_config: SimpleNamespace( + case_dir="/tmp/oracle" + ), + keep_topology_artifacts=lambda: False, + ) + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow._import_integration_module", + lambda name: oracle_module, + ) + + stage = run_correctness_sensitivity_stage( + base_model="Qwen/Qwen3.5-35B-A3B", + architecture=architecture, + ) + + assert stage.name == "correctness_sensitivity" + assert stage.passed is True + assert stage.metrics["requested_num_layers"] == 4 + assert stage.metrics["is_moe"] is True + assert stage.metrics["objectives"] == ["sft"] + assert stage.metrics["sensitivity_mutations"] == ["skip_finalize"] + assert stage.metrics["available_gpu_count"] == 2 + assert stage.metrics["required_gpu_count"] == 1 + assert stage.metrics["correctness_variant_count"] == 1 + assert stage.metrics["sensitivity_skipped"] is False + assert stage.metrics["sensitivity_skip_reason"] is None + assert stage.metrics["sensitivity_variant_count"] == 1 + assert stage.artifact_dir == "/tmp/oracle" + + +def test_run_correctness_sensitivity_stage_can_skip_sensitivity_only( + monkeypatch, +) -> None: + architecture = ArchitectureReport( + base_model="Qwen/Qwen3.5-35B-A3B", + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + layer_families=[LayerFamilyInstance(key="grouped_moe_mlp", layer_index=0)], + recommended_min_layers=4, + ) + oracle_module = SimpleNamespace( + OracleCaseConfig=lambda **kwargs: SimpleNamespace(**kwargs), + selected_suite_topologies=lambda *, is_moe: [ + SimpleNamespace(world_size=lambda: 1, slug=lambda: "tp1"), + SimpleNamespace(world_size=lambda: 2, slug=lambda: "tp2"), + ], + oracle_topology=lambda *, is_moe: SimpleNamespace(world_size=lambda: 1), + selected_oracle_objectives=lambda: ["sft"], + supported_sensitivity_mutations_for_objective=lambda objective, *, is_moe: ( + ["skip_finalize"] if objective == "sft" else [] + ), + sensitivity_topology_for_mutation=lambda mutation, *, is_moe: SimpleNamespace( + world_size=lambda: 4 + ), + available_gpu_count=lambda: 2, + run_suite=lambda case_config, max_world_size: [ + SimpleNamespace( + variant="sft_topology_tp2", + topology="tp2", + signal="pass", + fail_count=0, + ) + ], + run_sensitivity_suite=lambda case_config, mutations, max_world_size: ( + _ for _ in () + ).throw(AssertionError("sensitivity suite should be skipped")), + ensure_case_artifacts=lambda case_config: SimpleNamespace( + case_dir="/tmp/oracle" + ), + keep_topology_artifacts=lambda: False, + ) + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow._import_integration_module", + lambda name: oracle_module, + ) + monkeypatch.setenv(SKIP_SENSITIVITY_ENV, "1") + + stage = run_correctness_sensitivity_stage( + base_model="Qwen/Qwen3.5-35B-A3B", + architecture=architecture, + ) + + assert stage.name == "correctness_sensitivity" + assert stage.passed is True + assert stage.metrics["required_gpu_count"] == 1 + assert stage.metrics["correctness_variant_count"] == 1 + assert stage.metrics["sensitivity_mutations"] == [] + assert stage.metrics["sensitivity_skipped"] is True + assert stage.metrics["sensitivity_skip_reason"] == f"{SKIP_SENSITIVITY_ENV}=1" + assert stage.metrics["sensitivity_variant_count"] == 0 + assert stage.metrics["sensitivity_variants"] == [] + + +def test_run_merged_vllm_serving_stage_reports_served_model(monkeypatch) -> None: + architecture = ArchitectureReport( + base_model="Qwen/Qwen3.5-35B-A3B", + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + recommended_min_layers=4, + ) + oracle_module = SimpleNamespace( + OracleCaseConfig=lambda **kwargs: SimpleNamespace(**kwargs) + ) + merged_module = SimpleNamespace( + run_merged_vllm_serving=lambda case_config: SimpleNamespace( + output_dir="/tmp/merged-serving", + model_ids=["validation@0"], + model_dump=lambda mode="json": { + "base_model": "Qwen/Qwen3.5-35B-A3B", + "served_model_name": "validation@0", + }, + ) + ) + + def _import_integration_module(name: str): + if name == "integration.megatron.model_support.oracle_harness": + return oracle_module + if name == "integration.megatron.lora.merged_vllm_serving": + return merged_module + raise AssertionError(name) + + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow._import_integration_module", + _import_integration_module, + ) + + stage = run_merged_vllm_serving_stage( + base_model="Qwen/Qwen3.5-35B-A3B", + architecture=architecture, + ) + + assert stage.name == "merged_vllm_serving" + assert stage.passed is True + assert stage.metrics == { + "base_model": "Qwen/Qwen3.5-35B-A3B", + "served_model_name": "validation@0", + } + assert stage.artifact_dir == "/tmp/merged-serving" diff --git a/tests/integration/megatron/model_support/workflow.py b/tests/integration/megatron/model_support/workflow.py new file mode 100644 index 000000000..b7a22af6a --- /dev/null +++ b/tests/integration/megatron/model_support/workflow.py @@ -0,0 +1,732 @@ +from contextlib import contextmanager, redirect_stderr, redirect_stdout +import importlib +import importlib.metadata +import os +from pathlib import Path +import subprocess +import sys +import tempfile +from typing import Any + +from art.megatron.model_support.discovery import inspect_architecture +from art.megatron.model_support.registry import ( + get_model_support_handler_for_spec, + get_model_support_spec, +) +from art.megatron.model_support.spec import ( + ArchitectureReport, + MinimalLayerCoverageReport, + NativeVllmLoraStatus, + ValidationReport, + ValidationStageResult, +) + +REPO_ROOT = Path(__file__).resolve().parents[4] +TESTS_DIR = REPO_ROOT / "tests" +LOCAL_LOG_DIR = REPO_ROOT / ".local" +CORRECTNESS_LOG_PATH = LOCAL_LOG_DIR / "correctness.log" +SENSITIVITY_LOG_PATH = LOCAL_LOG_DIR / "sensitivity.log" +LIVE_TRAINING_LOG_PATH = LOCAL_LOG_DIR / "live_training.log" +ORACLE_LIVE_TRAINING_LOG_ENV = "ART_ORACLE_LIVE_TRAINING_LOG" +SKIP_SENSITIVITY_ENV = "ART_MODEL_SUPPORT_SKIP_SENSITIVITY" + +MANDATORY_VALIDATION_STAGES = ( + "dependency_resolution", + "architecture_discovery", + "hf_parity", + "lora_coverage", + "train_inf_mismatch", + "merged_vllm_serving", + "correctness_sensitivity", + "chat_template_rollout", + "packed_position_ids", + "yes_no_trainability", +) +NATIVE_VLLM_LORA_STAGE = "native_vllm_lora" +SUBPROCESS_VALIDATION_STAGES = frozenset( + { + "hf_parity", + "lora_coverage", + "train_inf_mismatch", + "merged_vllm_serving", + "correctness_sensitivity", + "chat_template_rollout", + "packed_position_ids", + "yes_no_trainability", + NATIVE_VLLM_LORA_STAGE, + } +) + + +def build_validation_stage_names( + *, + include_native_vllm_lora: bool = False, + native_vllm_lora_status: NativeVllmLoraStatus | None = None, +) -> list[str]: + stages = list(MANDATORY_VALIDATION_STAGES) + if include_native_vllm_lora or native_vllm_lora_status not in {None, "disabled"}: + stages.append(NATIVE_VLLM_LORA_STAGE) + return stages + + +def detect_dependency_versions() -> dict[str, str]: + versions: dict[str, str] = {} + for package_name in ("transformers", "vllm", "megatron-bridge"): + try: + versions[package_name] = importlib.metadata.version(package_name) + except importlib.metadata.PackageNotFoundError: + continue + return versions + + +def initialize_validation_report( + *, + base_model: str, + include_native_vllm_lora: bool = False, + allow_unvalidated_arch: bool = False, +) -> ValidationReport: + spec = get_model_support_spec( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + handler = get_model_support_handler_for_spec(spec) + return ValidationReport( + base_model=base_model, + model_key=spec.key, + dependency_versions=detect_dependency_versions(), + stages=[ + ValidationStageResult(name=stage_name) + for stage_name in build_validation_stage_names( + include_native_vllm_lora=include_native_vllm_lora, + native_vllm_lora_status=handler.native_vllm_lora_status, + ) + ], + ) + + +def _stage_error_metrics(exc: Exception) -> dict[str, Any]: + return {"error": f"{type(exc).__name__}: {exc}"} + + +def _truthy_env(name: str) -> bool: + value = os.environ.get(name) + return value is not None and value.strip().lower() in {"1", "true", "yes", "on"} + + +def _import_integration_module(module_name: str) -> Any: + tests_dir = str(TESTS_DIR) + if tests_dir not in sys.path: + sys.path.insert(0, tests_dir) + return importlib.import_module(module_name) + + +def _subprocess_log_tail(log_path: Path, *, max_lines: int = 40) -> str: + if not log_path.exists(): + return "" + lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines() + return "\n".join(lines[-max_lines:]) + + +@contextmanager +def _redirect_output(log_path: Path): + log_path.parent.mkdir(parents=True, exist_ok=True) + with log_path.open("w", encoding="utf-8") as log_file: + with redirect_stdout(log_file), redirect_stderr(log_file): + yield + + +@contextmanager +def _temporary_env(**updates: str): + previous = {key: os.environ.get(key) for key in updates} + os.environ.update(updates) + try: + yield + finally: + for key, value in previous.items(): + if value is None: + os.environ.pop(key, None) + continue + os.environ[key] = value + + +def _run_stage_in_subprocess( + *, + stage_name: str, + base_model: str, + architecture: ArchitectureReport, + allow_unvalidated_arch: bool = False, +) -> ValidationStageResult: + with tempfile.TemporaryDirectory(prefix=f"model_support_{stage_name}_") as tmp_dir: + tmp_path = Path(tmp_dir) + architecture_json = tmp_path / "architecture.json" + output_json = tmp_path / "stage_result.json" + log_path = tmp_path / "stage.log" + architecture_json.write_text( + architecture.model_dump_json(indent=2), + encoding="utf-8", + ) + cmd = [ + sys.executable, + "-m", + "integration.megatron.model_support.workflow_stage_worker", + "--stage", + stage_name, + "--base-model", + base_model, + "--architecture-json", + str(architecture_json), + "--output-json", + str(output_json), + ] + if allow_unvalidated_arch: + cmd.append("--allow-unsupported-arch") + env = os.environ.copy() + existing_pythonpath = env.get("PYTHONPATH") + env["PYTHONPATH"] = ( + str(TESTS_DIR) + if not existing_pythonpath + else f"{TESTS_DIR}{os.pathsep}{existing_pythonpath}" + ) + with log_path.open("w", encoding="utf-8") as log_file: + completed = subprocess.run( + cmd, + cwd=str(REPO_ROOT), + env=env, + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + check=False, + ) + if completed.returncode != 0: + tail = _subprocess_log_tail(log_path) + error = ( + f"subprocess exited with code {completed.returncode}" + if not tail + else tail + ) + return ValidationStageResult( + name=stage_name, + passed=False, + metrics={"error": error}, + ) + if not output_json.exists(): + return ValidationStageResult( + name=stage_name, + passed=False, + metrics={"error": "stage worker did not write output_json"}, + ) + return ValidationStageResult.model_validate_json( + output_json.read_text(encoding="utf-8") + ) + + +def run_hf_parity_stage( + *, + base_model: str, + architecture: ArchitectureReport, + allow_unvalidated_arch: bool = False, +) -> ValidationStageResult: + hf_parity = _import_integration_module( + "integration.megatron.model_support.hf_parity" + ) + oracle_harness = _import_integration_module( + "integration.megatron.model_support.oracle_harness" + ) + spec = get_model_support_spec( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + handler = get_model_support_handler_for_spec(spec) + case_config = oracle_harness.OracleCaseConfig( + base_model=base_model, + is_moe=handler.is_moe, + precision="fp32", + num_layers=max(1, architecture.recommended_min_layers), + num_steps=1, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + report = hf_parity.run_hf_parity(case_config=case_config) + case_artifacts = oracle_harness.ensure_case_artifacts(case_config) + artifact_dir = str( + Path(case_artifacts.case_dir) / hf_parity.HF_PARITY_OUTPUT_DIRNAME + ) + return ValidationStageResult( + name="hf_parity", + passed=report.signal == "pass", + metrics={ + "requested_num_layers": report.requested_num_layers, + "coverage": report.coverage.model_dump(mode="json"), + "signal": report.signal, + "pass_count": report.pass_count, + "fail_count": report.fail_count, + "phases": [row.model_dump(mode="json") for row in report.metrics], + }, + artifact_dir=artifact_dir, + ) + + +def run_lora_coverage_stage( + *, + base_model: str, + architecture: ArchitectureReport, + allow_unvalidated_arch: bool = False, +) -> ValidationStageResult: + lora_coverage = _import_integration_module( + "integration.megatron.model_support.lora_coverage" + ) + oracle_harness = _import_integration_module( + "integration.megatron.model_support.oracle_harness" + ) + spec = get_model_support_spec( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + handler = get_model_support_handler_for_spec(spec) + case_config = oracle_harness.OracleCaseConfig( + base_model=base_model, + is_moe=handler.is_moe, + precision="fp32", + num_layers=max(1, architecture.recommended_min_layers), + num_steps=1, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + report = lora_coverage.run_lora_coverage(case_config) + return ValidationStageResult( + name="lora_coverage", + passed=not report.missing_wrapped_target_modules + and not report.missing_exported_target_modules, + metrics=report.model_dump(mode="json"), + ) + + +def run_train_inf_mismatch_stage( + *, + base_model: str, + architecture: ArchitectureReport, + allow_unvalidated_arch: bool = False, +) -> ValidationStageResult: + del architecture + del allow_unvalidated_arch + train_inf_mismatch = _import_integration_module( + "integration.megatron.train_inf_mismatch.workflow_stage" + ) + report = train_inf_mismatch.run_train_inf_mismatch(base_model=base_model) + return ValidationStageResult( + name="train_inf_mismatch", + passed=report.passed, + metrics=report.model_dump(mode="json"), + artifact_dir=report.artifact_dir, + ) + + +def run_correctness_sensitivity_stage( + *, + base_model: str, + architecture: ArchitectureReport, + allow_unvalidated_arch: bool = False, +) -> ValidationStageResult: + oracle_harness = _import_integration_module( + "integration.megatron.model_support.oracle_harness" + ) + spec = get_model_support_spec( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + handler = get_model_support_handler_for_spec(spec) + case_config = oracle_harness.OracleCaseConfig( + base_model=base_model, + is_moe=handler.is_moe, + precision="fp32", + num_layers=max(1, architecture.recommended_min_layers), + num_steps=1, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + suite_topologies = list( + oracle_harness.selected_suite_topologies(is_moe=handler.is_moe) + ) + objectives = list(oracle_harness.selected_oracle_objectives()) + skip_sensitivity = _truthy_env(SKIP_SENSITIVITY_ENV) + available_gpu_count = oracle_harness.available_gpu_count() + max_world_size = available_gpu_count + oracle_world_size = oracle_harness.oracle_topology( + is_moe=handler.is_moe + ).world_size() + if available_gpu_count < oracle_world_size: + raise RuntimeError( + "Need " + f"{oracle_world_size} GPUs for oracle topology, found {available_gpu_count}" + ) + selected_suite_topologies = [ + topology + for topology in suite_topologies + if topology.world_size() <= max_world_size + ] + excluded_suite_topologies = [ + topology + for topology in suite_topologies + if topology.world_size() > max_world_size + ] + mutations: list[str] = [] + excluded_sensitivity_mutations: list[str] = [] + if not skip_sensitivity: + for objective in objectives: + for ( + mutation + ) in oracle_harness.supported_sensitivity_mutations_for_objective( + objective, + is_moe=handler.is_moe, + ): + if mutation not in mutations: + mutations.append(mutation) + excluded_sensitivity_mutations = [ + mutation + for mutation in mutations + if oracle_harness.sensitivity_topology_for_mutation( + mutation, + is_moe=handler.is_moe, + ).world_size() + > max_world_size + ] + mutations = [ + mutation + for mutation in mutations + if mutation not in excluded_sensitivity_mutations + ] + LIVE_TRAINING_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) + LIVE_TRAINING_LOG_PATH.write_text("", encoding="utf-8") + with _temporary_env(**{ORACLE_LIVE_TRAINING_LOG_ENV: str(LIVE_TRAINING_LOG_PATH)}): + with _redirect_output(CORRECTNESS_LOG_PATH): + suite_reports = oracle_harness.run_suite( + case_config=case_config, + max_world_size=max_world_size, + ) + sensitivity_reports = [] + if skip_sensitivity: + SENSITIVITY_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) + SENSITIVITY_LOG_PATH.write_text( + ( + "Sensitivity suite skipped. " + f"Set {SKIP_SENSITIVITY_ENV}=0 to re-enable workflow sensitivity.\n" + ), + encoding="utf-8", + ) + elif not mutations: + SENSITIVITY_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) + SENSITIVITY_LOG_PATH.write_text( + ( + "Sensitivity suite skipped. " + f"No sensitivity mutations fit max_world_size={max_world_size}.\n" + ), + encoding="utf-8", + ) + else: + with _redirect_output(SENSITIVITY_LOG_PATH): + sensitivity_reports = oracle_harness.run_sensitivity_suite( + case_config=case_config, + mutations=mutations, + max_world_size=max_world_size, + ) + case_artifacts = oracle_harness.ensure_case_artifacts(case_config) + return ValidationStageResult( + name="correctness_sensitivity", + passed=True, + metrics={ + "requested_num_layers": case_config.num_layers, + "is_moe": handler.is_moe, + "allow_unvalidated_arch": allow_unvalidated_arch, + "objectives": objectives, + "sensitivity_mutations": mutations, + "excluded_sensitivity_mutations": excluded_sensitivity_mutations, + "available_gpu_count": available_gpu_count, + "max_world_size": max_world_size, + "required_gpu_count": oracle_world_size, + "topology_artifacts_retained": oracle_harness.keep_topology_artifacts(), + "correctness_variant_count": len(suite_reports), + "correctness_excluded_topology_count": len(excluded_suite_topologies), + "correctness_excluded_topologies": [ + topology.slug() for topology in excluded_suite_topologies + ], + "correctness_selected_topologies": [ + topology.slug() for topology in selected_suite_topologies + ], + "correctness_variants": [ + { + "variant": report.variant, + "topology": report.topology, + "signal": report.signal, + "fail_count": report.fail_count, + } + for report in suite_reports + ], + "sensitivity_skipped": skip_sensitivity, + "sensitivity_skip_reason": ( + f"{SKIP_SENSITIVITY_ENV}=1" if skip_sensitivity else None + ), + "sensitivity_variant_count": len(sensitivity_reports), + "sensitivity_variants": [ + { + "variant": report.variant, + "topology": report.topology, + "signal": report.signal, + "expected_signal": report.expected_signal, + "fail_count": report.fail_count, + } + for report in sensitivity_reports + ], + }, + artifact_dir=case_artifacts.case_dir, + ) + + +def run_merged_vllm_serving_stage( + *, + base_model: str, + architecture: ArchitectureReport, + allow_unvalidated_arch: bool = False, +) -> ValidationStageResult: + merged_vllm_serving = _import_integration_module( + "integration.megatron.lora.merged_vllm_serving" + ) + oracle_harness = _import_integration_module( + "integration.megatron.model_support.oracle_harness" + ) + spec = get_model_support_spec( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + handler = get_model_support_handler_for_spec(spec) + case_config = oracle_harness.OracleCaseConfig( + base_model=base_model, + is_moe=handler.is_moe, + precision="fp32", + num_layers=max(1, architecture.recommended_min_layers), + num_steps=1, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + report = merged_vllm_serving.run_merged_vllm_serving(case_config) + return ValidationStageResult( + name="merged_vllm_serving", + passed=bool(report.model_ids), + metrics=report.model_dump(mode="json"), + artifact_dir=report.output_dir, + ) + + +def run_chat_template_rollout_stage( + *, + base_model: str, + architecture: ArchitectureReport, + allow_unvalidated_arch: bool = False, +) -> ValidationStageResult: + del architecture + del allow_unvalidated_arch + chat_template_rollout = _import_integration_module( + "integration.megatron.model_support.chat_template_rollout" + ) + report = chat_template_rollout.run_chat_template_rollout(base_model=base_model) + return ValidationStageResult( + name="chat_template_rollout", + passed=report.passed, + metrics=report.model_dump(mode="json"), + artifact_dir=report.output_dir, + ) + + +def run_yes_no_trainability_stage( + *, + base_model: str, + architecture: ArchitectureReport, + allow_unvalidated_arch: bool = False, +) -> ValidationStageResult: + del architecture + yes_no_trainability = _import_integration_module( + "integration.megatron.trainability.yes_no_trainability" + ) + report = yes_no_trainability.run_yes_no_trainability( + base_model=base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + passed = ( + report.saturated_step is not None + and report.saturated_step > 0 + and report.initial_eval_reward < report.reward_threshold + and report.final_eval_reward is not None + and report.final_eval_reward >= report.reward_threshold + and report.final_eval_reward > report.initial_eval_reward + ) + return ValidationStageResult( + name="yes_no_trainability", + passed=passed, + metrics=report.model_dump(mode="json"), + artifact_dir=report.output_dir, + ) + + +def run_native_vllm_lora_stage( + *, + base_model: str, + architecture: ArchitectureReport, + allow_unvalidated_arch: bool = False, +) -> ValidationStageResult: + native_vllm_lora = _import_integration_module( + "integration.megatron.lora.native_vllm_lora" + ) + oracle_harness = _import_integration_module( + "integration.megatron.model_support.oracle_harness" + ) + spec = get_model_support_spec( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + handler = get_model_support_handler_for_spec(spec) + case_config = oracle_harness.OracleCaseConfig( + base_model=base_model, + is_moe=handler.is_moe, + precision="fp32", + num_layers=max(1, architecture.recommended_min_layers), + num_steps=1, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + report = native_vllm_lora.run_native_vllm_lora(case_config) + passed = ( + report.rollout_weights_mode == "lora" + and report.step0_served + and report.step1_served + and report.step0_name in report.model_ids_before + and report.step1_name not in report.model_ids_before + and report.step0_name in report.model_ids_after + and report.step1_name in report.model_ids_after + ) + return ValidationStageResult( + name=NATIVE_VLLM_LORA_STAGE, + passed=passed, + metrics=report.model_dump(mode="json"), + artifact_dir=report.output_dir, + ) + + +def run_packed_position_ids_stage( + *, + base_model: str, + architecture: ArchitectureReport, + allow_unvalidated_arch: bool = False, +) -> ValidationStageResult: + packed_position_ids = _import_integration_module( + "integration.megatron.model_support.packed_position_ids" + ) + report = packed_position_ids.run_packed_position_ids( + base_model=base_model, + num_layers=max(1, architecture.recommended_min_layers), + allow_unvalidated_arch=allow_unvalidated_arch, + ) + metrics = report.model_dump(mode="json") + passed = bool(metrics["scenarios"]) and all( + scenario["matched"] and scenario["checked_token_count"] > 0 + for scenario in metrics["scenarios"] + ) + return ValidationStageResult( + name="packed_position_ids", + passed=passed, + metrics=metrics, + artifact_dir=report.output_dir, + ) + + +def build_validation_report( + *, + base_model: str, + include_native_vllm_lora: bool = False, + allow_unvalidated_arch: bool = False, +) -> ValidationReport: + report = initialize_validation_report( + base_model=base_model, + include_native_vllm_lora=include_native_vllm_lora, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + architecture = ( + inspect_architecture(base_model, allow_unvalidated_arch=True) + if allow_unvalidated_arch + else inspect_architecture(base_model) + ) + stage_runners = { + "hf_parity": run_hf_parity_stage, + "lora_coverage": run_lora_coverage_stage, + "train_inf_mismatch": run_train_inf_mismatch_stage, + "merged_vllm_serving": run_merged_vllm_serving_stage, + "correctness_sensitivity": run_correctness_sensitivity_stage, + "chat_template_rollout": run_chat_template_rollout_stage, + "packed_position_ids": run_packed_position_ids_stage, + "yes_no_trainability": run_yes_no_trainability_stage, + NATIVE_VLLM_LORA_STAGE: run_native_vllm_lora_stage, + } + stage_results: dict[str, ValidationStageResult] = {} + for stage_name, stage_runner in stage_runners.items(): + if stage_name in SUBPROCESS_VALIDATION_STAGES: + stage_results[stage_name] = _run_stage_in_subprocess( + stage_name=stage_name, + base_model=base_model, + architecture=architecture, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + continue + try: + stage_results[stage_name] = stage_runner( + base_model=base_model, + architecture=architecture, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + except Exception as exc: + stage_results[stage_name] = ValidationStageResult( + name=stage_name, + passed=False, + metrics=_stage_error_metrics(exc), + ) + for stage in report.stages: + if stage.name == "dependency_resolution": + stage.passed = True + stage.metrics = dict(report.dependency_versions) + continue + if stage.name != "architecture_discovery": + stage_result = stage_results.get(stage.name) + if stage_result is not None: + stage.passed = stage_result.passed + stage.metrics = dict(stage_result.metrics) + stage.artifact_dir = stage_result.artifact_dir + continue + stage.passed = not architecture.unresolved_risks + stage.metrics = { + "recommended_min_layers": architecture.recommended_min_layers, + "layer_families": [ + family.model_dump() for family in architecture.layer_families + ], + "unresolved_risks": list(architecture.unresolved_risks), + } + return report + + +def assess_minimal_layer_coverage( + *, + base_model: str, + num_layers: int, + architecture: ArchitectureReport | None = None, + allow_unvalidated_arch: bool = False, +) -> MinimalLayerCoverageReport: + architecture_report = architecture or ( + inspect_architecture(base_model, allow_unvalidated_arch=True) + if allow_unvalidated_arch + else inspect_architecture(base_model) + ) + missing_layer_families = [ + family.key + for family in architecture_report.layer_families + if family.layer_index is not None and family.layer_index >= num_layers + ] + return MinimalLayerCoverageReport( + base_model=base_model, + model_key=architecture_report.model_key, + requested_num_layers=num_layers, + recommended_min_layers=architecture_report.recommended_min_layers, + covered=not missing_layer_families and not architecture_report.unresolved_risks, + missing_layer_families=missing_layer_families, + unresolved_risks=list(architecture_report.unresolved_risks), + ) diff --git a/tests/integration/megatron/model_support/workflow_stage_worker.py b/tests/integration/megatron/model_support/workflow_stage_worker.py new file mode 100644 index 000000000..c854259fa --- /dev/null +++ b/tests/integration/megatron/model_support/workflow_stage_worker.py @@ -0,0 +1,63 @@ +import argparse +from pathlib import Path + +from art.megatron.model_support.spec import ArchitectureReport + +from .workflow import ( + run_chat_template_rollout_stage, + run_correctness_sensitivity_stage, + run_hf_parity_stage, + run_lora_coverage_stage, + run_merged_vllm_serving_stage, + run_native_vllm_lora_stage, + run_packed_position_ids_stage, + run_train_inf_mismatch_stage, + run_yes_no_trainability_stage, +) + +_STAGE_RUNNERS = { + "hf_parity": run_hf_parity_stage, + "lora_coverage": run_lora_coverage_stage, + "train_inf_mismatch": run_train_inf_mismatch_stage, + "merged_vllm_serving": run_merged_vllm_serving_stage, + "correctness_sensitivity": run_correctness_sensitivity_stage, + "chat_template_rollout": run_chat_template_rollout_stage, + "packed_position_ids": run_packed_position_ids_stage, + "yes_no_trainability": run_yes_no_trainability_stage, + "native_vllm_lora": run_native_vllm_lora_stage, +} + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--stage", required=True) + parser.add_argument("--base-model", required=True) + parser.add_argument("--architecture-json", required=True) + parser.add_argument("--output-json", required=True) + parser.add_argument( + "--allow-unsupported-arch", + dest="allow_unvalidated_arch", + action="store_true", + ) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + architecture = ArchitectureReport.model_validate_json( + Path(args.architecture_json).read_text(encoding="utf-8") + ) + stage_runner = _STAGE_RUNNERS[args.stage] + result = stage_runner( + base_model=args.base_model, + architecture=architecture, + allow_unvalidated_arch=args.allow_unvalidated_arch, + ) + Path(args.output_json).write_text( + result.model_dump_json(indent=2), + encoding="utf-8", + ) + + +if __name__ == "__main__": + main() diff --git a/tests/integration/megatron/runtime_isolation/README.md b/tests/integration/megatron/runtime_isolation/README.md new file mode 100644 index 000000000..d54f9ad85 --- /dev/null +++ b/tests/integration/megatron/runtime_isolation/README.md @@ -0,0 +1,28 @@ +# Megatron Runtime Isolation Tests + +Runtime-boundary and vLLM-isolation integration tests live in this directory. + +Rules: + +- Write runtime-isolation artifacts under `tests/integration/megatron/runtime_isolation/artifacts/`. +- Do not run these tests from a dirty worktree. +- Any code involved in a test run must be committed before the test starts. +- Every artifact set must include the exact commit hash it ran from. + +Live smokes: + +- `test_live_runtime_server_smoke.py` validates the external runtime directly. +- `test_live_megatron_backend_smoke.py` validates ART-level Megatron shared and dedicated runtime flows. +- `test_live_yes_no_trainability.py` validates workflow-style yes/no trainability on the requested backend/mode matrix. +- `test_live_local_backend_smoke.py` validates the ART `LocalBackend` path. +- Both are opt-in and are expected to write artifacts for every attempted run. + +Use the `artifact_dir` fixture from [conftest.py](./conftest.py) for artifact output. + +That fixture: + +- refuses to run when the worktree is dirty +- creates a per-test artifact directory under `artifacts/` +- writes `run_metadata.json` with the exact commit hash and test node id + +Artifact directories are git-ignored by design so reproducible outputs do not dirty the worktree. diff --git a/tests/integration/megatron/runtime_isolation/__init__.py b/tests/integration/megatron/runtime_isolation/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/integration/megatron/runtime_isolation/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/integration/megatron/runtime_isolation/artifacts.py b/tests/integration/megatron/runtime_isolation/artifacts.py new file mode 100644 index 000000000..da754db97 --- /dev/null +++ b/tests/integration/megatron/runtime_isolation/artifacts.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from datetime import datetime, timezone +import os +from pathlib import Path +import re +import subprocess +import sys +import uuid + +from pydantic import BaseModel + +TEST_ROOT = Path(__file__).resolve().parent +ARTIFACTS_ROOT = TEST_ROOT / "artifacts" +REPO_ROOT = Path( + subprocess.run( + ["git", "rev-parse", "--show-toplevel"], + cwd=TEST_ROOT, + check=True, + capture_output=True, + text=True, + ).stdout.strip() +) + + +class ArtifactMetadata(BaseModel): + commit: str + branch: str + test_nodeid: str + created_at_utc: str + python_executable: str + artifact_dir: str + + +def _git(*args: str) -> str: + return subprocess.run( + ["git", *args], + cwd=REPO_ROOT, + check=True, + capture_output=True, + text=True, + ).stdout.strip() + + +def _dirty_lines() -> list[str]: + output = _git("status", "--porcelain=v1", "--untracked-files=all") + return [line for line in output.splitlines() if line] + + +def require_clean_git_state() -> str: + dirty = _dirty_lines() + if dirty: + rendered = "\n".join(dirty) + raise RuntimeError( + "Megatron runtime-isolation tests require a fully committed worktree.\n" + "Commit or remove these changes before running tests:\n" + f"{rendered}" + ) + return _git("rev-parse", "HEAD") + + +def _sanitize_nodeid(nodeid: str) -> str: + collapsed = re.sub(r"[^A-Za-z0-9_.-]+", "_", nodeid.strip()) + return collapsed.strip("._") or "unnamed_test" + + +def create_artifact_dir(test_nodeid: str) -> Path: + commit = require_clean_git_state() + branch = _git("branch", "--show-current") + test_name = _sanitize_nodeid(test_nodeid) + timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + run_id = f"{timestamp}_{os.getpid()}_{uuid.uuid4().hex[:8]}" + artifact_dir = ARTIFACTS_ROOT / test_name / commit[:12] / run_id + artifact_dir.mkdir(parents=True, exist_ok=False) + + metadata = ArtifactMetadata( + commit=commit, + branch=branch, + test_nodeid=test_nodeid, + created_at_utc=datetime.now(timezone.utc).isoformat(), + python_executable=sys.executable, + artifact_dir=str(artifact_dir), + ) + (artifact_dir / "run_metadata.json").write_text( + metadata.model_dump_json(indent=2) + "\n", + encoding="utf-8", + ) + return artifact_dir diff --git a/tests/integration/megatron/runtime_isolation/artifacts/.gitignore b/tests/integration/megatron/runtime_isolation/artifacts/.gitignore new file mode 100644 index 000000000..d6b7ef32c --- /dev/null +++ b/tests/integration/megatron/runtime_isolation/artifacts/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/tests/integration/megatron/runtime_isolation/conftest.py b/tests/integration/megatron/runtime_isolation/conftest.py new file mode 100644 index 000000000..ca3d03e72 --- /dev/null +++ b/tests/integration/megatron/runtime_isolation/conftest.py @@ -0,0 +1,36 @@ +from pathlib import Path + +import pytest + +from .artifacts import create_artifact_dir, require_clean_git_state + +TEST_ROOT = Path(__file__).resolve().parent +ARTIFACTS_ROOT = TEST_ROOT / "artifacts" + + +@pytest.fixture(scope="session", autouse=True) +def _require_clean_commit_state() -> None: + require_clean_git_state() + + +@pytest.fixture +def artifact_dir(request: pytest.FixtureRequest) -> Path: + return create_artifact_dir(request.node.nodeid) + + +def pytest_collection_modifyitems( + session: pytest.Session, + config: pytest.Config, + items: list[pytest.Item], +) -> None: + del session, config + yes_no_order = { + "test_megatron_dedicated_yes_no_trainability_live": 0, + "test_megatron_shared_yes_no_trainability_live": 1, + "test_unsloth_dedicated_yes_no_trainability_live": 2, + } + + def _sort_key(item: pytest.Item) -> tuple[int, str]: + return (yes_no_order.get(item.name, 99), item.nodeid) + + items.sort(key=_sort_key) diff --git a/tests/integration/megatron/runtime_isolation/test_art_import_boundary.py b/tests/integration/megatron/runtime_isolation/test_art_import_boundary.py new file mode 100644 index 000000000..dd7d57602 --- /dev/null +++ b/tests/integration/megatron/runtime_isolation/test_art_import_boundary.py @@ -0,0 +1,86 @@ +import json +import os +from pathlib import Path +import subprocess +import sys + +ROOT = Path(__file__).resolve().parents[4] + + +def _run( + command: list[str], + *, + artifact_dir: Path, + env: dict[str, str] | None = None, +) -> subprocess.CompletedProcess[str]: + result = subprocess.run( + command, + cwd=ROOT, + env=env, + check=True, + capture_output=True, + text=True, + ) + (artifact_dir / "stdout.txt").write_text(result.stdout) + (artifact_dir / "stderr.txt").write_text(result.stderr) + return result + + +def _load_json_from_stdout(stdout: str) -> dict[str, object]: + return json.loads(stdout.strip().splitlines()[-1]) + + +def test_art_import_does_not_require_vllm_or_mutate_compile_threads( + artifact_dir: Path, +) -> None: + env = dict(os.environ) + env.pop("TORCHINDUCTOR_COMPILE_THREADS", None) + result = _run( + [ + sys.executable, + "-c", + ( + "import importlib.util, json, os; " + "before = os.environ.get('TORCHINDUCTOR_COMPILE_THREADS'); " + "import art; " + "after = os.environ.get('TORCHINDUCTOR_COMPILE_THREADS'); " + "print(json.dumps({" + "'before': before, " + "'after': after, " + "'has_vllm': importlib.util.find_spec('vllm') is not None" + "}))" + ), + ], + artifact_dir=artifact_dir, + env=env, + ) + payload = _load_json_from_stdout(result.stdout) + assert payload["has_vllm"] is False + assert payload["before"] is None + assert payload["after"] is None + + +def test_service_modules_import_without_vllm(artifact_dir: Path) -> None: + result = _run( + [ + sys.executable, + "-c", + ( + "import importlib, json; " + "modules = [" + "'art.unsloth.service', " + "'art.megatron.service', " + "'art.megatron.weights.merged_weight_export'" + "]; " + "loaded = [importlib.import_module(name).__name__ for name in modules]; " + "print(json.dumps({'loaded': loaded}))" + ), + ], + artifact_dir=artifact_dir, + ) + payload = _load_json_from_stdout(result.stdout) + assert payload["loaded"] == [ + "art.unsloth.service", + "art.megatron.service", + "art.megatron.weights.merged_weight_export", + ] diff --git a/tests/integration/megatron/runtime_isolation/test_art_separation_contract.py b/tests/integration/megatron/runtime_isolation/test_art_separation_contract.py new file mode 100644 index 000000000..7905c06eb --- /dev/null +++ b/tests/integration/megatron/runtime_isolation/test_art_separation_contract.py @@ -0,0 +1,34 @@ +from pathlib import Path +import tomllib + +ROOT = Path(__file__).resolve().parents[4] + + +def test_art_source_has_no_vllm_imports() -> None: + offenders: list[str] = [] + for path in sorted((ROOT / "src" / "art").rglob("*.py")): + for line_number, line in enumerate(path.read_text().splitlines(), start=1): + stripped = line.strip() + if stripped.startswith("import vllm") or stripped.startswith("from vllm"): + offenders.append(f"{path.relative_to(ROOT)}:{line_number}") + assert offenders == [] + + +def test_art_pyproject_has_no_vllm_dependency_or_plugin_entrypoint() -> None: + pyproject = tomllib.loads((ROOT / "pyproject.toml").read_text()) + project = pyproject["project"] + backend = project["optional-dependencies"]["backend"] + megatron = project["optional-dependencies"]["megatron"] + dev = pyproject["dependency-groups"]["dev"] + + def _contains_vllm(values: list[str]) -> bool: + return any( + value.startswith("vllm") or value == "art-vllm-runtime" for value in values + ) + + assert not _contains_vllm(backend) + assert not _contains_vllm(megatron) + assert not _contains_vllm(dev) + assert "entry-points" not in project or "vllm.general_plugins" not in project.get( + "entry-points", {} + ) diff --git a/tests/integration/megatron/runtime_isolation/test_client.py b/tests/integration/megatron/runtime_isolation/test_client.py new file mode 100644 index 000000000..7d311d1d9 --- /dev/null +++ b/tests/integration/megatron/runtime_isolation/test_client.py @@ -0,0 +1,44 @@ +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from art.megatron.runtime.client import stream_megatron_job, write_megatron_job +from art.megatron.runtime.jobs import ( + MegatronSyncJob, + MergedWeightTransferInitInfo, + MergedWeightTransferSpec, +) + + +@pytest.mark.asyncio +async def test_stream_megatron_job_raises_when_worker_exits( + tmp_path: Path, +) -> None: + job_path = tmp_path / "job.json" + log_path = tmp_path / "job.log" + job = MegatronSyncJob( + lora_path="/tmp/lora", + merged_weight_transfer=MergedWeightTransferSpec( + init_info=MergedWeightTransferInitInfo( + master_address="127.0.0.1", + master_port=12345, + rank_offset=1, + world_size=2, + ), + vllm_base_url="http://127.0.0.1:8000", + served_model_name="test@0", + ), + log_path=str(log_path), + ) + write_megatron_job(job, job_path=str(job_path)) + + with pytest.raises(RuntimeError, match="Megatron worker exited with code 17"): + async for _ in stream_megatron_job( + job, + job_path=str(job_path), + process=SimpleNamespace(returncode=17), + process_log_path="/tmp/megatron-runtime.log", + poll_interval=0.0, + ): + pass diff --git a/tests/integration/megatron/runtime_isolation/test_live_local_backend_smoke.py b/tests/integration/megatron/runtime_isolation/test_live_local_backend_smoke.py new file mode 100644 index 000000000..4849ca319 --- /dev/null +++ b/tests/integration/megatron/runtime_isolation/test_live_local_backend_smoke.py @@ -0,0 +1,109 @@ +import json +import os +from pathlib import Path +import uuid + +import pytest + +torch = pytest.importorskip("torch") + +import art +from art.local import LocalBackend + +DEFAULT_BASE_MODEL = "Qwen/Qwen3-0.6B" +DEFAULT_GPU_MEMORY_UTILIZATION = 0.12 +DEFAULT_MAX_MODEL_LEN = 512 +DEFAULT_MAX_SEQ_LENGTH = 512 +LIVE_SMOKE_ENV = "ART_RUN_LIVE_VLLM_SEPARATION" + + +def _require_live_smoke_opt_in() -> None: + if os.environ.get(LIVE_SMOKE_ENV) != "1": + pytest.skip(f"set {LIVE_SMOKE_ENV}=1 to run the live runtime smoke") + + +def _safe_gpu_memory_utilization() -> float: + min_free_gib = float(os.environ.get("ART_TEST_MIN_FREE_GPU_GIB", "8")) + free_bytes, total_bytes = torch.cuda.mem_get_info() + free_gib = free_bytes / (1024**3) + if free_gib < min_free_gib: + pytest.skip( + f"Insufficient free GPU memory for live vLLM separation smoke: " + f"{free_gib:.1f} GiB free < {min_free_gib:.1f} GiB required." + ) + requested = float( + os.environ.get( + "ART_TEST_GPU_MEMORY_UTILIZATION", + str(DEFAULT_GPU_MEMORY_UTILIZATION), + ) + ) + return max(0.02, min(requested, (free_bytes / total_bytes) * 0.8)) + + +def _live_test_config() -> art.dev.InternalModelConfig: + return { + "rollout_weights_mode": "lora", + "engine_args": { + "gpu_memory_utilization": _safe_gpu_memory_utilization(), + "max_model_len": int( + os.environ.get("ART_TEST_MAX_MODEL_LEN", str(DEFAULT_MAX_MODEL_LEN)) + ), + "max_num_seqs": 4, + "enforce_eager": True, + }, + "init_args": { + "max_seq_length": int( + os.environ.get("ART_TEST_MAX_SEQ_LENGTH", str(DEFAULT_MAX_SEQ_LENGTH)) + ), + }, + } + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA available") +@pytest.mark.asyncio +async def test_local_backend_external_runtime_live_smoke( + tmp_path: Path, + artifact_dir: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + _require_live_smoke_opt_in() + monkeypatch.setenv("WANDB_MODE", "offline") + + model_name = f"vllm-separation-live-{uuid.uuid4().hex[:8]}" + backend = LocalBackend(path=str(tmp_path)) + model = art.TrainableModel( + name=model_name, + project="integration-tests", + base_model=os.environ.get("BASE_MODEL", DEFAULT_BASE_MODEL), + _internal_config=_live_test_config(), + ) + + try: + await model.register(backend) + client = model.openai_client() + try: + step0_name = model.get_inference_name(step=0) + model_ids = [model_info.id async for model_info in client.models.list()] + completion = await client.chat.completions.create( + model=step0_name, + messages=[{"role": "user", "content": "Say hello."}], + max_tokens=8, + timeout=120, + logprobs=True, + top_logprobs=0, + ) + payload = { + "step0_name": step0_name, + "model_ids": model_ids, + "text": completion.choices[0].message.content, + "has_logprobs": completion.choices[0].logprobs is not None, + } + (artifact_dir / "live_smoke_result.json").write_text( + json.dumps(payload, indent=2, sort_keys=True) + ) + assert step0_name in model_ids + assert completion.choices[0].logprobs is not None + finally: + await client.close() + finally: + await backend.close() diff --git a/tests/integration/megatron/runtime_isolation/test_live_megatron_backend_smoke.py b/tests/integration/megatron/runtime_isolation/test_live_megatron_backend_smoke.py new file mode 100644 index 000000000..ad3ce4ffc --- /dev/null +++ b/tests/integration/megatron/runtime_isolation/test_live_megatron_backend_smoke.py @@ -0,0 +1,602 @@ +import asyncio +from contextlib import asynccontextmanager +import json +import os +from pathlib import Path +from typing import Any, AsyncIterator, cast +import uuid + +import httpx +import pytest + +import art +from art import dev +from art.megatron.runtime.backend import MegatronBackend +from art.megatron.service import MegatronService + +from ..model_support.oracle_harness import ORACLE_TOPOLOGY, Topology +from ..model_support.oracle_worker import provider_topology_env +from ..trainability import ( + _build_trainable_groups, + _build_training_groups, + _engine_args_for_yes_no_trainability, + _evaluate_model, + _wandb_disabled, + _warmup_model, + build_prompts, +) + +torch = pytest.importorskip("torch") + +DEFAULT_BASE_MODEL = "Qwen/Qwen3-30B-A3B-Instruct-2507" +DEFAULT_MAX_SEQ_LENGTH = 1024 +DEFAULT_PACKED_SEQUENCE_LENGTH = 1024 +DEDICATED_MERGED_ENV = "ART_RUN_LIVE_MEGATRON_MERGED_SMOKE" +DEDICATED_MULTIRANK_MERGED_ENV = "ART_RUN_LIVE_MEGATRON_MULTIRANK_MERGED_SMOKE" +SHARED_LORA_ENV = "ART_RUN_LIVE_MEGATRON_SHARED_SMOKE" +SHARED_LONG_LORA_ENV = "ART_RUN_LIVE_MEGATRON_SHARED_LONG_SMOKE" +SHARED_TOPOLOGY = Topology(tp=2, ep=2, etp=1, dp=1, sp=True) + + +def _base_model() -> str: + return os.environ.get( + "ART_LIVE_MEGATRON_BASE_MODEL", + os.environ.get("BASE_MODEL", DEFAULT_BASE_MODEL), + ) + + +def _max_seq_length() -> int: + return int(os.environ.get("ART_TEST_MAX_SEQ_LENGTH", str(DEFAULT_MAX_SEQ_LENGTH))) + + +def _packed_sequence_length() -> int: + return int( + os.environ.get( + "ART_TEST_PACKED_SEQUENCE_LENGTH", + str(DEFAULT_PACKED_SEQUENCE_LENGTH), + ) + ) + + +def _train_group_prompts() -> list[str]: + prompt_count = int(os.environ.get("ART_TEST_MEGATRON_PROMPT_COUNT", "2")) + return build_prompts()[: max(1, prompt_count)] + + +def _rollouts_per_prompt() -> int: + return int(os.environ.get("ART_TEST_MEGATRON_ROLLOUTS_PER_PROMPT", "2")) + + +def _trainer_gpu_ids() -> list[int]: + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + raise RuntimeError("Need at least 2 visible CUDA GPUs for Megatron live smokes") + return [0] + + +def _inference_gpu_ids() -> list[int]: + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + raise RuntimeError("Need at least 2 visible CUDA GPUs for Megatron live smokes") + return [1] + + +def _multirank_trainer_gpu_ids() -> list[int]: + if not torch.cuda.is_available() or torch.cuda.device_count() < 3: + raise RuntimeError( + "Need at least 3 visible CUDA GPUs for multi-rank Megatron merged smoke" + ) + return [0, 1] + + +def _multirank_inference_gpu_ids() -> list[int]: + if not torch.cuda.is_available() or torch.cuda.device_count() < 3: + raise RuntimeError( + "Need at least 3 visible CUDA GPUs for multi-rank Megatron merged smoke" + ) + return [2] + + +def _require_opt_in(env_name: str) -> None: + if os.environ.get(env_name) != "1": + pytest.skip(f"set {env_name}=1 to run this live Megatron smoke") + + +def _shared_live_config() -> dev.InternalModelConfig: + return cast( + dev.InternalModelConfig, + { + "rollout_weights_mode": "lora", + "engine_args": { + **_engine_args_for_yes_no_trainability(inference_gpu_ids=[0, 1]), + "tensor_parallel_size": 2, + "enable_expert_parallel": True, + "enable_sleep_mode": True, + }, + "init_args": {"max_seq_length": _max_seq_length()}, + }, + ) + + +def _dedicated_merged_config() -> dev.InternalModelConfig: + return { + "trainer_gpu_ids": _trainer_gpu_ids(), + "inference_gpu_ids": _inference_gpu_ids(), + "rollout_weights_mode": "merged", + "engine_args": { + **_engine_args_for_yes_no_trainability( + inference_gpu_ids=_inference_gpu_ids() + ), + }, + "init_args": {"max_seq_length": _max_seq_length()}, + } + + +def _dedicated_multirank_merged_config() -> dev.InternalModelConfig: + return { + "trainer_gpu_ids": _multirank_trainer_gpu_ids(), + "inference_gpu_ids": _multirank_inference_gpu_ids(), + "rollout_weights_mode": "merged", + "engine_args": { + **_engine_args_for_yes_no_trainability( + inference_gpu_ids=_multirank_inference_gpu_ids() + ), + }, + "init_args": {"max_seq_length": _max_seq_length()}, + } + + +def _shared_long_steps() -> int: + return int(os.environ.get("ART_TEST_MEGATRON_SHARED_LONG_STEPS", "10")) + + +async def _list_model_ids(model: art.TrainableModel) -> list[str]: + client = model.openai_client() + return [model_info.id async for model_info in client.models.list()] + + +async def _chat_snapshot(model: art.TrainableModel, *, step: int) -> dict[str, object]: + client = model.openai_client() + completion = await client.chat.completions.create( + messages=[{"role": "user", "content": "Say hello."}], + model=model.get_inference_name(step=step), + max_tokens=8, + timeout=180.0, + logprobs=True, + top_logprobs=0, + ) + return { + "text": completion.choices[0].message.content, + "has_logprobs": completion.choices[0].logprobs is not None, + } + + +async def _runtime_is_sleeping(service: MegatronService) -> bool: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(f"{service._vllm_base_url}/is_sleeping") + response.raise_for_status() + return bool(response.json()["is_sleeping"]) + + +async def _wait_until_runtime_sleeping( + service: MegatronService, + *, + timeout_s: float = 300.0, + poll_s: float = 0.5, +) -> bool: + deadline = asyncio.get_running_loop().time() + timeout_s + while asyncio.get_running_loop().time() < deadline: + if await _runtime_is_sleeping(service): + return True + await asyncio.sleep(poll_s) + return False + + +@asynccontextmanager +async def _megatron_backend_context( + *, + backend_root: Path, + topology: Topology, +) -> AsyncIterator[MegatronBackend]: + with _wandb_disabled(): + with provider_topology_env(topology): + async with MegatronBackend( + path=str(backend_root), in_process=False + ) as backend: + yield backend + + +def _jitter_training_groups( + groups: list[art.TrajectoryGroup], + *, + step: int, +) -> list[art.TrajectoryGroup]: + jittered_groups: list[art.TrajectoryGroup] = [] + for group_index, group in enumerate(groups): + jittered_trajectories: list[art.Trajectory] = [] + for trajectory_index, trajectory in enumerate(group.trajectories): + reward = float(trajectory.reward) + 1e-3 * ( + 1 + step + group_index + trajectory_index + ) + jittered_trajectories.append( + art.Trajectory( + messages_and_choices=trajectory.messages_and_choices, + reward=reward, + ) + ) + jittered_groups.append(art.TrajectoryGroup(jittered_trajectories)) + return jittered_groups + + +async def _build_jittered_training_groups( + model: art.TrainableModel, + *, + step: int, + rollouts_per_prompt: int, +) -> list[art.TrajectoryGroup]: + if rollouts_per_prompt < 2: + raise ValueError("Shared Megatron long smoke requires rollouts_per_prompt >= 2") + return _jitter_training_groups( + await _build_training_groups( + model, + base_model=model.base_model, + prompts=_train_group_prompts(), + rollouts_per_prompt=rollouts_per_prompt, + ), + step=step, + ) + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Need at least 2 CUDA GPUs for Megatron live smokes", +) +@pytest.mark.asyncio +async def test_megatron_backend_shared_lora_runtime_sleep_wake_live_smoke( + artifact_dir: Path, +) -> None: + _require_opt_in(SHARED_LORA_ENV) + backend_root = artifact_dir / "art_workspace" + backend_root.mkdir(parents=True, exist_ok=True) + + async with _megatron_backend_context( + backend_root=backend_root, + topology=SHARED_TOPOLOGY, + ) as backend: + model = art.TrainableModel( + name=f"megatron-shared-live-{uuid.uuid4().hex[:8]}", + project="integration-tests", + base_model=_base_model(), + _internal_config=_shared_live_config(), + report_metrics=[], + ) + await model.register(backend) + service = cast(MegatronService, await backend._get_service(model)) + prompts = _train_group_prompts() + await _warmup_model(model, base_model=model.base_model, prompt=prompts[0]) + step0_name = model.get_inference_name(step=0) + model_ids_before = await _list_model_ids(model) + train_groups = await _build_trainable_groups( + model, + base_model=model.base_model, + prompts=prompts, + rollouts_per_prompt=_rollouts_per_prompt(), + ) + train_task = asyncio.create_task( + backend.train( + model, + train_groups, + learning_rate=float(os.environ.get("ART_TEST_MEGATRON_LR", "1e-4")), + loss_fn="cispo", + packed_sequence_length=_packed_sequence_length(), + ) + ) + observed_sleep = False + try: + while not train_task.done(): + if await _runtime_is_sleeping(service): + observed_sleep = True + break + await asyncio.sleep(0.5) + assert observed_sleep or train_task.done() + result = await train_task + finally: + if not train_task.done(): + await train_task + + latest_step = int(result.step) + latest_name = model.get_inference_name(step=latest_step) + model_ids_after = await _list_model_ids(model) + eval_reward = await _evaluate_model( + model, + base_model=model.base_model, + prompts=prompts, + step=latest_step, + ) + latest_snapshot = await _chat_snapshot(model, step=latest_step) + runtime_sleep_after = await _runtime_is_sleeping(service) + payload = { + "base_model": model.base_model, + "output_dir": service.output_dir, + "step0_name": step0_name, + "latest_name": latest_name, + "latest_step": latest_step, + "model_ids_before": model_ids_before, + "model_ids_after": model_ids_after, + "observed_sleep": observed_sleep, + "runtime_sleep_after": runtime_sleep_after, + "eval_reward": eval_reward, + "latest_snapshot": latest_snapshot, + } + (artifact_dir / "shared_megatron_live_result.json").write_text( + json.dumps(payload, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + assert observed_sleep + assert runtime_sleep_after is False + assert latest_step > 0 + assert step0_name in model_ids_before + assert step0_name in model_ids_after + assert latest_name in model_ids_after + assert latest_snapshot["has_logprobs"] is True + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Need at least 2 CUDA GPUs for Megatron live smokes", +) +@pytest.mark.asyncio +async def test_megatron_backend_dedicated_merged_live_smoke( + artifact_dir: Path, +) -> None: + _require_opt_in(DEDICATED_MERGED_ENV) + backend_root = artifact_dir / "art_workspace" + backend_root.mkdir(parents=True, exist_ok=True) + + async with _megatron_backend_context( + backend_root=backend_root, + topology=ORACLE_TOPOLOGY, + ) as backend: + model = art.TrainableModel( + name=f"megatron-merged-live-{uuid.uuid4().hex[:8]}", + project="integration-tests", + base_model=_base_model(), + _internal_config=_dedicated_merged_config(), + report_metrics=[], + ) + await model.register(backend) + service = cast(MegatronService, await backend._get_service(model)) + prompts = _train_group_prompts() + await _warmup_model(model, base_model=model.base_model, prompt=prompts[0]) + step0_name = model.get_inference_name(step=0) + model_ids_before = await _list_model_ids(model) + train_groups = await _build_trainable_groups( + model, + base_model=model.base_model, + prompts=prompts, + rollouts_per_prompt=_rollouts_per_prompt(), + ) + result = await backend.train( + model, + train_groups, + learning_rate=float(os.environ.get("ART_TEST_MEGATRON_LR", "1e-4")), + loss_fn="cispo", + packed_sequence_length=_packed_sequence_length(), + ) + latest_step = int(result.step) + latest_name = model.get_inference_name(step=latest_step) + model_ids_after = await _list_model_ids(model) + eval_reward = await _evaluate_model( + model, + base_model=model.base_model, + prompts=prompts, + step=latest_step, + ) + latest_snapshot = await _chat_snapshot(model, step=latest_step) + payload = { + "base_model": model.base_model, + "output_dir": service.output_dir, + "step0_name": step0_name, + "latest_name": latest_name, + "latest_step": latest_step, + "model_ids_before": model_ids_before, + "model_ids_after": model_ids_after, + "eval_reward": eval_reward, + "latest_snapshot": latest_snapshot, + } + (artifact_dir / "dedicated_megatron_merged_live_result.json").write_text( + json.dumps(payload, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + assert latest_step > 0 + assert step0_name in model_ids_before + assert latest_name in model_ids_after + assert step0_name not in model_ids_after + assert latest_snapshot["has_logprobs"] is True + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 3, + reason="Need at least 3 CUDA GPUs for multi-rank Megatron merged smoke", +) +@pytest.mark.asyncio +async def test_megatron_backend_dedicated_multirank_merged_live_smoke( + artifact_dir: Path, +) -> None: + _require_opt_in(DEDICATED_MULTIRANK_MERGED_ENV) + backend_root = artifact_dir / "art_workspace" + backend_root.mkdir(parents=True, exist_ok=True) + + async with _megatron_backend_context( + backend_root=backend_root, + topology=SHARED_TOPOLOGY, + ) as backend: + model = art.TrainableModel( + name=f"megatron-multirank-merged-live-{uuid.uuid4().hex[:8]}", + project="integration-tests", + base_model=_base_model(), + _internal_config=_dedicated_multirank_merged_config(), + report_metrics=[], + ) + await model.register(backend) + service = cast(MegatronService, await backend._get_service(model)) + prompts = _train_group_prompts() + await _warmup_model(model, base_model=model.base_model, prompt=prompts[0]) + step0_name = model.get_inference_name(step=0) + model_ids_before = await _list_model_ids(model) + train_groups = await _build_trainable_groups( + model, + base_model=model.base_model, + prompts=prompts, + rollouts_per_prompt=_rollouts_per_prompt(), + ) + result = await backend.train( + model, + train_groups, + learning_rate=float(os.environ.get("ART_TEST_MEGATRON_LR", "1e-4")), + loss_fn="cispo", + packed_sequence_length=_packed_sequence_length(), + ) + latest_step = int(result.step) + latest_name = model.get_inference_name(step=latest_step) + model_ids_after = await _list_model_ids(model) + eval_reward = await _evaluate_model( + model, + base_model=model.base_model, + prompts=prompts, + step=latest_step, + ) + latest_snapshot = await _chat_snapshot(model, step=latest_step) + payload = { + "base_model": model.base_model, + "output_dir": service.output_dir, + "step0_name": step0_name, + "latest_name": latest_name, + "latest_step": latest_step, + "model_ids_before": model_ids_before, + "model_ids_after": model_ids_after, + "eval_reward": eval_reward, + "latest_snapshot": latest_snapshot, + "trainer_gpu_ids": _multirank_trainer_gpu_ids(), + "inference_gpu_ids": _multirank_inference_gpu_ids(), + "topology": SHARED_TOPOLOGY.model_dump(), + } + ( + artifact_dir / "dedicated_megatron_multirank_merged_live_result.json" + ).write_text( + json.dumps(payload, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + assert latest_step > 0 + assert step0_name in model_ids_before + assert latest_name in model_ids_after + assert step0_name not in model_ids_after + assert latest_snapshot["has_logprobs"] is True + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Need at least 2 CUDA GPUs for Megatron live smokes", +) +@pytest.mark.asyncio +async def test_megatron_backend_shared_lora_ten_step_live_smoke( + artifact_dir: Path, +) -> None: + _require_opt_in(SHARED_LONG_LORA_ENV) + backend_root = artifact_dir / "art_workspace" + backend_root.mkdir(parents=True, exist_ok=True) + + async with _megatron_backend_context( + backend_root=backend_root, + topology=SHARED_TOPOLOGY, + ) as backend: + model = art.TrainableModel( + name=f"megatron-shared-long-live-{uuid.uuid4().hex[:8]}", + project="integration-tests", + base_model=_base_model(), + _internal_config=_shared_live_config(), + report_metrics=[], + ) + await model.register(backend) + service = cast(MegatronService, await backend._get_service(model)) + prompts = _train_group_prompts() + await _warmup_model(model, base_model=model.base_model, prompt=prompts[0]) + step0_name = model.get_inference_name(step=0) + model_ids_before = await _list_model_ids(model) + step_reports: list[dict[str, object]] = [] + + for step_index in range(_shared_long_steps()): + train_groups = await _build_jittered_training_groups( + model, + step=step_index, + rollouts_per_prompt=_rollouts_per_prompt(), + ) + train_task = asyncio.create_task( + backend.train( + model, + train_groups, + learning_rate=float(os.environ.get("ART_TEST_MEGATRON_LR", "1e-4")), + loss_fn="cispo", + packed_sequence_length=_packed_sequence_length(), + ) + ) + observed_sleep = False + try: + while not train_task.done(): + if await _runtime_is_sleeping(service): + observed_sleep = True + break + await asyncio.sleep(0.5) + assert observed_sleep or train_task.done() + result = await train_task + finally: + if not train_task.done(): + await train_task + + latest_step = int(result.step) + eval_reward = await _evaluate_model( + model, + base_model=model.base_model, + prompts=prompts, + step=latest_step, + ) + step_reports.append( + { + "step": latest_step, + "observed_sleep": observed_sleep, + "eval_reward": eval_reward, + "train_reward": sum( + trajectory.reward + for group in train_groups + for trajectory in group.trajectories + ) + / max(1, sum(len(group.trajectories) for group in train_groups)), + } + ) + + latest_step = int(cast(Any, step_reports[-1]["step"])) + latest_name = model.get_inference_name(step=latest_step) + model_ids_after = await _list_model_ids(model) + latest_snapshot = await _chat_snapshot(model, step=latest_step) + runtime_sleep_after = await _runtime_is_sleeping(service) + payload = { + "base_model": model.base_model, + "output_dir": service.output_dir, + "step0_name": step0_name, + "latest_name": latest_name, + "latest_step": latest_step, + "model_ids_before": model_ids_before, + "model_ids_after": model_ids_after, + "runtime_sleep_after": runtime_sleep_after, + "latest_snapshot": latest_snapshot, + "step_reports": step_reports, + } + (artifact_dir / "shared_megatron_ten_step_live_result.json").write_text( + json.dumps(payload, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + assert all(bool(step_report["observed_sleep"]) for step_report in step_reports) + assert runtime_sleep_after is False + assert latest_step >= _shared_long_steps() + assert step0_name in model_ids_before + assert step0_name in model_ids_after + assert latest_name in model_ids_after + assert latest_snapshot["has_logprobs"] is True diff --git a/tests/integration/megatron/runtime_isolation/test_live_runtime_server_smoke.py b/tests/integration/megatron/runtime_isolation/test_live_runtime_server_smoke.py new file mode 100644 index 000000000..5773873c1 --- /dev/null +++ b/tests/integration/megatron/runtime_isolation/test_live_runtime_server_smoke.py @@ -0,0 +1,180 @@ +import json +import os +from pathlib import Path +import socket +import subprocess +import uuid + +import httpx +import pytest + +import art.vllm_runtime as runtime + +torch = pytest.importorskip("torch") + +ROOT = Path(__file__).resolve().parents[4] +DEFAULT_BASE_MODEL = "Qwen/Qwen3-0.6B" +DEFAULT_GPU_MEMORY_UTILIZATION = 0.12 +DEFAULT_MAX_MODEL_LEN = 512 +LIVE_RUNTIME_SMOKE_ENV = "ART_RUN_LIVE_VLLM_RUNTIME_SMOKE" + + +def _require_live_runtime_smoke_opt_in() -> None: + if os.environ.get(LIVE_RUNTIME_SMOKE_ENV) != "1": + pytest.skip(f"set {LIVE_RUNTIME_SMOKE_ENV}=1 to run the live runtime smoke") + + +def _safe_gpu_memory_utilization() -> float: + min_free_gib = float(os.environ.get("ART_TEST_MIN_FREE_GPU_GIB", "8")) + free_bytes, total_bytes = torch.cuda.mem_get_info() + free_gib = free_bytes / (1024**3) + if free_gib < min_free_gib: + pytest.skip( + f"Insufficient free GPU memory for live runtime smoke: " + f"{free_gib:.1f} GiB free < {min_free_gib:.1f} GiB required." + ) + requested = float( + os.environ.get( + "ART_TEST_GPU_MEMORY_UTILIZATION", + str(DEFAULT_GPU_MEMORY_UTILIZATION), + ) + ) + return max(0.02, min(requested, (free_bytes / total_bytes) * 0.8)) + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA available") +@pytest.mark.asyncio +async def test_external_runtime_server_live_smoke( + tmp_path: Path, + artifact_dir: Path, +) -> None: + _require_live_runtime_smoke_opt_in() + + port = _find_free_port() + served_model_name = f"vllm-runtime-live-{uuid.uuid4().hex[:8]}" + renamed_model_name = f"{served_model_name}@renamed" + log_path = artifact_dir / "runtime.log" + launch_config = runtime.VllmRuntimeLaunchConfig( + base_model=os.environ.get("BASE_MODEL", DEFAULT_BASE_MODEL), + port=port, + host="127.0.0.1", + cuda_visible_devices=os.environ.get("CUDA_VISIBLE_DEVICES", "0"), + lora_path=str(tmp_path / "placeholder_lora"), + served_model_name=served_model_name, + rollout_weights_mode="merged", + engine_args={ + "gpu_memory_utilization": _safe_gpu_memory_utilization(), + "max_model_len": int( + os.environ.get("ART_TEST_MAX_MODEL_LEN", str(DEFAULT_MAX_MODEL_LEN)) + ), + "max_num_seqs": 4, + "enforce_eager": True, + }, + ) + command = runtime.build_vllm_runtime_server_cmd(launch_config) + env = os.environ.copy() + env["WANDB_MODE"] = "offline" + + with log_path.open("w", encoding="utf-8") as log_file: + process = subprocess.Popen( + command, + cwd=ROOT, + env=env, + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + ) + try: + await runtime.wait_for_vllm_runtime( + process=process, + host=launch_config.host, + port=launch_config.port, + timeout=600.0, + ) + async with httpx.AsyncClient( + base_url=f"http://{launch_config.host}:{launch_config.port}", + timeout=120.0, + ) as client: + models_response = await client.get("/v1/models") + models_response.raise_for_status() + original_model_ids = [ + model_info["id"] for model_info in models_response.json()["data"] + ] + + rename_response = await client.post( + "/art/set_served_model_name", + json={"name": renamed_model_name}, + ) + rename_response.raise_for_status() + + renamed_models_response = await client.get("/v1/models") + renamed_models_response.raise_for_status() + renamed_model_ids = [ + model_info["id"] + for model_info in renamed_models_response.json()["data"] + ] + + sleep_response = await client.post( + "/sleep", + params={"level": 1, "mode": "wait"}, + ) + sleep_response.raise_for_status() + sleeping_response = await client.get("/is_sleeping") + sleeping_response.raise_for_status() + sleeping_before_wake = bool(sleeping_response.json()["is_sleeping"]) + + wake_response = await client.post("/wake_up") + wake_response.raise_for_status() + awake_response = await client.get("/is_sleeping") + awake_response.raise_for_status() + sleeping_after_wake = bool(awake_response.json()["is_sleeping"]) + + completion_response = await client.post( + "/v1/chat/completions", + json={ + "model": renamed_model_name, + "messages": [{"role": "user", "content": "Say hello."}], + "max_tokens": 8, + "logprobs": True, + "top_logprobs": 0, + }, + ) + completion_response.raise_for_status() + completion = completion_response.json() + + (artifact_dir / "runtime_smoke_result.json").write_text( + json.dumps( + { + "command": command, + "base_model": launch_config.base_model, + "original_model_ids": original_model_ids, + "renamed_model_ids": renamed_model_ids, + "sleeping_before_wake": sleeping_before_wake, + "sleeping_after_wake": sleeping_after_wake, + "text": completion["choices"][0]["message"]["content"], + "has_logprobs": completion["choices"][0]["logprobs"] is not None, + }, + indent=2, + sort_keys=True, + ) + + "\n", + encoding="utf-8", + ) + assert served_model_name in original_model_ids + assert renamed_model_name in renamed_model_ids + assert sleeping_before_wake is True + assert sleeping_after_wake is False + assert completion["choices"][0]["logprobs"] is not None + finally: + process.terminate() + try: + process.wait(timeout=30) + except subprocess.TimeoutExpired: + process.kill() + process.wait(timeout=30) diff --git a/tests/integration/megatron/runtime_isolation/test_runtime_launcher.py b/tests/integration/megatron/runtime_isolation/test_runtime_launcher.py new file mode 100644 index 000000000..0cb4bac95 --- /dev/null +++ b/tests/integration/megatron/runtime_isolation/test_runtime_launcher.py @@ -0,0 +1,270 @@ +import importlib.util +import os +from pathlib import Path + +import pytest + +ROOT = Path(__file__).resolve().parents[4] +spec = importlib.util.spec_from_file_location( + "art_vllm_runtime_launcher", ROOT / "src" / "art" / "vllm_runtime.py" +) +assert spec is not None and spec.loader is not None +runtime = importlib.util.module_from_spec(spec) +spec.loader.exec_module(runtime) + + +def test_get_vllm_runtime_project_root_defaults_to_repo_subdir(monkeypatch) -> None: + monkeypatch.delenv("ART_VLLM_RUNTIME_PROJECT_ROOT", raising=False) + runtime_root = runtime.get_vllm_runtime_project_root() + assert runtime_root == ROOT / "vllm_runtime" + + +def test_get_vllm_runtime_project_root_honors_override(monkeypatch) -> None: + monkeypatch.setenv("ART_VLLM_RUNTIME_PROJECT_ROOT", "/tmp/custom-runtime") + assert runtime.get_vllm_runtime_project_root() == Path("/tmp/custom-runtime") + + +def test_build_runtime_server_cmd_uses_runtime_project( + monkeypatch, + tmp_path: Path, +) -> None: + monkeypatch.delenv("ART_VLLM_RUNTIME_BIN", raising=False) + runtime_root = tmp_path / "custom-runtime" + runtime_bin = runtime_root / ".venv" / "bin" / "art-vllm-runtime-server" + runtime_bin.parent.mkdir(parents=True, exist_ok=True) + runtime_bin.write_text("#!/bin/sh\n", encoding="ascii") + monkeypatch.setenv("ART_VLLM_RUNTIME_PROJECT_ROOT", str(runtime_root)) + command = runtime.build_vllm_runtime_server_cmd( + runtime.VllmRuntimeLaunchConfig( + base_model="Qwen/Qwen3-14B", + port=8000, + host="127.0.0.1", + cuda_visible_devices="1", + lora_path="/tmp/lora", + served_model_name="test@0", + rollout_weights_mode="merged", + engine_args={"weight_transfer_config": {"backend": "nccl"}}, + server_args={"tool_call_parser": "hermes"}, + ) + ) + assert command[0] == str(runtime_bin) + assert "--model=Qwen/Qwen3-14B" in command + assert ( + '--engine-args-json={"weight_transfer_config": {"backend": "nccl"}}' in command + ) + assert '--server-args-json={"tool_call_parser": "hermes"}' in command + + +def test_build_runtime_server_cmd_honors_runtime_bin_override(monkeypatch) -> None: + monkeypatch.setenv("ART_VLLM_RUNTIME_BIN", "/opt/art/bin/runtime --wrapped") + command = runtime.build_vllm_runtime_server_cmd( + runtime.VllmRuntimeLaunchConfig( + base_model="Qwen/Qwen3-14B", + port=8000, + host="127.0.0.1", + cuda_visible_devices="1", + lora_path="/tmp/lora", + served_model_name="test@0", + rollout_weights_mode="merged", + ) + ) + assert command[:2] == ["/opt/art/bin/runtime", "--wrapped"] + + +def test_cleanup_old_managed_runtimes_only_deletes_marked_venvs( + monkeypatch, + tmp_path: Path, +) -> None: + monkeypatch.delenv("ART_VLLM_RUNTIME_KEEP_OLD", raising=False) + cache_root = tmp_path.resolve() + keep_hash = "a" * 64 + old_hash = "b" * 64 + invalid_hash = "c" * 64 + + def write_runtime(path: Path, manifest_hash: str) -> None: + (path / ".venv").mkdir(parents=True) + (path / ".venv" / "pyvenv.cfg").write_text("venv\n") + marker = runtime.VllmRuntimeInstallMarker( + runtime_version="0.1.0", + protocol_version=runtime.RUNTIME_PROTOCOL_VERSION, + manifest_hash=manifest_hash, + runtime_wheel_sha256="wheel", + cache_root=str(cache_root), + ) + runtime._install_marker_path(path).write_text(marker.model_dump_json()) + + keep_dir = cache_root / keep_hash + old_dir = cache_root / old_hash + invalid_dir = cache_root / invalid_hash + arbitrary_dir = cache_root / "not-art" + write_runtime(keep_dir, keep_hash) + write_runtime(old_dir, old_hash) + invalid_dir.mkdir() + arbitrary_dir.mkdir() + (arbitrary_dir / "important.txt").write_text("do not delete\n") + + runtime._cleanup_old_managed_runtimes(cache_root, keep_hash=keep_hash) + + assert keep_dir.exists() + assert not old_dir.exists() + assert invalid_dir.exists() + assert arbitrary_dir.exists() + assert (arbitrary_dir / "important.txt").exists() + + +def test_cleanup_old_managed_runtimes_respects_keep_old( + monkeypatch, + tmp_path: Path, +) -> None: + monkeypatch.setenv("ART_VLLM_RUNTIME_KEEP_OLD", "1") + old_hash = "d" * 64 + old_dir = tmp_path / old_hash + (old_dir / ".venv").mkdir(parents=True) + (old_dir / ".venv" / "pyvenv.cfg").write_text("venv\n") + marker = runtime.VllmRuntimeInstallMarker( + runtime_version="0.1.0", + protocol_version=runtime.RUNTIME_PROTOCOL_VERSION, + manifest_hash=old_hash, + runtime_wheel_sha256="wheel", + cache_root=str(tmp_path.resolve()), + ) + runtime._install_marker_path(old_dir).write_text(marker.model_dump_json()) + + runtime._cleanup_old_managed_runtimes(tmp_path.resolve(), keep_hash="e" * 64) + + assert old_dir.exists() + + +def test_install_managed_runtime_installs_entrypoint_after_promote( + monkeypatch, + tmp_path: Path, +) -> None: + bundle_dir = tmp_path / "bundle" + bundle_dir.mkdir() + runtime_wheel = bundle_dir / "art_vllm_runtime-0.1.0-py3-none-any.whl" + pyproject = bundle_dir / "pyproject.toml" + lockfile = bundle_dir / "uv.lock" + runtime_wheel.write_text("wheel\n") + pyproject.write_text("[project]\nname = 'art-vllm-runtime'\n") + lockfile.write_text("version = 1\n") + manifest = runtime.VllmRuntimeManifest( + art_version="0.5.17", + runtime_version="0.1.0", + python=">=3.11", + runtime_wheel=runtime_wheel.name, + runtime_wheel_sha256=runtime._sha256_file(runtime_wheel), + pyproject_sha256=runtime._sha256_file(pyproject), + lockfile_sha256=runtime._sha256_file(lockfile), + ) + manifest_hash = runtime._manifest_hash(manifest) + cache_root = (tmp_path / "cache").resolve() + + def fake_run_install_command(command: list[str], *, cwd=None) -> None: + del cwd + if command[:2] == ["uv", "sync"]: + stage = Path(command[command.index("--project") + 1]) + bin_dir = stage / ".venv" / "bin" + bin_dir.mkdir(parents=True) + (stage / ".venv" / "pyvenv.cfg").write_text("venv\n") + (bin_dir / "python").write_text("#!/bin/sh\n") + return + assert command[:3] == ["uv", "pip", "install"] + runtime_python = Path(command[command.index("--python") + 1]) + assert runtime_python == cache_root / manifest_hash / ".venv" / "bin" / "python" + runtime_bin = runtime_python.parent / runtime.RUNTIME_SERVER + runtime_bin.write_text(f"#!{runtime_python}\n") + runtime_bin.chmod(runtime_bin.stat().st_mode | 0o111) + + monkeypatch.setattr(runtime, "_run_install_command", fake_run_install_command) + + runtime_bin = runtime._install_managed_runtime( + bundle_dir=bundle_dir, + cache_root=cache_root, + manifest=manifest, + manifest_hash=manifest_hash, + ) + + assert ( + runtime_bin + == cache_root / manifest_hash / ".venv" / "bin" / runtime.RUNTIME_SERVER + ) + assert runtime_bin.read_text().startswith( + f"#!{runtime._runtime_python(cache_root / manifest_hash)}" + ) + assert runtime._read_install_marker(cache_root / manifest_hash) is not None + + +def test_validate_managed_runtime_rejects_non_executable_entrypoint( + tmp_path: Path, +) -> None: + manifest = runtime.VllmRuntimeManifest( + art_version="0.5.17", + runtime_version="0.1.0", + python=">=3.11", + runtime_wheel="art_vllm_runtime-0.1.0-py3-none-any.whl", + runtime_wheel_sha256="wheel", + pyproject_sha256="pyproject", + lockfile_sha256="lockfile", + ) + manifest_hash = runtime._manifest_hash(manifest) + runtime_dir = tmp_path / manifest_hash + runtime_bin = runtime._runtime_bin(runtime_dir) + runtime_bin.parent.mkdir(parents=True) + (runtime_dir / ".venv" / "pyvenv.cfg").write_text("venv\n") + runtime_bin.write_text("#!/bin/sh\n") + runtime_bin.chmod(runtime_bin.stat().st_mode & ~0o111) + marker = runtime.VllmRuntimeInstallMarker( + runtime_version=manifest.runtime_version, + protocol_version=manifest.protocol_version, + manifest_hash=manifest_hash, + runtime_wheel_sha256=manifest.runtime_wheel_sha256, + cache_root=str(tmp_path.resolve()), + ) + runtime._install_marker_path(runtime_dir).write_text(marker.model_dump_json()) + + assert not os.access(runtime_bin, os.X_OK) + assert ( + runtime._validate_managed_runtime( + runtime_dir, + cache_root=tmp_path.resolve(), + manifest=manifest, + manifest_hash=manifest_hash, + ) + is None + ) + + +@pytest.mark.asyncio +async def test_wait_for_vllm_runtime_polls_http_health(monkeypatch) -> None: + seen: dict[str, object] = {} + + class FakeProcess: + def poll(self): + return None + + class FakeResponse: + status_code = 200 + + class FakeClient: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def get(self, url: str, timeout: float): + seen["url"] = url + seen["timeout"] = timeout + return FakeResponse() + + monkeypatch.setattr(runtime.httpx, "AsyncClient", lambda: FakeClient()) + await runtime.wait_for_vllm_runtime( + process=FakeProcess(), + host="127.0.0.1", + port=8123, + timeout=12.0, + ) + assert seen == { + "url": "http://127.0.0.1:8123/health", + "timeout": 5.0, + } diff --git a/tests/integration/megatron/runtime_isolation/test_runtime_project_isolation.py b/tests/integration/megatron/runtime_isolation/test_runtime_project_isolation.py new file mode 100644 index 000000000..a0c9ce492 --- /dev/null +++ b/tests/integration/megatron/runtime_isolation/test_runtime_project_isolation.py @@ -0,0 +1,147 @@ +import json +from pathlib import Path +import subprocess + +ROOT = Path(__file__).resolve().parents[4] + + +def test_runtime_project_imports_in_its_own_project_env(artifact_dir: Path) -> None: + result = subprocess.run( + [ + "uv", + "run", + "--project", + str(ROOT / "vllm_runtime"), + "python", + "-c", + ( + "import importlib.util, json; " + "import art_vllm_runtime; " + "print(json.dumps({" + "'runtime_ok': True, " + "'has_vllm': importlib.util.find_spec('vllm') is not None" + "}))" + ), + ], + cwd=ROOT, + check=True, + capture_output=True, + text=True, + ) + (artifact_dir / "stdout.txt").write_text(result.stdout) + (artifact_dir / "stderr.txt").write_text(result.stderr) + payload = json.loads(result.stdout.strip()) + assert payload == {"runtime_ok": True, "has_vllm": True} + + +def test_runtime_server_source_contains_only_required_custom_routes() -> None: + source = ( + ROOT / "vllm_runtime" / "src" / "art_vllm_runtime" / "dedicated_server.py" + ).read_text() + for route in ("/sleep", "/wake_up", "/is_sleeping", "/art/set_served_model_name"): + assert route in source + + +def test_runtime_general_plugin_loads_full_patch_set() -> None: + pyproject = (ROOT / "vllm_runtime" / "pyproject.toml").read_text() + assert 'art = "art_vllm_runtime.patches:apply_vllm_runtime_patches"' in pyproject + + +def test_runtime_patch_set_does_not_install_lora_monkey_patches() -> None: + source = ( + ROOT / "vllm_runtime" / "src" / "art_vllm_runtime" / "patches.py" + ).read_text() + assert "patch_punica_ep_moe_lora_alignment" not in source + assert "patch_lora_duplicate_module_aliases" not in source + assert "patch_fused_moe_ep_lora_support" not in source + + +def test_runtime_cli_serializes_lora_target_modules_as_single_nargs_vector( + artifact_dir: Path, +) -> None: + result = subprocess.run( + [ + "uv", + "run", + "--project", + str(ROOT / "vllm_runtime"), + "python", + "-c", + ( + "import json; " + "from art_vllm_runtime.dedicated_server import _append_cli_arg; " + "args = []; " + "_append_cli_arg(args, 'lora_target_modules', ['a', 'b']); " + "print(json.dumps(args))" + ), + ], + cwd=ROOT, + check=True, + capture_output=True, + text=True, + ) + (artifact_dir / "lora_target_modules_stdout.txt").write_text(result.stdout) + (artifact_dir / "lora_target_modules_stderr.txt").write_text(result.stderr) + assert json.loads(result.stdout.strip()) == ["--lora-target-modules", "a", "b"] + + +def test_runtime_project_restores_nccl_unique_id_from_raw_bytes( + artifact_dir: Path, +) -> None: + result = subprocess.run( + [ + "uv", + "run", + "--project", + str(ROOT / "vllm_runtime"), + "python", + "-c", + ( + "import ctypes, json; " + "from art_vllm_runtime.patches import _restore_nccl_unique_id_payload; " + "from vllm.distributed.device_communicators.pynccl_wrapper import ncclUniqueId; " + "payload = bytes(range(128)); " + "restored = _restore_nccl_unique_id_payload(payload, ncclUniqueId()); " + "print(json.dumps({" + "'type': type(restored).__name__, " + "'matches': ctypes.string_at(ctypes.byref(restored), ctypes.sizeof(restored)).hex() == payload.hex()" + "}))" + ), + ], + cwd=ROOT, + check=True, + capture_output=True, + text=True, + ) + (artifact_dir / "restore_stdout.txt").write_text(result.stdout) + (artifact_dir / "restore_stderr.txt").write_text(result.stderr) + payload = json.loads(result.stdout.strip()) + assert payload == {"type": "ncclUniqueId", "matches": True} + + +def test_runtime_project_nccl_wrapper_accepts_raw_bytes(artifact_dir: Path) -> None: + result = subprocess.run( + [ + "uv", + "run", + "--project", + str(ROOT / "vllm_runtime"), + "python", + "-c", + ( + "import json; " + "from art_vllm_runtime.patches import _normalize_nccl_comm_init_rank_unique_id; " + "FakeLibrary = type('FakeLibrary', (), {'unique_id_from_bytes': lambda self, data: {'restored': len(data)}}); " + "restored = _normalize_nccl_comm_init_rank_unique_id(FakeLibrary(), bytes(range(128))); " + "print(json.dumps(restored))" + ), + ], + cwd=ROOT, + check=True, + capture_output=True, + text=True, + ) + (artifact_dir / "nccl_wrapper_stdout.txt").write_text(result.stdout) + (artifact_dir / "nccl_wrapper_stderr.txt").write_text(result.stderr) + payload = json.loads(result.stdout.strip()) + assert payload == {"restored": 128} diff --git a/tests/integration/megatron/runtime_isolation/test_service_runtime_boundary.py b/tests/integration/megatron/runtime_isolation/test_service_runtime_boundary.py new file mode 100644 index 000000000..586f5673d --- /dev/null +++ b/tests/integration/megatron/runtime_isolation/test_service_runtime_boundary.py @@ -0,0 +1,239 @@ +import asyncio +from pathlib import Path +import sys +from types import SimpleNamespace +from typing import cast +from unittest.mock import AsyncMock + +import httpx +import pytest + +from art.megatron.service import MegatronService +from art.unsloth.service import UnslothService + + +class _AsyncOkResponse: + def raise_for_status(self) -> None: + return None + + +class _RecordingAsyncClient: + def __init__( + self, posts: list[tuple[str, dict[str, object] | None, float]] + ) -> None: + self._posts = posts + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def post( + self, + url: str, + *, + params: dict[str, object] | None = None, + timeout: float, + ) -> _AsyncOkResponse: + self._posts.append((url, params, timeout)) + return _AsyncOkResponse() + + +class _FakeAsyncioProcess: + returncode: int | None = None + + async def wait(self) -> int: + await asyncio.Event().wait() + return 0 + + +@pytest.mark.asyncio +async def test_megatron_shared_start_requires_runtime_sleep_mode( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = MegatronService( + model_name="test-model", + base_model="Qwen/Qwen3-0.6B", + config={ + "rollout_weights_mode": "lora", + "engine_args": {"enable_sleep_mode": False}, + }, + output_dir=str(tmp_path), + ) + monkeypatch.setattr(service, "_resolve_active_lora_path", lambda: "/tmp/lora") + monkeypatch.setattr(service, "_start_vllm_subprocess", AsyncMock()) + + with pytest.raises( + ValueError, + match="Shared-GPU mode requires engine_args.enable_sleep_mode=True", + ): + await service.start_openai_server(None) + + +@pytest.mark.asyncio +async def test_unsloth_shared_start_requires_runtime_sleep_mode( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = UnslothService( + model_name="test-model", + base_model="Qwen/Qwen3-0.6B", + config={ + "rollout_weights_mode": "lora", + "engine_args": {"enable_sleep_mode": False}, + }, + output_dir=str(tmp_path), + ) + service.__dict__["_state"] = SimpleNamespace( + trainer=SimpleNamespace(save_model=lambda path: None), + offload_to_cpu=lambda: None, + ) + monkeypatch.setattr( + "art.unsloth.service.get_last_checkpoint_dir", lambda _output_dir: "/tmp/lora" + ) + monkeypatch.setattr("art.unsloth.service.get_step_from_dir", lambda _output_dir: 0) + monkeypatch.setattr(service, "_start_vllm_subprocess", AsyncMock()) + + with pytest.raises( + ValueError, + match="Shared-GPU mode requires engine_args.enable_sleep_mode=True", + ): + await service.start_openai_server(None) + + +@pytest.mark.asyncio +async def test_megatron_runtime_sleep_and_wake_use_runtime_routes( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = MegatronService( + model_name="test-model", + base_model="Qwen/Qwen3-0.6B", + config={"rollout_weights_mode": "lora"}, + output_dir=str(tmp_path), + ) + service._vllm_port = 8123 + posts: list[tuple[str, dict[str, object] | None, float]] = [] + monkeypatch.setattr(httpx, "AsyncClient", lambda: _RecordingAsyncClient(posts)) + + await service._sleep_runtime() + await service._wake_runtime() + + assert posts == [ + ("http://127.0.0.1:8123/sleep", {"level": 1, "mode": "wait"}, 300.0), + ("http://127.0.0.1:8123/wake_up", None, 300.0), + ] + assert service._is_sleeping is False + + +@pytest.mark.asyncio +async def test_unsloth_runtime_sleep_and_wake_use_runtime_routes( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = UnslothService( + model_name="test-model", + base_model="Qwen/Qwen3-0.6B", + config={"rollout_weights_mode": "lora"}, + output_dir=str(tmp_path), + ) + service._vllm_port = 8123 + posts: list[tuple[str, dict[str, object] | None, float]] = [] + monkeypatch.setattr(httpx, "AsyncClient", lambda: _RecordingAsyncClient(posts)) + + await service._sleep_runtime() + await service._wake_runtime() + + assert posts == [ + ("http://127.0.0.1:8123/sleep", {"level": 1, "mode": "wait"}, 300.0), + ("http://127.0.0.1:8123/wake_up", None, 300.0), + ] + assert service._is_sleeping is False + + +@pytest.mark.asyncio +async def test_megatron_dedicated_merged_start_syncs_initial_weights( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = MegatronService( + model_name="test-model", + base_model="Qwen/Qwen3-0.6B", + config={ + "trainer_gpu_ids": [0], + "inference_gpu_ids": [1], + "rollout_weights_mode": "merged", + }, + output_dir=str(tmp_path), + ) + start_vllm = AsyncMock(return_value=("127.0.0.1", 8000)) + sync_merged = AsyncMock() + monkeypatch.setattr(service, "_resolve_active_lora_path", lambda: "/tmp/lora") + monkeypatch.setattr(service, "_start_vllm_subprocess", start_vllm) + monkeypatch.setattr(service, "_sync_dedicated_merged_weights", sync_merged) + + location = await service.start_openai_server(None) + + assert location == ("127.0.0.1", 8000) + start_vllm.assert_awaited_once() + sync_merged.assert_awaited_once_with(lora_path="/tmp/lora", step=0) + + +@pytest.mark.asyncio +async def test_megatron_worker_uses_active_python_for_torchrun( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + pytest.importorskip("megatron.bridge") + service = MegatronService( + model_name="test-model", + base_model="Qwen/Qwen3-0.6B", + config={ + "trainer_gpu_ids": [0], + "inference_gpu_ids": [1], + "rollout_weights_mode": "lora", + }, + output_dir=str(tmp_path), + ) + recorded: dict[str, object] = {} + + async def _fake_create_subprocess_exec( + *command: str, + cwd: str, + env: dict[str, str], + stdout, + stderr, + start_new_session: bool, + ) -> _FakeAsyncioProcess: + recorded["command"] = list(command) + recorded["cwd"] = cwd + recorded["env"] = env + recorded["stdout"] = stdout + recorded["stderr"] = stderr + recorded["start_new_session"] = start_new_session + return _FakeAsyncioProcess() + + monkeypatch.setattr( + "art.megatron.service.asyncio.create_subprocess_exec", + _fake_create_subprocess_exec, + ) + monkeypatch.setattr(service, "_install_parent_signal_cleanup", lambda: None) + monkeypatch.setattr(service, "_allocate_master_port", lambda: 12345) + + await service._ensure_megatron_running() + command = cast(list[str], recorded["command"]) + assert isinstance(command, list) + assert command[0] == sys.executable + assert command[1].endswith("managed_process.py") + separator = command.index("--") + assert command[separator + 1 : separator + 4] == [ + sys.executable, + "-m", + "torch.distributed.run", + ] + assert "uv run" not in command + assert recorded["cwd"] == str(Path(__file__).resolve().parents[4]) + service._child_processes.close() + service._megatron_log_file.close() diff --git a/tests/integration/megatron/train_inf_mismatch/__init__.py b/tests/integration/megatron/train_inf_mismatch/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/integration/megatron/train_inf_mismatch/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/integration/megatron/train_inf_mismatch/artifacts.py b/tests/integration/megatron/train_inf_mismatch/artifacts.py new file mode 100644 index 000000000..1ee3dee72 --- /dev/null +++ b/tests/integration/megatron/train_inf_mismatch/artifacts.py @@ -0,0 +1,76 @@ +from datetime import datetime, timezone +import os +from pathlib import Path +import re +import subprocess +import sys +import uuid + +from pydantic import BaseModel + +TEST_ROOT = Path(__file__).resolve().parent +ARTIFACTS_ROOT = TEST_ROOT / "artifacts" +REPO_ROOT = Path( + subprocess.run( + ["git", "rev-parse", "--show-toplevel"], + cwd=TEST_ROOT, + check=True, + capture_output=True, + text=True, + ).stdout.strip() +) + + +class ArtifactMetadata(BaseModel): + commit: str + branch: str + test_nodeid: str + created_at_utc: str + python_executable: str + artifact_dir: str + + +def _git(*args: str) -> str: + return subprocess.run( + ["git", *args], + cwd=REPO_ROOT, + check=True, + capture_output=True, + text=True, + ).stdout.strip() + + +def require_clean_git_state() -> str: + dirty = _git("status", "--porcelain=v1", "--untracked-files=all").splitlines() + if dirty: + rendered = "\n".join(dirty) + raise RuntimeError( + "Megatron train/inf mismatch tests require a committed worktree.\n" + "Commit or remove these changes before running tests:\n" + f"{rendered}" + ) + return _git("rev-parse", "HEAD") + + +def create_artifact_dir(test_nodeid: str) -> Path: + commit = require_clean_git_state() + test_name = re.sub(r"[^A-Za-z0-9_.-]+", "_", test_nodeid).strip("._") + run_id = ( + f"{datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%SZ')}_" + f"{os.getpid()}_{uuid.uuid4().hex[:8]}" + ) + artifact_dir = ARTIFACTS_ROOT / (test_name or "unnamed_test") / commit[:12] / run_id + artifact_dir.mkdir(parents=True, exist_ok=False) + metadata = ArtifactMetadata( + commit=commit, + branch=_git("branch", "--show-current"), + test_nodeid=test_nodeid, + created_at_utc=datetime.now(timezone.utc).isoformat(), + python_executable=sys.executable, + artifact_dir=str(artifact_dir), + ) + (artifact_dir / "run_metadata.json").write_text( + metadata.model_dump_json(indent=2) + "\n", + encoding="utf-8", + ) + return artifact_dir diff --git a/tests/integration/megatron/train_inf_mismatch/artifacts/.gitignore b/tests/integration/megatron/train_inf_mismatch/artifacts/.gitignore new file mode 100644 index 000000000..d6b7ef32c --- /dev/null +++ b/tests/integration/megatron/train_inf_mismatch/artifacts/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/tests/integration/megatron/train_inf_mismatch/conftest.py b/tests/integration/megatron/train_inf_mismatch/conftest.py new file mode 100644 index 000000000..a3ffdf74f --- /dev/null +++ b/tests/integration/megatron/train_inf_mismatch/conftest.py @@ -0,0 +1,15 @@ +from pathlib import Path + +import pytest + +from .artifacts import create_artifact_dir, require_clean_git_state + + +@pytest.fixture(scope="session", autouse=True) +def _require_clean_commit_state() -> None: + require_clean_git_state() + + +@pytest.fixture +def artifact_dir(request: pytest.FixtureRequest) -> Path: + return create_artifact_dir(request.node.nodeid) diff --git a/tests/integration/megatron/train_inf_mismatch/output_parity.py b/tests/integration/megatron/train_inf_mismatch/output_parity.py new file mode 100644 index 000000000..562d65004 --- /dev/null +++ b/tests/integration/megatron/train_inf_mismatch/output_parity.py @@ -0,0 +1,1387 @@ +from __future__ import annotations + +import argparse +import asyncio +from contextlib import asynccontextmanager, contextmanager +import hashlib +import json +import math +import os +from pathlib import Path +import random +import shutil +import socket +import subprocess +import sys +import time +from typing import Any, AsyncIterator, Literal, cast + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from .artifacts import REPO_ROOT + +BF16_FWD_MEAN_ABS_PCT_LIMIT = 3.0 +MEAN_ABS_PCT_DENOMINATOR_EPS = 1e-18 +TOP_K = 20 + +RolloutMode = Literal["native_lora", "merged"] +EngineSide = Literal["megatron", "vllm"] +WeightState = Literal["base", "lora"] + + +class Topology(BaseModel): + model_config = ConfigDict(frozen=True) + + tp: int = 2 + ep: int = 2 + etp: int = 1 + dp: int = 1 + cp: int = 1 + pp: int = 1 + + def world_size(self) -> int: + return self.tp * self.dp * self.cp * self.pp + + def env(self) -> dict[str, str]: + return { + "ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE": str(self.tp), + "ART_MEGATRON_EXPERT_MODEL_PARALLEL_SIZE": str(self.ep), + "ART_MEGATRON_EXPERT_TENSOR_PARALLEL_SIZE": str(self.etp), + } + + def slug(self) -> str: + return ( + f"tp{self.tp}_ep{self.ep}_etp{self.etp}_dp{self.dp}_cp{self.cp}_pp{self.pp}" + ) + + +class ProbePackedConfig(BaseModel): + num_sequences: int = 4 + sequence_length: int = 1024 + prefill_tokens: int = 256 + completion_branches_per_prefix: int = 2 + decode_tokens: int = 128 + decode_tokens_jitter: int = 32 + vocab_high: int = 8192 + packing_mode: Literal["stop_early", "truncate"] = "stop_early" + + +class TrainInfOutputParityConfig(BaseModel): + base_model: str = "Qwen/Qwen3.5-35B-A3B" + seed: int = 20260512 + topology: Topology = Field(default_factory=Topology) + packed: ProbePackedConfig = Field(default_factory=ProbePackedConfig) + rollout_modes: list[RolloutMode] = Field(default_factory=list) + trainer_gpu_ids: list[int] = Field(default_factory=lambda: [0, 1]) + inference_gpu_ids: list[int] = Field(default_factory=lambda: [2, 3]) + allow_unvalidated_arch: bool = False + lora_target_modules: list[str] | None = None + engine_args: dict[str, Any] = Field(default_factory=dict) + server_args: dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="after") + def _set_default_rollout_modes(self) -> "TrainInfOutputParityConfig": + if not self.rollout_modes: + self.rollout_modes = default_rollout_modes_for_model( + self.base_model, + allow_unvalidated_arch=self.allow_unvalidated_arch, + ) + return self + + +class LogicalPrompt(BaseModel): + prompt_id: int + sample_id: int + family_id: int + completion_id: int + token_ids: list[int] + + +class LogicalToken(BaseModel): + token_id: int + sample_id: int + family_id: int + completion_id: int + prompt_id: int + art_packed_token_index: int + art_logit_index: int + vllm_prompt_token_index: int + + +class LogicalTokenMap(BaseModel): + prompts: list[LogicalPrompt] + tokens: list[LogicalToken] + + +class TokenTopK(BaseModel): + token_ids: list[int] + logprobs: list[float] + + +class ScoreBundle(BaseModel): + side: EngineSide + weight_state: WeightState + rollout_mode: RolloutMode | None = None + target_logprobs: list[float] + topk: list[TokenTopK] + + +class MeanAbsPctSummary(BaseModel): + mean_abs_pct: float + sequence_count: int + source_numel: int + trimmed_numel: int + + +class PairComparison(BaseModel): + mean_abs_pct: float + sequence_count: int + source_numel: int + trimmed_numel: int + mae: float + max_abs: float + p50_abs: float + p95_abs: float + p99_abs: float + + +class TopKComparison(BaseModel): + top1_match_rate: float + top20_overlap_rate: float + top20_intersection_logprob_mae: float + top20_intersection_kl_target_to_candidate: float + top20_intersection_kl_candidate_to_target: float + compared_intersection_count: int + + +class RolloutComparison(BaseModel): + rollout_mode: RolloutMode + base: PairComparison + lora: PairComparison + delta: PairComparison + base_topk: TopKComparison + lora_topk: TopKComparison + + +class TrainInfOutputParityReport(BaseModel): + base_model: str + artifact_dir: str + topology: str + trainer_gpu_ids: list[int] + inference_gpu_ids: list[int] + logical_prompt_count: int + logical_token_count: int + adapter_path: str + megatron_base_scores: str + megatron_lora_scores: str + rollout_comparisons: list[RolloutComparison] + passed: bool + + +class MegatronWorkerRequest(BaseModel): + config: TrainInfOutputParityConfig + artifact_dir: str + weight_state: WeightState + adapter_path: str | None = None + + +class MegatronWorkerResult(BaseModel): + score_path: str + logical_map_path: str + adapter_path: str | None = None + + +def _write_json(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True, allow_nan=False) + handle.write("\n") + + +def _read_json(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8") as handle: + value = json.load(handle) + if not isinstance(value, dict): + raise TypeError(f"Expected JSON object in {path}") + return value + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _parse_gpu_ids(value: str | None, default: list[int]) -> list[int]: + if value is None or value.strip() == "": + return list(default) + return [int(part.strip()) for part in value.split(",") if part.strip()] + + +def _parse_str_list(value: str) -> list[str]: + parts = [part.strip() for part in value.split(",") if part.strip()] + if not parts: + raise ValueError("Expected at least one comma-separated value") + return parts + + +def _parse_rollout_modes(value: str) -> list[RolloutMode]: + modes = _parse_str_list(value) + invalid = sorted(set(modes) - {"native_lora", "merged"}) + if invalid: + raise ValueError(f"Unsupported rollout modes: {invalid}") + return cast(list[RolloutMode], modes) + + +def default_rollout_modes_for_model( + base_model: str, + *, + allow_unvalidated_arch: bool = False, +) -> list[RolloutMode]: + from art.megatron.model_support.registry import native_vllm_lora_status_for_model + + modes: list[RolloutMode] = [] + if ( + native_vllm_lora_status_for_model( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + != "disabled" + ): + modes.append("native_lora") + modes.append("merged") + return modes + + +@contextmanager +def _provider_topology_env(topology: Topology) -> Any: + names = topology.env() + previous = {name: os.environ.get(name) for name in names} + os.environ.update(names) + try: + yield + finally: + for name, value in previous.items(): + if value is None: + os.environ.pop(name, None) + else: + os.environ[name] = value + + +def config_from_env() -> TrainInfOutputParityConfig: + config = TrainInfOutputParityConfig( + base_model=os.environ.get( + "ART_TRAIN_INF_MISMATCH_BASE_MODEL", + os.environ.get("BASE_MODEL", TrainInfOutputParityConfig().base_model), + ), + trainer_gpu_ids=_parse_gpu_ids( + os.environ.get("ART_TRAIN_INF_MISMATCH_TRAINER_GPU_IDS"), + [0, 1], + ), + inference_gpu_ids=_parse_gpu_ids( + os.environ.get("ART_TRAIN_INF_MISMATCH_INFERENCE_GPU_IDS"), + [2, 3], + ), + allow_unvalidated_arch=os.environ.get( + "ART_TRAIN_INF_MISMATCH_ALLOW_UNVALIDATED_ARCH", "0" + ) + == "1", + ) + if raw_modes := os.environ.get("ART_TRAIN_INF_MISMATCH_ROLLOUT_MODES"): + config.rollout_modes = _parse_rollout_modes(raw_modes) + if raw_seq_len := os.environ.get("ART_TRAIN_INF_MISMATCH_SEQUENCE_LENGTH"): + config.packed.sequence_length = int(raw_seq_len) + if raw_prefill := os.environ.get("ART_TRAIN_INF_MISMATCH_PREFILL_TOKENS"): + config.packed.prefill_tokens = int(raw_prefill) + if raw_decode := os.environ.get("ART_TRAIN_INF_MISMATCH_DECODE_TOKENS"): + config.packed.decode_tokens = int(raw_decode) + if raw_targets := os.environ.get("ART_TRAIN_INF_MISMATCH_LORA_TARGET_MODULES"): + config.lora_target_modules = _parse_str_list(raw_targets) + return config + + +def _prompt_family_segments( + group_ids: Any, + parent_ids: Any, + *, + required_completion_count: int = 1, +) -> list[tuple[tuple[int, int], list[tuple[int, int]]]]: + valid_tokens = int((group_ids != -1).sum().item()) + families: list[tuple[tuple[int, int], list[tuple[int, int]]]] = [] + cursor = 0 + while cursor < valid_tokens: + group_id = int(group_ids[cursor].item()) + parent_id = int(parent_ids[cursor].item()) + prompt_start = cursor + while cursor < valid_tokens and int(group_ids[cursor].item()) == group_id: + cursor += 1 + prompt_end = cursor + if group_id != parent_id: + continue + completions: list[tuple[int, int]] = [] + while cursor < valid_tokens: + completion_group_id = int(group_ids[cursor].item()) + completion_parent_id = int(parent_ids[cursor].item()) + if completion_parent_id != group_id or completion_group_id == group_id: + break + completion_start = cursor + while ( + cursor < valid_tokens + and int(group_ids[cursor].item()) == completion_group_id + ): + cursor += 1 + completions.append((completion_start, cursor)) + if len(completions) >= required_completion_count: + families.append(((prompt_start, prompt_end), completions)) + return families + + +def build_logical_token_map(packed_tensors: dict[str, Any]) -> LogicalTokenMap: + tokens = packed_tensors["tokens"] + group_ids = packed_tensors["group_ids"] + parent_ids = packed_tensors["parent_ids"] + prompts: list[LogicalPrompt] = [] + logical_tokens: list[LogicalToken] = [] + prompt_id_by_tokens: dict[tuple[int, ...], int] = {} + + for sample_id in range(int(tokens.shape[0])): + families = _prompt_family_segments(group_ids[sample_id], parent_ids[sample_id]) + for family_id, (prompt_segment, completion_segments) in enumerate(families): + prompt_start, prompt_end = prompt_segment + prompt_len = prompt_end - prompt_start + for completion_id, (completion_start, completion_end) in enumerate( + completion_segments + ): + if completion_end - completion_start < 2: + continue + flat = [ + int(value) + for value in tokens[sample_id, prompt_start:prompt_end].tolist() + ] + [ + int(value) + for value in tokens[ + sample_id, completion_start:completion_end + ].tolist() + ] + flat_key = tuple(flat) + prompt_id = prompt_id_by_tokens.get(flat_key) + if prompt_id is None: + prompt_id = len(prompts) + prompt_id_by_tokens[flat_key] = prompt_id + prompts.append( + LogicalPrompt( + prompt_id=prompt_id, + sample_id=sample_id, + family_id=family_id, + completion_id=completion_id, + token_ids=flat, + ) + ) + for packed_i in range(completion_start + 1, completion_end): + logical_tokens.append( + LogicalToken( + token_id=int(tokens[sample_id, packed_i].item()), + sample_id=sample_id, + family_id=family_id, + completion_id=completion_id, + prompt_id=prompt_id, + art_packed_token_index=packed_i, + art_logit_index=packed_i - 1, + vllm_prompt_token_index=prompt_len + + (packed_i - completion_start), + ) + ) + + if not prompts or not logical_tokens: + raise RuntimeError("Shared-prefix probe produced no comparable logical tokens") + return LogicalTokenMap(prompts=prompts, tokens=logical_tokens) + + +def aggregate_mean_abs_pct( + *, + candidate: Any, + target: Any, + sequence_ids: list[int], +) -> MeanAbsPctSummary: + import torch + + cand = candidate.detach().float().reshape(-1) + ref = target.detach().float().reshape(-1) + if cand.shape != ref.shape: + raise RuntimeError(f"Shape mismatch: candidate={cand.shape} target={ref.shape}") + if cand.numel() != len(sequence_ids): + raise RuntimeError( + f"sequence_ids length mismatch: {len(sequence_ids)} != {cand.numel()}" + ) + if cand.numel() == 0: + return MeanAbsPctSummary( + mean_abs_pct=0.0, + sequence_count=0, + source_numel=0, + trimmed_numel=0, + ) + sequence_count = len({int(sequence_id) for sequence_id in sequence_ids}) + mean_abs_diff = float((cand - ref).abs().mean().item()) + mean_abs_reference = float(ref.abs().mean().item()) + return MeanAbsPctSummary( + mean_abs_pct=( + mean_abs_diff / (mean_abs_reference + MEAN_ABS_PCT_DENOMINATOR_EPS) + ) + * 100.0, + sequence_count=sequence_count, + source_numel=int(cand.numel()), + trimmed_numel=0, + ) + + +def _percentile(sorted_values: list[float], q: float) -> float: + if not sorted_values: + return 0.0 + index = min(len(sorted_values) - 1, max(0, math.ceil(q * len(sorted_values)) - 1)) + return float(sorted_values[index]) + + +def compare_pair( + *, + candidate: Any, + target: Any, + sequence_ids: list[int], +) -> PairComparison: + import torch + + cand = candidate.detach().float().reshape(-1) + ref = target.detach().float().reshape(-1) + pct = aggregate_mean_abs_pct( + candidate=cand, + target=ref, + sequence_ids=sequence_ids, + ) + diff = (cand - ref).abs() + sorted_diff = sorted(float(value) for value in diff.tolist()) + return PairComparison( + mean_abs_pct=pct.mean_abs_pct, + sequence_count=pct.sequence_count, + source_numel=pct.source_numel, + trimmed_numel=pct.trimmed_numel, + mae=float(diff.mean().item()) if diff.numel() else 0.0, + max_abs=float(diff.max().item()) if diff.numel() else 0.0, + p50_abs=_percentile(sorted_diff, 0.50), + p95_abs=_percentile(sorted_diff, 0.95), + p99_abs=_percentile(sorted_diff, 0.99), + ) + + +def _logsumexp(values: list[float]) -> float: + max_value = max(values) + return max_value + math.log(sum(math.exp(value - max_value) for value in values)) + + +def _restricted_kl( + left_by_id: dict[int, float], + right_by_id: dict[int, float], + token_ids: set[int], +) -> float: + if not token_ids: + return 0.0 + ordered_ids = sorted(token_ids) + left_values = [left_by_id[token_id] for token_id in ordered_ids] + right_values = [right_by_id[token_id] for token_id in ordered_ids] + left_log_z = _logsumexp(left_values) + right_log_z = _logsumexp(right_values) + kl = 0.0 + for left_value, right_value in zip(left_values, right_values, strict=True): + left_logprob = left_value - left_log_z + right_logprob = right_value - right_log_z + kl += math.exp(left_logprob) * (left_logprob - right_logprob) + return float(kl) + + +def compare_topk(candidate: ScoreBundle, target: ScoreBundle) -> TopKComparison: + if len(candidate.topk) != len(target.topk): + raise RuntimeError("top-k score length mismatch") + top1_matches = 0 + overlap_sum = 0.0 + intersection_abs_sum = 0.0 + intersection_count = 0 + target_to_candidate_kl_sum = 0.0 + candidate_to_target_kl_sum = 0.0 + kl_count = 0 + for cand_topk, ref_topk in zip(candidate.topk, target.topk, strict=True): + cand_ids = cand_topk.token_ids[:TOP_K] + ref_ids = ref_topk.token_ids[:TOP_K] + if cand_ids and ref_ids and cand_ids[0] == ref_ids[0]: + top1_matches += 1 + cand_set = set(cand_ids) + ref_set = set(ref_ids) + intersection = cand_set & ref_set + overlap_sum += len(intersection) / max(TOP_K, 1) + cand_by_id = dict(zip(cand_topk.token_ids, cand_topk.logprobs, strict=True)) + ref_by_id = dict(zip(ref_topk.token_ids, ref_topk.logprobs, strict=True)) + for token_id in intersection: + intersection_abs_sum += abs(cand_by_id[token_id] - ref_by_id[token_id]) + intersection_count += 1 + if intersection: + target_to_candidate_kl_sum += _restricted_kl( + ref_by_id, cand_by_id, intersection + ) + candidate_to_target_kl_sum += _restricted_kl( + cand_by_id, ref_by_id, intersection + ) + kl_count += 1 + count = max(len(candidate.topk), 1) + return TopKComparison( + top1_match_rate=top1_matches / count, + top20_overlap_rate=overlap_sum / count, + top20_intersection_logprob_mae=( + intersection_abs_sum / intersection_count if intersection_count else 0.0 + ), + top20_intersection_kl_target_to_candidate=( + target_to_candidate_kl_sum / kl_count if kl_count else 0.0 + ), + top20_intersection_kl_candidate_to_target=( + candidate_to_target_kl_sum / kl_count if kl_count else 0.0 + ), + compared_intersection_count=intersection_count, + ) + + +def compare_rollout( + *, + rollout_mode: RolloutMode, + megatron_base: ScoreBundle, + megatron_lora: ScoreBundle, + vllm_base: ScoreBundle, + vllm_lora: ScoreBundle, + logical_map: LogicalTokenMap, +) -> RolloutComparison: + import torch + + sequence_ids = [token.prompt_id for token in logical_map.tokens] + mb = torch.tensor(megatron_base.target_logprobs, dtype=torch.float32) + ml = torch.tensor(megatron_lora.target_logprobs, dtype=torch.float32) + vb = torch.tensor(vllm_base.target_logprobs, dtype=torch.float32) + vl = torch.tensor(vllm_lora.target_logprobs, dtype=torch.float32) + return RolloutComparison( + rollout_mode=rollout_mode, + base=compare_pair(candidate=vb, target=mb, sequence_ids=sequence_ids), + lora=compare_pair(candidate=vl, target=ml, sequence_ids=sequence_ids), + delta=compare_pair( + candidate=vl - vb, + target=ml - mb, + sequence_ids=sequence_ids, + ), + base_topk=compare_topk(vllm_base, megatron_base), + lora_topk=compare_topk(vllm_lora, megatron_lora), + ) + + +def _set_seed(seed: int) -> None: + import numpy as np + import torch + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def _packed_tensor_config(config: TrainInfOutputParityConfig) -> Any: + from ..model_support.oracle_harness import PackedTensorConfig + + return PackedTensorConfig( + num_sequences=config.packed.num_sequences, + sequence_length=config.packed.sequence_length, + prefill_tokens=config.packed.prefill_tokens, + completion_branches_per_prefix=config.packed.completion_branches_per_prefix, + decode_tokens=config.packed.decode_tokens, + decode_tokens_jitter=config.packed.decode_tokens_jitter, + vocab_high=config.packed.vocab_high, + packing_mode=config.packed.packing_mode, + ) + + +def _build_packed_tensors(config: TrainInfOutputParityConfig) -> dict[str, Any]: + from ..model_support.packed_position_ids import ( + _build_art_realistic_packed_tensors, + ) + + return _build_art_realistic_packed_tensors( + _packed_tensor_config(config), config.seed + ) + + +def _configure_provider(provider: Any, config: TrainInfOutputParityConfig) -> None: + if hasattr(provider, "attention_dropout"): + provider.attention_dropout = 0.0 + if hasattr(provider, "hidden_dropout"): + provider.hidden_dropout = 0.0 + + +def _lora_target_modules(config: TrainInfOutputParityConfig) -> list[str]: + from art.dev.get_model_config import default_target_modules + + return list(config.lora_target_modules or default_target_modules(config.base_model)) + + +def _configure_lora_target_modules( + provider_bundle: Any, target_modules: list[str] +) -> None: + if not target_modules: + raise ValueError("LoRA target module override cannot be empty") + spec = provider_bundle.spec.model_copy( + update={"default_target_modules": tuple(target_modules)} + ) + provider_bundle.spec = spec + setattr(provider_bundle.provider, "_art_model_support_spec", spec) + + +def _build_deterministic_nonzero_lora( + initial_state: dict[str, Any], + *, + seed: int, +) -> dict[str, Any]: + import torch + + initialized: dict[str, Any] = {} + for key in sorted(initial_state): + value = initial_state[key] + if not isinstance(value, torch.Tensor): + raise TypeError(f"Expected tensor for LoRA key {key!r}") + digest = hashlib.sha256(f"{seed}:{key}".encode("utf-8")).digest() + key_seed = int.from_bytes(digest[:8], "little") % (2**31) + generator = torch.Generator(device="cpu").manual_seed(key_seed) + random_values = torch.randn(value.shape, generator=generator) + initialized[key] = (0.01 * random_values).to(value.dtype).contiguous() + return initialized + + +def _merge_sharded_lora(shards_by_rank: list[dict[str, Any]]) -> dict[str, Any]: + from art.megatron.weights.merge import merge_sharded_adapter_entries + + entries_by_key: dict[str, list[tuple[dict[str, Any], Any]]] = {} + for rank_entry in shards_by_rank: + state = rank_entry["state"] + manifest = rank_entry["manifest"] + for key, tensor in state.items(): + entries_by_key.setdefault(key, []).append((manifest[key], tensor)) + return merge_sharded_adapter_entries(entries_by_key) + + +def _collect_full_lora_state(model_chunks: list[Any]) -> dict[str, Any] | None: + import torch + + local_state: dict[str, Any] = {} + local_manifest: dict[str, Any] = {} + for chunk in model_chunks: + for module in chunk.modules(): + if hasattr(module, "sharded_lora_manifest"): + local_manifest.update(module.sharded_lora_manifest()) + if hasattr(module, "sharded_lora_state_dict"): + local_state.update( + { + key: value.detach().cpu() + for key, value in module.sharded_lora_state_dict().items() + } + ) + rank = torch.distributed.get_rank() # type: ignore[possibly-missing-attribute] + world_size = torch.distributed.get_world_size() # type: ignore[possibly-missing-attribute] + gathered = [None for _ in range(world_size)] if rank == 0 else None + torch.distributed.gather_object( # type: ignore[possibly-missing-attribute] + {"state": local_state, "manifest": local_manifest}, + gathered, + dst=0, + ) + if rank != 0: + return None + assert gathered is not None + return _merge_sharded_lora([entry for entry in gathered if entry is not None]) + + +def _adapter_config(config: TrainInfOutputParityConfig) -> dict[str, Any]: + from peft.tuners.lora.config import LoraConfig + + from art.megatron.lora import LORA_ALPHA, LORA_RANK + + return LoraConfig( + base_model_name_or_path=config.base_model, + r=LORA_RANK, + lora_alpha=LORA_ALPHA, + target_modules=_lora_target_modules(config), + bias="none", + ).to_dict() + + +def _save_vllm_lora_adapter( + *, + lora_path: Path, + state: dict[str, Any], + runtime: Any, + config: TrainInfOutputParityConfig, +) -> None: + import torch + + from art.megatron.model_support.lora_disk import save_vllm_lora_tensors + + if not state: + raise RuntimeError("Refusing to save empty LoRA state") + zero_keys = [ + key + for key, value in state.items() + if isinstance(value, torch.Tensor) + and int(torch.count_nonzero(value).item()) == 0 + ] + if zero_keys: + raise RuntimeError(f"Refusing zero LoRA tensors: {zero_keys[:5]}") + adapter_config = _adapter_config(config) + tensors, adapter_config = runtime.model_support_handler.to_vllm_lora_tensors( + state, + adapter_config=adapter_config, + ) + save_vllm_lora_tensors(lora_path, tensors, adapter_config) + + +def _run_logits( + *, + runtime: Any, + packed_tensors: dict[str, Any], +) -> Any: + import torch + + from art.megatron.flex_attention import create_shared_prefix_attention_state + + device = next(runtime.model[0].parameters()).device + input_ids = packed_tensors["tokens"].to(device=device) + position_ids = packed_tensors["input_pos"].to(device=device) + group_ids = packed_tensors["group_ids"].to(device=device) + parent_ids = packed_tensors["parent_ids"].to(device=device) + attention_state = create_shared_prefix_attention_state( + group_ids=group_ids, + parent_ids=parent_ids, + ) + with torch.no_grad(): + return runtime.model[0]( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=torch.zeros((1, 1, 1, 1), dtype=torch.bool, device=device), + labels=None, + **runtime.model_support_handler.get_forward_kwargs( + runtime.model[0], + attention_bias=attention_state, + ), + ) + + +def _extract_scores_from_logits( + *, + logits: Any, + logical_map: LogicalTokenMap, + side: EngineSide, + weight_state: WeightState, + rollout_mode: RolloutMode | None = None, +) -> ScoreBundle: + import torch + + log_probs = torch.log_softmax(logits.detach().float(), dim=-1).cpu() + target_logprobs: list[float] = [] + topk: list[TokenTopK] = [] + for token in logical_map.tokens: + row = log_probs[token.sample_id, token.art_logit_index] + target_logprobs.append(float(row[token.token_id].item())) + values, indices = torch.topk(row, TOP_K) + topk.append( + TokenTopK( + token_ids=[int(value) for value in indices.tolist()], + logprobs=[float(value) for value in values.tolist()], + ) + ) + return ScoreBundle( + side=side, + weight_state=weight_state, + rollout_mode=rollout_mode, + target_logprobs=target_logprobs, + topk=topk, + ) + + +def _megatron_worker(request: MegatronWorkerRequest) -> None: + import torch + + from art.megatron import train as megatron_train + from art.megatron.weights.merge import load_lora_adapter_state_dict + + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group(backend="nccl") # type: ignore[possibly-missing-attribute] + _set_seed(request.config.seed) + os.environ.update(request.config.topology.env()) + + runtime = megatron_train.build_training_runtime( + model_identifier=request.config.base_model, + provider_torch_dtype=torch.bfloat16, + provider_bundle_configure=( + lambda bundle: ( + _configure_lora_target_modules( + bundle, + _lora_target_modules(request.config), + ) + if request.config.lora_target_modules is not None + else None + ) + ), + provider_configure=lambda provider: _configure_provider( + provider, request.config + ), + print_env=False, + build_optimizer=False, + # This worker only runs forward passes. Use the LoRA trainable path for + # both base and LoRA scoring so Megatron freezes base weights before DDP + # allocates buffers; base scoring simply does not load a nonzero adapter. + trainable_parameter_mode="lora", + allow_unvalidated_arch=request.config.allow_unvalidated_arch, + ) + for chunk in runtime.model: + chunk.eval() + + artifact_dir = Path(request.artifact_dir) + packed_tensors = _build_packed_tensors(request.config) + logical_map = build_logical_token_map(packed_tensors) + + adapter_path: Path | None = None + if request.weight_state == "lora": + if request.adapter_path is None: + initial_state = _collect_full_lora_state(cast(list[Any], runtime.model)) + if torch.distributed.get_rank() == 0: # type: ignore[possibly-missing-attribute] + adapter_path = artifact_dir / "active_lora" + initialized = _build_deterministic_nonzero_lora( + initial_state or {}, + seed=request.config.seed, + ) + _save_vllm_lora_adapter( + lora_path=adapter_path, + state=initialized, + runtime=runtime, + config=request.config, + ) + torch.distributed.barrier() # type: ignore[possibly-missing-attribute] + adapter_path = artifact_dir / "active_lora" + else: + adapter_path = Path(request.adapter_path) + adapter_model = load_lora_adapter_state_dict( + str(adapter_path), + handler=runtime.model_support_handler, + allow_unvalidated_arch=request.config.allow_unvalidated_arch, + ) + megatron_train.load_adapter_into_model(runtime.model, adapter_model) + + logits = _run_logits(runtime=runtime, packed_tensors=packed_tensors) + score = _extract_scores_from_logits( + logits=logits, + logical_map=logical_map, + side="megatron", + weight_state=request.weight_state, + ) + + if torch.distributed.get_rank() == 0: # type: ignore[possibly-missing-attribute] + score_path = artifact_dir / f"megatron_{request.weight_state}_scores.json" + logical_map_path = artifact_dir / "logical_token_map.json" + _write_json(score_path, score.model_dump(mode="json")) + _write_json(logical_map_path, logical_map.model_dump(mode="json")) + result = MegatronWorkerResult( + score_path=str(score_path), + logical_map_path=str(logical_map_path), + adapter_path=str(adapter_path) if adapter_path is not None else None, + ) + _write_json( + artifact_dir / f"megatron_{request.weight_state}_worker_result.json", + result.model_dump(mode="json"), + ) + torch.distributed.barrier() # type: ignore[possibly-missing-attribute] + torch.distributed.destroy_process_group() # type: ignore[possibly-missing-attribute] + + +def _run_megatron_worker(request: MegatronWorkerRequest) -> MegatronWorkerResult: + artifact_dir = Path(request.artifact_dir) + request_path = artifact_dir / f"megatron_{request.weight_state}_request.json" + _write_json(request_path, request.model_dump(mode="json")) + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = ",".join( + str(value) for value in request.config.trainer_gpu_ids + ) + env["PYTHONUNBUFFERED"] = "1" + tests_dir = str(REPO_ROOT / "tests") + env["PYTHONPATH"] = ( + tests_dir + if not env.get("PYTHONPATH") + else f"{tests_dir}{os.pathsep}{env['PYTHONPATH']}" + ) + command = [ + sys.executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(request.config.topology.world_size()), + "-m", + "integration.megatron.train_inf_mismatch.output_parity", + "--worker", + "--request", + str(request_path), + ] + log_path = artifact_dir / f"megatron_{request.weight_state}_worker.log" + with log_path.open("w", encoding="utf-8") as log_file: + run = subprocess.run( + command, + cwd=str(REPO_ROOT / "tests"), + env=env, + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + check=False, + ) + if run.returncode != 0: + tail = "\n".join(log_path.read_text(encoding="utf-8").splitlines()[-120:]) + raise RuntimeError( + f"Megatron {request.weight_state} worker failed with exit code " + f"{run.returncode}.\n{tail}" + ) + return MegatronWorkerResult.model_validate( + _read_json(artifact_dir / f"megatron_{request.weight_state}_worker_result.json") + ) + + +@asynccontextmanager +async def _direct_vllm_runtime( + *, + config: TrainInfOutputParityConfig, + artifact_dir: Path, + served_model_name: str, + lora_path: str, + rollout_weights_mode: Literal["lora", "merged"], + engine_args: dict[str, Any], +) -> AsyncIterator[tuple[str, int]]: + import art.vllm_runtime as runtime + + port = _free_port() + launch_config = runtime.VllmRuntimeLaunchConfig( + base_model=config.base_model, + port=port, + host="127.0.0.1", + cuda_visible_devices=",".join(str(value) for value in config.inference_gpu_ids), + lora_path=lora_path, + served_model_name=served_model_name, + rollout_weights_mode=rollout_weights_mode, + engine_args=engine_args, + server_args={ + "return_tokens_as_token_ids": True, + **config.server_args, + }, + ) + command = runtime.build_vllm_runtime_server_cmd(launch_config) + log_path = artifact_dir / f"vllm_{served_model_name}.log" + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + with log_path.open("w", encoding="utf-8") as log_file: + process = subprocess.Popen( + command, + cwd=str(runtime.get_vllm_runtime_working_dir()), + env=env, + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + ) + try: + await runtime.wait_for_vllm_runtime( + process=process, + host=launch_config.host, + port=launch_config.port, + timeout=float( + os.environ.get("ART_TRAIN_INF_MISMATCH_VLLM_TIMEOUT", "1200") + ), + ) + yield launch_config.host, launch_config.port + finally: + process.terminate() + try: + process.wait(timeout=30) + except subprocess.TimeoutExpired: + process.kill() + process.wait(timeout=30) + + +async def _request_prompt_logprobs( + *, + base_url: str, + model_name: str, + prompt_token_ids: list[int], +) -> dict[str, Any]: + import httpx + + async with httpx.AsyncClient(timeout=300.0) as client: + response = await client.post( + f"{base_url}/v1/completions", + json={ + "model": model_name, + "prompt": prompt_token_ids, + "add_special_tokens": False, + "max_tokens": 0, + "echo": True, + "prompt_logprobs": TOP_K, + "return_token_ids": True, + }, + ) + response.raise_for_status() + return response.json() + + +def _logprob_entry_value(entry: dict[str, Any], token_id: int) -> float: + raw = entry.get(str(token_id)) + if raw is None: + raise RuntimeError(f"Token {token_id} missing from vLLM prompt_logprobs entry") + if isinstance(raw, dict): + return float(raw["logprob"]) + return float(raw.logprob) + + +def _topk_from_entry(entry: dict[str, Any]) -> TokenTopK: + parsed: list[tuple[int, int, float]] = [] + for raw_token_id, raw_value in entry.items(): + token_id = int(raw_token_id) + if isinstance(raw_value, dict): + rank = int(raw_value.get("rank", TOP_K + 1)) + logprob = float(raw_value["logprob"]) + else: + rank = int(raw_value.rank) + logprob = float(raw_value.logprob) + if 1 <= rank <= TOP_K: + parsed.append((rank, token_id, logprob)) + parsed.sort(key=lambda item: item[0]) + return TokenTopK( + token_ids=[token_id for _rank, token_id, _logprob in parsed[:TOP_K]], + logprobs=[logprob for _rank, _token_id, logprob in parsed[:TOP_K]], + ) + + +async def _score_vllm_at_url( + *, + base_url: str, + model_name: str, + logical_map: LogicalTokenMap, + weight_state: WeightState, + rollout_mode: RolloutMode, + artifact_dir: Path, +) -> ScoreBundle: + responses_by_prompt: dict[int, dict[str, Any]] = {} + prompt_by_id = {prompt.prompt_id: prompt for prompt in logical_map.prompts} + for prompt in logical_map.prompts: + response = await _request_prompt_logprobs( + base_url=base_url, + model_name=model_name, + prompt_token_ids=prompt.token_ids, + ) + choice = response["choices"][0] + returned_prompt_ids = [int(value) for value in choice["prompt_token_ids"]] + if returned_prompt_ids != prompt.token_ids: + raise RuntimeError( + "vLLM returned prompt_token_ids do not match request for " + f"prompt_id={prompt.prompt_id}" + ) + responses_by_prompt[prompt.prompt_id] = response + _write_json( + artifact_dir / f"vllm_{rollout_mode}_{weight_state}_responses.json", + responses_by_prompt, + ) + + target_logprobs: list[float] = [] + topk: list[TokenTopK] = [] + for token in logical_map.tokens: + prompt = prompt_by_id[token.prompt_id] + choice = responses_by_prompt[token.prompt_id]["choices"][0] + entries = choice["prompt_logprobs"] + returned_token_id = int(prompt.token_ids[token.vllm_prompt_token_index]) + if returned_token_id != token.token_id: + raise RuntimeError( + "Logical token alignment mismatch: " + f"expected={token.token_id} returned={returned_token_id}" + ) + entry = entries[token.vllm_prompt_token_index] + if entry is None: + raise RuntimeError( + f"Missing prompt logprob entry for prompt_id={token.prompt_id} " + f"index={token.vllm_prompt_token_index}" + ) + target_logprobs.append(_logprob_entry_value(entry, token.token_id)) + topk.append(_topk_from_entry(entry)) + return ScoreBundle( + side="vllm", + weight_state=weight_state, + rollout_mode=rollout_mode, + target_logprobs=target_logprobs, + topk=topk, + ) + + +async def _score_vllm_base( + *, + config: TrainInfOutputParityConfig, + rollout_mode: RolloutMode, + logical_map: LogicalTokenMap, + artifact_dir: Path, +) -> ScoreBundle: + served_name = f"train_inf_base_{rollout_mode}_{int(time.time())}" + placeholder_lora = artifact_dir / "unused_lora_placeholder" + placeholder_lora.mkdir(exist_ok=True) + engine_args = { + "tensor_parallel_size": len(config.inference_gpu_ids), + "enable_expert_parallel": len(config.inference_gpu_ids) > 1, + "max_model_len": config.packed.sequence_length + 8, + **config.engine_args, + } + if rollout_mode == "native_lora": + engine_args["enable_lora"] = True + engine_args["lora_target_modules"] = _lora_target_modules(config) + async with _direct_vllm_runtime( + config=config, + artifact_dir=artifact_dir, + served_model_name=served_name, + lora_path=str(placeholder_lora), + rollout_weights_mode="merged", + engine_args=engine_args, + ) as (host, port): + return await _score_vllm_at_url( + base_url=f"http://{host}:{port}", + model_name=served_name, + logical_map=logical_map, + weight_state="base", + rollout_mode=rollout_mode, + artifact_dir=artifact_dir, + ) + + +async def _score_vllm_native_lora( + *, + config: TrainInfOutputParityConfig, + adapter_path: str, + logical_map: LogicalTokenMap, + artifact_dir: Path, +) -> ScoreBundle: + served_name = f"train_inf_native_lora_{int(time.time())}" + engine_args = { + "tensor_parallel_size": len(config.inference_gpu_ids), + "enable_expert_parallel": len(config.inference_gpu_ids) > 1, + "max_model_len": config.packed.sequence_length + 8, + **config.engine_args, + } + engine_args["lora_target_modules"] = _lora_target_modules(config) + async with _direct_vllm_runtime( + config=config, + artifact_dir=artifact_dir, + served_model_name=served_name, + lora_path=adapter_path, + rollout_weights_mode="lora", + engine_args=engine_args, + ) as (host, port): + return await _score_vllm_at_url( + base_url=f"http://{host}:{port}", + model_name=served_name, + logical_map=logical_map, + weight_state="lora", + rollout_mode="native_lora", + artifact_dir=artifact_dir, + ) + + +async def _score_vllm_merged_lora( + *, + config: TrainInfOutputParityConfig, + adapter_path: str, + logical_map: LogicalTokenMap, + artifact_dir: Path, +) -> ScoreBundle: + from art import dev + from art.megatron.service import MegatronService + + service_name = f"train_inf_merged_lora_{int(time.time())}" + output_dir = artifact_dir / "merged_service" + from art.utils.output_dirs import get_step_checkpoint_dir + + checkpoint_dir = Path(get_step_checkpoint_dir(str(output_dir), 0)) + checkpoint_dir.mkdir(parents=True) + for filename in ("adapter_model.safetensors", "adapter_config.json"): + shutil.copy(Path(adapter_path) / filename, checkpoint_dir / filename) + internal_config = dev.InternalModelConfig( + trainer_gpu_ids=config.trainer_gpu_ids, + inference_gpu_ids=config.inference_gpu_ids, + rollout_weights_mode="merged", + allow_unvalidated_arch=config.allow_unvalidated_arch, + engine_args={ + "tensor_parallel_size": len(config.inference_gpu_ids), + "enable_expert_parallel": len(config.inference_gpu_ids) > 1, + "max_model_len": config.packed.sequence_length + 8, + **config.engine_args, + }, + ) + with _provider_topology_env(config.topology): + service = MegatronService( + model_name=service_name, + base_model=config.base_model, + config=internal_config, + output_dir=str(output_dir), + ) + try: + host, port = await service.start_openai_server( + {"server_args": {"port": _free_port(), **config.server_args}} + ) + return await _score_vllm_at_url( + base_url=f"http://{host}:{port}", + model_name=f"{service_name}@0", + logical_map=logical_map, + weight_state="lora", + rollout_mode="merged", + artifact_dir=artifact_dir, + ) + finally: + await service.aclose() + + +def _assert_lora_active( + base: ScoreBundle, lora: ScoreBundle, *, side: EngineSide +) -> None: + import torch + + base_values = torch.tensor(base.target_logprobs, dtype=torch.float32) + lora_values = torch.tensor(lora.target_logprobs, dtype=torch.float32) + if not bool(torch.isfinite(base_values).all().item()): + raise RuntimeError(f"{side} base target logprobs contain non-finite values") + if not bool(torch.isfinite(lora_values).all().item()): + raise RuntimeError(f"{side} LoRA target logprobs contain non-finite values") + if int(torch.count_nonzero((lora_values - base_values).abs() > 0).item()) == 0: + raise RuntimeError(f"{side} LoRA is not active: all deltas are zero") + + +async def run_train_inf_output_parity( + *, + config: TrainInfOutputParityConfig, + artifact_dir: Path, +) -> TrainInfOutputParityReport: + _write_json(artifact_dir / "probe_config.json", config.model_dump(mode="json")) + lora_result = _run_megatron_worker( + MegatronWorkerRequest( + config=config, + artifact_dir=str(artifact_dir), + weight_state="lora", + adapter_path=None, + ) + ) + if lora_result.adapter_path is None: + raise RuntimeError("LoRA worker did not produce an adapter") + base_result = _run_megatron_worker( + MegatronWorkerRequest( + config=config, + artifact_dir=str(artifact_dir), + weight_state="base", + adapter_path=None, + ) + ) + logical_map = LogicalTokenMap.model_validate( + _read_json(Path(lora_result.logical_map_path)) + ) + base_logical_map = LogicalTokenMap.model_validate( + _read_json(Path(base_result.logical_map_path)) + ) + if base_logical_map != logical_map: + raise RuntimeError("Base and LoRA Megatron workers produced different maps") + + megatron_base = ScoreBundle.model_validate(_read_json(Path(base_result.score_path))) + megatron_lora = ScoreBundle.model_validate(_read_json(Path(lora_result.score_path))) + _assert_lora_active(megatron_base, megatron_lora, side="megatron") + + rollout_comparisons: list[RolloutComparison] = [] + for rollout_mode in config.rollout_modes: + vllm_base = await _score_vllm_base( + config=config, + rollout_mode=rollout_mode, + logical_map=logical_map, + artifact_dir=artifact_dir, + ) + if rollout_mode == "native_lora": + vllm_lora = await _score_vllm_native_lora( + config=config, + adapter_path=lora_result.adapter_path, + logical_map=logical_map, + artifact_dir=artifact_dir, + ) + else: + vllm_lora = await _score_vllm_merged_lora( + config=config, + adapter_path=lora_result.adapter_path, + logical_map=logical_map, + artifact_dir=artifact_dir, + ) + _assert_lora_active(vllm_base, vllm_lora, side="vllm") + _write_json( + artifact_dir / f"vllm_{rollout_mode}_base_scores.json", + vllm_base.model_dump(mode="json"), + ) + _write_json( + artifact_dir / f"vllm_{rollout_mode}_lora_scores.json", + vllm_lora.model_dump(mode="json"), + ) + rollout_comparisons.append( + compare_rollout( + rollout_mode=rollout_mode, + megatron_base=megatron_base, + megatron_lora=megatron_lora, + vllm_base=vllm_base, + vllm_lora=vllm_lora, + logical_map=logical_map, + ) + ) + + passed = all( + comparison.base.mean_abs_pct <= BF16_FWD_MEAN_ABS_PCT_LIMIT + and comparison.lora.mean_abs_pct <= BF16_FWD_MEAN_ABS_PCT_LIMIT + for comparison in rollout_comparisons + ) + report = TrainInfOutputParityReport( + base_model=config.base_model, + artifact_dir=str(artifact_dir), + topology=config.topology.slug(), + trainer_gpu_ids=config.trainer_gpu_ids, + inference_gpu_ids=config.inference_gpu_ids, + logical_prompt_count=len(logical_map.prompts), + logical_token_count=len(logical_map.tokens), + adapter_path=lora_result.adapter_path, + megatron_base_scores=base_result.score_path, + megatron_lora_scores=lora_result.score_path, + rollout_comparisons=rollout_comparisons, + passed=passed, + ) + _write_json(artifact_dir / "comparison_report.json", report.model_dump(mode="json")) + return report + + +def _worker_cli(request_path: Path) -> None: + request = MegatronWorkerRequest.model_validate(_read_json(request_path)) + _megatron_worker(request) + + +def _parse_args(argv: list[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--worker", action="store_true") + parser.add_argument("--request", type=Path) + return parser.parse_args(argv) + + +def _main(argv: list[str]) -> int: + args = _parse_args(argv) + if args.worker: + if args.request is None: + raise ValueError("--worker requires --request") + _worker_cli(args.request) + return 0 + raise ValueError("This module is intended to be run through pytest or --worker") + + +if __name__ == "__main__": + raise SystemExit(_main(sys.argv[1:])) diff --git a/tests/integration/megatron/train_inf_mismatch/test_live_output_parity.py b/tests/integration/megatron/train_inf_mismatch/test_live_output_parity.py new file mode 100644 index 000000000..1aef412f7 --- /dev/null +++ b/tests/integration/megatron/train_inf_mismatch/test_live_output_parity.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import os +from pathlib import Path + +import pytest + +from .output_parity import ( + BF16_FWD_MEAN_ABS_PCT_LIMIT, + config_from_env, + run_train_inf_output_parity, +) + +torch = pytest.importorskip("torch") + +LIVE_ENV = "ART_RUN_TRAIN_INF_MISMATCH_LIVE" + + +def _require_live_opt_in() -> None: + if os.environ.get(LIVE_ENV) != "1": + pytest.skip(f"set {LIVE_ENV}=1 to run train/inf output parity") + + +def _require_visible_gpus(gpu_ids: list[int]) -> None: + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for train/inf output parity") + visible_count = int(torch.cuda.device_count()) + required = max(gpu_ids) + 1 if gpu_ids else 0 + if visible_count < required: + pytest.skip( + f"Need visible CUDA device ids through {required - 1}, " + f"but torch sees {visible_count} devices" + ) + + +@pytest.mark.asyncio +async def test_train_inf_output_parity_live(artifact_dir: Path) -> None: + _require_live_opt_in() + config = config_from_env() + _require_visible_gpus(config.trainer_gpu_ids + config.inference_gpu_ids) + + report = await run_train_inf_output_parity( + config=config, + artifact_dir=artifact_dir, + ) + + assert report.logical_prompt_count > 0 + assert report.logical_token_count > 0 + assert report.passed, report.model_dump_json(indent=2) + for comparison in report.rollout_comparisons: + assert comparison.base.mean_abs_pct <= BF16_FWD_MEAN_ABS_PCT_LIMIT + assert comparison.lora.mean_abs_pct <= BF16_FWD_MEAN_ABS_PCT_LIMIT diff --git a/tests/integration/megatron/train_inf_mismatch/test_output_parity_invariants.py b/tests/integration/megatron/train_inf_mismatch/test_output_parity_invariants.py new file mode 100644 index 000000000..0a7c0aa15 --- /dev/null +++ b/tests/integration/megatron/train_inf_mismatch/test_output_parity_invariants.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import math + +import pytest + +torch = pytest.importorskip("torch") + +from . import workflow_stage +from .output_parity import ( + TOP_K, + EngineSide, + ScoreBundle, + TokenTopK, + TrainInfOutputParityConfig, + WeightState, + aggregate_mean_abs_pct, + build_logical_token_map, + compare_rollout, + compare_topk, + config_from_env, +) + + +def test_logical_map_flattens_shared_prefix_branches() -> None: + packed = { + "tokens": torch.tensor([[10, 11, 12, 13, 14, 12, 15, 16]]), + "group_ids": torch.tensor([[0, 0, 1, 1, 1, 2, 2, 2]]), + "parent_ids": torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), + } + + logical_map = build_logical_token_map(packed) + + assert [prompt.token_ids for prompt in logical_map.prompts] == [ + [10, 11, 12, 13, 14], + [10, 11, 12, 15, 16], + ] + assert [token.token_id for token in logical_map.tokens] == [13, 14, 15, 16] + assert [token.art_logit_index for token in logical_map.tokens] == [2, 3, 5, 6] + assert [token.vllm_prompt_token_index for token in logical_map.tokens] == [ + 3, + 4, + 3, + 4, + ] + + +def test_aggregate_mean_abs_pct_uses_vllm_merge_formula() -> None: + summary = aggregate_mean_abs_pct( + candidate=torch.tensor([2.0, 4.0]), + target=torch.tensor([1.0, 3.0]), + sequence_ids=[0, 0], + ) + + assert summary.source_numel == 2 + assert summary.trimmed_numel == 0 + assert summary.mean_abs_pct == pytest.approx((2.0 / 4.0) * 100.0) + + +def test_aggregate_mean_abs_pct_does_not_trim_or_average_sequence_summaries() -> None: + target = torch.ones(80) + candidate = target.clone() + candidate[0] = 101.0 + candidate[1] = 51.0 + candidate[2] = 26.0 + candidate[3] = 2.0 + + summary = aggregate_mean_abs_pct( + candidate=candidate, + target=target, + sequence_ids=[0] * 40 + [1] * 40, + ) + + assert summary.source_numel == 80 + assert summary.sequence_count == 2 + assert summary.trimmed_numel == 0 + assert summary.mean_abs_pct == pytest.approx((176.0 / 80.0) * 100.0) + + +def _score( + values: list[float], + *, + side: EngineSide, + state: WeightState, +) -> ScoreBundle: + return ScoreBundle( + side=side, + weight_state=state, + target_logprobs=values, + topk=[ + TokenTopK( + token_ids=list(range(TOP_K)), + logprobs=[-float(index) for index in range(TOP_K)], + ) + for _ in values + ], + ) + + +def test_compare_rollout_reports_base_lora_and_delta_separately() -> None: + packed = { + "tokens": torch.tensor([[10, 11, 12, 13, 14]]), + "group_ids": torch.tensor([[0, 0, 1, 1, 1]]), + "parent_ids": torch.tensor([[0, 0, 0, 0, 0]]), + } + logical_map = build_logical_token_map(packed) + + report = compare_rollout( + rollout_mode="native_lora", + megatron_base=_score([-1.0, -2.0], side="megatron", state="base"), + megatron_lora=_score([-1.5, -2.5], side="megatron", state="lora"), + vllm_base=_score([-1.1, -2.2], side="vllm", state="base"), + vllm_lora=_score([-1.7, -2.8], side="vllm", state="lora"), + logical_map=logical_map, + ) + + assert report.base.mean_abs_pct > 0 + assert report.lora.mean_abs_pct > 0 + assert report.delta.mean_abs_pct > 0 + + +def test_compare_topk_reports_restricted_intersection_kl() -> None: + target = ScoreBundle( + side="megatron", + weight_state="base", + target_logprobs=[0.0], + topk=[ + TokenTopK( + token_ids=[10, 11], + logprobs=[math.log(0.75), math.log(0.25)], + ) + ], + ) + candidate = ScoreBundle( + side="vllm", + weight_state="base", + target_logprobs=[0.0], + topk=[ + TokenTopK( + token_ids=[10, 11], + logprobs=[math.log(0.5), math.log(0.5)], + ) + ], + ) + + report = compare_topk(candidate, target) + + assert report.top20_intersection_kl_target_to_candidate == pytest.approx( + 0.75 * math.log(0.75 / 0.5) + 0.25 * math.log(0.25 / 0.5) + ) + assert report.top20_intersection_kl_candidate_to_target == pytest.approx( + 0.5 * math.log(0.5 / 0.75) + 0.5 * math.log(0.5 / 0.25) + ) + + +def test_config_from_env_accepts_lora_target_module_override( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv( + "ART_TRAIN_INF_MISMATCH_LORA_TARGET_MODULES", + "experts,in_proj_qkv,in_proj_z", + ) + + config = config_from_env() + + assert config.lora_target_modules == ["experts", "in_proj_qkv", "in_proj_z"] + + +def test_default_rollout_modes_follow_model_support_native_lora_status() -> None: + assert TrainInfOutputParityConfig( + base_model="Qwen/Qwen3.5-35B-A3B" + ).rollout_modes == ["native_lora", "merged"] + assert TrainInfOutputParityConfig( + base_model="unvalidated/native-disabled", + allow_unvalidated_arch=True, + ).rollout_modes == ["merged"] + + +def test_config_from_env_rollout_modes_override_handler_default( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv( + "ART_TRAIN_INF_MISMATCH_BASE_MODEL", + "unvalidated/native-disabled", + ) + monkeypatch.setenv("ART_TRAIN_INF_MISMATCH_ALLOW_UNVALIDATED_ARCH", "1") + monkeypatch.setenv("ART_TRAIN_INF_MISMATCH_ROLLOUT_MODES", "native_lora") + + config = config_from_env() + + assert config.rollout_modes == ["native_lora"] + + +def test_workflow_stage_enables_live_train_inf_mismatch( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: + import subprocess + + captured_env = {} + + def fake_run(*args, **kwargs): + captured_env.update(kwargs["env"]) + return subprocess.CompletedProcess( + args=args, + returncode=0, + stdout="1 passed\n", + stderr="", + ) + + monkeypatch.setattr(workflow_stage, "create_artifact_dir", lambda _nodeid: tmp_path) + monkeypatch.setattr(workflow_stage.subprocess, "run", fake_run) + + report = workflow_stage.run_train_inf_mismatch(base_model="Qwen/Qwen3.5-35B-A3B") + + assert report.passed is True + assert captured_env["ART_RUN_TRAIN_INF_MISMATCH_LIVE"] == "1" diff --git a/tests/integration/megatron/train_inf_mismatch/test_qwen35_vllm_lora_layout.py b/tests/integration/megatron/train_inf_mismatch/test_qwen35_vllm_lora_layout.py new file mode 100644 index 000000000..5fe449f44 --- /dev/null +++ b/tests/integration/megatron/train_inf_mismatch/test_qwen35_vllm_lora_layout.py @@ -0,0 +1,227 @@ +import torch + +from art.megatron.model_support.handlers import QWEN3_5_MOE_HANDLER + + +def _config(base_model: str, *, rank: int) -> dict: + return { + "base_model_name_or_path": base_model, + "r": rank, + "lora_alpha": rank, + "target_modules": [ + "in_proj_qkv", + "in_proj_z", + "out_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + "bias": "none", + } + + +def _small_q_gate_config(*, rank: int) -> dict: + config = _config("Qwen/Qwen3.5-35B-A3B", rank=rank) + config.update( + { + "num_attention_heads": 4, + "num_key_value_heads": 2, + "head_dim": 3, + } + ) + return config + + +def _sentinel( + expert: int, + module_id: int, + lora_id: int, + shape: tuple[int, int], +) -> torch.Tensor: + return ( + torch.arange(shape[0] * shape[1], dtype=torch.float32).reshape(shape) + + expert * 10_000 + + module_id * 1_000 + + lora_id * 100 + ) + + +def _qwen35_art_moe_tensors( + prefix: str, + *, + num_experts: int, + rank: int, + hidden: int, + intermediate: int, +) -> dict[str, torch.Tensor]: + tensors: dict[str, torch.Tensor] = {} + module_ids = {"gate_up_proj": 1, "down_proj": 2} + for expert in range(num_experts): + for module, module_id in module_ids.items(): + in_dim = intermediate if module == "down_proj" else hidden + out_dim = hidden if module == "down_proj" else 2 * intermediate + module_prefix = f"{prefix}.mlp.experts.{expert}.{module}" + tensors[f"{module_prefix}.lora_A.weight"] = _sentinel( + expert, + module_id, + 0, + (rank, in_dim), + ) + tensors[f"{module_prefix}.lora_B.weight"] = _sentinel( + expert, + module_id, + 1, + (out_dim, rank), + ) + return tensors + + +def _q_proj_lora_b_to_vllm_expected( + tensor: torch.Tensor, + *, + num_heads: int, + num_groups: int, + head_dim: int, +) -> torch.Tensor: + heads_per_group = num_heads // num_groups + grouped = tensor.reshape(num_groups, 2 * heads_per_group, head_dim, tensor.shape[1]) + query = grouped[:, :heads_per_group] + gate = grouped[:, heads_per_group:] + return torch.cat((query, gate), dim=2).reshape(tensor.shape).contiguous() + + +def test_qwen35_q_proj_lora_b_translates_grouped_gate_layout() -> None: + rank = 2 + num_heads = 4 + num_groups = 2 + head_dim = 3 + rows = num_groups * 2 * (num_heads // num_groups) * head_dim + art_key = "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight" + vllm_key = ( + "base_model.model.model.language_model.layers.0.self_attn.q_proj.lora_B.weight" + ) + art_tensor = torch.arange(rows * rank, dtype=torch.float32).reshape(rows, rank) + adapter_config = _small_q_gate_config(rank=rank) + + vllm_tensors, vllm_config = QWEN3_5_MOE_HANDLER.to_vllm_lora_tensors( + {art_key: art_tensor}, + adapter_config=adapter_config, + ) + + assert vllm_config == adapter_config + assert torch.equal( + vllm_tensors[vllm_key], + _q_proj_lora_b_to_vllm_expected( + art_tensor, + num_heads=num_heads, + num_groups=num_groups, + head_dim=head_dim, + ), + ) + roundtrip = QWEN3_5_MOE_HANDLER.from_vllm_lora_tensors( + vllm_tensors, + adapter_config=adapter_config, + ) + assert torch.equal(roundtrip[art_key], art_tensor) + + +def test_qwen35_moe_layout_exports_vllm_3d_without_rank_rewrite() -> None: + rank = 2 + hidden = 3 + intermediate = 4 + num_experts = 4 + art_prefix = "base_model.model.model.layers.0" + vllm_prefix = "base_model.model.model.language_model.layers.0.mlp.experts" + art_tensors = _qwen35_art_moe_tensors( + art_prefix, + num_experts=num_experts, + rank=rank, + hidden=hidden, + intermediate=intermediate, + ) + + vllm_tensors, vllm_config = QWEN3_5_MOE_HANDLER.to_vllm_lora_tensors( + art_tensors, + adapter_config=_config("Qwen/Qwen3.5-35B-A3B", rank=rank), + ) + + assert vllm_config["r"] == rank + assert vllm_config["lora_alpha"] == rank + assert vllm_config["target_modules"] == [ + "in_proj_qkv", + "in_proj_z", + "out_proj", + "experts", + ] + assert set(vllm_tensors) == { + f"{vllm_prefix}.base_layer.lora_A.weight", + f"{vllm_prefix}.base_layer.lora_B.weight", + f"{vllm_prefix}.lora_A.weight", + f"{vllm_prefix}.lora_B.weight", + } + assert vllm_tensors[f"{vllm_prefix}.base_layer.lora_A.weight"].shape == ( + num_experts * rank, + hidden, + ) + assert vllm_tensors[f"{vllm_prefix}.base_layer.lora_B.weight"].shape == ( + 2 * intermediate, + num_experts * rank, + ) + assert vllm_tensors[f"{vllm_prefix}.lora_A.weight"].shape == ( + num_experts * rank, + intermediate, + ) + assert vllm_tensors[f"{vllm_prefix}.lora_B.weight"].shape == ( + hidden, + num_experts * rank, + ) + roundtrip = QWEN3_5_MOE_HANDLER.from_vllm_lora_tensors( + vllm_tensors, + adapter_config=vllm_config, + ) + assert set(roundtrip) == set(art_tensors) + for key, tensor in art_tensors.items(): + assert torch.equal(roundtrip[key], tensor), key + + +def test_qwen35_moe_path_keeps_dense_lora_rank_when_moe_is_present() -> None: + rank = 1 + num_heads = 4 + num_groups = 2 + head_dim = 3 + rows = num_groups * 2 * (num_heads // num_groups) * head_dim + art_prefix = "base_model.model.model.layers.0" + art_key = f"{art_prefix}.self_attn.q_proj.lora_B.weight" + vllm_key = ( + "base_model.model.model.language_model.layers.0.self_attn.q_proj.lora_B.weight" + ) + art_tensor = torch.arange(rows * rank, dtype=torch.float32).reshape(rows, rank) + art_tensors = { + **_qwen35_art_moe_tensors( + art_prefix, + num_experts=1, + rank=rank, + hidden=3, + intermediate=4, + ), + art_key: art_tensor, + } + + vllm_tensors, vllm_config = QWEN3_5_MOE_HANDLER.to_vllm_lora_tensors( + art_tensors, + adapter_config=_small_q_gate_config(rank=rank), + ) + + expected = _q_proj_lora_b_to_vllm_expected( + art_tensor, + num_heads=num_heads, + num_groups=num_groups, + head_dim=head_dim, + ) + assert vllm_config["r"] == rank + assert torch.equal(vllm_tensors[vllm_key], expected) + roundtrip = QWEN3_5_MOE_HANDLER.from_vllm_lora_tensors( + vllm_tensors, + adapter_config=vllm_config, + ) + assert torch.equal(roundtrip[art_key], art_tensor) diff --git a/tests/integration/megatron/train_inf_mismatch/workflow_stage.py b/tests/integration/megatron/train_inf_mismatch/workflow_stage.py new file mode 100644 index 000000000..296c0184d --- /dev/null +++ b/tests/integration/megatron/train_inf_mismatch/workflow_stage.py @@ -0,0 +1,85 @@ +import os +from pathlib import Path +import re +import subprocess +import sys + +from pydantic import BaseModel + +from .artifacts import REPO_ROOT, TEST_ROOT, create_artifact_dir + + +class TrainInfMismatchReport(BaseModel): + base_model: str + passed: bool + returncode: int + artifact_dir: str + test_root: str + stdout_path: str + stderr_path: str + passed_count: int + failed_count: int + skipped_count: int + + +def _pytest_counts(output: str) -> dict[str, int]: + counts = {"passed": 0, "failed": 0, "skipped": 0} + for line in reversed(output.splitlines()): + matches = re.findall(r"(\d+) (passed|failed|skipped|error|errors)", line) + if not matches: + continue + for count, kind in matches: + if kind in {"error", "errors"}: + counts["failed"] += int(count) + else: + counts[kind] += int(count) + return counts + return counts + + +def run_train_inf_mismatch(*, base_model: str) -> TrainInfMismatchReport: + artifact_dir = create_artifact_dir("workflow::train_inf_mismatch") + stdout_path = artifact_dir / "pytest_stdout.txt" + stderr_path = artifact_dir / "pytest_stderr.txt" + env = os.environ.copy() + env["BASE_MODEL"] = base_model + env["ART_RUN_TRAIN_INF_MISMATCH_LIVE"] = "1" + env["ART_TRAIN_INF_MISMATCH_BASE_MODEL"] = base_model + existing_pythonpath = env.get("PYTHONPATH") + tests_dir = str(REPO_ROOT / "tests") + env["PYTHONPATH"] = ( + tests_dir + if not existing_pythonpath + else f"{tests_dir}{os.pathsep}{existing_pythonpath}" + ) + result = subprocess.run( + [ + sys.executable, + "-m", + "pytest", + "-q", + str(TEST_ROOT), + f"--ignore={TEST_ROOT / 'artifacts'}", + "--tb=short", + ], + cwd=Path(REPO_ROOT), + env=env, + capture_output=True, + text=True, + check=False, + ) + stdout_path.write_text(result.stdout, encoding="utf-8") + stderr_path.write_text(result.stderr, encoding="utf-8") + counts = _pytest_counts(result.stdout + "\n" + result.stderr) + return TrainInfMismatchReport( + base_model=base_model, + passed=result.returncode == 0, + returncode=result.returncode, + artifact_dir=str(artifact_dir), + test_root=str(TEST_ROOT), + stdout_path=str(stdout_path), + stderr_path=str(stderr_path), + passed_count=counts["passed"], + failed_count=counts["failed"], + skipped_count=counts["skipped"], + ) diff --git a/tests/integration/megatron/trainability/__init__.py b/tests/integration/megatron/trainability/__init__.py new file mode 100644 index 000000000..a673a9653 --- /dev/null +++ b/tests/integration/megatron/trainability/__init__.py @@ -0,0 +1,31 @@ +from .yes_no_trainability import ( + TrainabilityStepReport, + YesNoTrainabilityReport, + _build_trainable_groups, + _build_training_groups, + _engine_args_for_yes_no_trainability, + _evaluate_model, + _wandb_disabled, + _warmup_model, + build_prompts, + run_megatron_dedicated_yes_no_trainability, + run_unsloth_dedicated_yes_no_trainability, + run_yes_no_trainability, + run_yes_no_trainability_async, +) + +__all__ = [ + "YesNoTrainabilityReport", + "TrainabilityStepReport", + "_build_trainable_groups", + "_build_training_groups", + "_engine_args_for_yes_no_trainability", + "_evaluate_model", + "_wandb_disabled", + "_warmup_model", + "build_prompts", + "run_megatron_dedicated_yes_no_trainability", + "run_unsloth_dedicated_yes_no_trainability", + "run_yes_no_trainability", + "run_yes_no_trainability_async", +] diff --git a/tests/integration/megatron/trainability/test_config.py b/tests/integration/megatron/trainability/test_config.py new file mode 100644 index 000000000..6004e9a9f --- /dev/null +++ b/tests/integration/megatron/trainability/test_config.py @@ -0,0 +1,199 @@ +import asyncio +from typing import cast + +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage +import pytest + +import art + +from .yes_no_trainability import ( + _build_internal_config, + _build_variant, + _default_variant_name, + _evaluate_groups, + _TrainabilityVariant, + _variant_init_args, + _variant_max_steps, + _variant_packed_sequence_length, + _variant_rollouts_per_prompt, + _variant_train_kwargs, +) + + +class _ConcurrentCompletions: + def __init__(self, expected: int) -> None: + self.expected = expected + self.started = 0 + self.active = 0 + self.max_active = 0 + self.all_started = asyncio.Event() + + async def create(self, **kwargs): + self.started += 1 + self.active += 1 + self.max_active = max(self.max_active, self.active) + if self.started == self.expected: + self.all_started.set() + try: + await asyncio.wait_for(self.all_started.wait(), timeout=1.0) + return ChatCompletion( + id=f"completion-{self.started}", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + role="assistant", + content="maybe", + ), + ) + ], + created=0, + model=str(kwargs["model"]), + object="chat.completion", + ) + finally: + self.active -= 1 + + +class _FakeChat: + def __init__(self, completions: _ConcurrentCompletions) -> None: + self.completions = completions + + +class _FakeClient: + def __init__(self, completions: _ConcurrentCompletions) -> None: + self.chat = _FakeChat(completions) + + +class _FakeModel: + def __init__(self, client: _FakeClient) -> None: + self.client = client + + def openai_client(self) -> _FakeClient: + return self.client + + def get_inference_name(self, *, step: int | None = None) -> str: + return f"fake@{step}" + + +@pytest.mark.asyncio +async def test_eval_prompts_are_submitted_concurrently() -> None: + completions = _ConcurrentCompletions(expected=3) + + groups = await _evaluate_groups( + cast(art.TrainableModel, _FakeModel(_FakeClient(completions))), + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + prompts=["a", "b", "c"], + step=1, + ) + + assert len(groups) == 3 + assert completions.started == 3 + assert completions.max_active == 3 + assert [group.trajectories[0].reward for group in groups] == [1.0, 1.0, 1.0] + + +def test_megatron_variants_keep_short_packed_sequence_default(monkeypatch) -> None: + monkeypatch.delenv("ART_MODEL_SUPPORT_YES_NO_PACKED_SEQUENCE_LENGTH", raising=False) + variant = _TrainabilityVariant( + name="megatron_shared", + backend_name="megatron", + placement_mode="shared", + trainer_gpu_ids=[0, 1], + inference_gpu_ids=[0, 1], + ) + + assert _variant_packed_sequence_length(variant) == 1024 + assert _variant_train_kwargs(variant) == {"packed_sequence_length": 1024} + config = _build_internal_config( + variant, base_model="Qwen/Qwen3-30B-A3B-Instruct-2507" + ) + assert config["init_args"]["max_seq_length"] == 1024 + assert config["rollout_weights_mode"] == "lora" + assert ( + _default_variant_name("Qwen/Qwen3-30B-A3B-Instruct-2507") == "megatron_shared" + ) + assert _variant_rollouts_per_prompt(variant) == 4 + assert _variant_max_steps(variant) == 4 + + +def test_unsloth_variant_uses_chunk_aligned_training_length(monkeypatch) -> None: + monkeypatch.delenv("ART_MODEL_SUPPORT_YES_NO_PACKED_SEQUENCE_LENGTH", raising=False) + variant = _TrainabilityVariant( + name="unsloth_dedicated", + backend_name="local", + placement_mode="dedicated", + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + ) + + assert _variant_packed_sequence_length(variant) == 1024 + assert _variant_train_kwargs(variant) == {"packed_sequence_length": 1024} + assert _variant_init_args(variant) == {"max_seq_length": 1024} + assert _build_internal_config( + variant, base_model="Qwen/Qwen3-30B-A3B-Instruct-2507" + )["init_args"] == {"max_seq_length": 1024} + assert _variant_rollouts_per_prompt(variant) == 8 + assert _variant_max_steps(variant) == 12 + + +def test_qwen3_5_defaults_to_shared_lora_rollout() -> None: + variant = _TrainabilityVariant( + name="megatron_shared", + backend_name="megatron", + placement_mode="shared", + trainer_gpu_ids=[0, 1], + inference_gpu_ids=[0, 1], + ) + + config = _build_internal_config(variant, base_model="Qwen/Qwen3.5-35B-A3B") + + assert _default_variant_name("Qwen/Qwen3.5-35B-A3B") == "megatron_shared" + assert config["rollout_weights_mode"] == "lora" + assert "trainer_gpu_ids" not in config + assert "inference_gpu_ids" not in config + + +def test_validated_dense_model_uses_dense_shared_topology( + monkeypatch, +) -> None: + monkeypatch.setenv("ART_MODEL_SUPPORT_SHARED_GPU_IDS", "0,1") + built_variant = _build_variant( + "megatron_shared", + base_model="Qwen/Qwen3.5-4B", + ) + assert built_variant.topology is not None + assert built_variant.topology.tp == 2 + assert built_variant.topology.ep == 1 + assert built_variant.topology.etp == 1 + + variant = _TrainabilityVariant( + name="megatron_shared", + backend_name="megatron", + placement_mode="shared", + trainer_gpu_ids=[0, 1], + inference_gpu_ids=[0, 1], + ) + + config = _build_internal_config(variant, base_model="Qwen/Qwen3.5-4B") + assert config["rollout_weights_mode"] == "lora" + assert config["engine_args"]["enable_sleep_mode"] is True + assert "enable_expert_parallel" not in config["engine_args"] + + +def test_qwen3_5_moe_shared_variant_enables_expert_parallel(monkeypatch) -> None: + monkeypatch.setenv("ART_MODEL_SUPPORT_SHARED_GPU_IDS", "0,1") + variant = _TrainabilityVariant( + name="megatron_shared", + backend_name="megatron", + placement_mode="shared", + trainer_gpu_ids=[0, 1], + inference_gpu_ids=[0, 1], + ) + + config = _build_internal_config(variant, base_model="Qwen/Qwen3.5-35B-A3B") + + assert config["rollout_weights_mode"] == "lora" + assert config["engine_args"]["enable_expert_parallel"] is True diff --git a/tests/integration/megatron/trainability/test_live_yes_no_trainability.py b/tests/integration/megatron/trainability/test_live_yes_no_trainability.py new file mode 100644 index 000000000..119d3b74a --- /dev/null +++ b/tests/integration/megatron/trainability/test_live_yes_no_trainability.py @@ -0,0 +1,106 @@ +import json +import os +from pathlib import Path + +import pytest + +from .yes_no_trainability import run_yes_no_trainability_async + +torch = pytest.importorskip("torch") + +DEFAULT_BASE_MODEL = "Qwen/Qwen3-30B-A3B-Instruct-2507" +LIVE_ENV = "ART_RUN_LIVE_YES_NO_TRAINABILITY" + + +def _require_opt_in() -> None: + if os.environ.get(LIVE_ENV) != "1": + pytest.skip(f"set {LIVE_ENV}=1 to run live yes/no trainability validation") + + +def _base_model() -> str: + return os.environ.get( + "ART_LIVE_YES_NO_BASE_MODEL", + os.environ.get("BASE_MODEL", DEFAULT_BASE_MODEL), + ) + + +def _unsloth_base_model() -> str: + return os.environ.get("ART_LIVE_UNSLOTH_YES_NO_BASE_MODEL", _base_model()) + + +def _assert_passed(report) -> None: + assert report.saturated_step is not None + assert report.saturated_step > 0 + assert report.initial_eval_reward < report.reward_threshold + assert report.final_eval_reward is not None + assert report.final_eval_reward >= report.reward_threshold + assert report.final_eval_reward > report.initial_eval_reward + assert report.latest_step > 0 + assert report.step0_name in report.model_ids_before + assert report.latest_name in report.model_ids_after + if report.rollout_weights_mode == "merged": + assert report.step0_name not in report.model_ids_after + else: + assert report.step0_name in report.model_ids_after + assert report.latest_snapshot["has_logprobs"] is True + + +def _write_report(artifact_dir: Path, name: str, report) -> None: + (artifact_dir / name).write_text( + json.dumps(report.model_dump(mode="json"), indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Need at least 2 CUDA GPUs for live yes/no trainability validation", +) +@pytest.mark.asyncio +async def test_megatron_shared_yes_no_trainability_live( + artifact_dir: Path, +) -> None: + _require_opt_in() + report = await run_yes_no_trainability_async( + base_model=_base_model(), + variant_name="megatron_shared", + artifact_root=artifact_dir / "megatron_shared_workspace", + ) + _write_report(artifact_dir, "megatron_shared_yes_no_trainability.json", report) + _assert_passed(report) + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Need at least 2 CUDA GPUs for live yes/no trainability validation", +) +@pytest.mark.asyncio +async def test_megatron_dedicated_yes_no_trainability_live( + artifact_dir: Path, +) -> None: + _require_opt_in() + report = await run_yes_no_trainability_async( + base_model=_base_model(), + variant_name="megatron_dedicated", + artifact_root=artifact_dir / "megatron_dedicated_workspace", + ) + _write_report(artifact_dir, "megatron_dedicated_yes_no_trainability.json", report) + _assert_passed(report) + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Need at least 2 CUDA GPUs for live yes/no trainability validation", +) +@pytest.mark.asyncio +async def test_unsloth_dedicated_yes_no_trainability_live( + artifact_dir: Path, +) -> None: + _require_opt_in() + report = await run_yes_no_trainability_async( + base_model=_unsloth_base_model(), + variant_name="unsloth_dedicated", + artifact_root=artifact_dir / "unsloth_dedicated_workspace", + ) + _write_report(artifact_dir, "unsloth_dedicated_yes_no_trainability.json", report) + _assert_passed(report) diff --git a/tests/integration/megatron/trainability/yes_no_trainability.py b/tests/integration/megatron/trainability/yes_no_trainability.py new file mode 100644 index 000000000..8f4850505 --- /dev/null +++ b/tests/integration/megatron/trainability/yes_no_trainability.py @@ -0,0 +1,823 @@ +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager, contextmanager, nullcontext +import gc +from itertools import permutations +import os +from pathlib import Path +import re +import time +from typing import Any, AsyncIterator, Iterator, Literal, TypedDict, cast +import uuid + +from pydantic import BaseModel, Field +import torch + +import art +from art import dev +from art.local import LocalBackend +from art.megatron.model_support.registry import ( + get_model_support_spec, + model_uses_expert_parallel, +) +from art.megatron.model_support.spec import RolloutWeightsMode +from art.megatron.runtime.backend import MegatronBackend + +from ..model_support.oracle_harness import Topology, oracle_topology +from ..model_support.oracle_worker import provider_topology_env + +_TRAINER_GPU_IDS_ENV = "ART_MODEL_SUPPORT_TRAINER_GPU_IDS" +_INFERENCE_GPU_IDS_ENV = "ART_MODEL_SUPPORT_INFERENCE_GPU_IDS" +_SHARED_GPU_IDS_ENV = "ART_MODEL_SUPPORT_SHARED_GPU_IDS" +_TRAINABILITY_ROOT = ( + Path(__file__).resolve().parents[4] / ".local" / "model_support_validation" +) +_SHARED_MEGATRON_TOPOLOGY = Topology(tp=2, ep=2, etp=1, dp=1, sp=True) +_DENSE_SHARED_MEGATRON_TOPOLOGY = Topology(tp=2, ep=1, etp=1, dp=1, sp=True) +_VARIANT_NAME = Literal[ + "megatron_shared", + "megatron_dedicated", + "unsloth_dedicated", +] + + +class _TrainKwargs(TypedDict): + packed_sequence_length: int + + +class TrainabilityStepReport(BaseModel): + step: int + eval_reward: float + train_reward: float + train_metrics: dict[str, float] = Field(default_factory=dict) + + +class YesNoTrainabilityReport(BaseModel): + variant: _VARIANT_NAME + backend_name: Literal["megatron", "local"] + placement_mode: Literal["shared", "dedicated"] + base_model: str + output_dir: str + trainer_gpu_ids: list[int] + inference_gpu_ids: list[int] + rollout_weights_mode: str + reward_threshold: float + max_steps: int + prompt_count: int + eval_prompt_count: int + rollouts_per_prompt: int + latest_step: int + initial_eval_reward: float + final_eval_reward: float | None = None + saturated_step: int | None = None + step0_name: str + latest_name: str + model_ids_before: list[str] = Field(default_factory=list) + model_ids_after: list[str] = Field(default_factory=list) + latest_snapshot: dict[str, object] = Field(default_factory=dict) + steps: list[TrainabilityStepReport] = Field(default_factory=list) + + +class _TrainabilityVariant(BaseModel): + name: _VARIANT_NAME + backend_name: Literal["megatron", "local"] + placement_mode: Literal["shared", "dedicated"] + topology: Topology | None = None + trainer_gpu_ids: list[int] = Field(default_factory=list) + inference_gpu_ids: list[int] = Field(default_factory=list) + + +def build_prompts() -> list[str]: + prompt = os.environ.get("ART_MODEL_SUPPORT_YES_NO_PROMPT", "").strip() + prompt_count = _get_env_int("ART_MODEL_SUPPORT_YES_NO_PROMPT_COUNT", 8) + if prompt: + return [prompt] * max(1, prompt_count) + prompts = [ + f"{prefix} exactly one of {body}" + for prefix in ("respond with", "just respond with") + for use_quotes in (True, False) + for length in (3, 2) + for words in permutations(("yes", "no", "maybe"), length) + for body in [ + ", ".join(f"'{word}'" if use_quotes else word for word in words) + if length == 3 + else " or ".join(f"'{word}'" if use_quotes else word for word in words) + ] + ] + if prompt_count <= len(prompts): + return prompts[: max(1, prompt_count)] + return [prompts[index % len(prompts)] for index in range(prompt_count)] + + +def _slugify(value: str) -> str: + return value.lower().replace("/", "_").replace(".", "_").replace("-", "_") + + +def _parse_gpu_id_env(name: str) -> list[int] | None: + raw = os.environ.get(name) + if raw is None or raw.strip() == "": + return None + return [int(part.strip()) for part in raw.split(",") if part.strip()] + + +def _resolve_shared_gpu_ids() -> list[int]: + if shared_gpu_ids := _parse_gpu_id_env(_SHARED_GPU_IDS_ENV): + return shared_gpu_ids + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + raise RuntimeError("Need at least 2 visible CUDA GPUs for shared trainability") + return [0, 1] + + +def _resolve_dedicated_gpu_ids() -> tuple[list[int], list[int]]: + trainer_gpu_ids = _parse_gpu_id_env(_TRAINER_GPU_IDS_ENV) + inference_gpu_ids = _parse_gpu_id_env(_INFERENCE_GPU_IDS_ENV) + if trainer_gpu_ids is not None or inference_gpu_ids is not None: + if trainer_gpu_ids is None or inference_gpu_ids is None: + raise RuntimeError( + f"{_TRAINER_GPU_IDS_ENV} and {_INFERENCE_GPU_IDS_ENV} must both be set" + ) + return trainer_gpu_ids, inference_gpu_ids + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + raise RuntimeError( + "Need at least 2 visible CUDA GPUs for dedicated trainability" + ) + return [0], [1] + + +def _safe_gpu_memory_utilization(device_ids: list[int]) -> float: + requested = float( + os.environ.get("ART_MODEL_SUPPORT_YES_NO_GPU_MEMORY_UTILIZATION", "0.85") + ) + min_free_gib = float( + os.environ.get("ART_MODEL_SUPPORT_YES_NO_MIN_FREE_GPU_GIB", "8") + ) + min_utilization = min( + requested, + float( + os.environ.get( + "ART_MODEL_SUPPORT_YES_NO_MIN_GPU_MEMORY_UTILIZATION", + "0.5", + ) + ), + ) + attempts = _get_env_int("ART_MODEL_SUPPORT_YES_NO_GPU_MEMORY_RETRY_ATTEMPTS", 12) + sleep_s = _get_env_float("ART_MODEL_SUPPORT_YES_NO_GPU_MEMORY_RETRY_SLEEP_S", 5.0) + devices = sorted(set(device_ids)) + last_message = "no GPU memory samples collected" + + for attempt in range(attempts): + free_ratios: list[float] = [] + low_free: list[str] = [] + for device in devices: + free_bytes, total_bytes = torch.cuda.mem_get_info(device) + free_gib = free_bytes / (1024**3) + if free_gib < min_free_gib: + low_free.append( + f"GPU {device} has only {free_gib:.1f} GiB free < {min_free_gib:.1f} GiB required" + ) + free_ratios.append(free_bytes / total_bytes) + + utilization = max(0.02, min(requested, min(free_ratios) * 0.95)) + if not low_free and utilization >= min_utilization: + return utilization + + ratio_summary = ", ".join( + f"GPU {device}: free_ratio={ratio:.3f}" + for device, ratio in zip(devices, free_ratios, strict=True) + ) + last_message = "; ".join( + [ + *low_free, + f"computed gpu_memory_utilization={utilization:.3f}", + ratio_summary, + ] + ) + if attempt == attempts - 1: + break + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + time.sleep(sleep_s) + + raise RuntimeError( + "Unable to recover enough free GPU memory for yes/no validation runtime startup. " + f"{last_message}" + ) + + +def reward_for_answer(text: str) -> float: + return {"yes": 0.5, "no": 0.75, "maybe": 1.0}.get( + first_word_for_answer(text).lower(), + 0.0, + ) + + +def first_word_for_answer(text: str | None) -> str: + if not text: + return "" + stripped = re.sub( + r".*?\s*", + "", + text, + flags=re.IGNORECASE | re.DOTALL, + ) + first_word = stripped.strip().split(maxsplit=1) + if not first_word: + return "" + return first_word[0].strip(".,!?:;\"'()[]{}") + + +def _get_env_int(name: str, default: int) -> int: + return int(os.environ.get(name, str(default))) + + +def _get_env_float(name: str, default: float) -> float: + return float(os.environ.get(name, str(default))) + + +def _get_env_bool(name: str, default: bool) -> bool: + raw = os.environ.get(name) + if raw is None: + return default + lowered = raw.strip().lower() + if lowered in {"1", "true", "yes", "on"}: + return True + if lowered in {"0", "false", "no", "off"}: + return False + raise ValueError(f"Invalid boolean value for {name}: {raw!r}") + + +def _max_tokens() -> int: + return _get_env_int("ART_MODEL_SUPPORT_YES_NO_MAX_TOKENS", 5) + + +def _render_chat_messages(base_model: str, prompt: str) -> art.Messages: + del base_model + return [{"role": "user", "content": prompt}] + + +def _enable_thinking() -> bool: + return os.environ.get( + "ART_MODEL_SUPPORT_YES_NO_ENABLE_THINKING", "" + ).strip().lower() in {"1", "true", "yes", "on"} + + +def _extra_body() -> dict[str, object]: + return {"chat_template_kwargs": {"enable_thinking": _enable_thinking()}} + + +def _request_timeout(name: str, default: float) -> float: + return _get_env_float(name, default) + + +def _engine_args_for_yes_no_trainability( + *, + inference_gpu_ids: list[int], + tensor_parallel_size: int = 1, + enable_expert_parallel: bool = False, + enable_sleep_mode: bool | None = None, +) -> dev.EngineArgs: + engine_args: dict[str, object] = { + "gpu_memory_utilization": _safe_gpu_memory_utilization(inference_gpu_ids), + "max_model_len": _get_env_int("ART_MODEL_SUPPORT_YES_NO_MAX_MODEL_LEN", 128), + "max_num_seqs": _get_env_int("ART_MODEL_SUPPORT_YES_NO_MAX_NUM_SEQS", 4), + "enforce_eager": True, + "tensor_parallel_size": tensor_parallel_size, + } + if enable_expert_parallel: + engine_args["enable_expert_parallel"] = True + if enable_sleep_mode is not None: + engine_args["enable_sleep_mode"] = enable_sleep_mode + return cast(dev.EngineArgs, engine_args) + + +@contextmanager +def _wandb_disabled() -> Iterator[None]: + saved = {name: os.environ.get(name) for name in ("WANDB_API_KEY", "WANDB_MODE")} + os.environ.pop("WANDB_API_KEY", None) + os.environ["WANDB_MODE"] = "disabled" + try: + yield + finally: + for name, value in saved.items(): + if value is None: + os.environ.pop(name, None) + else: + os.environ[name] = value + + +def _artifact_dir(base_model: str, variant_name: _VARIANT_NAME) -> Path: + path = ( + _TRAINABILITY_ROOT / _slugify(base_model) / variant_name / uuid.uuid4().hex[:8] + ) + path.mkdir(parents=True, exist_ok=True) + return path + + +def _build_variant( + variant_name: _VARIANT_NAME, + *, + base_model: str, + allow_unvalidated_arch: bool = False, +) -> _TrainabilityVariant: + is_moe = model_uses_expert_parallel( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + if variant_name == "megatron_shared": + shared_gpu_ids = _resolve_shared_gpu_ids() + return _TrainabilityVariant( + name=variant_name, + backend_name="megatron", + placement_mode="shared", + topology=_SHARED_MEGATRON_TOPOLOGY + if is_moe + else _DENSE_SHARED_MEGATRON_TOPOLOGY, + trainer_gpu_ids=shared_gpu_ids, + inference_gpu_ids=shared_gpu_ids, + ) + trainer_gpu_ids, inference_gpu_ids = _resolve_dedicated_gpu_ids() + if variant_name == "megatron_dedicated": + return _TrainabilityVariant( + name=variant_name, + backend_name="megatron", + placement_mode="dedicated", + topology=oracle_topology(is_moe=is_moe), + trainer_gpu_ids=trainer_gpu_ids, + inference_gpu_ids=inference_gpu_ids, + ) + return _TrainabilityVariant( + name=variant_name, + backend_name="local", + placement_mode="dedicated", + trainer_gpu_ids=trainer_gpu_ids, + inference_gpu_ids=inference_gpu_ids, + ) + + +def _variant_packed_sequence_length(variant: _TrainabilityVariant) -> int: + return _get_env_int("ART_MODEL_SUPPORT_YES_NO_PACKED_SEQUENCE_LENGTH", 1024) + + +def _variant_train_kwargs(variant: _TrainabilityVariant) -> _TrainKwargs: + return {"packed_sequence_length": _variant_packed_sequence_length(variant)} + + +def _variant_init_args(variant: _TrainabilityVariant) -> dev.InitArgs: + return {"max_seq_length": _variant_packed_sequence_length(variant)} + + +def _variant_max_steps(variant: _TrainabilityVariant) -> int: + default = 12 if variant.backend_name == "local" else 4 + return _get_env_int("ART_MODEL_SUPPORT_YES_NO_MAX_STEPS", default) + + +def _variant_rollouts_per_prompt(variant: _TrainabilityVariant) -> int: + default = 8 if variant.backend_name == "local" else 4 + return _get_env_int("ART_MODEL_SUPPORT_YES_NO_ROLLOUTS_PER_PROMPT", default) + + +def _rollout_weights_mode( + base_model: str, + *, + allow_unvalidated_arch: bool = False, +) -> RolloutWeightsMode: + return get_model_support_spec( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ).default_rollout_weights_mode + + +def _default_variant_name( + base_model: str, + *, + allow_unvalidated_arch: bool = False, +) -> _VARIANT_NAME: + if ( + _rollout_weights_mode( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + == "merged" + ): + return "megatron_dedicated" + return "megatron_shared" + + +def _build_internal_config( + variant: _TrainabilityVariant, + *, + base_model: str, + rollout_weights_mode: RolloutWeightsMode | None = None, + allow_unvalidated_arch: bool = False, +) -> dev.InternalModelConfig: + shared = variant.placement_mode == "shared" + inference_gpu_ids = ( + variant.inference_gpu_ids if not shared else _resolve_shared_gpu_ids() + ) + engine_args = _engine_args_for_yes_no_trainability( + inference_gpu_ids=inference_gpu_ids, + tensor_parallel_size=len(inference_gpu_ids) if shared else 1, + enable_expert_parallel=( + shared + and variant.backend_name == "megatron" + and model_uses_expert_parallel( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + ), + enable_sleep_mode=True if shared else None, + ) + engine_args["model"] = base_model + internal_config = dev.InternalModelConfig( + rollout_weights_mode=rollout_weights_mode + or _rollout_weights_mode( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ), + engine_args=engine_args, + init_args=_variant_init_args(variant), + allow_unvalidated_arch=allow_unvalidated_arch, + ) + if not shared: + internal_config["trainer_gpu_ids"] = variant.trainer_gpu_ids + internal_config["inference_gpu_ids"] = variant.inference_gpu_ids + dev.validate_dedicated_config(internal_config) + return internal_config + + +@asynccontextmanager +async def _backend_context( + variant: _TrainabilityVariant, + *, + backend_root: Path, +) -> AsyncIterator[LocalBackend | MegatronBackend]: + with _wandb_disabled(): + topology_context = ( + provider_topology_env(variant.topology) + if variant.topology is not None + else nullcontext() + ) + with topology_context: + if variant.backend_name == "megatron": + async with MegatronBackend( + path=str(backend_root), + in_process=False, + ) as backend: + yield backend + return + async with LocalBackend(path=str(backend_root)) as backend: + yield backend + + +async def _list_model_ids(model: art.TrainableModel) -> list[str]: + client = model.openai_client() + return [model_info.id async for model_info in client.models.list()] + + +async def _chat_snapshot(model: art.TrainableModel, *, step: int) -> dict[str, object]: + client = model.openai_client() + completion = await client.chat.completions.create( + messages=[{"role": "user", "content": "Say hello."}], + model=model.get_inference_name(step=step), + max_tokens=8, + timeout=180.0, + logprobs=True, + top_logprobs=0, + ) + return { + "text": completion.choices[0].message.content, + "has_logprobs": completion.choices[0].logprobs is not None, + } + + +async def _evaluate_groups( + model: art.TrainableModel, + *, + base_model: str, + prompts: list[str], + step: int, +) -> list[art.TrajectoryGroup]: + client = model.openai_client() + + async def _group_for_prompt(prompt: str) -> art.TrajectoryGroup: + messages = _render_chat_messages(base_model, prompt) + completion = await client.chat.completions.create( + messages=messages, + model=model.get_inference_name(step=step), + max_tokens=_max_tokens(), + extra_body=_extra_body(), + temperature=_get_env_float( + "ART_MODEL_SUPPORT_YES_NO_EVAL_TEMPERATURE", + 0.0, + ), + timeout=_request_timeout("ART_MODEL_SUPPORT_YES_NO_EVAL_TIMEOUT", 180.0), + ) + choice = completion.choices[0] + return art.TrajectoryGroup( + [ + art.Trajectory( + messages_and_choices=[*messages, choice], + reward=reward_for_answer(choice.message.content or ""), + ) + ] + ) + + return await art.gather_trajectory_groups( + [_group_for_prompt(prompt) for prompt in prompts] # ty: ignore[invalid-argument-type] + ) + + +def _mean_group_reward(groups: list[art.TrajectoryGroup]) -> float: + rewards = [ + trajectory.reward for group in groups for trajectory in group.trajectories + ] + return sum(rewards) / max(1, len(rewards)) + + +async def _evaluate_model( + model: art.TrainableModel, + *, + base_model: str, + prompts: list[str], + step: int, +) -> float: + return _mean_group_reward( + await _evaluate_groups( + model, + base_model=base_model, + prompts=prompts, + step=step, + ) + ) + + +async def _build_training_groups( + model: art.TrainableModel, + *, + base_model: str, + prompts: list[str], + rollouts_per_prompt: int, +) -> list[art.TrajectoryGroup]: + client = model.openai_client() + + async def _group_for_prompt(prompt: str) -> art.TrajectoryGroup: + messages = _render_chat_messages(base_model, prompt) + completion = await client.chat.completions.create( + messages=messages, + model=model.get_inference_name(), + max_tokens=_max_tokens(), + n=rollouts_per_prompt, + extra_body=_extra_body(), + temperature=_get_env_float( + "ART_MODEL_SUPPORT_YES_NO_ROLLOUT_TEMPERATURE", + 1.2, + ), + timeout=_request_timeout( + "ART_MODEL_SUPPORT_YES_NO_ROLLOUT_TIMEOUT", + 180.0, + ), + ) + return art.TrajectoryGroup( + [ + art.Trajectory( + messages_and_choices=[*messages, choice], + reward=reward_for_answer(choice.message.content or ""), + ) + for choice in completion.choices + ] + ) + + return await art.gather_trajectory_groups( + [_group_for_prompt(prompt) for prompt in prompts] # ty: ignore[invalid-argument-type] + ) + + +def _group_has_reward_variance(group: art.TrajectoryGroup) -> bool: + return len({trajectory.reward for trajectory in group.trajectories}) > 1 + + +async def _build_trainable_groups( + model: art.TrainableModel, + *, + base_model: str, + prompts: list[str], + rollouts_per_prompt: int, +) -> list[art.TrajectoryGroup]: + max_attempts = _get_env_int("ART_MODEL_SUPPORT_YES_NO_MAX_ROLLOUT_ATTEMPTS", 4) + for _ in range(max_attempts): + groups = await _build_training_groups( + model, + base_model=base_model, + prompts=prompts, + rollouts_per_prompt=rollouts_per_prompt, + ) + trainable_groups = [ + group for group in groups if _group_has_reward_variance(group) + ] + if trainable_groups: + return trainable_groups + raise RuntimeError( + "No reward-variant trajectory groups were produced for yes/no trainability" + ) + + +async def _warmup_model( + model: art.TrainableModel, + *, + base_model: str, + prompt: str, +) -> None: + client = model.openai_client() + await client.chat.completions.create( + messages=_render_chat_messages(base_model, prompt), + model=model.get_inference_name(step=0), + max_tokens=1, + extra_body=_extra_body(), + temperature=0.0, + timeout=_request_timeout("ART_MODEL_SUPPORT_YES_NO_WARMUP_TIMEOUT", 900.0), + ) + + +async def run_yes_no_trainability_async( + *, + base_model: str, + variant_name: _VARIANT_NAME = "megatron_shared", + artifact_root: Path | None = None, + rollout_weights_mode: RolloutWeightsMode | None = None, + allow_unvalidated_arch: bool = False, +) -> YesNoTrainabilityReport: + variant = _build_variant( + variant_name, + base_model=base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + backend_root = artifact_root or _artifact_dir(base_model, variant.name) + backend_root.mkdir(parents=True, exist_ok=True) + reward_threshold = _get_env_float("ART_MODEL_SUPPORT_YES_NO_REWARD_THRESHOLD", 0.9) + max_steps = _variant_max_steps(variant) + rollouts_per_prompt = _variant_rollouts_per_prompt(variant) + eval_prompt_count = _get_env_int("ART_MODEL_SUPPORT_YES_NO_EVAL_PROMPTS", 8) + prompts = build_prompts() + eval_prompts = prompts[:eval_prompt_count] + internal_config = _build_internal_config( + variant, + base_model=base_model, + rollout_weights_mode=rollout_weights_mode, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + rollout_weights_mode = internal_config["rollout_weights_mode"] + model = art.TrainableModel( + name=f"{variant.name}-{uuid.uuid4().hex[:8]}", + project="model-support-validation", + base_model=base_model, + _internal_config=internal_config, + report_metrics=[], + ) + train_kwargs = _variant_train_kwargs(variant) + + async with _backend_context(variant, backend_root=backend_root) as backend: + await model.register(backend) + output_dir = Path(model.base_path) / model.project / "models" / model.name + await _warmup_model(model, base_model=base_model, prompt=prompts[0]) + step0_name = model.get_inference_name(step=0) + model_ids_before = await _list_model_ids(model) + initial_eval_groups = await _evaluate_groups( + model, + base_model=base_model, + prompts=eval_prompts, + step=0, + ) + initial_eval_reward = _mean_group_reward(initial_eval_groups) + await model.log(initial_eval_groups, step=0, split="val") + report = YesNoTrainabilityReport( + variant=variant.name, + backend_name=variant.backend_name, + placement_mode=variant.placement_mode, + base_model=base_model, + output_dir=str(output_dir), + trainer_gpu_ids=variant.trainer_gpu_ids, + inference_gpu_ids=variant.inference_gpu_ids, + rollout_weights_mode=rollout_weights_mode, + reward_threshold=reward_threshold, + max_steps=max_steps, + prompt_count=len(prompts), + eval_prompt_count=len(eval_prompts), + rollouts_per_prompt=rollouts_per_prompt, + latest_step=0, + initial_eval_reward=initial_eval_reward, + step0_name=step0_name, + latest_name=step0_name, + model_ids_before=model_ids_before, + ) + + for _ in range(max_steps): + train_groups = await _build_trainable_groups( + model, + base_model=base_model, + prompts=prompts, + rollouts_per_prompt=rollouts_per_prompt, + ) + result = await backend.train( + model, + train_groups, + learning_rate=_get_env_float( + "ART_MODEL_SUPPORT_YES_NO_LEARNING_RATE", + 1e-4, + ), + loss_fn="cispo", + packed_sequence_length=train_kwargs["packed_sequence_length"], + ) + await model.log( + train_groups, + metrics=result.metrics, + step=result.step, + split="train", + ) + eval_groups = await _evaluate_groups( + model, + base_model=base_model, + prompts=eval_prompts, + step=result.step, + ) + eval_reward = _mean_group_reward(eval_groups) + await model.log(eval_groups, step=result.step, split="val") + report.latest_step = int(result.step) + report.latest_name = model.get_inference_name(step=result.step) + report.final_eval_reward = float(eval_reward) + report.steps.append( + TrainabilityStepReport( + step=int(result.step), + eval_reward=float(eval_reward), + train_reward=sum( + trajectory.reward + for group in train_groups + for trajectory in group.trajectories + ) + / max(1, sum(len(group.trajectories) for group in train_groups)), + train_metrics={ + key: float(value) + for key, value in result.metrics.items() + if isinstance(value, int | float) + }, + ) + ) + if eval_reward >= reward_threshold: + report.saturated_step = int(result.step) + break + + report.model_ids_after = await _list_model_ids(model) + report.latest_snapshot = await _chat_snapshot(model, step=report.latest_step) + + output_dir = Path(report.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + (output_dir / "report.json").write_text( + report.model_dump_json(indent=2), + encoding="utf-8", + ) + return report + + +def run_yes_no_trainability( + base_model: str, + *, + allow_unvalidated_arch: bool = False, +) -> YesNoTrainabilityReport: + return asyncio.run( + run_yes_no_trainability_async( + base_model=base_model, + variant_name=_default_variant_name( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ), + allow_unvalidated_arch=allow_unvalidated_arch, + ) + ) + + +def run_megatron_dedicated_yes_no_trainability( + base_model: str, + *, + rollout_weights_mode: RolloutWeightsMode | None = None, +) -> YesNoTrainabilityReport: + return asyncio.run( + run_yes_no_trainability_async( + base_model=base_model, + variant_name="megatron_dedicated", + rollout_weights_mode=rollout_weights_mode, + ) + ) + + +def run_unsloth_dedicated_yes_no_trainability( + base_model: str, +) -> YesNoTrainabilityReport: + return asyncio.run( + run_yes_no_trainability_async( + base_model=base_model, + variant_name="unsloth_dedicated", + ) + ) diff --git a/tests/integration/test_lora_quack_cutover.py b/tests/integration/test_lora_quack_cutover.py index 77ecd42c7..71e4a9df5 100644 --- a/tests/integration/test_lora_quack_cutover.py +++ b/tests/integration/test_lora_quack_cutover.py @@ -5,7 +5,7 @@ pytest.importorskip("quack") -from art.megatron.cute_grouped_lora_quack import quack_grouped_lora_dual +from art.megatron.kernels.cute_grouped_lora_quack import quack_grouped_lora_dual from art.megatron.lora import LoRA diff --git a/tests/integration/test_megatron_provider_support.py b/tests/integration/test_megatron_provider_support.py deleted file mode 100644 index e6030c902..000000000 --- a/tests/integration/test_megatron_provider_support.py +++ /dev/null @@ -1,164 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any, cast - -import pytest - -pytest.importorskip("megatron.bridge") -pytest.importorskip("megatron.bridge.models.qwen.qwen3_moe_bridge") -pytest.importorskip("megatron.bridge.models.qwen_vl.qwen35_vl_bridge") - -from megatron.bridge.models.qwen.qwen3_moe_bridge import Qwen3MoEBridge -from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import Qwen35VLMoEBridge -from megatron.core.transformer.enums import AttnBackend - -from art.megatron.flex_attention import FlexDotProductAttention -import art.megatron.provider as provider_module - - -class _FakeProvider: - def __init__(self) -> None: - self.transformer_layer_spec = self._base_layer_spec - self.finalized = False - - def _base_layer_spec( - self, config: object, vp_stage: int | None = None - ) -> SimpleNamespace: - return SimpleNamespace( - submodules=SimpleNamespace( - self_attention=SimpleNamespace( - submodules=SimpleNamespace(core_attention=object()) - ) - ), - ) - - def finalize(self) -> None: - self.finalized = True - - -class _FakeHybridProvider(_FakeProvider): - def _base_layer_spec( - self, config: object, vp_stage: int | None = None - ) -> SimpleNamespace: - del config, vp_stage - gdn_layer = SimpleNamespace( - submodules=SimpleNamespace( - self_attention=SimpleNamespace(submodules=SimpleNamespace()) - ) - ) - attention_layer = SimpleNamespace( - submodules=SimpleNamespace( - self_attention=SimpleNamespace( - submodules=SimpleNamespace(core_attention=object()) - ) - ) - ) - return SimpleNamespace(layer_specs=[gdn_layer, attention_layer]) - - -class _FakeBridge: - def __init__(self, *, model_bridge: object, provider: _FakeProvider) -> None: - self._model_bridge = model_bridge - self._provider = provider - self.hf_pretrained = SimpleNamespace(model_name_or_path="unused") - - def to_megatron_provider(self) -> _FakeProvider: - return self._provider - - -@pytest.mark.parametrize("bridge_type", [Qwen3MoEBridge, Qwen35VLMoEBridge]) -def test_get_provider_accepts_supported_qwen_moe_bridges( - monkeypatch: pytest.MonkeyPatch, - bridge_type: type[object], -) -> None: - provider = _FakeProvider() - fake_bridge = _FakeBridge( - model_bridge=object.__new__(bridge_type), - provider=provider, - ) - monkeypatch.setattr( - provider_module.AutoBridge, - "from_hf_pretrained", - lambda *args, **kwargs: fake_bridge, - ) - monkeypatch.setattr(provider_module.torch.cuda, "device_count", lambda: 2) - - resolved = provider_module.get_provider("unused-model") - - assert resolved is provider - assert provider.finalized is True - assert resolved.attention_backend is AttnBackend.auto - assert resolved.recompute_granularity == "full" - assert resolved.recompute_method == "uniform" - assert resolved.recompute_num_layers == 1 - assert resolved.tensor_model_parallel_size == 2 - assert resolved.context_parallel_size == 1 - assert resolved.pipeline_model_parallel_size == 1 - assert resolved.expert_model_parallel_size == 2 - assert resolved.expert_tensor_parallel_size == 1 - assert resolved.sequence_parallel is True - assert resolved.moe_shared_expert_overlap is True - assert resolved.moe_router_dtype == "fp32" - assert resolved.moe_aux_loss_coeff == 0.0 - assert resolved.calculate_per_token_loss is True - - layer_spec = provider_module._resolve_layer_spec( - resolved.transformer_layer_spec, - resolved, - vp_stage=7, - ) - layer_spec = cast(Any, layer_spec) - assert ( - layer_spec.submodules.self_attention.submodules.core_attention - is FlexDotProductAttention - ) - - -def test_get_provider_rejects_unsupported_bridge( - monkeypatch: pytest.MonkeyPatch, -) -> None: - fake_bridge = _FakeBridge(model_bridge=object(), provider=_FakeProvider()) - monkeypatch.setattr( - provider_module.AutoBridge, - "from_hf_pretrained", - lambda *args, **kwargs: fake_bridge, - ) - - with pytest.raises( - AssertionError, - match="Only Qwen3 and Qwen3.5 MoE models are supported", - ): - provider_module.get_provider("unsupported-model") - - -def test_get_provider_preserves_hybrid_qwen35_layer_specs( - monkeypatch: pytest.MonkeyPatch, -) -> None: - provider = _FakeHybridProvider() - fake_bridge = _FakeBridge( - model_bridge=object.__new__(Qwen35VLMoEBridge), - provider=provider, - ) - monkeypatch.setattr( - provider_module.AutoBridge, - "from_hf_pretrained", - lambda *args, **kwargs: fake_bridge, - ) - monkeypatch.setattr(provider_module.torch.cuda, "device_count", lambda: 1) - - resolved = provider_module.get_provider("unused-qwen35") - layer_spec = provider_module._resolve_layer_spec( - resolved.transformer_layer_spec, - resolved, - vp_stage=0, - ) - - layer_specs = getattr(layer_spec, "layer_specs", None) - assert layer_specs is not None - gdn_layer, attention_layer = layer_specs - assert not hasattr(gdn_layer.submodules.self_attention.submodules, "core_attention") - assert ( - attention_layer.submodules.self_attention.submodules.core_attention - is FlexDotProductAttention - ) diff --git a/tests/integration/test_megatron_qwen35_lora_wrapping.py b/tests/integration/test_megatron_qwen35_lora_wrapping.py deleted file mode 100644 index a85719976..000000000 --- a/tests/integration/test_megatron_qwen35_lora_wrapping.py +++ /dev/null @@ -1,447 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterator -from contextlib import contextmanager -import socket - -import pytest - -torch = pytest.importorskip("torch") -pytest.importorskip("megatron.bridge") -pytest.importorskip("megatron.bridge.models.qwen_vl.qwen35_vl_provider") - -from megatron.bridge.models.qwen_vl.qwen35_vl_provider import ( - Qwen3_5MoeVisionConfig, - Qwen35VLMoEModelProvider, -) -from megatron.core import parallel_state as ps -from megatron.core.extensions.transformer_engine import ( - TELayerNormColumnParallelLinear, - TERowParallelLinear, -) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.attention import SelfAttention -from megatron.core.transformer.moe.shared_experts import SharedExpertMLP -from megatron.core.transformer.transformer_layer import TransformerLayer -from torch.distributed import destroy_process_group, init_process_group, is_initialized - -from art.megatron.bridge_adapter_compat import build_adapter_weights_by_base -from art.megatron.lora import ( - GatedDeltaNetInProjLoRA, - MLPExpertsLinearFC1LoRA, - MLPExpertsLinearFC2LoRA, - SelfAttentionLinearProjLoRA, - SelfAttentionLinearQKVLoRA, - SharedExpertsLinearFC1LoRA, - SharedExpertsLinearFC2LoRA, - apply_lora_adapters, -) - - -class _DenseMLP(torch.nn.Module): - def __init__( - self, - *, - linear_fc1: TELayerNormColumnParallelLinear, - linear_fc2: TERowParallelLinear, - ) -> None: - super().__init__() - self.linear_fc1 = linear_fc1 - self.linear_fc2 = linear_fc2 - - -def _make_qwen35_provider() -> Qwen35VLMoEModelProvider: - assert Qwen3_5MoeVisionConfig is not None - provider = Qwen35VLMoEModelProvider( - num_layers=4, - hidden_size=64, - ffn_hidden_size=128, - moe_ffn_hidden_size=32, - moe_shared_expert_intermediate_size=16, - num_attention_heads=4, - num_query_groups=1, - kv_channels=16, - linear_key_head_dim=8, - linear_value_head_dim=16, - linear_num_key_heads=2, - linear_num_value_heads=4, - num_moe_experts=4, - moe_router_topk=2, - normalization="RMSNorm", - gated_linear_unit=True, - add_bias_linear=False, - add_qkv_bias=False, - qk_layernorm=True, - hidden_dropout=0.0, - attention_dropout=0.0, - attention_output_gate=True, - experimental_attention_variant="gated_delta_net", - linear_attention_freq=4, - linear_conv_kernel_dim=2, - vocab_size=128, - seq_length=128, - position_embedding_type="mrope", - vision_config=Qwen3_5MoeVisionConfig(), - tensor_model_parallel_size=1, - expert_model_parallel_size=1, - pipeline_model_parallel_size=1, - context_parallel_size=1, - params_dtype=torch.bfloat16, - ) - provider.finalize() - return provider - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) - - -def _adapter_tensors( - lora: torch.nn.Module, - expert_idx: int | None = None, -) -> tuple[torch.Tensor, torch.Tensor]: - a_t = lora.A_T if expert_idx is None else lora.A_T[expert_idx] - b_t = lora.B_T if expert_idx is None else lora.B_T[expert_idx] - return a_t.transpose(-1, -2).contiguous(), b_t.transpose(-1, -2).contiguous() - - -@contextmanager -def _single_rank_model_parallel() -> Iterator[None]: - if not torch.cuda.is_available(): - pytest.skip("CUDA is required for Megatron Qwen3.5 LoRA coverage.") - if is_initialized(): - pytest.skip("torch.distributed is already initialized in this process.") - - torch.cuda.set_device(0) - init_process_group( - backend="nccl", - init_method=f"tcp://127.0.0.1:{_find_free_port()}", - rank=0, - world_size=1, - ) - try: - ps.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - context_parallel_size=1, - expert_model_parallel_size=1, - ) - model_parallel_cuda_manual_seed(1234) - yield - finally: - if getattr(ps, "model_parallel_is_initialized", lambda: False)(): - ps.destroy_model_parallel() - if is_initialized(): - destroy_process_group() - - -@pytest.mark.skipif( - not torch.cuda.is_available(), - reason="No CUDA available in this environment", -) -def test_apply_lora_adapters_wraps_qwen35_gdn_and_shared_experts() -> None: - with _single_rank_model_parallel(): - provider = _make_qwen35_provider() - model = provider.provide_language_model(pre_process=True, post_process=True) - apply_lora_adapters([model], provider) - - gdn_in_proj_qkv_prefixes: list[str] = [] - gdn_in_proj_z_prefixes: list[str] = [] - gdn_out_proj_prefixes: list[str] = [] - shared_fc1_gate_prefixes: list[str] = [] - shared_fc1_up_prefixes: list[str] = [] - shared_fc2_prefixes: list[str] = [] - - for module in model.modules(): - in_proj = getattr(module, "in_proj", None) - if isinstance(in_proj, GatedDeltaNetInProjLoRA): - gdn_in_proj_qkv_prefixes.append(in_proj.qkv_lora.adapter_model_prefix) - gdn_in_proj_z_prefixes.append(in_proj.z_lora.adapter_model_prefix) - - out_proj = getattr(module, "out_proj", None) - if isinstance(out_proj, SelfAttentionLinearProjLoRA): - prefix = out_proj.lora.adapter_model_prefix - if prefix.endswith(".linear_attn.out_proj"): - gdn_out_proj_prefixes.append(prefix) - - linear_fc1 = getattr(module, "linear_fc1", None) - if isinstance(linear_fc1, SharedExpertsLinearFC1LoRA): - shared_fc1_gate_prefixes.append( - linear_fc1.gate_lora.adapter_model_prefix - ) - shared_fc1_up_prefixes.append(linear_fc1.up_lora.adapter_model_prefix) - - linear_fc2 = getattr(module, "linear_fc2", None) - if isinstance(linear_fc2, SharedExpertsLinearFC2LoRA): - shared_fc2_prefixes.append( - linear_fc2.row_parallel_lora.lora.adapter_model_prefix - ) - - assert gdn_in_proj_qkv_prefixes - assert gdn_in_proj_z_prefixes - assert gdn_out_proj_prefixes - assert shared_fc1_gate_prefixes - assert shared_fc1_up_prefixes - assert shared_fc2_prefixes - assert len(gdn_in_proj_qkv_prefixes) == len(gdn_in_proj_z_prefixes) - assert len(gdn_in_proj_qkv_prefixes) == len(gdn_out_proj_prefixes) - assert len(shared_fc1_gate_prefixes) == len(shared_fc1_up_prefixes) - assert len(shared_fc1_gate_prefixes) == len(shared_fc2_prefixes) - assert all( - prefix.startswith("base_model.model.model.layers.") - and prefix.endswith(".linear_attn.in_proj_qkv") - for prefix in gdn_in_proj_qkv_prefixes - ) - assert all( - prefix.startswith("base_model.model.model.layers.") - and prefix.endswith(".linear_attn.in_proj_z") - for prefix in gdn_in_proj_z_prefixes - ) - assert all( - prefix.startswith("base_model.model.model.layers.") - and prefix.endswith(".linear_attn.out_proj") - for prefix in gdn_out_proj_prefixes - ) - assert all( - prefix.startswith("base_model.model.model.layers.") - and prefix.endswith(".mlp.shared_expert.gate_proj") - for prefix in shared_fc1_gate_prefixes - ) - assert all( - prefix.startswith("base_model.model.model.layers.") - and prefix.endswith(".mlp.shared_expert.up_proj") - for prefix in shared_fc1_up_prefixes - ) - assert all( - prefix.startswith("base_model.model.model.layers.") - and prefix.endswith(".mlp.shared_expert.down_proj") - for prefix in shared_fc2_prefixes - ) - - -@pytest.mark.skipif( - not torch.cuda.is_available(), - reason="No CUDA available in this environment", -) -def test_apply_lora_adapters_accepts_layernorm_column_fc1_dense_path() -> None: - with _single_rank_model_parallel(): - provider = _make_qwen35_provider() - model = provider.provide_language_model(pre_process=True, post_process=True) - - target_layer = next( - module - for module in model.modules() - if isinstance(module, TransformerLayer) - and isinstance(module.self_attention, SelfAttention) - and isinstance(getattr(module.mlp, "shared_experts", None), SharedExpertMLP) - ) - dense_fc1 = target_layer.self_attention.linear_qkv - dense_fc2 = target_layer.self_attention.linear_proj - assert isinstance(dense_fc1, TELayerNormColumnParallelLinear) - assert isinstance(dense_fc2, TERowParallelLinear) - target_layer.mlp = _DenseMLP( - linear_fc1=dense_fc1, - linear_fc2=dense_fc2, - ) - - apply_lora_adapters([model], provider) - - assert isinstance(target_layer.mlp.linear_fc1, SharedExpertsLinearFC1LoRA) - assert isinstance(target_layer.mlp.linear_fc2, SharedExpertsLinearFC2LoRA) - - -@pytest.mark.skipif( - not torch.cuda.is_available(), - reason="No CUDA available in this environment", -) -def test_build_adapter_weights_handles_grouped_qwen35_moe_hf_weights() -> None: - with _single_rank_model_parallel(): - provider = _make_qwen35_provider() - model = provider.provide_language_model(pre_process=True, post_process=True) - apply_lora_adapters([model], provider) - - target_layer = next( - module - for module in model.modules() - if isinstance(module, TransformerLayer) - and hasattr(module.mlp, "experts") - and isinstance(module.mlp.experts.linear_fc1, MLPExpertsLinearFC1LoRA) - and isinstance(module.mlp.experts.linear_fc2, MLPExpertsLinearFC2LoRA) - ) - fc1_handler = target_layer.mlp.experts.linear_fc1 - fc2_handler = target_layer.mlp.experts.linear_fc2 - - for lora in (fc1_handler.gate_lora, fc1_handler.up_lora, fc2_handler.lora): - lora.A_T.data.fill_(1) - lora.B_T.data.fill_(1) - - adapter_weights_by_base = build_adapter_weights_by_base([model]) - layer_prefix = ( - f"language_model.decoder.layers.{target_layer.layer_number - 1}.mlp.experts" - ) - - for expert_idx in range(fc1_handler.gate_lora.num_local_experts): - fc1_weights = adapter_weights_by_base[ - f"{layer_prefix}.linear_fc1.weight{expert_idx}" - ] - fc2_weights = adapter_weights_by_base[ - f"{layer_prefix}.linear_fc2.weight{expert_idx}" - ] - - assert len(fc1_weights) == 1 - assert len(fc2_weights) == 1 - - gate_linear_in, gate_linear_out = _adapter_tensors( - fc1_handler.gate_lora, expert_idx - ) - up_linear_in, up_linear_out = _adapter_tensors( - fc1_handler.up_lora, expert_idx - ) - fc2_linear_in, fc2_linear_out = _adapter_tensors( - fc2_handler.lora, expert_idx - ) - - torch.testing.assert_close( - fc1_weights[0].linear_in_weight.weight, - torch.cat([gate_linear_in, up_linear_in], dim=0), - ) - torch.testing.assert_close( - fc1_weights[0].linear_out_weight.weight, - torch.cat( - [ - torch.cat( - [ - gate_linear_out, - torch.zeros( - ( - gate_linear_out.shape[0], - up_linear_in.shape[0], - ), - device=gate_linear_out.device, - dtype=gate_linear_out.dtype, - ), - ], - dim=1, - ), - torch.cat( - [ - torch.zeros( - ( - up_linear_out.shape[0], - gate_linear_in.shape[0], - ), - device=up_linear_out.device, - dtype=up_linear_out.dtype, - ), - up_linear_out, - ], - dim=1, - ), - ], - dim=0, - ), - ) - torch.testing.assert_close( - fc2_weights[0].linear_in_weight.weight, - fc2_linear_in, - ) - torch.testing.assert_close( - fc2_weights[0].linear_out_weight.weight, - fc2_linear_out, - ) - - -@pytest.mark.skipif( - not torch.cuda.is_available(), - reason="No CUDA available in this environment", -) -def test_build_adapter_weights_handles_grouped_qwen35_moe_hf_weights_with_expert_suffix() -> ( - None -): - with _single_rank_model_parallel(): - provider = _make_qwen35_provider() - model = provider.provide_language_model(pre_process=True, post_process=True) - apply_lora_adapters([model], provider) - - target_layer = next( - module - for module in model.modules() - if isinstance(module, TransformerLayer) - and hasattr(module.mlp, "experts") - and isinstance(module.mlp.experts.linear_fc1, MLPExpertsLinearFC1LoRA) - and isinstance(module.mlp.experts.linear_fc2, MLPExpertsLinearFC2LoRA) - ) - fc1_handler = target_layer.mlp.experts.linear_fc1 - fc2_handler = target_layer.mlp.experts.linear_fc2 - - for lora in (fc1_handler.gate_lora, fc1_handler.up_lora, fc2_handler.lora): - lora.A_T.data.fill_(1) - lora.B_T.data.fill_(1) - - adapter_weights_by_base = build_adapter_weights_by_base([model]) - layer_prefix = ( - f"language_model.decoder.layers.{target_layer.layer_number - 1}.mlp.experts" - ) - expert_idx = 0 - fc1_weights = adapter_weights_by_base[ - f"{layer_prefix}.linear_fc1.weight{expert_idx}" - ] - fc2_weights = adapter_weights_by_base[ - f"{layer_prefix}.linear_fc2.weight{expert_idx}" - ] - - assert len(fc1_weights) == 1 - assert len(fc2_weights) == 1 - assert fc1_weights[0].global_base_prefix == f"{layer_prefix}.linear_fc1" - assert fc2_weights[0].global_base_prefix == f"{layer_prefix}.linear_fc2" - - -@pytest.mark.skipif( - not torch.cuda.is_available(), - reason="No CUDA available in this environment", -) -def test_build_adapter_weights_exposes_qwen35_q_proj_adapter() -> None: - with _single_rank_model_parallel(): - provider = _make_qwen35_provider() - model = provider.provide_language_model(pre_process=True, post_process=True) - apply_lora_adapters([model], provider) - - target_layer = next( - module - for module in model.modules() - if isinstance(module, TransformerLayer) - and isinstance( - getattr(module.self_attention, "linear_qkv", None), - SelfAttentionLinearQKVLoRA, - ) - ) - qkv_handler = target_layer.self_attention.linear_qkv - - for lora in ( - qkv_handler.q_proj_lora, - qkv_handler.k_proj_lora, - qkv_handler.v_proj_lora, - ): - lora.A_T.data.fill_(1) - lora.B_T.data.fill_(1) - - adapter_weights = build_adapter_weights_by_base([model])[ - f"language_model.decoder.layers.{target_layer.layer_number - 1}.self_attention.linear_qkv.weight" - ] - adapter_weights_by_key = { - adapter_weight.adapter_key: adapter_weight - for adapter_weight in adapter_weights - } - - assert set(adapter_weights_by_key) == {"adapter_q", "adapter_k", "adapter_v"} - q_linear_in, q_linear_out = _adapter_tensors(qkv_handler.q_proj_lora) - torch.testing.assert_close( - adapter_weights_by_key["adapter_q"].linear_in_weight.weight, - q_linear_in, - ) - torch.testing.assert_close( - adapter_weights_by_key["adapter_q"].linear_out_weight.weight, - q_linear_out, - ) diff --git a/tests/support/__init__.py b/tests/support/__init__.py new file mode 100644 index 000000000..38361eaf5 --- /dev/null +++ b/tests/support/__init__.py @@ -0,0 +1 @@ +"""Shared test support helpers.""" diff --git a/tests/support/chat_template_conformance_cases.py b/tests/support/chat_template_conformance_cases.py new file mode 100644 index 000000000..b39d8f8d0 --- /dev/null +++ b/tests/support/chat_template_conformance_cases.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import json +from typing import Any, cast + +from openai.types.chat.chat_completion import Choice +from pydantic import BaseModel +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from art.trajectories import History, Trajectory, TrajectoryGroup +from art.types import MessagesAndChoices, Tools + + +def _tool_schema() -> Tools: + return cast( + Tools, + [ + { + "type": "function", + "function": { + "name": "lookup_weather", + "description": "Look up the weather forecast for a city.", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "days": {"type": "integer"}, + }, + "required": ["city", "days"], + }, + }, + } + ], + ) + + +def _tool_call(*, city: str) -> dict[str, Any]: + return { + "id": "call_weather", + "type": "function", + "function": { + "name": "lookup_weather", + "arguments": json.dumps({"city": city, "days": 3}), + }, + } + + +def _tool_message(*, forecast: str) -> dict[str, Any]: + return { + "role": "tool", + "tool_call_id": "call_weather", + "content": json.dumps({"forecast": forecast}), + } + + +def _choice_for_text( + text: str, + token_ids: list[int], + *, + tool_calls: list[dict[str, Any]] | None = None, +) -> Choice: + return Choice.model_validate( + { + "finish_reason": "stop", + "index": 0, + "logprobs": { + "content": [ + { + "token": f"token_id:{token_id}", + "bytes": list(str(token_id).encode("utf-8")), + "logprob": -0.1, + "top_logprobs": [], + } + for token_id in token_ids + ], + "refusal": None, + }, + "message": { + "content": text, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": tool_calls or [], + }, + } + ) + + +def _messages_and_choices(*items: Any) -> MessagesAndChoices: + return cast(MessagesAndChoices, list(items)) + + +class ChatTemplateConformanceInputs(BaseModel): + text_pack_group: TrajectoryGroup + non_final_tool_call_base: Trajectory + non_final_tool_call_mutated: Trajectory + tool_conversation_group: TrajectoryGroup + additional_histories_group: TrajectoryGroup + sft_tool_conversation: Trajectory + sft_tool_conversation_mutated: Trajectory + unsupported_assistant_tool_calls: Trajectory + + +def build_chat_template_conformance_inputs( + tokenizer: PreTrainedTokenizerBase, +) -> ChatTemplateConformanceInputs: + maybe_ids = tokenizer.encode("maybe", add_special_tokens=False) + yes_ids = tokenizer.encode("yes", add_special_tokens=False) + lookup_ids = tokenizer.encode("lookup_weather", add_special_tokens=False) + sunny_ids = tokenizer.encode("sunny", add_special_tokens=False) + rainy_ids = tokenizer.encode("rainy", add_special_tokens=False) + prior_yes_ids = tokenizer.encode("prior yes", add_special_tokens=False) + + tools = _tool_schema() + + return ChatTemplateConformanceInputs( + text_pack_group=TrajectoryGroup( + [ + Trajectory( + messages_and_choices=_messages_and_choices( + {"role": "user", "content": "Respond with one word."}, + _choice_for_text("maybe", maybe_ids), + ), + reward=1.0, + ), + Trajectory( + messages_and_choices=_messages_and_choices( + {"role": "user", "content": "Respond with one word."}, + _choice_for_text("yes", yes_ids), + ), + reward=0.0, + ), + ] + ), + non_final_tool_call_base=Trajectory( + messages_and_choices=_messages_and_choices( + {"role": "user", "content": "What is the weather forecast?"}, + _choice_for_text( + "lookup_weather", + lookup_ids, + tool_calls=[_tool_call(city="San Francisco")], + ), + _tool_message(forecast="sunny"), + _choice_for_text("sunny", sunny_ids), + ), + reward=1.0, + tools=tools, + ), + non_final_tool_call_mutated=Trajectory( + messages_and_choices=_messages_and_choices( + {"role": "user", "content": "What is the weather forecast?"}, + _choice_for_text( + "lookup_weather", + lookup_ids, + tool_calls=[_tool_call(city="New York")], + ), + _tool_message(forecast="sunny"), + _choice_for_text("sunny", sunny_ids), + ), + reward=1.0, + tools=tools, + ), + tool_conversation_group=TrajectoryGroup( + [ + Trajectory( + messages_and_choices=_messages_and_choices( + { + "role": "user", + "content": "What is the weather in San Francisco?", + }, + _choice_for_text( + "lookup_weather", + lookup_ids, + tool_calls=[_tool_call(city="San Francisco")], + ), + _tool_message(forecast="sunny"), + _choice_for_text("sunny", sunny_ids), + ), + reward=1.0, + tools=tools, + ), + Trajectory( + messages_and_choices=_messages_and_choices( + { + "role": "user", + "content": "What is the weather in New York?", + }, + _choice_for_text( + "lookup_weather", + lookup_ids, + tool_calls=[_tool_call(city="New York")], + ), + _tool_message(forecast="rainy"), + _choice_for_text("rainy", rainy_ids), + ), + reward=0.0, + tools=tools, + ), + ] + ), + additional_histories_group=TrajectoryGroup( + [ + Trajectory( + messages_and_choices=_messages_and_choices( + {"role": "user", "content": "Answer with one word."}, + _choice_for_text("maybe", maybe_ids), + ), + additional_histories=[ + History( + messages_and_choices=_messages_and_choices( + {"role": "user", "content": "Previous turn."}, + _choice_for_text("prior yes", prior_yes_ids), + ), + ) + ], + reward=1.0, + ), + Trajectory( + messages_and_choices=_messages_and_choices( + {"role": "user", "content": "Answer with one word."}, + _choice_for_text("yes", yes_ids), + ), + additional_histories=[ + History( + messages_and_choices=_messages_and_choices( + {"role": "user", "content": "Previous turn."}, + _choice_for_text("prior yes", prior_yes_ids), + ), + ) + ], + reward=0.0, + ), + ] + ), + sft_tool_conversation=Trajectory( + messages_and_choices=_messages_and_choices( + { + "role": "user", + "content": "What is the weather in San Francisco?", + }, + { + "role": "assistant", + "content": "", + "tool_calls": [_tool_call(city="San Francisco")], + }, + _tool_message(forecast="sunny"), + {"role": "assistant", "content": "It will be sunny."}, + ), + tools=tools, + ), + sft_tool_conversation_mutated=Trajectory( + messages_and_choices=_messages_and_choices( + { + "role": "user", + "content": "What is the weather in San Francisco?", + }, + { + "role": "assistant", + "content": "", + "tool_calls": [_tool_call(city="New York")], + }, + _tool_message(forecast="sunny"), + {"role": "assistant", "content": "It will be sunny."}, + ), + tools=tools, + ), + unsupported_assistant_tool_calls=Trajectory( + messages_and_choices=_messages_and_choices( + {"role": "user", "content": "Use the weather tool."}, + { + "role": "assistant", + "content": "", + "tool_calls": [_tool_call(city="San Francisco")], + }, + ), + tools=tools, + ), + ) diff --git a/tests/unit/test_dedicated_config.py b/tests/unit/test_dedicated_config.py index dd9127468..fea4fff84 100644 --- a/tests/unit/test_dedicated_config.py +++ b/tests/unit/test_dedicated_config.py @@ -98,9 +98,7 @@ def test_trainer_not_contiguous(): def test_dedicated_rejects_fast_inference(): - with pytest.raises( - ValueError, match="fast_inference is incompatible with dedicated" - ): + with pytest.raises(ValueError, match="fast_inference is no longer supported"): validate_dedicated_config( InternalModelConfig( trainer_gpu_ids=[0], @@ -123,15 +121,16 @@ def test_dedicated_rejects_enable_sleep_mode(): ) -def test_dedicated_allows_fast_inference_false(): - """fast_inference=False is fine in dedicated mode (it's the intended state).""" - validate_dedicated_config( - InternalModelConfig( - trainer_gpu_ids=[0], - inference_gpu_ids=[1], - init_args={"fast_inference": False}, # type: ignore[typeddict-item] +def test_dedicated_rejects_fast_inference_false(): + """fast_inference config is removed; vLLM always lives in its own runtime.""" + with pytest.raises(ValueError, match="fast_inference is no longer supported"): + validate_dedicated_config( + InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + init_args={"fast_inference": False}, # type: ignore[typeddict-item] + ) ) - ) def test_get_model_config_shared_mode(): @@ -142,7 +141,7 @@ def test_get_model_config_shared_mode(): assert "trainer_gpu_ids" not in result assert "inference_gpu_ids" not in result assert result["engine_args"]["enable_sleep_mode"] is True - assert result["init_args"].get("fast_inference") is False + assert "fast_inference" not in result["init_args"] assert result["rollout_weights_mode"] == "lora" assert result["peft_args"]["target_modules"] == [ "q_proj", @@ -172,9 +171,7 @@ def test_get_model_config_qwen3_5_moe_target_modules(base_model: str): "in_proj_qkv", "in_proj_z", "out_proj", - "gate_proj", - "up_proj", - "down_proj", + "experts", ] @@ -252,18 +249,14 @@ def test_merged_rollout_weights_requires_dedicated_mode(): validate_dedicated_config(InternalModelConfig(rollout_weights_mode="merged")) -def test_qwen3_5_moe_requires_merged_rollout_weights(): - with pytest.raises( - ValueError, - match="Qwen3.5-MoE models require rollout_weights_mode='merged'", - ): - validate_dedicated_config( - InternalModelConfig( - trainer_gpu_ids=[0], - inference_gpu_ids=[1], - engine_args={"model": "Qwen/Qwen3.5-35B-A3B"}, # type: ignore[typeddict-item] - ) +def test_qwen3_5_moe_allows_default_lora_rollout_weights(): + validate_dedicated_config( + InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + engine_args={"model": "Qwen/Qwen3.5-35B-A3B"}, # type: ignore[typeddict-item] ) + ) def test_qwen3_5_moe_allows_merged_rollout_weights(): @@ -277,15 +270,11 @@ def test_qwen3_5_moe_allows_merged_rollout_weights(): ) -def test_other_qwen3_5_moe_requires_merged_rollout_weights(): - with pytest.raises( - ValueError, - match="Qwen3.5-MoE models require rollout_weights_mode='merged'", - ): - validate_dedicated_config( - InternalModelConfig( - trainer_gpu_ids=[0], - inference_gpu_ids=[1], - engine_args={"model": "Qwen/Qwen3.5-397B-A17B"}, # type: ignore[typeddict-item] - ) +def test_other_qwen3_5_moe_allows_default_lora_rollout_weights(): + validate_dedicated_config( + InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + engine_args={"model": "Qwen/Qwen3.5-397B-A17B"}, # type: ignore[typeddict-item] ) + ) diff --git a/tests/unit/test_dedicated_server.py b/tests/unit/test_dedicated_server.py deleted file mode 100644 index 11209cef0..000000000 --- a/tests/unit/test_dedicated_server.py +++ /dev/null @@ -1,142 +0,0 @@ -"""Unit tests for dedicated vLLM server entry point.""" - -import pytest - -pytest.importorskip("cloudpickle") -pytest.importorskip("vllm") - -from art.vllm.dedicated_server import _append_cli_arg, parse_args - - -def test_parse_args_required(): - args = parse_args( - [ - "--model", - "Qwen/Qwen3-14B", - "--port", - "8000", - "--cuda-visible-devices", - "1", - "--lora-path", - "/tmp/checkpoints/0000", - "--served-model-name", - "my-model@0", - ] - ) - assert args.model == "Qwen/Qwen3-14B" - assert args.port == 8000 - assert args.cuda_visible_devices == "1" - assert args.lora_path == "/tmp/checkpoints/0000" - assert args.served_model_name == "my-model@0" - assert args.host == "127.0.0.1" - assert args.rollout_weights_mode == "lora" - assert args.engine_args_json == "{}" - assert args.server_args_json == "{}" - - -def test_parse_args_with_engine_args(): - args = parse_args( - [ - "--model", - "test-model", - "--port", - "9000", - "--cuda-visible-devices", - "2", - "--lora-path", - "/tmp/lora", - "--served-model-name", - "test@1", - "--engine-args-json", - '{"max_model_len": 4096}', - ] - ) - assert args.engine_args_json == '{"max_model_len": 4096}' - - -def test_parse_args_custom_host(): - args = parse_args( - [ - "--model", - "test-model", - "--port", - "8000", - "--cuda-visible-devices", - "0", - "--lora-path", - "/tmp/lora", - "--served-model-name", - "test@0", - "--host", - "0.0.0.0", - ] - ) - assert args.host == "0.0.0.0" - - -def test_parse_args_with_server_args(): - args = parse_args( - [ - "--model", - "test-model", - "--port", - "8000", - "--cuda-visible-devices", - "1", - "--lora-path", - "/tmp/lora", - "--served-model-name", - "test@0", - "--server-args-json", - '{"enable_auto_tool_choice": true, "tool_call_parser": "hermes"}', - ] - ) - import json - - server_args = json.loads(args.server_args_json) - assert server_args["enable_auto_tool_choice"] is True - assert server_args["tool_call_parser"] == "hermes" - - -def test_parse_args_merged_mode(): - args = parse_args( - [ - "--model", - "test-model", - "--port", - "8000", - "--cuda-visible-devices", - "1", - "--lora-path", - "/tmp/lora", - "--served-model-name", - "test@0", - "--rollout-weights-mode", - "merged", - ] - ) - - assert args.rollout_weights_mode == "merged" - assert args.lora_path == "/tmp/lora" - - -def test_parse_args_requires_lora_path(): - with pytest.raises(SystemExit): - parse_args( - [ - "--model", - "test-model", - "--port", - "8000", - "--cuda-visible-devices", - "1", - "--served-model-name", - "test@0", - ] - ) - - -def test_append_cli_arg_serializes_dict_values(): - args: list[str] = [] - _append_cli_arg(args, "weight_transfer_config", {"backend": "nccl"}) - assert args == ['--weight-transfer-config={"backend": "nccl"}'] diff --git a/tests/unit/test_local_backend_monitor.py b/tests/unit/test_local_backend_monitor.py deleted file mode 100644 index 33f3ae7e1..000000000 --- a/tests/unit/test_local_backend_monitor.py +++ /dev/null @@ -1,141 +0,0 @@ -import asyncio -from pathlib import Path - -import pytest - -from art import TrainableModel -from art.local import LocalBackend - - -class _FakeResponse: - def __init__(self, body: str, status: int = 200) -> None: - self._body = body - self.status = status - - async def __aenter__(self) -> "_FakeResponse": - return self - - async def __aexit__(self, exc_type, exc, tb) -> bool: - return False - - async def text(self) -> str: - return self._body - - -class _FakeSession: - def __init__(self, urls: list[str]) -> None: - self._urls = urls - - async def __aenter__(self) -> "_FakeSession": - return self - - async def __aexit__(self, exc_type, exc, tb) -> bool: - return False - - def get(self, url: str, timeout) -> _FakeResponse: - del timeout - self._urls.append(url) - if url.endswith("/metrics"): - return _FakeResponse( - "vllm:num_requests_running 0\nvllm:num_requests_waiting 0\n" - ) - if url.endswith("/health"): - return _FakeResponse("ok") - raise AssertionError(f"Unexpected URL: {url}") - - -@pytest.mark.asyncio -async def test_monitor_openai_server_uses_health_probe_when_idle( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - backend = LocalBackend(path=str(tmp_path)) - model = TrainableModel( - name="qwen35-monitor", - project="unit-tests", - base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", - base_path=str(tmp_path), - ) - - class _FakeService: - async def vllm_engine_is_sleeping(self) -> bool: - return False - - backend._services[model.name] = _FakeService() # type: ignore[index] - requested_urls: list[str] = [] - sleep_calls = 0 - - async def fake_sleep(_seconds: float) -> None: - nonlocal sleep_calls - sleep_calls += 1 - if sleep_calls > 1: - raise asyncio.CancelledError - - monkeypatch.setattr("art.local.backend.asyncio.sleep", fake_sleep) - monkeypatch.setattr( - "art.local.backend.aiohttp.ClientSession", - lambda: _FakeSession(requested_urls), - ) - - with pytest.raises(asyncio.CancelledError): - await backend._monitor_openai_server( - model, - "http://127.0.0.1:1234/v1", - "default", - ) - - assert requested_urls == [ - "http://127.0.0.1:1234/metrics", - "http://127.0.0.1:1234/health", - ] - - -@pytest.mark.asyncio -async def test_close_cancels_monitor_tasks( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Monitor tasks should be cancelled during close() to avoid - ConnectionRefusedError after vLLM shuts down.""" - backend = LocalBackend(path=str(tmp_path)) - - class _FakeService: - aclose_called = False - - async def aclose(self) -> None: - self.aclose_called = True - - async def vllm_engine_is_sleeping(self) -> bool: - return False - - service = _FakeService() - backend._services["test-model"] = service # type: ignore[index] - - async def fake_sleep(_seconds: float) -> None: - await asyncio.sleep(0) # yield control - - monkeypatch.setattr("art.local.backend.asyncio.sleep", fake_sleep) - monkeypatch.setattr( - "art.local.backend.aiohttp.ClientSession", - lambda: _FakeSession([]), - ) - - model = TrainableModel( - name="test-model", - project="unit-tests", - base_model="test/model", - base_path=str(tmp_path), - ) - - task = asyncio.create_task( - backend._monitor_openai_server(model, "http://127.0.0.1:1234/v1", "default") - ) - backend._monitor_tasks["test-model"] = task - - # Let the monitor run one iteration - await asyncio.sleep(0) - - await backend.close() - - assert task.cancelled() or task.done() - assert len(backend._monitor_tasks) == 0 diff --git a/tests/unit/test_megatron_dedicated.py b/tests/unit/test_megatron_dedicated.py deleted file mode 100644 index dcc56aa2d..000000000 --- a/tests/unit/test_megatron_dedicated.py +++ /dev/null @@ -1,505 +0,0 @@ -import asyncio -from contextlib import nullcontext -import os -from pathlib import Path -import shlex -import sys -import types as pytypes -from types import SimpleNamespace -from typing import Any - -import pytest -import torch - -pytest.importorskip("vllm") - -from art import TrainableModel, types -from art.dev.model import InternalModelConfig -from art.dev.validate import QWEN3_5_MOE_MODELS -from art.megatron.backend import MegatronBackend -from art.megatron.jobs import ( - MegatronMergedTrainJob, - MergedWeightTransferInitInfo, -) -from art.megatron.service import MegatronService, create_identity_lora -from art.megatron.train import _compile_enabled, _unwrap_art_wrapper_name - - -@pytest.mark.asyncio -async def test_megatron_backend_dedicated_uses_trainer_gpus_without_child_process( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - config = InternalModelConfig( - trainer_gpu_ids=[0], - inference_gpu_ids=[1], - rollout_weights_mode="lora", - ) - model = TrainableModel( - name="megatron-dedicated", - project="unit-tests", - base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", - base_path=str(tmp_path), - _internal_config=config, - ) - backend = MegatronBackend(path=str(tmp_path)) - validated: dict[str, Any] = {} - - class FakeService: - def __init__( - self, - *, - model_name: str, - base_model: str, - config: InternalModelConfig, - output_dir: str, - ) -> None: - self.model_name = model_name - self.base_model = base_model - self.config = config - self.output_dir = output_dir - - monkeypatch.setattr( - "art.dev.get_model_config.get_model_config", - lambda *args, **kwargs: config, - ) - monkeypatch.setattr( - "art.dev.validate.validate_dedicated_config", - lambda cfg: validated.setdefault("config", cfg), - ) - monkeypatch.setattr( - "art.megatron.backend.move_to_child_process", - lambda *args, **kwargs: (_ for _ in ()).throw( - AssertionError( - "Dedicated Megatron service should not move to a child process" - ) - ), - ) - monkeypatch.setattr("art.megatron.service.MegatronService", FakeService) - - service = await backend._get_service(model) - - assert isinstance(service, FakeService) - assert validated["config"] is config - assert os.environ["CUDA_VISIBLE_DEVICES"] == "0" - - -@pytest.mark.asyncio -async def test_megatron_service_ensure_megatron_running_uses_trainer_gpus( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - service = MegatronService( - model_name="megatron-dedicated", - base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", - config=InternalModelConfig( - trainer_gpu_ids=[0, 1], - inference_gpu_ids=[2], - rollout_weights_mode="lora", - ), - output_dir=str(tmp_path), - ) - megatron_module = pytypes.ModuleType("megatron") - megatron_bridge_module = pytypes.ModuleType("megatron.bridge") - monkeypatch.setitem(sys.modules, "megatron", megatron_module) - monkeypatch.setitem(sys.modules, "megatron.bridge", megatron_bridge_module) - - seen: dict[str, Any] = {} - - monkeypatch.setattr( - "art.megatron.service.subprocess.run", lambda *args, **kwargs: None - ) - - async def fake_create_subprocess_shell( - command: str, - cwd: str, - env: dict[str, str], - start_new_session: bool, - ) -> Any: - seen["command"] = command - seen["cwd"] = cwd - seen["env"] = env - seen["start_new_session"] = start_new_session - return pytypes.SimpleNamespace(returncode=None) - - monkeypatch.setattr( - "art.megatron.service.asyncio.create_subprocess_shell", - fake_create_subprocess_shell, - ) - - await service._ensure_megatron_running() - - assert "uv run --project" in seen["command"] - assert "torchrun" in seen["command"] - assert "--nproc_per_node 2" in seen["command"] - assert seen["env"]["CUDA_VISIBLE_DEVICES"] == "0,1" - assert seen["env"]["MODEL_IDENTIFIER"] == "Qwen/Qwen3-30B-A3B-Instruct-2507" - assert seen["start_new_session"] is True - - -def test_unwrap_art_wrapper_name_strips_compiled_wrapper_segments() -> None: - assert ( - _unwrap_art_wrapper_name( - "module.module.decoder.layers.0._orig_mod.self_attention.linear_proj.linear_proj.weight" - ) - == "decoder.layers.0.self_attention.linear_proj.weight" - ) - assert ( - _unwrap_art_wrapper_name( - "module.module.decoder.layers.0._orig_mod.mlp.experts.linear_fc1.linear_fc1.weight7" - ) - == "decoder.layers.0.mlp.experts.linear_fc1.weight7" - ) - - -def test_compile_enabled_disables_qwen35_moe_by_default() -> None: - assert _compile_enabled("Qwen/Qwen3-30B-A3B-Instruct-2507") is True - assert _compile_enabled("Qwen/Qwen3.5-32B-Instruct") is True - for model_identifier in QWEN3_5_MOE_MODELS: - assert _compile_enabled(model_identifier) is False - - -def test_compile_enabled_honors_env_override( - monkeypatch: pytest.MonkeyPatch, -) -> None: - monkeypatch.setenv("ART_DISABLE_MEGATRON_COMPILE", "0") - assert _compile_enabled("Qwen/Qwen3.5-35B-A3B") is True - monkeypatch.setenv("ART_DISABLE_MEGATRON_COMPILE", "1") - assert _compile_enabled("Qwen/Qwen3-30B-A3B-Instruct-2507") is False - - -def test_create_identity_lora_uses_nested_text_config_when_top_level_lacks_vocab_size( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - top_level_config = SimpleNamespace( - text_config=SimpleNamespace(vocab_size=128), - ) - seen: dict[str, Any] = {} - - class FakeModel: - name_or_path = "" - - def named_parameters(self) -> list[tuple[str, torch.Tensor]]: - return [ - ( - "model.layers.0.self_attn.q_proj.weight", - torch.empty(1, device="meta"), - ), - ( - "model.layers.0.linear_attn.in_proj_qkv.weight", - torch.empty(1, device="meta"), - ), - ] - - class FakePeftModel: - def save_pretrained(self, lora_path: str) -> None: - Path(lora_path).mkdir(parents=True, exist_ok=True) - - monkeypatch.setattr( - "transformers.AutoConfig.from_pretrained", - lambda *_args, **_kwargs: top_level_config, - ) - monkeypatch.setattr( - "transformers.AutoModelForCausalLM.from_config", - lambda config, **_kwargs: seen.setdefault("config", config) or FakeModel(), - ) - monkeypatch.setattr("accelerate.init_empty_weights", nullcontext) - monkeypatch.setattr( - "peft.get_peft_model", - lambda _model, lora_config, **_kwargs: ( - seen.setdefault("lora_config", lora_config) or FakePeftModel() - ), - ) - monkeypatch.setattr( - "art.megatron.service.convert_checkpoint_if_needed", - lambda _path: None, - ) - - create_identity_lora("Qwen/Qwen3.5-35B-A3B", str(tmp_path)) - - assert seen["config"] is top_level_config.text_config - assert ( - "model.layers.0.linear_attn.in_proj_qkv.weight" - in seen["lora_config"].target_parameters - ) - - -@pytest.mark.asyncio -async def test_megatron_service_start_openai_server_dedicated_starts_subprocess( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - checkpoint_dir = tmp_path / "checkpoints" / "0000" - checkpoint_dir.mkdir(parents=True) - service = MegatronService( - model_name="megatron-dedicated", - base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", - config=InternalModelConfig( - trainer_gpu_ids=[0], - inference_gpu_ids=[1], - rollout_weights_mode="lora", - ), - output_dir=str(tmp_path), - ) - seen: dict[str, Any] = {} - - monkeypatch.setattr( - "art.megatron.service.get_last_checkpoint_dir", - lambda _output_dir: str(checkpoint_dir), - ) - monkeypatch.setattr(service, "_ensure_identity_lora", lambda _path: None) - monkeypatch.setattr( - service, "_ensure_lora_adapter_config", lambda _path, source_path=None: None - ) - - async def fake_start_vllm_subprocess( - lora_path: str, - port: int, - config: dict[str, Any] | None, - ) -> tuple[str, int]: - seen["lora_path"] = lora_path - seen["port"] = port - seen["config"] = config - return ("127.0.0.1", port) - - monkeypatch.setattr(service, "_start_vllm_subprocess", fake_start_vllm_subprocess) - - location = await service.start_openai_server({"server_args": {"port": 8123}}) - - assert location == ("127.0.0.1", 8123) - assert seen["lora_path"] == str(checkpoint_dir) - assert seen["port"] == 8123 - - -@pytest.mark.asyncio -async def test_megatron_service_register_lora_for_step_dedicated_reloads_adapter( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - service = MegatronService( - model_name="megatron-dedicated", - base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", - config=InternalModelConfig( - trainer_gpu_ids=[0], - inference_gpu_ids=[1], - rollout_weights_mode="lora", - ), - output_dir=str(tmp_path), - ) - seen: list[tuple[str, int]] = [] - - monkeypatch.setattr( - service, - "_reload_adapter", - lambda checkpoint_dir, step: ( - seen.append((checkpoint_dir, step)) or asyncio.sleep(0) - ), - ) - - await service.register_lora_for_step(3, "/tmp/checkpoints/3") - - assert seen == [("/tmp/checkpoints/3", 3)] - - -@pytest.mark.asyncio -async def test_megatron_service_start_openai_server_merged_syncs_step_zero( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - checkpoint_dir = tmp_path / "checkpoints" / "0000" - checkpoint_dir.mkdir(parents=True) - service = MegatronService( - model_name="megatron-merged", - base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", - config=InternalModelConfig( - trainer_gpu_ids=[0], - inference_gpu_ids=[1], - rollout_weights_mode="merged", - ), - output_dir=str(tmp_path), - ) - calls: list[object] = [] - ensured_identity_paths: list[str] = [] - - monkeypatch.setattr( - "art.megatron.service.get_last_checkpoint_dir", - lambda _output_dir: str(checkpoint_dir), - ) - monkeypatch.setattr( - service, - "_ensure_identity_lora", - lambda path: ensured_identity_paths.append(path), - ) - monkeypatch.setattr(service, "_ensure_lora_adapter_config", lambda _path: None) - monkeypatch.setattr( - service, - "_start_vllm_subprocess", - lambda lora_path, port, config: asyncio.sleep(0, result=("127.0.0.1", port)), - ) - monkeypatch.setattr(service, "_clear_pending_jobs", lambda: calls.append("clear")) - monkeypatch.setattr( - service, - "_sync_dedicated_merged_weights", - lambda *, lora_path, step: calls.append((lora_path, step)) or asyncio.sleep(0), - ) - - location = await service.start_openai_server({"server_args": {"port": 8123}}) - - assert location == ("127.0.0.1", 8123) - assert ensured_identity_paths == [str(checkpoint_dir)] - assert calls == ["clear", (str(checkpoint_dir), 0)] - - -@pytest.mark.asyncio -async def test_megatron_service_start_openai_server_shared_lora_bootstraps_step_zero( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - checkpoint_dir = tmp_path / "checkpoints" / "0000" - checkpoint_dir.mkdir(parents=True) - service = MegatronService( - model_name="megatron-shared", - base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", - config=InternalModelConfig(), - output_dir=str(tmp_path), - ) - ensured_identity_paths: list[str] = [] - - monkeypatch.setattr( - "art.megatron.service.get_last_checkpoint_dir", - lambda _output_dir: str(checkpoint_dir), - ) - monkeypatch.setattr( - service, - "_ensure_identity_lora", - lambda path: ensured_identity_paths.append(path), - ) - monkeypatch.setattr(service, "_ensure_lora_adapter_config", lambda _path: None) - monkeypatch.setattr( - "art.megatron.service.dev.get_openai_server_config", - lambda **_kwargs: {"server_args": {"port": 8123}, "engine_args": {}}, - ) - monkeypatch.setattr( - "art.megatron.service.openai_server_task", - lambda **_kwargs: asyncio.sleep(0), - ) - - location = await service.start_openai_server({"server_args": {"port": 8123}}) - - assert location == ("0.0.0.0", 8123) - assert ensured_identity_paths == [str(checkpoint_dir)] - - -@pytest.mark.asyncio -async def test_megatron_service_register_lora_for_step_merged_sets_served_name( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - service = MegatronService( - model_name="megatron-merged", - base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", - config=InternalModelConfig( - trainer_gpu_ids=[0], - inference_gpu_ids=[1], - rollout_weights_mode="merged", - ), - output_dir=str(tmp_path), - ) - calls: list[int] = [] - - monkeypatch.setattr( - service, - "_set_served_model_name", - lambda step: calls.append(step) or asyncio.sleep(0), - ) - - await service.register_lora_for_step(3, "/tmp/checkpoints/3") - - assert calls == [3] - - -@pytest.mark.asyncio -async def test_megatron_service_train_merged_writes_merged_job_and_does_not_reload_adapter( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - checkpoint_dir = tmp_path / "checkpoints" / "0000" - checkpoint_dir.mkdir(parents=True) - adapter_path = checkpoint_dir / "adapter_model.safetensors" - adapter_path.write_bytes(b"adapter") - service = MegatronService( - model_name="megatron-merged", - base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", - config=InternalModelConfig( - trainer_gpu_ids=[0], - inference_gpu_ids=[1], - rollout_weights_mode="merged", - ), - output_dir=str(tmp_path), - ) - events: list[Any] = [] - - monkeypatch.setattr( - service, - "_ensure_megatron_running", - lambda: events.append("ensure") or asyncio.sleep(0), - ) - monkeypatch.setattr( - service, "_resolve_active_lora_path", lambda: str(checkpoint_dir) - ) - monkeypatch.setattr( - service, - "_init_merged_weight_transfer", - lambda: events.append("init") or asyncio.sleep(0), - ) - monkeypatch.setattr( - "art.megatron.service.create_megatron_job_paths", - lambda *, jobs_dir, training_log_dir: ("/tmp/job.json", "/tmp/log.jsonl"), - ) - monkeypatch.setattr( - "art.megatron.service.write_megatron_job", - lambda job, *, job_path: events.append(job), - ) - - async def fake_stream_megatron_job(job: Any, *, job_path: str): - events.append(("stream", job_path, job.lora_path)) - yield {"loss": 1.0} - - monkeypatch.setattr( - "art.megatron.service.stream_megatron_job", - fake_stream_megatron_job, - ) - monkeypatch.setattr( - service, - "_ensure_lora_adapter_config", - lambda _path, source_path=None: None, - ) - monkeypatch.setattr( - service, - "_reload_adapter", - lambda checkpoint_dir, step: (_ for _ in ()).throw( - AssertionError("merged mode should not hot-reload a LoRA adapter") - ), - ) - service._merged_weight_transfer_init_info = MergedWeightTransferInitInfo( - master_address="127.0.0.1", - master_port=1234, - rank_offset=1, - world_size=2, - ) - - results = [] - async for result in service.train( - {"dir": "/tmp/tensors", "num_sequences": 1, "sequence_length": 16}, - types.TrainConfig(learning_rate=5e-5), - {}, - ): - results.append(result) - - assert results == [{"loss": 1.0}] - assert events[0:2] == ["ensure", "init"] - job = events[2] - assert isinstance(job, MegatronMergedTrainJob) - assert job.merged_weight_transfer.served_model_name == "megatron-merged@1" - assert events[3] == ("stream", "/tmp/job.json", str(checkpoint_dir)) diff --git a/tests/unit/test_megatron_provider_env.py b/tests/unit/test_megatron_provider_env.py deleted file mode 100644 index 130f86737..000000000 --- a/tests/unit/test_megatron_provider_env.py +++ /dev/null @@ -1,164 +0,0 @@ -from importlib.util import module_from_spec, spec_from_file_location -from pathlib import Path -import sys -from types import ModuleType, SimpleNamespace - -import pytest - - -def _ensure_package(name: str) -> ModuleType: - module = sys.modules.get(name) - if module is None: - module = ModuleType(name) - module.__path__ = [] # type: ignore[attr-defined] - sys.modules[name] = module - return module - - -def _stub_module(name: str, **attrs: object) -> ModuleType: - module = sys.modules.get(name) - if module is None: - module = ModuleType(name) - sys.modules[name] = module - for key, value in attrs.items(): - setattr(module, key, value) - return module - - -def _load_provider_module() -> ModuleType: - module_name = "art_megatron_provider_under_test" - existing = sys.modules.get(module_name) - if existing is not None: - return existing - - for package_name in [ - "art", - "art.megatron", - "megatron", - "megatron.bridge", - "megatron.bridge.models", - "megatron.bridge.models.hf_pretrained", - "megatron.bridge.models.qwen", - "megatron.bridge.training", - "megatron.core", - "megatron.core.transformer", - ]: - _ensure_package(package_name) - - class _StateSource: - pass - - class _ModuleSpec: - pass - - class _GPTModelProvider: - pass - - class _AutoBridge: - @staticmethod - def from_hf_pretrained(*args: object, **kwargs: object) -> object: - raise NotImplementedError - - _stub_module("art.megatron.flex_attention", FlexDotProductAttention=object) - _stub_module("megatron.bridge", AutoBridge=_AutoBridge) - _stub_module( - "megatron.bridge.models.gpt_provider", - GPTModelProvider=_GPTModelProvider, - ) - _stub_module( - "megatron.bridge.models.hf_pretrained.state", - SafeTensorsStateSource=object, - StateDict=object, - StateSource=_StateSource, - ) - _stub_module( - "megatron.bridge.models.qwen.qwen3_moe_bridge", - Qwen3MoEBridge=object, - ) - _stub_module( - "megatron.bridge.training.flex_dispatcher_backend", - apply_flex_dispatcher_backend=lambda *args, **kwargs: None, - ) - _stub_module( - "megatron.core.transformer.enums", - AttnBackend=SimpleNamespace(auto="auto"), - ) - _stub_module("megatron.core.transformer.spec_utils", ModuleSpec=_ModuleSpec) - - provider_path = Path(__file__).resolve().parents[2] / "src/art/megatron/provider.py" - spec = spec_from_file_location(module_name, provider_path) - assert spec is not None and spec.loader is not None - module = module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - return module - - -megatron_provider = _load_provider_module() - - -def _fake_provider() -> SimpleNamespace: - return SimpleNamespace( - overlap_moe_expert_parallel_comm=False, - delay_wgrad_compute=False, - ep_overlap_early_attn_memory_release=False, - moe_deepep_num_sms=None, - moe_apply_probs_on_input=False, - bias_activation_fusion=False, - fine_grained_activation_offloading=False, - offload_modules=[], - tensor_model_parallel_size=1, - recompute_granularity="full", - recompute_method="uniform", - recompute_num_layers=1, - recompute_modules=None, - moe_shared_expert_overlap=True, - ) - - -def test_apply_runtime_env_overrides_warns_on_shared_expert_overlap_conflict( - monkeypatch: pytest.MonkeyPatch, -) -> None: - provider = _fake_provider() - monkeypatch.setenv("ART_MEGATRON_OVERLAP_MOE_EXPERT_PARALLEL_COMM", "true") - monkeypatch.setenv("ART_MEGATRON_MOE_SHARED_EXPERT_OVERLAP", "true") - monkeypatch.setattr( - megatron_provider, "_resolve_default_deepep_num_sms", lambda _: 20 - ) - - with pytest.warns(UserWarning, match="moe_shared_expert_overlap=False"): - megatron_provider._apply_runtime_env_overrides(provider) - - assert provider.moe_shared_expert_overlap is False - - -@pytest.mark.parametrize("raw_value", [None, "default"]) -def test_apply_runtime_env_overrides_uses_resolved_default_num_sms( - monkeypatch: pytest.MonkeyPatch, - raw_value: str | None, -) -> None: - provider = _fake_provider() - monkeypatch.setattr( - megatron_provider, "_resolve_default_deepep_num_sms", lambda _: 42 - ) - if raw_value is None: - monkeypatch.delenv("ART_MEGATRON_MOE_DEEPEP_NUM_SMS", raising=False) - else: - monkeypatch.setenv("ART_MEGATRON_MOE_DEEPEP_NUM_SMS", raw_value) - - megatron_provider._apply_runtime_env_overrides(provider) - - assert provider.moe_deepep_num_sms == 42 - - -@pytest.mark.parametrize("raw_value", ["none", "3", "0", "-2"]) -def test_env_default_or_even_positive_int_rejects_invalid_values( - monkeypatch: pytest.MonkeyPatch, - raw_value: str, -) -> None: - monkeypatch.setenv("ART_MEGATRON_MOE_DEEPEP_NUM_SMS", raw_value) - - with pytest.raises(ValueError, match="positive, even integer"): - megatron_provider._env_default_or_even_positive_int( - "ART_MEGATRON_MOE_DEEPEP_NUM_SMS" - ) diff --git a/tests/unit/test_megatron_qwen_helpers.py b/tests/unit/test_megatron_qwen_helpers.py deleted file mode 100644 index 0c0b77829..000000000 --- a/tests/unit/test_megatron_qwen_helpers.py +++ /dev/null @@ -1,59 +0,0 @@ -from types import SimpleNamespace -from typing import Any, cast - -import pytest - -pytest.importorskip("megatron.bridge") - -import torch - -from art.megatron.lora import SelfAttentionLinearQKVLoRA -from art.megatron.train import _canonical_art_param_name - - -def test_canonical_art_param_name_strips_art_wrapper_segments() -> None: - assert ( - _canonical_art_param_name( - "module.language_model.decoder.layers.0.self_attention.out_proj.linear_proj.weight" - ) - == "language_model.decoder.layers.0.self_attention.out_proj.weight" - ) - assert ( - _canonical_art_param_name( - "module.language_model.decoder.layers.0.mlp.linear_fc2.row_parallel_lora.linear_proj.weight" - ) - == "language_model.decoder.layers.0.mlp.linear_fc2.weight" - ) - - -def test_self_attention_linear_qkv_lora_accepts_nongated_qwen3_layout( - monkeypatch: pytest.MonkeyPatch, -) -> None: - monkeypatch.setattr( - "art.megatron.lora.ps.get_tensor_model_parallel_world_size", lambda: 1 - ) - provider: Any = SimpleNamespace( - kv_channels=128, - num_query_groups=4, - num_attention_heads=32, - attention_output_gate=False, - ) - q_out_features = provider.kv_channels * provider.num_attention_heads - kv_out_features = provider.kv_channels * provider.num_query_groups - linear_qkv: Any = SimpleNamespace( - weight=torch.empty(q_out_features + 2 * kv_out_features, 16), - in_features=16, - return_layernorm_output=False, - return_layernorm_output_gathered=False, - ) - - wrapped = SelfAttentionLinearQKVLoRA( - adapter_model_prefix="base_model.model.model.layers.0.self_attn", - linear_qkv=cast(Any, linear_qkv), - rank=4, - alpha=8.0, - provider=cast(Any, provider), - ) - - assert wrapped.attention_output_gate is False - assert wrapped.q_proj_lora.B_T.shape[-1] == q_out_features diff --git a/tests/unit/test_metric_routing.py b/tests/unit/test_metric_routing.py index 7cb57d912..fb5eb65ae 100644 --- a/tests/unit/test_metric_routing.py +++ b/tests/unit/test_metric_routing.py @@ -60,7 +60,7 @@ def test_get_wandb_run_registers_taxonomy_sections(self, tmp_path: Path) -> None assert run is fake_run define_calls = [ - (call.args, call.kwargs) for call in fake_wandb.define_metric.call_args_list + (call.args, call.kwargs) for call in fake_run.define_metric.call_args_list ] assert define_calls == [ (("training_step",), {}), @@ -107,7 +107,7 @@ def test_log_metrics_defines_nested_cost_keys_with_training_step( ) define_calls = [ - (call.args, call.kwargs) for call in fake_wandb.define_metric.call_args_list + (call.args, call.kwargs) for call in fake_run.define_metric.call_args_list ] assert ( ("costs/train/sample",), diff --git a/tests/unit/test_moe_routing_replay.py b/tests/unit/test_moe_routing_replay.py index be51ab325..a43a701a1 100644 --- a/tests/unit/test_moe_routing_replay.py +++ b/tests/unit/test_moe_routing_replay.py @@ -15,6 +15,8 @@ RouterCallRoute, StepRouterRoutes, StepRoutes, + TopologyAwareLocalTokenIndexer, + build_router_key_from_module_name, ) @@ -37,6 +39,11 @@ def _dense_from_compact( return probs, routing_map +def _assert_probs_close(actual: torch.Tensor, expected: torch.Tensor) -> None: + max_diff = (actual - expected).abs().max().item() + assert max_diff < 1e-6 + + def _make_bundle() -> tuple[MoeRoutingReplayBundle, RouterCallRoute]: router_key = "chunk_00.layer_0000.mlp.router" route = RouterCallRoute( @@ -84,6 +91,77 @@ def _make_bundle() -> tuple[MoeRoutingReplayBundle, RouterCallRoute]: return bundle, route +def _make_sampled_bundle() -> MoeRoutingReplayBundle: + router_key = "chunk_00.layer_0000.mlp.router" + route0 = RouterCallRoute( + expert_indices=torch.tensor([[0, 2], [1, 0]], dtype=torch.int32), + expert_probs=torch.tensor([[0.70, 0.30], [1.00, 0.00]], dtype=torch.float32), + expert_mask=torch.tensor([[True, True], [True, False]], dtype=torch.bool), + num_experts=3, + sample_index=0, + ) + route1 = RouterCallRoute( + expert_indices=torch.tensor([[2, 1], [0, 1]], dtype=torch.int32), + expert_probs=torch.tensor([[0.60, 0.40], [1.00, 0.00]], dtype=torch.float32), + expert_mask=torch.tensor([[True, True], [True, False]], dtype=torch.bool), + num_experts=3, + sample_index=1, + ) + return MoeRoutingReplayBundle( + topology=ParallelTopology(tp=1, ep=1, etp=1, dp=1, sp=False, cp=1, pp=1, vpp=1), + num_steps=1, + max_topk=2, + router_keys=[router_key], + steps={ + 0: StepRoutes( + routers={router_key: StepRouterRoutes(calls={0: route0, 1: route1})}, + global_token_uids=torch.arange(2, dtype=torch.int64), + ) + }, + ) + + +def _make_multi_call_bundle() -> MoeRoutingReplayBundle: + router_key = "chunk_00.layer_0000.mlp.router" + route0 = RouterCallRoute( + expert_indices=torch.tensor([[0, 2]], dtype=torch.int32), + expert_probs=torch.tensor([[0.70, 0.30]], dtype=torch.float32), + expert_mask=torch.tensor([[True, True]], dtype=torch.bool), + num_experts=3, + sample_index=0, + ) + route1 = RouterCallRoute( + expert_indices=torch.tensor([[1, 0]], dtype=torch.int32), + expert_probs=torch.tensor([[1.00, 0.00]], dtype=torch.float32), + expert_mask=torch.tensor([[True, False]], dtype=torch.bool), + num_experts=3, + sample_index=0, + ) + route2 = RouterCallRoute( + expert_indices=torch.tensor([[2, 1]], dtype=torch.int32), + expert_probs=torch.tensor([[0.55, 0.45]], dtype=torch.float32), + expert_mask=torch.tensor([[True, True]], dtype=torch.bool), + num_experts=3, + sample_index=1, + ) + return MoeRoutingReplayBundle( + topology=ParallelTopology(tp=1, ep=1, etp=1, dp=1, sp=False, cp=1, pp=1, vpp=1), + num_steps=1, + max_topk=2, + router_keys=[router_key], + steps={ + 0: StepRoutes( + routers={ + router_key: StepRouterRoutes( + calls={0: route0, 1: route1, 2: route2} + ) + }, + global_token_uids=torch.arange(1, dtype=torch.int64), + ) + }, + ) + + class _IdentityIndexer: def build_local_token_uids( self, @@ -99,6 +177,28 @@ def build_local_token_uids( return global_token_uids[:num_local_tokens].clone() +class _FakeParallelState: + def __init__( + self, + *, + tp_world_size: int = 1, + tp_rank: int = 0, + cp_world_size: int = 1, + ) -> None: + self._tp_world_size = tp_world_size + self._tp_rank = tp_rank + self._cp_world_size = cp_world_size + + def get_context_parallel_world_size(self) -> int: + return self._cp_world_size + + def get_tensor_model_parallel_world_size(self) -> int: + return self._tp_world_size + + def get_tensor_model_parallel_rank(self) -> int: + return self._tp_rank + + class _FakeRouter(nn.Module): def __init__(self) -> None: super().__init__() @@ -138,6 +238,60 @@ def __init__(self) -> None: self.decoder = _FakeDecoder() +def test_build_router_key_from_compiled_module_name() -> None: + assert ( + build_router_key_from_module_name( + chunk_index=0, + module_name="module.decoder.layers.0._orig_mod.mlp.router", + ) + == "chunk_00.layer_0000.mlp.router" + ) + + +def test_build_router_key_from_nested_compiled_module_name() -> None: + assert ( + build_router_key_from_module_name( + chunk_index=3, + module_name="module.decoder.layers.12.mlp._orig_mod.router", + ) + == "chunk_03.layer_0012.mlp.router" + ) + + +def test_topology_aware_local_token_indexer_keeps_merged_rows_when_counts_match() -> ( + None +): + indexer = TopologyAwareLocalTokenIndexer( + parallel_state_module=_FakeParallelState(tp_world_size=2, tp_rank=1) + ) + global_token_uids = torch.arange(256, dtype=torch.int64) + + local_uids = indexer.build_local_token_uids( + global_token_uids=global_token_uids, + num_local_tokens=256, + sequence_parallel=True, + context_parallel_size=1, + ) + + assert torch.equal(local_uids, global_token_uids) + + +def test_topology_aware_local_token_indexer_slices_sequence_parallel_rows() -> None: + indexer = TopologyAwareLocalTokenIndexer( + parallel_state_module=_FakeParallelState(tp_world_size=2, tp_rank=1) + ) + global_token_uids = torch.arange(256, dtype=torch.int64) + + local_uids = indexer.build_local_token_uids( + global_token_uids=global_token_uids, + num_local_tokens=128, + sequence_parallel=True, + context_parallel_size=1, + ) + + assert torch.equal(local_uids, torch.arange(128, 256, dtype=torch.int64)) + + def test_bundle_roundtrip_disk() -> None: bundle, route = _make_bundle() with tempfile.TemporaryDirectory() as tmp_dir: @@ -174,7 +328,7 @@ def test_controller_patches_router_and_replays() -> None: expected_probs, expected_map = _dense_from_compact(route, dtype=logits.dtype) assert torch.equal(replay_map.cpu(), expected_map) - assert torch.allclose(replay_probs.cpu(), expected_probs, atol=0.0, rtol=0.0) + _assert_probs_close(replay_probs.cpu(), expected_probs) controller.finalize_step() controller.remove_router_patches() @@ -192,3 +346,92 @@ def test_controller_finalize_fails_when_unconsumed_calls_remain() -> None: controller.set_step(step_index=0, sample_index=0) with pytest.raises(RuntimeError, match="consumption mismatch"): controller.finalize_step() + + +def test_controller_reuses_route_for_recompute_with_same_active_micro() -> None: + bundle = _make_sampled_bundle() + controller = MoeRoutingReplayController( + bundle=bundle, + strict=True, + local_token_indexer=_IdentityIndexer(), + ) + chunk = _FakeChunk() + controller.install_router_patches([chunk]) + controller.set_step(step_index=0, sample_index=[0, 1]) + router = cast( + _FakeRouter, + chunk.decoder.layers[0].mlp.router, # ty: ignore[possibly-missing-attribute] + ) + logits = torch.randn((2, 3), dtype=torch.float32) + + controller.begin_micro(0, 0) + first_probs, first_map = router.routing(logits) + recompute_probs, recompute_map = router.routing(logits) + controller.begin_micro(1, 1) + second_probs, second_map = router.routing(logits) + + expected_first_probs, expected_first_map = _dense_from_compact( + bundle.steps[0].routers[bundle.router_keys[0]].calls[0], + dtype=logits.dtype, + ) + expected_second_probs, expected_second_map = _dense_from_compact( + bundle.steps[0].routers[bundle.router_keys[0]].calls[1], + dtype=logits.dtype, + ) + + assert torch.equal(first_map.cpu(), expected_first_map) + _assert_probs_close(first_probs.cpu(), expected_first_probs) + assert torch.equal(recompute_map.cpu(), expected_first_map) + _assert_probs_close(recompute_probs.cpu(), expected_first_probs) + assert torch.equal(second_map.cpu(), expected_second_map) + _assert_probs_close(second_probs.cpu(), expected_second_probs) + + controller.finalize_step() + controller.remove_router_patches() + + +def test_controller_consumes_multiple_captured_calls_before_recompute_reuse() -> None: + bundle = _make_multi_call_bundle() + controller = MoeRoutingReplayController( + bundle=bundle, + strict=True, + local_token_indexer=_IdentityIndexer(), + ) + chunk = _FakeChunk() + controller.install_router_patches([chunk]) + controller.set_step(step_index=0, sample_index=[0, 1]) + router = cast( + _FakeRouter, + chunk.decoder.layers[0].mlp.router, # ty: ignore[possibly-missing-attribute] + ) + logits = torch.randn((1, 3), dtype=torch.float32) + + controller.begin_micro(0, 0) + first_probs, first_map = router.routing(logits) + second_probs, second_map = router.routing(logits) + recompute_probs, recompute_map = router.routing(logits) + controller.begin_micro(1, 1) + next_probs, next_map = router.routing(logits) + + calls = bundle.steps[0].routers[bundle.router_keys[0]].calls + expected_first_probs, expected_first_map = _dense_from_compact( + calls[0], dtype=logits.dtype + ) + expected_second_probs, expected_second_map = _dense_from_compact( + calls[1], dtype=logits.dtype + ) + expected_next_probs, expected_next_map = _dense_from_compact( + calls[2], dtype=logits.dtype + ) + + assert torch.equal(first_map.cpu(), expected_first_map) + _assert_probs_close(first_probs.cpu(), expected_first_probs) + assert torch.equal(second_map.cpu(), expected_second_map) + _assert_probs_close(second_probs.cpu(), expected_second_probs) + assert torch.equal(recompute_map.cpu(), expected_second_map) + _assert_probs_close(recompute_probs.cpu(), expected_second_probs) + assert torch.equal(next_map.cpu(), expected_next_map) + _assert_probs_close(next_probs.cpu(), expected_next_probs) + + controller.finalize_step() + controller.remove_router_patches() diff --git a/tests/unit/test_pipeline_trainer_local_backend.py b/tests/unit/test_pipeline_trainer_local_backend.py index 90e2c59d7..e4ed7d4ff 100644 --- a/tests/unit/test_pipeline_trainer_local_backend.py +++ b/tests/unit/test_pipeline_trainer_local_backend.py @@ -199,6 +199,42 @@ async def fake_train_model( assert seen["dev_config"]["packed_sequence_length"] == 2048 +@pytest.mark.asyncio +async def test_local_backend_train_maps_normalize_advantages_to_scale_rewards( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="local-backend-normalize-advantages", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = LocalBackend(path=str(tmp_path)) + seen: dict[str, Any] = {} + + async def fake_train_model( + _model: TrainableModel, + _groups: list[TrajectoryGroup], + config: Any, + dev_config: dict[str, Any], + verbose: bool = False, + ): + seen["dev_config"] = dev_config + yield {} + + backend._train_model = fake_train_model # type: ignore[method-assign] + backend._get_step = AsyncMock(return_value=1) # type: ignore[method-assign] + with patch.object(model, "_get_wandb_run", return_value=None): + await backend.train( + model, + [_make_group([0.0, 1.0])], + normalize_advantages=False, + save_checkpoint=False, + ) + + assert seen["dev_config"]["scale_rewards"] is False + + def _make_tokenized_result( trajectory: Trajectory, token_ids: list[int], @@ -360,7 +396,6 @@ async def aclose(self) -> None: [ ({"loss_fn": "dro"}, "loss_fn='cispo' or loss_fn='ppo'"), ({"loss_fn_config": {"clip": 0.2}}, "loss_fn_config=None"), - ({"normalize_advantages": False}, "normalize_advantages=True"), ({"adam_params": object()}, "adam_params=None"), ], ) diff --git a/tests/unit/test_preprocessing_tokenize.py b/tests/unit/test_preprocessing_tokenize.py index 889870a21..68654ef14 100644 --- a/tests/unit/test_preprocessing_tokenize.py +++ b/tests/unit/test_preprocessing_tokenize.py @@ -1,19 +1,22 @@ +import importlib import sys -import types from typing import cast from openai.types.chat.chat_completion import Choice import pytest from transformers.tokenization_utils_base import BatchEncoding -import art -from art.preprocessing.tokenize import ( - tokenize_sft_batch, - tokenize_trajectory, -) +from art.preprocessing.tokenize import tokenize_trajectory from art.trajectories import History, Trajectory from art.types import MessagesAndChoices +if "tests" not in sys.path: + sys.path.insert(0, "tests") + +build_chat_template_conformance_inputs = importlib.import_module( + "support.chat_template_conformance_cases" +).build_chat_template_conformance_inputs + pytest.importorskip("torch") pytest.importorskip("transformers") @@ -33,9 +36,16 @@ def apply_chat_template( **kwargs, ): del tools, kwargs - rendered = "".join( - f"<{message['role']}>{message.get('content', '')}" for message in messages - ) + rendered_parts = [] + for message in messages: + tool_calls = "".join( + f"{tool_call['function']['name']}:{tool_call['function']['arguments']}" + for tool_call in message.get("tool_calls", []) + ) + rendered_parts.append( + f"<{message['role']}>{tool_calls}{message.get('content', '')}" + ) + rendered = "".join(rendered_parts) if not tokenize: return rendered token_ids = self.encode(rendered, add_special_tokens=False) @@ -127,77 +137,6 @@ def test_tokenize_trajectory_accepts_batchencoding_chat_template_output() -> Non assert assistant_ids == tokenizer.encode("OK", add_special_tokens=False) -def test_tokenize_sft_batch_accepts_batchencoding_chat_template_output( - monkeypatch: pytest.MonkeyPatch, -) -> None: - tokenizer = _FakeTokenizer() - - fake_unsloth = types.ModuleType("unsloth") - fake_unsloth_zoo = types.ModuleType("unsloth_zoo") - fake_dataset_utils = types.ModuleType("unsloth_zoo.dataset_utils") - - def _train_on_responses_only(**kwargs): - del kwargs - - def _labels_fn(batch): - return {"labels": [list(batch["input_ids"][0])]} - - return _labels_fn - - fake_dataset_utils.train_on_responses_only = _train_on_responses_only # type: ignore[attr-defined] - fake_unsloth_zoo.dataset_utils = fake_dataset_utils # type: ignore[attr-defined] - - monkeypatch.setitem(sys.modules, "unsloth", fake_unsloth) - monkeypatch.setitem(sys.modules, "unsloth_zoo", fake_unsloth_zoo) - monkeypatch.setitem(sys.modules, "unsloth_zoo.dataset_utils", fake_dataset_utils) - - trajectory = Trajectory( - messages_and_choices=[ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "World"}, - ] - ) - - batch = tokenize_sft_batch( - trajectory_batch=[trajectory], - learning_rate=1e-5, - tokenizer=tokenizer, # type: ignore[arg-type] - instruction_part="", - response_part="", - ) - - expected_ids = tokenizer.encode( - tokenizer.apply_chat_template( - trajectory.messages_and_choices, - tokenize=False, - add_generation_prompt=False, - ), - add_special_tokens=False, - ) - - assert batch.trajectory_tensors[0]["input_ids"].tolist() == [expected_ids] - assert batch.trajectory_tensors[0]["attention_mask"].tolist() == [ - [1] * len(expected_ids) - ] - assert batch.num_dropped_trajectories == 0 - assert batch.num_tokens == len(expected_ids) - assert batch.num_trainable_tokens == len(expected_ids) - - dropped_batch = tokenize_sft_batch( - trajectory_batch=[trajectory], - learning_rate=1e-5, - tokenizer=tokenizer, # type: ignore[arg-type] - instruction_part="", - response_part="", - max_seq_length=len(expected_ids) - 1, - ) - assert dropped_batch.trajectory_tensors == [] - assert dropped_batch.num_trajectories == 0 - assert dropped_batch.num_tokens == 0 - assert dropped_batch.num_trainable_tokens == 0 - assert dropped_batch.num_dropped_trajectories == 1 - - def test_tokenize_trajectory_normalizes_mapping_tool_arguments_for_chat_template() -> ( None ): @@ -257,3 +196,60 @@ def test_tokenize_trajectory_normalizes_mapping_tool_arguments_for_chat_template ) assert result is not None + + +def test_tokenize_trajectory_non_final_tool_call_mutation_changes_prefill_tokens() -> ( + None +): + tokenizer = _Qwen3_5FakeTokenizer() + inputs = build_chat_template_conformance_inputs(tokenizer) # type: ignore[arg-type] + + base = tokenize_trajectory( + tokenizer=tokenizer, # type: ignore[arg-type] + image_processor=None, + history=History( + messages_and_choices=inputs.non_final_tool_call_base.messages_and_choices, + tools=inputs.non_final_tool_call_base.tools, + ), + advantage=1.0, + allow_training_without_logprobs=False, + trajectory=inputs.non_final_tool_call_base, + ) + mutated = tokenize_trajectory( + tokenizer=tokenizer, # type: ignore[arg-type] + image_processor=None, + history=History( + messages_and_choices=inputs.non_final_tool_call_mutated.messages_and_choices, + tools=inputs.non_final_tool_call_mutated.tools, + ), + advantage=1.0, + allow_training_without_logprobs=False, + trajectory=inputs.non_final_tool_call_mutated, + ) + + assert base is not None + assert mutated is not None + assert len(base.choice_offsets) >= 2 + assert len(mutated.choice_offsets) >= 2 + assert ( + base.token_ids[: base.choice_offsets[-1]] + != mutated.token_ids[: mutated.choice_offsets[-1]] + ) + + +def test_tokenize_trajectory_rejects_assistant_tool_calls_without_logprobs() -> None: + tokenizer = _Qwen3_5FakeTokenizer() + inputs = build_chat_template_conformance_inputs(tokenizer) # type: ignore[arg-type] + + with pytest.raises(ValueError, match="Assistant message has tool_calls"): + tokenize_trajectory( + tokenizer=tokenizer, # type: ignore[arg-type] + image_processor=None, + history=History( + messages_and_choices=inputs.unsupported_assistant_tool_calls.messages_and_choices, + tools=inputs.unsupported_assistant_tool_calls.tools, + ), + advantage=1.0, + allow_training_without_logprobs=True, + trajectory=inputs.unsupported_assistant_tool_calls, + ) diff --git a/tests/unit/test_tinker_renderers.py b/tests/unit/test_tinker_renderers.py index 34671fd10..35db45bef 100644 --- a/tests/unit/test_tinker_renderers.py +++ b/tests/unit/test_tinker_renderers.py @@ -1,8 +1,9 @@ import json from typing import cast -from art.tinker.cookbook_v import renderers -from art.tinker.cookbook_v.tokenizer_utils import Tokenizer +from tinker_cookbook import renderers +from tinker_cookbook.tokenizer_utils import Tokenizer + from art.tinker.renderers import get_renderer_name from art.tinker_native.data import convert_openai_messages_to_renderer_format diff --git a/tests/unit/test_tokenize_trajectory_groups.ipynb b/tests/unit/test_tokenize_trajectory_groups.ipynb index ed488855d..7b10993e6 100644 --- a/tests/unit/test_tokenize_trajectory_groups.ipynb +++ b/tests/unit/test_tokenize_trajectory_groups.ipynb @@ -70,6 +70,7 @@ } ], "source": [ + "# NBVAL_IGNORE_OUTPUT\n", "tokenized_results = list(\n", " tokenize_trajectory_groups(\n", " tokenizer,\n", @@ -121,12 +122,201 @@ " shuffle_group_trajectories=False,\n", " )\n", ")\n", + "actual = []\n", "for result in tokenized_results:\n", " result.advantage = round(result.advantage, 2)\n", " result.weight = round(result.weight, 2)\n", " # set prompt_id to 0 to eliminate stochasticity\n", " result.prompt_id = 0\n", - " display(result) # ty:ignore[unresolved-reference]" + " actual.append(\n", + " {\n", + " \"advantage\": result.advantage,\n", + " \"token_ids\": result.token_ids,\n", + " \"assistant_mask\": result.assistant_mask,\n", + " \"nan_logprobs\": sum(logprob != logprob for logprob in result.logprobs),\n", + " \"last_logprob\": \"nan\"\n", + " if result.logprobs[-1] != result.logprobs[-1]\n", + " else result.logprobs[-1],\n", + " \"weight\": result.weight,\n", + " \"prompt_id\": result.prompt_id,\n", + " \"prompt_length\": result.prompt_length,\n", + " }\n", + " )\n", + "\n", + "assert actual == [\n", + " {\n", + " \"advantage\": -1.0,\n", + " \"token_ids\": [\n", + " 151644,\n", + " 8948,\n", + " 198,\n", + " 2610,\n", + " 525,\n", + " 1207,\n", + " 16948,\n", + " 11,\n", + " 3465,\n", + " 553,\n", + " 54364,\n", + " 14817,\n", + " 13,\n", + " 1446,\n", + " 525,\n", + " 264,\n", + " 10950,\n", + " 17847,\n", + " 13,\n", + " 151645,\n", + " 198,\n", + " 151644,\n", + " 872,\n", + " 198,\n", + " 3838,\n", + " 374,\n", + " 279,\n", + " 6722,\n", + " 315,\n", + " 9625,\n", + " 30,\n", + " 151645,\n", + " 198,\n", + " 151644,\n", + " 77091,\n", + " 198,\n", + " 39572,\n", + " ],\n", + " \"assistant_mask\": [\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 1,\n", + " ],\n", + " \"nan_logprobs\": 37,\n", + " \"last_logprob\": \"nan\",\n", + " \"weight\": 1.0,\n", + " \"prompt_id\": 0,\n", + " \"prompt_length\": 35,\n", + " },\n", + " {\n", + " \"advantage\": 1.0,\n", + " \"token_ids\": [\n", + " 151644,\n", + " 8948,\n", + " 198,\n", + " 2610,\n", + " 525,\n", + " 1207,\n", + " 16948,\n", + " 11,\n", + " 3465,\n", + " 553,\n", + " 54364,\n", + " 14817,\n", + " 13,\n", + " 1446,\n", + " 525,\n", + " 264,\n", + " 10950,\n", + " 17847,\n", + " 13,\n", + " 151645,\n", + " 198,\n", + " 151644,\n", + " 872,\n", + " 198,\n", + " 3838,\n", + " 374,\n", + " 279,\n", + " 6722,\n", + " 315,\n", + " 9625,\n", + " 30,\n", + " 151645,\n", + " 198,\n", + " 151644,\n", + " 77091,\n", + " 198,\n", + " 59604,\n", + " ],\n", + " \"assistant_mask\": [\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 1,\n", + " ],\n", + " \"nan_logprobs\": 36,\n", + " \"last_logprob\": -0.01,\n", + " \"weight\": 1.0,\n", + " \"prompt_id\": 0,\n", + " \"prompt_length\": 35,\n", + " },\n", + "]" ] } ], diff --git a/tests/unit/test_unsloth_autocast_dtype.py b/tests/unit/test_unsloth_autocast_dtype.py index 5438077fa..f2962ef8b 100644 --- a/tests/unit/test_unsloth_autocast_dtype.py +++ b/tests/unit/test_unsloth_autocast_dtype.py @@ -47,6 +47,16 @@ def test_get_dtype_for_autocasting_honors_explicit_fp16(monkeypatch) -> None: assert _get_dtype_for_autocasting(model) == torch.float16 +def test_get_dtype_for_autocasting_honors_force_float32_override( + monkeypatch, +) -> None: + monkeypatch.setenv("ACCELERATE_MIXED_PRECISION", "bf16") + monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "1") + model = _TinyModel([(torch.bfloat16, 8)]) + + assert _get_dtype_for_autocasting(model) == torch.float16 + + def test_get_dtype_for_autocasting_honors_explicit_bfloat16(monkeypatch) -> None: monkeypatch.setenv("ACCELERATE_MIXED_PRECISION", "bf16") monkeypatch.delenv("UNSLOTH_FORCE_FLOAT32", raising=False) diff --git a/tests/unit/test_vllm_patches_contract.py b/tests/unit/test_vllm_patches_contract.py deleted file mode 100644 index d9aafb7e7..000000000 --- a/tests/unit/test_vllm_patches_contract.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Unit tests for ART's vLLM patch contract.""" - -import importlib -from typing import Any, cast - -import pytest - -pytest.importorskip("cloudpickle") -pytest.importorskip("vllm") - -from art.vllm.patches import ( - patch_tool_parser_manager, - patch_transformers_v5_compat, - subclass_chat_completion_request, -) - - -def test_subclass_chat_completion_request_forces_logprobs() -> None: - protocol = importlib.import_module( - "vllm.entrypoints.openai.chat_completion.protocol" - ) - original = getattr(protocol, "ChatCompletionRequest") - - try: - subclass_chat_completion_request() - request_cls = getattr(protocol, "ChatCompletionRequest") - request = request_cls( - messages=[{"role": "user", "content": "hello"}], - model="dummy-model", - ) - assert request.logprobs is True - assert request.top_logprobs == 0 - finally: - setattr(protocol, "ChatCompletionRequest", original) - - -def test_patch_tool_parser_manager_falls_back_to_empty_delta_message() -> None: - protocol = importlib.import_module("vllm.entrypoints.openai.engine.protocol") - DeltaMessage = protocol.DeltaMessage - - from vllm.tool_parsers.abstract_tool_parser import ToolParserManager - - class DummyToolParser: - @staticmethod - def extract_tool_calls_streaming(*_args, **_kwargs): - return None - - original_get_tool_parser = ToolParserManager.get_tool_parser - - try: - setattr( - ToolParserManager, - "get_tool_parser", - classmethod(lambda _cls, _name: DummyToolParser), - ) - patch_tool_parser_manager() - - parser_cls = ToolParserManager.get_tool_parser("dummy") - result = parser_cls.extract_tool_calls_streaming("", "", "", [], [], [], None) # ty:ignore[missing-argument,invalid-argument-type] - - assert isinstance(result, DeltaMessage) - finally: - setattr(ToolParserManager, "get_tool_parser", original_get_tool_parser) - - -def test_patch_transformers_v5_compat_normalizes_rope_ignore_keys() -> None: - from transformers.configuration_utils import PretrainedConfig - - patch_transformers_v5_compat() - - class DummyRopeConfig: - default_theta = 10000.0 - rope_parameters = None - - def standardize_rope_params(self) -> None: - pass - - def validate_rope(self, ignore_keys=None) -> None: - self.ignore_keys = ignore_keys - - dummy = DummyRopeConfig() - PretrainedConfig.convert_rope_params_to_dict( # type: ignore[attr-defined] - cast(Any, dummy), - ignore_keys_at_rope_validation=cast(Any, ["mrope_section"]), - partial_rotary_factor=0.25, - ) - - assert dummy.ignore_keys == {"mrope_section", "partial_rotary_factor"} diff --git a/uv.lock b/uv.lock index 12bb20090..381013c1a 100644 --- a/uv.lock +++ b/uv.lock @@ -24,6 +24,7 @@ resolution-markers = [ overrides = [ { name = "flashinfer-python", specifier = "==0.6.1" }, { name = "numpy", specifier = "<2" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'", specifier = "==2.28.9" }, { name = "nvidia-resiliency-ext", specifier = "<0.5" }, { name = "quack-kernels", specifier = "==0.2.5" }, { name = "transformer-engine", specifier = "==2.11.0" }, @@ -301,25 +302,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] -[[package]] -name = "anthropic" -version = "0.86.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio", marker = "sys_platform == 'linux'" }, - { name = "distro", marker = "sys_platform == 'linux'" }, - { name = "docstring-parser", marker = "sys_platform == 'linux'" }, - { name = "httpx", marker = "sys_platform == 'linux'" }, - { name = "jiter", marker = "sys_platform == 'linux'" }, - { name = "pydantic", marker = "sys_platform == 'linux'" }, - { name = "sniffio", marker = "sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/37/7a/8b390dc47945d3169875d342847431e5f7d5fa716b2e37494d57cfc1db10/anthropic-0.86.0.tar.gz", hash = "sha256:60023a7e879aa4fbb1fed99d487fe407b2ebf6569603e5047cfe304cebdaa0e5", size = 583820, upload-time = "2026-03-18T18:43:08.017Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/63/5f/67db29c6e5d16c8c9c4652d3efb934d89cb750cad201539141781d8eae14/anthropic-0.86.0-py3-none-any.whl", hash = "sha256:9d2bbd339446acce98858c5627d33056efe01f70435b22b63546fe7edae0cd57", size = 469400, upload-time = "2026-03-18T18:43:06.526Z" }, -] - [[package]] name = "antlr4-python3-runtime" version = "4.9.3" @@ -385,15 +367,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321, upload-time = "2024-02-06T09:43:09.663Z" }, ] -[[package]] -name = "astor" -version = "0.8.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/5a/21/75b771132fee241dfe601d39ade629548a9626d1d39f333fde31bc46febe/astor-0.8.1.tar.gz", hash = "sha256:6a6effda93f4e1ce9f618779b2dd1d9d84f1e32812c23a29b3fff6fd7f63fa5e", size = 35090, upload-time = "2019-12-10T01:50:35.51Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/88/97eef84f48fa04fbd6750e62dcceafba6c63c81b7ac1420856c8dcc0a3f9/astor-0.8.1-py2.py3-none-any.whl", hash = "sha256:070a54e890cefb5b3739d19f30f5a5ec840ffc9c50ffa7d23cc9fc1a38ebbfc5", size = 27488, upload-time = "2019-12-10T01:50:33.628Z" }, -] - [[package]] name = "asttokens" version = "3.0.1" @@ -816,65 +789,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/0d/52d98722666d6fc6c3dd4c76df339501d6efd40e0ff95e6186a7b7f0befd/black-26.3.1-py3-none-any.whl", hash = "sha256:2bd5aa94fc267d38bb21a70d7410a89f1a1d318841855f698746f8e7f51acd1b", size = 207542, upload-time = "2026-03-12T03:36:01.668Z" }, ] -[[package]] -name = "blake3" -version = "1.0.8" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.12' and sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/75/aa/abcd75e9600987a0bc6cfe9b6b2ff3f0e2cb08c170addc6e76035b5c4cb3/blake3-1.0.8.tar.gz", hash = "sha256:513cc7f0f5a7c035812604c2c852a0c1468311345573de647e310aca4ab165ba", size = 117308, upload-time = "2025-10-14T06:47:48.83Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f4/0a/515209b0c282c360e249b89cd85350d97cfd55fadbb4df736c67b77b27a1/blake3-1.0.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7fcfe81b3ae3fb5d2e88be0d3259603ff95f0d5ed69f655c28fdaef31e49a470", size = 371092, upload-time = "2025-10-14T06:45:34.062Z" }, - { url = "https://files.pythonhosted.org/packages/a0/33/9d342a2bf5817f006bbe947335e5d387327541ea47590854947befd01251/blake3-1.0.8-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:58ce8d45a5bb5326482de72ea1969a378634236186a970fef63058a5b7b8b435", size = 374859, upload-time = "2025-10-14T06:45:35.262Z" }, - { url = "https://files.pythonhosted.org/packages/5b/fc/ea4bef850a7ec9fbb383503fd3c56056dd9fa44e10c3bc61050ab7b2bac0/blake3-1.0.8-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:83605dbf43f581d8b7175b7f3bfe5388bad5a7c6ac175c9c11d669da31133f4b", size = 448585, upload-time = "2025-10-14T06:45:36.542Z" }, - { url = "https://files.pythonhosted.org/packages/a5/67/167a65a4c431715407d07b1b8b1367698a3ad88e7260edb85f0c5293f08a/blake3-1.0.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b5573b052777142b2cecc453d022c3f21aa4aba75011258410bb98f41c1a727", size = 507519, upload-time = "2025-10-14T06:45:37.814Z" }, - { url = "https://files.pythonhosted.org/packages/32/e2/0886e192d634b264c613b0fbf380745b39992b424a0effc00ef08783644e/blake3-1.0.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe1b02ab49bfd969ef50b9f17482a2011c77536654af21807ba5c2674e0bb2a0", size = 393645, upload-time = "2025-10-14T06:45:39.146Z" }, - { url = "https://files.pythonhosted.org/packages/fc/3b/7fb2fe615448caaa5f6632b2c7551117b38ccac747a3a5769181e9751641/blake3-1.0.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7780666dc6be809b49442d6d5ce06fdbe33024a87560b58471103ec17644682", size = 387640, upload-time = "2025-10-14T06:45:40.546Z" }, - { url = "https://files.pythonhosted.org/packages/bc/8c/2bfc942c6c97cb3d20f341859343bb86ee20af723fedfc886373e606079b/blake3-1.0.8-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:af394b50c6aa0b1b957a99453d1ee440ef67cd2d1b5669c731647dc723de8a3a", size = 550316, upload-time = "2025-10-14T06:45:42.003Z" }, - { url = "https://files.pythonhosted.org/packages/7e/75/0252be37620699b79dbaa799c9b402d63142a131d16731df4ef09d135dd7/blake3-1.0.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c63ece266a43014cf29e772a82857cd8e90315ae3ed53e3c5204851596edd5f2", size = 554463, upload-time = "2025-10-14T06:45:43.22Z" }, - { url = "https://files.pythonhosted.org/packages/ee/7d/85a4c0782f613de23d114a7a78fcce270f75b193b3ff3493a0de24ba104a/blake3-1.0.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:269f255b110840e52b6ce9db02217e39660ebad3e34ddd5bca8b8d378a77e4e1", size = 371296, upload-time = "2025-10-14T06:45:49.674Z" }, - { url = "https://files.pythonhosted.org/packages/e3/20/488475254976ed93fab57c67aa80d3b40df77f7d9db6528c9274bff53e08/blake3-1.0.8-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:66ca28a673025c40db3eba21a9cac52f559f83637efa675b3f6bd8683f0415f3", size = 374516, upload-time = "2025-10-14T06:45:51.23Z" }, - { url = "https://files.pythonhosted.org/packages/7b/21/2a1c47fedb77fb396512677ec6d46caf42ac6e9a897db77edd0a2a46f7bb/blake3-1.0.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bcb04966537777af56c1f399b35525aa70a1225816e121ff95071c33c0f7abca", size = 447911, upload-time = "2025-10-14T06:45:52.637Z" }, - { url = "https://files.pythonhosted.org/packages/cb/7d/db0626df16029713e7e61b67314c4835e85c296d82bd907c21c6ea271da2/blake3-1.0.8-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e5b5da177d62cc4b7edf0cea08fe4dec960c9ac27f916131efa890a01f747b93", size = 505420, upload-time = "2025-10-14T06:45:54.445Z" }, - { url = "https://files.pythonhosted.org/packages/5b/55/6e737850c2d58a6d9de8a76dad2ae0f75b852a23eb4ecb07a0b165e6e436/blake3-1.0.8-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:38209b10482c97e151681ea3e91cc7141f56adbbf4820a7d701a923124b41e6a", size = 394189, upload-time = "2025-10-14T06:45:55.719Z" }, - { url = "https://files.pythonhosted.org/packages/5b/94/eafaa5cdddadc0c9c603a6a6d8339433475e1a9f60c8bb9c2eed2d8736b6/blake3-1.0.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:504d1399b7fb91dfe5c25722d2807990493185faa1917456455480c36867adb5", size = 388001, upload-time = "2025-10-14T06:45:57.067Z" }, - { url = "https://files.pythonhosted.org/packages/17/81/735fa00d13de7f68b25e1b9cb36ff08c6f165e688d85d8ec2cbfcdedccc5/blake3-1.0.8-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c84af132aa09abeadf9a0118c8fb26f4528f3f42c10ef8be0fcf31c478774ec4", size = 550302, upload-time = "2025-10-14T06:45:58.657Z" }, - { url = "https://files.pythonhosted.org/packages/0e/c6/d1fe8bdea4a6088bd54b5a58bc40aed89a4e784cd796af7722a06f74bae7/blake3-1.0.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a25db3d36b55f5ed6a86470155cc749fc9c5b91c949b8d14f48658f9d960d9ec", size = 554211, upload-time = "2025-10-14T06:46:00.269Z" }, - { url = "https://files.pythonhosted.org/packages/77/57/e8a85fa261894bf7ce7af928ff3408aab60287ab8d58b55d13a3f700b619/blake3-1.0.8-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19fc6f2b7edab8acff6895fc6e38c19bd79f4c089e21153020c75dfc7397d52d", size = 370994, upload-time = "2025-10-14T06:46:07.398Z" }, - { url = "https://files.pythonhosted.org/packages/62/cd/765b76bb48b8b294fea94c9008b0d82b4cfa0fa2f3c6008d840d01a597e4/blake3-1.0.8-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4f54cff7f15d91dc78a63a2dd02a3dccdc932946f271e2adb4130e0b4cf608ba", size = 374372, upload-time = "2025-10-14T06:46:08.698Z" }, - { url = "https://files.pythonhosted.org/packages/36/7a/32084eadbb28592bb07298f0de316d2da586c62f31500a6b1339a7e7b29b/blake3-1.0.8-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7e12a777f6b798eb8d06f875d6e108e3008bd658d274d8c676dcf98e0f10537", size = 447627, upload-time = "2025-10-14T06:46:10.002Z" }, - { url = "https://files.pythonhosted.org/packages/a7/f4/3788a1d86e17425eea147e28d7195d7053565fc279236a9fd278c2ec495e/blake3-1.0.8-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddfc59b0176fb31168f08d5dd536e69b1f4f13b5a0f4b0c3be1003efd47f9308", size = 507536, upload-time = "2025-10-14T06:46:11.614Z" }, - { url = "https://files.pythonhosted.org/packages/fe/01/4639cba48513b94192681b4da472cdec843d3001c5344d7051ee5eaef606/blake3-1.0.8-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a2336d5b2a801a7256da21150348f41610a6c21dae885a3acb1ebbd7333d88d8", size = 394105, upload-time = "2025-10-14T06:46:12.808Z" }, - { url = "https://files.pythonhosted.org/packages/21/ae/6e55c19c8460fada86cd1306a390a09b0c5a2e2e424f9317d2edacea439f/blake3-1.0.8-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4072196547484c95a5a09adbb952e9bb501949f03f9e2a85e7249ef85faaba8", size = 386928, upload-time = "2025-10-14T06:46:16.284Z" }, - { url = "https://files.pythonhosted.org/packages/ee/6c/05b7a5a907df1be53a8f19e7828986fc6b608a44119641ef9c0804fbef15/blake3-1.0.8-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:0eab3318ec02f8e16fe549244791ace2ada2c259332f0c77ab22cf94dfff7130", size = 550003, upload-time = "2025-10-14T06:46:17.791Z" }, - { url = "https://files.pythonhosted.org/packages/b4/03/f0ea4adfedc1717623be6460b3710fcb725ca38082c14274369803f727e1/blake3-1.0.8-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:a33b9a1fb6d1d559a8e0d04b041e99419a6bb771311c774f6ff57ed7119c70ed", size = 553857, upload-time = "2025-10-14T06:46:19.088Z" }, - { url = "https://files.pythonhosted.org/packages/13/da/722cebca11238f3b24d3cefd2361c9c9ea47cfa0ad9288eeb4d1e0b7cf93/blake3-1.0.8-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ef153c5860d5bf1cc71aece69b28097d2a392913eb323d6b52555c875d0439fc", size = 370441, upload-time = "2025-10-14T06:46:26.29Z" }, - { url = "https://files.pythonhosted.org/packages/2e/d5/2f7440c8e41c0af995bad3a159e042af0f4ed1994710af5b4766ca918f65/blake3-1.0.8-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e8ae3689f0c7bfa6ce6ae45cab110e4c3442125c4c23b28f1f097856de26e4d1", size = 374312, upload-time = "2025-10-14T06:46:27.451Z" }, - { url = "https://files.pythonhosted.org/packages/a6/6c/fb6a7812e60ce3e110bcbbb11f167caf3e975c589572c41e1271f35f2c41/blake3-1.0.8-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fb83532f7456ddeb68dae1b36e1f7c52f9cb72852ac01159bbcb1a12b0f8be0", size = 447007, upload-time = "2025-10-14T06:46:29.056Z" }, - { url = "https://files.pythonhosted.org/packages/13/3b/c99b43fae5047276ea9d944077c190fc1e5f22f57528b9794e21f7adedc6/blake3-1.0.8-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6ae7754c7d96e92a70a52e07c732d594cf9924d780f49fffd3a1e9235e0f5ba7", size = 507323, upload-time = "2025-10-14T06:46:30.661Z" }, - { url = "https://files.pythonhosted.org/packages/fc/bb/ba90eddd592f8c074a0694cb0a744b6bd76bfe67a14c2b490c8bdfca3119/blake3-1.0.8-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4bacaae75e98dee3b7da6c5ee3b81ee21a3352dd2477d6f1d1dbfd38cdbf158a", size = 393449, upload-time = "2025-10-14T06:46:31.805Z" }, - { url = "https://files.pythonhosted.org/packages/25/ed/58a2acd0b9e14459cdaef4344db414d4a36e329b9720921b442a454dd443/blake3-1.0.8-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9456c829601d72852d8ba0af8dae0610f7def1d59f5942efde1e2ef93e8a8b57", size = 386844, upload-time = "2025-10-14T06:46:33.195Z" }, - { url = "https://files.pythonhosted.org/packages/4a/04/fed09845b18d90862100c8e48308261e2f663aab25d3c71a6a0bdda6618b/blake3-1.0.8-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:497ef8096ec4ac1ffba9a66152cee3992337cebf8ea434331d8fd9ce5423d227", size = 549550, upload-time = "2025-10-14T06:46:35.23Z" }, - { url = "https://files.pythonhosted.org/packages/d6/65/1859fddfabc1cc72548c2269d988819aad96d854e25eae00531517925901/blake3-1.0.8-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:511133bab85ff60ed143424ce484d08c60894ff7323f685d7a6095f43f0c85c3", size = 553805, upload-time = "2025-10-14T06:46:36.532Z" }, - { url = "https://files.pythonhosted.org/packages/49/fa/b913eb9cc4af708c03e01e6b88a8bb3a74833ba4ae4b16b87e2829198e06/blake3-1.0.8-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a47939f04b89c5c6ff1e51e883e5efab1ea1bf01a02f4d208d216dddd63d0dd8", size = 370654, upload-time = "2025-10-14T06:46:43.907Z" }, - { url = "https://files.pythonhosted.org/packages/7f/4f/245e0800c33b99c8f2b570d9a7199b51803694913ee4897f339648502933/blake3-1.0.8-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:73e0b4fa25f6e3078526a592fb38fca85ef204fd02eced6731e1cdd9396552d4", size = 374693, upload-time = "2025-10-14T06:46:45.186Z" }, - { url = "https://files.pythonhosted.org/packages/a2/a6/8cb182c8e482071dbdfcc6ec0048271fd48bcb78782d346119ff54993700/blake3-1.0.8-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4b0543c57eb9d6dac9d4bced63e9f7f7b546886ac04cec8da3c3d9c8f30cbbb7", size = 447673, upload-time = "2025-10-14T06:46:46.358Z" }, - { url = "https://files.pythonhosted.org/packages/06/b7/1cbbb5574d2a9436d1b15e7eb5b9d82e178adcaca71a97b0fddaca4bfe3a/blake3-1.0.8-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ed972ebd553c0c25363459e9fc71a38c045d8419e365b59acd8cd791eff13981", size = 507233, upload-time = "2025-10-14T06:46:48.109Z" }, - { url = "https://files.pythonhosted.org/packages/9c/45/b55825d90af353b3e26c653bab278da9d6563afcf66736677f9397e465be/blake3-1.0.8-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3bafdec95dfffa3f6571e529644744e280337df15ddd9728f224ba70c5779b23", size = 393852, upload-time = "2025-10-14T06:46:49.511Z" }, - { url = "https://files.pythonhosted.org/packages/34/73/9058a1a457dd20491d1b37de53d6876eff125e1520d9b2dd7d0acbc88de2/blake3-1.0.8-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d78f06f3fb838b34c330e2987090376145cbe5944d8608a0c4779c779618f7b", size = 386442, upload-time = "2025-10-14T06:46:51.205Z" }, - { url = "https://files.pythonhosted.org/packages/30/6d/561d537ffc17985e276e08bf4513f1c106f1fdbef571e782604dc4e44070/blake3-1.0.8-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:dd03ff08d1b6e4fdda1cd03826f971ae8966ef6f683a8c68aa27fb21904b5aa9", size = 549929, upload-time = "2025-10-14T06:46:52.494Z" }, - { url = "https://files.pythonhosted.org/packages/03/2f/dbe20d2c57f1a67c63be4ba310bcebc707b945c902a0bde075d2a8f5cd5c/blake3-1.0.8-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:4e02a3c499e35bf51fc15b2738aca1a76410804c877bcd914752cac4f71f052a", size = 553750, upload-time = "2025-10-14T06:46:54.194Z" }, - { url = "https://files.pythonhosted.org/packages/11/33/503b37220a3e2e31917ef13722efd00055af51c5e88ae30974c733d7ece6/blake3-1.0.8-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88d527c247f9609dc1d45a08fd243e39f0d5300d54c57e048de24d4fa9240ebb", size = 370220, upload-time = "2025-10-14T06:47:02.573Z" }, - { url = "https://files.pythonhosted.org/packages/3e/df/fe817843adf59516c04d44387bd643b422a3b0400ea95c6ede6a49920737/blake3-1.0.8-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:506a47897a11ebe8f3cdeb52f1365d6a2f83959e98ccb0c830f8f73277d4d358", size = 373454, upload-time = "2025-10-14T06:47:03.784Z" }, - { url = "https://files.pythonhosted.org/packages/d1/4d/90a2a623575373dfc9b683f1bad1bf017feafa5a6d65d94fb09543050740/blake3-1.0.8-cp314-cp314t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5122a61b3b004bbbd979bdf83a3aaab432da3e2a842d7ddf1c273f2503b4884", size = 447102, upload-time = "2025-10-14T06:47:04.958Z" }, - { url = "https://files.pythonhosted.org/packages/93/ff/4e8ce314f60115c4c657b1fdbe9225b991da4f5bcc5d1c1f1d151e2f39d6/blake3-1.0.8-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0171e85d56dec1219abdae5f49a0ed12cb3f86a454c29160a64fd8a8166bba37", size = 506791, upload-time = "2025-10-14T06:47:06.82Z" }, - { url = "https://files.pythonhosted.org/packages/44/88/2963a1f18aab52bdcf35379b2b48c34bbc462320c37e76960636b8602c36/blake3-1.0.8-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:003f61e8c41dd9931edddf1cc6a1bb680fb2ac0ad15493ef4a1df9adc59ce9df", size = 393717, upload-time = "2025-10-14T06:47:09.085Z" }, - { url = "https://files.pythonhosted.org/packages/45/d1/a848ed8e8d4e236b9b16381768c9ae99d92890c24886bb4505aa9c3d2033/blake3-1.0.8-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2c3151955efb09ba58cd3e1263521e15e9e3866a40d6bd3556d86fc968e8f95", size = 386150, upload-time = "2025-10-14T06:47:10.363Z" }, - { url = "https://files.pythonhosted.org/packages/96/09/e3eb5d60f97c01de23d9f434e6e1fc117efb466eaa1f6ddbbbcb62580d6e/blake3-1.0.8-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:5eb25bca3cee2e0dd746a214784fb36be6a43640c01c55b6b4e26196e72d076c", size = 549120, upload-time = "2025-10-14T06:47:11.713Z" }, - { url = "https://files.pythonhosted.org/packages/14/ad/3d9661c710febb8957dd685fdb3e5a861aa0ac918eda3031365ce45789e2/blake3-1.0.8-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:ab4e1dea4fa857944944db78e8f20d99ee2e16b2dea5a14f514fb0607753ac83", size = 553264, upload-time = "2025-10-14T06:47:13.317Z" }, -] - [[package]] name = "blinker" version = "1.9.0" @@ -1090,31 +1004,6 @@ requires-dist = [ { name = "torch" }, ] -[[package]] -name = "cbor2" -version = "5.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bd/cb/09939728be094d155b5d4ac262e39877875f5f7e36eea66beb359f647bd0/cbor2-5.9.0.tar.gz", hash = "sha256:85c7a46279ac8f226e1059275221e6b3d0e370d2bb6bd0500f9780781615bcea", size = 111231, upload-time = "2026-03-22T15:56:50.638Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/31/43/fe29b1f897770011a5e7497f4523c2712282ee4a6cbf775ea6383fb7afb9/cbor2-5.9.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a9d6e4e0f988b0e766509a8071975a8ee99f930e14a524620bf38083106158d2", size = 268738, upload-time = "2026-03-22T15:56:05.222Z" }, - { url = "https://files.pythonhosted.org/packages/0a/1a/e494568f3d8aafbcdfe361df44c3bcf5cdab5183e25ea08e3d3f9fcf4075/cbor2-5.9.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5326336f633cc89dfe543c78829c16c3a6449c2c03277d1ddba99086c3323363", size = 262571, upload-time = "2026-03-22T15:56:06.411Z" }, - { url = "https://files.pythonhosted.org/packages/42/2e/92acd6f87382fd44a34d9d7e85cc45372e6ba664040b72d1d9df648b25d0/cbor2-5.9.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5e702b02d42a5ace45425b595ffe70fe35aebaf9a3cdfdc2c758b6189c744422", size = 262356, upload-time = "2026-03-22T15:56:08.236Z" }, - { url = "https://files.pythonhosted.org/packages/3f/68/52c039a28688baeeb78b0be7483855e6c66ea05884a937444deede0c87b8/cbor2-5.9.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:2372d357d403e7912f104ff085950ffc82a5854d6d717f1ca1ce16a40a0ef5a7", size = 257604, upload-time = "2026-03-22T15:56:09.835Z" }, - { url = "https://files.pythonhosted.org/packages/09/fd/7ddf3d3153b54c69c3be77172b8d9aa3a9d74f62a7fbde614d53eaeed9a4/cbor2-5.9.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ae6c706ac1d85a0b3cb3395308fd0c4d55e3202b4760773675957e93cdff45fc", size = 287865, upload-time = "2026-03-22T15:56:14.813Z" }, - { url = "https://files.pythonhosted.org/packages/db/9d/7ede2cc42f9bb4260492e7d29d2aab781eacbbcfb09d983de1e695077199/cbor2-5.9.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4cd43d8fc374b31643b2830910f28177a606a7bc84975a62675dd3f2e320fc7b", size = 288246, upload-time = "2026-03-22T15:56:16.113Z" }, - { url = "https://files.pythonhosted.org/packages/ce/9d/588ebc7c5bc5843f609b05fe07be8575c7dec987735b0bbc908ac9c1264a/cbor2-5.9.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4aa07b392cc3d76fb31c08a46a226b58c320d1c172ff3073e864409ced7bc50f", size = 280214, upload-time = "2026-03-22T15:56:17.519Z" }, - { url = "https://files.pythonhosted.org/packages/f7/a1/6fc8f4b15c6a27e7fbb7966c30c2b4b18c274a3221fa2f5e6235502d34bc/cbor2-5.9.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:971d425b3a23b75953d8853d5f9911bdeefa09d759ee3b5e6b07b5ff3cbd9073", size = 282162, upload-time = "2026-03-22T15:56:18.975Z" }, - { url = "https://files.pythonhosted.org/packages/1b/10/df643a381aebc3f05486de4813662bc58accb640fc3275cb276a75e89694/cbor2-5.9.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ac684fe195c39821fca70d18afbf748f728aefbfbf88456018d299e559b8cae0", size = 287682, upload-time = "2026-03-22T15:56:24.024Z" }, - { url = "https://files.pythonhosted.org/packages/c6/0c/8aa6b766059ae4a0ca1ec3ff96fe3823a69a7be880dba2e249f7fbe2700b/cbor2-5.9.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2a54fbb32cb828c214f7f333a707e4aec61182e7efdc06ea5d9596d3ecee624a", size = 288009, upload-time = "2026-03-22T15:56:25.305Z" }, - { url = "https://files.pythonhosted.org/packages/74/07/6236bc25c183a9cf7e8062e5dddf9eae9b0b14ebf14a58a69fe5a1e872c6/cbor2-5.9.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4753a6d1bc71054d9179557bc65740860f185095ccb401d46637fff028a5b3ec", size = 280437, upload-time = "2026-03-22T15:56:26.479Z" }, - { url = "https://files.pythonhosted.org/packages/4e/0a/84328d23c3c68874ac6497edb9b1900579a1028efa54734df3f1762bbc15/cbor2-5.9.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:380e534482b843e43442b87d8777a7bf9bed20cb7526f89b780c3400f617304b", size = 282247, upload-time = "2026-03-22T15:56:28.644Z" }, - { url = "https://files.pythonhosted.org/packages/70/e1/a6cca2cc72e13f00030c6a649f57ae703eb2c620806ab70c40db8eab33fa/cbor2-5.9.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0322296b9d52f55880e300ba8ba09ecf644303b99b51138bbb1c0fb644fa7c3e", size = 286953, upload-time = "2026-03-22T15:56:33.292Z" }, - { url = "https://files.pythonhosted.org/packages/08/3c/24cd5ef488a957d90e016f200a3aad820e4c2f85edd61c9fe4523007a1ee/cbor2-5.9.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:422817286c1d0ce947fb2f7eca9212b39bddd7231e8b452e2d2cc52f15332dba", size = 285454, upload-time = "2026-03-22T15:56:34.703Z" }, - { url = "https://files.pythonhosted.org/packages/a4/35/dca96818494c0ba47cdd73e8d809b27fa91f8fa0ce32a068a09237687454/cbor2-5.9.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:9a4907e0c3035bb8836116854ed8e56d8aef23909d601fa59706320897ec2551", size = 279441, upload-time = "2026-03-22T15:56:35.888Z" }, - { url = "https://files.pythonhosted.org/packages/a4/44/d3362378b16e53cf7e535a3f5aed8476e2109068154e24e31981ef5bde9e/cbor2-5.9.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:fb7afe77f8d269e42d7c4b515c6fd14f1ccc0625379fb6829b269f493d16eddd", size = 279673, upload-time = "2026-03-22T15:56:37.08Z" }, - { url = "https://files.pythonhosted.org/packages/42/ff/b83492b096fbef26e9cb62c1a4bf2d3cef579ea7b33138c6c37c4ae66f67/cbor2-5.9.0-py3-none-any.whl", hash = "sha256:27695cbd70c90b8de5c4a284642c2836449b14e2c2e07e3ffe0744cb7669a01b", size = 24627, upload-time = "2026-03-22T15:56:48.847Z" }, -] - [[package]] name = "certifi" version = "2026.2.25" @@ -1365,7 +1254,7 @@ wheels = [ [[package]] name = "comet-ml" -version = "3.57.3" +version = "3.57.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "dulwich" }, @@ -1384,9 +1273,9 @@ dependencies = [ { name = "wrapt" }, { name = "wurlitzer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b8/e3/39e6a2ecb0f83bacc414c23cbab86a330540396607fc6894588a1d71e092/comet_ml-3.57.3.tar.gz", hash = "sha256:84be2337366d598e8a299eb56e2efcffb859d7de7957148f81f124d3ef02728d", size = 585296, upload-time = "2026-03-12T14:23:22.371Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7f/c6/3885cbc9fe99617ee492403d464906dc15bf17943397c31022fba0997e73/comet_ml-3.57.4.tar.gz", hash = "sha256:42b06f5b473ea270f665409477983f52fa5356ee88e9447f07fc610e47850b5e", size = 585959, upload-time = "2026-04-29T13:37:36.617Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/17/c875fe4d53286f8664d3edf7185b488cddfca6c208ba978aa34550e9d10d/comet_ml-3.57.3-py3-none-any.whl", hash = "sha256:4eedf5da98a000466759adbd7e96d6f44004ce1bd844e9bc03d6ec5f86583249", size = 786246, upload-time = "2026-03-12T14:23:20.942Z" }, + { url = "https://files.pythonhosted.org/packages/b8/fb/d6c7c9df3fffcd8f3ab6d9926bd6dcf7eedd14daa78f5f76dc4b50140707/comet_ml-3.57.4-py3-none-any.whl", hash = "sha256:8fc913b9b50fa60d372d8e2190f8543fffe4d6a0c9fddd9582b394826906e0e3", size = 787005, upload-time = "2026-04-29T13:37:34.703Z" }, ] [[package]] @@ -1398,21 +1287,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl", hash = "sha256:c615d91d75f7f04f095b30d1c1711babd43bdc6419c1be9886a85f2f4e489417", size = 7294, upload-time = "2025-07-25T14:02:02.896Z" }, ] -[[package]] -name = "compressed-tensors" -version = "0.13.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "loguru", marker = "sys_platform == 'linux'" }, - { name = "pydantic", marker = "sys_platform == 'linux'" }, - { name = "torch", marker = "sys_platform == 'linux'" }, - { name = "transformers", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/fc/65/88dd1c58fb9d0ded51b5c86471b937a1525f91fad2211a6f051dc1ea822d/compressed_tensors-0.13.0.tar.gz", hash = "sha256:23893824d3498ea3f1a829f14a8fa85f9a5e76a34c711a038b8d7c619ca9a67c", size = 200995, upload-time = "2025-12-16T16:03:55.397Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0b/b5/61ac2563c62490922b603c09113a083fd74af3630ec3931e769484d6dcb5/compressed_tensors-0.13.0-py3-none-any.whl", hash = "sha256:3518799c9baf034eb642efb551db6b0537b8713d45a64fe4def26f7f8d6cabec", size = 192620, upload-time = "2025-12-16T16:03:53.041Z" }, -] - [[package]] name = "configobj" version = "5.0.9" @@ -1695,25 +1569,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/06/fc198cc9bc0170fcc07344c04af5de3a70a67b30aa040120f06415e76c65/cudo_compute-0.3.6-py3-none-any.whl", hash = "sha256:1b72a8f09333106fe9c350d0b35171dce2b339752036f64c38096f4e20d6b5d1", size = 380302, upload-time = "2025-01-08T16:50:45.282Z" }, ] -[[package]] -name = "cupy-cuda12x" -version = "14.0.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cuda-pathfinder", marker = "sys_platform == 'linux'" }, - { name = "numpy", marker = "sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/11/6d089629f44591864bc8a11fa64c9d4fcd1afb4a7217954c806fb47c4fe5/cupy_cuda12x-14.0.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:31e6a33579a06fde3ff238b8b6b72446384d17554b2a3b14f818c9ee44b0c2e6", size = 146237981, upload-time = "2026-02-20T10:22:29.065Z" }, - { url = "https://files.pythonhosted.org/packages/37/f0/0f1d79c0c7fccbc2ed0c0ff3be1b0562be60b764c729ca8ded1bd6d953aa/cupy_cuda12x-14.0.1-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bfbde2e9f7946021b49414f9c800991163f2a56a1318f3d7d69cbb06001a1585", size = 135080693, upload-time = "2026-02-20T10:22:35.843Z" }, - { url = "https://files.pythonhosted.org/packages/38/ca/b93ef9fca1471a65f136a73e10819634c0b83427362fc08fc9f29f935bf0/cupy_cuda12x-14.0.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:f244bc14fad6f1ef0c74abd98afa4b82d2534aecdba911197810ec0047f0d1f3", size = 145578614, upload-time = "2026-02-20T10:22:49.108Z" }, - { url = "https://files.pythonhosted.org/packages/5a/a6/944406223a190815d9df156a1d66f3b0352bd8827dc4a8c752196d616dbc/cupy_cuda12x-14.0.1-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:9f0c81c3509f77be3ae8444759d5b314201b2dfcbbf2ae0d0b5fb7a61f20893c", size = 134613763, upload-time = "2026-02-20T10:22:56.792Z" }, - { url = "https://files.pythonhosted.org/packages/99/67/f967c5aff77bd6ae6765faf20580db80bb8a7e2574e999166de1d4e50146/cupy_cuda12x-14.0.1-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:9d9b1bdcf9fa777593017867e8733192c071b94639a1b3e8b2ee99eb3f3ea760", size = 145128055, upload-time = "2026-02-20T10:23:08.765Z" }, - { url = "https://files.pythonhosted.org/packages/80/53/037c931731151c504cfc00069eb295c903927c92145115623f13bd2ea076/cupy_cuda12x-14.0.1-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:21fcb4e917e43237edcc5e3a1a1241e2a2946ba9e577ce36fd580bd9856f91e8", size = 134227269, upload-time = "2026-02-20T10:23:16.147Z" }, - { url = "https://files.pythonhosted.org/packages/5d/cb/ba61bcd602856aeabf362280cb3c17ed5fe03ae23e84578eb99f5245546c/cupy_cuda12x-14.0.1-cp314-cp314-manylinux2014_aarch64.whl", hash = "sha256:3be87da86d808d9fec23b0a1df001f15f8f145698bc4bebc6d6938fa7e11519f", size = 144976386, upload-time = "2026-02-20T10:23:29.877Z" }, - { url = "https://files.pythonhosted.org/packages/ba/73/34e5f334f6b1e5c5dff80af8109979fb0e8461b27e4454517e0e47486455/cupy_cuda12x-14.0.1-cp314-cp314-manylinux2014_x86_64.whl", hash = "sha256:fa356384760e01498d010af2d96de536ef3dad19db1d3a1ad0764e4323fb919f", size = 133521354, upload-time = "2026-02-20T10:23:37.063Z" }, -] - [[package]] name = "cut-cross-entropy" version = "25.1.1" @@ -1841,19 +1696,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/d0/205d54408c08b13550c733c4b85429e7ead111c7f0014309637425520a9a/deprecated-1.3.1-py2.py3-none-any.whl", hash = "sha256:597bfef186b6f60181535a29fbe44865ce137a5079f295b479886c82729d5f3f", size = 11298, upload-time = "2025-10-30T08:19:00.758Z" }, ] -[[package]] -name = "depyf" -version = "0.20.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "astor", marker = "sys_platform == 'linux'" }, - { name = "dill", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/88/35/83fb0178212279aa0af031031905804c6de5618435d229f41ed21bb9ad2c/depyf-0.20.0.tar.gz", hash = "sha256:fb7683bd72c44f67b56029df2c47721e9a02ffa4d7b19095f1c54c4ebf797a98", size = 6168761, upload-time = "2025-10-13T12:33:38.589Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/cf/65/4df6936130b56e1429114e663e7c1576cf845f3aef1b2dd200c0a5d19dba/depyf-0.20.0-py3-none-any.whl", hash = "sha256:d31effad4261cebecb58955d832e448ace88f432328f95f82fd99c30fd9308d4", size = 39381, upload-time = "2025-10-13T12:33:33.647Z" }, -] - [[package]] name = "diffusers" version = "0.37.0" @@ -1883,15 +1725,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/50/3d/9373ad9c56321fdab5b41197068e1d8c25883b3fea29dd361f9b55116869/dill-0.4.0-py3-none-any.whl", hash = "sha256:44f54bf6412c2c8464c14e8243eb163690a9800dbe2c367330883b19c7561049", size = 119668, upload-time = "2025-04-16T00:41:47.671Z" }, ] -[[package]] -name = "diskcache" -version = "5.6.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3f/21/1c1ffc1a039ddcc459db43cc108658f32c57d271d7289a2794e401d0fdb6/diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc", size = 67916, upload-time = "2023-08-31T06:12:00.316Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550, upload-time = "2023-08-31T06:11:58.822Z" }, -] - [[package]] name = "diskcache-weave" version = "5.6.3.post1" @@ -2129,16 +1962,6 @@ all = [ { name = "pyyaml" }, { name = "uvicorn", extra = ["standard"] }, ] -standard = [ - { name = "email-validator", marker = "sys_platform == 'linux'" }, - { name = "fastapi-cli", extra = ["standard"], marker = "sys_platform == 'linux'" }, - { name = "httpx", marker = "sys_platform == 'linux'" }, - { name = "jinja2", marker = "sys_platform == 'linux'" }, - { name = "pydantic-extra-types", marker = "sys_platform == 'linux'" }, - { name = "pydantic-settings", marker = "sys_platform == 'linux'" }, - { name = "python-multipart", marker = "sys_platform == 'linux'" }, - { name = "uvicorn", extra = ["standard"], marker = "sys_platform == 'linux'" }, -] [[package]] name = "fastapi-cli" @@ -2643,21 +2466,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/6e/81d47999aebc1b155f81eca4477a616a70f238a2549848c38983f3c22a82/ftfy-6.3.1-py3-none-any.whl", hash = "sha256:7c70eb532015cd2f9adb53f101fb6c7945988d023a085d127d1573dc49dd0083", size = 44821, upload-time = "2024-10-26T00:50:33.425Z" }, ] -[[package]] -name = "gguf" -version = "0.18.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy", marker = "sys_platform == 'linux'" }, - { name = "pyyaml", marker = "sys_platform == 'linux'" }, - { name = "requests", marker = "sys_platform == 'linux'" }, - { name = "tqdm", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/3f/26/7622a41c39db9d7090225a4bf8368550e59694dcf7313b44f9a82b501209/gguf-0.18.0.tar.gz", hash = "sha256:b4659093d5d0dccdb5902a904d54b327f4052879fe5e90946ad5fce9f8018c2e", size = 107170, upload-time = "2026-02-27T15:05:39.254Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5e/0c/e0f1eae7535a97476fb903f65301e35da2a66182b8161066b7eb312b2cb8/gguf-0.18.0-py3-none-any.whl", hash = "sha256:af93f7ef198a265cbde5fa6a6b3101528bca285903949ab0a3e591cd993a1864", size = 114244, upload-time = "2026-02-27T15:05:37.991Z" }, -] - [[package]] name = "gitdb" version = "4.0.12" @@ -2992,19 +2800,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/b2/b096ccce418882fbfda4f7496f9357aaa9a5af1896a9a7f60d9f2b275a06/grpcio-1.78.0-cp314-cp314-win_amd64.whl", hash = "sha256:dce09d6116df20a96acfdbf85e4866258c3758180e8c49845d6ba8248b6d0bbb", size = 4929852, upload-time = "2026-02-06T09:56:45.885Z" }, ] -[[package]] -name = "grpcio-reflection" -version = "1.71.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "grpcio", marker = "sys_platform == 'linux'" }, - { name = "protobuf", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/41/14/4e5f8e902fa9461abae292773b921a578f68333c7c3e731bcff7514f78cd/grpcio_reflection-1.71.2.tar.gz", hash = "sha256:bedfac3d2095d6c066b16b66bfce85b4be3e92dc9f3b7121e6f019d24a9c09c0", size = 18798, upload-time = "2025-06-28T04:24:06.019Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a3/89/c99ff79b90315cf47dbcdd86babb637764e5f14f523d622020bfee57dc4d/grpcio_reflection-1.71.2-py3-none-any.whl", hash = "sha256:c4f1a0959acb94ec9e1369bb7dab827cc9a6efcc448bdb10436246c8e52e2f57", size = 22684, upload-time = "2025-06-28T04:23:44.759Z" }, -] - [[package]] name = "gunicorn" version = "25.1.0" @@ -3250,15 +3045,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/16/8d/85c9701e9af72ca132a1783e2a54364a90c6da832304416a30fc11196ab2/httpx_aiohttp-0.1.12-py3-none-any.whl", hash = "sha256:5b0eac39a7f360fa7867a60bcb46bb1024eada9c01cbfecdb54dc1edb3fb7141", size = 6367, upload-time = "2025-12-12T10:12:14.018Z" }, ] -[[package]] -name = "httpx-sse" -version = "0.4.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" }, -] - [[package]] name = "huey" version = "2.6.0" @@ -3488,15 +3274,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ce/ff/3b59672c47c6284e8005b42e84ceba13864aa0f39f067c973d1af02f5d91/InquirerPy-0.3.4-py3-none-any.whl", hash = "sha256:c65fdfbac1fa00e3ee4fb10679f4d3ed7a012abf4833910e63c295827fe2a7d4", size = 67677, upload-time = "2022-06-27T23:11:17.723Z" }, ] -[[package]] -name = "interegular" -version = "0.3.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/dc/9d/8b6dde58a028a3962ce17e84d5fe73758df61378e00ef8ac3d85da34b0ff/interegular-0.3.3.tar.gz", hash = "sha256:d9b697b21b34884711399ba0f0376914b81899ce670032486d0d048344a76600", size = 24705, upload-time = "2024-01-06T23:01:22.372Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c4/01/72d6472f80651673716d1deda2a5bbb633e563ecf94f4479da5519d69d25/interegular-0.3.3-py37-none-any.whl", hash = "sha256:b0c07007d48c89d6d19f7204972d369b2a77222722e126b6aa63aa721dc3b19c", size = 23635, upload-time = "2024-01-06T23:01:20.829Z" }, -] - [[package]] name = "intervaltree" version = "3.2.1" @@ -3920,22 +3697,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4a/4a/cf14bf3b1f5ffb13c69cf5f0ea78031247790558ee88984a8bdd22fae60d/kaitaistruct-0.11-py2.py3-none-any.whl", hash = "sha256:5c6ce79177b4e193a577ecd359e26516d1d6d000a0bffd6e1010f2a46a62a561", size = 11372, upload-time = "2025-09-08T15:46:23.635Z" }, ] -[[package]] -name = "kaldi-native-fbank" -version = "1.22.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3a/2c/84076b352107ce12d56f28c313f1aca1be332d953dd96aec7b84976e6d53/kaldi-native-fbank-1.22.3.tar.gz", hash = "sha256:387bf87225c6b83c93ae652eeaef1b4d531994b6e398e7a77189de340674f9af", size = 71013, upload-time = "2025-10-09T02:31:21.487Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/53/720ffbe8b30de203570f397866334eb4c6364c9214699010f2086de911ff/kaldi_native_fbank-1.22.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d48e5dd8e897bf4509be2c6eeb4bbab728eaaef1f214ae0510c96219c4253d17", size = 299054, upload-time = "2025-10-09T02:28:42.011Z" }, - { url = "https://files.pythonhosted.org/packages/52/3f/beb161e4fdf6710938ccf18418c147d87ba8f102903d6c6e4eda25588e22/kaldi_native_fbank-1.22.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ce84c65779c9eed6ec02699797a4ba1859451977537a993be3ea8167a210ec3e", size = 321921, upload-time = "2025-10-09T02:31:21.646Z" }, - { url = "https://files.pythonhosted.org/packages/43/28/6f4fd8953c0b3f30de4526fd024095032abcdc25b6736c77a891687c604e/kaldi_native_fbank-1.22.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f5a44b4a83cf9bf13d3f77858928068b06d3ec2238c27ff2e39393fbf7749c9f", size = 298887, upload-time = "2025-10-09T02:30:53.739Z" }, - { url = "https://files.pythonhosted.org/packages/84/90/01ef7331c52b1eaf9916f3f7a535155aac2e9e2ddad12a141613d92758c7/kaldi_native_fbank-1.22.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f16e74372fe9e20abb4183f98a8e2288d5ee4c48d04d94b6160311170e007661", size = 322002, upload-time = "2025-10-09T02:30:13.04Z" }, - { url = "https://files.pythonhosted.org/packages/9a/72/adb11d27c545aca1db442da744ee430a6aae377a33574bfd2ec159dcf673/kaldi_native_fbank-1.22.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f74b85948328ab4b4c88522f98a59f83dd5295443b08483e945c7de2c35e5dcc", size = 299276, upload-time = "2025-10-09T02:30:38.1Z" }, - { url = "https://files.pythonhosted.org/packages/bc/1e/496c7ae814b2a7f8f47d423dc33aae2cdfb1edf898e2faaf5c5b39b90363/kaldi_native_fbank-1.22.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e3f9c6551ff5b6ae785dd15f819c3b2b7432d77bfb79ea8806748e2c7d900b5d", size = 322714, upload-time = "2025-10-09T02:30:32.698Z" }, - { url = "https://files.pythonhosted.org/packages/d6/4b/1f3f17a7b601124df88112a1d1fcb543c8d908d6674f752f7d3322991770/kaldi_native_fbank-1.22.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:41fb506fde155d97aeef95dd6ceccc38c2c5dd4401f9b8fded9bacaf1bafef36", size = 300037, upload-time = "2025-10-09T02:30:10.203Z" }, - { url = "https://files.pythonhosted.org/packages/2b/6a/374ec4e1cf13e672f5acd8272116c1885c2a7f84be491fc652415fc6e870/kaldi_native_fbank-1.22.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f1cc2b8eeec52a33868cf59bb95d40b335fa9cff7e15a6208e0e9b67b7fd7236", size = 322854, upload-time = "2025-10-09T02:31:26.003Z" }, -] - [[package]] name = "keyring" version = "25.7.0" @@ -4221,54 +3982,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/28/89/eb28bfcf97d6b045c400e72eb047c381594467048c237dbb6c227764084c/litellm-1.82.0-py3-none-any.whl", hash = "sha256:5496b5d4532cccdc7a095c21cbac4042f7662021c57bc1d17be4e39838929e80", size = 14911978, upload-time = "2026-03-01T02:35:26.844Z" }, ] -[[package]] -name = "llguidance" -version = "1.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/95/48/3f7a9d3ff1b36bba92b5107a3a21286821227afe9ea464736133994d61fb/llguidance-1.3.0.tar.gz", hash = "sha256:861249afd51dc325646834462ea827e57a5c2b2042e108e6aae7059fdad9104d", size = 1070460, upload-time = "2025-10-20T19:58:44.164Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/aa/11/44389d3d1526d7a5c38ffd587a5ebc61d7bee443ac1dea95f2089ad58f5f/llguidance-1.3.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f6caca5d78db7f76e1fbb0fff8607b861c32d47fa3d5dee2fc49de27ee269df", size = 2835242, upload-time = "2025-10-20T19:58:34.518Z" }, - { url = "https://files.pythonhosted.org/packages/83/a8/1ff2bedb8f9acb46a2d2d603415d272bb622c142ea86f5b95445cc6e366c/llguidance-1.3.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc17e9dd602c3879bf91664a64bf72f54c74dbfbeb24ccfab6a5fe435b12f7aa", size = 3033133, upload-time = "2025-10-20T19:58:38.721Z" }, -] - -[[package]] -name = "llvmlite" -version = "0.44.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/89/6a/95a3d3610d5c75293d5dbbb2a76480d5d4eeba641557b69fe90af6c5b84e/llvmlite-0.44.0.tar.gz", hash = "sha256:07667d66a5d150abed9157ab6c0b9393c9356f229784a4385c02f99e94fc94d4", size = 171880, upload-time = "2025-01-20T11:14:41.342Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/99/fe/d030f1849ebb1f394bb3f7adad5e729b634fb100515594aca25c354ffc62/llvmlite-0.44.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5d22c3bfc842668168a786af4205ec8e3ad29fb1bc03fd11fd48460d0df64c1", size = 42361858, upload-time = "2025-01-20T11:13:07.623Z" }, - { url = "https://files.pythonhosted.org/packages/d7/7a/ce6174664b9077fc673d172e4c888cb0b128e707e306bc33fff8c2035f0d/llvmlite-0.44.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f01a394e9c9b7b1d4e63c327b096d10f6f0ed149ef53d38a09b3749dcf8c9610", size = 41184200, upload-time = "2025-01-20T11:13:20.058Z" }, - { url = "https://files.pythonhosted.org/packages/cb/da/8341fd3056419441286c8e26bf436923021005ece0bff5f41906476ae514/llvmlite-0.44.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0143a5ef336da14deaa8ec26c5449ad5b6a2b564df82fcef4be040b9cacfea9", size = 42361901, upload-time = "2025-01-20T11:13:46.711Z" }, - { url = "https://files.pythonhosted.org/packages/53/ad/d79349dc07b8a395a99153d7ce8b01d6fcdc9f8231355a5df55ded649b61/llvmlite-0.44.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d752f89e31b66db6f8da06df8b39f9b91e78c5feea1bf9e8c1fba1d1c24c065d", size = 41184247, upload-time = "2025-01-20T11:13:56.159Z" }, - { url = "https://files.pythonhosted.org/packages/d2/1b/656f5a357de7135a3777bd735cc7c9b8f23b4d37465505bd0eaf4be9befe/llvmlite-0.44.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46224058b13c96af1365290bdfebe9a6264ae62fb79b2b55693deed11657a8bf", size = 42361904, upload-time = "2025-01-20T11:14:22.949Z" }, - { url = "https://files.pythonhosted.org/packages/d8/e1/12c5f20cb9168fb3464a34310411d5ad86e4163c8ff2d14a2b57e5cc6bac/llvmlite-0.44.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:aa0097052c32bf721a4efc03bd109d335dfa57d9bffb3d4c24cc680711b8b4fc", size = 41184245, upload-time = "2025-01-20T11:14:31.731Z" }, -] - -[[package]] -name = "lm-format-enforcer" -version = "0.11.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "interegular", marker = "sys_platform == 'linux'" }, - { name = "packaging", marker = "sys_platform == 'linux'" }, - { name = "pydantic", marker = "sys_platform == 'linux'" }, - { name = "pyyaml", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/84/d5/41cd417ba7dfdbbcfe46cebf81fb3dfd7c591b89897560ad05bb410a465d/lm_format_enforcer-0.11.3.tar.gz", hash = "sha256:e68081c108719cce284a9bcc889709b26ffb085a1945b5eba3a12cfa96d528da", size = 40258, upload-time = "2025-08-24T19:37:47.527Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/ef/11292bb0b85cf4c93447cab5a29f64576ed14d3ab4280e35ddd23486594a/lm_format_enforcer-0.11.3-py3-none-any.whl", hash = "sha256:cf586350875def1ae7a8fba84fcbbfc8371424b6c9d05c1fcba70aa233fbf06f", size = 45418, upload-time = "2025-08-24T19:37:46.325Z" }, -] - -[[package]] -name = "loguru" -version = "0.7.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3a/05/a1dae3dffd1116099471c643b8924f5aa6524411dc6c63fdae648c4f1aca/loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6", size = 63559, upload-time = "2024-12-06T11:20:56.608Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0c/29/0348de65b8cc732daa3e33e67806420b2ae89bdce2b04af740289c5c6c8c/loguru-0.7.3-py3-none-any.whl", hash = "sha256:31a33c10c8e1e10422bfd431aeb5d351c7cf7fa671e3c4df004162264b28220c", size = 61595, upload-time = "2024-12-06T11:20:54.538Z" }, -] - [[package]] name = "lxml" version = "6.1.0" @@ -4620,30 +4333,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76", size = 9516, upload-time = "2025-10-23T09:00:20.675Z" }, ] -[[package]] -name = "mcp" -version = "1.26.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio", marker = "sys_platform == 'linux'" }, - { name = "httpx", marker = "sys_platform == 'linux'" }, - { name = "httpx-sse", marker = "sys_platform == 'linux'" }, - { name = "jsonschema", marker = "sys_platform == 'linux'" }, - { name = "pydantic", marker = "sys_platform == 'linux'" }, - { name = "pydantic-settings", marker = "sys_platform == 'linux'" }, - { name = "pyjwt", extra = ["crypto"], marker = "sys_platform == 'linux'" }, - { name = "python-multipart", marker = "sys_platform == 'linux'" }, - { name = "sse-starlette", marker = "sys_platform == 'linux'" }, - { name = "starlette", marker = "sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'linux'" }, - { name = "typing-inspection", marker = "sys_platform == 'linux'" }, - { name = "uvicorn", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/fc/6d/62e76bbb8144d6ed86e202b5edd8a4cb631e7c8130f3f4893c3f90262b10/mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66", size = 608005, upload-time = "2026-01-24T19:40:32.468Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fd/d9/eaa1f80170d2b7c5ba23f3b59f766f3a0bb41155fbc32a69adfa1adaaef9/mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca", size = 233615, upload-time = "2026-01-24T19:40:30.652Z" }, -] - [[package]] name = "mdurl" version = "0.1.2" @@ -4771,30 +4460,6 @@ av-decode = [ { name = "soundfile" }, ] -[[package]] -name = "mistral-common" -version = "1.10.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "jsonschema", marker = "sys_platform == 'linux'" }, - { name = "numpy", marker = "sys_platform == 'linux'" }, - { name = "pillow", marker = "sys_platform == 'linux'" }, - { name = "pydantic", marker = "sys_platform == 'linux'" }, - { name = "pydantic-extra-types", extra = ["pycountry"], marker = "sys_platform == 'linux'" }, - { name = "requests", marker = "sys_platform == 'linux'" }, - { name = "tiktoken", marker = "sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a7/22/f798c1acc3f8cf32b6201b063d96867d79aa39d31dff12478739e1a78979/mistral_common-1.10.0.tar.gz", hash = "sha256:e456ff101edbdfc094039ec6c26f7d0f73356729798d628a6e6e96c3917147bc", size = 6351515, upload-time = "2026-03-13T10:13:46.683Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/87/c6/1429a0a3ab40f8530492b62b52eb792266c261b22ed62aa7f25d61d531ae/mistral_common-1.10.0-py3-none-any.whl", hash = "sha256:c594d1a05202b61e8f0d867ec6064df4c5e5d492c2c2bdb6fd8fb4872c6afd8b", size = 6525284, upload-time = "2026-03-13T10:13:44.329Z" }, -] - -[package.optional-dependencies] -image = [ - { name = "opencv-python-headless", marker = "sys_platform == 'linux'" }, -] - [[package]] name = "ml-dtypes" version = "0.5.4" @@ -4915,24 +4580,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b6/9a/7ac1db2ed7b5e21c50fadf925a53f0c77452a8a855ee4a119b084c2fa5d3/mlflow_tracing-3.10.1-py3-none-any.whl", hash = "sha256:649c722cc58d54f1f40559023a6bd6f3f08150c3ce3c3bb27972b3e795890f47", size = 1495173, upload-time = "2026-03-05T10:46:27.395Z" }, ] -[[package]] -name = "model-hosting-container-standards" -version = "0.1.14" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "fastapi", marker = "sys_platform == 'linux'" }, - { name = "httpx", marker = "sys_platform == 'linux'" }, - { name = "jmespath", marker = "sys_platform == 'linux'" }, - { name = "pydantic", marker = "sys_platform == 'linux'" }, - { name = "setuptools", marker = "sys_platform == 'linux'" }, - { name = "starlette", marker = "sys_platform == 'linux'" }, - { name = "supervisor", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c6/3d/cf5c6029648cb0a116f7b5c2f74aa155ab0c6dd723a1f204a6d7ff354526/model_hosting_container_standards-0.1.14.tar.gz", hash = "sha256:b6cf4c46d88ce6acd6e543a578bb88ffd55d1179a7c09c22e61ae1d8a567c564", size = 90386, upload-time = "2026-03-18T21:25:14.513Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/48/94/052452842d39c562237a70345c57ec213a9db22bd25bba998fd2b32d70a7/model_hosting_container_standards-0.1.14-py3-none-any.whl", hash = "sha256:d678be6745899b8ba1e8246c96b101e7802a6a4ea3fb5d90ae8d6eb4204e84c6", size = 121406, upload-time = "2026-03-18T21:25:12.932Z" }, -] - [[package]] name = "more-itertools" version = "10.8.0" @@ -4977,34 +4624,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5e/75/bd9b7bb966668920f06b200e84454c8f3566b102183bc55c5473d96cb2b9/msal_extensions-1.3.1-py3-none-any.whl", hash = "sha256:96d3de4d034504e969ac5e85bae8106c8373b5c6568e4c8fa7af2eca9dbe6bca", size = 20583, upload-time = "2025-03-14T23:51:03.016Z" }, ] -[[package]] -name = "msgpack" -version = "1.1.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/4d/f2/bfb55a6236ed8725a96b0aa3acbd0ec17588e6a2c3b62a93eb513ed8783f/msgpack-1.1.2.tar.gz", hash = "sha256:3b60763c1373dd60f398488069bcdc703cd08a711477b5d480eecc9f9626f47e", size = 173581, upload-time = "2025-10-08T09:15:56.596Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/71/46/b817349db6886d79e57a966346cf0902a426375aadc1e8e7a86a75e22f19/msgpack-1.1.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:61c8aa3bd513d87c72ed0b37b53dd5c5a0f58f2ff9f26e1555d3bd7948fb7296", size = 416962, upload-time = "2025-10-08T09:14:51.997Z" }, - { url = "https://files.pythonhosted.org/packages/da/e0/6cc2e852837cd6086fe7d8406af4294e66827a60a4cf60b86575a4a65ca8/msgpack-1.1.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:454e29e186285d2ebe65be34629fa0e8605202c60fbc7c4c650ccd41870896ef", size = 426183, upload-time = "2025-10-08T09:14:53.477Z" }, - { url = "https://files.pythonhosted.org/packages/25/98/6a19f030b3d2ea906696cedd1eb251708e50a5891d0978b012cb6107234c/msgpack-1.1.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7bc8813f88417599564fafa59fd6f95be417179f76b40325b500b3c98409757c", size = 411454, upload-time = "2025-10-08T09:14:54.648Z" }, - { url = "https://files.pythonhosted.org/packages/b7/cd/9098fcb6adb32187a70b7ecaabf6339da50553351558f37600e53a4a2a23/msgpack-1.1.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bafca952dc13907bdfdedfc6a5f579bf4f292bdd506fadb38389afa3ac5b208e", size = 422341, upload-time = "2025-10-08T09:14:56.328Z" }, - { url = "https://files.pythonhosted.org/packages/f2/60/a064b0345fc36c4c3d2c743c82d9100c40388d77f0b48b2f04d6041dbec1/msgpack-1.1.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c63eea553c69ab05b6747901b97d620bb2a690633c77f23feb0c6a947a8a7b8f", size = 417131, upload-time = "2025-10-08T09:15:05.136Z" }, - { url = "https://files.pythonhosted.org/packages/65/92/a5100f7185a800a5d29f8d14041f61475b9de465ffcc0f3b9fba606e4505/msgpack-1.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:372839311ccf6bdaf39b00b61288e0557916c3729529b301c52c2d88842add42", size = 427556, upload-time = "2025-10-08T09:15:06.837Z" }, - { url = "https://files.pythonhosted.org/packages/f5/87/ffe21d1bf7d9991354ad93949286f643b2bb6ddbeab66373922b44c3b8cc/msgpack-1.1.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2929af52106ca73fcb28576218476ffbb531a036c2adbcf54a3664de124303e9", size = 404920, upload-time = "2025-10-08T09:15:08.179Z" }, - { url = "https://files.pythonhosted.org/packages/ff/41/8543ed2b8604f7c0d89ce066f42007faac1eaa7d79a81555f206a5cdb889/msgpack-1.1.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:be52a8fc79e45b0364210eef5234a7cf8d330836d0a64dfbb878efa903d84620", size = 415013, upload-time = "2025-10-08T09:15:09.83Z" }, - { url = "https://files.pythonhosted.org/packages/d3/68/93180dce57f684a61a88a45ed13047558ded2be46f03acb8dec6d7c513af/msgpack-1.1.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1fdf7d83102bf09e7ce3357de96c59b627395352a4024f6e2458501f158bf999", size = 412721, upload-time = "2025-10-08T09:15:16.567Z" }, - { url = "https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fac4be746328f90caa3cd4bc67e6fe36ca2bf61d5c6eb6d895b6527e3f05071e", size = 424657, upload-time = "2025-10-08T09:15:17.825Z" }, - { url = "https://files.pythonhosted.org/packages/38/f8/4398c46863b093252fe67368b44edc6c13b17f4e6b0e4929dbf0bdb13f23/msgpack-1.1.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:fffee09044073e69f2bad787071aeec727183e7580443dfeb8556cbf1978d162", size = 402668, upload-time = "2025-10-08T09:15:19.003Z" }, - { url = "https://files.pythonhosted.org/packages/28/ce/698c1eff75626e4124b4d78e21cca0b4cc90043afb80a507626ea354ab52/msgpack-1.1.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5928604de9b032bc17f5099496417f113c45bc6bc21b5c6920caf34b3c428794", size = 419040, upload-time = "2025-10-08T09:15:20.183Z" }, - { url = "https://files.pythonhosted.org/packages/8e/a9/3536e385167b88c2cc8f4424c49e28d49a6fc35206d4a8060f136e71f94c/msgpack-1.1.2-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99e2cb7b9031568a2a5c73aa077180f93dd2e95b4f8d3b8e14a73ae94a9e667e", size = 411885, upload-time = "2025-10-08T09:15:27.22Z" }, - { url = "https://files.pythonhosted.org/packages/2f/40/dc34d1a8d5f1e51fc64640b62b191684da52ca469da9cd74e84936ffa4a6/msgpack-1.1.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:180759d89a057eab503cf62eeec0aa61c4ea1200dee709f3a8e9397dbb3b6931", size = 419658, upload-time = "2025-10-08T09:15:28.4Z" }, - { url = "https://files.pythonhosted.org/packages/3b/ef/2b92e286366500a09a67e03496ee8b8ba00562797a52f3c117aa2b29514b/msgpack-1.1.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:04fb995247a6e83830b62f0b07bf36540c213f6eac8e851166d8d86d83cbd014", size = 403290, upload-time = "2025-10-08T09:15:29.764Z" }, - { url = "https://files.pythonhosted.org/packages/78/90/e0ea7990abea5764e4655b8177aa7c63cdfa89945b6e7641055800f6c16b/msgpack-1.1.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8e22ab046fa7ede9e36eeb4cfad44d46450f37bb05d5ec482b02868f451c95e2", size = 415234, upload-time = "2025-10-08T09:15:31.022Z" }, - { url = "https://files.pythonhosted.org/packages/fc/6b/62e85ff7193663fbea5c0254ef32f0c77134b4059f8da89b958beb7696f3/msgpack-1.1.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5559d03930d3aa0f3aacb4c42c776af1a2ace2611871c84a75afe436695e6245", size = 435242, upload-time = "2025-10-08T09:15:37.647Z" }, - { url = "https://files.pythonhosted.org/packages/c1/47/5c74ecb4cc277cf09f64e913947871682ffa82b3b93c8dad68083112f412/msgpack-1.1.2-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:70c5a7a9fea7f036b716191c29047374c10721c389c21e9ffafad04df8c52c90", size = 432509, upload-time = "2025-10-08T09:15:38.794Z" }, - { url = "https://files.pythonhosted.org/packages/24/a4/e98ccdb56dc4e98c929a3f150de1799831c0a800583cde9fa022fa90602d/msgpack-1.1.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f2cb069d8b981abc72b41aea1c580ce92d57c673ec61af4c500153a626cb9e20", size = 415957, upload-time = "2025-10-08T09:15:40.238Z" }, - { url = "https://files.pythonhosted.org/packages/da/28/6951f7fb67bc0a4e184a6b38ab71a92d9ba58080b27a77d3e2fb0be5998f/msgpack-1.1.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d62ce1f483f355f61adb5433ebfd8868c5f078d1a52d042b0a998682b4fa8c27", size = 422910, upload-time = "2025-10-08T09:15:41.505Z" }, -] - [[package]] name = "msgspec" version = "0.20.0" @@ -5348,24 +4967,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/93/a7b983643d1253bb223234b5b226e69de6cda02b76cdca7770f684b795f5/ninja-1.13.0-py3-none-win_arm64.whl", hash = "sha256:3c0b40b1f0bba764644385319028650087b4c1b18cdfa6f45cb39a3669b81aa9", size = 290806, upload-time = "2025-08-11T15:10:18.018Z" }, ] -[[package]] -name = "numba" -version = "0.61.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "llvmlite", marker = "sys_platform == 'linux'" }, - { name = "numpy", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/1c/a0/e21f57604304aa03ebb8e098429222722ad99176a4f979d34af1d1ee80da/numba-0.61.2.tar.gz", hash = "sha256:8750ee147940a6637b80ecf7f95062185ad8726c8c28a2295b8ec1160a196f7d", size = 2820615, upload-time = "2025-04-09T02:58:07.659Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/97/c8/8740616c8436c86c1b9a62e72cb891177d2c34c2d24ddcde4c390371bf4c/numba-0.61.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3945615cd73c2c7eba2a85ccc9c1730c21cd3958bfcf5a44302abae0fb07bb60", size = 3829227, upload-time = "2025-04-09T02:57:46.63Z" }, - { url = "https://files.pythonhosted.org/packages/fc/06/66e99ae06507c31d15ff3ecd1f108f2f59e18b6e08662cd5f8a5853fbd18/numba-0.61.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:bbfdf4eca202cebade0b7d43896978e146f39398909a42941c9303f82f403a18", size = 3523422, upload-time = "2025-04-09T02:57:48.222Z" }, - { url = "https://files.pythonhosted.org/packages/9a/2d/e518df036feab381c23a624dac47f8445ac55686ec7f11083655eb707da3/numba-0.61.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b1bb509d01f23d70325d3a5a0e237cbc9544dd50e50588bc581ba860c213546", size = 3885928, upload-time = "2025-04-09T02:57:55.206Z" }, - { url = "https://files.pythonhosted.org/packages/10/0f/23cced68ead67b75d77cfcca3df4991d1855c897ee0ff3fe25a56ed82108/numba-0.61.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:48a53a3de8f8793526cbe330f2a39fe9a6638efcbf11bd63f3d2f9757ae345cd", size = 3577115, upload-time = "2025-04-09T02:57:56.818Z" }, - { url = "https://files.pythonhosted.org/packages/0d/e0/5ea04e7ad2c39288c0f0f9e8d47638ad70f28e275d092733b5817cf243c9/numba-0.61.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bdbca73ad81fa196bd53dc12e3aaf1564ae036e0c125f237c7644fe64a4928ab", size = 3893918, upload-time = "2025-04-09T02:58:02.933Z" }, - { url = "https://files.pythonhosted.org/packages/17/58/064f4dcb7d7e9412f16ecf80ed753f92297e39f399c905389688cf950b81/numba-0.61.2-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:5f154aaea625fb32cfbe3b80c5456d514d416fcdf79733dd69c0df3a11348e9e", size = 3584056, upload-time = "2025-04-09T02:58:04.538Z" }, -] - [[package]] name = "numpy" version = "1.26.4" @@ -5586,10 +5187,11 @@ wheels = [ [[package]] name = "nvidia-nccl-cu12" -version = "2.27.5" +version = "2.28.9" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6e/89/f7a07dc961b60645dbbf42e80f2bc85ade7feb9a491b11a1e973aa00071f/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457", size = 322348229, upload-time = "2025-06-26T04:11:28.385Z" }, + { url = "https://files.pythonhosted.org/packages/08/c4/120d2dfd92dff2c776d68f361ff8705fdea2ca64e20b612fab0fd3f581ac/nvidia_nccl_cu12-2.28.9-py3-none-manylinux_2_18_aarch64.whl", hash = "sha256:50a36e01c4a090b9f9c47d92cec54964de6b9fcb3362d0e19b8ffc6323c21b60", size = 296766525, upload-time = "2025-11-18T05:49:16.094Z" }, + { url = "https://files.pythonhosted.org/packages/4a/4e/44dbb46b3d1b0ec61afda8e84837870f2f9ace33c564317d59b70bc19d3e/nvidia_nccl_cu12-2.28.9-py3-none-manylinux_2_18_x86_64.whl", hash = "sha256:485776daa8447da5da39681af455aa3b2c2586ddcf4af8772495e7c532c7e5ab", size = 296782137, upload-time = "2025-11-18T05:49:34.248Z" }, ] [[package]] @@ -5791,40 +5393,6 @@ aiohttp = [ { name = "httpx-aiohttp" }, ] -[[package]] -name = "openai-harmony" -version = "0.0.8" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pydantic", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/3e/92/2d038d096f29179c7c9571b431f9e739f87a487121901725e23fe338dd9d/openai_harmony-0.0.8.tar.gz", hash = "sha256:6e43f98e6c242fa2de6f8ea12eab24af63fa2ed3e89c06341fb9d92632c5cbdf", size = 284777, upload-time = "2025-11-05T19:07:06.727Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d3/d2/ce6953ca87db9cae3e775024184da7d1c5cb88cead19a2d75b42f00a959c/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4f709815924ec325b9a890e6ab2bbb0ceec8e319a4e257328eb752cf36b2efc", size = 2948463, upload-time = "2025-11-05T19:06:48.17Z" }, - { url = "https://files.pythonhosted.org/packages/fa/4c/b553c9651662d6ce102ca7f3629d268b23df1abe5841e24bed81e8a8e949/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5cfcfd963b50a41fc656c84d3440ca6eecdccd6c552158ce790b8f2e33dfb5a9", size = 2704083, upload-time = "2025-11-05T19:06:50.205Z" }, - { url = "https://files.pythonhosted.org/packages/9b/af/4eec8f9ab9c27bcdb444460c72cf43011d176fc44c79d6e113094ca1e152/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a3a16972aa1cee38ea958470cd04ac9a2d5ac38fdcf77ab686611246220c158", size = 2959765, upload-time = "2025-11-05T19:06:53.62Z" }, - { url = "https://files.pythonhosted.org/packages/11/3c/33f3374e4624e0e776f6b13b73c45a7ead7f9c4529f8369ed5bfcaa30cac/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b4d5cfa168e74d08f8ba6d58a7e49bc7daef4d58951ec69b66b0d56f4927a68d", size = 3427031, upload-time = "2025-11-05T19:06:51.829Z" }, - { url = "https://files.pythonhosted.org/packages/25/3f/1a192b93bb47c6b44cd98ba8cc1d3d2a9308f1bb700c3017e6352da11bda/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c007d277218a50db8839e599ed78e0fffe5130f614c3f6d93ae257f282071a29", size = 2953260, upload-time = "2025-11-05T19:06:55.406Z" }, - { url = "https://files.pythonhosted.org/packages/5b/f8/93b582cad3531797c3db7c2db5400fd841538ccddfd9f5e3df61be99a630/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:8565d4f5a0638da1bffde29832ed63c9e695c558611053add3b2dc0b56c92dbc", size = 3127044, upload-time = "2025-11-05T19:06:59.553Z" }, - { url = "https://files.pythonhosted.org/packages/1d/10/4327dbf87f75ae813405fd9a9b4a5cde63d506ffed0a096a440a4cabd89c/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:cbaa3bda75ef0d8836e1f8cc84af62f971b1d756d740efc95c38c3e04c0bfde2", size = 2932931, upload-time = "2025-11-05T19:07:01.437Z" }, - { url = "https://files.pythonhosted.org/packages/8a/c8/1774eec4f6f360ef57618fb8f52e3d3af245b2491bd0297513aa09eec04b/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:772922a9bd24e133950fad71eb1550836f415a88e8c77870e12d0c3bd688ddc2", size = 2996140, upload-time = "2025-11-05T19:07:03.438Z" }, - { url = "https://files.pythonhosted.org/packages/60/c3/3d1e01e2dba517a91760e4a03e4f20ffc75039a6fe584d0e6f9b5c78fd15/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:007b0476a1f331f8130783f901f1da6f5a7057af1a4891f1b6a31dec364189b5", size = 3205080, upload-time = "2025-11-05T19:07:05.078Z" }, -] - -[[package]] -name = "opencv-python-headless" -version = "4.13.0.92" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy", marker = "sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/21/76/9417a6aef9def70e467a5bf560579f816148a4c658b7d525581b356eda9e/opencv_python_headless-4.13.0.92-cp37-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5c8cfc8e87ed452b5cecb9419473ee5560a989859fe1d10d1ce11ae87b09a2cb", size = 33703709, upload-time = "2026-02-05T10:24:46.469Z" }, - { url = "https://files.pythonhosted.org/packages/92/ce/bd17ff5772938267fd49716e94ca24f616ff4cb1ff4c6be13085108037be/opencv_python_headless-4.13.0.92-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0525a3d2c0b46c611e2130b5fdebc94cf404845d8fa64d2f3a3b679572a5bd22", size = 56016764, upload-time = "2026-02-05T10:26:48.904Z" }, - { url = "https://files.pythonhosted.org/packages/8f/b4/b7bcbf7c874665825a8c8e1097e93ea25d1f1d210a3e20d4451d01da30aa/opencv_python_headless-4.13.0.92-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:eb60e36b237b1ebd40a912da5384b348df8ed534f6f644d8e0b4f103e272ba7d", size = 35010236, upload-time = "2026-02-05T10:28:11.031Z" }, - { url = "https://files.pythonhosted.org/packages/4b/33/b5db29a6c00eb8f50708110d8d453747ca125c8b805bc437b289dbdcc057/opencv_python_headless-4.13.0.92-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:0bd48544f77c68b2941392fcdf9bcd2b9cdf00e98cb8c29b2455d194763cf99e", size = 60391106, upload-time = "2026-02-05T10:30:14.236Z" }, -] - [[package]] name = "openpipe-art" version = "0.5.17" @@ -5851,6 +5419,7 @@ backend = [ { name = "nbclient" }, { name = "nbmake" }, { name = "nvidia-cudnn-frontend", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-resiliency-ext" }, { name = "peft" }, { name = "pyarrow" }, @@ -5862,7 +5431,6 @@ backend = [ { name = "trl" }, { name = "unsloth" }, { name = "unsloth-zoo" }, - { name = "vllm", marker = "sys_platform == 'linux'" }, { name = "wandb" }, ] langgraph = [ @@ -5880,8 +5448,10 @@ megatron = [ { name = "ml-dtypes", marker = "python_full_version < '3.13'" }, { name = "numpy" }, { name = "nvidia-ml-py" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-resiliency-ext" }, { name = "pybind11" }, + { name = "quack-kernels" }, { name = "torch" }, { name = "transformer-engine" }, { name = "transformer-engine-cu12" }, @@ -5896,6 +5466,7 @@ tinker = [ { name = "fastapi" }, { name = "huggingface-hub" }, { name = "numpy" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'" }, { name = "pillow" }, { name = "pyarrow" }, { name = "pydantic" }, @@ -5955,6 +5526,9 @@ requires-dist = [ { name = "numpy", marker = "extra == 'tinker'", specifier = "<2" }, { name = "nvidia-cudnn-frontend", marker = "sys_platform == 'linux' and extra == 'backend'", specifier = "<1.21" }, { name = "nvidia-ml-py", marker = "extra == 'megatron'", specifier = "==13.580.82" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux' and extra == 'backend'", specifier = "==2.28.9" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux' and extra == 'megatron'", specifier = "==2.28.9" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux' and extra == 'tinker'", specifier = "==2.28.9" }, { name = "nvidia-resiliency-ext", marker = "sys_platform == 'linux' and extra == 'backend'", specifier = "<0.5" }, { name = "nvidia-resiliency-ext", marker = "sys_platform == 'linux' and extra == 'megatron'", specifier = "<0.5" }, { name = "openai", specifier = ">=2.14.0" }, @@ -5966,6 +5540,7 @@ requires-dist = [ { name = "pybind11", marker = "extra == 'megatron'", specifier = ">=2.13.6" }, { name = "pydantic", marker = "extra == 'tinker'", specifier = ">=2.12.5" }, { name = "pytest", marker = "extra == 'backend'", specifier = ">=8.4.1" }, + { name = "quack-kernels", marker = "extra == 'megatron'", specifier = "==0.2.5" }, { name = "seaborn", marker = "extra == 'plotting'", specifier = ">=0.13.2" }, { name = "setproctitle", specifier = ">=1.3.6" }, { name = "setuptools", marker = "extra == 'backend'", specifier = ">=78.1.0" }, @@ -5986,7 +5561,6 @@ requires-dist = [ { name = "unsloth", marker = "extra == 'backend'", specifier = "==2026.3.3" }, { name = "unsloth-zoo", marker = "extra == 'backend'", specifier = "==2026.3.1" }, { name = "uvicorn", marker = "extra == 'tinker'", specifier = ">=0.35.0" }, - { name = "vllm", marker = "sys_platform == 'linux' and extra == 'backend'", url = "https://github.com/vivekkalyan/vllm/releases/download/v0.17.0-art1/vllm-0.17.0%2Bart1-cp38-abi3-manylinux_2_31_x86_64.whl" }, { name = "wandb", marker = "extra == 'backend'", specifier = "==0.25.0" }, { name = "weave", specifier = ">=0.52.24" }, ] @@ -6024,67 +5598,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/05/44/4c45a34def3506122ae61ad684139f0bbc4e00c39555d4f7e20e0e001c8a/opentelemetry_api-1.33.1-py3-none-any.whl", hash = "sha256:4db83ebcf7ea93e64637ec6ee6fabee45c5cbe4abd9cf3da95c43828ddb50b83", size = 65771, upload-time = "2025-05-16T18:52:17.419Z" }, ] -[[package]] -name = "opentelemetry-exporter-otlp" -version = "1.33.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "opentelemetry-exporter-otlp-proto-grpc", marker = "sys_platform == 'linux'" }, - { name = "opentelemetry-exporter-otlp-proto-http", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b1/3f/c8ad4f1c3aaadcea2b0f1b4d7970e7b7898c145699769a789f3435143f69/opentelemetry_exporter_otlp-1.33.1.tar.gz", hash = "sha256:4d050311ea9486e3994575aa237e32932aad58330a31fba24fdba5c0d531cf04", size = 6189, upload-time = "2025-05-16T18:52:43.176Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4d/32/b9add70dd4e845654fc9fcd1401a705477743880be6c3e62acb1ad0d8662/opentelemetry_exporter_otlp-1.33.1-py3-none-any.whl", hash = "sha256:9bcf1def35b880b55a49e31ebd63910edac14b294fd2ab884953c4deaff5b300", size = 7045, upload-time = "2025-05-16T18:52:21.022Z" }, -] - -[[package]] -name = "opentelemetry-exporter-otlp-proto-common" -version = "1.33.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "opentelemetry-proto", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/7a/18/a1ec9dcb6713a48b4bdd10f1c1e4d5d2489d3912b80d2bcc059a9a842836/opentelemetry_exporter_otlp_proto_common-1.33.1.tar.gz", hash = "sha256:c57b3fa2d0595a21c4ed586f74f948d259d9949b58258f11edb398f246bec131", size = 20828, upload-time = "2025-05-16T18:52:43.795Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/09/52/9bcb17e2c29c1194a28e521b9d3f2ced09028934c3c52a8205884c94b2df/opentelemetry_exporter_otlp_proto_common-1.33.1-py3-none-any.whl", hash = "sha256:b81c1de1ad349785e601d02715b2d29d6818aed2c809c20219f3d1f20b038c36", size = 18839, upload-time = "2025-05-16T18:52:22.447Z" }, -] - -[[package]] -name = "opentelemetry-exporter-otlp-proto-grpc" -version = "1.33.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "deprecated", marker = "sys_platform == 'linux'" }, - { name = "googleapis-common-protos", marker = "sys_platform == 'linux'" }, - { name = "grpcio", marker = "sys_platform == 'linux'" }, - { name = "opentelemetry-api", marker = "sys_platform == 'linux'" }, - { name = "opentelemetry-exporter-otlp-proto-common", marker = "sys_platform == 'linux'" }, - { name = "opentelemetry-proto", marker = "sys_platform == 'linux'" }, - { name = "opentelemetry-sdk", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d8/5f/75ef5a2a917bd0e6e7b83d3fb04c99236ee958f6352ba3019ea9109ae1a6/opentelemetry_exporter_otlp_proto_grpc-1.33.1.tar.gz", hash = "sha256:345696af8dc19785fac268c8063f3dc3d5e274c774b308c634f39d9c21955728", size = 22556, upload-time = "2025-05-16T18:52:44.76Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ba/ec/6047e230bb6d092c304511315b13893b1c9d9260044dd1228c9d48b6ae0e/opentelemetry_exporter_otlp_proto_grpc-1.33.1-py3-none-any.whl", hash = "sha256:7e8da32c7552b756e75b4f9e9c768a61eb47dee60b6550b37af541858d669ce1", size = 18591, upload-time = "2025-05-16T18:52:23.772Z" }, -] - -[[package]] -name = "opentelemetry-exporter-otlp-proto-http" -version = "1.33.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "deprecated", marker = "sys_platform == 'linux'" }, - { name = "googleapis-common-protos", marker = "sys_platform == 'linux'" }, - { name = "opentelemetry-api", marker = "sys_platform == 'linux'" }, - { name = "opentelemetry-exporter-otlp-proto-common", marker = "sys_platform == 'linux'" }, - { name = "opentelemetry-proto", marker = "sys_platform == 'linux'" }, - { name = "opentelemetry-sdk", marker = "sys_platform == 'linux'" }, - { name = "requests", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/60/48/e4314ac0ed2ad043c07693d08c9c4bf5633857f5b72f2fefc64fd2b114f6/opentelemetry_exporter_otlp_proto_http-1.33.1.tar.gz", hash = "sha256:46622d964a441acb46f463ebdc26929d9dec9efb2e54ef06acdc7305e8593c38", size = 15353, upload-time = "2025-05-16T18:52:45.522Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/63/ba/5a4ad007588016fe37f8d36bf08f325fe684494cc1e88ca8fa064a4c8f57/opentelemetry_exporter_otlp_proto_http-1.33.1-py3-none-any.whl", hash = "sha256:ebd6c523b89a2ecba0549adb92537cc2bf647b4ee61afbbd5a4c6535aa3da7cf", size = 17733, upload-time = "2025-05-16T18:52:25.137Z" }, -] - [[package]] name = "opentelemetry-proto" version = "1.33.1" @@ -6124,15 +5637,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/80/08b1698c52ff76d96ba440bf15edc2f4bc0a279868778928e947c1004bdd/opentelemetry_semantic_conventions-0.54b1-py3-none-any.whl", hash = "sha256:29dab644a7e435b58d3a3918b58c333c92686236b30f7891d5e51f02933ca60d", size = 194938, upload-time = "2025-05-16T18:52:38.796Z" }, ] -[[package]] -name = "opentelemetry-semantic-conventions-ai" -version = "0.4.13" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ba/e6/40b59eda51ac47009fb47afcdf37c6938594a0bd7f3b9fadcbc6058248e3/opentelemetry_semantic_conventions_ai-0.4.13.tar.gz", hash = "sha256:94efa9fb4ffac18c45f54a3a338ffeb7eedb7e1bb4d147786e77202e159f0036", size = 5368, upload-time = "2025-08-22T10:14:17.387Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/35/b5/cf25da2218910f0d6cdf7f876a06bed118c4969eacaf60a887cbaef44f44/opentelemetry_semantic_conventions_ai-0.4.13-py3-none-any.whl", hash = "sha256:883a30a6bb5deaec0d646912b5f9f6dcbb9f6f72557b73d0f2560bf25d13e2d5", size = 6080, upload-time = "2025-08-22T10:14:16.477Z" }, -] - [[package]] name = "orjson" version = "3.11.7" @@ -6249,20 +5753,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/cd/29cee6007bddf7a834e6cd6f536754c0535fcb939d384f0f37a38b1cddb8/ormsgpack-1.12.2-cp314-cp314t-win_amd64.whl", hash = "sha256:837dd316584485b72ef451d08dd3e96c4a11d12e4963aedb40e08f89685d8ec2", size = 117232, upload-time = "2026-01-18T20:55:45.448Z" }, ] -[[package]] -name = "outlines-core" -version = "0.2.11" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1a/d3/e04e9145f8f806723dec9b9e5227ad695a3efcd3ced7794cf7c22b15df5e/outlines_core-0.2.11.tar.gz", hash = "sha256:dfce56f717ff5083e54cbcfdb66cad243365437fccbb5509adaa7e31e030f1d8", size = 197263, upload-time = "2025-05-19T10:12:51.719Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4c/db/32c6e1170f139420e948fdd18a09a6175244bc0760dcf4dc2470e18411b9/outlines_core-0.2.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:132605b8dd1e3d1369da6a851992dd357f6376068292f6bd47caa7a28b794d19", size = 2289078, upload-time = "2025-05-19T10:12:12.118Z" }, - { url = "https://files.pythonhosted.org/packages/25/c3/b6e6f4e08fa84d2424f82705a6dc47fee33cb91989010fa678736957dcf6/outlines_core-0.2.11-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:b31d5fc83b78aad282dd667b8d6e684614481fe08a7609ce0ce45dee64cd2991", size = 2115075, upload-time = "2025-05-19T10:12:13.761Z" }, - { url = "https://files.pythonhosted.org/packages/92/c7/a65d1fddf49830ebc41422294eacde35286d9f68994a8aa905cb14f5aade/outlines_core-0.2.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86df9740368866295077346440d911df4972da2b3f1f54b8125e6f329e8a8891", size = 2287677, upload-time = "2025-05-19T10:12:24.24Z" }, - { url = "https://files.pythonhosted.org/packages/23/79/8795aed8be9b77dd69d78e7cfbfcf28c179e6b08da6e56bbbf48a09fe55f/outlines_core-0.2.11-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:96ce4dd78f106799be4a0a5795cefd1352806162973756a4b6fce4bb6eddd7e4", size = 2113000, upload-time = "2025-05-19T10:12:25.446Z" }, - { url = "https://files.pythonhosted.org/packages/87/96/7dcdc5198844145ab35528f9f93a58c3d47b87e54d0f79357c631d7b7a9a/outlines_core-0.2.11-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:daef6eaaf8c3403455ab5cbf265cb5c6838df571eb7c4b23cddac19cfc701726", size = 2287320, upload-time = "2025-05-19T10:12:35.515Z" }, - { url = "https://files.pythonhosted.org/packages/4d/68/b420b6a3beaadbf8e9f2a82132120027efd6424634013fbeca8c2fed7467/outlines_core-0.2.11-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:76b2512417c68863f8f227a080e87f755682dfd895e23b021121318be11da579", size = 2112861, upload-time = "2025-05-19T10:12:36.742Z" }, -] - [[package]] name = "packaging" version = "26.0" @@ -6350,15 +5840,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b6/61/fae042894f4296ec49e3f193aff5d7c18440da9e48102c3315e1bc4519a7/parso-0.8.6-py2.py3-none-any.whl", hash = "sha256:2c549f800b70a5c4952197248825584cb00f033b29c692671d3bf08bf380baff", size = 106894, upload-time = "2026-02-09T15:45:21.391Z" }, ] -[[package]] -name = "partial-json-parser" -version = "0.2.1.1.post7" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6a/6d/eed37d7ebc1e0bcd27b831c0cf1fe94881934316187c4b30d23f29ea0bd4/partial_json_parser-0.2.1.1.post7.tar.gz", hash = "sha256:86590e1ba6bcb6739a2dfc17d2323f028cb5884f4c6ce23db376999132c9a922", size = 10296, upload-time = "2025-11-17T07:27:41.202Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/42/32/658973117bf0fd82a24abbfb94fe73a5e86216e49342985e10acce54775a/partial_json_parser-0.2.1.1.post7-py3-none-any.whl", hash = "sha256:145119e5eabcf80cbb13844a6b50a85c68bf99d376f8ed771e2a3c3b03e653ae", size = 10877, upload-time = "2025-11-17T07:27:40.457Z" }, -] - [[package]] name = "passlib" version = "1.7.4" @@ -6713,19 +6194,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/74/c3/24a2f845e3917201628ecaba4f18bab4d18a337834c1df2a159ee9d22a42/prometheus_client-0.24.1-py3-none-any.whl", hash = "sha256:150db128af71a5c2482b36e588fc8a6b95e498750da4b17065947c16070f4055", size = 64057, upload-time = "2026-01-14T15:26:24.42Z" }, ] -[[package]] -name = "prometheus-fastapi-instrumentator" -version = "7.1.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "prometheus-client", marker = "sys_platform == 'linux'" }, - { name = "starlette", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/69/6d/24d53033cf93826aa7857699a4450c1c67e5b9c710e925b1ed2b320c04df/prometheus_fastapi_instrumentator-7.1.0.tar.gz", hash = "sha256:be7cd61eeea4e5912aeccb4261c6631b3f227d8924542d79eaf5af3f439cbe5e", size = 20220, upload-time = "2025-03-19T19:35:05.351Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/27/72/0824c18f3bc75810f55dacc2dd933f6ec829771180245ae3cc976195dec0/prometheus_fastapi_instrumentator-7.1.0-py3-none-any.whl", hash = "sha256:978130f3c0bb7b8ebcc90d35516a6fe13e02d2eb358c8f83887cdef7020c31e9", size = 19296, upload-time = "2025-03-19T19:35:04.323Z" }, -] - [[package]] name = "prompt-toolkit" version = "3.0.52" @@ -7050,111 +6518,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, ] -[[package]] -name = "pybase64" -version = "1.4.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/aa/b8/4ed5c7ad5ec15b08d35cc79ace6145d5c1ae426e46435f4987379439dfea/pybase64-1.4.3.tar.gz", hash = "sha256:c2ed274c9e0ba9c8f9c4083cfe265e66dd679126cd9c2027965d807352f3f053", size = 137272, upload-time = "2025-12-06T13:27:04.013Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3c/fb/bb06a5b9885e7d853ac1e801c4d8abfdb4c8506deee33e53d55aa6690e67/pybase64-1.4.3-cp311-cp311-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:f9ef0388878bc15a084bd9bf73ec1b2b4ee513d11009b1506375e10a7aae5032", size = 68331, upload-time = "2025-12-06T13:22:54.197Z" }, - { url = "https://files.pythonhosted.org/packages/64/15/8d60b9ec5e658185fc2ee3333e01a6e30d717cf677b24f47cbb3a859d13c/pybase64-1.4.3-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:95a57cccf106352a72ed8bc8198f6820b16cc7d55aa3867a16dea7011ae7c218", size = 71370, upload-time = "2025-12-06T13:22:55.517Z" }, - { url = "https://files.pythonhosted.org/packages/ac/29/a3e5c1667cc8c38d025a4636855de0fc117fc62e2afeb033a3c6f12c6a22/pybase64-1.4.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7cd1c47dfceb9c7bd3de210fb4e65904053ed2d7c9dce6d107f041ff6fbd7e21", size = 59834, upload-time = "2025-12-06T13:22:56.682Z" }, - { url = "https://files.pythonhosted.org/packages/a9/00/8ffcf9810bd23f3984698be161cf7edba656fd639b818039a7be1d6405d4/pybase64-1.4.3-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.whl", hash = "sha256:9fe9922698f3e2f72874b26890d53a051c431d942701bb3a37aae94da0b12107", size = 56652, upload-time = "2025-12-06T13:22:57.724Z" }, - { url = "https://files.pythonhosted.org/packages/81/62/379e347797cdea4ab686375945bc77ad8d039c688c0d4d0cfb09d247beb9/pybase64-1.4.3-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:af5f4bd29c86b59bb4375e0491d16ec8a67548fa99c54763aaedaf0b4b5a6632", size = 59382, upload-time = "2025-12-06T13:22:58.758Z" }, - { url = "https://files.pythonhosted.org/packages/c6/f2/9338ffe2f487086f26a2c8ca175acb3baa86fce0a756ff5670a0822bb877/pybase64-1.4.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:c302f6ca7465262908131411226e02100f488f531bb5e64cb901aa3f439bccd9", size = 59990, upload-time = "2025-12-06T13:23:01.007Z" }, - { url = "https://files.pythonhosted.org/packages/f9/a4/85a6142b65b4df8625b337727aa81dc199642de3d09677804141df6ee312/pybase64-1.4.3-cp311-cp311-manylinux_2_31_riscv64.whl", hash = "sha256:2f3f439fa4d7fde164ebbbb41968db7d66b064450ab6017c6c95cef0afa2b349", size = 54923, upload-time = "2025-12-06T13:23:02.369Z" }, - { url = "https://files.pythonhosted.org/packages/ac/00/e40215d25624012bf5b7416ca37f168cb75f6dd15acdb91ea1f2ea4dc4e7/pybase64-1.4.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7a23c6866551043f8b681a5e1e0d59469148b2920a3b4fc42b1275f25ea4217a", size = 58664, upload-time = "2025-12-06T13:23:03.378Z" }, - { url = "https://files.pythonhosted.org/packages/b0/73/d7e19a63e795c13837f2356268d95dc79d1180e756f57ced742a1e52fdeb/pybase64-1.4.3-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:56e6526f8565642abc5f84338cc131ce298a8ccab696b19bdf76fa6d7dc592ef", size = 52338, upload-time = "2025-12-06T13:23:04.458Z" }, - { url = "https://files.pythonhosted.org/packages/f2/32/3c746d7a310b69bdd9df77ffc85c41b80bce00a774717596f869b0d4a20e/pybase64-1.4.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:6a792a8b9d866ffa413c9687d9b611553203753987a3a582d68cbc51cf23da45", size = 68993, upload-time = "2025-12-06T13:23:05.526Z" }, - { url = "https://files.pythonhosted.org/packages/5d/b3/63cec68f9d6f6e4c0b438d14e5f1ef536a5fe63ce14b70733ac5e31d7ab8/pybase64-1.4.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:62ad29a5026bb22cfcd1ca484ec34b0a5ced56ddba38ceecd9359b2818c9c4f9", size = 58055, upload-time = "2025-12-06T13:23:06.931Z" }, - { url = "https://files.pythonhosted.org/packages/d5/cb/7acf7c3c06f9692093c07f109668725dc37fb9a3df0fa912b50add645195/pybase64-1.4.3-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:11b9d1d2d32ec358c02214363b8fc3651f6be7dd84d880ecd597a6206a80e121", size = 54430, upload-time = "2025-12-06T13:23:07.936Z" }, - { url = "https://files.pythonhosted.org/packages/33/39/4eb33ff35d173bfff4002e184ce8907f5d0a42d958d61cd9058ef3570179/pybase64-1.4.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:0aebaa7f238caa0a0d373616016e2040c6c879ebce3ba7ab3c59029920f13640", size = 56272, upload-time = "2025-12-06T13:23:09.253Z" }, - { url = "https://files.pythonhosted.org/packages/19/97/a76d65c375a254e65b730c6f56bf528feca91305da32eceab8bcc08591e6/pybase64-1.4.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e504682b20c63c2b0c000e5f98a80ea867f8d97642e042a5a39818e44ba4d599", size = 70904, upload-time = "2025-12-06T13:23:10.336Z" }, - { url = "https://files.pythonhosted.org/packages/43/1b/9a8cab0042b464e9a876d5c65fe5127445a2436da36fda64899b119b1a1b/pybase64-1.4.3-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:f0b3f200c3e06316f6bebabd458b4e4bcd4c2ca26af7c0c766614d91968dee27", size = 68210, upload-time = "2025-12-06T13:23:18.813Z" }, - { url = "https://files.pythonhosted.org/packages/62/f7/965b79ff391ad208b50e412b5d3205ccce372a2d27b7218ae86d5295b105/pybase64-1.4.3-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:bb632edfd132b3eaf90c39c89aa314beec4e946e210099b57d40311f704e11d4", size = 71599, upload-time = "2025-12-06T13:23:20.195Z" }, - { url = "https://files.pythonhosted.org/packages/03/4b/a3b5175130b3810bbb8ccfa1edaadbd3afddb9992d877c8a1e2f274b476e/pybase64-1.4.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:356ef1d74648ce997f5a777cf8f1aefecc1c0b4fe6201e0ef3ec8a08170e1b54", size = 59922, upload-time = "2025-12-06T13:23:21.487Z" }, - { url = "https://files.pythonhosted.org/packages/da/5d/c38d1572027fc601b62d7a407721688b04b4d065d60ca489912d6893e6cf/pybase64-1.4.3-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.whl", hash = "sha256:c48361f90db32bacaa5518419d4eb9066ba558013aaf0c7781620279ecddaeb9", size = 56712, upload-time = "2025-12-06T13:23:22.77Z" }, - { url = "https://files.pythonhosted.org/packages/e7/d4/4e04472fef485caa8f561d904d4d69210a8f8fc1608ea15ebd9012b92655/pybase64-1.4.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:702bcaa16ae02139d881aeaef5b1c8ffb4a3fae062fe601d1e3835e10310a517", size = 59300, upload-time = "2025-12-06T13:23:24.543Z" }, - { url = "https://files.pythonhosted.org/packages/86/e7/16e29721b86734b881d09b7e23dfd7c8408ad01a4f4c7525f3b1088e25ec/pybase64-1.4.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:53d0ffe1847b16b647c6413d34d1de08942b7724273dd57e67dcbdb10c574045", size = 60278, upload-time = "2025-12-06T13:23:25.608Z" }, - { url = "https://files.pythonhosted.org/packages/b1/02/18515f211d7c046be32070709a8efeeef8a0203de4fd7521e6b56404731b/pybase64-1.4.3-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:9a1792e8b830a92736dae58f0c386062eb038dfe8004fb03ba33b6083d89cd43", size = 54817, upload-time = "2025-12-06T13:23:26.633Z" }, - { url = "https://files.pythonhosted.org/packages/e7/be/14e29d8e1a481dbff151324c96dd7b5d2688194bb65dc8a00ca0e1ad1e86/pybase64-1.4.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1d468b1b1ac5ad84875a46eaa458663c3721e8be5f155ade356406848d3701f6", size = 58611, upload-time = "2025-12-06T13:23:27.684Z" }, - { url = "https://files.pythonhosted.org/packages/b4/8a/a2588dfe24e1bbd742a554553778ab0d65fdf3d1c9a06d10b77047d142aa/pybase64-1.4.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:e97b7bdbd62e71898cd542a6a9e320d9da754ff3ebd02cb802d69087ee94d468", size = 52404, upload-time = "2025-12-06T13:23:28.714Z" }, - { url = "https://files.pythonhosted.org/packages/27/fc/afcda7445bebe0cbc38cafdd7813234cdd4fc5573ff067f1abf317bb0cec/pybase64-1.4.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b33aeaa780caaa08ffda87fc584d5eab61e3d3bbb5d86ead02161dc0c20d04bc", size = 68817, upload-time = "2025-12-06T13:23:30.079Z" }, - { url = "https://files.pythonhosted.org/packages/d3/3a/87c3201e555ed71f73e961a787241a2438c2bbb2ca8809c29ddf938a3157/pybase64-1.4.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:1c0efcf78f11cf866bed49caa7b97552bc4855a892f9cc2372abcd3ed0056f0d", size = 57854, upload-time = "2025-12-06T13:23:31.17Z" }, - { url = "https://files.pythonhosted.org/packages/fd/7d/931c2539b31a7b375e7d595b88401eeb5bd6c5ce1059c9123f9b608aaa14/pybase64-1.4.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:66e3791f2ed725a46593f8bd2761ff37d01e2cdad065b1dceb89066f476e50c6", size = 54333, upload-time = "2025-12-06T13:23:32.422Z" }, - { url = "https://files.pythonhosted.org/packages/de/5e/537601e02cc01f27e9d75f440f1a6095b8df44fc28b1eef2cd739aea8cec/pybase64-1.4.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:72bb0b6bddadab26e1b069bb78e83092711a111a80a0d6b9edcb08199ad7299b", size = 56492, upload-time = "2025-12-06T13:23:33.515Z" }, - { url = "https://files.pythonhosted.org/packages/96/97/2a2e57acf8f5c9258d22aba52e71f8050e167b29ed2ee1113677c1b600c1/pybase64-1.4.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5b3365dbcbcdb0a294f0f50af0c0a16b27a232eddeeb0bceeefd844ef30d2a23", size = 70974, upload-time = "2025-12-06T13:23:36.27Z" }, - { url = "https://files.pythonhosted.org/packages/5c/8d/20b68f11adfc4c22230e034b65c71392e3e338b413bf713c8945bd2ccfb3/pybase64-1.4.3-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:27fdff227a0c0e182e0ba37a99109645188978b920dfb20d8b9c17eeee370d0d", size = 30932, upload-time = "2025-12-06T13:23:43.348Z" }, - { url = "https://files.pythonhosted.org/packages/f7/79/b1b550ac6bff51a4880bf6e089008b2e1ca16f2c98db5e039a08ac3ad157/pybase64-1.4.3-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:2a8204f1fdfec5aa4184249b51296c0de95445869920c88123978304aad42df1", size = 31394, upload-time = "2025-12-06T13:23:44.317Z" }, - { url = "https://files.pythonhosted.org/packages/82/70/b5d7c5932bf64ee1ec5da859fbac981930b6a55d432a603986c7f509c838/pybase64-1.4.3-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:874fc2a3777de6baf6aa921a7aa73b3be98295794bea31bd80568a963be30767", size = 38078, upload-time = "2025-12-06T13:23:45.348Z" }, - { url = "https://files.pythonhosted.org/packages/1c/c9/24b3b905cf75e23a9a4deaf203b35ffcb9f473ac0e6d8257f91a05dfce62/pybase64-1.4.3-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:1d45c8fe8fe82b65c36b227bb4a2cf623d9ada16bed602ce2d3e18c35285b72a", size = 68244, upload-time = "2025-12-06T13:23:49.026Z" }, - { url = "https://files.pythonhosted.org/packages/f8/cd/d15b0c3e25e5859fab0416dc5b96d34d6bd2603c1c96a07bb2202b68ab92/pybase64-1.4.3-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:ad70c26ba091d8f5167e9d4e1e86a0483a5414805cdb598a813db635bd3be8b8", size = 71620, upload-time = "2025-12-06T13:23:50.081Z" }, - { url = "https://files.pythonhosted.org/packages/0d/31/4ca953cc3dcde2b3711d6bfd70a6f4ad2ca95a483c9698076ba605f1520f/pybase64-1.4.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e98310b7c43145221e7194ac9fa7fffc84763c87bfc5e2f59f9f92363475bdc1", size = 59930, upload-time = "2025-12-06T13:23:51.68Z" }, - { url = "https://files.pythonhosted.org/packages/60/55/e7f7bdcd0fd66e61dda08db158ffda5c89a306bbdaaf5a062fbe4e48f4a1/pybase64-1.4.3-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.whl", hash = "sha256:398685a76034e91485a28aeebcb49e64cd663212fd697b2497ac6dfc1df5e671", size = 56425, upload-time = "2025-12-06T13:23:52.732Z" }, - { url = "https://files.pythonhosted.org/packages/cb/65/b592c7f921e51ca1aca3af5b0d201a98666d0a36b930ebb67e7c2ed27395/pybase64-1.4.3-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:7e46400a6461187ccb52ed75b0045d937529e801a53a9cd770b350509f9e4d50", size = 59327, upload-time = "2025-12-06T13:23:53.856Z" }, - { url = "https://files.pythonhosted.org/packages/23/95/1613d2fb82dbb1548595ad4179f04e9a8451bfa18635efce18b631eabe3f/pybase64-1.4.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:1b62b9f2f291d94f5e0b76ab499790b7dcc78a009d4ceea0b0428770267484b6", size = 60294, upload-time = "2025-12-06T13:23:54.937Z" }, - { url = "https://files.pythonhosted.org/packages/9d/73/40431f37f7d1b3eab4673e7946ff1e8f5d6bd425ec257e834dae8a6fc7b0/pybase64-1.4.3-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:f30ceb5fa4327809dede614be586efcbc55404406d71e1f902a6fdcf322b93b2", size = 54858, upload-time = "2025-12-06T13:23:56.031Z" }, - { url = "https://files.pythonhosted.org/packages/a7/84/f6368bcaf9f743732e002a9858646fd7a54f428490d427dd6847c5cfe89e/pybase64-1.4.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0d5f18ed53dfa1d4cf8b39ee542fdda8e66d365940e11f1710989b3cf4a2ed66", size = 58629, upload-time = "2025-12-06T13:23:57.12Z" }, - { url = "https://files.pythonhosted.org/packages/43/75/359532f9adb49c6b546cafc65c46ed75e2ccc220d514ba81c686fbd83965/pybase64-1.4.3-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:119d31aa4b58b85a8ebd12b63c07681a138c08dfc2fe5383459d42238665d3eb", size = 52448, upload-time = "2025-12-06T13:23:58.298Z" }, - { url = "https://files.pythonhosted.org/packages/92/6c/ade2ba244c3f33ed920a7ed572ad772eb0b5f14480b72d629d0c9e739a40/pybase64-1.4.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:3cf0218b0e2f7988cf7d738a73b6a1d14f3be6ce249d7c0f606e768366df2cce", size = 68841, upload-time = "2025-12-06T13:23:59.886Z" }, - { url = "https://files.pythonhosted.org/packages/a0/51/b345139cd236be382f2d4d4453c21ee6299e14d2f759b668e23080f8663f/pybase64-1.4.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:12f4ee5e988bc5c0c1106b0d8fc37fb0508f12dab76bac1b098cb500d148da9d", size = 57910, upload-time = "2025-12-06T13:24:00.994Z" }, - { url = "https://files.pythonhosted.org/packages/1a/b8/9f84bdc4f1c4f0052489396403c04be2f9266a66b70c776001eaf0d78c1f/pybase64-1.4.3-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:937826bc7b6b95b594a45180e81dd4d99bd4dd4814a443170e399163f7ff3fb6", size = 54335, upload-time = "2025-12-06T13:24:02.046Z" }, - { url = "https://files.pythonhosted.org/packages/d0/c7/be63b617d284de46578a366da77ede39c8f8e815ed0d82c7c2acca560fab/pybase64-1.4.3-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:88995d1460971ef80b13e3e007afbe4b27c62db0508bc7250a2ab0a0b4b91362", size = 56486, upload-time = "2025-12-06T13:24:03.141Z" }, - { url = "https://files.pythonhosted.org/packages/5e/96/f252c8f9abd6ded3ef1ccd3cdbb8393a33798007f761b23df8de1a2480e6/pybase64-1.4.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:72326fe163385ed3e1e806dd579d47fde5d8a59e51297a60fc4e6cbc1b4fc4ed", size = 70978, upload-time = "2025-12-06T13:24:04.221Z" }, - { url = "https://files.pythonhosted.org/packages/46/fc/cb64964c3b29b432f54d1bce5e7691d693e33bbf780555151969ffd95178/pybase64-1.4.3-cp313-cp313t-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:2e745f2ce760c6cf04d8a72198ef892015ddb89f6ceba489e383518ecbdb13ab", size = 72317, upload-time = "2025-12-06T13:24:11.129Z" }, - { url = "https://files.pythonhosted.org/packages/0a/b7/fab2240da6f4e1ad46f71fa56ec577613cf5df9dce2d5b4cfaa4edd0e365/pybase64-1.4.3-cp313-cp313t-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6fac217cd9de8581a854b0ac734c50fd1fa4b8d912396c1fc2fce7c230efe3a7", size = 75534, upload-time = "2025-12-06T13:24:12.433Z" }, - { url = "https://files.pythonhosted.org/packages/91/3b/3e2f2b6e68e3d83ddb9fa799f3548fb7449765daec9bbd005a9fbe296d7f/pybase64-1.4.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:da1ee8fa04b283873de2d6e8fa5653e827f55b86bdf1a929c5367aaeb8d26f8a", size = 65399, upload-time = "2025-12-06T13:24:13.928Z" }, - { url = "https://files.pythonhosted.org/packages/6b/08/476ac5914c3b32e0274a2524fc74f01cbf4f4af4513d054e41574eb018f6/pybase64-1.4.3-cp313-cp313t-manylinux2014_armv7l.manylinux_2_17_armv7l.whl", hash = "sha256:b0bf8e884ee822ca7b1448eeb97fa131628fe0ff42f60cae9962789bd562727f", size = 60487, upload-time = "2025-12-06T13:24:15.177Z" }, - { url = "https://files.pythonhosted.org/packages/f1/b8/618a92915330cc9cba7880299b546a1d9dab1a21fd6c0292ee44a4fe608c/pybase64-1.4.3-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:1bf749300382a6fd1f4f255b183146ef58f8e9cb2f44a077b3a9200dfb473a77", size = 63959, upload-time = "2025-12-06T13:24:16.854Z" }, - { url = "https://files.pythonhosted.org/packages/a5/52/af9d8d051652c3051862c442ec3861259c5cdb3fc69774bc701470bd2a59/pybase64-1.4.3-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:153a0e42329b92337664cfc356f2065248e6c9a1bd651bbcd6dcaf15145d3f06", size = 64874, upload-time = "2025-12-06T13:24:18.328Z" }, - { url = "https://files.pythonhosted.org/packages/e4/51/5381a7adf1f381bd184d33203692d3c57cf8ae9f250f380c3fecbdbe554b/pybase64-1.4.3-cp313-cp313t-manylinux_2_31_riscv64.whl", hash = "sha256:86ee56ac7f2184ca10217ed1c655c1a060273e233e692e9086da29d1ae1768db", size = 58572, upload-time = "2025-12-06T13:24:19.417Z" }, - { url = "https://files.pythonhosted.org/packages/e0/f0/578ee4ffce5818017de4fdf544e066c225bc435e73eb4793cde28a689d0b/pybase64-1.4.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:0e71a4db76726bf830b47477e7d830a75c01b2e9b01842e787a0836b0ba741e3", size = 63636, upload-time = "2025-12-06T13:24:20.497Z" }, - { url = "https://files.pythonhosted.org/packages/b9/ad/8ae94814bf20159ea06310b742433e53d5820aa564c9fdf65bf2d79f8799/pybase64-1.4.3-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:2ba7799ec88540acd9861b10551d24656ca3c2888ecf4dba2ee0a71544a8923f", size = 56193, upload-time = "2025-12-06T13:24:21.559Z" }, - { url = "https://files.pythonhosted.org/packages/d1/31/6438cfcc3d3f0fa84d229fa125c243d5094e72628e525dfefadf3bcc6761/pybase64-1.4.3-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:2860299e4c74315f5951f0cf3e72ba0f201c3356c8a68f95a3ab4e620baf44e9", size = 72655, upload-time = "2025-12-06T13:24:22.673Z" }, - { url = "https://files.pythonhosted.org/packages/a3/0d/2bbc9e9c3fc12ba8a6e261482f03a544aca524f92eae0b4908c0a10ba481/pybase64-1.4.3-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:bb06015db9151f0c66c10aae8e3603adab6b6cd7d1f7335a858161d92fc29618", size = 62471, upload-time = "2025-12-06T13:24:23.8Z" }, - { url = "https://files.pythonhosted.org/packages/2c/0b/34d491e7f49c1dbdb322ea8da6adecda7c7cd70b6644557c6e4ca5c6f7c7/pybase64-1.4.3-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:242512a070817272865d37c8909059f43003b81da31f616bb0c391ceadffe067", size = 58119, upload-time = "2025-12-06T13:24:24.994Z" }, - { url = "https://files.pythonhosted.org/packages/ce/17/c21d0cde2a6c766923ae388fc1f78291e1564b0d38c814b5ea8a0e5e081c/pybase64-1.4.3-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:5d8277554a12d3e3eed6180ebda62786bf9fc8d7bb1ee00244258f4a87ca8d20", size = 60791, upload-time = "2025-12-06T13:24:26.046Z" }, - { url = "https://files.pythonhosted.org/packages/92/b2/eaa67038916a48de12b16f4c384bcc1b84b7ec731b23613cb05f27673294/pybase64-1.4.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f40b7ddd698fc1e13a4b64fbe405e4e0e1279e8197e37050e24154655f5f7c4e", size = 74701, upload-time = "2025-12-06T13:24:27.466Z" }, - { url = "https://files.pythonhosted.org/packages/e3/71/cf62b261d431857e8e054537a5c3c24caafa331de30daede7b2c6c558501/pybase64-1.4.3-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:8f183ac925a48046abe047360fe3a1b28327afb35309892132fe1915d62fb282", size = 30939, upload-time = "2025-12-06T13:24:34.001Z" }, - { url = "https://files.pythonhosted.org/packages/24/3e/d12f92a3c1f7c6ab5d53c155bff9f1084ba997a37a39a4f781ccba9455f3/pybase64-1.4.3-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:30bf3558e24dcce4da5248dcf6d73792adfcf4f504246967e9db155be4c439ad", size = 31401, upload-time = "2025-12-06T13:24:35.11Z" }, - { url = "https://files.pythonhosted.org/packages/9b/3d/9c27440031fea0d05146f8b70a460feb95d8b4e3d9ca8f45c972efb4c3d3/pybase64-1.4.3-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:a674b419de318d2ce54387dd62646731efa32b4b590907800f0bd40675c1771d", size = 38075, upload-time = "2025-12-06T13:24:36.53Z" }, - { url = "https://files.pythonhosted.org/packages/db/26/b136a4b65e5c94ff06217f7726478df3f31ab1c777c2c02cf698e748183f/pybase64-1.4.3-cp314-cp314-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:b51204d349a4b208287a8aa5b5422be3baa88abf6cc8ff97ccbda34919bbc857", size = 68460, upload-time = "2025-12-06T13:24:41.735Z" }, - { url = "https://files.pythonhosted.org/packages/68/6d/84ce50e7ee1ae79984d689e05a9937b2460d4efa1e5b202b46762fb9036c/pybase64-1.4.3-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:30f2fd53efecbdde4bdca73a872a68dcb0d1bf8a4560c70a3e7746df973e1ef3", size = 71688, upload-time = "2025-12-06T13:24:42.908Z" }, - { url = "https://files.pythonhosted.org/packages/e3/57/6743e420416c3ff1b004041c85eb0ebd9c50e9cf05624664bfa1dc8b5625/pybase64-1.4.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0932b0c5cfa617091fd74f17d24549ce5de3628791998c94ba57be808078eeaf", size = 60040, upload-time = "2025-12-06T13:24:44.37Z" }, - { url = "https://files.pythonhosted.org/packages/3b/68/733324e28068a89119af2921ce548e1c607cc5c17d354690fc51c302e326/pybase64-1.4.3-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.whl", hash = "sha256:acb61f5ab72bec808eb0d4ce8b87ec9f38d7d750cb89b1371c35eb8052a29f11", size = 56478, upload-time = "2025-12-06T13:24:45.815Z" }, - { url = "https://files.pythonhosted.org/packages/b5/9e/f3f4aa8cfe3357a3cdb0535b78eb032b671519d3ecc08c58c4c6b72b5a91/pybase64-1.4.3-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:2bc2d5bc15168f5c04c53bdfe5a1e543b2155f456ed1e16d7edce9ce73842021", size = 59463, upload-time = "2025-12-06T13:24:46.938Z" }, - { url = "https://files.pythonhosted.org/packages/aa/d1/53286038e1f0df1cf58abcf4a4a91b0f74ab44539c2547b6c31001ddd054/pybase64-1.4.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:8a7bc3cd23880bdca59758bcdd6f4ef0674f2393782763910a7466fab35ccb98", size = 60360, upload-time = "2025-12-06T13:24:48.039Z" }, - { url = "https://files.pythonhosted.org/packages/00/9a/5cc6ce95db2383d27ff4d790b8f8b46704d360d701ab77c4f655bcfaa6a7/pybase64-1.4.3-cp314-cp314-manylinux_2_31_riscv64.whl", hash = "sha256:ad15acf618880d99792d71e3905b0e2508e6e331b76a1b34212fa0f11e01ad28", size = 54999, upload-time = "2025-12-06T13:24:49.547Z" }, - { url = "https://files.pythonhosted.org/packages/64/e7/c3c1d09c3d7ae79e3aa1358c6d912d6b85f29281e47aa94fc0122a415a2f/pybase64-1.4.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:448158d417139cb4851200e5fee62677ae51f56a865d50cda9e0d61bda91b116", size = 58736, upload-time = "2025-12-06T13:24:50.641Z" }, - { url = "https://files.pythonhosted.org/packages/db/d5/0baa08e3d8119b15b588c39f0d39fd10472f0372e3c54ca44649cbefa256/pybase64-1.4.3-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:9058c49b5a2f3e691b9db21d37eb349e62540f9f5fc4beabf8cbe3c732bead86", size = 52298, upload-time = "2025-12-06T13:24:51.791Z" }, - { url = "https://files.pythonhosted.org/packages/00/87/fc6f11474a1de7e27cd2acbb8d0d7508bda3efa73dfe91c63f968728b2a3/pybase64-1.4.3-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:ce561724f6522907a66303aca27dce252d363fcd85884972d348f4403ba3011a", size = 69049, upload-time = "2025-12-06T13:24:53.253Z" }, - { url = "https://files.pythonhosted.org/packages/69/9d/7fb5566f669ac18b40aa5fc1c438e24df52b843c1bdc5da47d46d4c1c630/pybase64-1.4.3-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:63316560a94ac449fe86cb8b9e0a13714c659417e92e26a5cbf085cd0a0c838d", size = 57952, upload-time = "2025-12-06T13:24:54.342Z" }, - { url = "https://files.pythonhosted.org/packages/de/cc/ceb949232dbbd3ec4ee0190d1df4361296beceee9840390a63df8bc31784/pybase64-1.4.3-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:7ecd796f2ac0be7b73e7e4e232b8c16422014de3295d43e71d2b19fd4a4f5368", size = 54484, upload-time = "2025-12-06T13:24:55.774Z" }, - { url = "https://files.pythonhosted.org/packages/a7/69/659f3c8e6a5d7b753b9c42a4bd9c42892a0f10044e9c7351a4148d413a33/pybase64-1.4.3-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:d01e102a12fb2e1ed3dc11611c2818448626637857ec3994a9cf4809dfd23477", size = 56542, upload-time = "2025-12-06T13:24:57Z" }, - { url = "https://files.pythonhosted.org/packages/85/2c/29c9e6c9c82b72025f9676f9e82eb1fd2339ad038cbcbf8b9e2ac02798fc/pybase64-1.4.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:ebff797a93c2345f22183f454fd8607a34d75eca5a3a4a969c1c75b304cee39d", size = 71045, upload-time = "2025-12-06T13:24:58.179Z" }, - { url = "https://files.pythonhosted.org/packages/43/04/8b15c34d3c2282f1c1b0850f1113a249401b618a382646a895170bc9b5e7/pybase64-1.4.3-cp314-cp314t-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:a5ae04ea114c86eb1da1f6e18d75f19e3b5ae39cb1d8d3cd87c29751a6a22780", size = 72474, upload-time = "2025-12-06T13:25:06.434Z" }, - { url = "https://files.pythonhosted.org/packages/42/00/f34b4d11278f8fdc68bc38f694a91492aa318f7c6f1bd7396197ac0f8b12/pybase64-1.4.3-cp314-cp314t-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1755b3dce3a2a5c7d17ff6d4115e8bee4a1d5aeae74469db02e47c8f477147da", size = 75706, upload-time = "2025-12-06T13:25:07.636Z" }, - { url = "https://files.pythonhosted.org/packages/bb/5d/71747d4ad7fe16df4c4c852bdbdeb1f2cf35677b48d7c34d3011a7a6ad3a/pybase64-1.4.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fb852f900e27ffc4ec1896817535a0fa19610ef8875a096b59f21d0aa42ff172", size = 65589, upload-time = "2025-12-06T13:25:08.809Z" }, - { url = "https://files.pythonhosted.org/packages/49/b1/d1e82bd58805bb5a3a662864800bab83a83a36ba56e7e3b1706c708002a5/pybase64-1.4.3-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.whl", hash = "sha256:9cf21ea8c70c61eddab3421fbfce061fac4f2fb21f7031383005a1efdb13d0b9", size = 60670, upload-time = "2025-12-06T13:25:10.04Z" }, - { url = "https://files.pythonhosted.org/packages/15/67/16c609b7a13d1d9fc87eca12ba2dce5e67f949eeaab61a41bddff843cbb0/pybase64-1.4.3-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:afff11b331fdc27692fc75e85ae083340a35105cea1a3c4552139e2f0e0d174f", size = 64194, upload-time = "2025-12-06T13:25:11.48Z" }, - { url = "https://files.pythonhosted.org/packages/3c/11/37bc724e42960f0106c2d33dc957dcec8f760c91a908cc6c0df7718bc1a8/pybase64-1.4.3-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:d9a5143df542c1ce5c1f423874b948c4d689b3f05ec571f8792286197a39ba02", size = 64984, upload-time = "2025-12-06T13:25:12.645Z" }, - { url = "https://files.pythonhosted.org/packages/6e/66/b2b962a6a480dd5dae3029becf03ea1a650d326e39bf1c44ea3db78bb010/pybase64-1.4.3-cp314-cp314t-manylinux_2_31_riscv64.whl", hash = "sha256:d62e9861019ad63624b4a7914dff155af1cc5d6d79df3be14edcaedb5fdad6f9", size = 58750, upload-time = "2025-12-06T13:25:13.848Z" }, - { url = "https://files.pythonhosted.org/packages/2b/15/9b6d711035e29b18b2e1c03d47f41396d803d06ef15b6c97f45b75f73f04/pybase64-1.4.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:84cfd4d92668ef5766cc42a9c9474b88960ac2b860767e6e7be255c6fddbd34a", size = 63816, upload-time = "2025-12-06T13:25:15.356Z" }, - { url = "https://files.pythonhosted.org/packages/b4/21/e2901381ed0df62e2308380f30d9c4d87d6b74e33a84faed3478d33a7197/pybase64-1.4.3-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:60fc025437f9a7c2cc45e0c19ed68ed08ba672be2c5575fd9d98bdd8f01dd61f", size = 56348, upload-time = "2025-12-06T13:25:16.559Z" }, - { url = "https://files.pythonhosted.org/packages/c4/16/3d788388a178a0407aa814b976fe61bfa4af6760d9aac566e59da6e4a8b4/pybase64-1.4.3-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:edc8446196f04b71d3af76c0bd1fe0a45066ac5bffecca88adb9626ee28c266f", size = 72842, upload-time = "2025-12-06T13:25:18.055Z" }, - { url = "https://files.pythonhosted.org/packages/a6/63/c15b1f8bd47ea48a5a2d52a4ec61f037062932ea6434ab916107b58e861e/pybase64-1.4.3-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:e99f6fa6509c037794da57f906ade271f52276c956d00f748e5b118462021d48", size = 62651, upload-time = "2025-12-06T13:25:19.191Z" }, - { url = "https://files.pythonhosted.org/packages/bd/b8/f544a2e37c778d59208966d4ef19742a0be37c12fc8149ff34483c176616/pybase64-1.4.3-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:d94020ef09f624d841aa9a3a6029df8cf65d60d7a6d5c8687579fa68bd679b65", size = 58295, upload-time = "2025-12-06T13:25:20.822Z" }, - { url = "https://files.pythonhosted.org/packages/03/99/1fae8a3b7ac181e36f6e7864a62d42d5b1f4fa7edf408c6711e28fba6b4d/pybase64-1.4.3-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:f64ce70d89942a23602dee910dec9b48e5edf94351e1b378186b74fcc00d7f66", size = 60960, upload-time = "2025-12-06T13:25:22.099Z" }, - { url = "https://files.pythonhosted.org/packages/9d/9e/cd4c727742345ad8384569a4466f1a1428f4e5cc94d9c2ab2f53d30be3fe/pybase64-1.4.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:8ea99f56e45c469818b9781903be86ba4153769f007ba0655fa3b46dc332803d", size = 74863, upload-time = "2025-12-06T13:25:23.442Z" }, - { url = "https://files.pythonhosted.org/packages/bf/44/d4b7adc7bf4fd5b52d8d099121760c450a52c390223806b873f0b6a2d551/pybase64-1.4.3-graalpy311-graalpy242_311_native-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:a492518f3078a4e3faaef310697d21df9c6bc71908cebc8c2f6fbfa16d7d6b1f", size = 43227, upload-time = "2025-12-06T13:26:21.845Z" }, - { url = "https://files.pythonhosted.org/packages/08/86/2ba2d8734ef7939debeb52cf9952e457ba7aa226cae5c0e6dd631f9b851f/pybase64-1.4.3-graalpy311-graalpy242_311_native-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cae1a0f47784fd16df90d8acc32011c8d5fcdd9ab392c9ec49543e5f6a9c43a4", size = 35804, upload-time = "2025-12-06T13:26:23.149Z" }, - { url = "https://files.pythonhosted.org/packages/fa/8f/43c3bb11ca9bacf81cb0b7a71500bb65b2eda6d5fe07433c09b543de97f3/pybase64-1.4.3-graalpy312-graalpy250_312_native-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5c29a582b0ea3936d02bd6fe9bf674ab6059e6e45ab71c78404ab2c913224414", size = 43461, upload-time = "2025-12-06T13:26:28.906Z" }, - { url = "https://files.pythonhosted.org/packages/2d/4c/2a5258329200be57497d3972b5308558c6de42e3749c6cc2aa1cbe34b25a/pybase64-1.4.3-graalpy312-graalpy250_312_native-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b6b664758c804fa919b4f1257aa8cf68e95db76fc331de5f70bfc3a34655afe1", size = 36058, upload-time = "2025-12-06T13:26:30.092Z" }, - { url = "https://files.pythonhosted.org/packages/d3/22/832a2f9e76cdf39b52e01e40d8feeb6a04cf105494f2c3e3126d0149717f/pybase64-1.4.3-pp311-pypy311_pp73-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:bd4d2293de9fd212e294c136cec85892460b17d24e8c18a6ba18750928037750", size = 40681, upload-time = "2025-12-06T13:26:43.782Z" }, - { url = "https://files.pythonhosted.org/packages/12/d7/6610f34a8972415fab3bb4704c174a1cc477bffbc3c36e526428d0f3957d/pybase64-1.4.3-pp311-pypy311_pp73-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2af6d0d3a691911cc4c9a625f3ddcd3af720738c21be3d5c72de05629139d393", size = 41294, upload-time = "2025-12-06T13:26:44.936Z" }, - { url = "https://files.pythonhosted.org/packages/64/25/ed24400948a6c974ab1374a233cb7e8af0a5373cea0dd8a944627d17c34a/pybase64-1.4.3-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5cfc8c49a28322d82242088378f8542ce97459866ba73150b062a7073e82629d", size = 35447, upload-time = "2025-12-06T13:26:46.098Z" }, -] - [[package]] name = "pybind11" version = "3.0.2" @@ -7248,15 +6611,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e1/e3/0f15da0fb5864a37637820e4bde463a52ba0c052a8edab06aad46b9e578b/pycasbin-2.8.0-py3-none-any.whl", hash = "sha256:1a9e370de553c677c4dff75a5d6f3b0eb354b73b20d7df77ff4ee61a71267a3a", size = 476153, upload-time = "2026-02-02T03:34:12.555Z" }, ] -[[package]] -name = "pycountry" -version = "26.2.16" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/de/1d/061b9e7a48b85cfd69f33c33d2ef784a531c359399ad764243399673c8f5/pycountry-26.2.16.tar.gz", hash = "sha256:5b6027d453fcd6060112b951dd010f01f168b51b4bf8a1f1fc8c95c8d94a0801", size = 7711342, upload-time = "2026-02-17T03:42:52.367Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9c/42/7703bd45b62fecd44cd7d3495423097e2f7d28bc2e99e7c1af68892ab157/pycountry-26.2.16-py3-none-any.whl", hash = "sha256:115c4baf7cceaa30f59a4694d79483c9167dbce7a9de4d3d571c5f3ea77c305a", size = 8044600, upload-time = "2026-02-17T03:42:49.777Z" }, -] - [[package]] name = "pycparser" version = "3.0" @@ -7426,11 +6780,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/17/c1/3226e6d7f5a4f736f38ac11a6fbb262d701889802595cdb0f53a885ac2e0/pydantic_extra_types-2.11.1-py3-none-any.whl", hash = "sha256:1722ea2bddae5628ace25f2aa685b69978ef533123e5638cfbddb999e0100ec1", size = 79526, upload-time = "2026-03-16T08:08:02.533Z" }, ] -[package.optional-dependencies] -pycountry = [ - { name = "pycountry", marker = "sys_platform == 'linux'" }, -] - [[package]] name = "pydantic-settings" version = "2.13.1" @@ -7646,15 +6995,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101, upload-time = "2026-03-01T16:00:25.09Z" }, ] -[[package]] -name = "python-json-logger" -version = "4.0.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/29/bf/eca6a3d43db1dae7070f70e160ab20b807627ba953663ba07928cdd3dc58/python_json_logger-4.0.0.tar.gz", hash = "sha256:f58e68eb46e1faed27e0f574a55a0455eecd7b8a5b88b85a784519ba3cff047f", size = 17683, upload-time = "2025-10-06T04:15:18.984Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/51/e5/fecf13f06e5e5f67e8837d777d1bc43fac0ed2b77a676804df5c34744727/python_json_logger-4.0.0-py3-none-any.whl", hash = "sha256:af09c9daf6a813aa4cc7180395f50f2a9e5fa056034c9953aec92e381c5ba1e2", size = 15548, upload-time = "2025-10-06T04:15:17.553Z" }, -] - [[package]] name = "python-multipart" version = "0.0.22" @@ -7853,10 +7193,10 @@ name = "quack-kernels" version = "0.2.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "apache-tvm-ffi", marker = "sys_platform == 'linux'" }, - { name = "nvidia-cutlass-dsl", marker = "sys_platform == 'linux'" }, - { name = "torch", marker = "sys_platform == 'linux'" }, - { name = "torch-c-dlpack-ext", marker = "sys_platform == 'linux'" }, + { name = "apache-tvm-ffi" }, + { name = "nvidia-cutlass-dsl" }, + { name = "torch" }, + { name = "torch-c-dlpack-ext" }, ] sdist = { url = "https://files.pythonhosted.org/packages/89/de/472a20a625495e31c33a99a30867c1d58335a1afa02dc30019f667702d1d/quack_kernels-0.2.5.tar.gz", hash = "sha256:06241a5962c09b4a2c27d4d21208e31790836fecde4373c6e9d874fdd88b5590", size = 152256, upload-time = "2026-01-31T09:07:09.998Z" } wheels = [ @@ -7898,34 +7238,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c4/43/80f67e0336cb2fc725f8e06f7fe35c1d0fe946f4d2b8b2175e797e07349e/qwen_vl_utils-0.0.14-py3-none-any.whl", hash = "sha256:5e28657bfd031e56bd447c5901b58ddfc3835285ed100f4c56580e0ade054e96", size = 8120, upload-time = "2025-09-23T09:38:56.297Z" }, ] -[[package]] -name = "ray" -version = "2.54.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click", marker = "sys_platform == 'linux'" }, - { name = "filelock", marker = "sys_platform == 'linux'" }, - { name = "jsonschema", marker = "sys_platform == 'linux'" }, - { name = "msgpack", marker = "sys_platform == 'linux'" }, - { name = "packaging", marker = "sys_platform == 'linux'" }, - { name = "protobuf", marker = "sys_platform == 'linux'" }, - { name = "pyyaml", marker = "sys_platform == 'linux'" }, - { name = "requests", marker = "sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/ac/29/7871f4206e6b00a9bb784c16dad32ccd01e9df5a93545db92de220eb2871/ray-2.54.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:491ae56ab80d8822c4eaf4d5bb96dcf32a6231d8d7b76eb8034400eb9be1bb18", size = 72066630, upload-time = "2026-02-18T04:05:04.957Z" }, - { url = "https://files.pythonhosted.org/packages/1d/e8/d2c8ebd9cd945abc817b01ad02a29df78cdb86cd07d764587e16977389d0/ray-2.54.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:928bb09245a3c6f7c3c113ba8eafc69f948da9602d7f33e8251ecdf97c157615", size = 72895723, upload-time = "2026-02-18T04:05:10.686Z" }, - { url = "https://files.pythonhosted.org/packages/60/ad/e07aca3637e9c3ec4857ec4366208099cf8488ece8061a9925ba29b66382/ray-2.54.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:795ae21d6b764245d3f521bc5833446d58569e7dfde9c5777417eb285d87450f", size = 72107346, upload-time = "2026-02-18T04:05:27.999Z" }, - { url = "https://files.pythonhosted.org/packages/9e/b9/cc5ea8460c3dc602e6b7198277a7c59ba2b8929374ab22efa8df9f3deac8/ray-2.54.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:a972afd5aa3dda99d0b2f369b5f62e5dd95865ab7d37bf2e0a0e0d2cfbd9b325", size = 72967230, upload-time = "2026-02-18T04:05:33.771Z" }, - { url = "https://files.pythonhosted.org/packages/fd/8c/4a4a38eaec6e9614076a96967f58540f4f8d4aa0c793f43150c5df23cb9a/ray-2.54.0-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:8952c23a8aa94f10728c2d16e0dc3732d09aa0e6254801757ff494984a214f45", size = 72013826, upload-time = "2026-02-18T04:05:49.866Z" }, - { url = "https://files.pythonhosted.org/packages/42/ac/e7ec2a406bd755f61c7090460fa5ab3f09b00c3c2d8db6d0b559f78a30eb/ray-2.54.0-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:ab89e6089abb6e46fb98fdd96d399b31a852d79127cd8ac00746c61d93defa2c", size = 72880209, upload-time = "2026-02-18T04:05:55.498Z" }, -] - -[package.optional-dependencies] -cgraph = [ - { name = "cupy-cuda12x", marker = "sys_platform == 'linux'" }, -] - [[package]] name = "referencing" version = "0.37.0" @@ -8763,29 +8075,66 @@ wheels = [ [[package]] name = "simplejson" -version = "4.0.0" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/bc/cd/0fbfa6dff885f6d20c786d951c9ff9d29f5c8705d895df6b3b3e107af22d/simplejson-4.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:718b82b5b3efc76c5c178441f96c67e59261fd80cb4c327a02e8b68ee3fc66be", size = 108971, upload-time = "2026-04-18T19:08:49.367Z" }, - { url = "https://files.pythonhosted.org/packages/e6/ed/954e4a32803cbce0bdcc76c4310f7c0386722895eb91d9560c0d84a3456b/simplejson-4.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:86885efbb6ddbab9aa6b73a5520ed331c3b442757bfc917cee49fc27f5e4b291", size = 89162, upload-time = "2026-04-18T19:08:51.126Z" }, - { url = "https://files.pythonhosted.org/packages/40/96/360f29e7c8f0de11bde878f08d4fc064190b67e03f74b6d7604bd0c22a4a/simplejson-4.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0166383d28901691a6d12533295bccdf9ac850c129dffac489216139050d9934", size = 89573, upload-time = "2026-04-18T19:08:52.426Z" }, - { url = "https://files.pythonhosted.org/packages/d5/81/10c4d9dba518f1fdb1cfde78660c12615134c46cd1f7014947f703244dae/simplejson-4.0.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:bd894e2d36f250bf7eba553e32840947c400c6dac43e2e726f276e92d157936d", size = 174948, upload-time = "2026-04-18T19:08:54.052Z" }, - { url = "https://files.pythonhosted.org/packages/1e/87/ba488e54cb8653ab8b328c6c15c170da6735c926d5cc5fe12de03041a8a2/simplejson-4.0.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:73781d0393a00fc8cc820da9016a3ee9297c23bcbf0c6dff470e508db53b6c58", size = 172433, upload-time = "2026-04-18T19:08:55.564Z" }, - { url = "https://files.pythonhosted.org/packages/b6/83/2606b446beddd767579eaf5a915f43450c57e149e34bc58cddd116ecc016/simplejson-4.0.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:bf34b959a8326e93ec6652fb8126458b84ad5f32940cfc4c7be5cfeb196837ce", size = 182001, upload-time = "2026-04-18T19:08:57.552Z" }, - { url = "https://files.pythonhosted.org/packages/07/cd/0a1e7790ae4839b9421b3e2e2e6d7b864eaa2e03ac5bb189324fa4785f64/simplejson-4.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:03fce7a33ef707dd272f4eef120d76db7fcb029aad480181c6478dda9c6f8937", size = 170292, upload-time = "2026-04-18T19:08:58.929Z" }, - { url = "https://files.pythonhosted.org/packages/c1/14/5794540aa669df02153a7af5da9b90e5b9b9dd670b99d59254f45efe0b54/simplejson-4.0.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:c6174e3d98fc081f0ec810392d3140b8b8a7158e751ef9518237adbdf34edcf8", size = 178935, upload-time = "2026-04-18T19:09:00.309Z" }, - { url = "https://files.pythonhosted.org/packages/0d/d1/1dd7e7df22b537ab70d84fcfd1aa13ae20ef57f521c82b5379dc518e2a9a/simplejson-4.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:7eaf534e2c5436a20f37224e730674597519261c3bb2ba0f5f60ac5c20697246", size = 173586, upload-time = "2026-04-18T19:09:02.078Z" }, - { url = "https://files.pythonhosted.org/packages/ee/89/55b8b360b880ff860eca05266b28352bca0684daf2389d342ca9a3948586/simplejson-4.0.0-cp311-cp311-win32.whl", hash = "sha256:e09a8894d7a74f88ceb896a6776ac2b8564be85ce8641eee54a115d82932e93c", size = 87866, upload-time = "2026-04-18T19:09:03.655Z" }, - { url = "https://files.pythonhosted.org/packages/23/c2/34533a4f22f447f800709ef3f04f5fdd3a7418ef3cd038b36e678c9f2e59/simplejson-4.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:83c13de82ae196c6250af95e898300b79e8ac4d81a32c0918530d6f1905aa218", size = 89925, upload-time = "2026-04-18T19:09:04.979Z" }, - { url = "https://files.pythonhosted.org/packages/9d/b9/5f3ba2c372c82213e0414ca2de73bf4e8fcd5300fa493980df12df62f2d5/simplejson-4.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a1df4396b2e91f1341c59d74745f4487ef283a40664cf4d41b0f3224fcedb44b", size = 110115, upload-time = "2026-04-18T19:09:06.31Z" }, - { url = "https://files.pythonhosted.org/packages/c6/fd/3ca6a2cfb82eb94e24839362b646d61193fbed7a968cf6edfd3b219441bc/simplejson-4.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c1c76ad790978ade6681b541ae6c08b7e67fdb42afbe6f190f6249b941165b99", size = 89807, upload-time = "2026-04-18T19:09:07.665Z" }, - { url = "https://files.pythonhosted.org/packages/d5/1b/d3d0e743102bda48d4e56b4f89cd2167e65f6709fa742181a4d1ca03fc50/simplejson-4.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9b9d637c6309213134009d341470f704b873fac1d864a834f2e0234a6d5ea67b", size = 90103, upload-time = "2026-04-18T19:09:09.287Z" }, - { url = "https://files.pythonhosted.org/packages/ea/b7/574bc1e563a57199024f4285b71f25a35a73109cab4dabe7b3c3e16aa1b2/simplejson-4.0.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:bc283ec7e90a4981901f29aea892409c2793ebf2f40d0e52b65833866f01078b", size = 183969, upload-time = "2026-04-18T19:09:10.931Z" }, - { url = "https://files.pythonhosted.org/packages/7f/67/da1cf7bc6cac906dc96fcebc10e67007e44934b87e63d1dde0ee32a2bee3/simplejson-4.0.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:752b5936832130a3522358cdfbbbf43e1b03e08c518ae9d0784e9d495e9146e0", size = 181209, upload-time = "2026-04-18T19:09:12.448Z" }, - { url = "https://files.pythonhosted.org/packages/3b/93/959214ec5cf3d7b5288e3ca5dbb02da3405ec9da7b0b1f3bcc418579c48b/simplejson-4.0.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d9b034d63d097ff658134e1cf392f93c5b9de068ac9dacac761e5a4dc7d1a7a0", size = 189816, upload-time = "2026-04-18T19:09:14.214Z" }, - { url = "https://files.pythonhosted.org/packages/25/b9/10d32426964e1a55c43b999a6e6cfedb7f8a5a475a4b5c4e8ab48a0ad870/simplejson-4.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8985428511c21b2b148c7e2cdf39f8c441fbb833f28231d93c5f9f02739cce30", size = 177993, upload-time = "2026-04-18T19:09:15.722Z" }, - { url = "https://files.pythonhosted.org/packages/31/35/efd8526409fdeece24eef58a661b94b986668a3aa574528f61b9fade8a1a/simplejson-4.0.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:feabd1a6a45c9e6f72b3fdfdb1b1fc90f89f95ef1cf00e38dfdec80aa29bbea0", size = 186797, upload-time = "2026-04-18T19:09:17.227Z" }, - { url = "https://files.pythonhosted.org/packages/2b/4b/a9db28f6e48cd28b82f8db054e3b420ae6e0358115f9269336892fe2b576/simplejson-4.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:90f103b2782ffb28828577d75f91cd6be7a6a5fb923753fd5be47500ebf06244", size = 181522, upload-time = "2026-04-18T19:09:19.192Z" }, +version = "4.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0e/2a/54837395a3487c725669428d513293612a48d82b95a0642c936932e5d898/simplejson-4.1.1.tar.gz", hash = "sha256:c08eb9f7a90f77ae470e19a07472e9a79ebc0d1c2315d86a72767665bd5ba79f", size = 118860, upload-time = "2026-04-24T19:24:59.819Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/25/39013ffe279d90093ec1c848565b3683c586906c10fa55d9000ec29d046b/simplejson-4.1.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2867c64d92abd1992c15666fae198203093f593e43d6b81adf176bae530d493a", size = 111538, upload-time = "2026-04-24T19:22:49.051Z" }, + { url = "https://files.pythonhosted.org/packages/f2/ae/2c272971c8a87e2539c54a98eb6ff037bee1e2e93943c3986cf7500a4f3a/simplejson-4.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c47c46e16c8ea9e4850061e6ed5aa2b9cd2074cb2274bfd9c138cba15ce7453", size = 90594, upload-time = "2026-04-24T19:22:50.408Z" }, + { url = "https://files.pythonhosted.org/packages/4e/a2/6eebfb99dedc139f549200f61ade6d1890ac5707c5d427bdfa6fe39c9313/simplejson-4.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e294e33dbf316a9bbdd4030d46503c9b0f19470ae7ad6af5bae6c426bc2e869f", size = 90718, upload-time = "2026-04-24T19:22:51.694Z" }, + { url = "https://files.pythonhosted.org/packages/80/7e/c9e6c0c4ad8415e64dad0c47f619b556b02680a41631b4dbc281d55dc54d/simplejson-4.1.1-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7ce252b28fddbdd83db5bd7d93dad2a8a591d7ada098afec9c1b23d6b722a7a4", size = 180901, upload-time = "2026-04-24T19:22:53.025Z" }, + { url = "https://files.pythonhosted.org/packages/34/09/69e331e3994b1ed9be6ce9ace4ade704e7ed503edf869929ca7bb404eda8/simplejson-4.1.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4c44ef6b02a4eb67ed17a72342341792149b3ff46f15426c26e970e49addf327", size = 178133, upload-time = "2026-04-24T19:22:54.574Z" }, + { url = "https://files.pythonhosted.org/packages/5d/40/ed806f24afef295c1032448f5ff6f6f2979392d5645ddb9f4fed7f38194d/simplejson-4.1.1-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:82bfca2b85a34178c25829c703f0a9e9f113a5af7539285bd3efb583a0bf1ba3", size = 188155, upload-time = "2026-04-24T19:22:56.044Z" }, + { url = "https://files.pythonhosted.org/packages/38/94/8d6f515b827b0f7881a49c8c1ac6920b7ae9428939ef04238c973278b42a/simplejson-4.1.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0e4b23f71dd781f8830f1663dc01a4944d3dbf87a1f93d78fba1cf64722d0ccf", size = 176225, upload-time = "2026-04-24T19:22:57.981Z" }, + { url = "https://files.pythonhosted.org/packages/c9/fd/6dffb4956563d48bbe46b91ff341adae34920e94008fd6b8d728072abfc7/simplejson-4.1.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:82fee635d7b73ad801030b05a75fbd34a098da0c2ecf600667a03636d09e1e42", size = 185535, upload-time = "2026-04-24T19:22:59.618Z" }, + { url = "https://files.pythonhosted.org/packages/de/d2/a509ee37763e79aec75d68f8521db1440306edeba3b8b4064ab4ee8bf1d9/simplejson-4.1.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:68e62eda21192c5ea9bb92d571ca46a4477fef48762f50d433de2b4253051551", size = 179302, upload-time = "2026-04-24T19:23:01.324Z" }, + { url = "https://files.pythonhosted.org/packages/d8/23/5b343bfd2a79d3b6818e4db3586c405a001a090d4c89d336e31273ce7177/simplejson-4.1.1-cp311-cp311-win32.whl", hash = "sha256:ffd3d82294b47f5ec64050021ace95fd62628a0c1cc8bbf4d06d2d1fb697e055", size = 88408, upload-time = "2026-04-24T19:23:02.808Z" }, + { url = "https://files.pythonhosted.org/packages/38/04/df9b37aedbd524dca20840d25ebe01d6ae486b89792aeff5d15b9c4114f7/simplejson-4.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:78a3fe0995be42bed62a26aa78e0e0b4d87c6545785346b9cc898f3389569a35", size = 90526, upload-time = "2026-04-24T19:23:04.408Z" }, + { url = "https://files.pythonhosted.org/packages/60/25/e90998fe8e480eb43b966c09e835379887d427567ebd496563d3b1e16b19/simplejson-4.1.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:19040a17154dc03d289bab68d73ce0a6a0be01de30c584bbdd93490bead14b22", size = 112414, upload-time = "2026-04-24T19:23:06.084Z" }, + { url = "https://files.pythonhosted.org/packages/9c/a0/abd4785f36c3400f1fbb21f517be39295a750a714f04b7ee175adf6ef580/simplejson-4.1.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a94ebaecdbaa80d9551a3ec6bf0c9302fc8b53ab6c1b2bfd498a1df4cb28158d", size = 91120, upload-time = "2026-04-24T19:23:07.877Z" }, + { url = "https://files.pythonhosted.org/packages/b8/78/fc060d2e3b13c6ec59288574b8efac64075e316b2afba4396a56b2422f78/simplejson-4.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:67341c95c0a168ab4a6d1e807e50463f1c8da932c3286d81e201266c427061fa", size = 91055, upload-time = "2026-04-24T19:23:09.264Z" }, + { url = "https://files.pythonhosted.org/packages/0c/b6/156a8de1e1b47694f0e7de6675866936608d45dc68388fd017d36f8693be/simplejson-4.1.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:45ec18e337fec538b7e902d489505c450b2454653d1290f3f50385e6fd8aa607", size = 190297, upload-time = "2026-04-24T19:23:11.226Z" }, + { url = "https://files.pythonhosted.org/packages/86/1c/e4d0eab695be3eb21d0f46bce820752031f03e7113f9c80a9b3c73ee7157/simplejson-4.1.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:820c69a4710400e9b248d5670647d60be58824369282d3925e516b3ff1a7cd82", size = 187002, upload-time = "2026-04-24T19:23:12.982Z" }, + { url = "https://files.pythonhosted.org/packages/76/0e/7f5a59d29426b062d5928fb88b403c3f797129d53be7102f955dbe51aa44/simplejson-4.1.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2e708d373a10e4378ef2d59f8361850c7150fd907ed49efe49bc5492160476d1", size = 195146, upload-time = "2026-04-24T19:23:14.517Z" }, + { url = "https://files.pythonhosted.org/packages/78/18/9943db224dd4d5fa3c090c3e56a94c37b254338c83995ec5680285111c40/simplejson-4.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:980fc33353f81fd12d8c49d44f8c2760d1dc8192285e627c5180d141035b228a", size = 183931, upload-time = "2026-04-24T19:23:16.742Z" }, + { url = "https://files.pythonhosted.org/packages/c2/08/9a690da9a766161c06c627d805362cf159f1abe480969372b2897649b955/simplejson-4.1.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:de2ed102fff88dacf543699f53ee3a533cc11539a39baa176b7e09dd783069d6", size = 192228, upload-time = "2026-04-24T19:23:18.33Z" }, + { url = "https://files.pythonhosted.org/packages/05/88/bd8aad36b451ffb0e0a3f721d695a88befa6d1ac7d1e02ae788ca7ff4029/simplejson-4.1.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2785ff8edc0e28bf773a32543a6bbed46351453c997b3f6709c744e3c2f7eabb", size = 187808, upload-time = "2026-04-24T19:23:21.165Z" }, + { url = "https://files.pythonhosted.org/packages/04/ee/14f91db0d1f481533b651dafbf8cd0da088d9817f7af30c68f7f19f9c847/simplejson-4.1.1-cp312-cp312-win32.whl", hash = "sha256:2e0d5ead6d14610467ec356ec1f6b5d8a56aa216abaad8d41c8b873b16cf313f", size = 88512, upload-time = "2026-04-24T19:23:22.764Z" }, + { url = "https://files.pythonhosted.org/packages/b9/c4/90de06b2d8737c68c05ff9274113f854dbf6a5f28b7a955212111672cb57/simplejson-4.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:63a5451f557d6be48a231bae932458655c620902b868170b2f1c8afed496f6b4", size = 90748, upload-time = "2026-04-24T19:23:24.494Z" }, + { url = "https://files.pythonhosted.org/packages/37/a9/47b445eeb559c9593453a0648e0fd6d08e8adff64dd5e5ced66726da8a09/simplejson-4.1.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dff52fc7af272e84fc21cc5a06c927c823ca6ae00af14f3b0d7707b42775ed98", size = 113160, upload-time = "2026-04-24T19:23:26.033Z" }, + { url = "https://files.pythonhosted.org/packages/4c/65/cb72db31523c164dea5dc55b02dad065a40c478856bc7534b279d2b51906/simplejson-4.1.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:971aed0647ad6e840a3943bec812fcda5f2d26a5497a4981d1fb49aa4f9a396c", size = 91521, upload-time = "2026-04-24T19:23:27.572Z" }, + { url = "https://files.pythonhosted.org/packages/9a/e5/54cb7c50ad5fdc1e0a86b7df4b135c2cbd5c4623605aa94466659098e8da/simplejson-4.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:249e2e220aa6d9b9d936bde84eb7bf79d5b6c5a8273c6e411f8b1635a9073f2d", size = 91407, upload-time = "2026-04-24T19:23:28.991Z" }, + { url = "https://files.pythonhosted.org/packages/38/2e/21a3ede87f0bf82d6c7bcb90480d50a6490eb974c6ab20881188e440957c/simplejson-4.1.1-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8e5cdd6a5d52299f345c15ab5678cc4249e24f383f361d986afbc3c7072a6b6b", size = 192451, upload-time = "2026-04-24T19:23:30.56Z" }, + { url = "https://files.pythonhosted.org/packages/59/df/9903edd3102bf0b5984edfcb90c88612330996efa3b4fbf8a971d6e17839/simplejson-4.1.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:642cec364e0676e2d5a73fa4d31d0c7c55886997caa2fde24e8292ca44d32728", size = 189015, upload-time = "2026-04-24T19:23:32.647Z" }, + { url = "https://files.pythonhosted.org/packages/98/cd/33230927a780e1398b857e3944abb914556994d252b1d765ae40d112cb25/simplejson-4.1.1-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:76fe296ca1df23d290033f10aaacf534fd1b3e3007e7f9ff8aa68b21413aaa78", size = 196658, upload-time = "2026-04-24T19:23:34.563Z" }, + { url = "https://files.pythonhosted.org/packages/cd/84/2c5a7444eb53e9a86d3738299bffddd9f53aeed799ded2f45368221fdb19/simplejson-4.1.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:8f0ad25b7dc4e0fb23858355819f2e994f1a5badcdcde8737eac7921c2f1ed2a", size = 185967, upload-time = "2026-04-24T19:23:36.191Z" }, + { url = "https://files.pythonhosted.org/packages/d3/68/454378e06d059cd412a7ed5d87fb6d29fd5b60f13a4d89fc1f764ff434df/simplejson-4.1.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:a59ebd0533f03fd06ff0c42ba0f02d93cbcdd7944922bf3b93911327a95b901f", size = 193940, upload-time = "2026-04-24T19:23:38.151Z" }, + { url = "https://files.pythonhosted.org/packages/d5/d5/a15bf915f623a2c5a079d6e3be8256fdb8ef06f110669493a09b9d6933e0/simplejson-4.1.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:bccbf4419676b517939852e5aeff2af6aee4dc046881c67a1581fa6f1cb01abd", size = 189795, upload-time = "2026-04-24T19:23:40.139Z" }, + { url = "https://files.pythonhosted.org/packages/d2/c9/37212ae7dc4b607f0978c408e8633f05c810884e054c33113184c6c2c8a2/simplejson-4.1.1-cp313-cp313-win32.whl", hash = "sha256:6c845363eb5fd166fb7c72243da38f4fcfde666ede7fdf2cc6fd7762894626f7", size = 88773, upload-time = "2026-04-24T19:23:41.754Z" }, + { url = "https://files.pythonhosted.org/packages/fe/a5/c7a0a47883a9015b54c9d8a4b62f2aba17bd4335b1787b9b8a0fc2fa6d52/simplejson-4.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:104d8324c34f25b4b90800bc5fa363780cbc3d8496aef061cba7ce1af9162270", size = 90888, upload-time = "2026-04-24T19:23:43.11Z" }, + { url = "https://files.pythonhosted.org/packages/d3/18/4a118a6a92eb33bb08c8e2fe7ec85cb96f0673491bb2b829930831ee4fbe/simplejson-4.1.1-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:ed7473602b6625de793b6acba49aa949f144a475f538792067e4cf2fda2071f5", size = 110492, upload-time = "2026-04-24T19:23:44.957Z" }, + { url = "https://files.pythonhosted.org/packages/07/f4/84d160e9fa8cada1e0a9381cae4fa81eecd573577a5b34366d8ced59bdf7/simplejson-4.1.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:225c9caa324c5b554d009fb9cac22aee7711e71bd96f487938c659af467e828e", size = 90152, upload-time = "2026-04-24T19:23:46.355Z" }, + { url = "https://files.pythonhosted.org/packages/68/31/9a5432c433a7671107182cdc9a20ea78a70f99c4e5334aa54b6d4d0d79ed/simplejson-4.1.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:95407269340c7f22f09776ea7b717a52cf56cfcf119b5e45f66faa4a26445bea", size = 90115, upload-time = "2026-04-24T19:23:47.743Z" }, + { url = "https://files.pythonhosted.org/packages/78/91/3635cdb13318cb0a328abaa69e2b91251caad39d6779aa308098f341f6cb/simplejson-4.1.1-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3851658d642c1184d2023f0e6c9ce44a21eb1629e74e7c84ef956b128841fe12", size = 184036, upload-time = "2026-04-24T19:23:49.472Z" }, + { url = "https://files.pythonhosted.org/packages/fa/ba/149b6ec5393f6849d98c59cadba888b710a8ef4b805ab91e11a566960d40/simplejson-4.1.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:95a3bb0f78e85f4937f99092239f2011ce06f0f2d803df5c299cc05abbeae008", size = 180543, upload-time = "2026-04-24T19:23:51.023Z" }, + { url = "https://files.pythonhosted.org/packages/df/7c/a5d968d0b527a748b667e62bea94309ccbcb1e2b108e8f0cf8547efaa12b/simplejson-4.1.1-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:bbfdaa7c0603f75b7b14b211b7f2be44696d4e26833ad2d91d5c87bf5fb9a920", size = 188725, upload-time = "2026-04-24T19:23:52.995Z" }, + { url = "https://files.pythonhosted.org/packages/db/e3/6a8d11181d587ef00e2db9112357e6832111e56dd56b01b5c11758a1965d/simplejson-4.1.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:39e3c584071dced8c21b4689f0254303521daeb9b5bc1f4289755d71fa3cb0d3", size = 177492, upload-time = "2026-04-24T19:23:54.581Z" }, + { url = "https://files.pythonhosted.org/packages/67/e3/8b0eb8b06e8198cfbd1270487da163d0093df05cc4f557350cd65e2f7e79/simplejson-4.1.1-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:036a27bd0469b9d79557cbddb392969f876cd7f278cfbd0fba81534927a06575", size = 185281, upload-time = "2026-04-24T19:23:56.13Z" }, + { url = "https://files.pythonhosted.org/packages/dc/5f/64990f07ec9e2cb1a814c674e2e21b5693207f74ac70eb72151b847ea4e6/simplejson-4.1.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b70bfd2f67f3351baba08aa3ae9233c83f21fd95ae5e6b3d0ecb8c647929112f", size = 181848, upload-time = "2026-04-24T19:23:57.92Z" }, + { url = "https://files.pythonhosted.org/packages/61/a5/bbc1bc0447f339f79f99ab8c37f7f037cb2f1f93af75d6a4d553096bb0c3/simplejson-4.1.1-cp314-cp314-win32.whl", hash = "sha256:37233c72ce88d06acb92747347742b3c07871eba6789f060c179c9302dde8efe", size = 88761, upload-time = "2026-04-24T19:23:59.397Z" }, + { url = "https://files.pythonhosted.org/packages/18/72/ec1b5cbdcb140c132e6c7bdf99bd73e4f675439e77126c88f472fcffa09c/simplejson-4.1.1-cp314-cp314-win_amd64.whl", hash = "sha256:cc0442dea71cd9cbf30a0b8b9929ab5aa6c02c0443a3d977351e6ec5bada4388", size = 91018, upload-time = "2026-04-24T19:24:00.85Z" }, + { url = "https://files.pythonhosted.org/packages/3d/97/4fa437f68ff72219bac3bf3d050de9c6265691f3a170e16954bd69d7cddd/simplejson-4.1.1-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:c996a4d38290c515af347740659ce095b425449c164a5c9fa3977caa6eff5dbe", size = 113919, upload-time = "2026-04-24T19:24:02.287Z" }, + { url = "https://files.pythonhosted.org/packages/c2/83/59de041d09eb4a9577f7015d7263c32095dfb7fde49717dff62145d89809/simplejson-4.1.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:c65c763fb20d7ca113c1c14dce2fc04a0fc3a57aceff533d6fdac707c7bffb40", size = 91904, upload-time = "2026-04-24T19:24:03.812Z" }, + { url = "https://files.pythonhosted.org/packages/03/8e/46bb345d540f6eb31427d984a4e518cdb182d0621814fee4fee045e8815b/simplejson-4.1.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:0da5c9f57206ee7ef280ff7f1d924937b0a64f9a271a5ef371a2ecdbebba7421", size = 91752, upload-time = "2026-04-24T19:24:05.622Z" }, + { url = "https://files.pythonhosted.org/packages/83/e2/1b2ce97f068835eb3d253c116a4df7a3f436b7bf2fb5ff1ba29287e8b0ec/simplejson-4.1.1-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:ea3426e786425d10e9e82f8a6eda74a7d6eb10d99165ac3d0d3bbcb65c0ea343", size = 214021, upload-time = "2026-04-24T19:24:07.447Z" }, + { url = "https://files.pythonhosted.org/packages/48/70/d93e556df6a0786298644a7c08304fcbeddc248325f23f38acbebeb21165/simplejson-4.1.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d75cea7a1025edd7e439b2966b3d977c45b5b899e2adaf422811b3ac702ed9fb", size = 213530, upload-time = "2026-04-24T19:24:09.289Z" }, + { url = "https://files.pythonhosted.org/packages/1b/a5/c93bf305b9f00d7259e09e713d60e75bd0f7f53da970f716ab90491770e7/simplejson-4.1.1-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:63c2ada8e58f266491f19eed2eeeb7c25c6141e52f8f9e820f6bb94156cf8dbc", size = 218282, upload-time = "2026-04-24T19:24:10.991Z" }, + { url = "https://files.pythonhosted.org/packages/0c/20/a9b5d2e27ec44b069ee251bd55544fc76929a067107b1050001566ba86f3/simplejson-4.1.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:d1fffb56305c5b475ee746cf9e04f97423ba5aaacd292dc1255bd75b1d3b124b", size = 209249, upload-time = "2026-04-24T19:24:12.662Z" }, + { url = "https://files.pythonhosted.org/packages/97/e4/e06ee682ed5df67592181f5ecb062e35878967e27f5b6e087237d4548d95/simplejson-4.1.1-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:a6525ec733f43d0541206cffa64fd2aad5a7ae3eb76566aff49cd4db6382209a", size = 213963, upload-time = "2026-04-24T19:24:14.302Z" }, + { url = "https://files.pythonhosted.org/packages/9c/9f/1e160e4cd8cdbf062bf6a454cdf814dc7a48eb47e566fdb8f80ccb202605/simplejson-4.1.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:861e393260508efa64d8805a8e49c416c3484907e3f146ce966c69552b49b9a3", size = 210474, upload-time = "2026-04-24T19:24:15.917Z" }, + { url = "https://files.pythonhosted.org/packages/7a/e6/cecd913df322df5bbe7ebb8ba39e0708e505a165553900da8a7761026d6f/simplejson-4.1.1-cp314-cp314t-win32.whl", hash = "sha256:d083b89d30948a751d3d97476c2ed91e4caaa24a1a1459bdbadb8876242c71fe", size = 91134, upload-time = "2026-04-24T19:24:17.635Z" }, + { url = "https://files.pythonhosted.org/packages/97/73/f540dde99cc1d393bd062ab3b5735b777561a5d8f8a5f2e241164444d77a/simplejson-4.1.1-cp314-cp314t-win_amd64.whl", hash = "sha256:4cbb299d0528ec0447fe366d8c9641860e28f997a62730690fef905f1f41046e", size = 94467, upload-time = "2026-04-24T19:24:19.109Z" }, + { url = "https://files.pythonhosted.org/packages/ce/6a/8b74c52ffd33dbbde00fe7251fee6a0acdc8cea33f7a43805aed258fb79b/simplejson-4.1.1-py3-none-any.whl", hash = "sha256:2ce92b3748f02423e26d2bfb636fb9d7a8f67c8f5854dcae69d350d123b2eee2", size = 69195, upload-time = "2026-04-24T19:24:57.962Z" }, ] [[package]] @@ -9102,19 +8451,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/49/4b/359f28a903c13438ef59ebeee215fb25da53066db67b305c125f1c6d2a25/sqlparse-0.5.5-py3-none-any.whl", hash = "sha256:12a08b3bf3eec877c519589833aed092e2444e68240a3577e8e26148acc7b1ba", size = 46138, upload-time = "2025-12-19T07:17:46.573Z" }, ] -[[package]] -name = "sse-starlette" -version = "3.3.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio", marker = "sys_platform == 'linux'" }, - { name = "starlette", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/14/2f/9223c24f568bb7a0c03d751e609844dce0968f13b39a3f73fbb3a96cd27a/sse_starlette-3.3.3.tar.gz", hash = "sha256:72a95d7575fd5129bd0ae15275ac6432bb35ac542fdebb82889c24bb9f3f4049", size = 32420, upload-time = "2026-03-17T20:05:55.529Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/78/e2/b8cff57a67dddf9a464d7e943218e031617fb3ddc133aeeb0602ff5f6c85/sse_starlette-3.3.3-py3-none-any.whl", hash = "sha256:c5abb5082a1cc1c6294d89c5290c46b5f67808cfdb612b7ec27e8ba061c22e8d", size = 14329, upload-time = "2026-03-17T20:05:54.35Z" }, -] - [[package]] name = "stack-data" version = "0.6.3" @@ -9142,15 +8478,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/0d/13d1d239a25cbfb19e740db83143e95c772a1fe10202dda4b76792b114dd/starlette-0.52.1-py3-none-any.whl", hash = "sha256:0029d43eb3d273bc4f83a08720b4912ea4b071087a3b48db01b7c839f7954d74", size = 74272, upload-time = "2026-01-18T13:34:09.188Z" }, ] -[[package]] -name = "supervisor" -version = "4.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a9/b5/37e7a3706de436a8a2d75334711dad1afb4ddffab09f25e31d89e467542f/supervisor-4.3.0.tar.gz", hash = "sha256:4a2bf149adf42997e1bb44b70c43b613275ec9852c3edacca86a9166b27e945e", size = 468912, upload-time = "2025-08-23T18:25:02.418Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0e/65/5e726c372da8a5e35022a94388b12252710aad0c2351699c3d76ae8dba78/supervisor-4.3.0-py2.py3-none-any.whl", hash = "sha256:0bcb763fddafba410f35cbde226aa7f8514b9fb82eb05a0c85f6588d1c13f8db", size = 320736, upload-time = "2025-08-23T18:25:00.767Z" }, -] - [[package]] name = "sympy" version = "1.14.0" @@ -9553,7 +8880,7 @@ dependencies = [ { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvshmem-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, @@ -9603,18 +8930,26 @@ name = "torch-c-dlpack-ext" version = "0.1.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "torch", marker = "sys_platform == 'linux'" }, + { name = "torch" }, ] sdist = { url = "https://files.pythonhosted.org/packages/37/de/921b6491efce5c389a5ef9bbed3d2d6660005840dae488124173180859ab/torch_c_dlpack_ext-0.1.5.tar.gz", hash = "sha256:d06f0357d575d22a168cc77acb9020fc4bae30968ceb6718a055dcbe92bacabe", size = 12913, upload-time = "2026-01-12T11:25:08.484Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/65/66/c12a9bb3a5ddc0962c00467891bf1ffdda39a4d4780bf0fbbf54523ff34e/torch_c_dlpack_ext-0.1.5-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:56bd25a2af19280bf8a06aa62cff5510106f43235b9327d8561b3e9a659c4d84", size = 5076782, upload-time = "2026-01-12T11:24:37.868Z" }, { url = "https://files.pythonhosted.org/packages/20/e1/64e1e579d107064785549e70758e38a42376ab7e73d86897ed4beab10e74/torch_c_dlpack_ext-0.1.5-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fba674110e1fab0b176bb5a28223e157db65c90767d4ba74abdbee9f537b0e9d", size = 440949, upload-time = "2026-01-12T11:24:39.716Z" }, { url = "https://files.pythonhosted.org/packages/64/5c/3e1382a620824f92920ab3fae132d8fb4e85898284c99e0c6a7764e452ce/torch_c_dlpack_ext-0.1.5-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3448c4f0d64104d0b2e58080a7efa72304a04960c18f338024b80b13cd3eca26", size = 897768, upload-time = "2026-01-12T11:24:41.209Z" }, + { url = "https://files.pythonhosted.org/packages/54/4f/76ea1006b9038b496d01e916c91efd17cb782abde2491a261cf203f57e30/torch_c_dlpack_ext-0.1.5-cp311-cp311-win_amd64.whl", hash = "sha256:74676474e0afa9a4216c4755ea7cf05e8158be1d168f6bda669ba91097c263f2", size = 1479088, upload-time = "2026-01-12T11:24:42.436Z" }, + { url = "https://files.pythonhosted.org/packages/b1/67/10d236698525d7b7db4d74ec0a4b01f5b2db33968995fdd9ac6b4635e327/torch_c_dlpack_ext-0.1.5-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:c0f2bd51fcd99c0e5b50314e1985f2728c4941bfa821f065e6c30951d1f995ca", size = 5291237, upload-time = "2026-01-12T11:24:44.011Z" }, { url = "https://files.pythonhosted.org/packages/87/06/8d760997307a5c3be4384424667bf31aae0a42060838c532c7d846516175/torch_c_dlpack_ext-0.1.5-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3562ee411258676f9c38b8ad39306d1c8d027b6a86f6a87c920d2d009a9d1510", size = 443069, upload-time = "2026-01-12T11:24:45.451Z" }, { url = "https://files.pythonhosted.org/packages/e2/79/a914539b4785f3e44f891aa012a886edb8bc10fe081c440981c57543ce21/torch_c_dlpack_ext-0.1.5-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e6f9da4bb9af70e27facc777458be62e10dbbbddda7672d16138db0553c5a524", size = 897846, upload-time = "2026-01-12T11:24:48.168Z" }, + { url = "https://files.pythonhosted.org/packages/3a/e6/7d7a97a3953208d6d6ce749180c34d1dab48464ded9a76cecabe9d021ce6/torch_c_dlpack_ext-0.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:670fbbab70123cc228bed41693a3720757af57a0ad22669063c9db25321e8f55", size = 1482855, upload-time = "2026-01-12T11:24:49.581Z" }, + { url = "https://files.pythonhosted.org/packages/ca/c6/65346a201d921b616731311fc9941f15137672b444cebdad702cb52ccee0/torch_c_dlpack_ext-0.1.5-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:74acea2ed395cadda63342845b9e9ee7cd4537846223dacfb4431b4610109265", size = 1993243, upload-time = "2026-01-12T11:24:51.079Z" }, { url = "https://files.pythonhosted.org/packages/fd/ec/faf10be09a5812b1c5ec9922b53fb5def5fc4080b81a653b9347bb169ebb/torch_c_dlpack_ext-0.1.5-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49f1e99d13c64e22dac0a34a1560e9e5a398a49a9fa81df83053e04fde6ec5bd", size = 443798, upload-time = "2026-01-12T11:24:52.754Z" }, { url = "https://files.pythonhosted.org/packages/2d/68/f434b48700f3e04f33882f54d8d3910327b935f55e14ec49da7d607bf470/torch_c_dlpack_ext-0.1.5-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:debe62e5ef93e631065d6b9f6e60d3d39bae6b89fa1b25d9523f40b3efbf8aba", size = 755004, upload-time = "2026-01-12T11:24:54.004Z" }, + { url = "https://files.pythonhosted.org/packages/03/a8/cc64e563f05ea99bd79bdb43f71f0f46452d3acd734da4843ede5fc73a35/torch_c_dlpack_ext-0.1.5-cp313-cp313-win_amd64.whl", hash = "sha256:30e3eab616dbc81dfdb7492aca557be551a9163ba9b585f97394a42b336b113a", size = 999126, upload-time = "2026-01-12T11:24:55.44Z" }, + { url = "https://files.pythonhosted.org/packages/96/5e/449324ca8e81573e650b6851fc31c1038f750d1de85d0b185d788e1c7a3a/torch_c_dlpack_ext-0.1.5-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:cac94a4905d391889e679a8da31e46dc325af5d55d13b7c70c0ce3d71d1ced6d", size = 1982154, upload-time = "2026-01-12T11:24:58.038Z" }, { url = "https://files.pythonhosted.org/packages/20/62/11c05b99f69aa5152bca0313e0dfa6d125a020cf890dc888ef009aa7891c/torch_c_dlpack_ext-0.1.5-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3a58fdf45fb0bda7bc459632cec891570f31c11636d5851c825cf308ec8b73c2", size = 163825, upload-time = "2026-01-12T11:24:59.474Z" }, { url = "https://files.pythonhosted.org/packages/15/b5/be613cd8e71c9982bd07af530f86c5a7f30df7831d14cec5414857af7149/torch_c_dlpack_ext-0.1.5-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b985a324c68241cf83a9474b28015524b66775b12a91930dd4c0760aa628d01", size = 171740, upload-time = "2026-01-12T11:25:00.776Z" }, + { url = "https://files.pythonhosted.org/packages/5c/11/52e291f1659e2ec70a09f5ca4ad27e015eb4f0a1371ae68d23a9fbd1c704/torch_c_dlpack_ext-0.1.5-cp314-cp314-win_amd64.whl", hash = "sha256:d794e19fa3f330ab7a29987c07e031fc08e4953aec516d35701d0827863e356b", size = 277086, upload-time = "2026-01-12T11:25:01.901Z" }, ] [[package]] @@ -9626,28 +8961,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/3d/0c5a5833a135a045510e06c06b3d4cf316b06d59415bc21e0b021a000cc8/torchao-0.16.0-py3-none-any.whl", hash = "sha256:d0a8d773351fd17b95fee81dfbcbf98577b567dcdbec47d221b0ee258432101d", size = 1164150, upload-time = "2026-02-10T22:12:15.28Z" }, ] -[[package]] -name = "torchaudio" -version = "2.10.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "torch", marker = "sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/6f/b7/c66dc34a27441d78997e20d0ffe2f5ad73db9f7b1267511be255bb94ac9b/torchaudio-2.10.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:87c841a21e82703ebd4a29170c4e60c25a2b47312dc212930087ad58965ac0c8", size = 391843, upload-time = "2026-01-21T16:28:43.093Z" }, - { url = "https://files.pythonhosted.org/packages/13/ae/a2a34a64947c4fa4a61b4c86d8f36fbcb4ebfec30fdde140267db260f96c/torchaudio-2.10.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:b2c77fb9114dd463dc805560bf55a1ac2a52e219794cc32b7b32cf2aeffd2826", size = 1894140, upload-time = "2026-01-21T16:28:35.892Z" }, - { url = "https://files.pythonhosted.org/packages/ea/3f/df620439a76ece170472d41438d11a1545d5db5dc9f1eaeab8c6e055a328/torchaudio-2.10.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:42b148a0921a3721abd1f6ae098b1ec9f89703e555c4f7a0d44da87b8decbcb9", size = 391973, upload-time = "2026-01-21T16:28:39.732Z" }, - { url = "https://files.pythonhosted.org/packages/98/25/e55a30d7138f8fe56ed006df25b0a3c27681f0ec7bc9989e1778e6d559c3/torchaudio-2.10.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:0e77b2956448d63790a99beed0b74ac8b8cd3a94dcdd9ad01974411078f46278", size = 1895234, upload-time = "2026-01-21T16:28:37.034Z" }, - { url = "https://files.pythonhosted.org/packages/49/fd/831c2595c81b17141180ca11ab3c0836cc544ef13e15aa0e7b2cb619e582/torchaudio-2.10.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:5bc39ff3ea341097ce1ab023dd88c9dd8ca5f96ebf48821e7d23766137bb55d7", size = 392757, upload-time = "2026-01-21T16:28:33.631Z" }, - { url = "https://files.pythonhosted.org/packages/8e/d8/405c80c57dc68ca5855bddfaae57c3d84ea7397bf1eb2aa5d59c9fa1d3a9/torchaudio-2.10.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:3057c4286db5673d266124a2a10ca54e19f516772e9057f44573a7da5b85e328", size = 1897099, upload-time = "2026-01-21T16:28:24.793Z" }, - { url = "https://files.pythonhosted.org/packages/43/8c/653e7f67855424bf3b7cbb48335f8316f7fb02bb01a6cab38f6bf9555676/torchaudio-2.10.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:b41b254d958632dc00dc7768431cadda516c91641d798775cbb19bcd4f0d2be4", size = 393430, upload-time = "2026-01-21T16:28:34.855Z" }, - { url = "https://files.pythonhosted.org/packages/8e/1f/f91fcb9dd47a19b720fb48042a2f6f023651948e73726e98fff60d5ed5c7/torchaudio-2.10.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:da1081d1018a1e95f5a13947402aeb037cf5ac8861219a6164df004898a96bb1", size = 1897271, upload-time = "2026-01-21T16:28:23.519Z" }, - { url = "https://files.pythonhosted.org/packages/57/a1/ef5571406858f4ea89c18d6ad844d21cb9858708149e6bbd9a789ee30ea5/torchaudio-2.10.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:b2d5e11a2bec08f02a4f5fb7d1902ff82d48c533a27ceedc21e6ade650cf65b3", size = 393061, upload-time = "2026-01-21T16:28:25.802Z" }, - { url = "https://files.pythonhosted.org/packages/9d/0f/a0cf0ebc6f71b1868ea056dd4cd4f1a2244b8da8bc38372a1adc984a7c1f/torchaudio-2.10.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:77f6cf11a3b61af1b0967cd642368ecd30a86d70f622b22410ae6cb42d980b72", size = 1897137, upload-time = "2026-01-21T16:28:15.366Z" }, - { url = "https://files.pythonhosted.org/packages/53/8a/946aa07393845b918d318b5e34b3bd0359fd27fc9fac10a85fae2bb86382/torchaudio-2.10.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:ed912de8ec1b400e17a5172badcfcddc601a9cd4e02d200f3a9504fc8e54961c", size = 393434, upload-time = "2026-01-21T16:28:18.668Z" }, - { url = "https://files.pythonhosted.org/packages/e1/68/e37e8fbbae986afa80f8851e08fc017eb8ae5f7b398ee28ed92303da163e/torchaudio-2.10.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:f7aa33a8198e87949896e16ea245ea731906445becdf10130e8823c68494a94a", size = 1897289, upload-time = "2026-01-21T16:28:17.059Z" }, -] - [[package]] name = "torchvision" version = "0.25.0" @@ -10092,28 +9405,28 @@ wheels = [ [[package]] name = "uv" -version = "0.11.7" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/9b/7d/17750123a8c8e324627534fe1ae2e7a46689db8492f1a834ab4fd229a7d8/uv-0.11.7.tar.gz", hash = "sha256:46d971489b00bdb27e0aa715e4a5cd4ef2c28ea5b6ef78f2b67bf861eb44b405", size = 4083385, upload-time = "2026-04-15T21:42:55.474Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/5b/2bb2ab6fe6c78c2be10852482ef0cae5f3171460a6e5e24c32c9a0843163/uv-0.11.7-py3-none-linux_armv6l.whl", hash = "sha256:f422d39530516b1dfb28bb6e90c32bb7dacd50f6a383cd6e40c1a859419fbc8c", size = 23757265, upload-time = "2026-04-15T21:43:14.494Z" }, - { url = "https://files.pythonhosted.org/packages/b2/f5/36ff27b01e60a88712628c8a5a6003b8e418883c24e084e506095844a797/uv-0.11.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8b2fe1ec6775dad10183e3fdce430a5b37b7857d49763c884f3a67eaa8ca6f8a", size = 23184529, upload-time = "2026-04-15T21:42:30.225Z" }, - { url = "https://files.pythonhosted.org/packages/8a/fa/f379be661316698f877e78f4c51e5044be0b6f390803387237ad92c4057f/uv-0.11.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:162fa961a9a081dcea6e889c79f738a5ae56507047e4672964972e33c301bea9", size = 21780167, upload-time = "2026-04-15T21:42:44.942Z" }, - { url = "https://files.pythonhosted.org/packages/f2/7f/fbed29775b0612f4f5679d3226268f1a347161abc1727b4080fb41d9f46f/uv-0.11.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:5985a15a92bd9a170fc1947abb1fbc3e9828c5a430ad85b5bed8356c20b67a71", size = 23609640, upload-time = "2026-04-15T21:42:22.57Z" }, - { url = "https://files.pythonhosted.org/packages/ad/de/989a69634a869a22322770120557c2d8cbba5b77ec7cfad326b4ec0f0547/uv-0.11.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:fab0bb43fbbc0ee5b5fee212078d2300c371b725faff7cf72eeaafa0bff0606b", size = 23322484, upload-time = "2026-04-15T21:43:26.52Z" }, - { url = "https://files.pythonhosted.org/packages/24/08/c1af05ea602eb4eb75d86badb6b0594cc104c3ca83ccf06d9ed4dd2186ad/uv-0.11.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:23d457d6731ebdb83f1bffebe4894edab2ef43c1ec5488433c74300db4958924", size = 23326385, upload-time = "2026-04-15T21:42:41.32Z" }, - { url = "https://files.pythonhosted.org/packages/68/99/e246962da06383e992ecab55000c62a50fb36efef855ea7264fad4816bf4/uv-0.11.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d6a17507b8139b8803f445a03fd097f732ce8356b1b7b13cdb4dd8ef7f4b2e0", size = 24985751, upload-time = "2026-04-15T21:42:37.777Z" }, - { url = "https://files.pythonhosted.org/packages/45/2d/b0b68083859579ce811996c1480765ec6a2442b44c451eaef53e6218fbae/uv-0.11.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dd48823ca4b505124389f49ae50626ba9f57212b9047738efc95126ed5f3844d", size = 25724160, upload-time = "2026-04-15T21:43:18.762Z" }, - { url = "https://files.pythonhosted.org/packages/4e/19/5970e89d9e458fd3c4966bbc586a685a1c0ab0a8bf334503f63fa20b925b/uv-0.11.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb91f52ee67e10d5290f2c2897e2171357f1a10966de38d83eefa93d96843b0c", size = 25028512, upload-time = "2026-04-15T21:43:02.721Z" }, - { url = "https://files.pythonhosted.org/packages/83/eb/4e1557daf6693cb446ed28185664ad6682fd98c6dbac9e433cbc35df450a/uv-0.11.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e4d5e31bea86e1b6e0f5a0f95e14e80018e6f6c0129256d2915a4b3d793644d", size = 24933975, upload-time = "2026-04-15T21:42:18.828Z" }, - { url = "https://files.pythonhosted.org/packages/68/55/3b517ec8297f110d6981f525cccf26f86e30883fbb9c282769cffbcdcfca/uv-0.11.7-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:ceae53b202ea92bc954759bc7c7570cdcd5c3512fce15701198c19fd2dfb8605", size = 23706403, upload-time = "2026-04-15T21:43:10.664Z" }, - { url = "https://files.pythonhosted.org/packages/dc/30/7d93a0312d60e147722967036dc8ea37baab4802784bddc22464cb707deb/uv-0.11.7-py3-none-manylinux_2_31_riscv64.musllinux_1_1_riscv64.whl", hash = "sha256:f97e9f4e4d44fb5c4dfaa05e858ef3414a96416a2e4af270ecd88a3e5fb049a9", size = 24495797, upload-time = "2026-04-15T21:42:26.538Z" }, - { url = "https://files.pythonhosted.org/packages/8c/89/d49480bdab7725d36982793857e461d471bde8e1b7f438ffccee677a7bf8/uv-0.11.7-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:750ee5b96959b807cf442b73dd8b55111862d63f258f896787ea5f06b68aaca9", size = 24580471, upload-time = "2026-04-15T21:42:52.871Z" }, - { url = "https://files.pythonhosted.org/packages/b6/9f/c57dc03b48be17b564e304eb9ff982890c12dfb888b1ce370788733329ab/uv-0.11.7-py3-none-musllinux_1_1_i686.whl", hash = "sha256:f394331f0507e80ee732cb3df737589de53bed999dd02a6d24682f08c2f8ac4f", size = 24113637, upload-time = "2026-04-15T21:42:34.094Z" }, - { url = "https://files.pythonhosted.org/packages/13/ba/b87e358b629a68258527e3490e73b7b148770f4d2257842dea3b7981d4e8/uv-0.11.7-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:0df59ab0c6a4b14a763e8445e1c303af9abeb53cdfa4428daf9ff9642c0a3cce", size = 25119850, upload-time = "2026-04-15T21:43:22.529Z" }, - { url = "https://files.pythonhosted.org/packages/4b/74/16d229e1d8574bcbafa6dc643ac20b70c3e581f42ac31a6f4fd53035ffe3/uv-0.11.7-py3-none-win32.whl", hash = "sha256:553e67cc766d013ce24353fecd4ea5533d2aedcfd35f9fac430e07b1d1f23ed4", size = 22918454, upload-time = "2026-04-15T21:42:58.702Z" }, - { url = "https://files.pythonhosted.org/packages/a6/1d/b73e473da616ac758b8918fb218febcc46ddf64cba9e03894dfa226b28bd/uv-0.11.7-py3-none-win_amd64.whl", hash = "sha256:5674dfb5944513f4b3735b05c2deba6b1b01151f46729d533d413a9a905f8c5d", size = 25447744, upload-time = "2026-04-15T21:42:48.813Z" }, - { url = "https://files.pythonhosted.org/packages/1b/bb/e6bfdea92ed270f3445a5a3c17599d041b3f2dbc5026c09e02830a03bbaf/uv-0.11.7-py3-none-win_arm64.whl", hash = "sha256:6158b7e39464f1aa1e040daa0186cae4749a78b5cd80ac769f32ca711b8976b1", size = 23941816, upload-time = "2026-04-15T21:43:06.732Z" }, +version = "0.11.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/02/69a3b06fd8a91f95b79e95e14f5ccdd4df0f124c381aefe9d1e2784d5a65/uv-0.11.11.tar.gz", hash = "sha256:2ba46a912a1775957c579a1a42c8c8b480418502326b72427b1cad972c8f659f", size = 4112827, upload-time = "2026-05-06T20:04:47.982Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6c/54/39d3c58de992767834120fe3735b85cc60dd00a69b377c3d947ca6f172a1/uv-0.11.11-py3-none-linux_armv6l.whl", hash = "sha256:4977a1193e5dc9c2934b9f97d6cf787382f80deae17646640ee583cfc61486c0", size = 23537936, upload-time = "2026-05-06T20:04:58.626Z" }, + { url = "https://files.pythonhosted.org/packages/de/c9/d2d7ca30abf4c2d5ae0d9360a1e154115af176308ef1ecdc8bf7af724cf8/uv-0.11.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:92817f276758e41b4160fcb6d457ebd9f228f0473efe3808891164f326fdea38", size = 23068282, upload-time = "2026-05-06T20:05:01.466Z" }, + { url = "https://files.pythonhosted.org/packages/fa/37/f64decba47d7afaace3f238aa4a416dca947bd0a1a9b534c3a0f179e1016/uv-0.11.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6eec6ad051e6e5d922cd547b9f7b09a7f821597ae01900a6f01b0a01317e5fd0", size = 21671522, upload-time = "2026-05-06T20:05:04.382Z" }, + { url = "https://files.pythonhosted.org/packages/93/a6/c129878d7c2a66ffdaa12dc253d3135c5e10fc5b5e15812791e188c6dbec/uv-0.11.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:1d227bb53b701e533f0aa074dd145a6fa31492dc7d6d57a6e72a700b9a4a1991", size = 23283200, upload-time = "2026-05-06T20:04:39.879Z" }, + { url = "https://files.pythonhosted.org/packages/8f/c2/cff1f9ab7eda3d863e9866fca0e14df37c0fd734b66ebb77d751258b2fae/uv-0.11.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:05ee9f18701692fcb22db98085c041a3be7a35b88c710dea4487c293f42a4b95", size = 23081561, upload-time = "2026-05-06T20:05:07.149Z" }, + { url = "https://files.pythonhosted.org/packages/ca/44/ebd02ca8fae5961d1bcbcee11019dd170dd0d42517afad753281335700cc/uv-0.11.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0632af539d6a1ee00f58da9e7db32fd99e12187aa67426cb90d871154ab5debb", size = 23105780, upload-time = "2026-05-06T20:04:50.107Z" }, + { url = "https://files.pythonhosted.org/packages/86/f7/0741abcd70591a65f85fc4e8fecd3fb3fb4bdfe50042cccf016714955fd9/uv-0.11.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb3f2715551d2fc9ef44b6cf0918fcc556cd99e9bf6caa1d8a870a4657d2b180", size = 24542681, upload-time = "2026-05-06T20:04:53.014Z" }, + { url = "https://files.pythonhosted.org/packages/b1/42/46e7e35f1f39e39d4bf0f712479768cf8d33eb7f35b67fceaea43e975dfd/uv-0.11.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c86bd6460579857d7e359bdbfe6f688076c654481ae933151d1449f9ea672fb6", size = 25459284, upload-time = "2026-05-06T20:04:34.168Z" }, + { url = "https://files.pythonhosted.org/packages/e8/fc/efdb16e1a6c619b021259ac8d8e4b6afd97efb446054ea28761eb2e1a177/uv-0.11.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0f69f4df007c7506db8d7f77ccabd466a886ac21e9b04a479dd0cd22e26d2262", size = 24560769, upload-time = "2026-05-06T20:04:42.648Z" }, + { url = "https://files.pythonhosted.org/packages/4c/f8/a5d5bac297b1379719050788c6b852c6b3eefcb1e82d8465ed22c10cede7/uv-0.11.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5b9f31dab557b5ee4257d8c6ba2608a63c7278537cb0cd102cf6fc518e3fb5c", size = 24639659, upload-time = "2026-05-06T20:04:31.491Z" }, + { url = "https://files.pythonhosted.org/packages/ee/d5/f3be167a43192062f1409fd6b857a612665d331174293b4ffc73218872e1/uv-0.11.11-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:8e8faf2e5b3517155fd18e509b19b21135247d43b7fb9a8d61a44a53118d5ab7", size = 23388445, upload-time = "2026-05-06T20:04:25.199Z" }, + { url = "https://files.pythonhosted.org/packages/9f/cd/ef1f573ee8edd2beab9fcd2449121483829621b3b57f7ba3f35c56ef373b/uv-0.11.11-py3-none-manylinux_2_31_riscv64.musllinux_1_1_riscv64.whl", hash = "sha256:3f8c9a1bea743a3fe39e956455686f4d0dd25ef58e8d70dc11a45381fd7c50e5", size = 24114301, upload-time = "2026-05-06T20:04:28.586Z" }, + { url = "https://files.pythonhosted.org/packages/9d/be/9181158465719e875a6995c10af24e00cdefba3fe6c9c8cbb02d34b2ade7/uv-0.11.11-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:f68dc7b62050a26ac6b1491398aebbbf0fa5485627e73b1d626666a097dbab07", size = 24155126, upload-time = "2026-05-06T20:04:55.98Z" }, + { url = "https://files.pythonhosted.org/packages/71/9c/bb306f9964870847f02a931d1fff896726f8bafcf9ce917122ac1bfef14c/uv-0.11.11-py3-none-musllinux_1_1_i686.whl", hash = "sha256:29ddb0d9b24a30ff4360b94e3cb704e82cd5fda86dc224032251f33ab5ceb79e", size = 23824684, upload-time = "2026-05-06T20:05:10.305Z" }, + { url = "https://files.pythonhosted.org/packages/56/48/434a1cf4798ca200e0dcb36411ba38013edb6d3e1aeb4cd85e8a2d7db9ca/uv-0.11.11-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:505a31f2c30fa9e83b1853cab06c5b92e66341c914c6f20f3878903aa09a6f34", size = 24862560, upload-time = "2026-05-06T20:04:37.287Z" }, + { url = "https://files.pythonhosted.org/packages/63/3a/997cddf82917f084d486e1c268c7e94836190fd928c93aa3fb92caee9a7f/uv-0.11.11-py3-none-win32.whl", hash = "sha256:c1e0e3e18cc94680642eac3c3f19f2635c17dd058edcb41b78cbdc459f574eb4", size = 22573619, upload-time = "2026-05-06T20:04:45.35Z" }, + { url = "https://files.pythonhosted.org/packages/30/5f/db34b840f8d86833ef810de8150fc9ce01a03c779393e08eadbcc4c010d5/uv-0.11.11-py3-none-win_amd64.whl", hash = "sha256:36412b13f6287304789abdf40122d268cee548fce3573e07d148a29370181421", size = 25170135, upload-time = "2026-05-06T20:05:13.001Z" }, + { url = "https://files.pythonhosted.org/packages/2d/3e/f3ba2557b437ec5b1fde1e0d5248b723432dc90f09b0050f52695596fd2e/uv-0.11.11-py3-none-win_arm64.whl", hash = "sha256:011f42faf5d267a6681ea77e3f236f275cb4490efeecb9599de74dc7ad7df8f6", size = 23597162, upload-time = "2026-05-06T20:05:16.095Z" }, ] [[package]] @@ -10193,174 +9506,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/59/7d02447a55b2e55755011a647479041bc92a82e143f96a8195cb33bd0a1c/virtualenv-21.2.0-py3-none-any.whl", hash = "sha256:1bd755b504931164a5a496d217c014d098426cddc79363ad66ac78125f9d908f", size = 5825084, upload-time = "2026-03-09T17:24:35.378Z" }, ] -[[package]] -name = "vllm" -version = "0.17.0+art1" -source = { url = "https://github.com/vivekkalyan/vllm/releases/download/v0.17.0-art1/vllm-0.17.0%2Bart1-cp38-abi3-manylinux_2_31_x86_64.whl" } -dependencies = [ - { name = "aiohttp", marker = "sys_platform == 'linux'" }, - { name = "anthropic", marker = "sys_platform == 'linux'" }, - { name = "blake3", marker = "sys_platform == 'linux'" }, - { name = "cachetools", marker = "sys_platform == 'linux'" }, - { name = "cbor2", marker = "sys_platform == 'linux'" }, - { name = "cloudpickle", marker = "sys_platform == 'linux'" }, - { name = "compressed-tensors", marker = "sys_platform == 'linux'" }, - { name = "depyf", marker = "sys_platform == 'linux'" }, - { name = "diskcache", marker = "sys_platform == 'linux'" }, - { name = "einops", marker = "sys_platform == 'linux'" }, - { name = "fastapi", extra = ["standard"], marker = "sys_platform == 'linux'" }, - { name = "filelock", marker = "sys_platform == 'linux'" }, - { name = "flashinfer-python", marker = "sys_platform == 'linux'" }, - { name = "gguf", marker = "sys_platform == 'linux'" }, - { name = "grpcio", marker = "sys_platform == 'linux'" }, - { name = "grpcio-reflection", marker = "sys_platform == 'linux'" }, - { name = "ijson", marker = "sys_platform == 'linux'" }, - { name = "kaldi-native-fbank", marker = "sys_platform == 'linux'" }, - { name = "lark", marker = "sys_platform == 'linux'" }, - { name = "llguidance", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'ppc64le' and sys_platform == 'linux') or (platform_machine == 's390x' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "lm-format-enforcer", marker = "sys_platform == 'linux'" }, - { name = "mcp", marker = "sys_platform == 'linux'" }, - { name = "mistral-common", extra = ["image"], marker = "sys_platform == 'linux'" }, - { name = "model-hosting-container-standards", marker = "sys_platform == 'linux'" }, - { name = "msgspec", marker = "sys_platform == 'linux'" }, - { name = "ninja", marker = "sys_platform == 'linux'" }, - { name = "numba", marker = "sys_platform == 'linux'" }, - { name = "numpy", marker = "sys_platform == 'linux'" }, - { name = "nvidia-cutlass-dsl", marker = "sys_platform == 'linux'" }, - { name = "openai", marker = "sys_platform == 'linux'" }, - { name = "openai-harmony", marker = "sys_platform == 'linux'" }, - { name = "opencv-python-headless", marker = "sys_platform == 'linux'" }, - { name = "opentelemetry-api", marker = "sys_platform == 'linux'" }, - { name = "opentelemetry-exporter-otlp", marker = "sys_platform == 'linux'" }, - { name = "opentelemetry-sdk", marker = "sys_platform == 'linux'" }, - { name = "opentelemetry-semantic-conventions-ai", marker = "sys_platform == 'linux'" }, - { name = "outlines-core", marker = "sys_platform == 'linux'" }, - { name = "partial-json-parser", marker = "sys_platform == 'linux'" }, - { name = "pillow", marker = "sys_platform == 'linux'" }, - { name = "prometheus-client", marker = "sys_platform == 'linux'" }, - { name = "prometheus-fastapi-instrumentator", marker = "sys_platform == 'linux'" }, - { name = "protobuf", marker = "sys_platform == 'linux'" }, - { name = "psutil", marker = "sys_platform == 'linux'" }, - { name = "py-cpuinfo", marker = "sys_platform == 'linux'" }, - { name = "pybase64", marker = "sys_platform == 'linux'" }, - { name = "pydantic", marker = "sys_platform == 'linux'" }, - { name = "python-json-logger", marker = "sys_platform == 'linux'" }, - { name = "pyyaml", marker = "sys_platform == 'linux'" }, - { name = "pyzmq", marker = "sys_platform == 'linux'" }, - { name = "quack-kernels", marker = "sys_platform == 'linux'" }, - { name = "ray", extra = ["cgraph"], marker = "sys_platform == 'linux'" }, - { name = "regex", marker = "sys_platform == 'linux'" }, - { name = "requests", marker = "sys_platform == 'linux'" }, - { name = "sentencepiece", marker = "sys_platform == 'linux'" }, - { name = "setproctitle", marker = "sys_platform == 'linux'" }, - { name = "setuptools", marker = "python_full_version >= '3.12' and sys_platform == 'linux'" }, - { name = "six", marker = "python_full_version >= '3.12' and sys_platform == 'linux'" }, - { name = "tiktoken", marker = "sys_platform == 'linux'" }, - { name = "tokenizers", marker = "sys_platform == 'linux'" }, - { name = "torch", marker = "sys_platform == 'linux'" }, - { name = "torchaudio", marker = "sys_platform == 'linux'" }, - { name = "torchvision", marker = "sys_platform == 'linux'" }, - { name = "tqdm", marker = "sys_platform == 'linux'" }, - { name = "transformers", marker = "sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'linux'" }, - { name = "watchfiles", marker = "sys_platform == 'linux'" }, - { name = "xgrammar", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'ppc64le' and sys_platform == 'linux') or (platform_machine == 's390x' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] -wheels = [ - { url = "https://github.com/vivekkalyan/vllm/releases/download/v0.17.0-art1/vllm-0.17.0%2Bart1-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:dfe9f4bf82bb1fe677fdde81d0cd62702dedf252144847951b2fc13fa4932057" }, -] - -[package.metadata] -requires-dist = [ - { name = "aiohttp", specifier = ">=3.13.3" }, - { name = "anthropic", specifier = ">=0.71.0" }, - { name = "blake3" }, - { name = "cachetools" }, - { name = "cbor2" }, - { name = "cloudpickle" }, - { name = "compressed-tensors", specifier = "==0.13.0" }, - { name = "datasets", marker = "extra == 'bench'" }, - { name = "depyf", specifier = "==0.20.0" }, - { name = "diskcache", specifier = "==5.6.3" }, - { name = "einops" }, - { name = "fastapi", extras = ["standard"], specifier = ">=0.115.0" }, - { name = "fastsafetensors", marker = "extra == 'fastsafetensors'", specifier = ">=0.2.2" }, - { name = "filelock", specifier = ">=3.16.1" }, - { name = "flashinfer-python", specifier = "==0.6.4" }, - { name = "gguf", specifier = ">=0.17.0" }, - { name = "grpcio" }, - { name = "grpcio-reflection" }, - { name = "helion", marker = "extra == 'helion'" }, - { name = "ijson" }, - { name = "kaldi-native-fbank", specifier = ">=1.18.7" }, - { name = "lark", specifier = "==1.2.2" }, - { name = "librosa", marker = "extra == 'audio'" }, - { name = "llguidance", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 's390x' or platform_machine == 'x86_64'", specifier = ">=1.3.0,<1.4.0" }, - { name = "lm-format-enforcer", specifier = "==0.11.3" }, - { name = "matplotlib", marker = "extra == 'bench'" }, - { name = "mcp" }, - { name = "mistral-common", extras = ["audio"], marker = "extra == 'audio'" }, - { name = "mistral-common", extras = ["image"], specifier = ">=1.9.1" }, - { name = "model-hosting-container-standards", specifier = ">=0.1.13,<1.0.0" }, - { name = "msgspec" }, - { name = "ninja" }, - { name = "numba", specifier = "==0.61.2" }, - { name = "numpy" }, - { name = "nvidia-cutlass-dsl", specifier = ">=4.4.0.dev1" }, - { name = "openai", specifier = ">=1.99.1,<2.25.0" }, - { name = "openai-harmony", specifier = ">=0.0.3" }, - { name = "opencv-python-headless", specifier = ">=4.13.0" }, - { name = "opentelemetry-api", specifier = ">=1.27.0" }, - { name = "opentelemetry-api", marker = "extra == 'otel'", specifier = ">=1.26.0" }, - { name = "opentelemetry-exporter-otlp", specifier = ">=1.27.0" }, - { name = "opentelemetry-exporter-otlp", marker = "extra == 'otel'", specifier = ">=1.26.0" }, - { name = "opentelemetry-sdk", specifier = ">=1.27.0" }, - { name = "opentelemetry-sdk", marker = "extra == 'otel'", specifier = ">=1.26.0" }, - { name = "opentelemetry-semantic-conventions-ai", specifier = ">=0.4.1" }, - { name = "opentelemetry-semantic-conventions-ai", marker = "extra == 'otel'", specifier = ">=0.4.1" }, - { name = "outlines-core", specifier = "==0.2.11" }, - { name = "pandas", marker = "extra == 'bench'" }, - { name = "partial-json-parser" }, - { name = "petit-kernel", marker = "extra == 'petit-kernel'" }, - { name = "pillow" }, - { name = "plotly", marker = "extra == 'bench'" }, - { name = "prometheus-client", specifier = ">=0.18.0" }, - { name = "prometheus-fastapi-instrumentator", specifier = ">=7.0.0" }, - { name = "protobuf", specifier = ">=5.29.6,!=6.30.*,!=6.31.*,!=6.32.*,!=6.33.0.*,!=6.33.1.*,!=6.33.2.*,!=6.33.3.*,!=6.33.4.*" }, - { name = "psutil" }, - { name = "py-cpuinfo" }, - { name = "pybase64" }, - { name = "pydantic", specifier = ">=2.12.0" }, - { name = "python-json-logger" }, - { name = "pyyaml" }, - { name = "pyzmq", specifier = ">=25.0.0" }, - { name = "quack-kernels", specifier = ">=0.2.7" }, - { name = "ray", extras = ["cgraph"], specifier = ">=2.48.0" }, - { name = "regex" }, - { name = "requests", specifier = ">=2.26.0" }, - { name = "runai-model-streamer", extras = ["gcs", "s3"], marker = "extra == 'runai'", specifier = ">=0.15.3" }, - { name = "scipy", marker = "extra == 'audio'" }, - { name = "scipy", marker = "extra == 'bench'" }, - { name = "seaborn", marker = "extra == 'bench'" }, - { name = "sentencepiece" }, - { name = "setproctitle" }, - { name = "setuptools", marker = "python_full_version >= '3.12'", specifier = ">=77.0.3,<81.0.0" }, - { name = "six", marker = "python_full_version >= '3.12'", specifier = ">=1.16.0" }, - { name = "soundfile", marker = "extra == 'audio'" }, - { name = "tensorizer", marker = "extra == 'tensorizer'", specifier = "==2.10.1" }, - { name = "tiktoken", specifier = ">=0.6.0" }, - { name = "tokenizers", specifier = ">=0.21.1" }, - { name = "torch", specifier = "==2.10.0" }, - { name = "torchaudio", specifier = "==2.10.0" }, - { name = "torchvision", specifier = "==0.25.0" }, - { name = "tqdm" }, - { name = "transformers", specifier = ">=4.56.0,<5.3" }, - { name = "typing-extensions", specifier = ">=4.10" }, - { name = "watchfiles" }, - { name = "xgrammar", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 's390x' or platform_machine == 'x86_64'", specifier = "==0.1.29" }, -] -provides-extras = ["bench", "tensorizer", "fastsafetensors", "runai", "audio", "video", "flashinfer", "petit-kernel", "helion", "otel"] - [[package]] name = "waitress" version = "3.0.2" @@ -10831,27 +9976,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/49/0b/88c39c128a05d5b553a67cb9c4c3fc32eefb91f836f838befab9e78f8364/xformers-0.0.35-py39-none-win_amd64.whl", hash = "sha256:57381ce3cbb79b593e6b62cb20a937885345fad2796de2aa6fbb66c033601179", size = 2638618, upload-time = "2026-02-20T20:33:04.104Z" }, ] -[[package]] -name = "xgrammar" -version = "0.1.29" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy", marker = "sys_platform == 'linux'" }, - { name = "pydantic", marker = "sys_platform == 'linux'" }, - { name = "torch", marker = "sys_platform == 'linux'" }, - { name = "transformers", marker = "sys_platform == 'linux'" }, - { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/02/a3/70dbe3ffd331a1e7e1ad5a95690a4086e6c7cdb8089f5c7eda712219ccec/xgrammar-0.1.29.tar.gz", hash = "sha256:cf195afa81b489eebf35d4c6f37f27136d05420739ab4a6f7f065c938d7e4baa", size = 2321317, upload-time = "2025-12-19T08:23:54.53Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/87/0b/b5e5c99ce13a9d378a940cda07c5a08b50cc7efb66936c6ac8fa8232a0d5/xgrammar-0.1.29-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51bcfd63bd48a0b26209ffd2143a42067518559355ec9e4e574cef2ae74fac7c", size = 34699408, upload-time = "2025-12-19T08:23:16.906Z" }, - { url = "https://files.pythonhosted.org/packages/a3/a0/4ebc1b3f5af79a3f73d0566034758f3fbcd9c64174646314a9a6f7cc1d27/xgrammar-0.1.29-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e27b50cf8c565845295a8263a4a0790c00a7c1fd783e76222fc0f575654d6f56", size = 34903461, upload-time = "2025-12-19T08:23:19.556Z" }, - { url = "https://files.pythonhosted.org/packages/57/94/18793c64bf0368075a34c06e196bf002f1e6ab0aee332268f44e8d356d5a/xgrammar-0.1.29-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6eb370a16b27a683e5f2b9e429ab41440c69977d4a504849ed61831b94cc704c", size = 34705239, upload-time = "2025-12-19T08:23:28.369Z" }, - { url = "https://files.pythonhosted.org/packages/3e/da/4c14e3e00be698009b52700f15326a23272b4b00475939b6acc86b151188/xgrammar-0.1.29-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79e6e4f5cd33be77418cf91efc482f2b3d773d309891224383bc8a4948ad7b07", size = 34906135, upload-time = "2025-12-19T08:23:30.838Z" }, - { url = "https://files.pythonhosted.org/packages/e9/c5/e4965c9921e7bb6061f246ae7f8c7b9b1dfc21262248100c2f9b398b361e/xgrammar-0.1.29-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb22aea775971f7d8c4d0e193257ebeb71b68acd9d36af3331ca5fd4d9a46991", size = 34904126, upload-time = "2025-12-19T08:23:38.335Z" }, -] - [[package]] name = "xxhash" version = "3.6.0" diff --git a/vllm_runtime/pyproject.toml b/vllm_runtime/pyproject.toml new file mode 100644 index 000000000..7d8bed9e5 --- /dev/null +++ b/vllm_runtime/pyproject.toml @@ -0,0 +1,41 @@ +[project] +name = "art-vllm-runtime" +version = "0.1.0" +description = "Tiny ART-owned vLLM runtime package" +requires-python = ">=3.12,<3.13" +dependencies = [ + "nvidia-nccl-cu12==2.28.9 ; sys_platform == 'linux'", + "transformers==5.6.2", + "vllm @ https://wheels.vllm.ai/ecd0b60aad2f4e28dd00ababfc1402690d88cbed/vllm-0.20.2rc1.dev168%2Bgecd0b60aa.cu129-cp38-abi3-manylinux_2_34_x86_64.whl ; sys_platform == 'linux'", +] + +[project.scripts] +art-vllm-runtime-server = "art_vllm_runtime.dedicated_server:main" + +[project.entry-points."vllm.general_plugins"] +art = "art_vllm_runtime.patches:apply_vllm_runtime_patches" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/art_vllm_runtime"] + +[tool.hatch.build] +sources = ["src"] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.uv] +required-version = ">=0.6.15" +override-dependencies = [ + "flashinfer-python==0.6.8.post1", + "numpy<2", + "nvidia-nccl-cu12==2.28.9 ; sys_platform == 'linux'", + "torch @ https://download.pytorch.org/whl/test/cu128/torch-2.11.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", + "torchaudio @ https://download.pytorch.org/whl/test/cu128/torchaudio-2.11.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", + "torchvision @ https://download.pytorch.org/whl/test/cu128/torchvision-0.26.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", + "transformers==5.6.2", +] diff --git a/vllm_runtime/src/art_vllm_runtime/__init__.py b/vllm_runtime/src/art_vllm_runtime/__init__.py new file mode 100644 index 000000000..80e13097f --- /dev/null +++ b/vllm_runtime/src/art_vllm_runtime/__init__.py @@ -0,0 +1,15 @@ +from art_vllm_runtime.patches import ( + apply_vllm_runtime_patches, + patch_listen_for_disconnect, + patch_tool_parser_manager, + patch_transformers_v5_compat, + subclass_chat_completion_request, +) + +__all__ = [ + "apply_vllm_runtime_patches", + "patch_listen_for_disconnect", + "patch_tool_parser_manager", + "patch_transformers_v5_compat", + "subclass_chat_completion_request", +] diff --git a/src/art/vllm/dedicated_server.py b/vllm_runtime/src/art_vllm_runtime/dedicated_server.py similarity index 64% rename from src/art/vllm/dedicated_server.py rename to vllm_runtime/src/art_vllm_runtime/dedicated_server.py index 47921be6b..73590f03b 100644 --- a/src/art/vllm/dedicated_server.py +++ b/vllm_runtime/src/art_vllm_runtime/dedicated_server.py @@ -1,17 +1,13 @@ -"""Dedicated vLLM subprocess entry point. - -Launched by UnslothService in dedicated mode as: - python -m art.vllm.dedicated_server --model --port ... - -Sets CUDA_VISIBLE_DEVICES and applies ART patches before starting vLLM. -Must be imported/run as a standalone process — not imported into the main training process. -""" +"""Dedicated vLLM subprocess entry point for the ART-owned runtime package.""" import argparse import asyncio +from http import HTTPStatus import json import os +from art_vllm_runtime.patches import apply_vllm_runtime_patches + def parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser = argparse.ArgumentParser(description="ART dedicated vLLM server") @@ -38,24 +34,51 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: return parser.parse_args(argv) -def _patch_art_dedicated_routes() -> None: - from fastapi import APIRouter, FastAPI, Request +def _patch_art_runtime_routes() -> None: + from fastapi import APIRouter, FastAPI, Query, Request from fastapi.responses import JSONResponse from vllm.entrypoints.openai import api_server - from vllm.tasks import SupportedTask - if getattr(api_server, "_art_dedicated_routes_patched", False): + if getattr(api_server, "_art_runtime_routes_patched", False): return original_build_app = api_server.build_app - def art_build_app( - args: argparse.Namespace, - supported_tasks: tuple[SupportedTask, ...] | None = None, - ) -> FastAPI: - app = original_build_app(args, supported_tasks) + def art_build_app(*build_args: object, **build_kwargs: object) -> FastAPI: + app = original_build_app(*build_args, **build_kwargs) router = APIRouter() + def engine(request: Request): + return request.app.state.engine_client + + @router.post("/sleep") + async def sleep( + raw_request: Request, + level: int = Query(default=1, ge=0, le=2), + mode: str = Query(default="abort", pattern="^(abort|wait|keep)$"), + ) -> JSONResponse: + try: + await engine(raw_request).sleep(level=level, mode=mode) + except ValueError as err: + return JSONResponse( + content={"error": str(err)}, + status_code=HTTPStatus.BAD_REQUEST.value, + ) + return JSONResponse( + content={"status": "sleeping", "level": level, "mode": mode} + ) + + @router.post("/wake_up") + async def wake_up(raw_request: Request) -> JSONResponse: + await engine(raw_request).wake_up() + return JSONResponse(content={"status": "awake"}) + + @router.get("/is_sleeping") + async def is_sleeping(raw_request: Request) -> JSONResponse: + return JSONResponse( + content={"is_sleeping": await engine(raw_request).is_sleeping()} + ) + @router.post("/art/set_served_model_name") async def set_served_model_name(raw_request: Request) -> JSONResponse: body = await raw_request.json() @@ -70,7 +93,7 @@ async def set_served_model_name(raw_request: Request) -> JSONResponse: return app setattr(api_server, "build_app", art_build_app) - setattr(api_server, "_art_dedicated_routes_patched", True) + setattr(api_server, "_art_runtime_routes_patched", True) def _append_cli_arg(vllm_args: list[str], key: str, value: object) -> None: @@ -85,6 +108,19 @@ def _append_cli_arg(vllm_args: list[str], key: str, value: object) -> None: case dict(): vllm_args.append(f"{cli_key}={json.dumps(value)}") case list(): + if key == "lora_target_modules": + vllm_args.append(cli_key) + for item in value: + match item: + case str() | int() | float(): + vllm_args.append(str(item)) + case dict(): + vllm_args.append(json.dumps(item)) + case _: + assert False, ( + f"Unsupported CLI list item for {key}: {type(item)}" + ) + return for item in value: match item: case str() | int() | float(): @@ -102,22 +138,12 @@ def _append_cli_arg(vllm_args: list[str], key: str, value: object) -> None: def main(argv: list[str] | None = None) -> None: args = parse_args(argv) - # Must set CUDA_VISIBLE_DEVICES before any torch/CUDA import os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "1" if args.rollout_weights_mode == "merged": os.environ["VLLM_SERVER_DEV_MODE"] = "1" - # Patches must be applied before vLLM's api_server is imported - from .patches import ( - patch_listen_for_disconnect, - patch_tool_parser_manager, - subclass_chat_completion_request, - ) - - subclass_chat_completion_request() - patch_listen_for_disconnect() - patch_tool_parser_manager() + apply_vllm_runtime_patches() from vllm.entrypoints.openai import api_server from vllm.entrypoints.openai.cli_args import ( @@ -129,8 +155,7 @@ def main(argv: list[str] | None = None) -> None: engine_args = json.loads(args.engine_args_json) server_args = json.loads(args.server_args_json) - if args.rollout_weights_mode == "merged": - _patch_art_dedicated_routes() + _patch_art_runtime_routes() vllm_args = [ f"--model={args.model}", @@ -155,9 +180,6 @@ def main(argv: list[str] | None = None) -> None: vllm_parser = make_arg_parser(vllm_parser) namespace = vllm_parser.parse_args(vllm_args) validate_parsed_serve_args(namespace) - - # stdout/stderr are captured to a log file by the parent process, - # so no separate uvicorn file handler is needed here. asyncio.run(api_server.run_server(namespace)) diff --git a/vllm_runtime/src/art_vllm_runtime/patches.py b/vllm_runtime/src/art_vllm_runtime/patches.py new file mode 100644 index 000000000..aed69b601 --- /dev/null +++ b/vllm_runtime/src/art_vllm_runtime/patches.py @@ -0,0 +1,162 @@ +"""Monkey patches and bootstrap contract for the ART-owned vLLM runtime.""" + +import ctypes +from typing import Any + + +def apply_vllm_runtime_patches() -> None: + patch_transformers_v5_compat() + subclass_chat_completion_request() + patch_listen_for_disconnect() + patch_tool_parser_manager() + patch_nccl_unique_id_bootstrap() + + +def patch_transformers_v5_compat() -> None: + _patch_rope_validation_ignore_keys() + _patch_qwen3_vl_moe_tie_word_embeddings() + + +def _patch_rope_validation_ignore_keys() -> None: + from transformers.configuration_utils import PretrainedConfig + + original = PretrainedConfig.convert_rope_params_to_dict + if getattr(original, "__art_patched__", False): + return + + def patched(self: Any, ignore_keys_at_rope_validation: Any = None, **kwargs: Any): + if ignore_keys_at_rope_validation is not None: + ignore_keys_at_rope_validation = set(ignore_keys_at_rope_validation) + return original( + self, + ignore_keys_at_rope_validation=ignore_keys_at_rope_validation, + **kwargs, + ) + + patched.__art_patched__ = True # type: ignore[attr-defined] + PretrainedConfig.convert_rope_params_to_dict = patched # type: ignore[method-assign] + + +def _patch_qwen3_vl_moe_tie_word_embeddings() -> None: + from transformers import Qwen3VLMoeTextConfig + + setattr(Qwen3VLMoeTextConfig, "tie_word_embeddings", False) + + +def subclass_chat_completion_request() -> None: + from vllm.entrypoints.openai.chat_completion import protocol + + if getattr(protocol, "_art_chat_completion_request_patched", False): + return + + class ChatCompletionRequest(protocol.ChatCompletionRequest): + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) # ty:ignore[invalid-argument-type] + self.logprobs = True + if self.top_logprobs is None: + self.top_logprobs = 0 + + protocol.ChatCompletionRequest = ChatCompletionRequest # ty:ignore[invalid-assignment] + setattr(protocol, "_art_chat_completion_request_patched", True) + + +def patch_listen_for_disconnect() -> None: + import vllm.entrypoints.utils + + if getattr(vllm.entrypoints.utils, "_art_listen_for_disconnect_patched", False): + return + + async def patched_listen_for_disconnect(request: Any) -> None: + try: + while True: + message = await request.receive() + if message["type"] == "http.disconnect": + break + except UnboundLocalError: + pass + + vllm.entrypoints.utils.listen_for_disconnect = patched_listen_for_disconnect # ty:ignore[invalid-assignment] + setattr(vllm.entrypoints.utils, "_art_listen_for_disconnect_patched", True) + + +def patch_tool_parser_manager() -> None: + from vllm.entrypoints.openai.engine.protocol import DeltaMessage + from vllm.tool_parsers.abstract_tool_parser import ToolParserManager + + original = ToolParserManager.get_tool_parser + if getattr(original, "__art_patched__", False): + return + + def patched_get_tool_parser(name: str) -> type: + tool_parser_class = original(name) + current = tool_parser_class.extract_tool_calls_streaming + if getattr(current, "__art_patched__", False): + return tool_parser_class + + def patch( + *args: Any, + **kwargs: Any, + ) -> Any: + return current(*args, **kwargs) or DeltaMessage() + + patch.__art_patched__ = True # type: ignore[attr-defined] + tool_parser_class.extract_tool_calls_streaming = patch # ty:ignore[invalid-assignment] + return tool_parser_class + + patched_get_tool_parser.__art_patched__ = True # type: ignore[attr-defined] + ToolParserManager.get_tool_parser = patched_get_tool_parser # ty:ignore[invalid-assignment] + + +def _restore_nccl_unique_id_payload( + payload: object, + template: object | None, +) -> object: + from vllm.distributed.device_communicators.pynccl_wrapper import ncclUniqueId + + if not isinstance(payload, (bytes, bytearray)) or not isinstance( + template, ncclUniqueId + ): + return payload + raw = bytes(payload) + assert len(raw) == ctypes.sizeof(ncclUniqueId) + unique_id = ncclUniqueId() + ctypes.memmove(ctypes.byref(unique_id), raw, len(raw)) + return unique_id + + +def _normalize_nccl_comm_init_rank_unique_id(library: Any, unique_id: object) -> object: + if isinstance(unique_id, (bytes, bytearray)): + return library.unique_id_from_bytes(bytes(unique_id)) + return unique_id + + +def patch_nccl_unique_id_bootstrap() -> None: + from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary + from vllm.distributed.utils import StatelessProcessGroup + + original_broadcast = StatelessProcessGroup.broadcast_obj + if not getattr(original_broadcast, "__art_patched__", False): + + def patched_broadcast(self: Any, obj: Any | None, src: int) -> Any: + return _restore_nccl_unique_id_payload( + original_broadcast(self, obj, src), obj + ) + + patched_broadcast.__art_patched__ = True # type: ignore[attr-defined] + StatelessProcessGroup.broadcast_obj = patched_broadcast # type: ignore[method-assign] + + original_comm_init_rank = NCCLLibrary.ncclCommInitRank + if getattr(original_comm_init_rank, "__art_patched__", False): + return + + def patched_comm_init_rank( + self: Any, + world_size: int, + unique_id: object, + rank: int, + ) -> Any: + unique_id = _normalize_nccl_comm_init_rank_unique_id(self, unique_id) + return original_comm_init_rank(self, world_size, unique_id, rank) + + patched_comm_init_rank.__art_patched__ = True # type: ignore[attr-defined] + NCCLLibrary.ncclCommInitRank = patched_comm_init_rank # type: ignore[method-assign] diff --git a/vllm_runtime/uv.lock b/vllm_runtime/uv.lock new file mode 100644 index 000000000..1956cd581 --- /dev/null +++ b/vllm_runtime/uv.lock @@ -0,0 +1,2699 @@ +version = 1 +revision = 3 +requires-python = "==3.12.*" + +[manifest] +overrides = [ + { name = "flashinfer-python", specifier = "==0.6.8.post1" }, + { name = "numpy", specifier = "<2" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'", specifier = "==2.28.9" }, + { name = "torch", url = "https://download.pytorch.org/whl/test/cu128/torch-2.11.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl" }, + { name = "torchaudio", url = "https://download.pytorch.org/whl/test/cu128/torchaudio-2.11.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl" }, + { name = "torchvision", url = "https://download.pytorch.org/whl/test/cu128/torchvision-0.26.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl" }, + { name = "transformers", specifier = "==5.6.2" }, +] + +[[package]] +name = "aiohappyeyeballs" +version = "2.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/26/30/f84a107a9c4331c14b2b586036f40965c128aa4fee4dda5d3d51cb14ad54/aiohappyeyeballs-2.6.1.tar.gz", hash = "sha256:c3f9d0113123803ccadfdf3f0faa505bc78e6a72d1cc4806cbd719826e943558", size = 22760, upload-time = "2025-03-12T01:42:48.764Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0f/15/5bf3b99495fb160b63f95972b81750f18f7f4e02ad051373b669d17d44f2/aiohappyeyeballs-2.6.1-py3-none-any.whl", hash = "sha256:f349ba8f4b75cb25c99c5c2d84e997e485204d2902a9597802b0371f09331fb8", size = 15265, upload-time = "2025-03-12T01:42:47.083Z" }, +] + +[[package]] +name = "aiohttp" +version = "3.13.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohappyeyeballs" }, + { name = "aiosignal" }, + { name = "attrs" }, + { name = "frozenlist" }, + { name = "multidict" }, + { name = "propcache" }, + { name = "yarl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/77/9a/152096d4808df8e4268befa55fba462f440f14beab85e8ad9bf990516918/aiohttp-3.13.5.tar.gz", hash = "sha256:9d98cc980ecc96be6eb4c1994ce35d28d8b1f5e5208a23b421187d1209dbb7d1", size = 7858271, upload-time = "2026-03-31T22:01:03.343Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/67/84/c9ecc5828cb0b3695856c07c0a6817a99d51e2473400f705275a2b3d9239/aiohttp-3.13.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a60eaa2d440cd4707696b52e40ed3e2b0f73f65be07fd0ef23b6b539c9c0b0b4", size = 1749199, upload-time = "2026-03-31T21:57:41.938Z" }, + { url = "https://files.pythonhosted.org/packages/f0/d3/3c6d610e66b495657622edb6ae7c7fd31b2e9086b4ec50b47897ad6042a9/aiohttp-3.13.5-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:55b3bdd3292283295774ab585160c4004f4f2f203946997f49aac032c84649e9", size = 1721013, upload-time = "2026-03-31T21:57:43.904Z" }, + { url = "https://files.pythonhosted.org/packages/49/a0/24409c12217456df0bae7babe3b014e460b0b38a8e60753d6cb339f6556d/aiohttp-3.13.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c2b2355dc094e5f7d45a7bb262fe7207aa0460b37a0d87027dcf21b5d890e7d5", size = 1781501, upload-time = "2026-03-31T21:57:46.285Z" }, + { url = "https://files.pythonhosted.org/packages/98/9d/b65ec649adc5bccc008b0957a9a9c691070aeac4e41cea18559fef49958b/aiohttp-3.13.5-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b38765950832f7d728297689ad78f5f2cf79ff82487131c4d26fe6ceecdc5f8e", size = 1878981, upload-time = "2026-03-31T21:57:48.734Z" }, + { url = "https://files.pythonhosted.org/packages/57/d8/8d44036d7eb7b6a8ec4c5494ea0c8c8b94fbc0ed3991c1a7adf230df03bf/aiohttp-3.13.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b18f31b80d5a33661e08c89e202edabf1986e9b49c42b4504371daeaa11b47c1", size = 1767934, upload-time = "2026-03-31T21:57:51.171Z" }, + { url = "https://files.pythonhosted.org/packages/31/04/d3f8211f273356f158e3464e9e45484d3fb8c4ce5eb2f6fe9405c3273983/aiohttp-3.13.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:33add2463dde55c4f2d9635c6ab33ce154e5ecf322bd26d09af95c5f81cfa286", size = 1566671, upload-time = "2026-03-31T21:57:53.326Z" }, + { url = "https://files.pythonhosted.org/packages/41/db/073e4ebe00b78e2dfcacff734291651729a62953b48933d765dc513bf798/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:327cc432fdf1356fb4fbc6fe833ad4e9f6aacb71a8acaa5f1855e4b25910e4a9", size = 1705219, upload-time = "2026-03-31T21:57:55.385Z" }, + { url = "https://files.pythonhosted.org/packages/48/45/7dfba71a2f9fd97b15c95c06819de7eb38113d2cdb6319669195a7d64270/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7c35b0bf0b48a70b4cb4fc5d7bed9b932532728e124874355de1a0af8ec4bc88", size = 1743049, upload-time = "2026-03-31T21:57:57.341Z" }, + { url = "https://files.pythonhosted.org/packages/18/71/901db0061e0f717d226386a7f471bb59b19566f2cae5f0d93874b017271f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:df23d57718f24badef8656c49743e11a89fd6f5358fa8a7b96e728fda2abf7d3", size = 1749557, upload-time = "2026-03-31T21:57:59.626Z" }, + { url = "https://files.pythonhosted.org/packages/08/d5/41eebd16066e59cd43728fe74bce953d7402f2b4ddfdfef2c0e9f17ca274/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:02e048037a6501a5ec1f6fc9736135aec6eb8a004ce48838cb951c515f32c80b", size = 1558931, upload-time = "2026-03-31T21:58:01.972Z" }, + { url = "https://files.pythonhosted.org/packages/30/e6/4a799798bf05740e66c3a1161079bda7a3dd8e22ca392481d7a7f9af82a6/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:31cebae8b26f8a615d2b546fee45d5ffb76852ae6450e2a03f42c9102260d6fe", size = 1774125, upload-time = "2026-03-31T21:58:04.007Z" }, + { url = "https://files.pythonhosted.org/packages/84/63/7749337c90f92bc2cb18f9560d67aa6258c7060d1397d21529b8004fcf6f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:888e78eb5ca55a615d285c3c09a7a91b42e9dd6fc699b166ebd5dee87c9ccf14", size = 1732427, upload-time = "2026-03-31T21:58:06.337Z" }, +] + +[[package]] +name = "aiosignal" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "frozenlist" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/61/62/06741b579156360248d1ec624842ad0edf697050bbaf7c3e46394e106ad1/aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7", size = 25007, upload-time = "2025-07-03T22:54:43.528Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, +] + +[[package]] +name = "annotated-doc" +version = "0.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" }, +] + +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081, upload-time = "2024-05-20T21:33:25.928Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, +] + +[[package]] +name = "anthropic" +version = "0.92.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "docstring-parser" }, + { name = "httpx" }, + { name = "jiter" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/01/2d/fc5c5a369db977efbaa646d77ba42b38a6de4e95789884032b0e2e3fc834/anthropic-0.92.0.tar.gz", hash = "sha256:d1e792ed0692379452a1af6b266df495e973c3695cd0aace2a108b838393cbc4", size = 652420, upload-time = "2026-04-08T16:55:35.37Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/21/bf5b5ab10b6932c5c43eaa66b6e3f256de569cf0323d89f9cc281a0d0f39/anthropic-0.92.0-py3-none-any.whl", hash = "sha256:f92a4bd065d5cab90a96b65bb44e473bf7c6fe731a743cd156e9ad1d245c381e", size = 621195, upload-time = "2026-04-08T16:55:33.639Z" }, +] + +[[package]] +name = "anyio" +version = "4.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/14/2c5dd9f512b66549ae92767a9c7b330ae88e1932ca57876909410251fe13/anyio-4.13.0.tar.gz", hash = "sha256:334b70e641fd2221c1505b3890c69882fe4a2df910cba14d97019b90b24439dc", size = 231622, upload-time = "2026-03-24T12:59:09.671Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/da/42/e921fccf5015463e32a3cf6ee7f980a6ed0f395ceeaa45060b61d86486c2/anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708", size = 114353, upload-time = "2026-03-24T12:59:08.246Z" }, +] + +[[package]] +name = "apache-tvm-ffi" +version = "0.1.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6f/60/1e787a0b5ebf318483235be2a689ee367173983067e441b8379564f667c0/apache_tvm_ffi-0.1.9.tar.gz", hash = "sha256:d2d402587e8906de0a07f4746aa78f3d452c7efe3625d4bb39ac2ad693bce530", size = 2513731, upload-time = "2026-02-27T19:28:06.602Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/c0/6d3d54f50012255b41bc3e24944c086f63c4707c8686c7c6780e9283eb96/apache_tvm_ffi-0.1.9-cp312-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d503029e66c43b1a1cb1a42a1e9bb428c8a28dcbdec31c28e705472ca648a3a", size = 2203712, upload-time = "2026-02-27T19:27:25.867Z" }, + { url = "https://files.pythonhosted.org/packages/c6/dd/2bab4c6cd86257dbf99e93452a1af833113f8dc3e25a25579f6e4e4c8a94/apache_tvm_ffi-0.1.9-cp312-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:28241371934ea8af10d5067087ba1229ebddded7b2c02d33a258ec2a96df8c46", size = 2299704, upload-time = "2026-02-27T19:27:27.477Z" }, + { url = "https://files.pythonhosted.org/packages/7a/4a/b469bcb2e1014cb84d336d2a59f42958a058251c577a4c2680cacad346e2/apache_tvm_ffi-0.1.9-cp312-abi3-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:87cacce81df55685fc6a76e1e3c5db1200e85e87bf5974b692c59d131b7bc622", size = 2130865, upload-time = "2026-02-27T19:27:29.092Z" }, + { url = "https://files.pythonhosted.org/packages/70/ef/5402da5d37f5270fd88ea0348acca78dba9be8bdbf6c2bcae0935eb03ef1/apache_tvm_ffi-0.1.9-cp312-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f45eb43499acac45ff6c93564f0ff2d3ca27b69656d540fd56ce59d51c0b4c65", size = 2278991, upload-time = "2026-02-27T19:27:30.729Z" }, +] + +[[package]] +name = "art-vllm-runtime" +version = "0.1.0" +source = { editable = "." } +dependencies = [ + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'" }, + { name = "transformers" }, + { name = "vllm", marker = "sys_platform == 'linux'" }, +] + +[package.metadata] +requires-dist = [ + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'", specifier = "==2.28.9" }, + { name = "transformers", specifier = "==5.6.2" }, + { name = "vllm", marker = "sys_platform == 'linux'", url = "https://wheels.vllm.ai/ecd0b60aad2f4e28dd00ababfc1402690d88cbed/vllm-0.20.2rc1.dev168%2Bgecd0b60aa.cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, +] + +[[package]] +name = "astor" +version = "0.8.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5a/21/75b771132fee241dfe601d39ade629548a9626d1d39f333fde31bc46febe/astor-0.8.1.tar.gz", hash = "sha256:6a6effda93f4e1ce9f618779b2dd1d9d84f1e32812c23a29b3fff6fd7f63fa5e", size = 35090, upload-time = "2019-12-10T01:50:35.51Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/88/97eef84f48fa04fbd6750e62dcceafba6c63c81b7ac1420856c8dcc0a3f9/astor-0.8.1-py2.py3-none-any.whl", hash = "sha256:070a54e890cefb5b3739d19f30f5a5ec840ffc9c50ffa7d23cc9fc1a38ebbfc5", size = 27488, upload-time = "2019-12-10T01:50:33.628Z" }, +] + +[[package]] +name = "attrs" +version = "26.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/8e/82a0fe20a541c03148528be8cac2408564a6c9a0cc7e9171802bc1d26985/attrs-26.1.0.tar.gz", hash = "sha256:d03ceb89cb322a8fd706d4fb91940737b6642aa36998fe130a9bc96c985eff32", size = 952055, upload-time = "2026-03-19T14:22:25.026Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/b4/17d4b0b2a2dc85a6df63d1157e028ed19f90d4cd97c36717afef2bc2f395/attrs-26.1.0-py3-none-any.whl", hash = "sha256:c647aa4a12dfbad9333ca4e71fe62ddc36f4e63b2d260a37a8b83d2f043ac309", size = 67548, upload-time = "2026-03-19T14:22:23.645Z" }, +] + +[[package]] +name = "blake3" +version = "1.0.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/75/aa/abcd75e9600987a0bc6cfe9b6b2ff3f0e2cb08c170addc6e76035b5c4cb3/blake3-1.0.8.tar.gz", hash = "sha256:513cc7f0f5a7c035812604c2c852a0c1468311345573de647e310aca4ab165ba", size = 117308, upload-time = "2025-10-14T06:47:48.83Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ee/7d/85a4c0782f613de23d114a7a78fcce270f75b193b3ff3493a0de24ba104a/blake3-1.0.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:269f255b110840e52b6ce9db02217e39660ebad3e34ddd5bca8b8d378a77e4e1", size = 371296, upload-time = "2025-10-14T06:45:49.674Z" }, + { url = "https://files.pythonhosted.org/packages/e3/20/488475254976ed93fab57c67aa80d3b40df77f7d9db6528c9274bff53e08/blake3-1.0.8-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:66ca28a673025c40db3eba21a9cac52f559f83637efa675b3f6bd8683f0415f3", size = 374516, upload-time = "2025-10-14T06:45:51.23Z" }, + { url = "https://files.pythonhosted.org/packages/7b/21/2a1c47fedb77fb396512677ec6d46caf42ac6e9a897db77edd0a2a46f7bb/blake3-1.0.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bcb04966537777af56c1f399b35525aa70a1225816e121ff95071c33c0f7abca", size = 447911, upload-time = "2025-10-14T06:45:52.637Z" }, + { url = "https://files.pythonhosted.org/packages/cb/7d/db0626df16029713e7e61b67314c4835e85c296d82bd907c21c6ea271da2/blake3-1.0.8-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e5b5da177d62cc4b7edf0cea08fe4dec960c9ac27f916131efa890a01f747b93", size = 505420, upload-time = "2025-10-14T06:45:54.445Z" }, + { url = "https://files.pythonhosted.org/packages/5b/55/6e737850c2d58a6d9de8a76dad2ae0f75b852a23eb4ecb07a0b165e6e436/blake3-1.0.8-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:38209b10482c97e151681ea3e91cc7141f56adbbf4820a7d701a923124b41e6a", size = 394189, upload-time = "2025-10-14T06:45:55.719Z" }, + { url = "https://files.pythonhosted.org/packages/5b/94/eafaa5cdddadc0c9c603a6a6d8339433475e1a9f60c8bb9c2eed2d8736b6/blake3-1.0.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:504d1399b7fb91dfe5c25722d2807990493185faa1917456455480c36867adb5", size = 388001, upload-time = "2025-10-14T06:45:57.067Z" }, + { url = "https://files.pythonhosted.org/packages/17/81/735fa00d13de7f68b25e1b9cb36ff08c6f165e688d85d8ec2cbfcdedccc5/blake3-1.0.8-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c84af132aa09abeadf9a0118c8fb26f4528f3f42c10ef8be0fcf31c478774ec4", size = 550302, upload-time = "2025-10-14T06:45:58.657Z" }, + { url = "https://files.pythonhosted.org/packages/0e/c6/d1fe8bdea4a6088bd54b5a58bc40aed89a4e784cd796af7722a06f74bae7/blake3-1.0.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a25db3d36b55f5ed6a86470155cc749fc9c5b91c949b8d14f48658f9d960d9ec", size = 554211, upload-time = "2025-10-14T06:46:00.269Z" }, +] + +[[package]] +name = "cachetools" +version = "7.0.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/af/dd/57fe3fdb6e65b25a5987fd2cdc7e22db0aef508b91634d2e57d22928d41b/cachetools-7.0.5.tar.gz", hash = "sha256:0cd042c24377200c1dcd225f8b7b12b0ca53cc2c961b43757e774ebe190fd990", size = 37367, upload-time = "2026-03-09T20:51:29.451Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/06/f3/39cf3367b8107baa44f861dc802cbf16263c945b62d8265d36034fc07bea/cachetools-7.0.5-py3-none-any.whl", hash = "sha256:46bc8ebefbe485407621d0a4264b23c080cedd913921bad7ac3ed2f26c183114", size = 13918, upload-time = "2026-03-09T20:51:27.33Z" }, +] + +[[package]] +name = "cbor2" +version = "5.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bd/cb/09939728be094d155b5d4ac262e39877875f5f7e36eea66beb359f647bd0/cbor2-5.9.0.tar.gz", hash = "sha256:85c7a46279ac8f226e1059275221e6b3d0e370d2bb6bd0500f9780781615bcea", size = 111231, upload-time = "2026-03-22T15:56:50.638Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/09/fd/7ddf3d3153b54c69c3be77172b8d9aa3a9d74f62a7fbde614d53eaeed9a4/cbor2-5.9.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ae6c706ac1d85a0b3cb3395308fd0c4d55e3202b4760773675957e93cdff45fc", size = 287865, upload-time = "2026-03-22T15:56:14.813Z" }, + { url = "https://files.pythonhosted.org/packages/db/9d/7ede2cc42f9bb4260492e7d29d2aab781eacbbcfb09d983de1e695077199/cbor2-5.9.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4cd43d8fc374b31643b2830910f28177a606a7bc84975a62675dd3f2e320fc7b", size = 288246, upload-time = "2026-03-22T15:56:16.113Z" }, + { url = "https://files.pythonhosted.org/packages/ce/9d/588ebc7c5bc5843f609b05fe07be8575c7dec987735b0bbc908ac9c1264a/cbor2-5.9.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4aa07b392cc3d76fb31c08a46a226b58c320d1c172ff3073e864409ced7bc50f", size = 280214, upload-time = "2026-03-22T15:56:17.519Z" }, + { url = "https://files.pythonhosted.org/packages/f7/a1/6fc8f4b15c6a27e7fbb7966c30c2b4b18c274a3221fa2f5e6235502d34bc/cbor2-5.9.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:971d425b3a23b75953d8853d5f9911bdeefa09d759ee3b5e6b07b5ff3cbd9073", size = 282162, upload-time = "2026-03-22T15:56:18.975Z" }, + { url = "https://files.pythonhosted.org/packages/42/ff/b83492b096fbef26e9cb62c1a4bf2d3cef579ea7b33138c6c37c4ae66f67/cbor2-5.9.0-py3-none-any.whl", hash = "sha256:27695cbd70c90b8de5c4a284642c2836449b14e2c2e07e3ffe0744cb7669a01b", size = 24627, upload-time = "2026-03-22T15:56:48.847Z" }, +] + +[[package]] +name = "certifi" +version = "2026.2.25" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/af/2d/7bf41579a8986e348fa033a31cdd0e4121114f6bce2457e8876010b092dd/certifi-2026.2.25.tar.gz", hash = "sha256:e887ab5cee78ea814d3472169153c2d12cd43b14bd03329a39a9c6e2e80bfba7", size = 155029, upload-time = "2026-02-25T02:54:17.342Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9a/3c/c17fb3ca2d9c3acff52e30b309f538586f9f5b9c9cf454f3845fc9af4881/certifi-2026.2.25-py3-none-any.whl", hash = "sha256:027692e4402ad994f1c42e52a4997a9763c646b73e4096e4d5d6db8af1d6f0fa", size = 153684, upload-time = "2026-02-25T02:54:15.766Z" }, +] + +[[package]] +name = "cffi" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pycparser", marker = "implementation_name != 'PyPy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8453301356628e8147c79dbb825bcbc73dc7401f9846/cffi-2.0.0.tar.gz", hash = "sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529", size = 523588, upload-time = "2025-09-08T23:24:04.541Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ff/df/a4f0fbd47331ceeba3d37c2e51e9dfc9722498becbeec2bd8bc856c9538a/cffi-2.0.0-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:21d1152871b019407d8ac3985f6775c079416c282e431a4da6afe7aefd2bccbe", size = 212529, upload-time = "2025-09-08T23:22:47.349Z" }, + { url = "https://files.pythonhosted.org/packages/d5/72/12b5f8d3865bf0f87cf1404d8c374e7487dcf097a1c91c436e72e6badd83/cffi-2.0.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b21e08af67b8a103c71a250401c78d5e0893beff75e28c53c98f4de42f774062", size = 220097, upload-time = "2025-09-08T23:22:48.677Z" }, + { url = "https://files.pythonhosted.org/packages/c2/95/7a135d52a50dfa7c882ab0ac17e8dc11cec9d55d2c18dda414c051c5e69e/cffi-2.0.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:1e3a615586f05fc4065a8b22b8152f0c1b00cdbc60596d187c2a74f9e3036e4e", size = 207983, upload-time = "2025-09-08T23:22:50.06Z" }, + { url = "https://files.pythonhosted.org/packages/3a/c8/15cb9ada8895957ea171c62dc78ff3e99159ee7adb13c0123c001a2546c1/cffi-2.0.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:81afed14892743bbe14dacb9e36d9e0e504cd204e0b165062c488942b9718037", size = 206519, upload-time = "2025-09-08T23:22:51.364Z" }, + { url = "https://files.pythonhosted.org/packages/78/2d/7fa73dfa841b5ac06c7b8855cfc18622132e365f5b81d02230333ff26e9e/cffi-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3e17ed538242334bf70832644a32a7aae3d83b57567f9fd60a26257e992b79ba", size = 219572, upload-time = "2025-09-08T23:22:52.902Z" }, + { url = "https://files.pythonhosted.org/packages/07/e0/267e57e387b4ca276b90f0434ff88b2c2241ad72b16d31836adddfd6031b/cffi-2.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3925dd22fa2b7699ed2617149842d2e6adde22b262fcbfada50e3d195e4b3a94", size = 222963, upload-time = "2025-09-08T23:22:54.518Z" }, + { url = "https://files.pythonhosted.org/packages/b6/75/1f2747525e06f53efbd878f4d03bac5b859cbc11c633d0fb81432d98a795/cffi-2.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2c8f814d84194c9ea681642fd164267891702542f028a15fc97d4674b6206187", size = 221361, upload-time = "2025-09-08T23:22:55.867Z" }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/a1/67fe25fac3c7642725500a3f6cfe5821ad557c3abb11c9d20d12c7008d3e/charset_normalizer-3.4.7.tar.gz", hash = "sha256:ae89db9e5f98a11a4bf50407d4363e7b09b31e55bc117b4f7d80aab97ba009e5", size = 144271, upload-time = "2026-04-02T09:28:39.342Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/e3/0fadc706008ac9d7b9b5be6dc767c05f9d3e5df51744ce4cc9605de7b9f4/charset_normalizer-3.4.7-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6178f72c5508bfc5fd446a5905e698c6212932f25bcdd4b47a757a50605a90e2", size = 208061, upload-time = "2026-04-02T09:26:25.568Z" }, + { url = "https://files.pythonhosted.org/packages/42/f0/3dd1045c47f4a4604df85ec18ad093912ae1344ac706993aff91d38773a2/charset_normalizer-3.4.7-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e1421b502d83040e6d7fb2fb18dff63957f720da3d77b2fbd3187ceb63755d7b", size = 229031, upload-time = "2026-04-02T09:26:26.865Z" }, + { url = "https://files.pythonhosted.org/packages/dc/67/675a46eb016118a2fbde5a277a5d15f4f69d5f3f5f338e5ee2f8948fcf43/charset_normalizer-3.4.7-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:edac0f1ab77644605be2cbba52e6b7f630731fc42b34cb0f634be1a6eface56a", size = 225239, upload-time = "2026-04-02T09:26:28.044Z" }, + { url = "https://files.pythonhosted.org/packages/4b/f8/d0118a2f5f23b02cd166fa385c60f9b0d4f9194f574e2b31cef350ad7223/charset_normalizer-3.4.7-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5649fd1c7bade02f320a462fdefd0b4bd3ce036065836d4f42e0de958038e116", size = 216589, upload-time = "2026-04-02T09:26:29.239Z" }, + { url = "https://files.pythonhosted.org/packages/b1/f1/6d2b0b261b6c4ceef0fcb0d17a01cc5bc53586c2d4796fa04b5c540bc13d/charset_normalizer-3.4.7-cp312-cp312-manylinux_2_31_armv7l.whl", hash = "sha256:203104ed3e428044fd943bc4bf45fa73c0730391f9621e37fe39ecf477b128cb", size = 202733, upload-time = "2026-04-02T09:26:30.5Z" }, + { url = "https://files.pythonhosted.org/packages/6f/c0/7b1f943f7e87cc3db9626ba17807d042c38645f0a1d4415c7a14afb5591f/charset_normalizer-3.4.7-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:298930cec56029e05497a76988377cbd7457ba864beeea92ad7e844fe74cd1f1", size = 212652, upload-time = "2026-04-02T09:26:31.709Z" }, + { url = "https://files.pythonhosted.org/packages/38/dd/5a9ab159fe45c6e72079398f277b7d2b523e7f716acc489726115a910097/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:708838739abf24b2ceb208d0e22403dd018faeef86ddac04319a62ae884c4f15", size = 211229, upload-time = "2026-04-02T09:26:33.282Z" }, + { url = "https://files.pythonhosted.org/packages/d5/ff/531a1cad5ca855d1c1a8b69cb71abfd6d85c0291580146fda7c82857caa1/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:0f7eb884681e3938906ed0434f20c63046eacd0111c4ba96f27b76084cd679f5", size = 203552, upload-time = "2026-04-02T09:26:34.845Z" }, + { url = "https://files.pythonhosted.org/packages/c1/4c/a5fb52d528a8ca41f7598cb619409ece30a169fbdf9cdce592e53b46c3a6/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4dc1e73c36828f982bfe79fadf5919923f8a6f4df2860804db9a98c48824ce8d", size = 230806, upload-time = "2026-04-02T09:26:36.152Z" }, + { url = "https://files.pythonhosted.org/packages/59/7a/071feed8124111a32b316b33ae4de83d36923039ef8cf48120266844285b/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:aed52fea0513bac0ccde438c188c8a471c4e0f457c2dd20cdbf6ea7a450046c7", size = 212316, upload-time = "2026-04-02T09:26:37.672Z" }, + { url = "https://files.pythonhosted.org/packages/fd/35/f7dba3994312d7ba508e041eaac39a36b120f32d4c8662b8814dab876431/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:fea24543955a6a729c45a73fe90e08c743f0b3334bbf3201e6c4bc1b0c7fa464", size = 227274, upload-time = "2026-04-02T09:26:38.93Z" }, + { url = "https://files.pythonhosted.org/packages/8a/2d/a572df5c9204ab7688ec1edc895a73ebded3b023bb07364710b05dd1c9be/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bb6d88045545b26da47aa879dd4a89a71d1dce0f0e549b1abcb31dfe4a8eac49", size = 218468, upload-time = "2026-04-02T09:26:40.17Z" }, + { url = "https://files.pythonhosted.org/packages/db/8f/61959034484a4a7c527811f4721e75d02d653a35afb0b6054474d8185d4c/charset_normalizer-3.4.7-py3-none-any.whl", hash = "sha256:3dce51d0f5e7951f8bb4900c257dad282f49190fdbebecd4ba99bcc41fef404d", size = 61958, upload-time = "2026-04-02T09:28:37.794Z" }, +] + +[[package]] +name = "click" +version = "8.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/57/75/31212c6bf2503fdf920d87fee5d7a86a2e3bcf444984126f13d8e4016804/click-8.3.2.tar.gz", hash = "sha256:14162b8b3b3550a7d479eafa77dfd3c38d9dc8951f6f69c78913a8f9a7540fd5", size = 302856, upload-time = "2026-04-03T19:14:45.118Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e4/20/71885d8b97d4f3dde17b1fdb92dbd4908b00541c5a3379787137285f602e/click-8.3.2-py3-none-any.whl", hash = "sha256:1924d2c27c5653561cd2cae4548d1406039cb79b858b747cfea24924bbc1616d", size = 108379, upload-time = "2026-04-03T19:14:43.505Z" }, +] + +[[package]] +name = "cloudpickle" +version = "3.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/27/fb/576f067976d320f5f0114a8d9fa1215425441bb35627b1993e5afd8111e5/cloudpickle-3.1.2.tar.gz", hash = "sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414", size = 22330, upload-time = "2025-11-03T09:25:26.604Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl", hash = "sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a", size = 22228, upload-time = "2025-11-03T09:25:25.534Z" }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "compressed-tensors" +version = "0.15.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "loguru" }, + { name = "pydantic" }, + { name = "torch" }, + { name = "transformers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/1b/c3c4a98ec5f2727656336f07a0c35862195c310d8eb0b2fa5b4be6848680/compressed_tensors-0.15.0.1.tar.gz", hash = "sha256:a8e93054e8a5ec49c980b09ed36c4c1249b4a8ee167920a8e461c4da26e78d99", size = 229412, upload-time = "2026-04-10T14:23:54.708Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/52/93833dc1610e017ac5b7dcd59b8304d8ef67d1114c2d124e728a2cbbea12/compressed_tensors-0.15.0.1-py3-none-any.whl", hash = "sha256:e1b1f322e82e475715e242bad46925a304ea8e5c98b5055a15b8eb22fb6bfea9", size = 194260, upload-time = "2026-04-10T14:23:53.098Z" }, +] + +[[package]] +name = "cryptography" +version = "46.0.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/47/93/ac8f3d5ff04d54bc814e961a43ae5b0b146154c89c61b47bb07557679b18/cryptography-46.0.7.tar.gz", hash = "sha256:e4cfd68c5f3e0bfdad0d38e023239b96a2fe84146481852dffbcca442c245aa5", size = 750652, upload-time = "2026-04-08T01:57:54.692Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5f/45/6d80dc379b0bbc1f9d1e429f42e4cb9e1d319c7a8201beffd967c516ea01/cryptography-46.0.7-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b36a4695e29fe69215d75960b22577197aca3f7a25b9cf9d165dcfe9d80bc325", size = 4275492, upload-time = "2026-04-08T01:56:19.36Z" }, + { url = "https://files.pythonhosted.org/packages/4a/9a/1765afe9f572e239c3469f2cb429f3ba7b31878c893b246b4b2994ffe2fe/cryptography-46.0.7-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ad9ef796328c5e3c4ceed237a183f5d41d21150f972455a9d926593a1dcb308", size = 4426670, upload-time = "2026-04-08T01:56:21.415Z" }, + { url = "https://files.pythonhosted.org/packages/8f/3e/af9246aaf23cd4ee060699adab1e47ced3f5f7e7a8ffdd339f817b446462/cryptography-46.0.7-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:73510b83623e080a2c35c62c15298096e2a5dc8d51c3b4e1740211839d0dea77", size = 4280275, upload-time = "2026-04-08T01:56:23.539Z" }, + { url = "https://files.pythonhosted.org/packages/0f/54/6bbbfc5efe86f9d71041827b793c24811a017c6ac0fd12883e4caa86b8ed/cryptography-46.0.7-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cbd5fb06b62bd0721e1170273d3f4d5a277044c47ca27ee257025146c34cbdd1", size = 4928402, upload-time = "2026-04-08T01:56:25.624Z" }, + { url = "https://files.pythonhosted.org/packages/2d/cf/054b9d8220f81509939599c8bdbc0c408dbd2bdd41688616a20731371fe0/cryptography-46.0.7-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:420b1e4109cc95f0e5700eed79908cef9268265c773d3a66f7af1eef53d409ef", size = 4459985, upload-time = "2026-04-08T01:56:27.309Z" }, + { url = "https://files.pythonhosted.org/packages/f9/46/4e4e9c6040fb01c7467d47217d2f882daddeb8828f7df800cb806d8a2288/cryptography-46.0.7-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:24402210aa54baae71d99441d15bb5a1919c195398a87b563df84468160a65de", size = 3990652, upload-time = "2026-04-08T01:56:29.095Z" }, + { url = "https://files.pythonhosted.org/packages/36/5f/313586c3be5a2fbe87e4c9a254207b860155a8e1f3cca99f9910008e7d08/cryptography-46.0.7-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:8a469028a86f12eb7d2fe97162d0634026d92a21f3ae0ac87ed1c4a447886c83", size = 4279805, upload-time = "2026-04-08T01:56:30.928Z" }, + { url = "https://files.pythonhosted.org/packages/69/33/60dfc4595f334a2082749673386a4d05e4f0cf4df8248e63b2c3437585f2/cryptography-46.0.7-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:9694078c5d44c157ef3162e3bf3946510b857df5a3955458381d1c7cfc143ddb", size = 4892883, upload-time = "2026-04-08T01:56:32.614Z" }, + { url = "https://files.pythonhosted.org/packages/c7/0b/333ddab4270c4f5b972f980adef4faa66951a4aaf646ca067af597f15563/cryptography-46.0.7-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:42a1e5f98abb6391717978baf9f90dc28a743b7d9be7f0751a6f56a75d14065b", size = 4459756, upload-time = "2026-04-08T01:56:34.306Z" }, + { url = "https://files.pythonhosted.org/packages/d2/14/633913398b43b75f1234834170947957c6b623d1701ffc7a9600da907e89/cryptography-46.0.7-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91bbcb08347344f810cbe49065914fe048949648f6bd5c2519f34619142bbe85", size = 4410244, upload-time = "2026-04-08T01:56:35.977Z" }, + { url = "https://files.pythonhosted.org/packages/10/f2/19ceb3b3dc14009373432af0c13f46aa08e3ce334ec6eff13492e1812ccd/cryptography-46.0.7-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5d1c02a14ceb9148cc7816249f64f623fbfee39e8c03b3650d842ad3f34d637e", size = 4674868, upload-time = "2026-04-08T01:56:38.034Z" }, + { url = "https://files.pythonhosted.org/packages/a5/d0/36a49f0262d2319139d2829f773f1b97ef8aef7f97e6e5bd21455e5a8fb5/cryptography-46.0.7-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:84d4cced91f0f159a7ddacad249cc077e63195c36aac40b4150e7a57e84fffe7", size = 4270628, upload-time = "2026-04-08T01:57:12.885Z" }, + { url = "https://files.pythonhosted.org/packages/8a/6c/1a42450f464dda6ffbe578a911f773e54dd48c10f9895a23a7e88b3e7db5/cryptography-46.0.7-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:128c5edfe5e5938b86b03941e94fac9ee793a94452ad1365c9fc3f4f62216832", size = 4415405, upload-time = "2026-04-08T01:57:14.923Z" }, + { url = "https://files.pythonhosted.org/packages/9a/92/4ed714dbe93a066dc1f4b4581a464d2d7dbec9046f7c8b7016f5286329e2/cryptography-46.0.7-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5e51be372b26ef4ba3de3c167cd3d1022934bc838ae9eaad7e644986d2a3d163", size = 4272715, upload-time = "2026-04-08T01:57:16.638Z" }, + { url = "https://files.pythonhosted.org/packages/b7/e6/a26b84096eddd51494bba19111f8fffe976f6a09f132706f8f1bf03f51f7/cryptography-46.0.7-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cdf1a610ef82abb396451862739e3fc93b071c844399e15b90726ef7470eeaf2", size = 4918400, upload-time = "2026-04-08T01:57:19.021Z" }, + { url = "https://files.pythonhosted.org/packages/c7/08/ffd537b605568a148543ac3c2b239708ae0bd635064bab41359252ef88ed/cryptography-46.0.7-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1d25aee46d0c6f1a501adcddb2d2fee4b979381346a78558ed13e50aa8a59067", size = 4450634, upload-time = "2026-04-08T01:57:21.185Z" }, + { url = "https://files.pythonhosted.org/packages/16/01/0cd51dd86ab5b9befe0d031e276510491976c3a80e9f6e31810cce46c4ad/cryptography-46.0.7-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:cdfbe22376065ffcf8be74dc9a909f032df19bc58a699456a21712d6e5eabfd0", size = 3985233, upload-time = "2026-04-08T01:57:22.862Z" }, + { url = "https://files.pythonhosted.org/packages/92/49/819d6ed3a7d9349c2939f81b500a738cb733ab62fbecdbc1e38e83d45e12/cryptography-46.0.7-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:abad9dac36cbf55de6eb49badd4016806b3165d396f64925bf2999bcb67837ba", size = 4271955, upload-time = "2026-04-08T01:57:24.814Z" }, + { url = "https://files.pythonhosted.org/packages/80/07/ad9b3c56ebb95ed2473d46df0847357e01583f4c52a85754d1a55e29e4d0/cryptography-46.0.7-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:935ce7e3cfdb53e3536119a542b839bb94ec1ad081013e9ab9b7cfd478b05006", size = 4879888, upload-time = "2026-04-08T01:57:26.88Z" }, + { url = "https://files.pythonhosted.org/packages/b8/c7/201d3d58f30c4c2bdbe9b03844c291feb77c20511cc3586daf7edc12a47b/cryptography-46.0.7-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:35719dc79d4730d30f1c2b6474bd6acda36ae2dfae1e3c16f2051f215df33ce0", size = 4449961, upload-time = "2026-04-08T01:57:29.068Z" }, + { url = "https://files.pythonhosted.org/packages/a5/ef/649750cbf96f3033c3c976e112265c33906f8e462291a33d77f90356548c/cryptography-46.0.7-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:7bbc6ccf49d05ac8f7d7b5e2e2c33830d4fe2061def88210a126d130d7f71a85", size = 4401696, upload-time = "2026-04-08T01:57:31.029Z" }, + { url = "https://files.pythonhosted.org/packages/41/52/a8908dcb1a389a459a29008c29966c1d552588d4ae6d43f3a1a4512e0ebe/cryptography-46.0.7-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a1529d614f44b863a7b480c6d000fe93b59acee9c82ffa027cfadc77521a9f5e", size = 4664256, upload-time = "2026-04-08T01:57:33.144Z" }, +] + +[[package]] +name = "cuda-bindings" +version = "12.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-pathfinder" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/c2/65bfd79292b8ff18be4dd7f7442cea37bcbc1a228c1886f1dea515c45b67/cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:694ba35023846625ef471257e6b5a4bc8af690f961d197d77d34b1d1db393f56", size = 11760260, upload-time = "2025-10-21T14:51:40.79Z" }, + { url = "https://files.pythonhosted.org/packages/a9/c1/dabe88f52c3e3760d861401bb994df08f672ec893b8f7592dc91626adcf3/cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fda147a344e8eaeca0c6ff113d2851ffca8f7dfc0a6c932374ee5c47caa649c8", size = 12151019, upload-time = "2025-10-21T14:51:43.167Z" }, +] + +[[package]] +name = "cuda-pathfinder" +version = "1.5.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f2/f9/1b9b60a30fc463c14cdea7a77228131a0ccc89572e8df9cb86c9648271ab/cuda_pathfinder-1.5.2-py3-none-any.whl", hash = "sha256:0c5f160a7756c5b072723cbbd6d861e38917ef956c68150b02f0b6e9271c71fa", size = 49988, upload-time = "2026-04-06T23:01:05.17Z" }, +] + +[[package]] +name = "cuda-python" +version = "12.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-bindings" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/f3/6b032a554019cfb3447e671798c1bd3e79b5f1af20d10253f56cea269ef2/cuda_python-12.9.4-py3-none-any.whl", hash = "sha256:d2cacea882a69863f1e7d27ee71d75f0684f4c76910aff839067e4f89c902279", size = 7594, upload-time = "2025-10-21T14:55:12.846Z" }, +] + +[[package]] +name = "cuda-tile" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f3/49/4592bc94ca05a07c7947ea114fd12734c8497f2daffee9faa79a03e39fb5/cuda_tile-1.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:375316b64c51ee7cfadb2f170a30c1547bc41eb39f1e233a6556713857d2e81f", size = 245744, upload-time = "2026-04-20T15:52:09.621Z" }, + { url = "https://files.pythonhosted.org/packages/40/76/84cb68be463c827bf79da9fa0aa5140838de6455ef6f438bbe0ffa75d378/cuda_tile-1.3.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:e4865acbff1172aaee304bf9c550586088d8b4545a384423597a590899386709", size = 247301, upload-time = "2026-04-20T15:51:04.042Z" }, +] + +[[package]] +name = "cuda-toolkit" +version = "12.8.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/c8/7dce3a0b15b42a3b58e7d96eb22a687d3bf2c44e01d149a6874629cd9938/cuda_toolkit-12.8.1-py2.py3-none-any.whl", hash = "sha256:adc7906af4ecbf9a352f9dca5734eceb21daec281ccfcf5675e1d2f724fc2cba", size = 2283, upload-time = "2025-08-13T02:03:07.842Z" }, +] + +[package.optional-dependencies] +cublas = [ + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +cudart = [ + { name = "nvidia-cuda-runtime-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +cufft = [ + { name = "nvidia-cufft-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +cufile = [ + { name = "nvidia-cufile-cu12", marker = "sys_platform == 'linux'" }, +] +cupti = [ + { name = "nvidia-cuda-cupti-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +curand = [ + { name = "nvidia-curand-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +cusolver = [ + { name = "nvidia-cusolver-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +cusparse = [ + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +nvjitlink = [ + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +nvrtc = [ + { name = "nvidia-cuda-nvrtc-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +nvtx = [ + { name = "nvidia-nvtx-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] + +[[package]] +name = "depyf" +version = "0.20.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "astor" }, + { name = "dill" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/88/35/83fb0178212279aa0af031031905804c6de5618435d229f41ed21bb9ad2c/depyf-0.20.0.tar.gz", hash = "sha256:fb7683bd72c44f67b56029df2c47721e9a02ffa4d7b19095f1c54c4ebf797a98", size = 6168761, upload-time = "2025-10-13T12:33:38.589Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cf/65/4df6936130b56e1429114e663e7c1576cf845f3aef1b2dd200c0a5d19dba/depyf-0.20.0-py3-none-any.whl", hash = "sha256:d31effad4261cebecb58955d832e448ace88f432328f95f82fd99c30fd9308d4", size = 39381, upload-time = "2025-10-13T12:33:33.647Z" }, +] + +[[package]] +name = "dill" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/81/e1/56027a71e31b02ddc53c7d65b01e68edf64dea2932122fe7746a516f75d5/dill-0.4.1.tar.gz", hash = "sha256:423092df4182177d4d8ba8290c8a5b640c66ab35ec7da59ccfa00f6fa3eea5fa", size = 187315, upload-time = "2026-01-19T02:36:56.85Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/77/dc8c558f7593132cf8fefec57c4f60c83b16941c574ac5f619abb3ae7933/dill-0.4.1-py3-none-any.whl", hash = "sha256:1e1ce33e978ae97fcfcff5638477032b801c46c7c65cf717f95fbc2248f79a9d", size = 120019, upload-time = "2026-01-19T02:36:55.663Z" }, +] + +[[package]] +name = "diskcache" +version = "5.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3f/21/1c1ffc1a039ddcc459db43cc108658f32c57d271d7289a2794e401d0fdb6/diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc", size = 67916, upload-time = "2023-08-31T06:12:00.316Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550, upload-time = "2023-08-31T06:11:58.822Z" }, +] + +[[package]] +name = "distro" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fc/f8/98eea607f65de6527f8a2e8885fc8015d3e6f5775df186e443e0964a11c3/distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed", size = 60722, upload-time = "2023-12-24T09:54:32.31Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, +] + +[[package]] +name = "dnspython" +version = "2.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/8b/57666417c0f90f08bcafa776861060426765fdb422eb10212086fb811d26/dnspython-2.8.0.tar.gz", hash = "sha256:181d3c6996452cb1189c4046c61599b84a5a86e099562ffde77d26984ff26d0f", size = 368251, upload-time = "2025-09-07T18:58:00.022Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" }, +] + +[[package]] +name = "docstring-parser" +version = "0.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/9d/c3b43da9515bd270df0f80548d9944e389870713cc1fe2b8fb35fe2bcefd/docstring_parser-0.17.0.tar.gz", hash = "sha256:583de4a309722b3315439bb31d64ba3eebada841f2e2cee23b99df001434c912", size = 27442, upload-time = "2025-07-21T07:35:01.868Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" }, +] + +[[package]] +name = "einops" +version = "0.8.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/77/850bef8d72ffb9219f0b1aac23fbc1bf7d038ee6ea666f331fa273031aa2/einops-0.8.2.tar.gz", hash = "sha256:609da665570e5e265e27283aab09e7f279ade90c4f01bcfca111f3d3e13f2827", size = 56261, upload-time = "2026-01-26T04:13:17.638Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl", hash = "sha256:54058201ac7087911181bfec4af6091bb59380360f069276601256a76af08193", size = 65638, upload-time = "2026-01-26T04:13:18.546Z" }, +] + +[[package]] +name = "email-validator" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dnspython" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/22/900cb125c76b7aaa450ce02fd727f452243f2e91a61af068b40adba60ea9/email_validator-2.3.0.tar.gz", hash = "sha256:9fc05c37f2f6cf439ff414f8fc46d917929974a82244c20eb10231ba60c54426", size = 51238, upload-time = "2025-08-26T13:09:06.831Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" }, +] + +[[package]] +name = "fastapi" +version = "0.135.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-doc" }, + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f7/e6/7adb4c5fa231e82c35b8f5741a9f2d055f520c29af5546fd70d3e8e1cd2e/fastapi-0.135.3.tar.gz", hash = "sha256:bd6d7caf1a2bdd8d676843cdcd2287729572a1ef524fc4d65c17ae002a1be654", size = 396524, upload-time = "2026-04-01T16:23:58.188Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/a4/5caa2de7f917a04ada20018eccf60d6cc6145b0199d55ca3711b0fc08312/fastapi-0.135.3-py3-none-any.whl", hash = "sha256:9b0f590c813acd13d0ab43dd8494138eb58e484bfac405db1f3187cfc5810d98", size = 117734, upload-time = "2026-04-01T16:23:59.328Z" }, +] + +[package.optional-dependencies] +standard = [ + { name = "email-validator" }, + { name = "fastapi-cli", extra = ["standard"] }, + { name = "httpx" }, + { name = "jinja2" }, + { name = "pydantic-extra-types" }, + { name = "pydantic-settings" }, + { name = "python-multipart" }, + { name = "uvicorn", extra = ["standard"] }, +] + +[[package]] +name = "fastapi-cli" +version = "0.0.24" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "rich-toolkit" }, + { name = "typer" }, + { name = "uvicorn", extra = ["standard"] }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6e/58/74797ae9e4610cfa0c6b34c8309096d3b20bb29be3b8b5fbf1004d10fa5f/fastapi_cli-0.0.24.tar.gz", hash = "sha256:1afc9c9e21d7ebc8a3ca5e31790cd8d837742be7e4f8b9236e99cb3451f0de00", size = 19043, upload-time = "2026-02-24T10:45:10.476Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/4b/68f9fe268e535d79c76910519530026a4f994ce07189ac0dded45c6af825/fastapi_cli-0.0.24-py3-none-any.whl", hash = "sha256:4a1f78ed798f106b4fee85ca93b85d8fe33c0a3570f775964d37edb80b8f0edc", size = 12304, upload-time = "2026-02-24T10:45:09.552Z" }, +] + +[package.optional-dependencies] +standard = [ + { name = "fastapi-cloud-cli" }, + { name = "uvicorn", extra = ["standard"] }, +] + +[[package]] +name = "fastapi-cloud-cli" +version = "0.16.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastar" }, + { name = "httpx" }, + { name = "pydantic", extra = ["email"] }, + { name = "rich-toolkit" }, + { name = "rignore" }, + { name = "sentry-sdk" }, + { name = "typer" }, + { name = "uvicorn", extra = ["standard"] }, +] +sdist = { url = "https://files.pythonhosted.org/packages/22/70/ca14fae57a221610d3e2e3dfad2b6e97ee31fcafaa36f90a2158d57e9a73/fastapi_cloud_cli-0.16.1.tar.gz", hash = "sha256:33b552c4ad46cd33823ef53f93b8b7813db2306c80c1cbcfa4d72067c99b26ab", size = 46193, upload-time = "2026-04-08T09:12:54.151Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/8b/f8c9eb116d2e89de5e0875c5fce90f23143410f41fe27725be04bdcec328/fastapi_cloud_cli-0.16.1-py3-none-any.whl", hash = "sha256:8b43bd8c7dd3710393d3be4c248c6a00807202b488a543716562529a8316cbee", size = 33212, upload-time = "2026-04-08T09:12:52.949Z" }, +] + +[[package]] +name = "fastar" +version = "0.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5a/8a/841a8fea5d704ed19836a1f7f83fe2b2d95624a14e9ddf45823ffb518c98/fastar-0.10.0.tar.gz", hash = "sha256:cba4452d6a33894faf5b0b9d55342a1259ad5c94cbdb16af09346084e0787680", size = 70357, upload-time = "2026-04-08T01:02:01.507Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6e/dd/bc0deb3c8fc1966f074725e4f44bf6573a4f1de8e3b7d77e08371ebeb0ea/fastar-0.10.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e0df3df848fe78657f9f9b40a811606cae34aa45ad79cd51f26d6f048f0d4ae1", size = 866216, upload-time = "2026-04-08T01:00:23.092Z" }, + { url = "https://files.pythonhosted.org/packages/97/3c/45023b3538b0eb34d0ac04b6bd4dc707c1480a48e88af5365d7be7448334/fastar-0.10.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a453abf99af0f42bb03db90f9bd4aa69b5a7b88d50841577d428ec51f206856f", size = 761054, upload-time = "2026-04-08T00:59:20.36Z" }, + { url = "https://files.pythonhosted.org/packages/69/07/23294498fceda38c3472f2c24a6aee1478991f1fd1982392bca6345af3ae/fastar-0.10.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c6a3e7acc58377de02ff3e8937d4b7e09b1270c294a0d5a0d3c2614aee69058e", size = 758885, upload-time = "2026-04-08T00:59:32.486Z" }, + { url = "https://files.pythonhosted.org/packages/19/89/1e0b3b5ef774deb0937bfeb93d2d21147a1db7a8d741ea63903b1f5d7cd6/fastar-0.10.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:50a4a5fcd001f289fe66cbcff0aaf9e081532253cd7427270734988b22db6136", size = 924750, upload-time = "2026-04-08T00:59:44.41Z" }, + { url = "https://files.pythonhosted.org/packages/b1/85/486c640b768f9f6524d9cebd32e84808070136fea5696884b946bf63ecbb/fastar-0.10.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:54f60b5a87a2884efa8fc51978989e58cb1dc0ec1f645629491cd12f1dd5bb77", size = 817365, upload-time = "2026-04-08T01:00:09.616Z" }, + { url = "https://files.pythonhosted.org/packages/f3/4b/271ac7f9067ab39cffe95f2349604ac2248906be6fd86a70abb3c9f3d8bb/fastar-0.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:edaa085c8555620ec24aac1663251d62bdece619fcf6a4ad9dc2389a5fa13220", size = 819348, upload-time = "2026-04-08T01:00:35.083Z" }, + { url = "https://files.pythonhosted.org/packages/4b/fc/ca87c6fee7eaad484711f8dca44c792e4dc0f2d3f4548c93939b06bdc7eb/fastar-0.10.0-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:4110f5a357ea88fa35f27021cf30c26d863a5b589d6ac9e4e854ed02b34c9f35", size = 885868, upload-time = "2026-04-08T00:59:56.124Z" }, + { url = "https://files.pythonhosted.org/packages/2f/00/588f0960ab1b36978d75a91bd44d9be9072c05211b04f224adcff9e83285/fastar-0.10.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:efa48b89ca2c8496f7fa0d36162e12d7476c597d0bae4d8fc42f86b958bd8fea", size = 968860, upload-time = "2026-04-08T01:01:12.557Z" }, + { url = "https://files.pythonhosted.org/packages/f4/4f/e07b9d82a58c27a8018d098b3ed51f561732c17fa6643c317bfba2907bdc/fastar-0.10.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:2637a20a69ea34455aa53cca8340273166bba8bd5c06727ea64ec151ba56abe0", size = 1036445, upload-time = "2026-04-08T01:01:25.512Z" }, + { url = "https://files.pythonhosted.org/packages/19/6e/de7934cea77c9938ecad2443b114cfee13a760534bb88279a0701b12fac3/fastar-0.10.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e9ea5e45a1dd85c3104273b4b1628112f6a09115ed95dc0d31595097ce278fb2", size = 1074104, upload-time = "2026-04-08T01:01:38.464Z" }, + { url = "https://files.pythonhosted.org/packages/7e/8d/54d56acbe2bbab3efbf2c1b93ea709e0cd78b7ff9d42b4038f520a580009/fastar-0.10.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:68d70adc24b9f4cf4520ed60dbd9fb60a6eb22bb96fd6756bcb387616cb2a979", size = 1026288, upload-time = "2026-04-08T01:01:51.658Z" }, +] + +[[package]] +name = "fastsafetensors" +version = "0.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d2/69/e34a1e86a02b255896c57263bf0dfbae45b4708fd609b937f783c2202e7b/fastsafetensors-0.3.1.tar.gz", hash = "sha256:b7eb039a564d77280d17e5d63b27e9963ba5158ad02d2a3c1772c62072a81a53", size = 55665, upload-time = "2026-05-06T08:48:59.125Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/50/909871d673bacd6dfc7fee5e59bcd4ec9fbd19775bafe567ad236a3adced/fastsafetensors-0.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ac76f33e47959b7c31658fbbda1805df7540819828a3ce6a94eb34b4db0b1fa7", size = 1854825, upload-time = "2026-05-06T08:48:54.452Z" }, +] + +[[package]] +name = "filelock" +version = "3.25.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/b8/00651a0f559862f3bb7d6f7477b192afe3f583cc5e26403b44e59a55ab34/filelock-3.25.2.tar.gz", hash = "sha256:b64ece2b38f4ca29dd3e810287aa8c48182bbecd1ae6e9ae126c9b35f1382694", size = 40480, upload-time = "2026-03-11T20:45:38.487Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/a5/842ae8f0c08b61d6484b52f99a03510a3a72d23141942d216ebe81fefbce/filelock-3.25.2-py3-none-any.whl", hash = "sha256:ca8afb0da15f229774c9ad1b455ed96e85a81373065fb10446672f64444ddf70", size = 26759, upload-time = "2026-03-11T20:45:37.437Z" }, +] + +[[package]] +name = "flashinfer-cubin" +version = "0.6.8.post1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/b7/5e3b1a8c67031b421a8bd29c2bc29b900a550bb3392e8bda18bb15b5e476/flashinfer_cubin-0.6.8.post1-py3-none-any.whl", hash = "sha256:43636d4cd39e694a83d76a89f87fefcdf4cecb4c4f7dd22dac25ec368c1e901f", size = 295154113, upload-time = "2026-04-18T18:28:21.738Z" }, +] + +[[package]] +name = "flashinfer-python" +version = "0.6.8.post1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "apache-tvm-ffi" }, + { name = "click" }, + { name = "cuda-tile" }, + { name = "einops" }, + { name = "ninja" }, + { name = "numpy" }, + { name = "nvidia-cudnn-frontend" }, + { name = "nvidia-cutlass-dsl" }, + { name = "nvidia-ml-py" }, + { name = "packaging" }, + { name = "requests" }, + { name = "tabulate" }, + { name = "torch" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/53/1e/2760fef9e74abc4480961048e5790b4c9e955872fb4d7d97900cfddced5a/flashinfer_python-0.6.8.post1.tar.gz", hash = "sha256:b18e4121baf9b93fa9a9f368ba9b981a0342895f50ab9dddc224aeb964ed346f", size = 6675885, upload-time = "2026-04-18T18:28:13.299Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/6d/1e8a8533913e33a50a486332ce0673f4fdb860f6eb9ed450327c5c1762cb/flashinfer_python-0.6.8.post1-py3-none-any.whl", hash = "sha256:818f9b8cc2fe66c42a1f6264be4841ac8821ada703685a02cfccb2b5124a710b", size = 9385316, upload-time = "2026-04-18T18:28:10.285Z" }, +] + +[[package]] +name = "frozenlist" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/f5/c831fac6cc817d26fd54c7eaccd04ef7e0288806943f7cc5bbf69f3ac1f0/frozenlist-1.8.0.tar.gz", hash = "sha256:3ede829ed8d842f6cd48fc7081d7a41001a56f1f38603f9d49bf3020d59a31ad", size = 45875, upload-time = "2025-10-06T05:38:17.865Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/bd/d91c5e39f490a49df14320f4e8c80161cfcce09f1e2cde1edd16a551abb3/frozenlist-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:494a5952b1c597ba44e0e78113a7266e656b9794eec897b19ead706bd7074383", size = 242411, upload-time = "2025-10-06T05:36:09.801Z" }, + { url = "https://files.pythonhosted.org/packages/8f/83/f61505a05109ef3293dfb1ff594d13d64a2324ac3482be2cedc2be818256/frozenlist-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:96f423a119f4777a4a056b66ce11527366a8bb92f54e541ade21f2374433f6d4", size = 243014, upload-time = "2025-10-06T05:36:11.394Z" }, + { url = "https://files.pythonhosted.org/packages/d8/cb/cb6c7b0f7d4023ddda30cf56b8b17494eb3a79e3fda666bf735f63118b35/frozenlist-1.8.0-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3462dd9475af2025c31cc61be6652dfa25cbfb56cbbf52f4ccfe029f38decaf8", size = 234909, upload-time = "2025-10-06T05:36:12.598Z" }, + { url = "https://files.pythonhosted.org/packages/31/c5/cd7a1f3b8b34af009fb17d4123c5a778b44ae2804e3ad6b86204255f9ec5/frozenlist-1.8.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c4c800524c9cd9bac5166cd6f55285957fcfc907db323e193f2afcd4d9abd69b", size = 250049, upload-time = "2025-10-06T05:36:14.065Z" }, + { url = "https://files.pythonhosted.org/packages/c0/01/2f95d3b416c584a1e7f0e1d6d31998c4a795f7544069ee2e0962a4b60740/frozenlist-1.8.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d6a5df73acd3399d893dafc71663ad22534b5aa4f94e8a2fabfe856c3c1b6a52", size = 256485, upload-time = "2025-10-06T05:36:15.39Z" }, + { url = "https://files.pythonhosted.org/packages/ce/03/024bf7720b3abaebcff6d0793d73c154237b85bdf67b7ed55e5e9596dc9a/frozenlist-1.8.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:405e8fe955c2280ce66428b3ca55e12b3c4e9c336fb2103a4937e891c69a4a29", size = 237619, upload-time = "2025-10-06T05:36:16.558Z" }, + { url = "https://files.pythonhosted.org/packages/69/fa/f8abdfe7d76b731f5d8bd217827cf6764d4f1d9763407e42717b4bed50a0/frozenlist-1.8.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:908bd3f6439f2fef9e85031b59fd4f1297af54415fb60e4254a95f75b3cab3f3", size = 250320, upload-time = "2025-10-06T05:36:17.821Z" }, + { url = "https://files.pythonhosted.org/packages/f5/3c/b051329f718b463b22613e269ad72138cc256c540f78a6de89452803a47d/frozenlist-1.8.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:294e487f9ec720bd8ffcebc99d575f7eff3568a08a253d1ee1a0378754b74143", size = 246820, upload-time = "2025-10-06T05:36:19.046Z" }, + { url = "https://files.pythonhosted.org/packages/0f/ae/58282e8f98e444b3f4dd42448ff36fa38bef29e40d40f330b22e7108f565/frozenlist-1.8.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:74c51543498289c0c43656701be6b077f4b265868fa7f8a8859c197006efb608", size = 250518, upload-time = "2025-10-06T05:36:20.763Z" }, + { url = "https://files.pythonhosted.org/packages/8f/96/007e5944694d66123183845a106547a15944fbbb7154788cbf7272789536/frozenlist-1.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:776f352e8329135506a1d6bf16ac3f87bc25b28e765949282dcc627af36123aa", size = 239096, upload-time = "2025-10-06T05:36:22.129Z" }, + { url = "https://files.pythonhosted.org/packages/9a/9a/e35b4a917281c0b8419d4207f4334c8e8c5dbf4f3f5f9ada73958d937dcc/frozenlist-1.8.0-py3-none-any.whl", hash = "sha256:0c18a16eab41e82c295618a77502e17b195883241c563b00f0aa5106fc4eaa0d", size = 13409, upload-time = "2025-10-06T05:38:16.721Z" }, +] + +[[package]] +name = "fsspec" +version = "2026.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e1/cf/b50ddf667c15276a9ab15a70ef5f257564de271957933ffea49d2cdbcdfb/fsspec-2026.3.0.tar.gz", hash = "sha256:1ee6a0e28677557f8c2f994e3eea77db6392b4de9cd1f5d7a9e87a0ae9d01b41", size = 313547, upload-time = "2026-03-27T19:11:14.892Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/1f/5f4a3cd9e4440e9d9bc78ad0a91a1c8d46b4d429d5239ebe6793c9fe5c41/fsspec-2026.3.0-py3-none-any.whl", hash = "sha256:d2ceafaad1b3457968ed14efa28798162f1638dbb5d2a6868a2db002a5ee39a4", size = 202595, upload-time = "2026-03-27T19:11:13.595Z" }, +] + +[[package]] +name = "gguf" +version = "0.18.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3f/26/7622a41c39db9d7090225a4bf8368550e59694dcf7313b44f9a82b501209/gguf-0.18.0.tar.gz", hash = "sha256:b4659093d5d0dccdb5902a904d54b327f4052879fe5e90946ad5fce9f8018c2e", size = 107170, upload-time = "2026-02-27T15:05:39.254Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/0c/e0f1eae7535a97476fb903f65301e35da2a66182b8161066b7eb312b2cb8/gguf-0.18.0-py3-none-any.whl", hash = "sha256:af93f7ef198a265cbde5fa6a6b3101528bca285903949ab0a3e591cd993a1864", size = 114244, upload-time = "2026-02-27T15:05:37.991Z" }, +] + +[[package]] +name = "googleapis-common-protos" +version = "1.74.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/20/18/a746c8344152d368a5aac738d4c857012f2c5d1fd2eac7e17b647a7861bd/googleapis_common_protos-1.74.0.tar.gz", hash = "sha256:57971e4eeeba6aad1163c1f0fc88543f965bb49129b8bb55b2b7b26ecab084f1", size = 151254, upload-time = "2026-04-02T21:23:26.679Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/b0/be5d3329badb9230b765de6eea66b73abd5944bdeb5afb3562ddcd80ae84/googleapis_common_protos-1.74.0-py3-none-any.whl", hash = "sha256:702216f78610bb510e3f12ac3cafd281b7ac45cc5d86e90ad87e4d301a3426b5", size = 300743, upload-time = "2026-04-02T21:22:49.108Z" }, +] + +[[package]] +name = "grpcio" +version = "1.80.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b7/48/af6173dbca4454f4637a4678b67f52ca7e0c1ed7d5894d89d434fecede05/grpcio-1.80.0.tar.gz", hash = "sha256:29aca15edd0688c22ba01d7cc01cb000d72b2033f4a3c72a81a19b56fd143257", size = 12978905, upload-time = "2026-03-30T08:49:10.502Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/e8/a2b749265eb3415abc94f2e619bbd9e9707bebdda787e61c593004ec927a/grpcio-1.80.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:c624cc9f1008361014378c9d776de7182b11fe8b2e5a81bc69f23a295f2a1ad0", size = 6015616, upload-time = "2026-03-30T08:47:13.428Z" }, + { url = "https://files.pythonhosted.org/packages/6e/5e/d319c6e997b50c155ac5a8cb12f5173d5b42677510e886d250d50264949d/grpcio-1.80.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d334591df610ab94714048e0d5b4f3dd5ad1bee74dfec11eee344220077a79de", size = 6563866, upload-time = "2026-03-30T08:47:18.588Z" }, + { url = "https://files.pythonhosted.org/packages/ae/f6/fdd975a2cb4d78eb67769a7b3b3830970bfa2e919f1decf724ae4445f42c/grpcio-1.80.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:0cb517eb1d0d0aaf1d87af7cc5b801d686557c1d88b2619f5e31fab3c2315921", size = 7273060, upload-time = "2026-03-30T08:47:21.113Z" }, + { url = "https://files.pythonhosted.org/packages/db/f0/a3deb5feba60d9538a962913e37bd2e69a195f1c3376a3dd44fe0427e996/grpcio-1.80.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4e78c4ac0d97dc2e569b2f4bcbbb447491167cb358d1a389fc4af71ab6f70411", size = 6782121, upload-time = "2026-03-30T08:47:23.827Z" }, + { url = "https://files.pythonhosted.org/packages/ca/84/36c6dcfddc093e108141f757c407902a05085e0c328007cb090d56646cdf/grpcio-1.80.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2ed770b4c06984f3b47eb0517b1c69ad0b84ef3f40128f51448433be904634cd", size = 7383811, upload-time = "2026-03-30T08:47:26.517Z" }, + { url = "https://files.pythonhosted.org/packages/7c/ef/f3a77e3dc5b471a0ec86c564c98d6adfa3510d38f8ee99010410858d591e/grpcio-1.80.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:256507e2f524092f1473071a05e65a5b10d84b82e3ff24c5b571513cfaa61e2f", size = 8393860, upload-time = "2026-03-30T08:47:29.439Z" }, + { url = "https://files.pythonhosted.org/packages/9b/8d/9d4d27ed7f33d109c50d6b5ce578a9914aa68edab75d65869a17e630a8d1/grpcio-1.80.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9a6284a5d907c37db53350645567c522be314bac859a64a7a5ca63b77bb7958f", size = 7830132, upload-time = "2026-03-30T08:47:33.254Z" }, +] + +[[package]] +name = "h11" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, +] + +[[package]] +name = "hf-xet" +version = "1.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/53/92/ec9ad04d0b5728dca387a45af7bc98fbb0d73b2118759f5f6038b61a57e8/hf_xet-1.4.3.tar.gz", hash = "sha256:8ddedb73c8c08928c793df2f3401ec26f95be7f7e516a7bee2fbb546f6676113", size = 670477, upload-time = "2026-03-31T22:40:07.874Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ac/9f/9c23e4a447b8f83120798f9279d0297a4d1360bdbf59ef49ebec78fe2545/hf_xet-1.4.3-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:d0da85329eaf196e03e90b84c2d0aca53bd4573d097a75f99609e80775f98025", size = 3805048, upload-time = "2026-03-31T22:39:53.105Z" }, + { url = "https://files.pythonhosted.org/packages/0b/f8/7aacb8e5f4a7899d39c787b5984e912e6c18b11be136ef13947d7a66d265/hf_xet-1.4.3-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:e23717ce4186b265f69afa66e6f0069fe7efbf331546f5c313d00e123dc84583", size = 3562178, upload-time = "2026-03-31T22:39:51.295Z" }, + { url = "https://files.pythonhosted.org/packages/df/9a/a24b26dc8a65f0ecc0fe5be981a19e61e7ca963b85e062c083f3a9100529/hf_xet-1.4.3-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fc360b70c815bf340ed56c7b8c63aacf11762a4b099b2fe2c9bd6d6068668c08", size = 4212320, upload-time = "2026-03-31T22:39:42.922Z" }, + { url = "https://files.pythonhosted.org/packages/53/60/46d493db155d2ee2801b71fb1b0fd67696359047fdd8caee2c914cc50c79/hf_xet-1.4.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:39f2d2e9654cd9b4319885733993807aab6de9dfbd34c42f0b78338d6617421f", size = 3991546, upload-time = "2026-03-31T22:39:41.335Z" }, + { url = "https://files.pythonhosted.org/packages/bc/f5/067363e1c96c6b17256910830d1b54099d06287e10f4ec6ec4e7e08371fc/hf_xet-1.4.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:49ad8a8cead2b56051aa84d7fce3e1335efe68df3cf6c058f22a65513885baac", size = 4193200, upload-time = "2026-03-31T22:40:01.936Z" }, + { url = "https://files.pythonhosted.org/packages/42/4b/53951592882d9c23080c7644542fda34a3813104e9e11fa1a7d82d419cb8/hf_xet-1.4.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:7716d62015477a70ea272d2d68cd7cad140f61c52ee452e133e139abfe2c17ba", size = 4429392, upload-time = "2026-03-31T22:40:03.492Z" }, + { url = "https://files.pythonhosted.org/packages/8a/21/75a6c175b4e79662ad8e62f46a40ce341d8d6b206b06b4320d07d55b188c/hf_xet-1.4.3-cp37-abi3-win_amd64.whl", hash = "sha256:6b591fcad34e272a5b02607485e4f2a1334aebf1bc6d16ce8eb1eb8978ac2021", size = 3677359, upload-time = "2026-03-31T22:40:13.619Z" }, + { url = "https://files.pythonhosted.org/packages/8a/7c/44314ecd0e89f8b2b51c9d9e5e7a60a9c1c82024ac471d415860557d3cd8/hf_xet-1.4.3-cp37-abi3-win_arm64.whl", hash = "sha256:7c2c7e20bcfcc946dc67187c203463f5e932e395845d098cc2a93f5b67ca0b47", size = 3533664, upload-time = "2026-03-31T22:40:12.152Z" }, +] + +[[package]] +name = "httpcore" +version = "1.0.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, +] + +[[package]] +name = "httptools" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/46/120a669232c7bdedb9d52d4aeae7e6c7dfe151e99dc70802e2fc7a5e1993/httptools-0.7.1.tar.gz", hash = "sha256:abd72556974f8e7c74a259655924a717a2365b236c882c3f6f8a45fe94703ac9", size = 258961, upload-time = "2025-10-10T03:55:08.559Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/a6/b3965e1e146ef5762870bbe76117876ceba51a201e18cc31f5703e454596/httptools-0.7.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2c15f37ef679ab9ecc06bfc4e6e8628c32a8e4b305459de7cf6785acd57e4d03", size = 517655, upload-time = "2025-10-10T03:54:41.347Z" }, + { url = "https://files.pythonhosted.org/packages/11/7d/71fee6f1844e6fa378f2eddde6c3e41ce3a1fb4b2d81118dd544e3441ec0/httptools-0.7.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7fe6e96090df46b36ccfaf746f03034e5ab723162bc51b0a4cf58305324036f2", size = 511440, upload-time = "2025-10-10T03:54:42.452Z" }, + { url = "https://files.pythonhosted.org/packages/22/a5/079d216712a4f3ffa24af4a0381b108aa9c45b7a5cc6eb141f81726b1823/httptools-0.7.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f72fdbae2dbc6e68b8239defb48e6a5937b12218e6ffc2c7846cc37befa84362", size = 495186, upload-time = "2025-10-10T03:54:43.937Z" }, + { url = "https://files.pythonhosted.org/packages/e9/9e/025ad7b65278745dee3bd0ebf9314934c4592560878308a6121f7f812084/httptools-0.7.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e99c7b90a29fd82fea9ef57943d501a16f3404d7b9ee81799d41639bdaae412c", size = 499192, upload-time = "2025-10-10T03:54:45.003Z" }, +] + +[[package]] +name = "httpx" +version = "0.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "certifi" }, + { name = "httpcore" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, +] + +[[package]] +name = "httpx-sse" +version = "0.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" }, +] + +[[package]] +name = "huggingface-hub" +version = "1.9.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "fsspec" }, + { name = "hf-xet", marker = "platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, + { name = "httpx" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "tqdm" }, + { name = "typer" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cf/65/fb800d327bf25bf31b798dd08935d326d064ecb9b359059fecd91b3a98e8/huggingface_hub-1.9.2.tar.gz", hash = "sha256:8d09d080a186bd950a361bfc04b862dfb04d6a2b41d48e9ba1b37507cfd3f1e1", size = 750284, upload-time = "2026-04-08T08:43:11.127Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/57/d4/e33bf0b362810a9b96c5923e38908950d58ecb512db42e3730320c7f4a3a/huggingface_hub-1.9.2-py3-none-any.whl", hash = "sha256:e1e62ce237d4fbeca9f970aeb15176fbd503e04c25577bfd22f44aa7aa2b5243", size = 637349, upload-time = "2026-04-08T08:43:09.114Z" }, +] + +[[package]] +name = "idna" +version = "3.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6f/6d/0703ccc57f3a7233505399edb88de3cbd678da106337b9fcde432b65ed60/idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902", size = 194582, upload-time = "2025-10-12T14:55:20.501Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, +] + +[[package]] +name = "ijson" +version = "3.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f4/57/60d1a6a512f2f0508d0bc8b4f1cc5616fd3196619b66bd6a01f9155a1292/ijson-3.5.0.tar.gz", hash = "sha256:94688760720e3f5212731b3cb8d30267f9a045fb38fb3870254e7b9504246f31", size = 68658, upload-time = "2026-02-24T03:58:30.974Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/76/6f91bdb019dd978fce1bc5ea1cd620cfc096d258126c91db2c03a20a7f34/ijson-3.5.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:7d48dc2984af02eb3c56edfb3f13b3f62f2f3e4fe36f058c8cfc75d93adf4fed", size = 138977, upload-time = "2026-02-24T03:57:11.932Z" }, + { url = "https://files.pythonhosted.org/packages/11/be/bbc983059e48a54b0121ee60042979faed7674490bbe7b2c41560db3f436/ijson-3.5.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f1e73a44844d9adbca9cf2c4132cd875933e83f3d4b23881fcaf82be83644c7d", size = 149785, upload-time = "2026-02-24T03:57:13.255Z" }, + { url = "https://files.pythonhosted.org/packages/6d/81/2fee58f9024a3449aee83edfa7167fb5ccd7e1af2557300e28531bb68e16/ijson-3.5.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7389a56b8562a19948bdf1d7bae3a2edc8c7f86fb59834dcb1c4c722818e645a", size = 149729, upload-time = "2026-02-24T03:57:14.191Z" }, + { url = "https://files.pythonhosted.org/packages/c7/56/f1706761fcc096c9d414b3dcd000b1e6e5c24364c21cfba429837f98ee8d/ijson-3.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3176f23f8ebec83f374ed0c3b4e5a0c4db7ede54c005864efebbed46da123608", size = 150697, upload-time = "2026-02-24T03:57:15.855Z" }, + { url = "https://files.pythonhosted.org/packages/d9/6e/ee0d9c875a0193b632b3e9ccd1b22a50685fb510256ad57ba483b6529f77/ijson-3.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:6babd88e508630c6ef86c9bebaaf13bb2fb8ec1d8f8868773a03c20253f599bc", size = 142873, upload-time = "2026-02-24T03:57:16.831Z" }, + { url = "https://files.pythonhosted.org/packages/d2/bf/f9d4399d0e6e3fd615035290a71e97c843f17f329b43638c0a01cf112d73/ijson-3.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:dc1b3836b174b6db2fa8319f1926fb5445abd195dc963368092103f8579cb8ed", size = 151583, upload-time = "2026-02-24T03:57:17.757Z" }, +] + +[[package]] +name = "importlib-metadata" +version = "8.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "zipp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f3/49/3b30cad09e7771a4982d9975a8cbf64f00d4a1ececb53297f1d9a7be1b10/importlib_metadata-8.7.1.tar.gz", hash = "sha256:49fef1ae6440c182052f407c8d34a68f72efc36db9ca90dc0113398f2fdde8bb", size = 57107, upload-time = "2025-12-21T10:00:19.278Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" }, +] + +[[package]] +name = "interegular" +version = "0.3.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/dc/9d/8b6dde58a028a3962ce17e84d5fe73758df61378e00ef8ac3d85da34b0ff/interegular-0.3.3.tar.gz", hash = "sha256:d9b697b21b34884711399ba0f0376914b81899ce670032486d0d048344a76600", size = 24705, upload-time = "2024-01-06T23:01:22.372Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c4/01/72d6472f80651673716d1deda2a5bbb633e563ecf94f4479da5519d69d25/interegular-0.3.3-py37-none-any.whl", hash = "sha256:b0c07007d48c89d6d19f7204972d369b2a77222722e126b6aa63aa721dc3b19c", size = 23635, upload-time = "2024-01-06T23:01:20.829Z" }, +] + +[[package]] +name = "jinja2" +version = "3.1.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, +] + +[[package]] +name = "jiter" +version = "0.13.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/5e/4ec91646aee381d01cdb9974e30882c9cd3b8c5d1079d6b5ff4af522439a/jiter-0.13.0.tar.gz", hash = "sha256:f2839f9c2c7e2dffc1bc5929a510e14ce0a946be9365fd1219e7ef342dae14f4", size = 164847, upload-time = "2026-02-02T12:37:56.441Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cf/52/e5719a60ac5d4d7c5995461a94ad5ef962a37c8bf5b088390e6fad59b2ff/jiter-0.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1e2b199f446d3e82246b4fd9236d7cb502dc2222b18698ba0d986d2fecc6152", size = 348821, upload-time = "2026-02-02T12:36:00.093Z" }, + { url = "https://files.pythonhosted.org/packages/61/db/c1efc32b8ba4c740ab3fc2d037d8753f67685f475e26b9d6536a4322bcdd/jiter-0.13.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04670992b576fa65bd056dbac0c39fe8bd67681c380cb2b48efa885711d9d726", size = 364163, upload-time = "2026-02-02T12:36:01.937Z" }, + { url = "https://files.pythonhosted.org/packages/55/8a/fb75556236047c8806995671a18e4a0ad646ed255276f51a20f32dceaeec/jiter-0.13.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a1aff1fbdb803a376d4d22a8f63f8e7ccbce0b4890c26cc7af9e501ab339ef0", size = 483709, upload-time = "2026-02-02T12:36:03.41Z" }, + { url = "https://files.pythonhosted.org/packages/7e/16/43512e6ee863875693a8e6f6d532e19d650779d6ba9a81593ae40a9088ff/jiter-0.13.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b3fb8c2053acaef8580809ac1d1f7481a0a0bdc012fd7f5d8b18fb696a5a089", size = 370480, upload-time = "2026-02-02T12:36:04.791Z" }, + { url = "https://files.pythonhosted.org/packages/f8/4c/09b93e30e984a187bc8aaa3510e1ec8dcbdcd71ca05d2f56aac0492453aa/jiter-0.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdaba7d87e66f26a2c45d8cbadcbfc4bf7884182317907baf39cfe9775bb4d93", size = 360735, upload-time = "2026-02-02T12:36:06.994Z" }, + { url = "https://files.pythonhosted.org/packages/1a/1b/46c5e349019874ec5dfa508c14c37e29864ea108d376ae26d90bee238cd7/jiter-0.13.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7b88d649135aca526da172e48083da915ec086b54e8e73a425ba50999468cc08", size = 391814, upload-time = "2026-02-02T12:36:08.368Z" }, + { url = "https://files.pythonhosted.org/packages/15/9e/26184760e85baee7162ad37b7912797d2077718476bf91517641c92b3639/jiter-0.13.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e404ea551d35438013c64b4f357b0474c7abf9f781c06d44fcaf7a14c69ff9e2", size = 513990, upload-time = "2026-02-02T12:36:09.993Z" }, + { url = "https://files.pythonhosted.org/packages/e9/34/2c9355247d6debad57a0a15e76ab1566ab799388042743656e566b3b7de1/jiter-0.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1f4748aad1b4a93c8bdd70f604d0f748cdc0e8744c5547798acfa52f10e79228", size = 548021, upload-time = "2026-02-02T12:36:11.376Z" }, + { url = "https://files.pythonhosted.org/packages/c4/10/528b439290763bff3d939268085d03382471b442f212dca4ff5f12802d43/jiter-0.13.0-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab44b178f7981fcaea7e0a5df20e773c663d06ffda0198f1a524e91b2fde7e59", size = 337384, upload-time = "2026-02-02T12:37:53.582Z" }, + { url = "https://files.pythonhosted.org/packages/67/8a/a342b2f0251f3dac4ca17618265d93bf244a2a4d089126e81e4c1056ac50/jiter-0.13.0-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bb00b6d26db67a05fe3e12c76edc75f32077fb51deed13822dc648fa373bc19", size = 343768, upload-time = "2026-02-02T12:37:55.055Z" }, +] + +[[package]] +name = "jmespath" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d3/59/322338183ecda247fb5d1763a6cbe46eff7222eaeebafd9fa65d4bf5cb11/jmespath-1.1.0.tar.gz", hash = "sha256:472c87d80f36026ae83c6ddd0f1d05d4e510134ed462851fd5f754c8c3cbb88d", size = 27377, upload-time = "2026-01-22T16:35:26.279Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/2f/967ba146e6d58cf6a652da73885f52fc68001525b4197effc174321d70b4/jmespath-1.1.0-py3-none-any.whl", hash = "sha256:a5663118de4908c91729bea0acadca56526eb2698e83de10cd116ae0f4e97c64", size = 20419, upload-time = "2026-01-22T16:35:24.919Z" }, +] + +[[package]] +name = "jsonschema" +version = "4.26.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "jsonschema-specifications" }, + { name = "referencing" }, + { name = "rpds-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/fc/e067678238fa451312d4c62bf6e6cf5ec56375422aee02f9cb5f909b3047/jsonschema-4.26.0.tar.gz", hash = "sha256:0c26707e2efad8aa1bfc5b7ce170f3fccc2e4918ff85989ba9ffa9facb2be326", size = 366583, upload-time = "2026-01-07T13:41:07.246Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/90/f63fb5873511e014207a475e2bb4e8b2e570d655b00ac19a9a0ca0a385ee/jsonschema-4.26.0-py3-none-any.whl", hash = "sha256:d489f15263b8d200f8387e64b4c3a75f06629559fb73deb8fdfb525f2dab50ce", size = 90630, upload-time = "2026-01-07T13:41:05.306Z" }, +] + +[[package]] +name = "jsonschema-specifications" +version = "2025.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "referencing" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/74/a633ee74eb36c44aa6d1095e7cc5569bebf04342ee146178e2d36600708b/jsonschema_specifications-2025.9.1.tar.gz", hash = "sha256:b540987f239e745613c7a9176f3edb72b832a4ac465cf02712288397832b5e8d", size = 32855, upload-time = "2025-09-08T01:34:59.186Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" }, +] + +[[package]] +name = "lark" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/af/60/bc7622aefb2aee1c0b4ba23c1446d3e30225c8770b38d7aedbfb65ca9d5a/lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80", size = 252132, upload-time = "2024-08-13T19:49:00.652Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2d/00/d90b10b962b4277f5e64a78b6609968859ff86889f5b898c1a778c06ec00/lark-1.2.2-py3-none-any.whl", hash = "sha256:c2276486b02f0f1b90be155f2c8ba4a8e194d42775786db622faccd652d8e80c", size = 111036, upload-time = "2024-08-13T19:48:58.603Z" }, +] + +[[package]] +name = "llguidance" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/48/3f7a9d3ff1b36bba92b5107a3a21286821227afe9ea464736133994d61fb/llguidance-1.3.0.tar.gz", hash = "sha256:861249afd51dc325646834462ea827e57a5c2b2042e108e6aae7059fdad9104d", size = 1070460, upload-time = "2025-10-20T19:58:44.164Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/11/44389d3d1526d7a5c38ffd587a5ebc61d7bee443ac1dea95f2089ad58f5f/llguidance-1.3.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f6caca5d78db7f76e1fbb0fff8607b861c32d47fa3d5dee2fc49de27ee269df", size = 2835242, upload-time = "2025-10-20T19:58:34.518Z" }, + { url = "https://files.pythonhosted.org/packages/83/a8/1ff2bedb8f9acb46a2d2d603415d272bb622c142ea86f5b95445cc6e366c/llguidance-1.3.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc17e9dd602c3879bf91664a64bf72f54c74dbfbeb24ccfab6a5fe435b12f7aa", size = 3033133, upload-time = "2025-10-20T19:58:38.721Z" }, +] + +[[package]] +name = "llvmlite" +version = "0.47.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/88/a8952b6d5c21e74cbf158515b779666f692846502623e9e3c39d8e8ba25f/llvmlite-0.47.0.tar.gz", hash = "sha256:62031ce968ec74e95092184d4b0e857e444f8fdff0b8f9213707699570c33ccc", size = 193614, upload-time = "2026-03-31T18:29:53.497Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/4b/e3f2cd17822cf772a4a51a0a8080b0032e6d37b2dbe8cfb724eac4e31c52/llvmlite-0.47.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5853bf26160857c0c2573415ff4efe01c4c651e59e2c55c2a088740acfee51cd", size = 56275178, upload-time = "2026-03-31T18:28:48.342Z" }, + { url = "https://files.pythonhosted.org/packages/b6/55/a3b4a543185305a9bdf3d9759d53646ed96e55e7dfd43f53e7a421b8fbae/llvmlite-0.47.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:003bcf7fa579e14db59c1a1e113f93ab8a06b56a4be31c7f08264d1d4072d077", size = 55128632, upload-time = "2026-03-31T18:28:52.901Z" }, +] + +[[package]] +name = "lm-format-enforcer" +version = "0.11.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "interegular" }, + { name = "packaging" }, + { name = "pydantic" }, + { name = "pyyaml" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/84/d5/41cd417ba7dfdbbcfe46cebf81fb3dfd7c591b89897560ad05bb410a465d/lm_format_enforcer-0.11.3.tar.gz", hash = "sha256:e68081c108719cce284a9bcc889709b26ffb085a1945b5eba3a12cfa96d528da", size = 40258, upload-time = "2025-08-24T19:37:47.527Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/ef/11292bb0b85cf4c93447cab5a29f64576ed14d3ab4280e35ddd23486594a/lm_format_enforcer-0.11.3-py3-none-any.whl", hash = "sha256:cf586350875def1ae7a8fba84fcbbfc8371424b6c9d05c1fcba70aa233fbf06f", size = 45418, upload-time = "2025-08-24T19:37:46.325Z" }, +] + +[[package]] +name = "loguru" +version = "0.7.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3a/05/a1dae3dffd1116099471c643b8924f5aa6524411dc6c63fdae648c4f1aca/loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6", size = 63559, upload-time = "2024-12-06T11:20:56.608Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/29/0348de65b8cc732daa3e33e67806420b2ae89bdce2b04af740289c5c6c8c/loguru-0.7.3-py3-none-any.whl", hash = "sha256:31a33c10c8e1e10422bfd431aeb5d351c7cf7fa671e3c4df004162264b28220c", size = 61595, upload-time = "2024-12-06T11:20:54.538Z" }, +] + +[[package]] +name = "markdown-it-py" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" }, +] + +[[package]] +name = "markupsafe" +version = "3.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7e/99/7690b6d4034fffd95959cbe0c02de8deb3098cc577c67bb6a24fe5d7caa7/markupsafe-3.0.3.tar.gz", hash = "sha256:722695808f4b6457b320fdc131280796bdceb04ab50fe1795cd540799ebe1698", size = 80313, upload-time = "2025-09-27T18:37:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/2c/799f4742efc39633a1b54a92eec4082e4f815314869865d876824c257c1e/markupsafe-3.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3a7e8ae81ae39e62a41ec302f972ba6ae23a5c5396c8e60113e9066ef893da0d", size = 24332, upload-time = "2025-09-27T18:36:32.813Z" }, + { url = "https://files.pythonhosted.org/packages/3c/2e/8d0c2ab90a8c1d9a24f0399058ab8519a3279d1bd4289511d74e909f060e/markupsafe-3.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d6dd0be5b5b189d31db7cda48b91d7e0a9795f31430b7f271219ab30f1d3ac9d", size = 22947, upload-time = "2025-09-27T18:36:33.86Z" }, + { url = "https://files.pythonhosted.org/packages/2c/54/887f3092a85238093a0b2154bd629c89444f395618842e8b0c41783898ea/markupsafe-3.0.3-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:94c6f0bb423f739146aec64595853541634bde58b2135f27f61c1ffd1cd4d16a", size = 21962, upload-time = "2025-09-27T18:36:35.099Z" }, + { url = "https://files.pythonhosted.org/packages/c9/2f/336b8c7b6f4a4d95e91119dc8521402461b74a485558d8f238a68312f11c/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:be8813b57049a7dc738189df53d69395eba14fb99345e0a5994914a3864c8a4b", size = 23760, upload-time = "2025-09-27T18:36:36.001Z" }, + { url = "https://files.pythonhosted.org/packages/32/43/67935f2b7e4982ffb50a4d169b724d74b62a3964bc1a9a527f5ac4f1ee2b/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:83891d0e9fb81a825d9a6d61e3f07550ca70a076484292a70fde82c4b807286f", size = 21529, upload-time = "2025-09-27T18:36:36.906Z" }, + { url = "https://files.pythonhosted.org/packages/89/e0/4486f11e51bbba8b0c041098859e869e304d1c261e59244baa3d295d47b7/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:77f0643abe7495da77fb436f50f8dab76dbc6e5fd25d39589a0f1fe6548bfa2b", size = 23015, upload-time = "2025-09-27T18:36:37.868Z" }, +] + +[[package]] +name = "mcp" +version = "1.27.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "jsonschema" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "pyjwt", extra = ["crypto"] }, + { name = "python-multipart" }, + { name = "sse-starlette" }, + { name = "starlette" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8b/eb/c0cfc62075dc6e1ec1c64d352ae09ac051d9334311ed226f1f425312848a/mcp-1.27.0.tar.gz", hash = "sha256:d3dc35a7eec0d458c1da4976a48f982097ddaab87e278c5511d5a4a56e852b83", size = 607509, upload-time = "2026-04-02T14:48:08.88Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9c/46/f6b4ad632c67ef35209a66127e4bddc95759649dd595f71f13fba11bdf9a/mcp-1.27.0-py3-none-any.whl", hash = "sha256:5ce1fa81614958e267b21fb2aa34e0aea8e2c6ede60d52aba45fd47246b4d741", size = 215967, upload-time = "2026-04-02T14:48:07.24Z" }, +] + +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + +[[package]] +name = "mistral-common" +version = "1.11.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonschema" }, + { name = "numpy" }, + { name = "pillow" }, + { name = "pydantic" }, + { name = "pydantic-extra-types", extra = ["pycountry"] }, + { name = "requests" }, + { name = "tiktoken" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c2/eb/12167a1bea9714582e5b4f539f9c019323363e314a499c72855ff0e5ad43/mistral_common-1.11.2.tar.gz", hash = "sha256:79f68fc2d1190f28637f40e053f919c8c2697e00b2aa679ddee562a95183f4ad", size = 6357845, upload-time = "2026-05-04T19:47:40.413Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/f0/6a5d604b972e442b9d36c117d01788feddad099e4965699e3516ee6fefc3/mistral_common-1.11.2-py3-none-any.whl", hash = "sha256:ebb42062cd705a0aa2bc69b4cde2b83d446ae58150b7e29322c90cb08fcfca6c", size = 6531968, upload-time = "2026-05-04T19:47:37.718Z" }, +] + +[package.optional-dependencies] +image = [ + { name = "opencv-python-headless" }, +] + +[[package]] +name = "ml-dtypes" +version = "0.5.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/4a/c27b42ed9b1c7d13d9ba8b6905dece787d6259152f2309338aed29b2447b/ml_dtypes-0.5.4.tar.gz", hash = "sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453", size = 692314, upload-time = "2025-11-17T22:32:31.031Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/0f/428ef6881782e5ebb7eca459689448c0394fa0a80bea3aa9262cba5445ea/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a7f7c643e8b1320fd958bf098aa7ecf70623a42ec5154e3be3be673f4c34d900", size = 5028464, upload-time = "2025-11-17T22:31:50.135Z" }, + { url = "https://files.pythonhosted.org/packages/3a/cb/28ce52eb94390dda42599c98ea0204d74799e4d8047a0eb559b6fd648056/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff", size = 5009002, upload-time = "2025-11-17T22:31:52.001Z" }, +] + +[[package]] +name = "model-hosting-container-standards" +version = "0.1.14" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastapi" }, + { name = "httpx" }, + { name = "jmespath" }, + { name = "pydantic" }, + { name = "setuptools" }, + { name = "starlette" }, + { name = "supervisor" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c6/3d/cf5c6029648cb0a116f7b5c2f74aa155ab0c6dd723a1f204a6d7ff354526/model_hosting_container_standards-0.1.14.tar.gz", hash = "sha256:b6cf4c46d88ce6acd6e543a578bb88ffd55d1179a7c09c22e61ae1d8a567c564", size = 90386, upload-time = "2026-03-18T21:25:14.513Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/94/052452842d39c562237a70345c57ec213a9db22bd25bba998fd2b32d70a7/model_hosting_container_standards-0.1.14-py3-none-any.whl", hash = "sha256:d678be6745899b8ba1e8246c96b101e7802a6a4ea3fb5d90ae8d6eb4204e84c6", size = 121406, upload-time = "2026-03-18T21:25:12.932Z" }, +] + +[[package]] +name = "mpmath" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106, upload-time = "2023-03-07T16:47:11.061Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, +] + +[[package]] +name = "msgspec" +version = "0.21.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c2/ae/d8fab0915716e70910012c0410d16b5eedf542493d19aa80c155215208bf/msgspec-0.21.0.tar.gz", hash = "sha256:9a37c1fb022f895bb24dfac597e449e19eb0cbe62447a832601cb19bb480b51d", size = 318712, upload-time = "2026-04-08T19:57:50.919Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/69/a978335a9724a69ac4428e06be1cb8ce7e737453857575028159bd264ded/msgspec-0.21.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:46e5e9b23bfa453572d8290541327d84cac1f74bbf45b88053dfea3b92d2608b", size = 218640, upload-time = "2026-04-08T19:57:09.203Z" }, + { url = "https://files.pythonhosted.org/packages/7b/34/3cb2b8a506850b8667c1167eb817a0b6605ebdf0027d301815ca2404f72b/msgspec-0.21.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff68f1f12aa3fa1335b79a5bb8b9158cfea2944b4cf8253d05fe28ab6d3510f", size = 224786, upload-time = "2026-04-08T19:57:10.679Z" }, + { url = "https://files.pythonhosted.org/packages/ff/4e/690f1487f72f37ca4482d4c63dceaf48d2b68db76d374108d7f0a15cc72c/msgspec-0.21.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6067127b5e44430a59fddff8d934a7a37ce96862cb25994415b68db7d4457bd5", size = 222514, upload-time = "2026-04-08T19:57:11.974Z" }, + { url = "https://files.pythonhosted.org/packages/83/95/4199f819d2b82db9c7d6de235591c02eebe4796672184eccad7f2b67d4e1/msgspec-0.21.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:11043d534a1bfcd08f1d4d5b50ba60015527b4c8517ec12c2213899e81913584", size = 227101, upload-time = "2026-04-08T19:57:13.278Z" }, +] + +[[package]] +name = "multidict" +version = "6.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1a/c2/c2d94cbe6ac1753f3fc980da97b3d930efe1da3af3c9f5125354436c073d/multidict-6.7.1.tar.gz", hash = "sha256:ec6652a1bee61c53a3e5776b6049172c53b6aaba34f18c9ad04f82712bac623d", size = 102010, upload-time = "2026-01-26T02:46:45.979Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cf/3b/d6bd75dc4f3ff7c73766e04e705b00ed6dbbaccf670d9e05a12b006f5a21/multidict-6.7.1-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:cb2a55f408c3043e42b40cc8eecd575afa27b7e0b956dfb190de0f8499a57a53", size = 251018, upload-time = "2026-01-26T02:43:56.198Z" }, + { url = "https://files.pythonhosted.org/packages/fd/80/c959c5933adedb9ac15152e4067c702a808ea183a8b64cf8f31af8ad3155/multidict-6.7.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:eb0ce7b2a32d09892b3dd6cc44877a0d02a33241fafca5f25c8b6b62374f8b75", size = 258883, upload-time = "2026-01-26T02:43:57.499Z" }, + { url = "https://files.pythonhosted.org/packages/86/85/7ed40adafea3d4f1c8b916e3b5cc3a8e07dfcdcb9cd72800f4ed3ca1b387/multidict-6.7.1-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:c3a32d23520ee37bf327d1e1a656fec76a2edd5c038bf43eddfa0572ec49c60b", size = 242413, upload-time = "2026-01-26T02:43:58.755Z" }, + { url = "https://files.pythonhosted.org/packages/d2/57/b8565ff533e48595503c785f8361ff9a4fde4d67de25c207cd0ba3befd03/multidict-6.7.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:9c90fed18bffc0189ba814749fdcc102b536e83a9f738a9003e569acd540a733", size = 268404, upload-time = "2026-01-26T02:44:00.216Z" }, + { url = "https://files.pythonhosted.org/packages/e0/50/9810c5c29350f7258180dfdcb2e52783a0632862eb334c4896ac717cebcb/multidict-6.7.1-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:da62917e6076f512daccfbbde27f46fed1c98fee202f0559adec8ee0de67f71a", size = 269456, upload-time = "2026-01-26T02:44:02.202Z" }, + { url = "https://files.pythonhosted.org/packages/f3/8d/5e5be3ced1d12966fefb5c4ea3b2a5b480afcea36406559442c6e31d4a48/multidict-6.7.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bfde23ef6ed9db7eaee6c37dcec08524cb43903c60b285b172b6c094711b3961", size = 256322, upload-time = "2026-01-26T02:44:03.56Z" }, + { url = "https://files.pythonhosted.org/packages/31/6e/d8a26d81ac166a5592782d208dd90dfdc0a7a218adaa52b45a672b46c122/multidict-6.7.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3758692429e4e32f1ba0df23219cd0b4fc0a52f476726fff9337d1a57676a582", size = 253955, upload-time = "2026-01-26T02:44:04.845Z" }, + { url = "https://files.pythonhosted.org/packages/59/4c/7c672c8aad41534ba619bcd4ade7a0dc87ed6b8b5c06149b85d3dd03f0cd/multidict-6.7.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:398c1478926eca669f2fd6a5856b6de9c0acf23a2cb59a14c0ba5844fa38077e", size = 251254, upload-time = "2026-01-26T02:44:06.133Z" }, + { url = "https://files.pythonhosted.org/packages/7b/bd/84c24de512cbafbdbc39439f74e967f19570ce7924e3007174a29c348916/multidict-6.7.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c102791b1c4f3ab36ce4101154549105a53dc828f016356b3e3bcae2e3a039d3", size = 252059, upload-time = "2026-01-26T02:44:07.518Z" }, + { url = "https://files.pythonhosted.org/packages/fa/ba/f5449385510825b73d01c2d4087bf6d2fccc20a2d42ac34df93191d3dd03/multidict-6.7.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:a088b62bd733e2ad12c50dad01b7d0166c30287c166e137433d3b410add807a6", size = 263588, upload-time = "2026-01-26T02:44:09.382Z" }, + { url = "https://files.pythonhosted.org/packages/d7/11/afc7c677f68f75c84a69fe37184f0f82fce13ce4b92f49f3db280b7e92b3/multidict-6.7.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3d51ff4785d58d3f6c91bdbffcb5e1f7ddfda557727043aa20d20ec4f65e324a", size = 259642, upload-time = "2026-01-26T02:44:10.73Z" }, + { url = "https://files.pythonhosted.org/packages/2b/17/ebb9644da78c4ab36403739e0e6e0e30ebb135b9caf3440825001a0bddcb/multidict-6.7.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fc5907494fccf3e7d3f94f95c91d6336b092b5fc83811720fae5e2765890dfba", size = 251377, upload-time = "2026-01-26T02:44:12.042Z" }, + { url = "https://files.pythonhosted.org/packages/81/08/7036c080d7117f28a4af526d794aab6a84463126db031b007717c1a6676e/multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56", size = 12319, upload-time = "2026-01-26T02:46:44.004Z" }, +] + +[[package]] +name = "networkx" +version = "3.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" }, +] + +[[package]] +name = "ninja" +version = "1.13.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/73/79a0b22fc731989c708068427579e840a6cf4e937fe7ae5c5d0b7356ac22/ninja-1.13.0.tar.gz", hash = "sha256:4a40ce995ded54d9dc24f8ea37ff3bf62ad192b547f6c7126e7e25045e76f978", size = 242558, upload-time = "2025-08-11T15:10:19.421Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/de/6e1cd6b84b412ac1ef327b76f0641aeb5dcc01e9d3f9eee0286d0c34fd93/ninja-1.13.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3d00c692fb717fd511abeb44b8c5d00340c36938c12d6538ba989fe764e79630", size = 177467, upload-time = "2025-08-11T15:09:52.767Z" }, + { url = "https://files.pythonhosted.org/packages/c8/83/49320fb6e58ae3c079381e333575fdbcf1cca3506ee160a2dcce775046fa/ninja-1.13.0-py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:be7f478ff9f96a128b599a964fc60a6a87b9fa332ee1bd44fa243ac88d50291c", size = 187834, upload-time = "2025-08-11T15:09:54.115Z" }, + { url = "https://files.pythonhosted.org/packages/56/c7/ba22748fb59f7f896b609cd3e568d28a0a367a6d953c24c461fe04fc4433/ninja-1.13.0-py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:60056592cf495e9a6a4bea3cd178903056ecb0943e4de45a2ea825edb6dc8d3e", size = 202736, upload-time = "2025-08-11T15:09:55.745Z" }, + { url = "https://files.pythonhosted.org/packages/79/22/d1de07632b78ac8e6b785f41fa9aad7a978ec8c0a1bf15772def36d77aac/ninja-1.13.0-py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:1c97223cdda0417f414bf864cfb73b72d8777e57ebb279c5f6de368de0062988", size = 179034, upload-time = "2025-08-11T15:09:57.394Z" }, + { url = "https://files.pythonhosted.org/packages/ed/de/0e6edf44d6a04dabd0318a519125ed0415ce437ad5a1ec9b9be03d9048cf/ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fb46acf6b93b8dd0322adc3a4945452a4e774b75b91293bafcc7b7f8e6517dfa", size = 180716, upload-time = "2025-08-11T15:09:58.696Z" }, + { url = "https://files.pythonhosted.org/packages/54/28/938b562f9057aaa4d6bfbeaa05e81899a47aebb3ba6751e36c027a7f5ff7/ninja-1.13.0-py3-none-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:4be9c1b082d244b1ad7ef41eb8ab088aae8c109a9f3f0b3e56a252d3e00f42c1", size = 146843, upload-time = "2025-08-11T15:10:00.046Z" }, + { url = "https://files.pythonhosted.org/packages/2a/fb/d06a3838de4f8ab866e44ee52a797b5491df823901c54943b2adb0389fbb/ninja-1.13.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:6739d3352073341ad284246f81339a384eec091d9851a886dfa5b00a6d48b3e2", size = 154402, upload-time = "2025-08-11T15:10:01.657Z" }, + { url = "https://files.pythonhosted.org/packages/31/bf/0d7808af695ceddc763cf251b84a9892cd7f51622dc8b4c89d5012779f06/ninja-1.13.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:11be2d22027bde06f14c343f01d31446747dbb51e72d00decca2eb99be911e2f", size = 552388, upload-time = "2025-08-11T15:10:03.349Z" }, + { url = "https://files.pythonhosted.org/packages/9d/70/c99d0c2c809f992752453cce312848abb3b1607e56d4cd1b6cded317351a/ninja-1.13.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:aa45b4037b313c2f698bc13306239b8b93b4680eb47e287773156ac9e9304714", size = 472501, upload-time = "2025-08-11T15:10:04.735Z" }, + { url = "https://files.pythonhosted.org/packages/9f/43/c217b1153f0e499652f5e0766da8523ce3480f0a951039c7af115e224d55/ninja-1.13.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5f8e1e8a1a30835eeb51db05cf5a67151ad37542f5a4af2a438e9490915e5b72", size = 638280, upload-time = "2025-08-11T15:10:06.512Z" }, + { url = "https://files.pythonhosted.org/packages/8c/45/9151bba2c8d0ae2b6260f71696330590de5850e5574b7b5694dce6023e20/ninja-1.13.0-py3-none-musllinux_1_2_ppc64le.whl", hash = "sha256:3d7d7779d12cb20c6d054c61b702139fd23a7a964ec8f2c823f1ab1b084150db", size = 642420, upload-time = "2025-08-11T15:10:08.35Z" }, + { url = "https://files.pythonhosted.org/packages/3c/fb/95752eb635bb8ad27d101d71bef15bc63049de23f299e312878fc21cb2da/ninja-1.13.0-py3-none-musllinux_1_2_riscv64.whl", hash = "sha256:d741a5e6754e0bda767e3274a0f0deeef4807f1fec6c0d7921a0244018926ae5", size = 585106, upload-time = "2025-08-11T15:10:09.818Z" }, + { url = "https://files.pythonhosted.org/packages/c1/31/aa56a1a286703800c0cbe39fb4e82811c277772dc8cd084f442dd8e2938a/ninja-1.13.0-py3-none-musllinux_1_2_s390x.whl", hash = "sha256:e8bad11f8a00b64137e9b315b137d8bb6cbf3086fbdc43bf1f90fd33324d2e96", size = 707138, upload-time = "2025-08-11T15:10:11.366Z" }, + { url = "https://files.pythonhosted.org/packages/34/6f/5f5a54a1041af945130abdb2b8529cbef0cdcbbf9bcf3f4195378319d29a/ninja-1.13.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b4f2a072db3c0f944c32793e91532d8948d20d9ab83da9c0c7c15b5768072200", size = 581758, upload-time = "2025-08-11T15:10:13.295Z" }, +] + +[[package]] +name = "numba" +version = "0.65.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "llvmlite" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/49/61/7299643b9c18d669e04be7c5bcb64d985070d07553274817b45b049e7bfe/numba-0.65.0.tar.gz", hash = "sha256:edad0d9f6682e93624c00125a471ae4df186175d71fd604c983c377cdc03e68b", size = 2764131, upload-time = "2026-04-01T03:52:01.946Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/36/88406bd58600cc696417b8e5dd6a056478da808f3eaf48d18e2421e0c2d9/numba-0.65.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a52d92ffd297c10364bce60cd1fcb88f99284ab5df085f2c6bcd1cb33b529a6f", size = 3801411, upload-time = "2026-04-01T03:51:34.321Z" }, + { url = "https://files.pythonhosted.org/packages/0c/61/ce753a1d7646dd477e16d15e89473703faebb8995d2f71d7ad69a540b565/numba-0.65.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:da8e371e328c06d0010c3d8b44b21858652831b85bcfba78cb22c042e22dbd8e", size = 3501622, upload-time = "2026-04-01T03:51:36.348Z" }, +] + +[[package]] +name = "numpy" +version = "1.26.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/65/6e/09db70a523a96d25e115e71cc56a6f9031e7b8cd166c1ac8438307c14058/numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010", size = 15786129, upload-time = "2024-02-06T00:26:44.495Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/12/8f2020a8e8b8383ac0177dc9570aad031a3beb12e38847f7129bacd96228/numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218", size = 20335901, upload-time = "2024-02-05T23:55:32.801Z" }, + { url = "https://files.pythonhosted.org/packages/75/5b/ca6c8bd14007e5ca171c7c03102d17b4f4e0ceb53957e8c44343a9546dcc/numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b", size = 13685868, upload-time = "2024-02-05T23:55:56.28Z" }, + { url = "https://files.pythonhosted.org/packages/79/f8/97f10e6755e2a7d027ca783f63044d5b1bc1ae7acb12afe6a9b4286eac17/numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b", size = 13925109, upload-time = "2024-02-05T23:56:20.368Z" }, + { url = "https://files.pythonhosted.org/packages/0f/50/de23fde84e45f5c4fda2488c759b69990fd4512387a8632860f3ac9cd225/numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed", size = 17950613, upload-time = "2024-02-05T23:56:56.054Z" }, + { url = "https://files.pythonhosted.org/packages/4c/0c/9c603826b6465e82591e05ca230dfc13376da512b25ccd0894709b054ed0/numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a", size = 13572172, upload-time = "2024-02-05T23:57:21.56Z" }, + { url = "https://files.pythonhosted.org/packages/76/8c/2ba3902e1a0fc1c74962ea9bb33a534bb05984ad7ff9515bf8d07527cadd/numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0", size = 17786643, upload-time = "2024-02-05T23:57:56.585Z" }, + { url = "https://files.pythonhosted.org/packages/28/4a/46d9e65106879492374999e76eb85f87b15328e06bd1550668f79f7b18c6/numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110", size = 5677803, upload-time = "2024-02-05T23:58:08.963Z" }, + { url = "https://files.pythonhosted.org/packages/16/2e/86f24451c2d530c88daf997cb8d6ac622c1d40d19f5a031ed68a4b73a374/numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818", size = 15517754, upload-time = "2024-02-05T23:58:36.364Z" }, +] + +[[package]] +name = "nvidia-cublas-cu12" +version = "12.8.4.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/99/db44d685f0e257ff0e213ade1964fc459b4a690a73293220e98feb3307cf/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:b86f6dd8935884615a0683b663891d43781b819ac4f2ba2b0c9604676af346d0", size = 590537124, upload-time = "2025-03-07T01:43:53.556Z" }, + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/1f/b3bd73445e5cb342727fd24fe1f7b748f690b460acadc27ea22f904502c8/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4412396548808ddfed3f17a467b104ba7751e6b58678a4b840675c56d21cf7ed", size = 9533318, upload-time = "2025-03-07T01:40:10.421Z" }, + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, + { url = "https://files.pythonhosted.org/packages/eb/d1/e50d0acaab360482034b84b6e27ee83c6738f7d32182b987f9c7a4e32962/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fc1fec1e1637854b4c0a65fb9a8346b51dd9ee69e61ebaccc82058441f15bce8", size = 43106076, upload-time = "2025-03-07T01:41:59.817Z" }, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/75/f865a3b236e4647605ea34cc450900854ba123834a5f1598e160b9530c3a/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:52bf7bbee900262ffefe5e9d5a2a69a30d97e2bc5bb6cc866688caa976966e3d", size = 965265, upload-time = "2025-03-07T01:39:43.533Z" }, + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.19.0.56" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/09/b8/277c51962ee46fa3e5b203ac5f76107c650f781d6891e681e28e6f3e9fe6/nvidia_cudnn_cu12-9.19.0.56-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:08caaf27fe556aca82a3ee3b5aa49a77e7de0cfcb7ff4e5c29da426387a8267e", size = 656910700, upload-time = "2026-02-03T20:40:25.508Z" }, + { url = "https://files.pythonhosted.org/packages/c5/41/65225d42fba06fb3dd3972485ea258e7dd07a40d6e01c95da6766ad87354/nvidia_cudnn_cu12-9.19.0.56-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:ac6ad90a075bb33a94f2b4cf4622eac13dd4dc65cf6dd9c7572a318516a36625", size = 657906812, upload-time = "2026-02-03T20:44:12.638Z" }, +] + +[[package]] +name = "nvidia-cudnn-frontend" +version = "1.18.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/b4/604e230378680ee117849a4e1045baca092f93161a829291a84d5acce70c/nvidia_cudnn_frontend-1.18.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:310b417f2848a83d1437203fcaeea320a74fb7f28af20bf42bf5afc9c01f1c12", size = 2027408, upload-time = "2026-01-27T23:32:46.576Z" }, + { url = "https://files.pythonhosted.org/packages/c6/52/08f98262e77b1cbcc834cc1a5db494d0661ea1dbdea58c2e2d51a57fdaca/nvidia_cudnn_frontend-1.18.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6c023539ca6de99234cf5102c3ec0d6af817f5396fc93028a22ba5b834a35b8a", size = 2159245, upload-time = "2026-01-27T23:07:32.664Z" }, +] + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.3.83" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" }, + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, +] + +[[package]] +name = "nvidia-cufile-cu12" +version = "1.13.1.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, + { url = "https://files.pythonhosted.org/packages/1e/f5/5607710447a6fe9fd9b3283956fceeee8a06cda1d2f56ce31371f595db2a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:4beb6d4cce47c1a0f1013d72e02b0994730359e17801d395bdcbf20cfb3bb00a", size = 1120705, upload-time = "2025-03-07T01:45:41.434Z" }, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.9.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/5e/92aa15eca622a388b80fbf8375d4760738df6285b1e92c43d37390a33a9a/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dfab99248034673b779bc6decafdc3404a8a6f502462201f2f31f11354204acd", size = 63625754, upload-time = "2025-03-07T01:46:10.735Z" }, + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.3.90" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cusparse-cu12" }, + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" }, + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, +] + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.8.93" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" }, + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, +] + +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/b9/598f6ff36faaece4b3c50d26f50e38661499ff34346f00e057760b35cc9d/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8878dce784d0fac90131b6817b607e803c36e629ba34dc5b433471382196b6a5", size = 283835557, upload-time = "2025-02-26T00:16:54.265Z" }, + { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, +] + +[[package]] +name = "nvidia-cutlass-dsl" +version = "4.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cutlass-dsl-libs-base" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/03/678dab0383db1ddfc449da216220f40404189eb36eeed9d87a4fa4bdb0e6/nvidia_cutlass_dsl-4.4.2-py3-none-any.whl", hash = "sha256:7cfb9ef19062b055b9372c7a627004724e2755e4c8b16c3cc88807d64501a4ae", size = 10167, upload-time = "2026-03-16T02:18:59.043Z" }, +] + +[[package]] +name = "nvidia-cutlass-dsl-libs-base" +version = "4.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-python" }, + { name = "numpy" }, + { name = "typing-extensions" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/7d/0df5e38d11e52cc72095a14d6448bc1c5d0d4b00b069a1189ca417fb225b/nvidia_cutlass_dsl_libs_base-4.4.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:2ec8812eeadcbb6fe20bda2e295ed9c00653f8253b78e33cf0ab65a47b829e73", size = 75473821, upload-time = "2026-03-16T02:27:08.371Z" }, + { url = "https://files.pythonhosted.org/packages/56/98/e264964741d9cc9816625d9600d17a5249fd5cbd8c2d166fb0d0c34dfe5a/nvidia_cutlass_dsl_libs_base-4.4.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:22e37b58f7a6f2f43bba533c4df8a088012122e0b4e9a632eca23937adeafb39", size = 74355593, upload-time = "2026-03-16T02:25:11.762Z" }, +] + +[[package]] +name = "nvidia-ml-py" +version = "13.595.45" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ce/49/c29f6e30d8662d2e94fef17739ea7309cc76aba269922ae999e4cc07f268/nvidia_ml_py-13.595.45.tar.gz", hash = "sha256:c9f34897fe0441ff35bc8f35baf80f830a20b0f4e6ce71e0a325bc0e66acf079", size = 50780, upload-time = "2026-03-19T16:59:44.956Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/24/fc256107d23597fa33d319505ce77160fa1a2349c096d01901ffc7cb7fc4/nvidia_ml_py-13.595.45-py3-none-any.whl", hash = "sha256:b65a7977f503d56154b14d683710125ef93594adb63fbf7e559336e3318f1376", size = 51776, upload-time = "2026-03-19T16:59:43.603Z" }, +] + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.28.9" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/c4/120d2dfd92dff2c776d68f361ff8705fdea2ca64e20b612fab0fd3f581ac/nvidia_nccl_cu12-2.28.9-py3-none-manylinux_2_18_aarch64.whl", hash = "sha256:50a36e01c4a090b9f9c47d92cec54964de6b9fcb3362d0e19b8ffc6323c21b60", size = 296766525, upload-time = "2025-11-18T05:49:16.094Z" }, + { url = "https://files.pythonhosted.org/packages/4a/4e/44dbb46b3d1b0ec61afda8e84837870f2f9ace33c564317d59b70bc19d3e/nvidia_nccl_cu12-2.28.9-py3-none-manylinux_2_18_x86_64.whl", hash = "sha256:485776daa8447da5da39681af455aa3b2c2586ddcf4af8772495e7c532c7e5ab", size = 296782137, upload-time = "2025-11-18T05:49:34.248Z" }, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, + { url = "https://files.pythonhosted.org/packages/2a/a2/8cee5da30d13430e87bf99bb33455d2724d0a4a9cb5d7926d80ccb96d008/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7", size = 38386204, upload-time = "2025-03-07T01:49:43.612Z" }, +] + +[[package]] +name = "nvidia-nvshmem-cu12" +version = "3.4.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/6a/03aa43cc9bd3ad91553a88b5f6fb25ed6a3752ae86ce2180221962bc2aa5/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b48363fc6964dede448029434c6abed6c5e37f823cb43c3bcde7ecfc0457e15", size = 138936938, upload-time = "2025-09-06T00:32:05.589Z" }, + { url = "https://files.pythonhosted.org/packages/b5/09/6ea3ea725f82e1e76684f0708bbedd871fc96da89945adeba65c3835a64c/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd", size = 139103095, upload-time = "2025-09-06T00:32:31.266Z" }, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/c0/1b303feea90d296f6176f32a2a70b5ef230f9bdeb3a72bddb0dc922dc137/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d7ad891da111ebafbf7e015d34879f7112832fc239ff0d7d776b6cb685274615", size = 91161, upload-time = "2025-03-07T01:42:23.922Z" }, + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, +] + +[[package]] +name = "openai" +version = "2.24.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "httpx" }, + { name = "jiter" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "tqdm" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/55/13/17e87641b89b74552ed408a92b231283786523edddc95f3545809fab673c/openai-2.24.0.tar.gz", hash = "sha256:1e5769f540dbd01cb33bc4716a23e67b9d695161a734aff9c5f925e2bf99a673", size = 658717, upload-time = "2026-02-24T20:02:07.958Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/30/844dc675ee6902579b8eef01ed23917cc9319a1c9c0c14ec6e39340c96d0/openai-2.24.0-py3-none-any.whl", hash = "sha256:fed30480d7d6c884303287bde864980a4b137b60553ffbcf9ab4a233b7a73d94", size = 1120122, upload-time = "2026-02-24T20:02:05.669Z" }, +] + +[[package]] +name = "openai-harmony" +version = "0.0.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3e/92/2d038d096f29179c7c9571b431f9e739f87a487121901725e23fe338dd9d/openai_harmony-0.0.8.tar.gz", hash = "sha256:6e43f98e6c242fa2de6f8ea12eab24af63fa2ed3e89c06341fb9d92632c5cbdf", size = 284777, upload-time = "2025-11-05T19:07:06.727Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d3/d2/ce6953ca87db9cae3e775024184da7d1c5cb88cead19a2d75b42f00a959c/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4f709815924ec325b9a890e6ab2bbb0ceec8e319a4e257328eb752cf36b2efc", size = 2948463, upload-time = "2025-11-05T19:06:48.17Z" }, + { url = "https://files.pythonhosted.org/packages/fa/4c/b553c9651662d6ce102ca7f3629d268b23df1abe5841e24bed81e8a8e949/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5cfcfd963b50a41fc656c84d3440ca6eecdccd6c552158ce790b8f2e33dfb5a9", size = 2704083, upload-time = "2025-11-05T19:06:50.205Z" }, + { url = "https://files.pythonhosted.org/packages/9b/af/4eec8f9ab9c27bcdb444460c72cf43011d176fc44c79d6e113094ca1e152/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a3a16972aa1cee38ea958470cd04ac9a2d5ac38fdcf77ab686611246220c158", size = 2959765, upload-time = "2025-11-05T19:06:53.62Z" }, + { url = "https://files.pythonhosted.org/packages/11/3c/33f3374e4624e0e776f6b13b73c45a7ead7f9c4529f8369ed5bfcaa30cac/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b4d5cfa168e74d08f8ba6d58a7e49bc7daef4d58951ec69b66b0d56f4927a68d", size = 3427031, upload-time = "2025-11-05T19:06:51.829Z" }, + { url = "https://files.pythonhosted.org/packages/25/3f/1a192b93bb47c6b44cd98ba8cc1d3d2a9308f1bb700c3017e6352da11bda/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c007d277218a50db8839e599ed78e0fffe5130f614c3f6d93ae257f282071a29", size = 2953260, upload-time = "2025-11-05T19:06:55.406Z" }, + { url = "https://files.pythonhosted.org/packages/5b/f8/93b582cad3531797c3db7c2db5400fd841538ccddfd9f5e3df61be99a630/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:8565d4f5a0638da1bffde29832ed63c9e695c558611053add3b2dc0b56c92dbc", size = 3127044, upload-time = "2025-11-05T19:06:59.553Z" }, + { url = "https://files.pythonhosted.org/packages/1d/10/4327dbf87f75ae813405fd9a9b4a5cde63d506ffed0a096a440a4cabd89c/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:cbaa3bda75ef0d8836e1f8cc84af62f971b1d756d740efc95c38c3e04c0bfde2", size = 2932931, upload-time = "2025-11-05T19:07:01.437Z" }, + { url = "https://files.pythonhosted.org/packages/8a/c8/1774eec4f6f360ef57618fb8f52e3d3af245b2491bd0297513aa09eec04b/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:772922a9bd24e133950fad71eb1550836f415a88e8c77870e12d0c3bd688ddc2", size = 2996140, upload-time = "2025-11-05T19:07:03.438Z" }, + { url = "https://files.pythonhosted.org/packages/60/c3/3d1e01e2dba517a91760e4a03e4f20ffc75039a6fe584d0e6f9b5c78fd15/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:007b0476a1f331f8130783f901f1da6f5a7057af1a4891f1b6a31dec364189b5", size = 3205080, upload-time = "2025-11-05T19:07:05.078Z" }, +] + +[[package]] +name = "opencv-python-headless" +version = "4.13.0.92" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/76/9417a6aef9def70e467a5bf560579f816148a4c658b7d525581b356eda9e/opencv_python_headless-4.13.0.92-cp37-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5c8cfc8e87ed452b5cecb9419473ee5560a989859fe1d10d1ce11ae87b09a2cb", size = 33703709, upload-time = "2026-02-05T10:24:46.469Z" }, + { url = "https://files.pythonhosted.org/packages/92/ce/bd17ff5772938267fd49716e94ca24f616ff4cb1ff4c6be13085108037be/opencv_python_headless-4.13.0.92-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0525a3d2c0b46c611e2130b5fdebc94cf404845d8fa64d2f3a3b679572a5bd22", size = 56016764, upload-time = "2026-02-05T10:26:48.904Z" }, + { url = "https://files.pythonhosted.org/packages/8f/b4/b7bcbf7c874665825a8c8e1097e93ea25d1f1d210a3e20d4451d01da30aa/opencv_python_headless-4.13.0.92-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:eb60e36b237b1ebd40a912da5384b348df8ed534f6f644d8e0b4f103e272ba7d", size = 35010236, upload-time = "2026-02-05T10:28:11.031Z" }, + { url = "https://files.pythonhosted.org/packages/4b/33/b5db29a6c00eb8f50708110d8d453747ca125c8b805bc437b289dbdcc057/opencv_python_headless-4.13.0.92-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:0bd48544f77c68b2941392fcdf9bcd2b9cdf00e98cb8c29b2455d194763cf99e", size = 60391106, upload-time = "2026-02-05T10:30:14.236Z" }, +] + +[[package]] +name = "opentelemetry-api" +version = "1.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-metadata" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2c/1d/4049a9e8698361cc1a1aa03a6c59e4fa4c71e0c0f94a30f988a6876a2ae6/opentelemetry_api-1.40.0.tar.gz", hash = "sha256:159be641c0b04d11e9ecd576906462773eb97ae1b657730f0ecf64d32071569f", size = 70851, upload-time = "2026-03-04T14:17:21.555Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5f/bf/93795954016c522008da367da292adceed71cca6ee1717e1d64c83089099/opentelemetry_api-1.40.0-py3-none-any.whl", hash = "sha256:82dd69331ae74b06f6a874704be0cfaa49a1650e1537d4a813b86ecef7d0ecf9", size = 68676, upload-time = "2026-03-04T14:17:01.24Z" }, +] + +[[package]] +name = "opentelemetry-exporter-otlp" +version = "1.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-exporter-otlp-proto-grpc" }, + { name = "opentelemetry-exporter-otlp-proto-http" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d0/37/b6708e0eff5c5fb9aba2e0ea09f7f3bcbfd12a592d2a780241b5f6014df7/opentelemetry_exporter_otlp-1.40.0.tar.gz", hash = "sha256:7caa0870b95e2fcb59d64e16e2b639ecffb07771b6cd0000b5d12e5e4fef765a", size = 6152, upload-time = "2026-03-04T14:17:23.235Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2d/fc/aea77c28d9f3ffef2fdafdc3f4a235aee4091d262ddabd25882f47ce5c5f/opentelemetry_exporter_otlp-1.40.0-py3-none-any.whl", hash = "sha256:48c87e539ec9afb30dc443775a1334cc5487de2f72a770a4c00b1610bf6c697d", size = 7023, upload-time = "2026-03-04T14:17:03.612Z" }, +] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-common" +version = "1.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-proto" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/51/bc/1559d46557fe6eca0b46c88d4c2676285f1f3be2e8d06bb5d15fbffc814a/opentelemetry_exporter_otlp_proto_common-1.40.0.tar.gz", hash = "sha256:1cbee86a4064790b362a86601ee7934f368b81cd4cc2f2e163902a6e7818a0fa", size = 20416, upload-time = "2026-03-04T14:17:23.801Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/ca/8f122055c97a932311a3f640273f084e738008933503d0c2563cd5d591fc/opentelemetry_exporter_otlp_proto_common-1.40.0-py3-none-any.whl", hash = "sha256:7081ff453835a82417bf38dccf122c827c3cbc94f2079b03bba02a3165f25149", size = 18369, upload-time = "2026-03-04T14:17:04.796Z" }, +] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-grpc" +version = "1.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "googleapis-common-protos" }, + { name = "grpcio" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp-proto-common" }, + { name = "opentelemetry-proto" }, + { name = "opentelemetry-sdk" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8f/7f/b9e60435cfcc7590fa87436edad6822240dddbc184643a2a005301cc31f4/opentelemetry_exporter_otlp_proto_grpc-1.40.0.tar.gz", hash = "sha256:bd4015183e40b635b3dab8da528b27161ba83bf4ef545776b196f0fb4ec47740", size = 25759, upload-time = "2026-03-04T14:17:24.4Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/96/6f/7ee0980afcbdcd2d40362da16f7f9796bd083bf7f0b8e038abfbc0300f5d/opentelemetry_exporter_otlp_proto_grpc-1.40.0-py3-none-any.whl", hash = "sha256:2aa0ca53483fe0cf6405087a7491472b70335bc5c7944378a0a8e72e86995c52", size = 20304, upload-time = "2026-03-04T14:17:05.942Z" }, +] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-http" +version = "1.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "googleapis-common-protos" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp-proto-common" }, + { name = "opentelemetry-proto" }, + { name = "opentelemetry-sdk" }, + { name = "requests" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2e/fa/73d50e2c15c56be4d000c98e24221d494674b0cc95524e2a8cb3856d95a4/opentelemetry_exporter_otlp_proto_http-1.40.0.tar.gz", hash = "sha256:db48f5e0f33217588bbc00274a31517ba830da576e59503507c839b38fa0869c", size = 17772, upload-time = "2026-03-04T14:17:25.324Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/3a/8865d6754e61c9fb170cdd530a124a53769ee5f740236064816eb0ca7301/opentelemetry_exporter_otlp_proto_http-1.40.0-py3-none-any.whl", hash = "sha256:a8d1dab28f504c5d96577d6509f80a8150e44e8f45f82cdbe0e34c99ab040069", size = 19960, upload-time = "2026-03-04T14:17:07.153Z" }, +] + +[[package]] +name = "opentelemetry-proto" +version = "1.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4c/77/dd38991db037fdfce45849491cb61de5ab000f49824a00230afb112a4392/opentelemetry_proto-1.40.0.tar.gz", hash = "sha256:03f639ca129ba513f5819810f5b1f42bcb371391405d99c168fe6937c62febcd", size = 45667, upload-time = "2026-03-04T14:17:31.194Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b9/b2/189b2577dde745b15625b3214302605b1353436219d42b7912e77fa8dc24/opentelemetry_proto-1.40.0-py3-none-any.whl", hash = "sha256:266c4385d88923a23d63e353e9761af0f47a6ed0d486979777fe4de59dc9b25f", size = 72073, upload-time = "2026-03-04T14:17:16.673Z" }, +] + +[[package]] +name = "opentelemetry-sdk" +version = "1.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/58/fd/3c3125b20ba18ce2155ba9ea74acb0ae5d25f8cd39cfd37455601b7955cc/opentelemetry_sdk-1.40.0.tar.gz", hash = "sha256:18e9f5ec20d859d268c7cb3c5198c8d105d073714db3de50b593b8c1345a48f2", size = 184252, upload-time = "2026-03-04T14:17:31.87Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/c5/6a852903d8bfac758c6dc6e9a68b015d3c33f2f1be5e9591e0f4b69c7e0a/opentelemetry_sdk-1.40.0-py3-none-any.whl", hash = "sha256:787d2154a71f4b3d81f20524a8ce061b7db667d24e46753f32a7bc48f1c1f3f1", size = 141951, upload-time = "2026-03-04T14:17:17.961Z" }, +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.61b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/c0/4ae7973f3c2cfd2b6e321f1675626f0dab0a97027cc7a297474c9c8f3d04/opentelemetry_semantic_conventions-0.61b0.tar.gz", hash = "sha256:072f65473c5d7c6dc0355b27d6c9d1a679d63b6d4b4b16a9773062cb7e31192a", size = 145755, upload-time = "2026-03-04T14:17:32.664Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b2/37/cc6a55e448deaa9b27377d087da8615a3416d8ad523d5960b78dbeadd02a/opentelemetry_semantic_conventions-0.61b0-py3-none-any.whl", hash = "sha256:fa530a96be229795f8cef353739b618148b0fe2b4b3f005e60e262926c4d38e2", size = 231621, upload-time = "2026-03-04T14:17:19.33Z" }, +] + +[[package]] +name = "opentelemetry-semantic-conventions-ai" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-sdk" }, + { name = "opentelemetry-semantic-conventions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/24/02/10aeacc37a38a3a8fa16ff67bec1ae3bf882539f6f9efb0f70acf802ca2d/opentelemetry_semantic_conventions_ai-0.5.1.tar.gz", hash = "sha256:153906200d8c1d2f8e09bd78dbef526916023de85ac3dab35912bfafb69ff04c", size = 26533, upload-time = "2026-03-26T14:20:38.73Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/55/22/41fb05f1dc5fda2c468e05a41814c20859016c85117b66c8a257cae814f6/opentelemetry_semantic_conventions_ai-0.5.1-py3-none-any.whl", hash = "sha256:25aeb22bd261543b4898a73824026d96770e5351209c7d07a0b1314762b1f6e4", size = 11250, upload-time = "2026-03-26T14:20:37.108Z" }, +] + +[[package]] +name = "outlines-core" +version = "0.2.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/04/4a0812eb27c086cfd2e66e7ec9150f33e105912a9b7f8b335e3479f03a06/outlines_core-0.2.14.tar.gz", hash = "sha256:64808deed1591ca3029ff64346ceb974cd5d780c916ea82504951fe83523039e", size = 191539, upload-time = "2026-01-09T15:59:10.016Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/29/3a04944407207a5d214879ca5ca33c2bd3e65199a4e927051c1bdaaa4d50/outlines_core-0.2.14-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3bb2060c240c4507f334965a8948dbeeb22007560d797f6debd92346c0b620cb", size = 2341426, upload-time = "2026-01-09T15:58:33.553Z" }, + { url = "https://files.pythonhosted.org/packages/b2/a7/a77f746272504bac3f628047d56ea1731b61549a3e1d9bbfd226f2968246/outlines_core-0.2.14-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:1de34681c7e0e7e1551fc9036e4fa3c57986336c905a10536591ceb6d869c258", size = 2236941, upload-time = "2026-01-09T15:58:35.118Z" }, +] + +[[package]] +name = "packaging" +version = "26.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/65/ee/299d360cdc32edc7d2cf530f3accf79c4fca01e96ffc950d8a52213bd8e4/packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4", size = 143416, upload-time = "2026-01-21T20:50:39.064Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" }, +] + +[[package]] +name = "partial-json-parser" +version = "0.2.1.1.post7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/6d/eed37d7ebc1e0bcd27b831c0cf1fe94881934316187c4b30d23f29ea0bd4/partial_json_parser-0.2.1.1.post7.tar.gz", hash = "sha256:86590e1ba6bcb6739a2dfc17d2323f028cb5884f4c6ce23db376999132c9a922", size = 10296, upload-time = "2025-11-17T07:27:41.202Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/32/658973117bf0fd82a24abbfb94fe73a5e86216e49342985e10acce54775a/partial_json_parser-0.2.1.1.post7-py3-none-any.whl", hash = "sha256:145119e5eabcf80cbb13844a6b50a85c68bf99d376f8ed771e2a3c3b03e653ae", size = 10877, upload-time = "2025-11-17T07:27:40.457Z" }, +] + +[[package]] +name = "pillow" +version = "12.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/21/c2bcdd5906101a30244eaffc1b6e6ce71a31bd0742a01eb89e660ebfac2d/pillow-12.2.0.tar.gz", hash = "sha256:a830b1a40919539d07806aa58e1b114df53ddd43213d9c8b75847eee6c0182b5", size = 46987819, upload-time = "2026-04-01T14:46:17.687Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/af/4e8e6869cbed569d43c416fad3dc4ecb944cb5d9492defaed89ddd6fe871/pillow-12.2.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:03e7e372d5240cc23e9f07deca4d775c0817bffc641b01e9c3af208dbd300987", size = 6284462, upload-time = "2026-04-01T14:43:18.268Z" }, + { url = "https://files.pythonhosted.org/packages/e9/9e/c05e19657fd57841e476be1ab46c4d501bffbadbafdc31a6d665f8b737b6/pillow-12.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b86024e52a1b269467a802258c25521e6d742349d760728092e1bc2d135b4d76", size = 8094744, upload-time = "2026-04-01T14:43:20.716Z" }, + { url = "https://files.pythonhosted.org/packages/2b/54/1789c455ed10176066b6e7e6da1b01e50e36f94ba584dc68d9eebfe9156d/pillow-12.2.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7371b48c4fa448d20d2714c9a1f775a81155050d383333e0a6c15b1123dda005", size = 6398371, upload-time = "2026-04-01T14:43:23.443Z" }, + { url = "https://files.pythonhosted.org/packages/43/e3/fdc657359e919462369869f1c9f0e973f353f9a9ee295a39b1fea8ee1a77/pillow-12.2.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:62f5409336adb0663b7caa0da5c7d9e7bdbaae9ce761d34669420c2a801b2780", size = 7087215, upload-time = "2026-04-01T14:43:26.758Z" }, + { url = "https://files.pythonhosted.org/packages/8b/f8/2f6825e441d5b1959d2ca5adec984210f1ec086435b0ed5f52c19b3b8a6e/pillow-12.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:01afa7cf67f74f09523699b4e88c73fb55c13346d212a59a2db1f86b0a63e8c5", size = 6509783, upload-time = "2026-04-01T14:43:29.56Z" }, + { url = "https://files.pythonhosted.org/packages/67/f9/029a27095ad20f854f9dba026b3ea6428548316e057e6fc3545409e86651/pillow-12.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fc3d34d4a8fbec3e88a79b92e5465e0f9b842b628675850d860b8bd300b159f5", size = 7212112, upload-time = "2026-04-01T14:43:32.091Z" }, +] + +[[package]] +name = "prometheus-client" +version = "0.24.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f0/58/a794d23feb6b00fc0c72787d7e87d872a6730dd9ed7c7b3e954637d8f280/prometheus_client-0.24.1.tar.gz", hash = "sha256:7e0ced7fbbd40f7b84962d5d2ab6f17ef88a72504dcf7c0b40737b43b2a461f9", size = 85616, upload-time = "2026-01-14T15:26:26.965Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/74/c3/24a2f845e3917201628ecaba4f18bab4d18a337834c1df2a159ee9d22a42/prometheus_client-0.24.1-py3-none-any.whl", hash = "sha256:150db128af71a5c2482b36e588fc8a6b95e498750da4b17065947c16070f4055", size = 64057, upload-time = "2026-01-14T15:26:24.42Z" }, +] + +[[package]] +name = "prometheus-fastapi-instrumentator" +version = "7.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "prometheus-client" }, + { name = "starlette" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/69/6d/24d53033cf93826aa7857699a4450c1c67e5b9c710e925b1ed2b320c04df/prometheus_fastapi_instrumentator-7.1.0.tar.gz", hash = "sha256:be7cd61eeea4e5912aeccb4261c6631b3f227d8924542d79eaf5af3f439cbe5e", size = 20220, upload-time = "2025-03-19T19:35:05.351Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/72/0824c18f3bc75810f55dacc2dd933f6ec829771180245ae3cc976195dec0/prometheus_fastapi_instrumentator-7.1.0-py3-none-any.whl", hash = "sha256:978130f3c0bb7b8ebcc90d35516a6fe13e02d2eb358c8f83887cdef7020c31e9", size = 19296, upload-time = "2025-03-19T19:35:04.323Z" }, +] + +[[package]] +name = "propcache" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9e/da/e9fc233cf63743258bff22b3dfa7ea5baef7b5bc324af47a0ad89b8ffc6f/propcache-0.4.1.tar.gz", hash = "sha256:f48107a8c637e80362555f37ecf49abe20370e557cc4ab374f04ec4423c97c3d", size = 46442, upload-time = "2025-10-08T19:49:02.291Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/d3/6c7ee328b39a81ee877c962469f1e795f9db87f925251efeb0545e0020d0/propcache-0.4.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ec17c65562a827bba85e3872ead335f95405ea1674860d96483a02f5c698fa72", size = 225505, upload-time = "2025-10-08T19:46:50.055Z" }, + { url = "https://files.pythonhosted.org/packages/01/5d/1c53f4563490b1d06a684742cc6076ef944bc6457df6051b7d1a877c057b/propcache-0.4.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:405aac25c6394ef275dee4c709be43745d36674b223ba4eb7144bf4d691b7367", size = 230242, upload-time = "2025-10-08T19:46:51.815Z" }, + { url = "https://files.pythonhosted.org/packages/20/e1/ce4620633b0e2422207c3cb774a0ee61cac13abc6217763a7b9e2e3f4a12/propcache-0.4.1-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0013cb6f8dde4b2a2f66903b8ba740bdfe378c943c4377a200551ceb27f379e4", size = 238474, upload-time = "2025-10-08T19:46:53.208Z" }, + { url = "https://files.pythonhosted.org/packages/46/4b/3aae6835b8e5f44ea6a68348ad90f78134047b503765087be2f9912140ea/propcache-0.4.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:15932ab57837c3368b024473a525e25d316d8353016e7cc0e5ba9eb343fbb1cf", size = 221575, upload-time = "2025-10-08T19:46:54.511Z" }, + { url = "https://files.pythonhosted.org/packages/6e/a5/8a5e8678bcc9d3a1a15b9a29165640d64762d424a16af543f00629c87338/propcache-0.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:031dce78b9dc099f4c29785d9cf5577a3faf9ebf74ecbd3c856a7b92768c3df3", size = 216736, upload-time = "2025-10-08T19:46:56.212Z" }, + { url = "https://files.pythonhosted.org/packages/f1/63/b7b215eddeac83ca1c6b934f89d09a625aa9ee4ba158338854c87210cc36/propcache-0.4.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:ab08df6c9a035bee56e31af99be621526bd237bea9f32def431c656b29e41778", size = 213019, upload-time = "2025-10-08T19:46:57.595Z" }, + { url = "https://files.pythonhosted.org/packages/57/74/f580099a58c8af587cac7ba19ee7cb418506342fbbe2d4a4401661cca886/propcache-0.4.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4d7af63f9f93fe593afbf104c21b3b15868efb2c21d07d8732c0c4287e66b6a6", size = 220376, upload-time = "2025-10-08T19:46:59.067Z" }, + { url = "https://files.pythonhosted.org/packages/c4/ee/542f1313aff7eaf19c2bb758c5d0560d2683dac001a1c96d0774af799843/propcache-0.4.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:cfc27c945f422e8b5071b6e93169679e4eb5bf73bbcbf1ba3ae3a83d2f78ebd9", size = 226988, upload-time = "2025-10-08T19:47:00.544Z" }, + { url = "https://files.pythonhosted.org/packages/8f/18/9c6b015dd9c6930f6ce2229e1f02fb35298b847f2087ea2b436a5bfa7287/propcache-0.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:35c3277624a080cc6ec6f847cbbbb5b49affa3598c4535a0a4682a697aaa5c75", size = 215615, upload-time = "2025-10-08T19:47:01.968Z" }, + { url = "https://files.pythonhosted.org/packages/5b/5a/bc7b4a4ef808fa59a816c17b20c4bef6884daebbdf627ff2a161da67da19/propcache-0.4.1-py3-none-any.whl", hash = "sha256:af2a6052aeb6cf17d3e46ee169099044fd8224cbaf75c76a2ef596e8163e2237", size = 13305, upload-time = "2025-10-08T19:49:00.792Z" }, +] + +[[package]] +name = "protobuf" +version = "6.33.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/66/70/e908e9c5e52ef7c3a6c7902c9dfbb34c7e29c25d2f81ade3856445fd5c94/protobuf-6.33.6.tar.gz", hash = "sha256:a6768d25248312c297558af96a9f9c929e8c4cee0659cb07e780731095f38135", size = 444531, upload-time = "2026-03-18T19:05:00.988Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ee/90/b3c01fdec7d2f627b3a6884243ba328c1217ed2d978def5c12dc50d328a3/protobuf-6.33.6-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:e2afbae9b8e1825e3529f88d514754e094278bb95eadc0e199751cdd9a2e82a2", size = 324610, upload-time = "2026-03-18T19:04:53.096Z" }, + { url = "https://files.pythonhosted.org/packages/9b/ca/25afc144934014700c52e05103c2421997482d561f3101ff352e1292fb81/protobuf-6.33.6-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:c96c37eec15086b79762ed265d59ab204dabc53056e3443e702d2681f4b39ce3", size = 339381, upload-time = "2026-03-18T19:04:54.616Z" }, + { url = "https://files.pythonhosted.org/packages/16/92/d1e32e3e0d894fe00b15ce28ad4944ab692713f2e7f0a99787405e43533a/protobuf-6.33.6-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:e9db7e292e0ab79dd108d7f1a94fe31601ce1ee3f7b79e0692043423020b0593", size = 323436, upload-time = "2026-03-18T19:04:55.768Z" }, + { url = "https://files.pythonhosted.org/packages/c4/72/02445137af02769918a93807b2b7890047c32bfb9f90371cbc12688819eb/protobuf-6.33.6-py3-none-any.whl", hash = "sha256:77179e006c476e69bf8e8ce866640091ec42e1beb80b213c3900006ecfba6901", size = 170656, upload-time = "2026-03-18T19:04:59.826Z" }, +] + +[[package]] +name = "psutil" +version = "7.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/aa/c6/d1ddf4abb55e93cebc4f2ed8b5d6dbad109ecb8d63748dd2b20ab5e57ebe/psutil-7.2.2.tar.gz", hash = "sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372", size = 493740, upload-time = "2026-01-28T18:14:54.428Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9", size = 155560, upload-time = "2026-01-28T18:15:25.976Z" }, + { url = "https://files.pythonhosted.org/packages/63/65/37648c0c158dc222aba51c089eb3bdfa238e621674dc42d48706e639204f/psutil-7.2.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0726cecd84f9474419d67252add4ac0cd9811b04d61123054b9fb6f57df6e9e", size = 156997, upload-time = "2026-01-28T18:15:27.794Z" }, + { url = "https://files.pythonhosted.org/packages/8e/13/125093eadae863ce03c6ffdbae9929430d116a246ef69866dad94da3bfbc/psutil-7.2.2-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fd04ef36b4a6d599bbdb225dd1d3f51e00105f6d48a28f006da7f9822f2606d8", size = 148972, upload-time = "2026-01-28T18:15:29.342Z" }, + { url = "https://files.pythonhosted.org/packages/04/78/0acd37ca84ce3ddffaa92ef0f571e073faa6d8ff1f0559ab1272188ea2be/psutil-7.2.2-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b58fabe35e80b264a4e3bb23e6b96f9e45a3df7fb7eed419ac0e5947c61e47cc", size = 148266, upload-time = "2026-01-28T18:15:31.597Z" }, +] + +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/37/a8/d832f7293ebb21690860d2e01d8115e5ff6f2ae8bbdc953f0eb0fa4bd2c7/py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690", size = 104716, upload-time = "2022-10-25T20:38:06.303Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" }, +] + +[[package]] +name = "pybase64" +version = "1.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/aa/b8/4ed5c7ad5ec15b08d35cc79ace6145d5c1ae426e46435f4987379439dfea/pybase64-1.4.3.tar.gz", hash = "sha256:c2ed274c9e0ba9c8f9c4083cfe265e66dd679126cd9c2027965d807352f3f053", size = 137272, upload-time = "2025-12-06T13:27:04.013Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/1b/9a8cab0042b464e9a876d5c65fe5127445a2436da36fda64899b119b1a1b/pybase64-1.4.3-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:f0b3f200c3e06316f6bebabd458b4e4bcd4c2ca26af7c0c766614d91968dee27", size = 68210, upload-time = "2025-12-06T13:23:18.813Z" }, + { url = "https://files.pythonhosted.org/packages/62/f7/965b79ff391ad208b50e412b5d3205ccce372a2d27b7218ae86d5295b105/pybase64-1.4.3-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:bb632edfd132b3eaf90c39c89aa314beec4e946e210099b57d40311f704e11d4", size = 71599, upload-time = "2025-12-06T13:23:20.195Z" }, + { url = "https://files.pythonhosted.org/packages/03/4b/a3b5175130b3810bbb8ccfa1edaadbd3afddb9992d877c8a1e2f274b476e/pybase64-1.4.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:356ef1d74648ce997f5a777cf8f1aefecc1c0b4fe6201e0ef3ec8a08170e1b54", size = 59922, upload-time = "2025-12-06T13:23:21.487Z" }, + { url = "https://files.pythonhosted.org/packages/da/5d/c38d1572027fc601b62d7a407721688b04b4d065d60ca489912d6893e6cf/pybase64-1.4.3-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.whl", hash = "sha256:c48361f90db32bacaa5518419d4eb9066ba558013aaf0c7781620279ecddaeb9", size = 56712, upload-time = "2025-12-06T13:23:22.77Z" }, + { url = "https://files.pythonhosted.org/packages/e7/d4/4e04472fef485caa8f561d904d4d69210a8f8fc1608ea15ebd9012b92655/pybase64-1.4.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:702bcaa16ae02139d881aeaef5b1c8ffb4a3fae062fe601d1e3835e10310a517", size = 59300, upload-time = "2025-12-06T13:23:24.543Z" }, + { url = "https://files.pythonhosted.org/packages/86/e7/16e29721b86734b881d09b7e23dfd7c8408ad01a4f4c7525f3b1088e25ec/pybase64-1.4.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:53d0ffe1847b16b647c6413d34d1de08942b7724273dd57e67dcbdb10c574045", size = 60278, upload-time = "2025-12-06T13:23:25.608Z" }, + { url = "https://files.pythonhosted.org/packages/b1/02/18515f211d7c046be32070709a8efeeef8a0203de4fd7521e6b56404731b/pybase64-1.4.3-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:9a1792e8b830a92736dae58f0c386062eb038dfe8004fb03ba33b6083d89cd43", size = 54817, upload-time = "2025-12-06T13:23:26.633Z" }, + { url = "https://files.pythonhosted.org/packages/e7/be/14e29d8e1a481dbff151324c96dd7b5d2688194bb65dc8a00ca0e1ad1e86/pybase64-1.4.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1d468b1b1ac5ad84875a46eaa458663c3721e8be5f155ade356406848d3701f6", size = 58611, upload-time = "2025-12-06T13:23:27.684Z" }, + { url = "https://files.pythonhosted.org/packages/b4/8a/a2588dfe24e1bbd742a554553778ab0d65fdf3d1c9a06d10b77047d142aa/pybase64-1.4.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:e97b7bdbd62e71898cd542a6a9e320d9da754ff3ebd02cb802d69087ee94d468", size = 52404, upload-time = "2025-12-06T13:23:28.714Z" }, + { url = "https://files.pythonhosted.org/packages/27/fc/afcda7445bebe0cbc38cafdd7813234cdd4fc5573ff067f1abf317bb0cec/pybase64-1.4.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b33aeaa780caaa08ffda87fc584d5eab61e3d3bbb5d86ead02161dc0c20d04bc", size = 68817, upload-time = "2025-12-06T13:23:30.079Z" }, + { url = "https://files.pythonhosted.org/packages/d3/3a/87c3201e555ed71f73e961a787241a2438c2bbb2ca8809c29ddf938a3157/pybase64-1.4.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:1c0efcf78f11cf866bed49caa7b97552bc4855a892f9cc2372abcd3ed0056f0d", size = 57854, upload-time = "2025-12-06T13:23:31.17Z" }, + { url = "https://files.pythonhosted.org/packages/fd/7d/931c2539b31a7b375e7d595b88401eeb5bd6c5ce1059c9123f9b608aaa14/pybase64-1.4.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:66e3791f2ed725a46593f8bd2761ff37d01e2cdad065b1dceb89066f476e50c6", size = 54333, upload-time = "2025-12-06T13:23:32.422Z" }, + { url = "https://files.pythonhosted.org/packages/de/5e/537601e02cc01f27e9d75f440f1a6095b8df44fc28b1eef2cd739aea8cec/pybase64-1.4.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:72bb0b6bddadab26e1b069bb78e83092711a111a80a0d6b9edcb08199ad7299b", size = 56492, upload-time = "2025-12-06T13:23:33.515Z" }, + { url = "https://files.pythonhosted.org/packages/96/97/2a2e57acf8f5c9258d22aba52e71f8050e167b29ed2ee1113677c1b600c1/pybase64-1.4.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5b3365dbcbcdb0a294f0f50af0c0a16b27a232eddeeb0bceeefd844ef30d2a23", size = 70974, upload-time = "2025-12-06T13:23:36.27Z" }, + { url = "https://files.pythonhosted.org/packages/fa/8f/43c3bb11ca9bacf81cb0b7a71500bb65b2eda6d5fe07433c09b543de97f3/pybase64-1.4.3-graalpy312-graalpy250_312_native-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5c29a582b0ea3936d02bd6fe9bf674ab6059e6e45ab71c78404ab2c913224414", size = 43461, upload-time = "2025-12-06T13:26:28.906Z" }, + { url = "https://files.pythonhosted.org/packages/2d/4c/2a5258329200be57497d3972b5308558c6de42e3749c6cc2aa1cbe34b25a/pybase64-1.4.3-graalpy312-graalpy250_312_native-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b6b664758c804fa919b4f1257aa8cf68e95db76fc331de5f70bfc3a34655afe1", size = 36058, upload-time = "2025-12-06T13:26:30.092Z" }, +] + +[[package]] +name = "pycountry" +version = "26.2.16" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/de/1d/061b9e7a48b85cfd69f33c33d2ef784a531c359399ad764243399673c8f5/pycountry-26.2.16.tar.gz", hash = "sha256:5b6027d453fcd6060112b951dd010f01f168b51b4bf8a1f1fc8c95c8d94a0801", size = 7711342, upload-time = "2026-02-17T03:42:52.367Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9c/42/7703bd45b62fecd44cd7d3495423097e2f7d28bc2e99e7c1af68892ab157/pycountry-26.2.16-py3-none-any.whl", hash = "sha256:115c4baf7cceaa30f59a4694d79483c9167dbce7a9de4d3d571c5f3ea77c305a", size = 8044600, upload-time = "2026-02-17T03:42:49.777Z" }, +] + +[[package]] +name = "pycparser" +version = "3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1b/7d/92392ff7815c21062bea51aa7b87d45576f649f16458d78b7cf94b9ab2e6/pycparser-3.0.tar.gz", hash = "sha256:600f49d217304a5902ac3c37e1281c9fe94e4d0489de643a9504c5cdfdfc6b29", size = 103492, upload-time = "2026-01-21T14:26:51.89Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/c3/44f3fbbfa403ea2a7c779186dc20772604442dde72947e7d01069cbe98e3/pycparser-3.0-py3-none-any.whl", hash = "sha256:b727414169a36b7d524c1c3e31839a521725078d7b2ff038656844266160a992", size = 48172, upload-time = "2026-01-21T14:26:50.693Z" }, +] + +[[package]] +name = "pydantic" +version = "2.12.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/69/44/36f1a6e523abc58ae5f928898e4aca2e0ea509b5aa6f6f392a5d882be928/pydantic-2.12.5.tar.gz", hash = "sha256:4d351024c75c0f085a9febbb665ce8c0c6ec5d30e903bdb6394b7ede26aebb49", size = 821591, upload-time = "2025-11-26T15:11:46.471Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/87/b70ad306ebb6f9b585f114d0ac2137d792b48be34d732d60e597c2f8465a/pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d", size = 463580, upload-time = "2025-11-26T15:11:44.605Z" }, +] + +[package.optional-dependencies] +email = [ + { name = "email-validator" }, +] + +[[package]] +name = "pydantic-core" +version = "2.41.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/70/23b021c950c2addd24ec408e9ab05d59b035b39d97cdc1130e1bce647bb6/pydantic_core-2.41.5.tar.gz", hash = "sha256:08daa51ea16ad373ffd5e7606252cc32f07bc72b28284b6bc9c6df804816476e", size = 460952, upload-time = "2025-11-04T13:43:49.098Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/b8/a01b53cb0e59139fbc9e4fda3e9724ede8de279097179be4ff31f1abb65a/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e96cea19e34778f8d59fe40775a7a574d95816eb150850a85a7a4c8f4b94ac69", size = 1919200, upload-time = "2025-11-04T13:40:02.241Z" }, + { url = "https://files.pythonhosted.org/packages/38/de/8c36b5198a29bdaade07b5985e80a233a5ac27137846f3bc2d3b40a47360/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed2e99c456e3fadd05c991f8f437ef902e00eedf34320ba2b0842bd1c3ca3a75", size = 2052578, upload-time = "2025-11-04T13:40:04.401Z" }, + { url = "https://files.pythonhosted.org/packages/00/b5/0e8e4b5b081eac6cb3dbb7e60a65907549a1ce035a724368c330112adfdd/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65840751b72fbfd82c3c640cff9284545342a4f1eb1586ad0636955b261b0b05", size = 2208504, upload-time = "2025-11-04T13:40:06.072Z" }, + { url = "https://files.pythonhosted.org/packages/77/56/87a61aad59c7c5b9dc8caad5a41a5545cba3810c3e828708b3d7404f6cef/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e536c98a7626a98feb2d3eaf75944ef6f3dbee447e1f841eae16f2f0a72d8ddc", size = 2335816, upload-time = "2025-11-04T13:40:07.835Z" }, + { url = "https://files.pythonhosted.org/packages/0d/76/941cc9f73529988688a665a5c0ecff1112b3d95ab48f81db5f7606f522d3/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eceb81a8d74f9267ef4081e246ffd6d129da5d87e37a77c9bde550cb04870c1c", size = 2075366, upload-time = "2025-11-04T13:40:09.804Z" }, + { url = "https://files.pythonhosted.org/packages/d3/43/ebef01f69baa07a482844faaa0a591bad1ef129253ffd0cdaa9d8a7f72d3/pydantic_core-2.41.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d38548150c39b74aeeb0ce8ee1d8e82696f4a4e16ddc6de7b1d8823f7de4b9b5", size = 2171698, upload-time = "2025-11-04T13:40:12.004Z" }, + { url = "https://files.pythonhosted.org/packages/b1/87/41f3202e4193e3bacfc2c065fab7706ebe81af46a83d3e27605029c1f5a6/pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c23e27686783f60290e36827f9c626e63154b82b116d7fe9adba1fda36da706c", size = 2132603, upload-time = "2025-11-04T13:40:13.868Z" }, + { url = "https://files.pythonhosted.org/packages/49/7d/4c00df99cb12070b6bccdef4a195255e6020a550d572768d92cc54dba91a/pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:482c982f814460eabe1d3bb0adfdc583387bd4691ef00b90575ca0d2b6fe2294", size = 2329591, upload-time = "2025-11-04T13:40:15.672Z" }, + { url = "https://files.pythonhosted.org/packages/cc/6a/ebf4b1d65d458f3cda6a7335d141305dfa19bdc61140a884d165a8a1bbc7/pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bfea2a5f0b4d8d43adf9d7b8bf019fb46fdd10a2e5cde477fbcb9d1fa08c68e1", size = 2319068, upload-time = "2025-11-04T13:40:17.532Z" }, + { url = "https://files.pythonhosted.org/packages/6e/0d/e3549b2399f71d56476b77dbf3cf8937cec5cd70536bdc0e374a421d0599/pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c007fe8a43d43b3969e8469004e9845944f1a80e6acd47c150856bb87f230c56", size = 1942879, upload-time = "2025-11-04T13:42:56.483Z" }, + { url = "https://files.pythonhosted.org/packages/f7/07/34573da085946b6a313d7c42f82f16e8920bfd730665de2d11c0c37a74b5/pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76d0819de158cd855d1cbb8fcafdf6f5cf1eb8e470abe056d5d161106e38062b", size = 2139017, upload-time = "2025-11-04T13:42:59.471Z" }, +] + +[[package]] +name = "pydantic-extra-types" +version = "2.11.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/d3/3be31542180c0300b6860129ff1e3a428f3ef580727616ce22462626129b/pydantic_extra_types-2.11.2.tar.gz", hash = "sha256:3a2b83b61fe920925688e7838b59caa90a45637d1dbba2b1364b8d1f7ff72a0a", size = 203929, upload-time = "2026-04-05T20:50:51.556Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/92/a4/7b6ab05c18d6c6e682382a0f0235301684452c4131a869f45961d1d032c9/pydantic_extra_types-2.11.2-py3-none-any.whl", hash = "sha256:683b8943252543e49760f89733b1519bc62f31d1a287ebbdc5a7b7959fb4acfd", size = 82851, upload-time = "2026-04-05T20:50:50.036Z" }, +] + +[package.optional-dependencies] +pycountry = [ + { name = "pycountry" }, +] + +[[package]] +name = "pydantic-settings" +version = "2.13.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/6d/fffca34caecc4a3f97bda81b2098da5e8ab7efc9a66e819074a11955d87e/pydantic_settings-2.13.1.tar.gz", hash = "sha256:b4c11847b15237fb0171e1462bf540e294affb9b86db4d9aa5c01730bdbe4025", size = 223826, upload-time = "2026-02-19T13:45:08.055Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/4b/ccc026168948fec4f7555b9164c724cf4125eac006e176541483d2c959be/pydantic_settings-2.13.1-py3-none-any.whl", hash = "sha256:d56fd801823dbeae7f0975e1f8c8e25c258eb75d278ea7abb5d9cebb01b56237", size = 58929, upload-time = "2026-02-19T13:45:06.034Z" }, +] + +[[package]] +name = "pygments" +version = "2.20.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/b2/bc9c9196916376152d655522fdcebac55e66de6603a76a02bca1b6414f6c/pygments-2.20.0.tar.gz", hash = "sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f", size = 4955991, upload-time = "2026-03-29T13:29:33.898Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/7e/a72dd26f3b0f4f2bf1dd8923c85f7ceb43172af56d63c7383eb62b332364/pygments-2.20.0-py3-none-any.whl", hash = "sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176", size = 1231151, upload-time = "2026-03-29T13:29:30.038Z" }, +] + +[[package]] +name = "pyjwt" +version = "2.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c2/27/a3b6e5bf6ff856d2509292e95c8f57f0df7017cf5394921fc4e4ef40308a/pyjwt-2.12.1.tar.gz", hash = "sha256:c74a7a2adf861c04d002db713dd85f84beb242228e671280bf709d765b03672b", size = 102564, upload-time = "2026-03-13T19:27:37.25Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/7a/8dd906bd22e79e47397a61742927f6747fe93242ef86645ee9092e610244/pyjwt-2.12.1-py3-none-any.whl", hash = "sha256:28ca37c070cad8ba8cd9790cd940535d40274d22f80ab87f3ac6a713e6e8454c", size = 29726, upload-time = "2026-03-13T19:27:35.677Z" }, +] + +[package.optional-dependencies] +crypto = [ + { name = "cryptography" }, +] + +[[package]] +name = "python-dotenv" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/82/ed/0301aeeac3e5353ef3d94b6ec08bbcabd04a72018415dcb29e588514bba8/python_dotenv-1.2.2.tar.gz", hash = "sha256:2c371a91fbd7ba082c2c1dc1f8bf89ca22564a087c2c287cd9b662adde799cf3", size = 50135, upload-time = "2026-03-01T16:00:26.196Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101, upload-time = "2026-03-01T16:00:25.09Z" }, +] + +[[package]] +name = "python-json-logger" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f7/ff/3cc9165fd44106973cd7ac9facb674a65ed853494592541d339bdc9a30eb/python_json_logger-4.1.0.tar.gz", hash = "sha256:b396b9e3ed782b09ff9d6e4f1683d46c83ad0d35d2e407c09a9ebbf038f88195", size = 17573, upload-time = "2026-03-29T04:39:56.805Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/be/0631a861af4d1c875f096c07d34e9a63639560a717130e7a87cbc82b7e3f/python_json_logger-4.1.0-py3-none-any.whl", hash = "sha256:132994765cf75bf44554be9aa49b06ef2345d23661a96720262716438141b6b2", size = 15021, upload-time = "2026-03-29T04:39:55.266Z" }, +] + +[[package]] +name = "python-multipart" +version = "0.0.24" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/45/e23b5dc14ddb9918ae4a625379506b17b6f8fc56ca1d82db62462f59aea6/python_multipart-0.0.24.tar.gz", hash = "sha256:9574c97e1c026e00bc30340ef7c7d76739512ab4dfd428fec8c330fa6a5cc3c8", size = 37695, upload-time = "2026-04-05T20:49:13.829Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/73/89930efabd4da63cea44a3f438aeb753d600123570e6d6264e763617a9ce/python_multipart-0.0.24-py3-none-any.whl", hash = "sha256:9b110a98db707df01a53c194f0af075e736a770dc5058089650d70b4a182f950", size = 24420, upload-time = "2026-04-05T20:49:12.555Z" }, +] + +[[package]] +name = "pyyaml" +version = "6.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz", hash = "sha256:d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f", size = 130960, upload-time = "2025-09-25T21:33:16.546Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/33/422b98d2195232ca1826284a76852ad5a86fe23e31b009c9886b2d0fb8b2/pyyaml-6.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7f047e29dcae44602496db43be01ad42fc6f1cc0d8cd6c83d342306c32270196", size = 182063, upload-time = "2025-09-25T21:32:11.445Z" }, + { url = "https://files.pythonhosted.org/packages/89/a0/6cf41a19a1f2f3feab0e9c0b74134aa2ce6849093d5517a0c550fe37a648/pyyaml-6.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fc09d0aa354569bc501d4e787133afc08552722d3ab34836a80547331bb5d4a0", size = 173973, upload-time = "2025-09-25T21:32:12.492Z" }, + { url = "https://files.pythonhosted.org/packages/ed/23/7a778b6bd0b9a8039df8b1b1d80e2e2ad78aa04171592c8a5c43a56a6af4/pyyaml-6.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9149cad251584d5fb4981be1ecde53a1ca46c891a79788c0df828d2f166bda28", size = 775116, upload-time = "2025-09-25T21:32:13.652Z" }, + { url = "https://files.pythonhosted.org/packages/65/30/d7353c338e12baef4ecc1b09e877c1970bd3382789c159b4f89d6a70dc09/pyyaml-6.0.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5fdec68f91a0c6739b380c83b951e2c72ac0197ace422360e6d5a959d8d97b2c", size = 844011, upload-time = "2025-09-25T21:32:15.21Z" }, + { url = "https://files.pythonhosted.org/packages/8b/9d/b3589d3877982d4f2329302ef98a8026e7f4443c765c46cfecc8858c6b4b/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ba1cc08a7ccde2d2ec775841541641e4548226580ab850948cbfda66a1befcdc", size = 807870, upload-time = "2025-09-25T21:32:16.431Z" }, + { url = "https://files.pythonhosted.org/packages/05/c0/b3be26a015601b822b97d9149ff8cb5ead58c66f981e04fedf4e762f4bd4/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8dc52c23056b9ddd46818a57b78404882310fb473d63f17b07d5c40421e47f8e", size = 761089, upload-time = "2025-09-25T21:32:17.56Z" }, + { url = "https://files.pythonhosted.org/packages/be/8e/98435a21d1d4b46590d5459a22d88128103f8da4c2d4cb8f14f2a96504e1/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:41715c910c881bc081f1e8872880d3c650acf13dfa8214bad49ed4cede7c34ea", size = 790181, upload-time = "2025-09-25T21:32:18.834Z" }, + { url = "https://files.pythonhosted.org/packages/74/93/7baea19427dcfbe1e5a372d81473250b379f04b1bd3c4c5ff825e2327202/pyyaml-6.0.3-cp312-cp312-win32.whl", hash = "sha256:96b533f0e99f6579b3d4d4995707cf36df9100d67e0c8303a0c55b27b5f99bc5", size = 137658, upload-time = "2025-09-25T21:32:20.209Z" }, + { url = "https://files.pythonhosted.org/packages/86/bf/899e81e4cce32febab4fb42bb97dcdf66bc135272882d1987881a4b519e9/pyyaml-6.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:5fcd34e47f6e0b794d17de1b4ff496c00986e1c83f7ab2fb8fcfe9616ff7477b", size = 154003, upload-time = "2025-09-25T21:32:21.167Z" }, + { url = "https://files.pythonhosted.org/packages/1a/08/67bd04656199bbb51dbed1439b7f27601dfb576fb864099c7ef0c3e55531/pyyaml-6.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:64386e5e707d03a7e172c0701abfb7e10f0fb753ee1d773128192742712a98fd", size = 140344, upload-time = "2025-09-25T21:32:22.617Z" }, +] + +[[package]] +name = "pyzmq" +version = "27.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "implementation_name == 'pypy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/04/0b/3c9baedbdf613ecaa7aa07027780b8867f57b6293b6ee50de316c9f3222b/pyzmq-27.1.0.tar.gz", hash = "sha256:ac0765e3d44455adb6ddbf4417dcce460fc40a05978c08efdf2948072f6db540", size = 281750, upload-time = "2025-09-08T23:10:18.157Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e8/5e/c3c49fdd0f535ef45eefcc16934648e9e59dace4a37ee88fc53f6cd8e641/pyzmq-27.1.0-cp312-abi3-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:1c179799b118e554b66da67d88ed66cd37a169f1f23b5d9f0a231b4e8d44a113", size = 895645, upload-time = "2025-09-08T23:08:05.301Z" }, + { url = "https://files.pythonhosted.org/packages/f8/e5/b0b2504cb4e903a74dcf1ebae157f9e20ebb6ea76095f6cfffea28c42ecd/pyzmq-27.1.0-cp312-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3837439b7f99e60312f0c926a6ad437b067356dc2bc2ec96eb395fd0fe804233", size = 652574, upload-time = "2025-09-08T23:08:06.828Z" }, + { url = "https://files.pythonhosted.org/packages/f8/9b/c108cdb55560eaf253f0cbdb61b29971e9fb34d9c3499b0e96e4e60ed8a5/pyzmq-27.1.0-cp312-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43ad9a73e3da1fab5b0e7e13402f0b2fb934ae1c876c51d0afff0e7c052eca31", size = 840995, upload-time = "2025-09-08T23:08:08.396Z" }, + { url = "https://files.pythonhosted.org/packages/c2/bb/b79798ca177b9eb0825b4c9998c6af8cd2a7f15a6a1a4272c1d1a21d382f/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0de3028d69d4cdc475bfe47a6128eb38d8bc0e8f4d69646adfbcd840facbac28", size = 1642070, upload-time = "2025-09-08T23:08:09.989Z" }, + { url = "https://files.pythonhosted.org/packages/9c/80/2df2e7977c4ede24c79ae39dcef3899bfc5f34d1ca7a5b24f182c9b7a9ca/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_i686.whl", hash = "sha256:cf44a7763aea9298c0aa7dbf859f87ed7012de8bda0f3977b6fb1d96745df856", size = 2021121, upload-time = "2025-09-08T23:08:11.907Z" }, + { url = "https://files.pythonhosted.org/packages/46/bd/2d45ad24f5f5ae7e8d01525eb76786fa7557136555cac7d929880519e33a/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f30f395a9e6fbca195400ce833c731e7b64c3919aa481af4d88c3759e0cb7496", size = 1878550, upload-time = "2025-09-08T23:08:13.513Z" }, +] + +[[package]] +name = "quack-kernels" +version = "0.3.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "apache-tvm-ffi" }, + { name = "nvidia-cutlass-dsl" }, + { name = "torch" }, + { name = "torch-c-dlpack-ext" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7b/db/d2e480fd71c38b88ffcbf40298d604400c64e0ffcaa06d6aa61a87b2673a/quack_kernels-0.3.9.tar.gz", hash = "sha256:4fd272f52142e408a591b94be7c6a0261e222e034e599bce6da827eeae8ad04d", size = 212760, upload-time = "2026-04-05T06:34:58.642Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/a8/eea5885361143c19505a8e86890a681c363ac0f9ac6ba02b5c2c82ebe44b/quack_kernels-0.3.9-py3-none-any.whl", hash = "sha256:160364a32fd72df6e934adb2bb2ae324843ddccffc88aaa6f5de4c9a00ec7ac8", size = 216038, upload-time = "2026-04-05T06:34:57.426Z" }, +] + +[[package]] +name = "referencing" +version = "0.37.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "rpds-py" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/22/f5/df4e9027acead3ecc63e50fe1e36aca1523e1719559c499951bb4b53188f/referencing-0.37.0.tar.gz", hash = "sha256:44aefc3142c5b842538163acb373e24cce6632bd54bdb01b21ad5863489f50d8", size = 78036, upload-time = "2025-10-13T15:30:48.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/58/ca301544e1fa93ed4f80d724bf5b194f6e4b945841c5bfd555878eea9fcb/referencing-0.37.0-py3-none-any.whl", hash = "sha256:381329a9f99628c9069361716891d34ad94af76e461dcb0335825aecc7692231", size = 26766, upload-time = "2025-10-13T15:30:47.625Z" }, +] + +[[package]] +name = "regex" +version = "2026.4.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cb/0e/3a246dbf05666918bd3664d9d787f84a9108f6f43cc953a077e4a7dfdb7e/regex-2026.4.4.tar.gz", hash = "sha256:e08270659717f6973523ce3afbafa53515c4dc5dcad637dc215b6fd50f689423", size = 416000, upload-time = "2026-04-03T20:56:28.155Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/28/b972a4d3df61e1d7bcf1b59fdb3cddef22f88b6be43f161bb41ebc0e4081/regex-2026.4.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:c07ab8794fa929e58d97a0e1796b8b76f70943fa39df225ac9964615cf1f9d52", size = 490434, upload-time = "2026-04-03T20:53:40.219Z" }, + { url = "https://files.pythonhosted.org/packages/84/20/30041446cf6dc3e0eab344fc62770e84c23b6b68a3b657821f9f80cb69b4/regex-2026.4.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2c785939dc023a1ce4ec09599c032cc9933d258a998d16ca6f2b596c010940eb", size = 292061, upload-time = "2026-04-03T20:53:41.862Z" }, + { url = "https://files.pythonhosted.org/packages/62/c8/3baa06d75c98c46d4cc4262b71fd2edb9062b5665e868bca57859dadf93a/regex-2026.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1b1ce5c81c9114f1ce2f9288a51a8fd3aeea33a0cc440c415bf02da323aa0a76", size = 289628, upload-time = "2026-04-03T20:53:43.701Z" }, + { url = "https://files.pythonhosted.org/packages/31/87/3accf55634caad8c0acab23f5135ef7d4a21c39f28c55c816ae012931408/regex-2026.4.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:760ef21c17d8e6a4fe8cf406a97cf2806a4df93416ccc82fc98d25b1c20425be", size = 796651, upload-time = "2026-04-03T20:53:45.379Z" }, + { url = "https://files.pythonhosted.org/packages/f6/0c/aaa2c83f34efedbf06f61cb1942c25f6cf1ee3b200f832c4d05f28306c2e/regex-2026.4.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7088fcdcb604a4417c208e2169715800d28838fefd7455fbe40416231d1d47c1", size = 865916, upload-time = "2026-04-03T20:53:47.064Z" }, + { url = "https://files.pythonhosted.org/packages/d9/f6/8c6924c865124643e8f37823eca845dc27ac509b2ee58123685e71cd0279/regex-2026.4.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:07edca1ba687998968f7db5bc355288d0c6505caa7374f013d27356d93976d13", size = 912287, upload-time = "2026-04-03T20:53:49.422Z" }, + { url = "https://files.pythonhosted.org/packages/11/0e/a9f6f81013e0deaf559b25711623864970fe6a098314e374ccb1540a4152/regex-2026.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:993f657a7c1c6ec51b5e0ba97c9817d06b84ea5fa8d82e43b9405de0defdc2b9", size = 801126, upload-time = "2026-04-03T20:53:51.096Z" }, + { url = "https://files.pythonhosted.org/packages/71/61/3a0cc8af2dc0c8deb48e644dd2521f173f7e6513c6e195aad9aa8dd77ac5/regex-2026.4.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:2b69102a743e7569ebee67e634a69c4cb7e59d6fa2e1aa7d3bdbf3f61435f62d", size = 776788, upload-time = "2026-04-03T20:53:52.889Z" }, + { url = "https://files.pythonhosted.org/packages/64/0b/8bb9cbf21ef7dee58e49b0fdb066a7aded146c823202e16494a36777594f/regex-2026.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dac006c8b6dda72d86ea3d1333d45147de79a3a3f26f10c1cf9287ca4ca0ac3", size = 785184, upload-time = "2026-04-03T20:53:55.627Z" }, + { url = "https://files.pythonhosted.org/packages/99/c2/d3e80e8137b25ee06c92627de4e4d98b94830e02b3e6f81f3d2e3f504cf5/regex-2026.4.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:50a766ee2010d504554bfb5f578ed2e066898aa26411d57e6296230627cdefa0", size = 859913, upload-time = "2026-04-03T20:53:57.249Z" }, + { url = "https://files.pythonhosted.org/packages/bc/e6/9d5d876157d969c804622456ef250017ac7a8f83e0e14f903b9e6df5ce95/regex-2026.4.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:9e2f5217648f68e3028c823df58663587c1507a5ba8419f4fdfc8a461be76043", size = 765732, upload-time = "2026-04-03T20:53:59.428Z" }, + { url = "https://files.pythonhosted.org/packages/82/80/b568935b4421388561c8ed42aff77247285d3ae3bb2a6ca22af63bae805e/regex-2026.4.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:39d8de85a08e32632974151ba59c6e9140646dcc36c80423962b1c5c0a92e244", size = 852152, upload-time = "2026-04-03T20:54:01.505Z" }, + { url = "https://files.pythonhosted.org/packages/39/29/f0f81217e21cd998245da047405366385d5c6072048038a3d33b37a79dc0/regex-2026.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:55d9304e0e7178dfb1e106c33edf834097ddf4a890e2f676f6c5118f84390f73", size = 789076, upload-time = "2026-04-03T20:54:03.323Z" }, + { url = "https://files.pythonhosted.org/packages/49/1d/1d957a61976ab9d4e767dd4f9d04b66cc0c41c5e36cf40e2d43688b5ae6f/regex-2026.4.4-cp312-cp312-win32.whl", hash = "sha256:04bb679bc0bde8a7bfb71e991493d47314e7b98380b083df2447cda4b6edb60f", size = 266700, upload-time = "2026-04-03T20:54:05.639Z" }, + { url = "https://files.pythonhosted.org/packages/c5/5c/bf575d396aeb58ea13b06ef2adf624f65b70fafef6950a80fc3da9cae3bc/regex-2026.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:db0ac18435a40a2543dbb3d21e161a6c78e33e8159bd2e009343d224bb03bb1b", size = 277768, upload-time = "2026-04-03T20:54:07.312Z" }, + { url = "https://files.pythonhosted.org/packages/c9/27/049df16ec6a6828ccd72add3c7f54b4df029669bea8e9817df6fff58be90/regex-2026.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:4ce255cc05c1947a12989c6db801c96461947adb7a59990f1360b5983fab4983", size = 270568, upload-time = "2026-04-03T20:54:09.484Z" }, +] + +[[package]] +name = "requests" +version = "2.33.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5f/a4/98b9c7c6428a668bf7e42ebb7c79d576a1c3c1e3ae2d47e674b468388871/requests-2.33.1.tar.gz", hash = "sha256:18817f8c57c6263968bc123d237e3b8b08ac046f5456bd1e307ee8f4250d3517", size = 134120, upload-time = "2026-03-30T16:09:15.531Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/8e/7540e8a2036f79a125c1d2ebadf69ed7901608859186c856fa0388ef4197/requests-2.33.1-py3-none-any.whl", hash = "sha256:4e6d1ef462f3626a1f0a0a9c42dd93c63bad33f9f1c1937509b8c5c8718ab56a", size = 64947, upload-time = "2026-03-30T16:09:13.83Z" }, +] + +[[package]] +name = "rich" +version = "14.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/c6/f3b320c27991c46f43ee9d856302c70dc2d0fb2dba4842ff739d5f46b393/rich-14.3.3.tar.gz", hash = "sha256:b8daa0b9e4eef54dd8cf7c86c03713f53241884e814f4e2f5fb342fe520f639b", size = 230582, upload-time = "2026-02-19T17:23:12.474Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/25/b208c5683343959b670dc001595f2f3737e051da617f66c31f7c4fa93abc/rich-14.3.3-py3-none-any.whl", hash = "sha256:793431c1f8619afa7d3b52b2cdec859562b950ea0d4b6b505397612db8d5362d", size = 310458, upload-time = "2026-02-19T17:23:13.732Z" }, +] + +[[package]] +name = "rich-toolkit" +version = "0.19.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "rich" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/ba/dae9e3096651042754da419a4042bc1c75e07d615f9b15066d738838e4df/rich_toolkit-0.19.7.tar.gz", hash = "sha256:133c0915872da91d4c25d85342d5ec1dfacc69b63448af1a08a0d4b4f23ef46e", size = 195877, upload-time = "2026-02-24T16:06:20.555Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/3c/c923619f6d2f5fafcc96fec0aaf9550a46cd5b6481f06e0c6b66a2a4fed0/rich_toolkit-0.19.7-py3-none-any.whl", hash = "sha256:0288e9203728c47c5a4eb60fd2f0692d9df7455a65901ab6f898437a2ba5989d", size = 32963, upload-time = "2026-02-24T16:06:22.066Z" }, +] + +[[package]] +name = "rignore" +version = "0.7.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e5/f5/8bed2310abe4ae04b67a38374a4d311dd85220f5d8da56f47ae9361be0b0/rignore-0.7.6.tar.gz", hash = "sha256:00d3546cd793c30cb17921ce674d2c8f3a4b00501cb0e3dd0e82217dbeba2671", size = 57140, upload-time = "2025-11-05T21:41:21.968Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4a/c8/dea564b36dedac8de21c18e1851789545bc52a0c22ece9843444d5608a6a/rignore-0.7.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bda49950d405aa8d0ebe26af807c4e662dd281d926530f03f29690a2e07d649a", size = 897821, upload-time = "2025-11-05T20:40:52.613Z" }, + { url = "https://files.pythonhosted.org/packages/b3/2b/ee96db17ac1835e024c5d0742eefb7e46de60020385ac883dd3d1cde2c1f/rignore-0.7.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b5fd5ab3840b8c16851d327ed06e9b8be6459702a53e5ab1fc4073b684b3789e", size = 873963, upload-time = "2025-11-05T20:41:07.49Z" }, + { url = "https://files.pythonhosted.org/packages/a5/8c/ad5a57bbb9d14d5c7e5960f712a8a0b902472ea3f4a2138cbf70d1777b75/rignore-0.7.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ced2a248352636a5c77504cb755dc02c2eef9a820a44d3f33061ce1bb8a7f2d2", size = 1169216, upload-time = "2025-11-05T20:41:23.73Z" }, + { url = "https://files.pythonhosted.org/packages/80/e6/5b00bc2a6bc1701e6878fca798cf5d9125eb3113193e33078b6fc0d99123/rignore-0.7.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a04a3b73b75ddc12c9c9b21efcdaab33ca3832941d6f1d67bffd860941cd448a", size = 942942, upload-time = "2025-11-05T20:41:39.393Z" }, + { url = "https://files.pythonhosted.org/packages/85/e5/7f99bd0cc9818a91d0e8b9acc65b792e35750e3bdccd15a7ee75e64efca4/rignore-0.7.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d24321efac92140b7ec910ac7c53ab0f0c86a41133d2bb4b0e6a7c94967f44dd", size = 959787, upload-time = "2025-11-05T20:42:09.765Z" }, + { url = "https://files.pythonhosted.org/packages/55/54/2ffea79a7c1eabcede1926347ebc2a81bc6b81f447d05b52af9af14948b9/rignore-0.7.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:73c7aa109d41e593785c55fdaa89ad80b10330affa9f9d3e3a51fa695f739b20", size = 984245, upload-time = "2025-11-05T20:41:54.062Z" }, + { url = "https://files.pythonhosted.org/packages/41/f7/e80f55dfe0f35787fa482aa18689b9c8251e045076c35477deb0007b3277/rignore-0.7.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1734dc49d1e9501b07852ef44421f84d9f378da9fbeda729e77db71f49cac28b", size = 1078647, upload-time = "2025-11-05T21:40:13.463Z" }, + { url = "https://files.pythonhosted.org/packages/d4/cf/2c64f0b6725149f7c6e7e5a909d14354889b4beaadddaa5fff023ec71084/rignore-0.7.6-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:5719ea14ea2b652c0c0894be5dfde954e1853a80dea27dd2fbaa749618d837f5", size = 1139186, upload-time = "2025-11-05T21:40:31.27Z" }, + { url = "https://files.pythonhosted.org/packages/75/95/a86c84909ccc24af0d094b50d54697951e576c252a4d9f21b47b52af9598/rignore-0.7.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:8e23424fc7ce35726854f639cb7968151a792c0c3d9d082f7f67e0c362cfecca", size = 1117604, upload-time = "2025-11-05T21:40:48.07Z" }, + { url = "https://files.pythonhosted.org/packages/7f/5e/13b249613fd5d18d58662490ab910a9f0be758981d1797789913adb4e918/rignore-0.7.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3efdcf1dd84d45f3e2bd2f93303d9be103888f56dfa7c3349b5bf4f0657ec696", size = 1127725, upload-time = "2025-11-05T21:41:05.804Z" }, +] + +[[package]] +name = "rpds-py" +version = "0.30.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/af/3f2f423103f1113b36230496629986e0ef7e199d2aa8392452b484b38ced/rpds_py-0.30.0.tar.gz", hash = "sha256:dd8ff7cf90014af0c0f787eea34794ebf6415242ee1d6fa91eaba725cc441e84", size = 69469, upload-time = "2025-11-30T20:24:38.837Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/65/1c/ae157e83a6357eceff62ba7e52113e3ec4834a84cfe07fa4b0757a7d105f/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca28829ae5f5d569bb62a79512c842a03a12576375d5ece7d2cadf8abe96ec28", size = 390763, upload-time = "2025-11-30T20:22:21.661Z" }, + { url = "https://files.pythonhosted.org/packages/d4/36/eb2eb8515e2ad24c0bd43c3ee9cd74c33f7ca6430755ccdb240fd3144c44/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a1010ed9524c73b94d15919ca4d41d8780980e1765babf85f9a2f90d247153dd", size = 408951, upload-time = "2025-11-30T20:22:23.408Z" }, + { url = "https://files.pythonhosted.org/packages/d6/65/ad8dc1784a331fabbd740ef6f71ce2198c7ed0890dab595adb9ea2d775a1/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8d1736cfb49381ba528cd5baa46f82fdc65c06e843dab24dd70b63d09121b3f", size = 514622, upload-time = "2025-11-30T20:22:25.16Z" }, + { url = "https://files.pythonhosted.org/packages/63/8e/0cfa7ae158e15e143fe03993b5bcd743a59f541f5952e1546b1ac1b5fd45/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d948b135c4693daff7bc2dcfc4ec57237a29bd37e60c2fabf5aff2bbacf3e2f1", size = 414492, upload-time = "2025-11-30T20:22:26.505Z" }, + { url = "https://files.pythonhosted.org/packages/60/1b/6f8f29f3f995c7ffdde46a626ddccd7c63aefc0efae881dc13b6e5d5bb16/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47f236970bccb2233267d89173d3ad2703cd36a0e2a6e92d0560d333871a3d23", size = 394080, upload-time = "2025-11-30T20:22:27.934Z" }, + { url = "https://files.pythonhosted.org/packages/6d/d5/a266341051a7a3ca2f4b750a3aa4abc986378431fc2da508c5034d081b70/rpds_py-0.30.0-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:2e6ecb5a5bcacf59c3f912155044479af1d0b6681280048b338b28e364aca1f6", size = 408680, upload-time = "2025-11-30T20:22:29.341Z" }, + { url = "https://files.pythonhosted.org/packages/10/3b/71b725851df9ab7a7a4e33cf36d241933da66040d195a84781f49c50490c/rpds_py-0.30.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a8fa71a2e078c527c3e9dc9fc5a98c9db40bcc8a92b4e8858e36d329f8684b51", size = 423589, upload-time = "2025-11-30T20:22:31.469Z" }, + { url = "https://files.pythonhosted.org/packages/00/2b/e59e58c544dc9bd8bd8384ecdb8ea91f6727f0e37a7131baeff8d6f51661/rpds_py-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:73c67f2db7bc334e518d097c6d1e6fed021bbc9b7d678d6cc433478365d1d5f5", size = 573289, upload-time = "2025-11-30T20:22:32.997Z" }, + { url = "https://files.pythonhosted.org/packages/da/3e/a18e6f5b460893172a7d6a680e86d3b6bc87a54c1f0b03446a3c8c7b588f/rpds_py-0.30.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5ba103fb455be00f3b1c2076c9d4264bfcb037c976167a6047ed82f23153f02e", size = 599737, upload-time = "2025-11-30T20:22:34.419Z" }, + { url = "https://files.pythonhosted.org/packages/5c/e2/714694e4b87b85a18e2c243614974413c60aa107fd815b8cbc42b873d1d7/rpds_py-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7cee9c752c0364588353e627da8a7e808a66873672bcb5f52890c33fd965b394", size = 563120, upload-time = "2025-11-30T20:22:35.903Z" }, +] + +[[package]] +name = "safetensors" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/29/9c/6e74567782559a63bd040a236edca26fd71bc7ba88de2ef35d75df3bca5e/safetensors-0.7.0.tar.gz", hash = "sha256:07663963b67e8bd9f0b8ad15bb9163606cd27cc5a1b96235a50d8369803b96b0", size = 200878, upload-time = "2025-11-19T15:18:43.199Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/47/aef6c06649039accf914afef490268e1067ed82be62bcfa5b7e886ad15e8/safetensors-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c82f4d474cf725255d9e6acf17252991c3c8aac038d6ef363a4bf8be2f6db517", size = 467781, upload-time = "2025-11-19T15:18:35.84Z" }, + { url = "https://files.pythonhosted.org/packages/e8/00/374c0c068e30cd31f1e1b46b4b5738168ec79e7689ca82ee93ddfea05109/safetensors-0.7.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:94fd4858284736bb67a897a41608b5b0c2496c9bdb3bf2af1fa3409127f20d57", size = 447058, upload-time = "2025-11-19T15:18:34.416Z" }, + { url = "https://files.pythonhosted.org/packages/f1/06/578ffed52c2296f93d7fd2d844cabfa92be51a587c38c8afbb8ae449ca89/safetensors-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e07d91d0c92a31200f25351f4acb2bc6aff7f48094e13ebb1d0fb995b54b6542", size = 491748, upload-time = "2025-11-19T15:18:09.79Z" }, + { url = "https://files.pythonhosted.org/packages/ae/33/1debbbb70e4791dde185edb9413d1fe01619255abb64b300157d7f15dddd/safetensors-0.7.0-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8469155f4cb518bafb4acf4865e8bb9d6804110d2d9bdcaa78564b9fd841e104", size = 503881, upload-time = "2025-11-19T15:18:16.145Z" }, + { url = "https://files.pythonhosted.org/packages/8e/1c/40c2ca924d60792c3be509833df711b553c60effbd91da6f5284a83f7122/safetensors-0.7.0-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:54bef08bf00a2bff599982f6b08e8770e09cc012d7bba00783fc7ea38f1fb37d", size = 623463, upload-time = "2025-11-19T15:18:21.11Z" }, + { url = "https://files.pythonhosted.org/packages/9b/3a/13784a9364bd43b0d61eef4bea2845039bc2030458b16594a1bd787ae26e/safetensors-0.7.0-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:42cb091236206bb2016d245c377ed383aa7f78691748f3bb6ee1bfa51ae2ce6a", size = 532855, upload-time = "2025-11-19T15:18:25.719Z" }, + { url = "https://files.pythonhosted.org/packages/a0/60/429e9b1cb3fc651937727befe258ea24122d9663e4d5709a48c9cbfceecb/safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac7252938f0696ddea46f5e855dd3138444e82236e3be475f54929f0c510d48", size = 507152, upload-time = "2025-11-19T15:18:33.023Z" }, + { url = "https://files.pythonhosted.org/packages/3c/a8/4b45e4e059270d17af60359713ffd83f97900d45a6afa73aaa0d737d48b6/safetensors-0.7.0-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1d060c70284127fa805085d8f10fbd0962792aed71879d00864acda69dbab981", size = 541856, upload-time = "2025-11-19T15:18:31.075Z" }, + { url = "https://files.pythonhosted.org/packages/06/87/d26d8407c44175d8ae164a95b5a62707fcc445f3c0c56108e37d98070a3d/safetensors-0.7.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:cdab83a366799fa730f90a4ebb563e494f28e9e92c4819e556152ad55e43591b", size = 674060, upload-time = "2025-11-19T15:18:37.211Z" }, + { url = "https://files.pythonhosted.org/packages/11/f5/57644a2ff08dc6325816ba7217e5095f17269dada2554b658442c66aed51/safetensors-0.7.0-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:672132907fcad9f2aedcb705b2d7b3b93354a2aec1b2f706c4db852abe338f85", size = 771715, upload-time = "2025-11-19T15:18:38.689Z" }, + { url = "https://files.pythonhosted.org/packages/86/31/17883e13a814bd278ae6e266b13282a01049b0c81341da7fd0e3e71a80a3/safetensors-0.7.0-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:5d72abdb8a4d56d4020713724ba81dac065fedb7f3667151c4a637f1d3fb26c0", size = 714377, upload-time = "2025-11-19T15:18:40.162Z" }, + { url = "https://files.pythonhosted.org/packages/4a/d8/0c8a7dc9b41dcac53c4cbf9df2b9c83e0e0097203de8b37a712b345c0be5/safetensors-0.7.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b0f6d66c1c538d5a94a73aa9ddca8ccc4227e6c9ff555322ea40bdd142391dd4", size = 677368, upload-time = "2025-11-19T15:18:41.627Z" }, + { url = "https://files.pythonhosted.org/packages/05/e5/cb4b713c8a93469e3c5be7c3f8d77d307e65fe89673e731f5c2bfd0a9237/safetensors-0.7.0-cp38-abi3-win32.whl", hash = "sha256:c74af94bf3ac15ac4d0f2a7c7b4663a15f8c2ab15ed0fc7531ca61d0835eccba", size = 326423, upload-time = "2025-11-19T15:18:45.74Z" }, + { url = "https://files.pythonhosted.org/packages/5d/e6/ec8471c8072382cb91233ba7267fd931219753bb43814cbc71757bfd4dab/safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755", size = 341380, upload-time = "2025-11-19T15:18:44.427Z" }, +] + +[[package]] +name = "sentencepiece" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/15/2e7a025fc62d764b151ae6d0f2a92f8081755ebe8d4a64099accc6f77ba6/sentencepiece-0.2.1.tar.gz", hash = "sha256:8138cec27c2f2282f4a34d9a016e3374cd40e5c6e9cb335063db66a0a3b71fad", size = 3228515, upload-time = "2025-08-12T07:00:51.718Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/fa/d3d5ebcba3cb9e6d3775a096251860c41a6bc53a1b9461151df83fe93255/sentencepiece-0.2.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99f955df238021bf11f0fc37cdb54fd5e5b5f7fd30ecc3d93fb48b6815437167", size = 1316273, upload-time = "2025-08-12T06:59:44.476Z" }, + { url = "https://files.pythonhosted.org/packages/04/88/14f2f4a2b922d8b39be45bf63d79e6cd3a9b2f248b2fcb98a69b12af12f5/sentencepiece-0.2.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0cdfecef430d985f1c2bcbfff3defd1d95dae876fbd0173376012d2d7d24044b", size = 1387881, upload-time = "2025-08-12T06:59:46.09Z" }, +] + +[[package]] +name = "sentry-sdk" +version = "2.57.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4f/87/46c0406d8b5ddd026f73adaf5ab75ce144219c41a4830b52df4b9ab55f7f/sentry_sdk-2.57.0.tar.gz", hash = "sha256:4be8d1e71c32fb27f79c577a337ac8912137bba4bcbc64a4ec1da4d6d8dc5199", size = 435288, upload-time = "2026-03-31T09:39:29.264Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/64/982e07b93219cb52e1cca5d272cb579e2f3eb001956c9e7a9a6d106c9473/sentry_sdk-2.57.0-py2.py3-none-any.whl", hash = "sha256:812c8bf5ff3d2f0e89c82f5ce80ab3a6423e102729c4706af7413fd1eb480585", size = 456489, upload-time = "2026-03-31T09:39:27.524Z" }, +] + +[[package]] +name = "setproctitle" +version = "1.3.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8d/48/49393a96a2eef1ab418b17475fb92b8fcfad83d099e678751b05472e69de/setproctitle-1.3.7.tar.gz", hash = "sha256:bc2bc917691c1537d5b9bca1468437176809c7e11e5694ca79a9ca12345dcb9e", size = 27002, upload-time = "2025-09-05T12:51:25.278Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/99/71630546b9395b095f4082be41165d1078204d1696c2d9baade3de3202d0/setproctitle-1.3.7-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2906b6c7959cdb75f46159bf0acd8cc9906cf1361c9e1ded0d065fe8f9039629", size = 32932, upload-time = "2025-09-05T12:49:39.271Z" }, + { url = "https://files.pythonhosted.org/packages/50/22/cee06af4ffcfb0e8aba047bd44f5262e644199ae7527ae2c1f672b86495c/setproctitle-1.3.7-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6915964a6dda07920a1159321dcd6d94fc7fc526f815ca08a8063aeca3c204f1", size = 33736, upload-time = "2025-09-05T12:49:40.565Z" }, + { url = "https://files.pythonhosted.org/packages/5c/00/a5949a8bb06ef5e7df214fc393bb2fb6aedf0479b17214e57750dfdd0f24/setproctitle-1.3.7-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:cff72899861c765bd4021d1ff1c68d60edc129711a2fdba77f9cb69ef726a8b6", size = 35605, upload-time = "2025-09-05T12:49:42.362Z" }, + { url = "https://files.pythonhosted.org/packages/b0/3a/50caca532a9343828e3bf5778c7a84d6c737a249b1796d50dd680290594d/setproctitle-1.3.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b7cb05bd446687ff816a3aaaf831047fc4c364feff7ada94a66024f1367b448c", size = 33143, upload-time = "2025-09-05T12:49:43.515Z" }, + { url = "https://files.pythonhosted.org/packages/ca/14/b843a251296ce55e2e17c017d6b9f11ce0d3d070e9265de4ecad948b913d/setproctitle-1.3.7-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:3a57b9a00de8cae7e2a1f7b9f0c2ac7b69372159e16a7708aa2f38f9e5cc987a", size = 34434, upload-time = "2025-09-05T12:49:45.31Z" }, + { url = "https://files.pythonhosted.org/packages/c8/b7/06145c238c0a6d2c4bc881f8be230bb9f36d2bf51aff7bddcb796d5eed67/setproctitle-1.3.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d8828b356114f6b308b04afe398ed93803d7fca4a955dd3abe84430e28d33739", size = 32795, upload-time = "2025-09-05T12:49:46.419Z" }, +] + +[[package]] +name = "setuptools" +version = "80.10.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/76/95/faf61eb8363f26aa7e1d762267a8d602a1b26d4f3a1e758e92cb3cb8b054/setuptools-80.10.2.tar.gz", hash = "sha256:8b0e9d10c784bf7d262c4e5ec5d4ec94127ce206e8738f29a437945fbc219b70", size = 1200343, upload-time = "2026-01-25T22:38:17.252Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/b8/f1f62a5e3c0ad2ff1d189590bfa4c46b4f3b6e49cef6f26c6ee4e575394d/setuptools-80.10.2-py3-none-any.whl", hash = "sha256:95b30ddfb717250edb492926c92b5221f7ef3fbcc2b07579bcd4a27da21d0173", size = 1064234, upload-time = "2026-01-25T22:38:15.216Z" }, +] + +[[package]] +name = "shellingham" +version = "1.5.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, +] + +[[package]] +name = "six" +version = "1.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031, upload-time = "2024-12-04T17:35:28.174Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, +] + +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, +] + +[[package]] +name = "sse-starlette" +version = "3.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "starlette" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/26/8c/f9290339ef6d79badbc010f067cd769d6601ec11a57d78569c683fb4dd87/sse_starlette-3.3.4.tar.gz", hash = "sha256:aaf92fc067af8a5427192895ac028e947b484ac01edbc3caf00e7e7137c7bef1", size = 32427, upload-time = "2026-03-29T09:00:23.307Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/7f/3de5402f39890ac5660b86bcf5c03f9d855dad5c4ed764866d7b592b46fd/sse_starlette-3.3.4-py3-none-any.whl", hash = "sha256:84bb06e58939a8b38d8341f1bc9792f06c2b53f48c608dd207582b664fc8f3c1", size = 14330, upload-time = "2026-03-29T09:00:21.846Z" }, +] + +[[package]] +name = "starlette" +version = "0.52.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c4/68/79977123bb7be889ad680d79a40f339082c1978b5cfcf62c2d8d196873ac/starlette-0.52.1.tar.gz", hash = "sha256:834edd1b0a23167694292e94f597773bc3f89f362be6effee198165a35d62933", size = 2653702, upload-time = "2026-01-18T13:34:11.062Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/0d/13d1d239a25cbfb19e740db83143e95c772a1fe10202dda4b76792b114dd/starlette-0.52.1-py3-none-any.whl", hash = "sha256:0029d43eb3d273bc4f83a08720b4912ea4b071087a3b48db01b7c839f7954d74", size = 74272, upload-time = "2026-01-18T13:34:09.188Z" }, +] + +[[package]] +name = "supervisor" +version = "4.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a9/b5/37e7a3706de436a8a2d75334711dad1afb4ddffab09f25e31d89e467542f/supervisor-4.3.0.tar.gz", hash = "sha256:4a2bf149adf42997e1bb44b70c43b613275ec9852c3edacca86a9166b27e945e", size = 468912, upload-time = "2025-08-23T18:25:02.418Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/65/5e726c372da8a5e35022a94388b12252710aad0c2351699c3d76ae8dba78/supervisor-4.3.0-py2.py3-none-any.whl", hash = "sha256:0bcb763fddafba410f35cbde226aa7f8514b9fb82eb05a0c85f6588d1c13f8db", size = 320736, upload-time = "2025-08-23T18:25:00.767Z" }, +] + +[[package]] +name = "sympy" +version = "1.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mpmath" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, +] + +[[package]] +name = "tabulate" +version = "0.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/46/58/8c37dea7bbf769b20d58e7ace7e5edfe65b849442b00ffcdd56be88697c6/tabulate-0.10.0.tar.gz", hash = "sha256:e2cfde8f79420f6deeffdeda9aaec3b6bc5abce947655d17ac662b126e48a60d", size = 91754, upload-time = "2026-03-04T18:55:34.402Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/55/db07de81b5c630da5cbf5c7df646580ca26dfaefa593667fc6f2fe016d2e/tabulate-0.10.0-py3-none-any.whl", hash = "sha256:f0b0622e567335c8fabaaa659f1b33bcb6ddfe2e496071b743aa113f8774f2d3", size = 39814, upload-time = "2026-03-04T18:55:31.284Z" }, +] + +[[package]] +name = "tiktoken" +version = "0.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "regex" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/ab/4d017d0f76ec3171d469d80fc03dfbb4e48a4bcaddaa831b31d526f05edc/tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931", size = 37806, upload-time = "2025-10-06T20:22:45.419Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/66/c5/ed88504d2f4a5fd6856990b230b56d85a777feab84e6129af0822f5d0f70/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:65b26c7a780e2139e73acc193e5c63ac754021f160df919add909c1492c0fb37", size = 1129008, upload-time = "2025-10-06T20:21:54.832Z" }, + { url = "https://files.pythonhosted.org/packages/f4/90/3dae6cc5436137ebd38944d396b5849e167896fc2073da643a49f372dc4f/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:edde1ec917dfd21c1f2f8046b86348b0f54a2c0547f68149d8600859598769ad", size = 1152665, upload-time = "2025-10-06T20:21:56.129Z" }, + { url = "https://files.pythonhosted.org/packages/a3/fe/26df24ce53ffde419a42f5f53d755b995c9318908288c17ec3f3448313a3/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:35a2f8ddd3824608b3d650a000c1ef71f730d0c56486845705a8248da00f9fe5", size = 1194230, upload-time = "2025-10-06T20:21:57.546Z" }, + { url = "https://files.pythonhosted.org/packages/20/cc/b064cae1a0e9fac84b0d2c46b89f4e57051a5f41324e385d10225a984c24/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83d16643edb7fa2c99eff2ab7733508aae1eebb03d5dfc46f5565862810f24e3", size = 1254688, upload-time = "2025-10-06T20:21:58.619Z" }, +] + +[[package]] +name = "tilelang" +version = "0.1.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "apache-tvm-ffi" }, + { name = "cloudpickle" }, + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "psutil" }, + { name = "setuptools", marker = "sys_platform == 'darwin'" }, + { name = "torch" }, + { name = "torch-c-dlpack-ext" }, + { name = "tqdm" }, + { name = "typing-extensions" }, + { name = "z3-solver" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/56/70/5051f65821baa30a3d61fc48f8ba10c776490315e8c90f82559b92089756/tilelang-0.1.9.tar.gz", hash = "sha256:287f727c913bb648fcf6c1968809ba3390e55eeed257a5c6bb9a80bc05966af4", size = 93395292, upload-time = "2026-04-22T09:19:11.988Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/8a/1cbeee79d62abaa02441c2d00621554e41aa62dbf3b94a4feb3867184b01/tilelang-0.1.9-cp38-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4bbccfe9035aed775ffafb6dc25a5994504b24e2c5d95d0f39643edfafa7bf12", size = 45419374, upload-time = "2026-04-22T09:15:56.014Z" }, + { url = "https://files.pythonhosted.org/packages/c6/a7/f4bfb86f87e107703146e703204cec2c0eae2492b633e0052b0ace3febb6/tilelang-0.1.9-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:77ab0ee2f40f66ea015b6b21426d482751e28cbc635ef9d1198cbd6502454a7c", size = 42110365, upload-time = "2026-04-22T09:17:18.292Z" }, +] + +[[package]] +name = "tokenizers" +version = "0.22.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/73/6f/f80cfef4a312e1fb34baf7d85c72d4411afde10978d4657f8cdd811d3ccc/tokenizers-0.22.2.tar.gz", hash = "sha256:473b83b915e547aa366d1eee11806deaf419e17be16310ac0a14077f1e28f917", size = 372115, upload-time = "2026-01-05T10:45:15.988Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/92/97/5dbfabf04c7e348e655e907ed27913e03db0923abb5dfdd120d7b25630e1/tokenizers-0.22.2-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:544dd704ae7238755d790de45ba8da072e9af3eea688f698b137915ae959281c", size = 3100275, upload-time = "2026-01-05T10:41:02.158Z" }, + { url = "https://files.pythonhosted.org/packages/2e/47/174dca0502ef88b28f1c9e06b73ce33500eedfac7a7692108aec220464e7/tokenizers-0.22.2-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:1e418a55456beedca4621dbab65a318981467a2b188e982a23e117f115ce5001", size = 2981472, upload-time = "2026-01-05T10:41:00.276Z" }, + { url = "https://files.pythonhosted.org/packages/d6/84/7990e799f1309a8b87af6b948f31edaa12a3ed22d11b352eaf4f4b2e5753/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2249487018adec45d6e3554c71d46eb39fa8ea67156c640f7513eb26f318cec7", size = 3290736, upload-time = "2026-01-05T10:40:32.165Z" }, + { url = "https://files.pythonhosted.org/packages/78/59/09d0d9ba94dcd5f4f1368d4858d24546b4bdc0231c2354aa31d6199f0399/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:25b85325d0815e86e0bac263506dd114578953b7b53d7de09a6485e4a160a7dd", size = 3168835, upload-time = "2026-01-05T10:40:38.847Z" }, + { url = "https://files.pythonhosted.org/packages/47/50/b3ebb4243e7160bda8d34b731e54dd8ab8b133e50775872e7a434e524c28/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bfb88f22a209ff7b40a576d5324bf8286b519d7358663db21d6246fb17eea2d5", size = 3521673, upload-time = "2026-01-05T10:40:56.614Z" }, + { url = "https://files.pythonhosted.org/packages/e0/fa/89f4cb9e08df770b57adb96f8cbb7e22695a4cb6c2bd5f0c4f0ebcf33b66/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1c774b1276f71e1ef716e5486f21e76333464f47bece56bbd554485982a9e03e", size = 3724818, upload-time = "2026-01-05T10:40:44.507Z" }, + { url = "https://files.pythonhosted.org/packages/64/04/ca2363f0bfbe3b3d36e95bf67e56a4c88c8e3362b658e616d1ac185d47f2/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:df6c4265b289083bf710dff49bc51ef252f9d5be33a45ee2bed151114a56207b", size = 3379195, upload-time = "2026-01-05T10:40:51.139Z" }, + { url = "https://files.pythonhosted.org/packages/2e/76/932be4b50ef6ccedf9d3c6639b056a967a86258c6d9200643f01269211ca/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:369cc9fc8cc10cb24143873a0d95438bb8ee257bb80c71989e3ee290e8d72c67", size = 3274982, upload-time = "2026-01-05T10:40:58.331Z" }, + { url = "https://files.pythonhosted.org/packages/1d/28/5f9f5a4cc211b69e89420980e483831bcc29dade307955cc9dc858a40f01/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:29c30b83d8dcd061078b05ae0cb94d3c710555fbb44861139f9f83dcca3dc3e4", size = 9478245, upload-time = "2026-01-05T10:41:04.053Z" }, + { url = "https://files.pythonhosted.org/packages/6c/fb/66e2da4704d6aadebf8cb39f1d6d1957df667ab24cff2326b77cda0dcb85/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:37ae80a28c1d3265bb1f22464c856bd23c02a05bb211e56d0c5301a435be6c1a", size = 9560069, upload-time = "2026-01-05T10:45:10.673Z" }, + { url = "https://files.pythonhosted.org/packages/16/04/fed398b05caa87ce9b1a1bb5166645e38196081b225059a6edaff6440fac/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:791135ee325f2336f498590eb2f11dc5c295232f288e75c99a36c5dbce63088a", size = 9899263, upload-time = "2026-01-05T10:45:12.559Z" }, + { url = "https://files.pythonhosted.org/packages/05/a1/d62dfe7376beaaf1394917e0f8e93ee5f67fea8fcf4107501db35996586b/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:38337540fbbddff8e999d59970f3c6f35a82de10053206a7562f1ea02d046fa5", size = 10033429, upload-time = "2026-01-05T10:45:14.333Z" }, + { url = "https://files.pythonhosted.org/packages/fd/18/a545c4ea42af3df6effd7d13d250ba77a0a86fb20393143bbb9a92e434d4/tokenizers-0.22.2-cp39-abi3-win32.whl", hash = "sha256:a6bf3f88c554a2b653af81f3204491c818ae2ac6fbc09e76ef4773351292bc92", size = 2502363, upload-time = "2026-01-05T10:45:20.593Z" }, + { url = "https://files.pythonhosted.org/packages/65/71/0670843133a43d43070abeb1949abfdef12a86d490bea9cd9e18e37c5ff7/tokenizers-0.22.2-cp39-abi3-win_amd64.whl", hash = "sha256:c9ea31edff2968b44a88f97d784c2f16dc0729b8b143ed004699ebca91f05c48", size = 2747786, upload-time = "2026-01-05T10:45:18.411Z" }, + { url = "https://files.pythonhosted.org/packages/72/f4/0de46cfa12cdcbcd464cc59fde36912af405696f687e53a091fb432f694c/tokenizers-0.22.2-cp39-abi3-win_arm64.whl", hash = "sha256:9ce725d22864a1e965217204946f830c37876eee3b2ba6fc6255e8e903d5fcbc", size = 2612133, upload-time = "2026-01-05T10:45:17.232Z" }, +] + +[[package]] +name = "torch" +version = "2.11.0+cu128" +source = { url = "https://download.pytorch.org/whl/test/cu128/torch-2.11.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl" } +dependencies = [ + { name = "cuda-bindings", marker = "sys_platform == 'linux'" }, + { name = "cuda-toolkit", extra = ["cublas", "cudart", "cufft", "cufile", "cupti", "curand", "cusolver", "cusparse", "nvjitlink", "nvrtc", "nvtx"], marker = "sys_platform == 'linux'" }, + { name = "filelock" }, + { name = "fsspec" }, + { name = "jinja2" }, + { name = "networkx" }, + { name = "nvidia-cudnn-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvshmem-cu12", marker = "sys_platform == 'linux'" }, + { name = "setuptools" }, + { name = "sympy" }, + { name = "triton", marker = "sys_platform == 'linux'" }, + { name = "typing-extensions" }, +] +wheels = [ + { url = "https://download.pytorch.org/whl/test/cu128/torch-2.11.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:d252cf975fb18c94a85336323ad425f473df56dab35a44b00399bd70c7a3b997" }, +] + +[package.metadata] +requires-dist = [ + { name = "cuda-bindings", marker = "sys_platform == 'linux'", specifier = ">=12.9.4,<13" }, + { name = "cuda-toolkit", extras = ["cublas", "cudart", "cufft", "cufile", "cupti", "curand", "cusolver", "cusparse", "nvjitlink", "nvrtc", "nvtx"], marker = "sys_platform == 'linux'", specifier = "==12.8.1" }, + { name = "filelock" }, + { name = "fsspec", specifier = ">=0.8.5" }, + { name = "jinja2" }, + { name = "networkx", specifier = ">=2.5.1" }, + { name = "nvidia-cudnn-cu12", marker = "sys_platform == 'linux'", specifier = "==9.19.0.56" }, + { name = "nvidia-cusparselt-cu12", marker = "sys_platform == 'linux'", specifier = "==0.7.1" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'", specifier = "==2.28.9" }, + { name = "nvidia-nvshmem-cu12", marker = "sys_platform == 'linux'", specifier = "==3.4.5" }, + { name = "opt-einsum", marker = "extra == 'opt-einsum'", specifier = ">=3.3" }, + { name = "optree", marker = "extra == 'optree'", specifier = ">=0.13.0" }, + { name = "pyyaml", marker = "extra == 'pyyaml'" }, + { name = "setuptools", specifier = "<82" }, + { name = "sympy", specifier = ">=1.13.3" }, + { name = "triton", marker = "sys_platform == 'linux'", specifier = "==3.6.0" }, + { name = "typing-extensions", specifier = ">=4.10.0" }, +] +provides-extras = ["optree", "opt-einsum", "pyyaml"] + +[[package]] +name = "torch-c-dlpack-ext" +version = "0.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/37/de/921b6491efce5c389a5ef9bbed3d2d6660005840dae488124173180859ab/torch_c_dlpack_ext-0.1.5.tar.gz", hash = "sha256:d06f0357d575d22a168cc77acb9020fc4bae30968ceb6718a055dcbe92bacabe", size = 12913, upload-time = "2026-01-12T11:25:08.484Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/06/8d760997307a5c3be4384424667bf31aae0a42060838c532c7d846516175/torch_c_dlpack_ext-0.1.5-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3562ee411258676f9c38b8ad39306d1c8d027b6a86f6a87c920d2d009a9d1510", size = 443069, upload-time = "2026-01-12T11:24:45.451Z" }, + { url = "https://files.pythonhosted.org/packages/e2/79/a914539b4785f3e44f891aa012a886edb8bc10fe081c440981c57543ce21/torch_c_dlpack_ext-0.1.5-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e6f9da4bb9af70e27facc777458be62e10dbbbddda7672d16138db0553c5a524", size = 897846, upload-time = "2026-01-12T11:24:48.168Z" }, +] + +[[package]] +name = "torchaudio" +version = "2.11.0+cu128" +source = { url = "https://download.pytorch.org/whl/test/cu128/torchaudio-2.11.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl" } +wheels = [ + { url = "https://download.pytorch.org/whl/test/cu128/torchaudio-2.11.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:78b86a17f164bdaabdcee93fdfde2587fc43b9ebf15cd61dcf730b4f8615176b" }, +] + +[[package]] +name = "torchvision" +version = "0.26.0+cu128" +source = { url = "https://download.pytorch.org/whl/test/cu128/torchvision-0.26.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl" } +dependencies = [ + { name = "numpy" }, + { name = "pillow" }, + { name = "torch" }, +] +wheels = [ + { url = "https://download.pytorch.org/whl/test/cu128/torchvision-0.26.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ccf26b4b659cfce6f2208cb8326071d51c70219a34856dfdf468d1e19af52c0d" }, +] + +[package.metadata] +requires-dist = [ + { name = "gdown", marker = "extra == 'gdown'", specifier = ">=4.7.3" }, + { name = "numpy" }, + { name = "pillow", specifier = ">=5.3.0,!=8.3.*" }, + { name = "scipy", marker = "extra == 'scipy'" }, + { name = "torch", specifier = "==2.11.0" }, +] +provides-extras = ["gdown", "scipy"] + +[[package]] +name = "tqdm" +version = "4.67.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/09/a9/6ba95a270c6f1fbcd8dac228323f2777d886cb206987444e4bce66338dd4/tqdm-4.67.3.tar.gz", hash = "sha256:7d825f03f89244ef73f1d4ce193cb1774a8179fd96f31d7e1dcde62092b960bb", size = 169598, upload-time = "2026-02-03T17:35:53.048Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/16/e1/3079a9ff9b8e11b846c6ac5c8b5bfb7ff225eee721825310c91b3b50304f/tqdm-4.67.3-py3-none-any.whl", hash = "sha256:ee1e4c0e59148062281c49d80b25b67771a127c85fc9676d3be5f243206826bf", size = 78374, upload-time = "2026-02-03T17:35:50.982Z" }, +] + +[[package]] +name = "transformers" +version = "5.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "regex" }, + { name = "safetensors" }, + { name = "tokenizers" }, + { name = "tqdm" }, + { name = "typer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a4/e9/c6c80a07690142a7d05444271f47b9f3c8aac7dea01d52e1137ee480ad78/transformers-5.6.2.tar.gz", hash = "sha256:e657134c3e5a6bc00a3c35f4e2674bb51adfcd89898495b788a18552bac2b91a", size = 8311867, upload-time = "2026-04-23T18:33:29.332Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/95/0b0218149b0d6f14df35f5b8f676fa83df4f19ed253c3cc447107ef86eca/transformers-5.6.2-py3-none-any.whl", hash = "sha256:f8d3a1bb96778fed9b8aabfd0dd6e19843e4b0f2bb6b59f32b8a92051b0f348f", size = 10364898, upload-time = "2026-04-23T18:33:26.081Z" }, +] + +[[package]] +name = "triton" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/5d/08201db32823bdf77a0e2b9039540080b2e5c23a20706ddba942924ebcd6/triton-3.6.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:374f52c11a711fd062b4bfbb201fd9ac0a5febd28a96fb41b4a0f51dde3157f4", size = 176128243, upload-time = "2026-01-20T16:16:07.857Z" }, + { url = "https://files.pythonhosted.org/packages/ab/a8/cdf8b3e4c98132f965f88c2313a4b493266832ad47fb52f23d14d4f86bb5/triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74caf5e34b66d9f3a429af689c1c7128daba1d8208df60e81106b115c00d6fca", size = 188266850, upload-time = "2026-01-20T16:00:43.041Z" }, +] + +[[package]] +name = "typer" +version = "0.24.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-doc" }, + { name = "click" }, + { name = "rich" }, + { name = "shellingham" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/24/cb09efec5cc954f7f9b930bf8279447d24618bb6758d4f6adf2574c41780/typer-0.24.1.tar.gz", hash = "sha256:e39b4732d65fbdcde189ae76cf7cd48aeae72919dea1fdfc16593be016256b45", size = 118613, upload-time = "2026-02-21T16:54:40.609Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4a/91/48db081e7a63bb37284f9fbcefda7c44c277b18b0e13fbc36ea2335b71e6/typer-0.24.1-py3-none-any.whl", hash = "sha256:112c1f0ce578bfb4cab9ffdabc68f031416ebcc216536611ba21f04e9aa84c9e", size = 56085, upload-time = "2026-02-21T16:54:41.616Z" }, +] + +[[package]] +name = "typing-extensions" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, +] + +[[package]] +name = "typing-inspection" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/55/e3/70399cb7dd41c10ac53367ae42139cf4b1ca5f36bb3dc6c9d33acdb43655/typing_inspection-0.4.2.tar.gz", hash = "sha256:ba561c48a67c5958007083d386c3295464928b01faa735ab8547c5692e87f464", size = 75949, upload-time = "2025-10-01T02:14:41.687Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7", size = 14611, upload-time = "2025-10-01T02:14:40.154Z" }, +] + +[[package]] +name = "urllib3" +version = "2.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556, upload-time = "2026-01-07T16:24:43.925Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, +] + +[[package]] +name = "uvicorn" +version = "0.44.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5e/da/6eee1ff8b6cbeed47eeb5229749168e81eb4b7b999a1a15a7176e51410c9/uvicorn-0.44.0.tar.gz", hash = "sha256:6c942071b68f07e178264b9152f1f16dfac5da85880c4ce06366a96d70d4f31e", size = 86947, upload-time = "2026-04-06T09:23:22.826Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/23/a5bbd9600dd607411fa644c06ff4951bec3a4d82c4b852374024359c19c0/uvicorn-0.44.0-py3-none-any.whl", hash = "sha256:ce937c99a2cc70279556967274414c087888e8cec9f9c94644dfca11bd3ced89", size = 69425, upload-time = "2026-04-06T09:23:21.524Z" }, +] + +[package.optional-dependencies] +standard = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "httptools" }, + { name = "python-dotenv" }, + { name = "pyyaml" }, + { name = "uvloop", marker = "platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'" }, + { name = "watchfiles" }, + { name = "websockets" }, +] + +[[package]] +name = "uvloop" +version = "0.22.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/06/f0/18d39dbd1971d6d62c4629cc7fa67f74821b0dc1f5a77af43719de7936a7/uvloop-0.22.1.tar.gz", hash = "sha256:6c84bae345b9147082b17371e3dd5d42775bddce91f885499017f4607fdaf39f", size = 2443250, upload-time = "2025-10-16T22:17:19.342Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/24/68/a6ac446820273e71aa762fa21cdcc09861edd3536ff47c5cd3b7afb10eeb/uvloop-0.22.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:700e674a166ca5778255e0e1dc4e9d79ab2acc57b9171b79e65feba7184b3370", size = 4317413, upload-time = "2025-10-16T22:16:31.644Z" }, + { url = "https://files.pythonhosted.org/packages/5f/6f/e62b4dfc7ad6518e7eff2516f680d02a0f6eb62c0c212e152ca708a0085e/uvloop-0.22.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b5b1ac819a3f946d3b2ee07f09149578ae76066d70b44df3fa990add49a82e4", size = 4426307, upload-time = "2025-10-16T22:16:32.917Z" }, + { url = "https://files.pythonhosted.org/packages/90/60/97362554ac21e20e81bcef1150cb2a7e4ffdaf8ea1e5b2e8bf7a053caa18/uvloop-0.22.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e047cc068570bac9866237739607d1313b9253c3051ad84738cbb095be0537b2", size = 4131970, upload-time = "2025-10-16T22:16:34.015Z" }, + { url = "https://files.pythonhosted.org/packages/99/39/6b3f7d234ba3964c428a6e40006340f53ba37993f46ed6e111c6e9141d18/uvloop-0.22.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:512fec6815e2dd45161054592441ef76c830eddaad55c8aa30952e6fe1ed07c0", size = 4296343, upload-time = "2025-10-16T22:16:35.149Z" }, +] + +[[package]] +name = "vllm" +version = "0.20.2rc1.dev168+gecd0b60aa.cu129" +source = { url = "https://wheels.vllm.ai/ecd0b60aad2f4e28dd00ababfc1402690d88cbed/vllm-0.20.2rc1.dev168%2Bgecd0b60aa.cu129-cp38-abi3-manylinux_2_34_x86_64.whl" } +dependencies = [ + { name = "aiohttp" }, + { name = "anthropic" }, + { name = "apache-tvm-ffi" }, + { name = "blake3" }, + { name = "cachetools" }, + { name = "cbor2" }, + { name = "cloudpickle" }, + { name = "compressed-tensors" }, + { name = "depyf" }, + { name = "diskcache" }, + { name = "einops" }, + { name = "fastapi", extra = ["standard"] }, + { name = "fastsafetensors" }, + { name = "filelock" }, + { name = "flashinfer-cubin" }, + { name = "flashinfer-python" }, + { name = "gguf" }, + { name = "ijson" }, + { name = "lark" }, + { name = "llguidance", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 'x86_64'" }, + { name = "lm-format-enforcer" }, + { name = "mcp" }, + { name = "mistral-common", extra = ["image"] }, + { name = "model-hosting-container-standards" }, + { name = "msgspec" }, + { name = "ninja" }, + { name = "numba" }, + { name = "numpy" }, + { name = "nvidia-cudnn-frontend" }, + { name = "nvidia-cutlass-dsl" }, + { name = "openai" }, + { name = "openai-harmony" }, + { name = "opencv-python-headless" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp" }, + { name = "opentelemetry-sdk" }, + { name = "opentelemetry-semantic-conventions-ai" }, + { name = "outlines-core" }, + { name = "partial-json-parser" }, + { name = "pillow" }, + { name = "prometheus-client" }, + { name = "prometheus-fastapi-instrumentator" }, + { name = "protobuf" }, + { name = "psutil" }, + { name = "py-cpuinfo" }, + { name = "pybase64" }, + { name = "pydantic" }, + { name = "python-json-logger" }, + { name = "pyyaml" }, + { name = "pyzmq" }, + { name = "quack-kernels" }, + { name = "regex" }, + { name = "requests" }, + { name = "sentencepiece" }, + { name = "setproctitle" }, + { name = "setuptools" }, + { name = "six" }, + { name = "tiktoken" }, + { name = "tilelang" }, + { name = "tokenizers" }, + { name = "torch" }, + { name = "torchaudio" }, + { name = "torchvision" }, + { name = "tqdm" }, + { name = "transformers" }, + { name = "typing-extensions" }, + { name = "watchfiles" }, + { name = "xgrammar", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 's390x' or platform_machine == 'x86_64'" }, +] +wheels = [ + { url = "https://wheels.vllm.ai/ecd0b60aad2f4e28dd00ababfc1402690d88cbed/vllm-0.20.2rc1.dev168%2Bgecd0b60aa.cu129-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:ffc821955e01472615540047d585a5264b6cdc64b21b9273bbb9db18ee0c539d" }, +] + +[package.metadata] +requires-dist = [ + { name = "aiohttp", specifier = ">=3.13.3" }, + { name = "anthropic", specifier = ">=0.71.0" }, + { name = "apache-tvm-ffi", specifier = "==0.1.9" }, + { name = "av", marker = "extra == 'audio'" }, + { name = "blake3" }, + { name = "cachetools" }, + { name = "cbor2" }, + { name = "cloudpickle" }, + { name = "compressed-tensors", specifier = "==0.15.0.1" }, + { name = "datasets", marker = "extra == 'bench'" }, + { name = "depyf", specifier = "==0.20.0" }, + { name = "diskcache", specifier = "==5.6.3" }, + { name = "einops" }, + { name = "fastapi", extras = ["standard"], specifier = ">=0.115.0" }, + { name = "fastsafetensors", specifier = ">=0.2.2" }, + { name = "fastsafetensors", marker = "extra == 'fastsafetensors'", specifier = ">=0.2.2" }, + { name = "filelock", specifier = ">=3.16.1" }, + { name = "flashinfer-cubin", specifier = "==0.6.8.post1" }, + { name = "flashinfer-python", specifier = "==0.6.8.post1" }, + { name = "gguf", specifier = ">=0.17.0" }, + { name = "helion", marker = "extra == 'helion'", specifier = "==1.0.0" }, + { name = "ijson" }, + { name = "instanttensor", marker = "extra == 'instanttensor'", specifier = ">=0.1.5" }, + { name = "lark", specifier = "==1.2.2" }, + { name = "llguidance", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 'x86_64'", specifier = ">=1.3.0,<1.4.0" }, + { name = "lm-format-enforcer", specifier = "==0.11.3" }, + { name = "matplotlib", marker = "extra == 'bench'" }, + { name = "mcp" }, + { name = "mistral-common", extras = ["audio"], marker = "extra == 'audio'" }, + { name = "mistral-common", extras = ["image"], specifier = ">=1.11.2" }, + { name = "model-hosting-container-standards", specifier = ">=0.1.14,<1.0.0" }, + { name = "msgspec" }, + { name = "ninja" }, + { name = "numba", specifier = "==0.65.0" }, + { name = "numpy" }, + { name = "nvidia-cudnn-frontend", specifier = ">=1.13.0,<1.19.0" }, + { name = "nvidia-cutlass-dsl", specifier = ">=4.4.2" }, + { name = "openai", specifier = ">=2.0.0" }, + { name = "openai-harmony", specifier = ">=0.0.3" }, + { name = "opencv-python-headless", specifier = ">=4.13.0" }, + { name = "opentelemetry-api", specifier = ">=1.27.0" }, + { name = "opentelemetry-api", marker = "extra == 'otel'", specifier = ">=1.26.0" }, + { name = "opentelemetry-exporter-otlp", specifier = ">=1.27.0" }, + { name = "opentelemetry-exporter-otlp", marker = "extra == 'otel'", specifier = ">=1.26.0" }, + { name = "opentelemetry-sdk", specifier = ">=1.27.0" }, + { name = "opentelemetry-sdk", marker = "extra == 'otel'", specifier = ">=1.26.0" }, + { name = "opentelemetry-semantic-conventions-ai", specifier = ">=0.4.1" }, + { name = "opentelemetry-semantic-conventions-ai", marker = "extra == 'otel'", specifier = ">=0.4.1" }, + { name = "outlines-core", specifier = "==0.2.14" }, + { name = "pandas", marker = "extra == 'bench'" }, + { name = "partial-json-parser" }, + { name = "pillow" }, + { name = "plotly", marker = "extra == 'bench'" }, + { name = "prometheus-client", specifier = ">=0.18.0" }, + { name = "prometheus-fastapi-instrumentator", specifier = ">=7.0.0" }, + { name = "protobuf", specifier = ">=5.29.6,!=6.30.*,!=6.31.*,!=6.32.*,!=6.33.0.*,!=6.33.1.*,!=6.33.2.*,!=6.33.3.*,!=6.33.4.*" }, + { name = "psutil" }, + { name = "py-cpuinfo" }, + { name = "pybase64" }, + { name = "pydantic", specifier = ">=2.12.0" }, + { name = "python-json-logger" }, + { name = "pyyaml" }, + { name = "pyzmq", specifier = ">=25.0.0" }, + { name = "quack-kernels", specifier = ">=0.3.3" }, + { name = "regex" }, + { name = "requests", specifier = ">=2.26.0" }, + { name = "runai-model-streamer", extras = ["azure", "gcs", "s3"], marker = "extra == 'runai'", specifier = ">=0.15.7" }, + { name = "scipy", marker = "extra == 'audio'" }, + { name = "scipy", marker = "extra == 'bench'" }, + { name = "seaborn", marker = "extra == 'bench'" }, + { name = "sentencepiece" }, + { name = "setproctitle" }, + { name = "setuptools", marker = "python_full_version >= '3.12'", specifier = ">=77.0.3,<81.0.0" }, + { name = "six", marker = "python_full_version >= '3.12'", specifier = ">=1.16.0" }, + { name = "smg-grpc-servicer", extras = ["vllm"], marker = "extra == 'grpc'", specifier = ">=0.5.2" }, + { name = "soundfile", marker = "extra == 'audio'" }, + { name = "tensorizer", marker = "extra == 'tensorizer'", specifier = "==2.10.1" }, + { name = "tiktoken", specifier = ">=0.6.0" }, + { name = "tilelang", specifier = "==0.1.9" }, + { name = "tokenizers", specifier = ">=0.21.1" }, + { name = "torch", specifier = "==2.11.0" }, + { name = "torchaudio", specifier = "==2.11.0" }, + { name = "torchvision", specifier = "==0.26.0" }, + { name = "tqdm" }, + { name = "transformers", specifier = ">=4.56.0,!=5.0.*,!=5.1.*,!=5.2.*,!=5.3.*,!=5.4.*,!=5.5.0" }, + { name = "typing-extensions", specifier = ">=4.10" }, + { name = "watchfiles" }, + { name = "xgrammar", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 's390x' or platform_machine == 'x86_64'", specifier = ">=0.2.0,<1.0.0" }, + { name = "zentorch-weekly", marker = "extra == 'zen'", specifier = "==5.2.1.dev20260408" }, +] +provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttensor", "runai", "audio", "video", "flashinfer", "helion", "grpc", "otel"] + +[[package]] +name = "watchfiles" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c2/c9/8869df9b2a2d6c59d79220a4db37679e74f807c559ffe5265e08b227a210/watchfiles-1.1.1.tar.gz", hash = "sha256:a173cb5c16c4f40ab19cecf48a534c409f7ea983ab8fed0741304a1c0a31b3f2", size = 94440, upload-time = "2025-10-14T15:06:21.08Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/5b/d3b460364aeb8da471c1989238ea0e56bec24b6042a68046adf3d9ddb01c/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8526e8f916bb5b9a0a777c8317c23ce65de259422bba5b31325a6fa6029d33af", size = 449374, upload-time = "2025-10-14T15:04:49.179Z" }, + { url = "https://files.pythonhosted.org/packages/b9/44/5769cb62d4ed055cb17417c0a109a92f007114a4e07f30812a73a4efdb11/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2edc3553362b1c38d9f06242416a5d8e9fe235c204a4072e988ce2e5bb1f69f6", size = 459485, upload-time = "2025-10-14T15:04:50.155Z" }, + { url = "https://files.pythonhosted.org/packages/19/0c/286b6301ded2eccd4ffd0041a1b726afda999926cf720aab63adb68a1e36/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30f7da3fb3f2844259cba4720c3fc7138eb0f7b659c38f3bfa65084c7fc7abce", size = 488813, upload-time = "2025-10-14T15:04:51.059Z" }, + { url = "https://files.pythonhosted.org/packages/c7/2b/8530ed41112dd4a22f4dcfdb5ccf6a1baad1ff6eed8dc5a5f09e7e8c41c7/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8979280bdafff686ba5e4d8f97840f929a87ed9cdf133cbbd42f7766774d2aa", size = 594816, upload-time = "2025-10-14T15:04:52.031Z" }, + { url = "https://files.pythonhosted.org/packages/ce/d2/f5f9fb49489f184f18470d4f99f4e862a4b3e9ac2865688eb2099e3d837a/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dcc5c24523771db3a294c77d94771abcfcb82a0e0ee8efd910c37c59ec1b31bb", size = 475186, upload-time = "2025-10-14T15:04:53.064Z" }, + { url = "https://files.pythonhosted.org/packages/cf/68/5707da262a119fb06fbe214d82dd1fe4a6f4af32d2d14de368d0349eb52a/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1db5d7ae38ff20153d542460752ff397fcf5c96090c1230803713cf3147a6803", size = 456812, upload-time = "2025-10-14T15:04:55.174Z" }, + { url = "https://files.pythonhosted.org/packages/66/ab/3cbb8756323e8f9b6f9acb9ef4ec26d42b2109bce830cc1f3468df20511d/watchfiles-1.1.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:28475ddbde92df1874b6c5c8aaeb24ad5be47a11f87cde5a28ef3835932e3e94", size = 630196, upload-time = "2025-10-14T15:04:56.22Z" }, + { url = "https://files.pythonhosted.org/packages/78/46/7152ec29b8335f80167928944a94955015a345440f524d2dfe63fc2f437b/watchfiles-1.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:36193ed342f5b9842edd3532729a2ad55c4160ffcfa3700e0d54be496b70dd43", size = 622657, upload-time = "2025-10-14T15:04:57.521Z" }, +] + +[[package]] +name = "websockets" +version = "16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/04/24/4b2031d72e840ce4c1ccb255f693b15c334757fc50023e4db9537080b8c4/websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5", size = 179346, upload-time = "2026-01-10T09:23:47.181Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/29/71729b4671f21e1eaa5d6573031ab810ad2936c8175f03f97f3ff164c802/websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c", size = 184915, upload-time = "2026-01-10T09:22:51.071Z" }, + { url = "https://files.pythonhosted.org/packages/97/bb/21c36b7dbbafc85d2d480cd65df02a1dc93bf76d97147605a8e27ff9409d/websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f", size = 186152, upload-time = "2026-01-10T09:22:52.224Z" }, + { url = "https://files.pythonhosted.org/packages/4a/34/9bf8df0c0cf88fa7bfe36678dc7b02970c9a7d5e065a3099292db87b1be2/websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1", size = 185583, upload-time = "2026-01-10T09:22:53.443Z" }, + { url = "https://files.pythonhosted.org/packages/47/88/4dd516068e1a3d6ab3c7c183288404cd424a9a02d585efbac226cb61ff2d/websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2", size = 184880, upload-time = "2026-01-10T09:22:55.033Z" }, + { url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598, upload-time = "2026-01-10T09:23:45.395Z" }, +] + +[[package]] +name = "xgrammar" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "apache-tvm-ffi" }, + { name = "numpy" }, + { name = "pydantic" }, + { name = "torch" }, + { name = "transformers" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a0/54/7e593fc41ffcaf5ac7c0379e0aec0cf03e53a742d1a91f64c6c7e79a6ac1/xgrammar-0.2.0.tar.gz", hash = "sha256:c4f0238a89869343171d43d069b8c5da874f3c2c25f408f20cd5987219a6adef", size = 2421093, upload-time = "2026-05-01T18:33:54.474Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/1c/92eac0cd125ba195e3f1e3e25e89aedcaecbf99a4034ab12b7655ac07453/xgrammar-0.2.0-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ddad831bc7da41d52ed34b7e1050c9a37d3f5f2314eaed8e658cbd2a34625e31", size = 44155238, upload-time = "2026-05-01T18:32:38.679Z" }, + { url = "https://files.pythonhosted.org/packages/7e/30/99f4e83821db16d58dd41249ba46038ed47bce274c57ad5567030775fc62/xgrammar-0.2.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a36c744d24d93e178c138486aa02b390a80326b64ff11e222e063a028dd65849", size = 44616361, upload-time = "2026-05-01T18:32:42.536Z" }, +] + +[[package]] +name = "yarl" +version = "1.23.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "multidict" }, + { name = "propcache" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/23/6e/beb1beec874a72f23815c1434518bfc4ed2175065173fb138c3705f658d4/yarl-1.23.0.tar.gz", hash = "sha256:53b1ea6ca88ebd4420379c330aea57e258408dd0df9af0992e5de2078dc9f5d5", size = 194676, upload-time = "2026-03-01T22:07:53.373Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/30/58260ed98e6ff7f90ba84442c1ddd758c9170d70327394a6227b310cd60f/yarl-1.23.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9cbf44c5cb4a7633d078788e1b56387e3d3cf2b8139a3be38040b22d6c3221c8", size = 97587, upload-time = "2026-03-01T22:05:17.384Z" }, + { url = "https://files.pythonhosted.org/packages/76/0a/8b08aac08b50682e65759f7f8dde98ae8168f72487e7357a5d684c581ef9/yarl-1.23.0-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:53ad387048f6f09a8969631e4de3f1bf70c50e93545d64af4f751b2498755072", size = 92528, upload-time = "2026-03-01T22:05:18.804Z" }, + { url = "https://files.pythonhosted.org/packages/52/07/0b7179101fe5f8385ec6c6bb5d0cb9f76bd9fb4a769591ab6fb5cdbfc69a/yarl-1.23.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4a59ba56f340334766f3a4442e0efd0af895fae9e2b204741ef885c446b3a1a8", size = 105339, upload-time = "2026-03-01T22:05:20.235Z" }, + { url = "https://files.pythonhosted.org/packages/d3/8a/36d82869ab5ec829ca8574dfcb92b51286fcfb1e9c7a73659616362dc880/yarl-1.23.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:803a3c3ce4acc62eaf01eaca1208dcf0783025ef27572c3336502b9c232005e7", size = 105061, upload-time = "2026-03-01T22:05:22.268Z" }, + { url = "https://files.pythonhosted.org/packages/66/3e/868e5c3364b6cee19ff3e1a122194fa4ce51def02c61023970442162859e/yarl-1.23.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3d2bff8f37f8d0f96c7ec554d16945050d54462d6e95414babaa18bfafc7f51", size = 100132, upload-time = "2026-03-01T22:05:23.638Z" }, + { url = "https://files.pythonhosted.org/packages/cf/26/9c89acf82f08a52cb52d6d39454f8d18af15f9d386a23795389d1d423823/yarl-1.23.0-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c75eb09e8d55bceb4367e83496ff8ef2bc7ea6960efb38e978e8073ea59ecb67", size = 99289, upload-time = "2026-03-01T22:05:25.749Z" }, + { url = "https://files.pythonhosted.org/packages/6f/54/5b0db00d2cb056922356104468019c0a132e89c8d3ab67d8ede9f4483d2a/yarl-1.23.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877b0738624280e34c55680d6054a307aa94f7d52fa0e3034a9cc6e790871da7", size = 96950, upload-time = "2026-03-01T22:05:27.318Z" }, + { url = "https://files.pythonhosted.org/packages/f6/40/10fa93811fd439341fad7e0718a86aca0de9548023bbb403668d6555acab/yarl-1.23.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:b5405bb8f0e783a988172993cfc627e4d9d00432d6bbac65a923041edacf997d", size = 93960, upload-time = "2026-03-01T22:05:28.738Z" }, + { url = "https://files.pythonhosted.org/packages/bc/d2/8ae2e6cd77d0805f4526e30ec43b6f9a3dfc542d401ac4990d178e4bf0cf/yarl-1.23.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:1c3a3598a832590c5a3ce56ab5576361b5688c12cb1d39429cf5dba30b510760", size = 104703, upload-time = "2026-03-01T22:05:30.438Z" }, + { url = "https://files.pythonhosted.org/packages/2f/0c/b3ceacf82c3fe21183ce35fa2acf5320af003d52bc1fcf5915077681142e/yarl-1.23.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:8419ebd326430d1cbb7efb5292330a2cf39114e82df5cc3d83c9a0d5ebeaf2f2", size = 98325, upload-time = "2026-03-01T22:05:31.835Z" }, + { url = "https://files.pythonhosted.org/packages/9d/e0/12900edd28bdab91a69bd2554b85ad7b151f64e8b521fe16f9ad2f56477a/yarl-1.23.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:be61f6fff406ca40e3b1d84716fde398fc08bc63dd96d15f3a14230a0973ed86", size = 105067, upload-time = "2026-03-01T22:05:33.358Z" }, + { url = "https://files.pythonhosted.org/packages/15/61/74bb1182cf79c9bbe4eb6b1f14a57a22d7a0be5e9cedf8e2d5c2086474c3/yarl-1.23.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3ceb13c5c858d01321b5d9bb65e4cf37a92169ea470b70fec6f236b2c9dd7e34", size = 100285, upload-time = "2026-03-01T22:05:35.4Z" }, + { url = "https://files.pythonhosted.org/packages/69/68/c8739671f5699c7dc470580a4f821ef37c32c4cb0b047ce223a7f115757f/yarl-1.23.0-py3-none-any.whl", hash = "sha256:a2df6afe50dea8ae15fa34c9f824a3ee958d785fd5d089063d960bae1daa0a3f", size = 48288, upload-time = "2026-03-01T22:07:51.388Z" }, +] + +[[package]] +name = "z3-solver" +version = "4.15.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/8e/0c8f17309549d2e5cde9a3ccefa6365437f1e7bafe71878eaf9478e47b18/z3_solver-4.15.4.0.tar.gz", hash = "sha256:928c29b58c4eb62106da51c1914f6a4a55d0441f8f48a81b9da07950434a8946", size = 5018600, upload-time = "2025-10-29T18:12:03.062Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/c9/bb51a96af0091324c81b803f16c49f719f9f6ea0b0bb52200f5c97ec4892/z3_solver-4.15.4.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e103a6f203f505b8b8b8e5c931cc407c95b61556512d4921c1ddc0b3f41b08e", size = 29268352, upload-time = "2025-10-29T18:11:53.032Z" }, + { url = "https://files.pythonhosted.org/packages/bf/2e/0b49f7e4e53817cfb09a0f6585012b782dfe0b666e8abefcb4fac0570606/z3_solver-4.15.4.0-py3-none-manylinux_2_34_aarch64.whl", hash = "sha256:62c7e9cbdd711932301f29919ad9158de9b2f58b4d281dd259bbcd0a2f408ba1", size = 27226534, upload-time = "2025-10-29T18:11:55.59Z" }, +] + +[[package]] +name = "zipp" +version = "3.23.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, +]