diff --git a/modeling_bailingmm.py b/modeling_bailingmm.py index 8182fd6..c4f3eae 100644 --- a/modeling_bailingmm.py +++ b/modeling_bailingmm.py @@ -434,7 +434,7 @@ def forward( if output_hidden_states is not None else self.config.output_hidden_states ) - use_cache = use_cache if use_cache is not None else self.config.use_cache + use_cache = use_cache if use_cache is not None else getattr(self.config, "use_cache", False) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): @@ -479,8 +479,25 @@ def forward( None, # noqa ) + if ( + self.config.llm_config.rope_scaling is not None + and self.config.llm_config.rope_scaling["type"] == "3D" + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_token_id=self.config.llm_config.image_patch_token, + video_token_id=self.config.llm_config.image_patch_token, + image_start_token_id=self.config.llm_config.image_start_token, + video_start_token_id=self.config.llm_config.video_start_token, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + ) + else: + rope_deltas = None + outputs = self.model( - input_ids=input_ids, + input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -587,18 +604,18 @@ def generate( assert audio_embeds is None assert position_ids is None condition_embeds = self.get_condition_embeds_for_image_gen( - input_ids=input_ids, + input_ids=input_ids, attention_mask=attention_mask, - image_embeds=image_embeds, + image_embeds=image_embeds, position_ids=position_ids, use_cache=use_cache, image_grid_thw=image_grid_thw, ) # negative_condition_embeds = self.get_condition_embeds_for_image_gen( - # input_ids=image_gen_negative_input_ids, + # input_ids=image_gen_negative_input_ids, # attention_mask=image_gen_negative_attention_mask, - # image_embeds=image_embeds, + # image_embeds=image_embeds, # position_ids=position_ids, # use_cache=use_cache, # image_grid_thw=image_grid_thw, @@ -608,7 +625,7 @@ def generate( if isinstance(image_gen_height, torch.Tensor): image_gen_height = int(image_gen_height.cpu().item()) - + if isinstance(image_gen_width, torch.Tensor): image_gen_width = int(image_gen_width.cpu().item()) @@ -627,7 +644,7 @@ def generate( "cfg_mode": image_gen_cfg_mode, "ref_x": pixel_values_reference, } - + image = self.diffusion_loss.sample( **sample_kwargs, ) @@ -701,13 +718,13 @@ def load_image_gen_modules(self, inference_model_path, torch_dtype=torch.float32 safetensors_path = hf_hub_download( repo_id=inference_model_path, filename="model.safetensors", - subfolder="mlp" + subfolder="mlp" ) with safe_open(safetensors_path, framework="pt") as f: temp_state_dict = {key: f.get_tensor(key) for key in f.keys()} self.query_tokens_dict = nn.ParameterDict() self.img_gen_scales = [4, 8, 16] - for scale in self.img_gen_scales: + for scale in self.img_gen_scales: num_tokens = scale * scale scale_name = f"{scale}x{scale}" #weights = temp_state_dict[f"query_tokens_dict.{scale_name}"] @@ -717,7 +734,7 @@ def load_image_gen_modules(self, inference_model_path, torch_dtype=torch.float32 self.query_tokens_dict.to(self.model.dtype).to(self.model.device) modified_state_dict_query_tokens = { f"{scale}x{scale}": temp_state_dict[f"query_tokens_dict.{scale}x{scale}"] - for scale in self.img_gen_scales + for scale in self.img_gen_scales } self.query_tokens_dict.load_state_dict(modified_state_dict_query_tokens, strict=True) # 计算各尺度的累积索引 @@ -726,7 +743,7 @@ def load_image_gen_modules(self, inference_model_path, torch_dtype=torch.float32 for scale in self.img_gen_scales: current_idx += scale * scale self.scale_indices.append(current_idx) - + diffusion_mlp_state_dict = { key[len("mlp.") :] : temp_state_dict[key] for key in temp_state_dict if key.startswith("mlp.") @@ -735,18 +752,18 @@ def load_image_gen_modules(self, inference_model_path, torch_dtype=torch.float32 if "sd3" in dit_type: from diffusion.sd3_loss import SD3Loss self.diffusion_loss = SD3Loss( - model_path=inference_model_path, - scheduler_path=inference_model_path, - vision_dim=self.model.config.hidden_size, + model_path=inference_model_path, + scheduler_path=inference_model_path, + vision_dim=self.model.config.hidden_size, mlp_state_dict=diffusion_mlp_state_dict, torch_dtype=torch_dtype, ) elif "sana" in dit_type: from diffusion.sana_loss import SANALoss self.diffusion_loss = SANALoss( - model_path=inference_model_path, - scheduler_path=inference_model_path, - vision_dim=self.model.config.hidden_size, + model_path=inference_model_path, + scheduler_path=inference_model_path, + vision_dim=self.model.config.hidden_size, mlp_state_dict=diffusion_mlp_state_dict, torch_dtype=torch_dtype, ) @@ -760,10 +777,10 @@ def load_image_gen_modules(self, inference_model_path, torch_dtype=torch.float32 for layer in self.connector.model.layers: layer.self_attn.is_causal = False self.connector.to(self.model.device) - + self.proj_in = nn.Linear(self.model.config.hidden_size, self.connector.config.hidden_size) self.proj_out = nn.Linear(self.connector.config.hidden_size, self.model.config.hidden_size) - + modified_state_dict_in = { 'weight': temp_state_dict['proj_in.weight'], 'bias': temp_state_dict['proj_in.bias'] @@ -802,7 +819,7 @@ def from_pretrained( ) if load_image_gen: model.load_image_gen_modules( - pretrained_model_name_or_path, + pretrained_model_name_or_path, torch_dtype=kwargs["torch_dtype"] if "torch_dtype" in kwargs else torch.float32, dit_type=dit_type, ) @@ -810,9 +827,9 @@ def from_pretrained( def get_condition_embeds_for_image_gen( self, - input_ids, + input_ids, attention_mask, - image_embeds, + image_embeds, position_ids, use_cache, image_grid_thw, @@ -828,7 +845,7 @@ def get_condition_embeds_for_image_gen( ) query_tokens_embeds = torch.cat( - [self.query_tokens_dict[f"{scale}x{scale}"] for scale in self.img_gen_scales], + [self.query_tokens_dict[f"{scale}x{scale}"] for scale in self.img_gen_scales], dim=0, ) if image_embeds is None: @@ -847,7 +864,7 @@ def get_condition_embeds_for_image_gen( else: image_grid_thw = torch.cat([image_grid_thw, new_image_grid_thw], dim=0) - + with torch.cuda.amp.autocast(dtype=torch.bfloat16): if image_embeds is None or input_ids.size(1) == 1: words_embeddings = self.model.get_input_embeddings()(input_ids.clip(0, self.model.get_input_embeddings().weight.shape[0] - 1)) @@ -859,7 +876,7 @@ def get_condition_embeds_for_image_gen( None, None, None, # noqa ) - if self.config.llm_config.rope_scaling is not None and self.config.llm_config.rope_scaling["type"] == "3D": + if self.config.llm_config.rope_scaling is not None and self.config.llm_config.rope_scaling["type"] == "3D": position_ids, _ = self.get_rope_index( input_ids, image_token_id=self.config.llm_config.image_patch_token, @@ -890,11 +907,11 @@ def get_condition_embeds_for_image_gen( scale_start_idxes = [0] + self.scale_indices[:-1] scale_end_idxes = self.scale_indices assert scale_end_idxes[-1] == hidden_states_gen.shape[1] - + scale, scale_start_idx, scale_end_idx = [ i for i in zip(self.img_gen_scales, scale_start_idxes, scale_end_idxes) ][-1] - + scale_hidden = hidden_states_gen[:, scale_start_idx : scale_end_idx, :] # 处理当前尺度的特征 @@ -903,11 +920,11 @@ def get_condition_embeds_for_image_gen( seq_shape = scale_embeds.shape with torch.cuda.amp.autocast(dtype=torch.bfloat16): scale_embeds = self.connector( - inputs_embeds=scale_embeds, - attention_mask=torch.ones(seq_shape[0],1,seq_shape[1],seq_shape[1]).to(scale_embeds.device), + inputs_embeds=scale_embeds, + attention_mask=torch.ones(seq_shape[0],1,seq_shape[1],seq_shape[1]).to(scale_embeds.device), output_hidden_states=True ).hidden_states[-1] - + scale_embeds = self.proj_out(scale_embeds) # 归一化 scale_embeds = torch.nn.functional.normalize(scale_embeds, dim=-1)