Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 48 additions & 31 deletions modeling_bailingmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def forward(
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = use_cache if use_cache is not None else getattr(self.config, "use_cache", False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (input_ids is None) ^ (inputs_embeds is not None):
Expand Down Expand Up @@ -479,8 +479,25 @@ def forward(
None, # noqa
)

if (
self.config.llm_config.rope_scaling is not None
and self.config.llm_config.rope_scaling["type"] == "3D"
):
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_token_id=self.config.llm_config.image_patch_token,
video_token_id=self.config.llm_config.image_patch_token,
image_start_token_id=self.config.llm_config.image_start_token,
video_start_token_id=self.config.llm_config.video_start_token,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
attention_mask=attention_mask,
)
else:
rope_deltas = None

outputs = self.model(
input_ids=input_ids,
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
Expand Down Expand Up @@ -587,18 +604,18 @@ def generate(
assert audio_embeds is None
assert position_ids is None
condition_embeds = self.get_condition_embeds_for_image_gen(
input_ids=input_ids,
input_ids=input_ids,
attention_mask=attention_mask,
image_embeds=image_embeds,
image_embeds=image_embeds,
position_ids=position_ids,
use_cache=use_cache,
image_grid_thw=image_grid_thw,
)

# negative_condition_embeds = self.get_condition_embeds_for_image_gen(
# input_ids=image_gen_negative_input_ids,
# input_ids=image_gen_negative_input_ids,
# attention_mask=image_gen_negative_attention_mask,
# image_embeds=image_embeds,
# image_embeds=image_embeds,
# position_ids=position_ids,
# use_cache=use_cache,
# image_grid_thw=image_grid_thw,
Expand All @@ -608,7 +625,7 @@ def generate(

if isinstance(image_gen_height, torch.Tensor):
image_gen_height = int(image_gen_height.cpu().item())

if isinstance(image_gen_width, torch.Tensor):
image_gen_width = int(image_gen_width.cpu().item())

Expand All @@ -627,7 +644,7 @@ def generate(
"cfg_mode": image_gen_cfg_mode,
"ref_x": pixel_values_reference,
}

image = self.diffusion_loss.sample(
**sample_kwargs,
)
Expand Down Expand Up @@ -701,13 +718,13 @@ def load_image_gen_modules(self, inference_model_path, torch_dtype=torch.float32
safetensors_path = hf_hub_download(
repo_id=inference_model_path,
filename="model.safetensors",
subfolder="mlp"
subfolder="mlp"
)
with safe_open(safetensors_path, framework="pt") as f:
temp_state_dict = {key: f.get_tensor(key) for key in f.keys()}
self.query_tokens_dict = nn.ParameterDict()
self.img_gen_scales = [4, 8, 16]
for scale in self.img_gen_scales:
for scale in self.img_gen_scales:
num_tokens = scale * scale
scale_name = f"{scale}x{scale}"
#weights = temp_state_dict[f"query_tokens_dict.{scale_name}"]
Expand All @@ -717,7 +734,7 @@ def load_image_gen_modules(self, inference_model_path, torch_dtype=torch.float32
self.query_tokens_dict.to(self.model.dtype).to(self.model.device)
modified_state_dict_query_tokens = {
f"{scale}x{scale}": temp_state_dict[f"query_tokens_dict.{scale}x{scale}"]
for scale in self.img_gen_scales
for scale in self.img_gen_scales
}
self.query_tokens_dict.load_state_dict(modified_state_dict_query_tokens, strict=True)
# 计算各尺度的累积索引
Expand All @@ -726,7 +743,7 @@ def load_image_gen_modules(self, inference_model_path, torch_dtype=torch.float32
for scale in self.img_gen_scales:
current_idx += scale * scale
self.scale_indices.append(current_idx)

diffusion_mlp_state_dict = {
key[len("mlp.") :] : temp_state_dict[key]
for key in temp_state_dict if key.startswith("mlp.")
Expand All @@ -735,18 +752,18 @@ def load_image_gen_modules(self, inference_model_path, torch_dtype=torch.float32
if "sd3" in dit_type:
from diffusion.sd3_loss import SD3Loss
self.diffusion_loss = SD3Loss(
model_path=inference_model_path,
scheduler_path=inference_model_path,
vision_dim=self.model.config.hidden_size,
model_path=inference_model_path,
scheduler_path=inference_model_path,
vision_dim=self.model.config.hidden_size,
mlp_state_dict=diffusion_mlp_state_dict,
torch_dtype=torch_dtype,
)
elif "sana" in dit_type:
from diffusion.sana_loss import SANALoss
self.diffusion_loss = SANALoss(
model_path=inference_model_path,
scheduler_path=inference_model_path,
vision_dim=self.model.config.hidden_size,
model_path=inference_model_path,
scheduler_path=inference_model_path,
vision_dim=self.model.config.hidden_size,
mlp_state_dict=diffusion_mlp_state_dict,
torch_dtype=torch_dtype,
)
Expand All @@ -760,10 +777,10 @@ def load_image_gen_modules(self, inference_model_path, torch_dtype=torch.float32
for layer in self.connector.model.layers:
layer.self_attn.is_causal = False
self.connector.to(self.model.device)

self.proj_in = nn.Linear(self.model.config.hidden_size, self.connector.config.hidden_size)
self.proj_out = nn.Linear(self.connector.config.hidden_size, self.model.config.hidden_size)

modified_state_dict_in = {
'weight': temp_state_dict['proj_in.weight'],
'bias': temp_state_dict['proj_in.bias']
Expand Down Expand Up @@ -802,17 +819,17 @@ def from_pretrained(
)
if load_image_gen:
model.load_image_gen_modules(
pretrained_model_name_or_path,
pretrained_model_name_or_path,
torch_dtype=kwargs["torch_dtype"] if "torch_dtype" in kwargs else torch.float32,
dit_type=dit_type,
)
return model

def get_condition_embeds_for_image_gen(
self,
input_ids,
input_ids,
attention_mask,
image_embeds,
image_embeds,
position_ids,
use_cache,
image_grid_thw,
Expand All @@ -828,7 +845,7 @@ def get_condition_embeds_for_image_gen(
)

query_tokens_embeds = torch.cat(
[self.query_tokens_dict[f"{scale}x{scale}"] for scale in self.img_gen_scales],
[self.query_tokens_dict[f"{scale}x{scale}"] for scale in self.img_gen_scales],
dim=0,
)
if image_embeds is None:
Expand All @@ -847,7 +864,7 @@ def get_condition_embeds_for_image_gen(
else:
image_grid_thw = torch.cat([image_grid_thw, new_image_grid_thw], dim=0)


with torch.cuda.amp.autocast(dtype=torch.bfloat16):
if image_embeds is None or input_ids.size(1) == 1:
words_embeddings = self.model.get_input_embeddings()(input_ids.clip(0, self.model.get_input_embeddings().weight.shape[0] - 1))
Expand All @@ -859,7 +876,7 @@ def get_condition_embeds_for_image_gen(
None, None, None, # noqa
)

if self.config.llm_config.rope_scaling is not None and self.config.llm_config.rope_scaling["type"] == "3D":
if self.config.llm_config.rope_scaling is not None and self.config.llm_config.rope_scaling["type"] == "3D":
position_ids, _ = self.get_rope_index(
input_ids,
image_token_id=self.config.llm_config.image_patch_token,
Expand Down Expand Up @@ -890,11 +907,11 @@ def get_condition_embeds_for_image_gen(
scale_start_idxes = [0] + self.scale_indices[:-1]
scale_end_idxes = self.scale_indices
assert scale_end_idxes[-1] == hidden_states_gen.shape[1]

scale, scale_start_idx, scale_end_idx = [
i for i in zip(self.img_gen_scales, scale_start_idxes, scale_end_idxes)
][-1]

scale_hidden = hidden_states_gen[:, scale_start_idx : scale_end_idx, :]

# 处理当前尺度的特征
Expand All @@ -903,11 +920,11 @@ def get_condition_embeds_for_image_gen(
seq_shape = scale_embeds.shape
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
scale_embeds = self.connector(
inputs_embeds=scale_embeds,
attention_mask=torch.ones(seq_shape[0],1,seq_shape[1],seq_shape[1]).to(scale_embeds.device),
inputs_embeds=scale_embeds,
attention_mask=torch.ones(seq_shape[0],1,seq_shape[1],seq_shape[1]).to(scale_embeds.device),
output_hidden_states=True
).hidden_states[-1]

scale_embeds = self.proj_out(scale_embeds)
# 归一化
scale_embeds = torch.nn.functional.normalize(scale_embeds, dim=-1)
Expand Down