Skip to content

Conversation

@Zephyr271828
Copy link

@Zephyr271828 Zephyr271828 commented Jan 30, 2026

Description

This PR aims to implement #2434 and add wandb logging support to MaxText.

Implementation details

The implementation of wandb logging simply follows the style of other logging interfaces.

Initialization

class MetricLogger:
  """
  Logger for saving metrics to a local file, GCS and TensorBoard.
  """

  def __init__(self, config, learning_rate_schedule):
    self.writer = max_utils.initialize_summary_writer(config.tensorboard_dir, config.run_name)
    self.config = config
    self.metadata = {}
    self.running_gcs_metrics = [] if config.gcs_metrics else None
    self.performance_metric_queue = self.get_performance_metric_queue(config)
    self.learning_rate_schedule = learning_rate_schedule
    self.cumulative_eval_metrics = {"scalar": defaultdict(float)}
    self.buffered_train_metrics = None
    
    if self.config.managed_mldiagnostics:
      ManagedMLDiagnostics(config)  # Initialize the MLRun instance.
      
    self.enable_wandb = self.config.enable_wandb and socket.gethostname().endswith("-0") # you should only init wandb on one host.
    if self.enable_wandb: 
      wandb.init(
        project=config.wandb_project_name,
        name=config.wandb_run_name,
        resume="allow",
      ) # Initialize wandb logger.

Logging step

  def write_metrics(self, metrics, step, is_training=True):
    """Entry point for all metrics writing in Train's Main."""
    if metrics:
      self.log_metrics(metrics, step, is_training)

      if self.config.enable_tensorboard:
        self.write_metrics_to_tensorboard(metrics, step, is_training)

      if self.config.metrics_file:
        self.write_metrics_locally(metrics, step)

      if self.config.gcs_metrics and jax.process_index() == 0:
        self.write_metrics_for_gcs(metrics, step, is_training)

      if self.config.managed_mldiagnostics:
        self.write_metrics_to_managed_mldiagnostics(metrics, step)
        
      if self.enable_wandb:
        self.write_metrics_to_wandb(metrics, step)
  def write_metrics_to_wandb(self, metrics, step):
    """Write metrics to weights and biases (wandb)."""
    flat_metrics = {}
    for key, val in metrics.get("scalar", {}).items():
      flat_metrics[key] = float(val)
    for key, val in metrics.get("scalars", {}).items():
      for subkey, subval in val.items():
        flat_metrics[f"{key}/{subkey}"] = float(subval)
    wandb.log(flat_metrics, step=step)

Usage

python -u -m src.MaxText.train src/MaxText/configs/base.yml \
    ...
    enable_wandb=True \
    wandb_project_name=xxx \
    wandb_run_name=yyy \
    ...

Limitations

Currently this implementation does not support resuming from an existing wandb run. In order to resume, we need to first retrieve the run_id from somewhere, then do

wandb.init(
    project=config.wandb_project,
    name=config.wandb_run_name,
    id=run_id,
    resume="allow",
)

It makes sense to save the run_ids at some cache dir inside of the maxtext repo, but I don't know whether that's consistent with the design philosophy of this project.

Tests

Example training script:

#!/bin/bash

python tools/orchestration/multihost_runner.py \
    --TPU_PREFIX=${TPU_PREFIX} \
    --COMMAND="
    export WANDB_API_KEY=''
    export PYTHONPATH=./src:\${PYTHONPATH:-''}
    python -m src.MaxText.train src/MaxText/configs/base.yml \
        run_name=${RUN_NAME} \
        base_output_directory=${BASE_OUTPUT_DIRECTORY} \
        dataset_type=grain \
        grain_train_files=${DATA_FILES} \
        grain_file_type='arrayrecord' \
        grain_worker_count=1 \
        enable_data_shuffling=${SHUFFLE} \
        tokenize_train_data=False \
        tokenize_eval_data=False \
        max_target_length=${SEQ_LEN} \
        async_checkpointing=${ASYNC_CHECKPOINTING} \
        model_name=${MODEL_NAME} \
        steps=${NUM_STEPS} \
        per_device_batch_size=${BATCH_SIZE} \
        gradient_accumulation_steps=${GRAD_ACCUM} \
        gradient_clipping_threshold=${GRAD_CLIP} \
        learning_rate=${LR} \
        warmup_steps_fraction=${WARMUP_RATIO} \
        checkpoint_period=500 \
        enable_wandb=True \
        wandb_project_name=${WANDB_PROJ_NAME} \
        wandb_run_name=${TPU_PREFIX}_${RUN_NAME} \
        packing=false \
    "

Outputs:

Per train step:
 Total TFLOPs: 104.77 
 split as 55.92% learnable weight flops and 44.08% attention flops
