diff --git a/bailingmm_utils.py b/bailingmm_utils.py index a0d0e81..c97da4c 100644 --- a/bailingmm_utils.py +++ b/bailingmm_utils.py @@ -41,6 +41,17 @@ FPS_MIN_FRAMES = 4 FPS_MAX_FRAMES = 128 +VideoInput = Union[ + List["Image.Image"], + "np.ndarray", + "torch.Tensor", + List["np.ndarray"], + List["torch.Tensor"], + List[List["Image.Image"]], + List[List["np.ndarrray"]], + List[List["torch.Tensor"]], +] + def is_decord_available() -> bool: import importlib.util return importlib.util.find_spec("decord") is not None @@ -504,3 +515,59 @@ def process_vision_info( if len(audio_inputs) == 0: audio_inputs = None return image_inputs, video_inputs, audio_inputs + +def get_closest_ratio(height: float, width: float, aspect_ratios: dict): + aspect_ratio = height / width + closest_ratio = min(aspect_ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) + return aspect_ratios[closest_ratio], float(closest_ratio) + +def process_ratio(ori_h, ori_w): + ASPECT_RATIO_512 = { + "0.25": [256, 1024], + "0.26": [256, 992], + "0.27": [256, 960], + "0.28": [256, 928], + "0.32": [288, 896], + "0.33": [288, 864], + "0.35": [288, 832], + "0.4": [320, 800], + "0.42": [320, 768], + "0.48": [352, 736], + "0.5": [352, 704], + "0.52": [352, 672], + "0.57": [384, 672], + "0.6": [384, 640], + "0.68": [416, 608], + "0.72": [416, 576], + "0.78": [448, 576], + "0.82": [448, 544], + "0.88": [480, 544], + "0.94": [480, 512], + "1.0": [512, 512], + "1.07": [512, 480], + "1.13": [544, 480], + "1.21": [544, 448], + "1.29": [576, 448], + "1.38": [576, 416], + "1.46": [608, 416], + "1.67": [640, 384], + "1.75": [672, 384], + "2.0": [704, 352], + "2.09": [736, 352], + "2.4": [768, 320], + "2.5": [800, 320], + "2.89": [832, 288], + "3.0": [864, 288], + "3.11": [896, 288], + "3.62": [928, 256], + "3.75": [960, 256], + "3.88": [992, 256], + "4.0": [1024, 256], + } + closest_size, _ = get_closest_ratio(ori_h, ori_w, aspect_ratios=ASPECT_RATIO_512) + closest_size = list(map(lambda x: int(x), closest_size)) + if closest_size[0] / ori_h > closest_size[1] / ori_w: + resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h) + else: + resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1] + return closest_size, resize_size \ No newline at end of file diff --git a/image_processing_bailingmm2.py b/image_processing_bailingmm2.py index 382383a..8f3030d 100644 --- a/image_processing_bailingmm2.py +++ b/image_processing_bailingmm2.py @@ -30,7 +30,7 @@ resize, to_channel_dimension_format, ) -from transformers.video_utils import VideoInput +from bailingmm_utils import VideoInput from transformers.image_utils import ( OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, diff --git a/modeling_bailing_moe_v2.py b/modeling_bailing_moe_v2.py index e87609f..16dd079 100644 --- a/modeling_bailing_moe_v2.py +++ b/modeling_bailing_moe_v2.py @@ -1561,6 +1561,8 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, second_per_grid_ts: Optional[torch.Tensor] = None, + image_mask=None, + audio_mask=None, **kwargs, ) -> Union[Tuple, MoeModelOutputWithPast]: @@ -1595,7 +1597,12 @@ def forward( # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + assert input_ids.size(1) == inputs_embeds.size(1), "{} vs {}".format( + input_ids.size, + inputs_embeds.size, + ) + batch_size, seq_length = inputs_embeds.shape[:2] + #raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: @@ -1810,6 +1817,8 @@ def forward( return_dict: Optional[bool] = None, second_per_grid_ts: Optional[torch.Tensor] = None, num_logits_to_keep: Optional[int] = 0, + image_mask=None, + audio_mask=None, **kwargs, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" @@ -1865,6 +1874,8 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, second_per_grid_ts=second_per_grid_ts, + image_mask=image_mask, + audio_mask=audio_mask, **kwargs, ) diff --git a/modeling_bailingmm2.py b/modeling_bailingmm2.py index 33a87f8..73e10ce 100644 --- a/modeling_bailingmm2.py +++ b/modeling_bailingmm2.py @@ -18,7 +18,8 @@ from configuration_bailingmm2 import BailingMM2Config from modeling_bailing_moe_v2 import BailingMoeV2ForCausalLM from modeling_utils import Transpose, encode_audio_segments, patch_continuous_features, build_modality_mask - +from bailingmm_utils import process_ratio +import os # vision encoder from qwen2_5_vit import Qwen2_5_VisionTransformer @@ -38,6 +39,7 @@ class BailingMM2NativeForConditionalGeneration(PreTrainedModel): def __init__( self, config: BailingMM2Config, + empty_load=False, ): super().__init__(config) self.config: BailingMM2Config = config @@ -45,6 +47,9 @@ def __init__( self.llm_dytpe = torch.bfloat16 + if empty_load: + return + if self.config.vision_config: self.vision = Qwen2_5_VisionTransformer(self.config.vision_config) @@ -120,9 +125,81 @@ def generate( video_grid_thw: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.Tensor]] = None, num_logits_to_keep: Optional[int] = 0, + image_gen: Optional[bool] = False, + image_gen_pixel_values_reference: Optional[torch.FloatTensor] = None, + image_gen_negative_input_ids: Optional[torch.LongTensor] = None, + image_gen_negative_attention_mask: Optional[torch.Tensor] = None, + image_gen_steps: Optional[int] = 30, + image_gen_seed: Optional[int] = 42, + image_gen_cfg: Optional[float] = 5.0, + image_gen_image_cfg: Optional[float] = 2.0, + image_gen_cfg_mode: Optional[int] = 1, + image_gen_height: Optional[int] = 512, + image_gen_width: Optional[int] = 512, + image_gen_llm_hidden_states: Optional[torch.LongTensor] = None, + image_gen_negative_llm_hidden_states: Optional[torch.LongTensor] = None, **generate_kwargs, ): image_embeds, video_embeds, audio_embeds, audio_embeds_lengths = None, None, None, None + + if image_gen: + if image_gen_llm_hidden_states is None: + assert self.model is not None + assert self.vision is not None + if pixel_values is not None: + image_embeds = self.extract_image_feature(pixel_values, grid_thw=image_grid_thw) + assert self.loaded_image_gen_modules is True, "please add `load_image_gen=True` in from_pretrained() method" + assert position_ids is None + condition_embeds = self.get_condition_embeds_for_image_gen( + input_ids=input_ids, + attention_mask=attention_mask, + image_embeds=image_embeds, + position_ids=position_ids, + use_cache=use_cache, + image_grid_thw=image_grid_thw, + llm_hidden_states=image_gen_llm_hidden_states, + ) + negative_condition_embeds = self.get_condition_embeds_for_image_gen( + input_ids=image_gen_negative_input_ids, + attention_mask=image_gen_negative_attention_mask, + image_embeds=image_embeds, + position_ids=position_ids, + use_cache=use_cache, + image_grid_thw=image_grid_thw, + llm_hidden_states=image_gen_negative_llm_hidden_states, + ) if image_gen_negative_input_ids is not None else condition_embeds * 0.0 + if isinstance(image_gen_height, torch.Tensor): + image_gen_height = int(image_gen_height.cpu().item()) + + if isinstance(image_gen_width, torch.Tensor): + image_gen_width = int(image_gen_width.cpu().item()) + closest_size, _ = process_ratio(ori_h=image_gen_height, ori_w=image_gen_width) + image_gen_height, image_gen_width = closest_size + if image_gen_seed is None or image_gen_seed < 0: + from datetime import datetime + image_gen_seed = datetime.now().microsecond % 1000 + sample_kwargs = { + "encoder_hidden_states": condition_embeds, + "steps": image_gen_steps, + "seed": image_gen_seed, + "cfg": image_gen_cfg, + "height": image_gen_height, + "width": image_gen_width, + "negative_encoder_hidden_states": negative_condition_embeds, + "image_cfg": image_gen_image_cfg, + "cfg_mode": image_gen_cfg_mode, + "ref_x": image_gen_pixel_values_reference, + } + print("image_gen_seed: ", image_gen_seed) + print("image_gen_steps: ", image_gen_steps) + print("image_gen_height: ", image_gen_height) + print("image_gen_width: ", image_gen_width) + + image = self.diffusion_loss.sample( + **sample_kwargs, + ) + return image + if pixel_values is not None: image_embeds = self.extract_image_feature(pixel_values, grid_thw=image_grid_thw) if pixel_values_videos is not None: @@ -150,6 +227,285 @@ def generate( ) return outputs + def load_image_gen_modules(self, inference_model_path, torch_dtype=torch.float32, dit_type="sd3"): + device = torch.device(torch.cuda.current_device()) + if self.model is not None: + device = self.model.device + from transformers import AutoModelForCausalLM + import os + from safetensors.torch import load_file + if os.path.exists(inference_model_path): + temp_state_dict = load_file(os.path.join(inference_model_path, 'mlp', 'model.safetensors')) + else: + from huggingface_hub import hf_hub_download + from safetensors import safe_open + safetensors_path = hf_hub_download( + repo_id=inference_model_path, + filename="model.safetensors", + subfolder="mlp" + ) + with safe_open(safetensors_path, framework="pt") as f: + temp_state_dict = {key: f.get_tensor(key) for key in f.keys()} + self.query_tokens_dict = nn.ParameterDict() + #self.img_gen_scales = [4, 8, 16] + self.img_gen_scales = [16] + for scale in self.img_gen_scales: + num_tokens = scale * scale + scale_name = f"{scale}x{scale}" + #weights = temp_state_dict[f"query_tokens_dict.{scale_name}"] + self.query_tokens_dict[scale_name] = nn.Parameter( + torch.nn.functional.normalize(torch.randn(num_tokens, self.config.llm_config.hidden_size), dim=-1) + ) + self.query_tokens_dict.to(torch_dtype).to(device) + modified_state_dict_query_tokens = { + f"{scale}x{scale}": temp_state_dict[f"query_tokens_dict.{scale}x{scale}"] + for scale in self.img_gen_scales + } + self.query_tokens_dict.load_state_dict(modified_state_dict_query_tokens, strict=True) + # 计算各尺度的累积索引 + self.scale_indices = [] + current_idx = 0 + for scale in self.img_gen_scales: + current_idx += scale * scale + self.scale_indices.append(current_idx) + + diffusion_mlp_state_dict = { + key[len("mlp.") :] : temp_state_dict[key] + for key in temp_state_dict if key.startswith("mlp.") + } + diffusion_c_input_dim = 2048 + if "sd3" in dit_type: + from diffusion.sd3_loss import SD3Loss + self.diffusion_loss = SD3Loss( + model_path=inference_model_path, + scheduler_path=inference_model_path, + vision_dim=diffusion_c_input_dim, + mlp_state_dict=diffusion_mlp_state_dict, + torch_dtype=torch_dtype, + ) + elif "sana" in dit_type: + from diffusion.sana_loss import SANALoss + self.diffusion_loss = SANALoss( + model_path=inference_model_path, + scheduler_path=inference_model_path, + vision_dim=diffusion_c_input_dim, + mlp_state_dict=diffusion_mlp_state_dict, + torch_dtype=torch_dtype, + ) + else: + raise ValueError("unsupported dit type: {}".format(dit_type)) + self.diffusion_loss.to(device) + #self.norm_query_embeds = True + # load connector + self.connector = AutoModelForCausalLM.from_pretrained(inference_model_path, subfolder='connector', torch_dtype=torch_dtype) + for layer in self.connector.model.layers: + layer.self_attn.is_causal = False + self.connector.to(device) + + + self.proj_in = nn.Linear(self.config.llm_config.hidden_size, self.connector.config.hidden_size) + self.proj_out = nn.Linear(self.connector.config.hidden_size, diffusion_c_input_dim) + + modified_state_dict_in = { + 'weight': temp_state_dict['proj_in.weight'], + 'bias': temp_state_dict['proj_in.bias'] + } + self.proj_in.load_state_dict(modified_state_dict_in, strict=True) + modified_state_dict_out = { + 'weight': temp_state_dict['proj_out.weight'], + 'bias': temp_state_dict['proj_out.bias'] + } + self.proj_out.load_state_dict(modified_state_dict_out, strict=True) + self.proj_in.to(device) + self.proj_out.to(device) + self.loaded_image_gen_modules = True + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + **kwargs, + ): + load_image_gen = False + if "load_image_gen" in kwargs: + load_image_gen = kwargs["load_image_gen"] + del kwargs["load_image_gen"] + dit_type = "sd3" + if "dit_type" in kwargs: + dit_type = kwargs["dit_type"] + del kwargs["dit_type"] + load_vlm = True + if "load_vlm" in kwargs: + load_vlm = kwargs["load_vlm"] + del kwargs["load_vlm"] + if load_vlm: + model = super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + **kwargs, + ) + else: + model = cls( + BailingMM2Config.from_dict(BailingMM2Config.get_config_dict(pretrained_model_name_or_path)[0]), + empty_load=True, + ) + if load_image_gen: + model.load_image_gen_modules( + pretrained_model_name_or_path, + torch_dtype=kwargs["torch_dtype"] if "torch_dtype" in kwargs else torch.float32, + dit_type=dit_type, + ) + return model + + def append_input_ids_with_multiscale_learnable_tokens( + self, + text_ids, + attention_mask, + scales, + start_token_id, + end_token_id, + patch_token_id, + ): + assert text_ids.shape[0] == 1 + assert attention_mask.shape == text_ids.shape + gen_mask = torch.zeros_like(attention_mask) + for scale in scales: + text_ids = torch.cat( + [ + text_ids, + torch.tensor([[start_token_id]]).to(text_ids.dtype).to(text_ids.device), + torch.tensor([[patch_token_id] * (scale**2)]) + .to(text_ids.dtype) + .to(text_ids.device), + torch.tensor([[end_token_id]]).to(text_ids.dtype).to(text_ids.device), + ], + dim=1, + ) + attention_mask = torch.cat( + [ + attention_mask, + torch.tensor([[1] * ((scale**2) + 2)]) + .to(attention_mask.dtype) + .to(attention_mask.device), + ], + dim=1, + ) + gen_mask = torch.cat( + [ + gen_mask, + torch.tensor([[0]]).to(gen_mask.dtype).to(gen_mask.device), + torch.tensor([[1] * (scale**2)]).to(gen_mask.dtype).to(gen_mask.device), + torch.tensor([[0]]).to(gen_mask.dtype).to(gen_mask.device), + ], + dim=1, + ) + assert text_ids.shape == attention_mask.shape + assert attention_mask.shape == gen_mask.shape + return text_ids, attention_mask, gen_mask + + def get_condition_embeds_for_image_gen( + self, + input_ids, + attention_mask, + image_embeds, + position_ids, + use_cache, + image_grid_thw, + llm_hidden_states, + ): + input_ids, attention_mask, gen_mask = self.append_input_ids_with_multiscale_learnable_tokens( + input_ids, + attention_mask, + self.img_gen_scales, + self.config.llm_config.image_patch_token + 1, + self.config.llm_config.image_patch_token + 2, + self.config.llm_config.image_patch_token, + ) + if llm_hidden_states is None: + query_tokens_embeds = torch.cat( + [self.query_tokens_dict[f"{scale}x{scale}"] for scale in self.img_gen_scales], + dim=0, + ) + if image_embeds is None: + image_embeds = query_tokens_embeds + else: + image_embeds = torch.cat([image_embeds, query_tokens_embeds], dim=0) + new_image_grid_thw = [] + for scale in self.img_gen_scales: + new_image_grid_thw.append([1, 2, scale * scale * 2]) + new_image_grid_thw = torch.tensor(new_image_grid_thw, dtype=input_ids.dtype).to(input_ids.device) + if image_grid_thw is None: + image_grid_thw = new_image_grid_thw + else: + image_grid_thw = torch.cat([image_grid_thw, new_image_grid_thw], dim=0) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + if image_embeds is None or input_ids.size(1) == 1: + words_embeddings = self.model.get_input_embeddings()(input_ids.clip(0, self.model.get_input_embeddings().weight.shape[0] - 1)) + image_mask = None + audio_mask = None + else: + words_embeddings, image_mask, audio_mask = self.model.model.prompt_wrap_navit( + input_ids.clip(0, self.model.get_input_embeddings().weight.shape[0] - 1), image_embeds, None, None, + None, None, None, # noqa + ) + # if self.config.llm_config.rope_scaling is not None and self.config.llm_config.rope_scaling["type"] == "3D": + # position_ids, _ = self.get_rope_index( + # input_ids, + # image_token_id=self.config.llm_config.image_patch_token, + # video_token_id=self.config.llm_config.image_patch_token, + # image_start_token_id=self.config.llm_config.image_start_token, + # video_start_token_id=self.config.llm_config.video_start_token, + # image_grid_thw=image_grid_thw, + # video_grid_thw=None, + # attention_mask=attention_mask, + # ) + assert input_ids.size(1) == words_embeddings.size(1), "{} vs {}".format( + input_ids.size, + words_embeddings.size, + ) + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=words_embeddings, + use_cache=False, + image_mask=image_mask, + audio_mask=audio_mask, + output_hidden_states=True, + ) + hidden_states = outputs.hidden_states[-1] + else: + hidden_states = llm_hidden_states + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + gen_mask = gen_mask.unsqueeze(-1).expand(gen_mask.shape[0], gen_mask.shape[1], hidden_states.shape[-1]).to(hidden_states.device).bool() + hidden_states_gen = torch.masked_select(hidden_states, gen_mask).view(hidden_states.shape[0], -1, hidden_states.shape[-1]) + # 分解hidden_states为不同尺度的表示 + scale_start_idxes = [0] + self.scale_indices[:-1] + scale_end_idxes = self.scale_indices + assert scale_end_idxes[-1] == hidden_states_gen.shape[1] + + scale, scale_start_idx, scale_end_idx = [ + i for i in zip(self.img_gen_scales, scale_start_idxes, scale_end_idxes) + ][-1] + + scale_hidden = hidden_states_gen[:, scale_start_idx : scale_end_idx, :] + # 处理当前尺度的特征 + scale_embeds = self.proj_in(scale_hidden) + seq_shape = scale_embeds.shape + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + scale_embeds = self.connector( + inputs_embeds=scale_embeds, + attention_mask=torch.ones(seq_shape[0],1,seq_shape[1],seq_shape[1]).to(scale_embeds.device), + output_hidden_states=True + ).hidden_states[-1] + + scale_embeds = self.proj_out(scale_embeds) + # 归一化 + scale_embeds = torch.nn.functional.normalize(scale_embeds, dim=-1) + return scale_embeds + __all__ = [ "BailingMM2NativeForConditionalGeneration" diff --git a/processing_bailingmm2.py b/processing_bailingmm2.py index cbd3ffa..71f6d58 100644 --- a/processing_bailingmm2.py +++ b/processing_bailingmm2.py @@ -28,14 +28,14 @@ from transformers.feature_extraction_utils import BatchFeature from transformers.image_utils import ImageInput -from transformers.video_utils import VideoInput from transformers.processing_utils import ( ProcessingKwargs, ProcessorMixin, ) from transformers.tokenization_utils_base import PreTokenizedInput, TextInput -from bailingmm_utils import process_vision_info +from bailingmm_utils import process_vision_info, VideoInput, process_ratio +import torchvision DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" @@ -194,11 +194,23 @@ def __call__( image_inputs = {} video_inputs = {} audio_inputs = {} + image_gen_inputs = {} if images is not None: image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] text = self._expand_image_tokens(text, image_grid_thw) + ref_pil = images[0] if isinstance(images, list) else images + ref_pil = ref_pil.convert("RGB") + closest_size, resize_size = process_ratio(ori_h=ref_pil.size[1], ori_w=ref_pil.size[0]) + ref_pil = torchvision.transforms.functional.resize(ref_pil, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR) + ref_pil = torchvision.transforms.functional.center_crop(ref_pil, closest_size) + ref_tensor = ((torchvision.transforms.functional.to_tensor(ref_pil) - 0.5) * 2.0).unsqueeze(0) + image_gen_inputs = { + "image_gen_pixel_values_reference": ref_tensor, + "image_gen_height": torch.LongTensor([ref_pil.size[1]]), + "image_gen_width": torch.LongTensor([ref_pil.size[0]]), + } if videos is not None: video_inputs = self.image_processor(images=None, videos=videos, do_resize=False, **output_kwargs["videos_kwargs"]) @@ -224,7 +236,7 @@ def __call__( audio_inputs["audio_placeholder_loc_lens"] = torch.tensor(loc_lens, dtype=torch.long) audio_inputs.pop('encoder_feats_lengths') - return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs, **audio_inputs}) + return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs, **audio_inputs, **image_gen_inputs}) def apply_system_template(self, sys_prompt_exp=None, use_cot_system_prompt=False): if use_cot_system_prompt: diff --git a/test_infer_gen_image.py b/test_infer_gen_image.py index b3e3a95..8e782d8 100644 --- a/test_infer_gen_image.py +++ b/test_infer_gen_image.py @@ -1,110 +1,113 @@ import os -import time import torch -from transformers import AutoProcessor +import time +import numpy as np +from bisect import bisect_left -from modeling_bailingmm import BailingMMNativeForConditionalGeneration from IPython import embed -import torchvision -from PIL import Image -import re - -import torch.nn as nn -from collections import defaultdict -from bailingmm_utils import process_ratio - -def auto_translate(model, processor, text): - if re.search(r'[\u4e00-\u9fff]', text): - prefix = "Translate the Chinese phrase below into natural English. Return only the translation result without any explanations, prefixes, or formatting. Phrase to translate:" - text = f"{prefix}{text}" - messages = [ - { - "role": "HUMAN", - "content": [ - {"type": "text", "text": text}, - ], - } - ] - text = processor.apply_chat_template(messages, add_generation_prompt=True) - image_inputs, video_inputs, audio_inputs = processor.process_vision_info(messages) - inputs = processor( - text=[text], - images=image_inputs, - videos=video_inputs, - audios=audio_inputs, - return_tensors="pt", - ).to(model.device) - - for k in inputs.keys(): - if k == "pixel_values" or k == "pixel_values_videos" or k == "audio_feats": - inputs[k] = inputs[k].to(dtype=torch.bfloat16) - - #srt_time = time.time() + +from transformers import ( + AutoProcessor, +) + +from modeling_bailingmm2 import BailingMM2NativeForConditionalGeneration + +import warnings + +warnings.filterwarnings("ignore") + + +def split_model(): + device_map = {} + world_size = torch.cuda.device_count() + num_layers = 32 + layer_per_gpu = num_layers // world_size + layer_per_gpu = [i * layer_per_gpu for i in range(1, world_size + 1)] + for i in range(num_layers): + device_map[f'model.model.layers.{i}'] = bisect_left(layer_per_gpu, i) + + device_map['vision'] = 0 + device_map['audio'] = 0 + device_map['linear_proj'] = 0 + device_map['linear_proj_audio'] = 0 + device_map['model.model.word_embeddings.weight'] = 0 + device_map['model.model.norm.weight'] = 0 + device_map['model.lm_head.weight'] = 0 + device_map['model.model.norm'] = 0 + device_map[f'model.model.layers.{num_layers - 1}'] = 0 + return device_map + +def generate(messages, processor, model, sys_prompt_exp=None, use_cot_system_prompt=False, max_new_tokens=512): + text = processor.apply_chat_template( + messages, + sys_prompt_exp=sys_prompt_exp, + use_cot_system_prompt=use_cot_system_prompt + ) + image_inputs, video_inputs, audio_inputs = processor.process_vision_info(messages) + + inputs = processor( + text=[text], + images=image_inputs, + videos=video_inputs, + audios=audio_inputs, + return_tensors="pt", + ).to(model.device) + + for k in inputs.keys(): + if k == "pixel_values" or k == "pixel_values_videos" or k == "audio_feats": + inputs[k] = inputs[k].to(dtype=torch.bfloat16) + + srt_time = time.time() + + with torch.no_grad(): generated_ids = model.generate( **inputs, - max_new_tokens=128, - use_cache=False, + max_new_tokens=max_new_tokens, + use_cache=True, eos_token_id=processor.gen_terminator, + num_logits_to_keep=1, ) - generated_ids_trimmed = [ - out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) - ] - text = processor.batch_decode( - generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False - )[0] - - return text - -def generate_gen_image( - model, - processor, - prompt, - height=None, - width=None, - input_image_path=None, - use_auto_translate=True, - debug=False, - POSITIVE_PREFIX_T2I=None, - POSITIVE_PREFIX_I2I="high quality", - NEGATIVE_PREFIX="worst quality, low quality, bad eyes, bad iris, twisted face, blurry, bad hand, watermark, multiple limbs, deformed fingers, bad fingers, ugly, monochrome, horror, geometry, bad anatomy, bad limbs, Blurry pupil, bad shading, error, bad composition, Extra fingers, strange fingers, Extra ears, extra leg, bad leg, disability, Blurry eyes, bad eyes, Twisted body, confusion, bad legs", - image_gen_steps=30, -): - # input_image_path 设置为 None, 运行 文生图,否则 图生图 - if height is None or width is None: - image_gen_width, image_gen_height = 512 * 1, 512 * 1 - else: - image_gen_width, image_gen_height = width, height - - closest_size, _ = process_ratio(ori_h=image_gen_height, ori_w=image_gen_width) - image_gen_height, image_gen_width = closest_size[0] * 1, closest_size[1] * 1 - - if use_auto_translate: - prompt_ori = prompt - prompt = auto_translate(model, processor, prompt) - if debug: - print("prompt: {} -> translated: {}".format(prompt_ori, prompt)) - - if input_image_path is not None: - if POSITIVE_PREFIX_I2I: - prompt = "{}; Requirement: {}".format(prompt, POSITIVE_PREFIX_I2I) - else: - if POSITIVE_PREFIX_T2I: - prompt = "{}; Requirement: {}".format(prompt, POSITIVE_PREFIX_T2I) - + + end_time = time.time() + + generated_ids_trimmed = [ + out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + + # tps = generated_ids.shape[1] / (end_time - srt_time) + # print(f"generated {generated_ids.shape[1]} tokens in {end_time - srt_time:.2f} seconds, tokens per second: {tps:.2f} tokens/s") + + output_text = processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + + return output_text + +if __name__ == '__main__': + model_name_or_path = "/nativemm/share/cpfs/weilong.cwl/checkpoints/Ming_Flash_2.0_sft1_merged" + #"/input/sunyunxiao.syx/checkpoints/Ming_Flash_2.0_sft1/" + model = BailingMM2NativeForConditionalGeneration.from_pretrained( + model_name_or_path, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map=split_model(), + load_image_gen=True, + ).to(dtype=torch.bfloat16) + + processor = AutoProcessor.from_pretrained("./", trust_remote_code=True) + + # gen_input_pixels = 451584 + # processor.image_processor.max_pixels = gen_input_pixels + # processor.image_processor.min_pixels = gen_input_pixels + messages = [ { "role": "HUMAN", "content": [ - {"type": "text", "text": prompt}, - ] if input_image_path is None else [ - {"type": "image", "image": input_image_path}, - {"type": "text", "text": prompt}, + {"type": "text", "text": "Draw a beautiful girl with short black hair and red dress."}, ], } ] - if debug: - print("messages:", messages) - text = processor.apply_chat_template(messages, add_generation_prompt=True) image_inputs, video_inputs, audio_inputs = processor.process_vision_info(messages) @@ -117,97 +120,121 @@ def generate_gen_image( return_tensors="pt", ).to(model.device) - if "image_gen_height" in inputs: - image_gen_height = inputs["image_gen_height"] - del inputs["image_gen_height"] - - if "image_gen_width" in inputs: - image_gen_width = inputs["image_gen_width"] - del inputs["image_gen_width"] - for k in inputs.keys(): - if k == "pixel_values" or k == "pixel_values_videos" or k == "audio_feats": + if k in ["pixel_values", "pixel_values_videos", "audio_feats", "pixel_values_reference"]: inputs[k] = inputs[k].to(dtype=torch.bfloat16) - if NEGATIVE_PREFIX: - negative_messages = [ - { - "role": "HUMAN", - "content": [ - {"type": "text", "text": NEGATIVE_PREFIX}, - ] if input_image_path is None else [ - {"type": "image", "image": input_image_path}, - {"type": "text", "text": NEGATIVE_PREFIX}, - ], - } - ] - if debug: - print("negative_messages:", negative_messages) - - negative_text = processor.apply_chat_template(negative_messages, add_generation_prompt=True) - negative_inputs = processor( - text=[negative_text], - images=image_inputs, - videos=None, - audios=None, - return_tensors="pt", - ).to(model.device) - inputs["image_gen_negative_input_ids"] = negative_inputs["input_ids"] - inputs["image_gen_negative_attention_mask"] = negative_inputs["attention_mask"] - + # set `image_gen=True` to enable image generation image = model.generate( **inputs, - max_new_tokens=128, - use_cache=False, - eos_token_id=processor.gen_terminator, image_gen=True, - image_gen_height=image_gen_height, - image_gen_width=image_gen_width, - image_gen_steps=image_gen_steps, ) + + image.save("./t2i_girl.jpg") - return image + print("Instruction: Draw a beautiful girl with short black hair and red dress.") -if __name__ == '__main__': - model_path = "inclusionAI/Ming-Lite-Omni" - processor = AutoProcessor.from_pretrained('.', trust_remote_code=True) - model = BailingMMNativeForConditionalGeneration.from_pretrained( - model_path, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - load_image_gen=True, - ).to("cuda") + vision_path = "/input/sunyunxiao.syx/assets/" - gen_input_pixels = 451584 - processor.image_processor.max_pixels = gen_input_pixels - processor.image_processor.min_pixels = gen_input_pixels - - image = generate_gen_image( - model=model, - processor=processor, - prompt="a beautiful girl wearing a red dress.", - POSITIVE_PREFIX_T2I="", - POSITIVE_PREFIX_I2I="", - image_gen_steps=30, - NEGATIVE_PREFIX="", - ) - image.save("./woman_red.jpg") - - image = generate_gen_image( - model=model, - processor=processor, - prompt="给人物戴上墨镜", - input_image_path="./woman_red.jpg", - POSITIVE_PREFIX_T2I="", - POSITIVE_PREFIX_I2I="", - image_gen_steps=30, - NEGATIVE_PREFIX="", - use_auto_translate=False, - ) - image.save("./woman_red_sunglasses.jpg") + messages = [ + { + "role": "HUMAN", + "content": [ + {"type": "image", "image": os.path.join(vision_path, "flowers.jpg")}, + {"type": "text", "text": "What kind of flower is this?"}, + ], + } + ] + srt_time = time.time() + output_text = generate(messages, processor=processor, model=model, max_new_tokens=512) + print(output_text) + print(f"Generate time: {(time.time() - srt_time):.2f}s") + + messages = [ + { + "role": "HUMAN", + "content": [ + {"type": "text", "text": "请介绍下你自己"} + ], + } + ] + + srt_time = time.time() + output_text = generate(messages, processor=processor, model=model, max_new_tokens=512) + print(output_text) + print(f"Generate time: {(time.time() - srt_time):.2f}s") + + messages = [ + { + "role": "HUMAN", + "content": [ + {"type": "video", "video": os.path.join(vision_path, "yoga.mp4")}, + {"type": "text", "text": "What is the woman doing?"}, + ], + } + ] + srt_time = time.time() + output_text = generate(messages, processor=processor, model=model, max_new_tokens=512) + print(output_text) + print(f"Generate time: {(time.time() - srt_time):.2f}s") + + messages = [ + { + "role": "HUMAN", + "content": [ + {"type": "text", "text": "中国的首都是哪里?"}, + ], + }, + { + "role": "ASSISTANT", + "content": [ + {"type": "text", "text": "北京"}, + ], + }, + { + "role": "HUMAN", + "content": [ + {"type": "text", "text": "它的占地面积是多少?有多少常住人口?"}, + ], + }, + ] + + srt_time = time.time() + output_text = generate(messages, processor=processor, model=model, max_new_tokens=512) + print(output_text) + print(f"Generate time: {(time.time() - srt_time):.2f}s") + + messages = [ + { + "role": "HUMAN", + "content": [ + {"type": "text", "text": "请详细介绍鹦鹉的生活习性。"} + ], + } + ] + + srt_time = time.time() + output_text = generate(messages, processor=processor, model=model, max_new_tokens=8192, use_cot_system_prompt=True) + print(output_text) + print(f"Generate time: {(time.time() - srt_time):.2f}s") + + + messages = [ + { + "role": "HUMAN", + "content": [ + {"type": "video", "video": os.path.join(vision_path, "yoga.mp4"), "max_frames": 40, "sample": "uniform"}, + {"type": "image", "image": os.path.join(vision_path, "flowers.jpg")}, + {"type": "text", "text": "What is the woman doing in the video and what kind of flower is in the image?"}, + ], + } + ] - \ No newline at end of file + srt_time = time.time() + output_text = generate(messages, processor=processor, model=model, max_new_tokens=512) + print(output_text) + print(f"Generate time: {(time.time() - srt_time):.2f}s") \ No newline at end of file