From f05d67f9fdfe524ce54dff47d32c19b1f8dbe812 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 23 Jan 2026 15:58:41 +0100 Subject: [PATCH] draft --- onnx_diagnostic/tasks/__init__.py | 2 + onnx_diagnostic/tasks/image_to_image.py | 586 ++++++++++++++++++ .../torch_models/hghub/hub_data.py | 6 +- 3 files changed, 593 insertions(+), 1 deletion(-) create mode 100644 onnx_diagnostic/tasks/image_to_image.py diff --git a/onnx_diagnostic/tasks/__init__.py b/onnx_diagnostic/tasks/__init__.py index f8f9663f..e12d73ae 100644 --- a/onnx_diagnostic/tasks/__init__.py +++ b/onnx_diagnostic/tasks/__init__.py @@ -5,6 +5,7 @@ fill_mask, image_classification, image_text_to_text, + image_to_image, image_to_video, mask_generation, mixture_of_expert, @@ -24,6 +25,7 @@ fill_mask, image_classification, image_text_to_text, + image_to_image, image_to_video, mask_generation, mixture_of_expert, diff --git a/onnx_diagnostic/tasks/image_to_image.py b/onnx_diagnostic/tasks/image_to_image.py new file mode 100644 index 00000000..e425e0bd --- /dev/null +++ b/onnx_diagnostic/tasks/image_to_image.py @@ -0,0 +1,586 @@ +import itertools +from typing import Any, Callable, Dict, Optional, Tuple +import torch +from ..helpers.cache_helper import make_dynamic_cache, get_make_hybrid_cache +from ..helpers.config_helper import ( + update_config, + check_hasattr, + _pick, + default_num_hidden_layers as nhl, +) +from .data import get_data + +__TASK__ = "image-to-image" + + +def reduce_model_config(config: Any) -> Dict[str, Any]: + """Reduces a model size.""" + kwargs: Dict[str, Any] = {} + if ( + hasattr(config, "architectures") + and config.architectures + and config.architectures[0] == "Gemma3ForConditionalGeneration" + ): + if hasattr(config, "vision_config"): + if hasattr(config.vision_config, "num_hidden_layers"): + config.vision_config.num_hidden_layers = min( + config.vision_config.num_hidden_layers, nhl() + ) + if hasattr(config, "text_config"): + if hasattr(config.text_config, "intermediate_size"): + config.text_config.intermediate_size = min( + config.text_config.intermediate_size, 10240 // 10 * 5 // 2 + ) + config.text_config.hidden_size = min( + config.text_config.hidden_size, 2560 // 10 * 5 // 2 + ) + update_config(config, kwargs) + return kwargs + + if hasattr(config, "num_hidden_layers"): + config.num_hidden_layers = min(config.num_hidden_layers, nhl()) + if hasattr(config, "mm_tokens_per_image"): + config.mm_tokens_per_image = min(config.mm_tokens_per_image, 2) + if hasattr(config, "vision_config"): + if hasattr(config.vision_config, "num_hidden_layers"): + config.vision_config.num_hidden_layers = min( + config.vision_config.num_hidden_layers, 2 + ) + if hasattr(config.vision_config, "num_heads"): + config.vision_config.num_heads = min(config.vision_config.num_heads, 4) + if hasattr(config.vision_config, "image_size"): + config.vision_config.image_size = min(config.vision_config.image_size, 168 // 2) + if hasattr(config.vision_config, "intermediate_size"): + config.vision_config.intermediate_size = min( + config.vision_config.intermediate_size, 1076 + ) + if hasattr(config.vision_config, "patch_size"): + config.vision_config.patch_size = min(config.vision_config.patch_size, 1) + if hasattr(config.vision_config, "temporal_patch_size"): + config.vision_config.temporal_patch_size = min( + config.vision_config.temporal_patch_size, 8 + ) + if hasattr(config.vision_config, "hidden_size"): + config.vision_config.hidden_size = min(config.vision_config.hidden_size, 16) + if hasattr(config, "text_config"): + if hasattr(config.text_config, "intermediate_size"): + config.text_config.intermediate_size = min( + config.text_config.intermediate_size, 320 + ) + if hasattr(config.text_config, "hidden_size"): + config.text_config.hidden_size = min(config.text_config.hidden_size, 16) + if hasattr(config.text_config, "num_hidden_layers"): + config.text_config.num_hidden_layers = min(config.text_config.num_hidden_layers, 2) + if hasattr(config.text_config, "layer_types"): + config.text_config.layer_types = config.text_config.layer_types[ + : config.text_config.num_hidden_layers + ] + if hasattr(config.text_config, "num_attention_heads"): + config.text_config.num_attention_heads = min( + config.text_config.num_attention_heads, 2 + ) + update_config(config, kwargs) + return kwargs + + +def _get_inputs_gemma3( + model: torch.nn.Module, + config: Optional[Any], + dummy_max_token_id: int, + num_key_value_heads: int, + num_hidden_layers: int, + pad_token_id: int, + image_token_index: int, + head_dim: int, + width: int, + height: int, + num_channels: int, + batch_size: Optional[int] = 1, + sequence_length: Optional[int] = 281, + n_images: Optional[int] = 1, + max_sequence_length: Optional[int] = 580, + total_sequence_length: Optional[int] = 860, + **kwargs, # unused +): + """ + The functions uses predefined values for input_ids and token_type_ids. + + **google/gemma-3-4b-it** + + iteration 1 + + :: + cache_position:T7s281, + input_ids:T7s1x281, + token_type_ids:T7s1x281, + attention_mask:dict(sliding_attention:T9s1x1x281x580, + full_attention:T9s1x1x281x580), + pixel_values:T16s1x3x896x896, + + iteration 2 + + :: + + cache_position:T7s1, + past_key_values:StaticCache(key_cache=#34[T1s1x4x580x256,...], + value_cache=#34[T1s1x4x580x256,...]), + input_ids:T7s1x1, + inputs_embeds:None, + token_type_ids:T7s1x1, + attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580), + position_ids:None, + """ + batch_size = 1 if batch_size is None else batch_size + sequence_length = 281 if sequence_length is None else sequence_length + n_images = 1 if n_images is None else n_images + max_sequence_length = 580 if max_sequence_length is None else max_sequence_length + total_sequence_length = 860 if total_sequence_length is None else total_sequence_length + + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." + batch = "batch" + seq_length = "seq_length" + tot_length = "total_length" + + shapes = { + "input_ids": {0: batch, 1: seq_length}, + "token_type_ids": {0: batch, 1: seq_length}, + "attention_mask": { + "full_attention": {0: batch, 2: seq_length, 3: tot_length}, + "sliding_attention": {0: batch, 2: seq_length, 3: tot_length}, + }, + "position_ids": {0: batch, 1: seq_length}, + "cache_position": {0: seq_length}, + "past_key_values": [{0: batch} for _ in range(num_hidden_layers * 2)], + "pixel_values": {0: batch}, + "use_cache": None, + } + + # retrieve specific inputs to keep the consistency between + # ids and images + dummies = get_data("dummies_imagetext2text_generation_gemma3.onnx") + dummies = dummies[("", 0, "I")][1] + dummies = {k: v for k, v in dummies.items() if k in shapes} + expected = {"input_ids", "token_type_ids", "position_ids", "cache_position"} + + def _check_(): + assert expected & set( + dummies + ), f"Unable to find expected inputs {expected} in loaded inputs {set(dummies)}" + # assert sequence_length == dummies["input_ids"].shape[-1], ( + # f"sequence_length={sequence_length} != {dummies['input_ids'].shape[-1]} for " + # f"model class {model.__class__.__name__}" + # ) + assert batch_size == dummies["input_ids"].shape[0], ( + f"batch_size={batch_size} != {dummies['input_ids'].shape[0]} for " + f"model class {model.__class__.__name__}" + ) + assert max_sequence_length == 580, ( + f"max_sequence_length={max_sequence_length} != 580 " + f"for model {model.__class__.__name__}" + ) + assert total_sequence_length == 860, ( + f"total_sequence_length={total_sequence_length} != 860 " + f"for model {model.__class__.__name__}" + ) + assert head_dim in ( + 256, + 32, + ), f"head_dim={head_dim} not in (32, 256) for model {model.__class__.__name__}" + assert n_images == 1, f"n_images={n_images} != 1 for model {model.__class__.__name__}" + assert num_key_value_heads in (1, 4), ( + f"num_key_value_heads={num_key_value_heads} not in (1, 4) " + f"for this model {model.__class__.__name__}" + ) + + _check_() + + make_hybrid_cache = get_make_hybrid_cache() + assert make_hybrid_cache is not None, "not implemented when make_hybrid_cache is missing" + + inputs = dict( + input_ids=dummies["input_ids"], + token_type_ids=dummies["token_type_ids"], + attention_mask=dict( + full_attention=torch.randn(batch_size, 1, sequence_length, total_sequence_length), + sliding_attention=torch.randn( + batch_size, 1, sequence_length, total_sequence_length + ), + ), + position_ids=torch.arange(0, sequence_length).to(torch.int64).expand((batch_size, -1)), + cache_position=torch.arange(0, sequence_length).to(torch.int64), + past_key_values=make_hybrid_cache( + [ + ( + torch.randn( + batch_size, num_key_value_heads, max_sequence_length, head_dim + ), + torch.randn( + batch_size, num_key_value_heads, max_sequence_length, head_dim + ), + ) + for i in range(num_hidden_layers) + ] + ), + pixel_values=torch.randn(n_images, num_channels, width, height).clamp(-1, 1), + # image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to( + # torch.int64 + # ), + use_cache=True, # Gemma3 does not set this value to true when a cache is provided + ) + return dict(inputs=inputs, dynamic_shapes=shapes) + + +def get_inputs_default( + model: torch.nn.Module, + config: Optional[Any], + dummy_max_token_id: int, + num_key_value_heads: int, + num_hidden_layers: int, + pad_token_id: int, + image_token_index: int, + head_dim: int, + width: int, + height: int, + num_channels: int, + batch_size: Optional[int] = 2, + sequence_length: Optional[int] = 43, + n_images: Optional[int] = 2, + max_sequence_length: Optional[int] = 43, + total_sequence_length: Optional[int] = 43, + add_second_input: int = 0, + **kwargs, # unused +): + batch_size = 2 if batch_size is None else batch_size + sequence_length = 43 if sequence_length is None else sequence_length + n_images = 2 if n_images is None else n_images + max_sequence_length = 43 if max_sequence_length is None else max_sequence_length + total_sequence_length = 43 if total_sequence_length is None else total_sequence_length + + assert batch_size > 0, "batch_size cannot be null" + assert ( + "cls_cache" not in kwargs + ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." + batch = "batch" + batch_img = torch.export.Dim("batch_img", min=1, max=1024) + seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) + cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) + images = "images" # torch.export.Dim("images", min=1, max=4096) + + shapes = { + "input_ids": {0: batch, 1: seq_length}, + "token_type_ids": {0: batch, 1: seq_length}, + "attention_mask": {0: batch, 1: "cache+seq"}, + "position_ids": {0: batch, 1: seq_length}, + "past_key_values": list( + itertools.chain.from_iterable( + zip( + [{0: batch} for _ in range(num_hidden_layers)], + [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], + ) + ) + ), + "pixel_values": ( + {0: batch, 1: images} + if model.__class__.__name__ == "IdeficsForVisionText2Text" + else {0: batch_img} + ), + "image_attention_mask": {0: batch, 1: seq_length, 2: images}, + "image_grid_thw": {0: batch}, + "use_cache": None, + } + + input_ids = torch.randint(0, dummy_max_token_id, (batch_size, total_sequence_length)).to( + torch.int64 + ) + if total_sequence_length > 0: + input_ids[0, 0] = image_token_index + if min(input_ids.shape) > 1: + input_ids[1, 1] = image_token_index + # input_ids[input_ids == image_token_index] = pad_token_id + token_type_ids = torch.zeros_like(input_ids) + token_type_ids[input_ids == image_token_index] = 1 + image_grid_thw = torch.zeros((n_images, 3), dtype=torch.int64) + if n_images > 0: + image_grid_thw[:, 1] = height + image_grid_thw[:, 2] = width + image_grid_thw[0, :] //= 2 + image_grid_thw[:, 0] = torch.arange(n_images, dtype=image_grid_thw.dtype) + + inputs = dict( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=torch.cat( + [ + torch.ones((batch_size, sequence_length), dtype=torch.int64), + ( + torch.ones(input_ids.shape) + if pad_token_id is None + else input_ids.ne(pad_token_id) + ).to(torch.int64), + ], + axis=-1, + ), + position_ids=torch.arange(0, total_sequence_length) + .to(torch.int64) + .expand((batch_size, -1)), + past_key_values=make_dynamic_cache( + [ + ( + torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim), + torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim), + ) + for i in range(num_hidden_layers) + ] + ), + pixel_values=( + torch.randn((batch_size, n_images, num_channels, width, height)).clamp(-1, 1) + if model.__class__.__name__ == "IdeficsForVisionText2Text" + else torch.randn(n_images, num_channels, width, height).clamp(-1, 1) + ), + image_attention_mask=torch.ones((batch_size, total_sequence_length, n_images)).to( + torch.int64 + ), + image_grid_thw=image_grid_thw, + use_cache=True, # Gemma3 does not set this value to true when a cache is provided + ) + res = dict(inputs=inputs, dynamic_shapes=shapes) + return res + + +def get_inputs( + model: torch.nn.Module, + config: Optional[Any], + dummy_max_token_id: int, + num_key_value_heads: int, + num_hidden_layers: int, + pad_token_id: int, + image_token_index: int, + head_dim: int, + width: int, + height: int, + num_channels: int, + batch_size: Optional[int] = None, + sequence_length: Optional[int] = None, + n_images: Optional[int] = None, + max_sequence_length: Optional[int] = None, + total_sequence_length: Optional[int] = None, + add_second_input: int = 0, + **kwargs, # unused +): + """ + Generates input for task ``image-text-to-text``. + + :param model: model to get the missing information + :param config: configuration used to generate the model + :param head_dim: last dimension of the cache + :param dummy_max_token_id: dummy max token id + :param pad_token_id: pad_token_id + :param image_token_index: image_token_index + :param batch_size: batch size + :param sequence_length: sequence length + :param max_sequence_length: for the cache + :param total_sequence_length: for the mask + :param n_images: number of images + :param width: width of the image + :param height: height of the image + :param num_channels: number of channels + :return: dictionary + + .. note:: + + The content of the input_ids and its shape is correlated to the images. + The function uses a predefined values. The function raises an exception + if dimension are not the expected ones. + """ + if model.__class__.__name__.startswith("Gemma3"): + res = _get_inputs_gemma3( + model, + config, + dummy_max_token_id=dummy_max_token_id, + num_key_value_heads=num_key_value_heads, + num_hidden_layers=num_hidden_layers, + pad_token_id=pad_token_id, + image_token_index=image_token_index, + head_dim=head_dim, + width=width, + height=height, + num_channels=num_channels, + batch_size=batch_size, + sequence_length=sequence_length, + max_sequence_length=max_sequence_length, + total_sequence_length=total_sequence_length, + n_images=n_images, + **kwargs, + ) + else: + res = get_inputs_default( + model, + config, + dummy_max_token_id=dummy_max_token_id, + num_key_value_heads=num_key_value_heads, + num_hidden_layers=num_hidden_layers, + pad_token_id=pad_token_id, + image_token_index=image_token_index, + head_dim=head_dim, + width=width, + height=height, + num_channels=num_channels, + batch_size=batch_size, + sequence_length=sequence_length, + max_sequence_length=max_sequence_length, + total_sequence_length=total_sequence_length, + n_images=n_images, + **kwargs, + ) + + if add_second_input: + assert ( + add_second_input > 0 + ), f"Not implemented for add_second_input={add_second_input}." + res["inputs2"] = get_inputs( + model=model, + config=config, + dummy_max_token_id=dummy_max_token_id, + num_key_value_heads=num_key_value_heads, + num_hidden_layers=num_hidden_layers, + head_dim=head_dim, + width=width, + height=height, + num_channels=num_channels, + batch_size=3, + sequence_length=1, + max_sequence_length=1, + total_sequence_length=1, + n_images=0, + pad_token_id=pad_token_id, + image_token_index=image_token_index, + add_second_input=0, + **kwargs, + )["inputs"] + return res + + +def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: + """ + Inputs kwargs. + + If the configuration is None, the function selects typical dimensions. + """ + if config is not None: + if hasattr(config, "text_config"): + check_hasattr( + config.text_config, + "vocab_size", + "hidden_size", + "num_attention_heads", + ("num_key_value_heads", "num_attention_heads"), + "intermediate_size", + "hidden_size", + "pad_token_id", + ) + text_config = True + else: + check_hasattr( + config, + "vocab_size", + "hidden_size", + "num_attention_heads", + ("num_key_value_heads", "num_attention_heads"), + "intermediate_size", + "hidden_size", + "vision_config", + ) + text_config = False + kwargs = dict( + head_dim=( + 16 + if config is None + else getattr( + config, + "head_dim", + ( + config.text_config.head_dim + if text_config and hasattr(config.text_config, "head_dim") + else ( + (config.text_config.hidden_size if text_config else config.hidden_size) + // ( + config.text_config.num_attention_heads + if text_config + else config.num_attention_heads + ) + ) + ), + ) + ), + dummy_max_token_id=( + 31999 + if config is None + else (config.text_config.vocab_size if text_config else config.vocab_size) - 1 + ), + num_hidden_layers=( + 4 + if config is None + else ( + config.text_config.num_hidden_layers + if text_config + else config.num_hidden_layers + ) + ), + num_key_value_heads=( + 8 + if config is None + else ( + _pick(config.text_config, "num_key_value_heads", "num_attention_heads") + if text_config + else _pick(config, "num_key_value_heads", "num_attention_heads") + ) + ), + intermediate_size=( + 1024 + if config is None + else ( + config.text_config.intermediate_size + if text_config + else config.intermediate_size + ) + ), + hidden_size=( + 512 + if config is None + else (config.text_config.hidden_size if text_config else config.hidden_size) + ), + width=( + 224 + if config is None or not hasattr(config.vision_config, "image_size") + else config.vision_config.image_size + ), + height=( + 224 + if config is None or not hasattr(config.vision_config, "image_size") + else config.vision_config.image_size + ), + num_channels=( + 3 + if config is None + else _pick(config.vision_config, "num_channels", "in_chans", "in_channels") + ), + pad_token_id=( + 0 + if config is None + or not hasattr(config, "text_config") + or not hasattr(config.text_config, "pad_token_id") + else config.text_config.pad_token_id + ), + image_token_index=( + 4 + if config is None + or ( + not hasattr(config, "image_token_index") + and not hasattr(config, "image_token_id") + ) + else _pick(config, "image_token_index", "image_token_id") + ), + ) + return kwargs, get_inputs diff --git a/onnx_diagnostic/torch_models/hghub/hub_data.py b/onnx_diagnostic/torch_models/hghub/hub_data.py index 1f81980c..d33344be 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data.py @@ -165,7 +165,10 @@ YolosModel,image-feature-extraction Alibaba-NLP/gte-large-en-v1.5,sentence-similarity emilyalsentzer/Bio_ClinicalBERT,fill-mask - nvidia/Cosmos-Predict2-2B-Video2World//transformer,image-to-video""") + nvidia/Cosmos-Predict2-2B-Video2World//transformer,image-to-video + diffusers/FLUX.2-dev-bnb-4bit,image-to-image + black-forest-labs/FLUX.2-dev,image-to-image + optimum-intel-internal-testing/tiny-random-stable-diffusion-xl-refiner,image-to-image""") __data_tasks__ = [ "audio-classification", @@ -177,6 +180,7 @@ "image-feature-extraction", "image-segmentation", "image-text-to-text", + "image-to-image", "image-to-text", "keypoint-detection", "mask-generation",