Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dependencies/extra_deps/post_train_base_deps.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
google-tunix @ https://github.com/google/tunix/archive/387072374f99a100cb11f99dec951940b1475a04.zip
google-tunix @ https://github.com/google/tunix/archive/b0712f6fd32a1896cf354cf151d9c472d4ebe856.zip
4 changes: 2 additions & 2 deletions src/dependencies/extra_deps/post_train_github_deps.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-r post_train_base_deps.txt
tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/39d9a9d38d3c96a7e1e57f9e693cf1c96a44e87d.zip
vllm @ git+https://github.com/vllm-project/vllm@529c671e8075d265a48b72e0eaaeb5e30d2f1630
tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/a46baf9ee149da0fbc1cfe335650e3780e30b585.zip
vllm @ git+https://github.com/vllm-project/vllm@a51376b3f05a2f74eac6ceeed7e52598b871a0fb
7 changes: 5 additions & 2 deletions src/dependencies/extra_deps/tpu_post_train_overrides.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
flax==0.12.4
google-metrax>=0.2.3
libtpu>=0.0.39
optax==0.2.6
optax==0.2.6
transformers>=5.5.0
datasets>=4.8.5
fsspec==2026.2.0
gcsfs==2026.2.0
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ aiohttp>=3.13.5
aiosignal>=1.4.0
annotated-doc>=0.0.4
annotated-types>=0.7.0
anthropic>=0.100.0
anthropic>=0.102.0
antlr4-python3-runtime>=4.9.3
anyio>=4.13.0
apache-tvm-ffi>=0.1.11
Expand All @@ -19,18 +19,18 @@ astor>=0.8.1
astroid>=4.0.4
asttokens>=3.0.1
astunparse>=1.6.3
attrs>=25.4.0
attrs>=26.1.0
auditwheel>=6.6.0
black>=25.12.0
boto3>=1.43.5
botocore>=1.43.5
build>=1.4.0
boto3>=1.43.6
botocore>=1.43.6
build>=1.4.3
cachetools>=7.1.1
cbor2>=6.0.1
cbor2>=6.1.0
certifi>=2026.2.25
cffi>=2.0.0 ; implementation_name == 'pypy' or platform_python_implementation != 'PyPy'
cfgv>=3.5.0
charset-normalizer>=3.4.6
charset-normalizer>=3.4.7
cheroot>=11.1.2
chex>=0.1.91
click>=8.3.3
Expand All @@ -47,11 +47,12 @@ cuda-bindings>=13.2.0 ; sys_platform == 'linux'
cuda-pathfinder>=1.5.4 ; sys_platform == 'linux'
cuda-toolkit>=13.0.2 ; sys_platform == 'linux'
cycler>=0.12.1
dataclasses-json>=0.6.7
dataclasses>=0.5
dataclasses-json>=0.0.1
datasets>=4.8.5
debugpy>=1.8.20
decorator>=5.2.1
dill>=0.4.1
dill>=0.3.7
distlib>=0.4.0
distro>=1.9.0
dm-tree>=0.1.10
Expand All @@ -66,7 +67,7 @@ execnet>=2.1.2
executing>=2.2.1
fastapi>=0.136.1
fastjsonschema>=2.21.2
filelock>=3.20.3
filelock>=3.28.0
flatbuffers>=25.12.19
flax>=0.12.4
fonttools>=4.62.1
Expand All @@ -77,15 +78,15 @@ gcsfs>=2026.2.0
gepa>=0.1.1
gguf>=0.19.0
google-api-core>=2.30.3
google-api-python-client>=2.195.0
google-auth>=2.50.0
google-auth-httplib2>=0.3.1
google-auth-oauthlib>=1.3.1
google-cloud-aiplatform>=1.150.0
google-api-python-client>=2.196.0
google-auth>=2.52.0
google-auth-httplib2>=0.4.0
google-auth-oauthlib>=1.4.0
google-cloud-aiplatform>=1.152.0
google-cloud-appengine-logging>=1.9.0
google-cloud-audit-log>=0.5.0
google-cloud-bigquery>=3.41.0
google-cloud-core>=2.5.1
google-cloud-core>=2.6.0
google-cloud-logging>=3.15.0
google-cloud-mldiagnostics>=1.0.2
google-cloud-monitoring>=2.30.0
Expand All @@ -96,13 +97,13 @@ google-crc32c>=1.8.0
google-genai>=1.75.0
google-metrax>=0.2.3
google-pasta>=0.2.0
google-resumable-media>=2.8.2
google-resumable-media>=2.9.0
google-tunix>=0.1.3
googleapis-common-protos>=1.74.0
googleapis-common-protos>=1.75.0
grain>=0.2.16
grpc-google-iam-v1>=0.14.4
grpcio>=1.78.0
grpcio-status>=1.78.0
grpcio>=1.80.0
grpcio-status>=1.80.0
gspread>=6.2.1
gviz-api>=1.10.0
h11>=0.16.0
Expand All @@ -112,7 +113,7 @@ hf-xet>=1.5.0 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or
httpcore>=1.0.9
httplib2>=0.31.2
httpx>=0.28.1
huggingface-hub>=1.14.0
huggingface-hub>=1.5.0
humanize>=4.15.0
hypothesis>=6.142.1
identify>=2.6.19
Expand All @@ -128,8 +129,8 @@ ipython-pygments-lexers>=1.1.1
ipywidgets>=8.1.8
isort>=8.0.1
jaraco-functools>=4.4.0
jax>=0.9.2
jaxlib>=0.9.2
jax>=0.10.0
jaxlib>=0.10.0
jaxtyping>=0.3.9
jedi>=0.20.0
jinja2>=3.1.6
Expand All @@ -143,23 +144,22 @@ jupyter-core>=5.9.1
jupyterlab-widgets>=3.0.16
kagglehub>=1.0.1
kagglesdk>=0.1.23
keras>=3.13.2
keras>=3.14.0
kiwisolver>=1.5.0
latex2sympy2-extended>=1.11.0
libclang>=18.1.1
libcst>=1.8.6
libtpu>=0.0.39
libtpu>=0.0.40 ; platform_machine == 'x86_64' and sys_platform == 'linux'
llguidance>=1.7.5
llvmlite>=0.47.0
loguru>=0.7.3
lxml>=6.1.0
markdown>=3.10.2
markdown-it-py>=4.0.0
markupsafe>=3.0.3
marshmallow>=3.26.2
math-verify>=0.9.0
matplotlib>=3.10.8
matplotlib-inline>=0.2.1
matplotlib-inline>=0.2.2
mccabe>=0.7.0
mdurl>=0.1.2
mistral-common>=1.11.2
Expand All @@ -172,7 +172,7 @@ mpmath>=1.3.0
msgpack>=1.1.2
msgspec>=0.21.1
multidict>=6.7.1
multiprocess>=0.70.19
multiprocess>=0.70.15
mypy-extensions>=1.1.0
namex>=0.1.0
nbclient>=0.10.4
Expand All @@ -184,25 +184,25 @@ nodeenv>=1.10.0
numba>=0.65.1
numpy>=2.1.3
numpy-typing-compat>=20251206.2.1
nvidia-cublas>=13.1.0.3 ; sys_platform == 'linux'
nvidia-cuda-cccl>=13.2.27
nvidia-cublas>=13.1.1.3 ; sys_platform == 'linux'
nvidia-cuda-cccl>=13.2.75
nvidia-cuda-cupti>=13.0.85 ; sys_platform == 'linux'
nvidia-cuda-nvrtc>=13.0.88 ; sys_platform == 'linux'
nvidia-cuda-runtime>=13.0.96 ; sys_platform == 'linux'
nvidia-cudnn-cu13>=9.19.0.56 ; sys_platform == 'linux'
nvidia-cudnn-cu13>=9.20.0.48 ; sys_platform == 'linux'
nvidia-cufft>=12.0.0.61 ; sys_platform == 'linux'
nvidia-cufile>=1.15.1.6 ; sys_platform == 'linux'
nvidia-curand>=10.4.0.35 ; sys_platform == 'linux'
nvidia-cusolver>=12.0.4.66 ; sys_platform == 'linux'
nvidia-cusparse>=12.6.3.3 ; sys_platform == 'linux'
nvidia-cusparselt-cu13>=0.8.0 ; sys_platform == 'linux'
nvidia-nccl-cu13>=2.28.9 ; sys_platform == 'linux'
nvidia-cusparselt-cu13>=0.8.1 ; sys_platform == 'linux'
nvidia-nccl-cu13>=2.29.7 ; sys_platform == 'linux'
nvidia-nvjitlink>=13.0.88 ; sys_platform == 'linux'
nvidia-nvshmem-cu13>=3.4.5 ; sys_platform == 'linux'
nvidia-nvtx>=13.0.85 ; sys_platform == 'linux'
oauthlib>=3.3.1
omegaconf>=2.3.0
openai>=2.35.1
openai>=2.36.0
openai-harmony>=0.0.8
opentelemetry-api>=1.41.1
opt-einsum>=3.4.0
Expand All @@ -211,8 +211,8 @@ optree>=0.19.0
optype>=0.17.0
orbax-checkpoint>=0.11.39
orbax-export>=0.0.8
packaging>=26.0
pandas>=3.0.2
packaging>=26.1
pandas>=3.0.3
papermill>=2.7.0
parameterized>=0.9.0
parso>=0.8.7
Expand All @@ -221,7 +221,7 @@ pathspec>=1.1.1
pathwaysutils>=0.1.8
perfetto>=0.16.0
pexpect>=4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
pillow>=12.1.1
pillow>=12.2.0
platformdirs>=4.9.6
pluggy>=1.6.0
portpicker>=1.6.0
Expand All @@ -230,8 +230,8 @@ prometheus-client>=0.25.0
prometheus-fastapi-instrumentator>=7.1.0
promise>=2.3
prompt-toolkit>=3.0.52
propcache>=0.4.1
proto-plus>=1.27.2
propcache>=0.5.2
proto-plus>=1.28.0
protobuf>=6.33.6
psutil>=7.2.2
ptyprocess>=0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
Expand All @@ -250,7 +250,7 @@ pydantic-extra-types>=2.11.1
pydot>=4.0.1
pyelftools>=0.32
pyglove>=0.4.5
pygments>=2.19.2
pygments>=2.20.0
pyink>=25.12.0
pylint>=4.0.5
pyparsing>=3.3.2
Expand All @@ -259,6 +259,7 @@ pytest>=8.4.2
pytest-mock>=3.15.1
pytest-xdist>=3.8.0
python-dateutil>=2.9.0.post0
python-discovery>=1.3.1
python-dotenv>=1.2.2
python-json-logger>=4.1.0
pytokens>=0.4.1
Expand All @@ -268,22 +269,21 @@ pyzmq>=27.1.0
qwix>=0.1.6
ray>=2.55.1
referencing>=0.37.0
regex>=2026.4.4
requests>=2.32.5
regex>=2026.5.9
requests>=2.33.1
requests-oauthlib>=2.0.0
rich>=14.3.3
rich>=15.0.0
rpds-py>=0.30.0
runai-model-streamer>=0.15.9
runai-model-streamer-gcs>=0.15.9
runai-model-streamer-s3>=0.15.9
s3transfer>=0.17.0
safetensors>=0.7.0
scipy>=1.17.1
scipy-stubs>=1.17.1.2
scipy-stubs>=1.17.1.4
sentencepiece>=0.2.1
seqio>=0.0.20
setuptools>=80.10.2
shellingham>=1.5.4
simple-parsing>=0.1.8
simplejson>=4.1.1
six>=1.17.0
Expand All @@ -300,7 +300,7 @@ tensorboard-data-server>=0.7.2
tensorboard-plugin-profile>=2.13.0
tensorboardx>=2.6.5
tensorflow>=2.20.0
tensorflow-datasets>=4.9.9
tensorflow-datasets>=4.9.10
tensorflow-metadata>=1.17.3
tensorflow-text>=2.20.1
tensorstore>=0.1.82
Expand All @@ -309,29 +309,28 @@ tiktoken>=0.12.0
tokamax>=0.0.12
tokenizers>=0.22.2
toml>=0.10.2
tomlkit>=0.14.0
tomlkit>=0.15.0
toolz>=1.1.0
torch>=2.11.0
torch>=2.12.0
torchax>=0.0.12
torchvision>=0.26.0
torchvision>=0.27.0
tornado>=6.5.5
tpu-info>=0.11.0
tqdm>=4.67.3
traitlets>=5.15.0
transformers>=5.8.0
transformers>=5.5.0
treescope>=0.1.10
triton>=3.6.0 ; sys_platform == 'linux'
triton>=3.7.0 ; sys_platform == 'linux'
typeguard>=2.13.3
typer>=0.25.1
typer>=0.4.0
typing-extensions>=4.15.0
typing-inspect>=0.9.0
typing-inspection>=0.4.2
tzdata>=2026.2 ; sys_platform == 'emscripten' or sys_platform == 'win32'
uritemplate>=4.2.0
urllib3>=2.6.3
uvicorn>=0.46.0
uvloop>=0.22.1
virtualenv>=20.36.1
virtualenv>=21.3.3
wadler-lindig>=0.1.7
watchfiles>=1.1.1
wcwidth>=0.7.0
Expand All @@ -346,5 +345,5 @@ xprof>=2.22.2
xxhash>=3.7.0
yapf>=0.43.0
yarl>=1.23.0
zipp>=3.23.0
zipp>=3.23.1
zstandard>=0.25.0
2 changes: 0 additions & 2 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2193,8 +2193,6 @@ def fused_moe_matmul(
use_ep=use_ep,
activation=activation,
scoring_fn=scoring_fn,
sc_kernel_threshold=16777216,
sc_kernel_col_chunk_size=1024,
)

# Reshape output 2D [T, D] -> 3D [B, S, D]
Expand Down