|
195 | 195 | "import jax\n", |
196 | 196 | "import os\n", |
197 | 197 | "import sys\n", |
| 198 | + "import subprocess\n", |
198 | 199 | "import transformers\n", |
199 | 200 | "\n", |
200 | 201 | "import MaxText\n", |
201 | | - "from MaxText import pyconfig\n", |
202 | 202 | "from MaxText.examples.sft_train_and_evaluate import evaluate_model, get_test_dataset\n", |
203 | 203 | "from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter\n", |
204 | 204 | "from MaxText.sft import sft_trainer\n", |
|
312 | 312 | "source": [ |
313 | 313 | "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", |
314 | 314 | " # install torch for the conversion script\n", |
315 | | - " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n", |
316 | | - "\n", |
317 | | - " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_REPO_ROOT} {sys.executable} -m MaxText.utils.ckpt_conversion.to_maxtext \\\n", |
318 | | - " {MAXTEXT_REPO_ROOT}/configs/base.yml \\\n", |
319 | | - " model_name={MODEL_NAME} \\\n", |
320 | | - " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n", |
321 | | - " hf_access_token={HF_TOKEN} \\\n", |
322 | | - " use_multimodal=false \\\n", |
323 | | - " scan_layers=true \\\n", |
324 | | - " skip_jax_distributed_system=True\n", |
| 315 | + " subprocess.run('uv pip install torch --index-url https://download.pytorch.org/whl/cpu', shell=True, check=True)\n", |
| 316 | + "\n", |
| 317 | + " subprocess.run(f'JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_REPO_ROOT} {sys.executable} -m MaxText.utils.ckpt_conversion.to_maxtext {MAXTEXT_REPO_ROOT}/configs/base.yml model_name={MODEL_NAME} base_output_directory={MODEL_CHECKPOINT_PATH} hf_access_token={HF_TOKEN} use_multimodal=false scan_layers=true skip_jax_distributed_system=True', shell=True, check=True)\n", |
325 | 318 | "\n", |
326 | 319 | "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", |
327 | 320 | " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")" |
|
0 commit comments