From 6c071da220174f5fc2e45c8238e778d1b5d0a9d1 Mon Sep 17 00:00:00 2001 From: hengtaoguo Date: Wed, 13 May 2026 23:45:09 +0000 Subject: [PATCH] Add multimodal quality evaluation scripts for ChartQA --- benchmarks/multimodal/__init__.py | 15 + benchmarks/multimodal/multimodal_eval.py | 360 +++++++++++++++++++++++ src/maxtext/configs/types.py | 1 + 3 files changed, 376 insertions(+) create mode 100644 benchmarks/multimodal/__init__.py create mode 100644 benchmarks/multimodal/multimodal_eval.py diff --git a/benchmarks/multimodal/__init__.py b/benchmarks/multimodal/__init__.py new file mode 100644 index 0000000000..11f31009e1 --- /dev/null +++ b/benchmarks/multimodal/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/benchmarks/multimodal/multimodal_eval.py b/benchmarks/multimodal/multimodal_eval.py new file mode 100644 index 0000000000..577e269621 --- /dev/null +++ b/benchmarks/multimodal/multimodal_eval.py @@ -0,0 +1,360 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This script runs a multimodal quality benchmark for a trained checkpoint. + +Usage: +# Gemma3-4b on a single TPU v4-8 VM +python3 -m benchmarks.multimodal.multimodal_eval MaxText/configs/base.yml \ + model_name=gemma3-4b tokenizer_path=assets/tokenizer.gemma3 \ + load_parameters_path=gs://maxtext-model-checkpoints/gemma3-4b/multimodal/2025-05-21-23-23-59/checkpoints/0/items \ + base_output_directory=$YOUR_GCS_PATH \ + per_device_batch_size=1 run_name=mmeval_test steps=1 async_checkpointing=false \ + scan_layers=false use_multimodal=true attention=\'dot_product\' \ + max_prefill_predict_length=550 max_target_length=570 per_device_batch_size=1 \ + hf_data_dir=HuggingFaceM4/ChartQA hf_eval_split=test + +# Llama4-17b-16e on a TPU v5p-128 cluster (images resized to 336x336 for simplicity) +python -m benchmarks.multimodal.multimodal_eval \ + MaxText/configs/base.yml model_name=llama4-17b-16e image_resize=336 \ + tokenizer_path=meta-llama/Llama-4-Scout-17B-16E \ + load_parameters_path=gs://maxtext-model-checkpoints/llama4-17b-16e/hybrid/2025-07-22-11-03-20/0/items \ + base_output_directory=$YOUR_GCS_PATH \ + per_device_batch_size=1 run_name=mmeval_test steps=1 async_checkpointing=false \ + scan_layers=true use_multimodal=true attention=\'dot_product\' \ + max_prefill_predict_length=350 max_target_length=370 per_device_batch_size=1 \ + hf_data_dir=HuggingFaceM4/ChartQA hf_eval_split=test hf_access_token=\'$YOUR_HF_ACCESS_TOKEN\' \ + ici_fsdp_parallelism=1 ici_expert_parallelism=16 ici_tensor_parallelism=4 +""" + + +import argparse +import os +import sys +from dataclasses import dataclass, field +from datetime import datetime +from typing import List, Optional + +import absl +from maxtext.inference.inference_utils import str2bool + +import datasets +import jax +import numpy as np +import pandas as pd +from PIL import Image +from tqdm import tqdm + +from maxtext.configs import pyconfig +from maxtext.inference.maxengine import maxengine +from maxtext.multimodal import processor as mm_processor +from maxtext.utils import gcs_utils +from maxtext.utils import max_logging +from maxtext.utils import max_utils +from maxtext.trainers.post_train.rl import utils_rl + + +@dataclass +class DebugConfig: + rl: bool = False + + +@dataclass +class TmvpConfig: + solution_start_token: str = "" + solution_end_token: str = "" + debug: DebugConfig = field(default_factory=DebugConfig) + + +absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log + + +ASCII_UPPERCASE_A = ord("A") # ASCII value for uppercase 'A' +SUPPORTED_DATASETS = ["HuggingFaceM4/ChartQA"] + +# To guide any ckpts converted from HF to answer in the desired format, use a default prompt template +DEFAULT_PROMPT_TEMPLATE = """You are an expert at answering questions based +on provided charts. Your task is to extract the exact answer from the +given context or determine that it's not present. +For numerical answers, provide only the number. +For text answers, provide only the exact text. +For judgement questions, respond with "Yes" or "No". +If not found, output "N/A". +Your output must be only the exact answer within , with no extra contents. + +Example: +Question: What is the capital of France? +Your answer: Paris + +Chart: {image_placeholder} Question: {question} +""" + +# For MaxText SFT ckpts, use a simpler prompt (aligned with input_pipeline_utils.reformat_prompt) +SFT_PROMPT_TEMPLATE = "{image_placeholder}{question}" + + +@dataclass +class ParsedDatasetExample: + """Parsed example from the HuggingFace dataset.""" + + question: Optional[str] = None + image_np: Optional[np.ndarray] = None + choices: Optional[List[str]] = None + answer: Optional[str] = None + + +def parse_dataset_example(example, hf_dataset_name, config): + """Parse a single example from the HuggingFace dataset.""" + parsed_example = ParsedDatasetExample() + if hf_dataset_name == "HuggingFaceM4/ChartQA": + parsed_example.question = example["query"] + parsed_example.image_np = np.asarray(example["image"].convert("RGB")) # Convert PIL object to np array + parsed_example.answer = example["label"][0] + else: + raise ValueError(f"Unsupported dataset: {hf_dataset_name}") + + # Resize the image if specified. This helps simplify the llama4's tiling, so we have a fixed input size + if getattr(config, "image_resize", -1) != -1: + pil_img = Image.fromarray(parsed_example.image_np) + pil_img = pil_img.resize((config.image_resize, config.image_resize)) + parsed_example.image_np = np.asarray(pil_img.convert("RGB")) + + return parsed_example + + +def construct_prompt( + parsed_dataset_example: ParsedDatasetExample, config, local_args, system_message: Optional[str] = None +): + """Construct prompt from a parsed dataset example.""" + # image_placeholder = multimodal_utils.get_image_placeholder(config.model_name) if config.use_multimodal else "" + image_placeholder = config.image_placeholder + choices_text = ( + "\n".join(f"{chr(ASCII_UPPERCASE_A + idx)}. {choice}" for idx, choice in enumerate(parsed_dataset_example.choices)) + if parsed_dataset_example.choices + else "" + ) + if local_args.ckpt_type == "base": + prompt = DEFAULT_PROMPT_TEMPLATE.format( + image_placeholder=image_placeholder, + question=parsed_dataset_example.question, + choices=choices_text if choices_text else "N/A", + ) + elif local_args.ckpt_type == "sft": + prompt = mm_processor.reformat_prompt( + parsed_dataset_example.question, + image_placeholder, + config.model_name, + num_images=1 if config.use_multimodal else 0, + ) + else: + raise ValueError(f"Unsupported ckpt_type: {local_args.ckpt_type}") + + prompt = system_message + "\n\n" + prompt if system_message else prompt + return prompt + + +def main(config, local_args): + engine = maxengine.MaxEngine(config) + params = engine.load_params() + + metadata = engine.get_tokenizer() + tokenizer = engine.build_tokenizer(metadata) + + max_prefill_predict_length = getattr(config, "max_prefill_predict_length", 1024) + max_target_length = getattr(config, "max_target_length", 2048) + + # Initialize counters for overall accuracy + correct_count = 0 + total_count = 0 + + # Get the HuggingFace dataset path and name from the config + hf_data_dir = config.hf_data_dir + hf_eval_split = config.hf_eval_split + + # Config for saving csv results + timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + results_file_name = local_args.tmp_results_file + result_gcs_path = f"{config.base_output_directory}/{timestamp}.csv" if config.base_output_directory else None + max_logging.log(f"Results will be saved to {results_file_name} and uploaded to GCS: {result_gcs_path}") + results_data = [] + tmvp_config = TmvpConfig() + + test_ds = datasets.load_dataset(hf_data_dir, "default", split=hf_eval_split) + for idx, example in enumerate(tqdm(test_ds, desc=f"Evaluating {hf_data_dir} dataset")): + prefill_length = config.max_prefill_predict_length + parsed_dataset_example = parse_dataset_example(example, hf_data_dir, config) + prompt = construct_prompt(parsed_dataset_example, config, local_args) + processor_output = mm_processor.preprocess_image_for_training(parsed_dataset_example.image_np, config.model_name) + prefill_length -= mm_processor.get_image_offsets(config=config, processor_output=processor_output) + print("\n" + "*" * 50) + + # Tokenize the input + tokens, true_length = tokenizer.encode(prompt, is_bos=True, prefill_lengths=[prefill_length]) + if config.use_multimodal: + tokens = mm_processor.prepare_text_for_image_fusion(tokens=tokens, config=config, processor_output=processor_output) + image_offsets = mm_processor.get_image_offsets(config=config, processor_output=processor_output) + true_length += image_offsets + if true_length > max_prefill_predict_length: + max_logging.log( + f"Warning: Prompt length {true_length} exceeds max prefill length" f" {max_prefill_predict_length}. Truncating." + ) + tokens = tokens[:max_prefill_predict_length] + true_length = max_prefill_predict_length + assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet" + assert config.quantization != "nanoo_fp8", "NANOO fp8 on AMD MI300/MI325 GPUs is not supported in decode.py yet" + + # Perform prefill + prefill_result, first_token = engine.prefill( + params=params, padded_tokens=tokens, images=processor_output.pixel_values, true_length=true_length + ) + slot = 0 + + # Initialize decode state + decode_state = engine.init_decode_state() + decode_state = engine.insert(prefill_result, decode_state, slot=slot) + + steps = range(max_prefill_predict_length, max_target_length) + sampled_tokens = [first_token.get_result_at_slot(slot).tokens.item()] + + predicted_answer = "" + + for _ in steps: + # Decode generated tokens so far + output = tokenizer.decode(sampled_tokens) + predicted_answer = utils_rl.extract_answer(output, tmvp_config) + if predicted_answer != utils_rl.FALLBACK_ANSWER: + break + + # Generate next token + decode_state, sampled_token = engine.generate(params, decode_state) + sampled_tokens.append(sampled_token.get_result_at_slot(slot).tokens.item()) + if sampled_tokens[-1] == tokenizer.eos_id: + break + + correct_answer = parsed_dataset_example.answer + if predicted_answer == utils_rl.FALLBACK_ANSWER: + predicted_answer = utils_rl.extract_answer(output, tmvp_config) + + exact_correct, _ = utils_rl.check_correctness(predicted_answer, [correct_answer], tmvp_config) + is_correct = exact_correct + + # Log answer + max_logging.log( + f"{total_count + 1} | {parsed_dataset_example.question}\n" + f"[Model output] {output}\n" + f"[Label answer] {correct_answer}\n" + f"Matching: {is_correct}" + ) + + # Save results for CSV + results_data.append( + { + "question ID": total_count + 1, + "question": parsed_dataset_example.question, + "label": parsed_dataset_example.answer, + "output": output, + "is_correct": is_correct, + } + ) + + # Update accuracy for overall + if is_correct: + correct_count += 1 + total_count += 1 + max_logging.log(f"Running accuracy: {correct_count / (total_count):.4f} | Processed: {total_count}/{len(test_ds)}") + + if local_args.num_examples != -1 and total_count >= local_args.num_examples: + break + + # Every 100 rows, save intermediate results to CSV and upload to GCS + if idx % 100 == 0 and result_gcs_path is not None and jax.process_index() == 0: + results_df = pd.DataFrame(results_data) + results_df.to_csv(results_file_name, index=False) + gcs_utils.upload_blob(result_gcs_path, results_file_name) + max_logging.log(f"Uploaded the results file to GCS bucket: {result_gcs_path}") + + # Final accuracy + if total_count > 0: + accuracy = correct_count / total_count + max_logging.log(f"\nFinal accuracy on {hf_data_dir} dataset: {accuracy:.4f}") + else: + max_logging.log("No valid predictions were made.") + + # Save predictions to CSV and upload to GCS + if result_gcs_path is not None and jax.process_index() == 0: + results_df = pd.DataFrame(results_data) + results_df.to_csv(results_file_name, index=False) + max_logging.log(f"Saved predictions to {results_file_name}") + gcs_utils.upload_blob(result_gcs_path, results_file_name) + max_logging.log(f"Uploaded the results file to GCS bucket: {result_gcs_path}") + + if local_args.remove_tmp_results and os.path.exists(results_file_name): + os.remove(results_file_name) + max_logging.log(f"Removed temporary results file: {results_file_name}") + + +def validate_config(config): + assert not config.load_full_state_path, ( + "Decode doesn't operate on full states! Convert to parameter checkpoint" + " first. Using generate_param_only_checkpoint." + ) + assert ( + config.hf_data_dir + ), "For benchmark evaluation, please specify the HuggingFace dataset name using the hf_data_dir config field." + assert config.hf_data_dir in SUPPORTED_DATASETS, ( + f"Unsupported dataset {config.hf_data_dir}. Supported datasets are: {SUPPORTED_DATASETS}." + " Please add support for your desired dataset in the code of multimodal_eval.py." + ) + + +if __name__ == "__main__": + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + + parser = argparse.ArgumentParser() + parser.add_argument( + "--num_examples", type=int, required=False, default=-1, help="Number of examples to evaluate. Default to -1 (all)." + ) + parser.add_argument( + "--tmp_results_file", + type=str, + required=False, + default="mm_eval_results.csv", + help="Temporary results CSV file path.", + ) + parser.add_argument( + "--remove_tmp_results", + type=str2bool, + required=False, + default=True, + help="Whether to remove the temporary results CSV file after uploading to GCS.", + ) + parser.add_argument( + "--ckpt_type", + type=str, + required=False, + default="base", + choices=["base", "sft"], + help=( + "Checkpoint type: 'base' (uses DEFAULT_PROMPT_TEMPLATE) or 'sft' (uses" + " SFT_PROMPT_TEMPLATE with model-specific reformat_prompt)." + ), + ) + + _local_args, remaining_args = parser.parse_known_args() + model_args = [sys.argv[0]] + remaining_args + + cfg = pyconfig.initialize(model_args) + validate_config(cfg) + max_utils.print_system_information() + main(cfg, _local_args) diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 20594bccc3..0ab5495ca2 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1809,6 +1809,7 @@ class MultimodalGeneral(BaseModel): use_mrope: bool = Field(False, description="Enable Multi-dimensional RoPE for Qwen3-Omni models.") mrope_section: list[int] = Field([24, 20, 20], description="Dimensions for temporal, height, width in MRoPE.") position_id_per_seconds: int = Field(25, description="Temporal granularity for MRoPE (tokens per second).") + image_resize: int = Field(-1, description="Resize images for simpler multimodal decoding; -1 disables resizing.") class VisionTower(BaseModel):