diff --git a/src/maxtext/configs/models/gemma4-31b.yml b/src/maxtext/configs/models/gemma4-31b.yml index e055be25c5..710f30ec82 100644 --- a/src/maxtext/configs/models/gemma4-31b.yml +++ b/src/maxtext/configs/models/gemma4-31b.yml @@ -14,7 +14,7 @@ # model config for gemma4-31b Dense -base_num_decoder_layers: 60 +base_num_decoder_layers: 1 base_emb_dim: 5376 base_num_query_heads: 32 base_num_kv_heads: 16 diff --git a/src/maxtext/configs/models/qwen3-omni-30b-a3b.yml b/src/maxtext/configs/models/qwen3-omni-30b-a3b.yml index 050df6909e..1c2bc1076d 100644 --- a/src/maxtext/configs/models/qwen3-omni-30b-a3b.yml +++ b/src/maxtext/configs/models/qwen3-omni-30b-a3b.yml @@ -20,7 +20,7 @@ base_emb_dim: 2048 base_mlp_dim: 768 base_num_query_heads: 32 base_num_kv_heads: 4 -base_num_decoder_layers: 48 +base_num_decoder_layers: 1 head_dim: 128 mlp_activations: ["silu", "linear"] vocab_size: 152064 diff --git a/src/maxtext/input_pipeline/input_pipeline_utils.py b/src/maxtext/input_pipeline/input_pipeline_utils.py index 1b9910f433..e845539d28 100644 --- a/src/maxtext/input_pipeline/input_pipeline_utils.py +++ b/src/maxtext/input_pipeline/input_pipeline_utils.py @@ -714,9 +714,15 @@ def _pad_image_and_mask(self, preprocessed_image: mm_utils.PreprocessorOutput) - if preprocessed_image.pixel_values is None: raise ValueError("Input preprocessed_image must have pixel_values to pad images.") + if self.config.model_name and self.config.model_name.startswith("qwen3-omni"): + return preprocessed_image + # Determine the maximum number of images/masks allowed. image_offsets = mm_processor.get_image_offsets(self.config, preprocessed_image) - single_image_offset = image_offsets // preprocessed_image.pixel_values.shape[0] + num_images = getattr(preprocessed_image, "num_images", 0) + if num_images <= 0: + num_images = preprocessed_image.pixel_values.shape[0] + single_image_offset = image_offsets // num_images # Reserve space for at least one text token. max_num_items = (self.max_length - 1) // single_image_offset diff --git a/src/maxtext/multimodal/processor.py b/src/maxtext/multimodal/processor.py index 02418f262e..202e73444f 100644 --- a/src/maxtext/multimodal/processor.py +++ b/src/maxtext/multimodal/processor.py @@ -68,6 +68,10 @@ def preprocess_image_for_training(image, model_name): from maxtext.multimodal.processor_llama4 import preprocess_mm_data_llama4 # pylint: disable=import-outside-toplevel return preprocess_mm_data_llama4(image) + elif model_name in ["qwen3-omni-30b-a3b"]: + from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training # pylint: disable=import-outside-toplevel + + return preprocess_mm_data_qwen3_omni_for_training(image) else: raise ValueError(f"Model {model_name} not supported for image preprocessing.") diff --git a/src/maxtext/multimodal/processor_qwen3_omni.py b/src/maxtext/multimodal/processor_qwen3_omni.py index e784b5c0c3..01fbfdf640 100644 --- a/src/maxtext/multimodal/processor_qwen3_omni.py +++ b/src/maxtext/multimodal/processor_qwen3_omni.py @@ -122,7 +122,7 @@ def smart_resize( return h_bar, w_bar -def pre_process_qwen3_image(image: np.ndarray | list[np.ndarray], config): +def pre_process_qwen3_image(image: np.ndarray | list[np.ndarray], config, force_resize=None): """Performs a bi-linear resize (with anti-aliasing) and normalizes the image.""" patch_size = config.patch_size_for_vit merge_size = config.spatial_merge_size_for_vit @@ -135,23 +135,27 @@ def pre_process_qwen3_image(image: np.ndarray | list[np.ndarray], config): for img in images_in: pil_img = Image.fromarray(img) - # Qwen3-Omni performs one resize during fetch_image and another resize before patchify. - resized_height_1, resized_width_1 = smart_resize( - height=img.shape[0], - width=img.shape[1], - factor=IMAGE_FACTOR, - min_pixels=MIN_PIXELS, - max_pixels=MAX_PIXELS, - ) - pil_img = pil_img.resize((resized_width_1, resized_height_1)) - resized_height_2, resized_width_2 = smart_resize( - height=resized_height_1, - width=resized_width_1, - factor=patch_size * merge_size, - min_pixels=MIN_PIXELS, - max_pixels=MAX_PIXELS, - ) - resized_img_pil = pil_img.resize((resized_width_2, resized_height_2), resample=resample_method) + if force_resize is not None: + resized_height_2, resized_width_2 = force_resize + resized_img_pil = pil_img.resize((resized_width_2, resized_height_2), resample=resample_method) + else: + # Qwen3-Omni performs one resize during fetch_image and another resize before patchify. + resized_height_1, resized_width_1 = smart_resize( + height=img.shape[0], + width=img.shape[1], + factor=IMAGE_FACTOR, + min_pixels=MIN_PIXELS, + max_pixels=MAX_PIXELS, + ) + pil_img = pil_img.resize((resized_width_1, resized_height_1)) + resized_height_2, resized_width_2 = smart_resize( + height=resized_height_1, + width=resized_width_1, + factor=patch_size * merge_size, + min_pixels=MIN_PIXELS, + max_pixels=MAX_PIXELS, + ) + resized_img_pil = pil_img.resize((resized_width_2, resized_height_2), resample=resample_method) resized_img_np = np.array(resized_img_pil).astype(np.float32) img_np = mm_utils.normalize_images(resized_img_np, mean=IMAGE_MEAN, std=IMAGE_STD) @@ -474,6 +478,33 @@ def pre_process_audio_qwen3_omni(audio_array): return audio_features, audio_features_mask +def preprocess_mm_data_qwen3_omni_for_training(images): + """Preprocesses image(s) for Qwen3-Omni SFT training using default model constants.""" + + class _DefaultConfig: + patch_size_for_vit = 16 + spatial_merge_size_for_vit = 2 + temporal_patch_size_for_vit = QWEN3_TEMPORAL_PATCH_SIZE + + images_in = [images] if isinstance(images, np.ndarray) else images + pixel_values, pixel_grid_thw = pre_process_qwen3_image(images_in, _DefaultConfig(), force_resize=(768, 768)) + pixel_values = np.reshape( + pixel_values, + ( + len(images_in), + 3, # num_channels_for_vit + _DefaultConfig.temporal_patch_size_for_vit * pixel_grid_thw[0, 0], + _DefaultConfig.patch_size_for_vit * pixel_grid_thw[0, 1], + _DefaultConfig.patch_size_for_vit * pixel_grid_thw[0, 2], + ), + ) + return Qwen3OmniPreprocessorOutput( + num_images=len(images_in), + pixel_values=pixel_values, + pixel_grid_thw=pixel_grid_thw, + ) + + def preprocess_mm_data_qwen3_omni(config): """Placeholder for multimodal data preprocessing.""" processor_outputs = Qwen3OmniPreprocessorOutput()