From 2f8466da9266bfaabd6caa7cdeb05008d24d727d Mon Sep 17 00:00:00 2001 From: Branden Vandermoon Date: Thu, 14 May 2026 00:41:32 +0000 Subject: [PATCH] Refactor temporary directory generation across scripts and utilities --- .../checkpoint_conversion/examples/convert_gemma2_to_hf.sh | 2 +- .../checkpoint_conversion/examples/convert_gemma3_to_hf.sh | 2 +- src/maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh | 3 ++- src/maxtext/inference/gpu/microbenchmark_llama2-70b_h100-8.sh | 2 +- src/maxtext/inference/mlperf/llama_offline_run.sh | 3 ++- src/maxtext/inference/mlperf/matmul/timing_util.py | 4 +--- .../integration/vllm/torchax_converter/validate_converter.py | 3 ++- 7 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/maxtext/checkpoint_conversion/examples/convert_gemma2_to_hf.sh b/src/maxtext/checkpoint_conversion/examples/convert_gemma2_to_hf.sh index 2d3c839f85..13ad3751ec 100644 --- a/src/maxtext/checkpoint_conversion/examples/convert_gemma2_to_hf.sh +++ b/src/maxtext/checkpoint_conversion/examples/convert_gemma2_to_hf.sh @@ -9,7 +9,7 @@ DATE=$(date +%Y-%m-%d) # Define variables for paths and arguments HF_CHECKPOINT_GCS_PATH="gs://maxtext-model-checkpoints/HuggingFace/gemma2-2b/${DATE}" # (optional)GCS path for HF model MAXTEXT_CHECKPOINT_DIR="gs://maxtext-model-checkpoints/gemma2-2b-it/2025-02-20-18-01/unscanned/checkpoints/0/items" -LOCAL_HF_CHECKPOINT_DIR="/tmp/hf_gemma2-2b_output" # HF requires a local dir +LOCAL_HF_CHECKPOINT_DIR=$(mktemp -d) # HF requires a local dir TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}/tokenizer.gemma" MODEL_NAME="gemma2-2b" PER_DEVICE_BATCH_SIZE=1 diff --git a/src/maxtext/checkpoint_conversion/examples/convert_gemma3_to_hf.sh b/src/maxtext/checkpoint_conversion/examples/convert_gemma3_to_hf.sh index 5489d962d5..fe682a448f 100644 --- a/src/maxtext/checkpoint_conversion/examples/convert_gemma3_to_hf.sh +++ b/src/maxtext/checkpoint_conversion/examples/convert_gemma3_to_hf.sh @@ -9,7 +9,7 @@ DATE=$(date +%Y-%m-%d) # Define variables for paths and arguments HF_CHECKPOINT_GCS_PATH="gs://maxtext-model-checkpoints/HuggingFace/gemma3-4b/${DATE}" # (optional)GCS path for HF model MAXTEXT_CHECKPOINT_DIR="gs://maxtext-model-checkpoints/gemma3-4b/2025-03-18-19-03/unscanned/checkpoints/0/items" -LOCAL_HF_CHECKPOINT_DIR="/tmp/hf_gemma3-4b_output" # HF requires a local dir +LOCAL_HF_CHECKPOINT_DIR=$(mktemp -d) # HF requires a local dir TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}/tokenizer.gemma3" MODEL_NAME="gemma3-4b" PER_DEVICE_BATCH_SIZE=1 diff --git a/src/maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh b/src/maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh index 0fc5a07429..b0fccd7e74 100644 --- a/src/maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh +++ b/src/maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh @@ -17,7 +17,8 @@ bash src/dependencies/scripts/preflight.sh PLATFORM=gke # flags set as default # hlo dump -export XLA_FLAGS="--xla_dump_to=/tmp/xla_dump_file" +XLA_DUMP_DIR=$(mktemp -d) +export XLA_FLAGS="--xla_dump_to=${XLA_DUMP_DIR}" # debug export TPU_STDERR_LOG_LEVEL=0 diff --git a/src/maxtext/inference/gpu/microbenchmark_llama2-70b_h100-8.sh b/src/maxtext/inference/gpu/microbenchmark_llama2-70b_h100-8.sh index e4b3163964..fd43058588 100755 --- a/src/maxtext/inference/gpu/microbenchmark_llama2-70b_h100-8.sh +++ b/src/maxtext/inference/gpu/microbenchmark_llama2-70b_h100-8.sh @@ -51,7 +51,7 @@ done # Default parameters if [[ -z ${BASE_OUTPUT_DIRECTORY} ]] ; then - export BASE_OUTPUT_DIRECTORY="/tmp/maxtext" + export BASE_OUTPUT_DIRECTORY=$(mktemp -d) fi if [[ -z ${INFERENCE_LOG_FILE_PATH} ]] ; then export INFERENCE_LOG_FILE_PATH="${BASE_OUTPUT_DIRECTORY}/microbenchmark_llama2-70b_h100-8_results.txt" diff --git a/src/maxtext/inference/mlperf/llama_offline_run.sh b/src/maxtext/inference/mlperf/llama_offline_run.sh index 9d2991df49..8cf266ef27 100755 --- a/src/maxtext/inference/mlperf/llama_offline_run.sh +++ b/src/maxtext/inference/mlperf/llama_offline_run.sh @@ -122,7 +122,8 @@ fi # LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" # makes subsequent runs faster -export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2" +JAX_COMPILATION_CACHE_DIR=$(mktemp -d) +export JAX_COMPILATION_CACHE_DIR export LIBTPU_INIT_ARGS # Ensure working directory is at repository root. diff --git a/src/maxtext/inference/mlperf/matmul/timing_util.py b/src/maxtext/inference/mlperf/matmul/timing_util.py index d3b6aaa809..3eb2fa1f3a 100644 --- a/src/maxtext/inference/mlperf/matmul/timing_util.py +++ b/src/maxtext/inference/mlperf/matmul/timing_util.py @@ -14,7 +14,6 @@ """ Timing utility functions """ import datetime -import os.path import tempfile import jax @@ -25,8 +24,7 @@ def simple_timeit(f, *args, tries=10, task=None, enable_profile=False): assert task is not None trace_name = f"{task}" # + '_' ]+ ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) - temp_dir = tempfile.gettempdir() - trace_dir = os.path.join(temp_dir, trace_name) + trace_dir = tempfile.mkdtemp(prefix=trace_name + "_") print(trace_dir) outcomes_ms = [] diff --git a/src/maxtext/integration/vllm/torchax_converter/validate_converter.py b/src/maxtext/integration/vllm/torchax_converter/validate_converter.py index f0afd7bbdf..1f710de292 100644 --- a/src/maxtext/integration/vllm/torchax_converter/validate_converter.py +++ b/src/maxtext/integration/vllm/torchax_converter/validate_converter.py @@ -50,6 +50,7 @@ import io import logging import os +import tempfile from typing import Sequence from absl import app @@ -73,7 +74,7 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s") -_JAX_COMPILATION_CACHE_DIR = "/tmp/jax_cache" +_JAX_COMPILATION_CACHE_DIR = tempfile.mkdtemp() vllm_model_name_mapping = { "qwen3-30b-a3b": "Qwen/Qwen3-30B-A3B",