I0131 11:17:43.038159 136759033735168 max_utils.py:695] Total memory size: 6.4 GB, Output size: 0.2 GB, Temp size: 6.2 GB, Argument size: 0.2 GB, Host temp size: 0.0 GB.
wandb: [wandb.login()] Loaded credentials for https://api.wandb.ai from WANDB_API_KEY.
wandb: Currently logged in as: yx3038 (yx3038-new-york-university) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: setting up run negt2tz0
wandb: Tracking run with wandb version 0.24.1
wandb: Run data is saved locally in /home/zephyr/2026-01-31-11-16-10/wandb/run-20260131_111743-negt2tz0
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run yufeng-v6e-32-0003_qwen3-0.6b_L200_seqlen_8192_bs_1_grad_accum_2_lr_0.0003_min_lr_ratio_0.1_warmup_ratio_0.05
wandb: ⭐️ View project at https://wandb.ai/yx3038-new-york-university/llm_pruning
wandb: 🚀 View run at https://wandb.ai/yx3038-new-york-university/llm_pruning/runs/negt2tz0
I0131 11:17:44.613295 136759033735168 metric_logger.py:297] number parameters: 0.596 billion
I0131 11:17:44.614987 136707965978176 grain_pool.py:367] Grain pool will use 1 processes.
I0131 11:17:44.618239 136707965978176 grain_pool.py:440] Grain pool will start child processes.
I0131 11:17:44.620215 136707965978176 grain_pool.py:448] Grain pool started all child processes.
2026-01-31 11:17:46.845550: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-01-31 11:17:46.845879: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-31 11:17:46.882721: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-01-31 11:17:48.106673: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-31 11:17:48.107431: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
PyTorch was not found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
2026-01-31 11:17:49.626479: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0131 11:17:58.143736 136759033735168 max_utils.py:654] 
Memstats: After params initialized:
I0131 11:17:58.144038 136759033735168 max_utils.py:660] 	Using (GB) 0.25 / 31.25 (0.800000%) on TPU_10(process=3,(2,2,0,0))
I0131 11:17:58.144209 136759033735168 max_utils.py:660] 	Using (GB) 0.25 / 31.25 (0.800000%) on TPU_11(process=3,(3,2,0,0))
I0131 11:17:58.144291 136759033735168 max_utils.py:660] 	Using (GB) 0.25 / 31.25 (0.800000%) on TPU_14(process=3,(2,3,0,0))
I0131 11:17:58.144401 136759033735168 max_utils.py:660] 	Using (GB) 0.25 / 31.25 (0.800000%) on TPU_15(process=3,(3,3,0,0))
I0131 11:18:29.464875 136759033735168 metric_logger.py:193] completed step: 1, seconds: 13.530, TFLOP/s/device: 7.743, Tokens/s/device: 1210.946, total_weights: 524224, loss: 252.466
I0131 11:18:30.696314 136759033735168 metric_logger.py:193] completed step: 2, seconds: 0.320, TFLOP/s/device: 327.449, Tokens/s/device: 51208.161, total_weights: 524224, loss: 252.398
I0131 11:18:31.927749 136759033735168 metric_logger.py:193] completed step: 3, seconds: 31.010, TFLOP/s/device: 3.378, Tokens/s/device: 528.341, total_weights: 524224, loss: 251.585
I0131 11:18:33.159628 136759033735168 metric_logger.py:193] completed step: 4, seconds: 1.231, TFLOP/s/device: 85.084, Tokens/s/device: 13305.786, total_weights: 524224, loss: 250.936
I0131 11:18:34.390431 136759033735168 metric_logger.py:193] completed step: 5, seconds: 1.231, TFLOP/s/device: 85.076, Tokens/s/device: 13304.576, total_weights: 524224, loss: 248.162
I0131 11:18:35.622099 136759033735168 metric_logger.py:193] completed step: 6, seconds: 1.232, TFLOP/s/device: 85.047, Tokens/s/device: 13300.062, total_weights: 524224, loss: 245.645
I0131 11:18:36.853450 136759033735168 metric_logger.py:193] completed step: 7, seconds: 1.230, TFLOP/s/device: 85.162, Tokens/s/device: 13318.008, total_weights: 524224, loss: 243.418
I0131 11:18:38.084757 136759033735168 metric_logger.py:193] completed step: 8, seconds: 1.232, TFLOP/s/device: 85.046, Tokens/s/device: 13299.846, total_weights: 524224, loss: 241.825
I0131 11:18:39.316017 136759033735168 metric_logger.py:193] completed step: 9, seconds: 1.231, TFLOP/s/device: 85.090, Tokens/s/device: 13306.867, total_weights: 524224, loss: 237.915
I0131 11:18:40.547462 136759033735168 metric_logger.py:193] completed step: 10, seconds: 1.231, TFLOP/s/device: 85.084, Tokens/s/device: 13305.829, total_weights: 524224, loss: 233.977
...

Wandb output:
截屏2026-01-31 19 29 26

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@google-cla
Copy link

google-cla bot commented Jan 30, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@Zephyr271828 Zephyr271828 marked this pull request as draft January 30, 2026 13:49
@Zephyr271828 Zephyr271828 marked this pull request as ready for review January 31, 2026 11:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant