@@ -246,15 +246,15 @@ impl<B: Backend> BiRefNetLoss<B> {
246246 let mut total_loss = pixel_loss;
247247
248248 // Add classification loss if enabled and predictions are provided
249- if let Some ( ref cls_loss) = self . classification_loss {
250- if let ( Some ( class_preds) , Some ( class_targets) ) = ( class_preds, class_targets) {
251- let cls_loss_value = cls_loss
252- . forward ( class_preds , class_targets , Reduction :: Mean )
253- . mul_scalar ( self . classification_weight )
254- . mul_scalar ( self . global_scale ) ;
255-
256- total_loss = total_loss + cls_loss_value ;
257- }
249+ if let Some ( ref cls_loss) = self . classification_loss
250+ && let ( Some ( class_preds) , Some ( class_targets) ) = ( class_preds, class_targets)
251+ {
252+ let cls_loss_value = cls_loss
253+ . forward ( class_preds , class_targets , Reduction :: Mean )
254+ . mul_scalar ( self . classification_weight )
255+ . mul_scalar ( self . global_scale ) ;
256+
257+ total_loss = total_loss + cls_loss_value ;
258258 }
259259
260260 Ok ( total_loss)
@@ -314,21 +314,21 @@ impl<B: Backend> BiRefNetLoss<B> {
314314 let mut total_loss = pixel_loss;
315315
316316 // Add classification loss if enabled and predictions are provided
317- if let Some ( ref cls_loss) = self . classification_loss {
318- if let ( Some ( class_preds) , Some ( class_targets) ) = ( class_preds, class_targets) {
319- let cls_loss_value = cls_loss
320- . forward ( class_preds , class_targets , Reduction :: Mean )
321- . mul_scalar ( self . classification_weight )
322- . mul_scalar ( self . global_scale ) ;
323-
324- total_loss = total_loss . clone ( ) + cls_loss_value . clone ( ) ;
325-
326- // Add classification loss to dictionary
327- loss_dict . insert (
328- "classification" . to_owned ( ) ,
329- cls_loss_value . into_scalar ( ) . to_f64 ( ) ,
330- ) ;
331- }
317+ if let Some ( ref cls_loss) = self . classification_loss
318+ && let ( Some ( class_preds) , Some ( class_targets) ) = ( class_preds, class_targets)
319+ {
320+ let cls_loss_value = cls_loss
321+ . forward ( class_preds , class_targets , Reduction :: Mean )
322+ . mul_scalar ( self . classification_weight )
323+ . mul_scalar ( self . global_scale ) ;
324+
325+ total_loss = total_loss . clone ( ) + cls_loss_value . clone ( ) ;
326+
327+ // Add classification loss to dictionary
328+ loss_dict . insert (
329+ "classification" . to_owned ( ) ,
330+ cls_loss_value . into_scalar ( ) . to_f64 ( ) ,
331+ ) ;
332332 }
333333
334334 // Add total loss to dictionary
0 commit comments