From fb4f76d4b16c692db456313943a974d4766a90da Mon Sep 17 00:00:00 2001 From: Nicolas Grande Date: Wed, 13 May 2026 20:44:27 +0000 Subject: [PATCH] update post training deps. --- .../extra_deps/post_train_base_deps.txt | 2 +- .../extra_deps/post_train_github_deps.txt | 4 +- .../extra_deps/tpu_post_train_overrides.txt | 7 +- .../tpu-post-train-requirements.txt | 109 +++++++++--------- src/maxtext/layers/moe.py | 2 - 5 files changed, 62 insertions(+), 62 deletions(-) diff --git a/src/dependencies/extra_deps/post_train_base_deps.txt b/src/dependencies/extra_deps/post_train_base_deps.txt index f57ad482c5..aba67e10bb 100644 --- a/src/dependencies/extra_deps/post_train_base_deps.txt +++ b/src/dependencies/extra_deps/post_train_base_deps.txt @@ -1 +1 @@ -google-tunix @ https://github.com/google/tunix/archive/387072374f99a100cb11f99dec951940b1475a04.zip +google-tunix @ https://github.com/google/tunix/archive/b0712f6fd32a1896cf354cf151d9c472d4ebe856.zip diff --git a/src/dependencies/extra_deps/post_train_github_deps.txt b/src/dependencies/extra_deps/post_train_github_deps.txt index 474d664122..3ea084167c 100644 --- a/src/dependencies/extra_deps/post_train_github_deps.txt +++ b/src/dependencies/extra_deps/post_train_github_deps.txt @@ -1,3 +1,3 @@ -r post_train_base_deps.txt -tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/39d9a9d38d3c96a7e1e57f9e693cf1c96a44e87d.zip -vllm @ git+https://github.com/vllm-project/vllm@529c671e8075d265a48b72e0eaaeb5e30d2f1630 +tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/a46baf9ee149da0fbc1cfe335650e3780e30b585.zip +vllm @ git+https://github.com/vllm-project/vllm@a51376b3f05a2f74eac6ceeed7e52598b871a0fb diff --git a/src/dependencies/extra_deps/tpu_post_train_overrides.txt b/src/dependencies/extra_deps/tpu_post_train_overrides.txt index c7899a2df1..ca27b47c47 100644 --- a/src/dependencies/extra_deps/tpu_post_train_overrides.txt +++ b/src/dependencies/extra_deps/tpu_post_train_overrides.txt @@ -1,4 +1,7 @@ flax==0.12.4 google-metrax>=0.2.3 -libtpu>=0.0.39 -optax==0.2.6 \ No newline at end of file +optax==0.2.6 +transformers>=5.5.0 +datasets>=4.8.5 +fsspec==2026.2.0 +gcsfs==2026.2.0 \ No newline at end of file diff --git a/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt b/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt index 4744ce0315..e0f8b0315f 100644 --- a/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt +++ b/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt @@ -8,7 +8,7 @@ aiohttp>=3.13.5 aiosignal>=1.4.0 annotated-doc>=0.0.4 annotated-types>=0.7.0 -anthropic>=0.100.0 +anthropic>=0.102.0 antlr4-python3-runtime>=4.9.3 anyio>=4.13.0 apache-tvm-ffi>=0.1.11 @@ -19,18 +19,18 @@ astor>=0.8.1 astroid>=4.0.4 asttokens>=3.0.1 astunparse>=1.6.3 -attrs>=25.4.0 +attrs>=26.1.0 auditwheel>=6.6.0 black>=25.12.0 -boto3>=1.43.5 -botocore>=1.43.5 -build>=1.4.0 +boto3>=1.43.6 +botocore>=1.43.6 +build>=1.4.3 cachetools>=7.1.1 -cbor2>=6.0.1 +cbor2>=6.1.0 certifi>=2026.2.25 cffi>=2.0.0 ; implementation_name == 'pypy' or platform_python_implementation != 'PyPy' cfgv>=3.5.0 -charset-normalizer>=3.4.6 +charset-normalizer>=3.4.7 cheroot>=11.1.2 chex>=0.1.91 click>=8.3.3 @@ -47,11 +47,12 @@ cuda-bindings>=13.2.0 ; sys_platform == 'linux' cuda-pathfinder>=1.5.4 ; sys_platform == 'linux' cuda-toolkit>=13.0.2 ; sys_platform == 'linux' cycler>=0.12.1 -dataclasses-json>=0.6.7 +dataclasses>=0.5 +dataclasses-json>=0.0.1 datasets>=4.8.5 debugpy>=1.8.20 decorator>=5.2.1 -dill>=0.4.1 +dill>=0.3.7 distlib>=0.4.0 distro>=1.9.0 dm-tree>=0.1.10 @@ -66,7 +67,7 @@ execnet>=2.1.2 executing>=2.2.1 fastapi>=0.136.1 fastjsonschema>=2.21.2 -filelock>=3.20.3 +filelock>=3.28.0 flatbuffers>=25.12.19 flax>=0.12.4 fonttools>=4.62.1 @@ -77,15 +78,15 @@ gcsfs>=2026.2.0 gepa>=0.1.1 gguf>=0.19.0 google-api-core>=2.30.3 -google-api-python-client>=2.195.0 -google-auth>=2.50.0 -google-auth-httplib2>=0.3.1 -google-auth-oauthlib>=1.3.1 -google-cloud-aiplatform>=1.150.0 +google-api-python-client>=2.196.0 +google-auth>=2.52.0 +google-auth-httplib2>=0.4.0 +google-auth-oauthlib>=1.4.0 +google-cloud-aiplatform>=1.152.0 google-cloud-appengine-logging>=1.9.0 google-cloud-audit-log>=0.5.0 google-cloud-bigquery>=3.41.0 -google-cloud-core>=2.5.1 +google-cloud-core>=2.6.0 google-cloud-logging>=3.15.0 google-cloud-mldiagnostics>=1.0.2 google-cloud-monitoring>=2.30.0 @@ -96,13 +97,13 @@ google-crc32c>=1.8.0 google-genai>=1.75.0 google-metrax>=0.2.3 google-pasta>=0.2.0 -google-resumable-media>=2.8.2 +google-resumable-media>=2.9.0 google-tunix>=0.1.3 -googleapis-common-protos>=1.74.0 +googleapis-common-protos>=1.75.0 grain>=0.2.16 grpc-google-iam-v1>=0.14.4 -grpcio>=1.78.0 -grpcio-status>=1.78.0 +grpcio>=1.80.0 +grpcio-status>=1.80.0 gspread>=6.2.1 gviz-api>=1.10.0 h11>=0.16.0 @@ -112,7 +113,7 @@ hf-xet>=1.5.0 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or httpcore>=1.0.9 httplib2>=0.31.2 httpx>=0.28.1 -huggingface-hub>=1.14.0 +huggingface-hub>=1.5.0 humanize>=4.15.0 hypothesis>=6.142.1 identify>=2.6.19 @@ -128,8 +129,8 @@ ipython-pygments-lexers>=1.1.1 ipywidgets>=8.1.8 isort>=8.0.1 jaraco-functools>=4.4.0 -jax>=0.9.2 -jaxlib>=0.9.2 +jax>=0.10.0 +jaxlib>=0.10.0 jaxtyping>=0.3.9 jedi>=0.20.0 jinja2>=3.1.6 @@ -143,12 +144,12 @@ jupyter-core>=5.9.1 jupyterlab-widgets>=3.0.16 kagglehub>=1.0.1 kagglesdk>=0.1.23 -keras>=3.13.2 +keras>=3.14.0 kiwisolver>=1.5.0 latex2sympy2-extended>=1.11.0 libclang>=18.1.1 libcst>=1.8.6 -libtpu>=0.0.39 +libtpu>=0.0.40 ; platform_machine == 'x86_64' and sys_platform == 'linux' llguidance>=1.7.5 llvmlite>=0.47.0 loguru>=0.7.3 @@ -156,10 +157,9 @@ lxml>=6.1.0 markdown>=3.10.2 markdown-it-py>=4.0.0 markupsafe>=3.0.3 -marshmallow>=3.26.2 math-verify>=0.9.0 matplotlib>=3.10.8 -matplotlib-inline>=0.2.1 +matplotlib-inline>=0.2.2 mccabe>=0.7.0 mdurl>=0.1.2 mistral-common>=1.11.2 @@ -172,7 +172,7 @@ mpmath>=1.3.0 msgpack>=1.1.2 msgspec>=0.21.1 multidict>=6.7.1 -multiprocess>=0.70.19 +multiprocess>=0.70.15 mypy-extensions>=1.1.0 namex>=0.1.0 nbclient>=0.10.4 @@ -184,25 +184,25 @@ nodeenv>=1.10.0 numba>=0.65.1 numpy>=2.1.3 numpy-typing-compat>=20251206.2.1 -nvidia-cublas>=13.1.0.3 ; sys_platform == 'linux' -nvidia-cuda-cccl>=13.2.27 +nvidia-cublas>=13.1.1.3 ; sys_platform == 'linux' +nvidia-cuda-cccl>=13.2.75 nvidia-cuda-cupti>=13.0.85 ; sys_platform == 'linux' nvidia-cuda-nvrtc>=13.0.88 ; sys_platform == 'linux' nvidia-cuda-runtime>=13.0.96 ; sys_platform == 'linux' -nvidia-cudnn-cu13>=9.19.0.56 ; sys_platform == 'linux' +nvidia-cudnn-cu13>=9.20.0.48 ; sys_platform == 'linux' nvidia-cufft>=12.0.0.61 ; sys_platform == 'linux' nvidia-cufile>=1.15.1.6 ; sys_platform == 'linux' nvidia-curand>=10.4.0.35 ; sys_platform == 'linux' nvidia-cusolver>=12.0.4.66 ; sys_platform == 'linux' nvidia-cusparse>=12.6.3.3 ; sys_platform == 'linux' -nvidia-cusparselt-cu13>=0.8.0 ; sys_platform == 'linux' -nvidia-nccl-cu13>=2.28.9 ; sys_platform == 'linux' +nvidia-cusparselt-cu13>=0.8.1 ; sys_platform == 'linux' +nvidia-nccl-cu13>=2.29.7 ; sys_platform == 'linux' nvidia-nvjitlink>=13.0.88 ; sys_platform == 'linux' nvidia-nvshmem-cu13>=3.4.5 ; sys_platform == 'linux' nvidia-nvtx>=13.0.85 ; sys_platform == 'linux' oauthlib>=3.3.1 omegaconf>=2.3.0 -openai>=2.35.1 +openai>=2.36.0 openai-harmony>=0.0.8 opentelemetry-api>=1.41.1 opt-einsum>=3.4.0 @@ -211,8 +211,8 @@ optree>=0.19.0 optype>=0.17.0 orbax-checkpoint>=0.11.39 orbax-export>=0.0.8 -packaging>=26.0 -pandas>=3.0.2 +packaging>=26.1 +pandas>=3.0.3 papermill>=2.7.0 parameterized>=0.9.0 parso>=0.8.7 @@ -221,7 +221,7 @@ pathspec>=1.1.1 pathwaysutils>=0.1.8 perfetto>=0.16.0 pexpect>=4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32' -pillow>=12.1.1 +pillow>=12.2.0 platformdirs>=4.9.6 pluggy>=1.6.0 portpicker>=1.6.0 @@ -230,8 +230,8 @@ prometheus-client>=0.25.0 prometheus-fastapi-instrumentator>=7.1.0 promise>=2.3 prompt-toolkit>=3.0.52 -propcache>=0.4.1 -proto-plus>=1.27.2 +propcache>=0.5.2 +proto-plus>=1.28.0 protobuf>=6.33.6 psutil>=7.2.2 ptyprocess>=0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32' @@ -250,7 +250,7 @@ pydantic-extra-types>=2.11.1 pydot>=4.0.1 pyelftools>=0.32 pyglove>=0.4.5 -pygments>=2.19.2 +pygments>=2.20.0 pyink>=25.12.0 pylint>=4.0.5 pyparsing>=3.3.2 @@ -259,6 +259,7 @@ pytest>=8.4.2 pytest-mock>=3.15.1 pytest-xdist>=3.8.0 python-dateutil>=2.9.0.post0 +python-discovery>=1.3.1 python-dotenv>=1.2.2 python-json-logger>=4.1.0 pytokens>=0.4.1 @@ -268,10 +269,10 @@ pyzmq>=27.1.0 qwix>=0.1.6 ray>=2.55.1 referencing>=0.37.0 -regex>=2026.4.4 -requests>=2.32.5 +regex>=2026.5.9 +requests>=2.33.1 requests-oauthlib>=2.0.0 -rich>=14.3.3 +rich>=15.0.0 rpds-py>=0.30.0 runai-model-streamer>=0.15.9 runai-model-streamer-gcs>=0.15.9 @@ -279,11 +280,10 @@ runai-model-streamer-s3>=0.15.9 s3transfer>=0.17.0 safetensors>=0.7.0 scipy>=1.17.1 -scipy-stubs>=1.17.1.2 +scipy-stubs>=1.17.1.4 sentencepiece>=0.2.1 seqio>=0.0.20 setuptools>=80.10.2 -shellingham>=1.5.4 simple-parsing>=0.1.8 simplejson>=4.1.1 six>=1.17.0 @@ -300,7 +300,7 @@ tensorboard-data-server>=0.7.2 tensorboard-plugin-profile>=2.13.0 tensorboardx>=2.6.5 tensorflow>=2.20.0 -tensorflow-datasets>=4.9.9 +tensorflow-datasets>=4.9.10 tensorflow-metadata>=1.17.3 tensorflow-text>=2.20.1 tensorstore>=0.1.82 @@ -309,29 +309,28 @@ tiktoken>=0.12.0 tokamax>=0.0.12 tokenizers>=0.22.2 toml>=0.10.2 -tomlkit>=0.14.0 +tomlkit>=0.15.0 toolz>=1.1.0 -torch>=2.11.0 +torch>=2.12.0 torchax>=0.0.12 -torchvision>=0.26.0 +torchvision>=0.27.0 tornado>=6.5.5 tpu-info>=0.11.0 tqdm>=4.67.3 traitlets>=5.15.0 -transformers>=5.8.0 +transformers>=5.5.0 treescope>=0.1.10 -triton>=3.6.0 ; sys_platform == 'linux' +triton>=3.7.0 ; sys_platform == 'linux' typeguard>=2.13.3 -typer>=0.25.1 +typer>=0.4.0 typing-extensions>=4.15.0 -typing-inspect>=0.9.0 typing-inspection>=0.4.2 tzdata>=2026.2 ; sys_platform == 'emscripten' or sys_platform == 'win32' uritemplate>=4.2.0 urllib3>=2.6.3 uvicorn>=0.46.0 uvloop>=0.22.1 -virtualenv>=20.36.1 +virtualenv>=21.3.3 wadler-lindig>=0.1.7 watchfiles>=1.1.1 wcwidth>=0.7.0 @@ -346,5 +345,5 @@ xprof>=2.22.2 xxhash>=3.7.0 yapf>=0.43.0 yarl>=1.23.0 -zipp>=3.23.0 +zipp>=3.23.1 zstandard>=0.25.0 diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 975e8fe9a2..0003d28b4f 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -2193,8 +2193,6 @@ def fused_moe_matmul( use_ep=use_ep, activation=activation, scoring_fn=scoring_fn, - sc_kernel_threshold=16777216, - sc_kernel_col_chunk_size=1024, ) # Reshape output 2D [T, D] -> 3D [B, S, D]