def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
if "lora" in inputs:
# Image-to-LoRA models need to load lora here.
pipe.clear_lora(verbose=0)
pipe.load_lora(pipe.dit, state_dict=inputs["lora"], hotload=True, verbose=0)
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
noise = torch.randn_like(inputs["input_latents"])
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
if "first_frame_latents" in inputs:
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
if "first_frame_latents" in inputs:
noise_pred = noise_pred[:, :, 1:]
training_target = training_target[:, :, 1:]
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
return loss
FlowMatchSFTLoss中只有对首帧输入的处理,没有隔离尾帧的latent,这样在首尾帧训练时会不会有影响?