@@ -18,7 +18,8 @@ diopiError_t diopiNLLLossV2(diopiContextHandle_t ctx, diopiTensorHandle_t out, d
1818 AscendTensor inputAt (input);
1919 if (inputAt.numel () <= 0 ) {
2020 if (diopiReduction_t::ReductionMean == reduction) {
21- DIOPI_ASCEND_CALL_ACLNN (aclnnInpalceFillScalar, ctx, out, std::nanf (" " ));
21+ diopiScalar_t nans{diopi_dtype_float64, std::nanf (" " )};
22+ DIOPI_ASCEND_CALL_ACLNN (aclnnInplaceFillScalar, ctx, out, &nans);
2223 } else if (diopiReduction_t::ReductionSum == reduction || diopiReduction_t::ReductionNone == reduction) {
2324 DIOPI_ASCEND_CALL_ACLNN (aclnnInplaceZero, ctx, out);
2425 }
@@ -44,6 +45,7 @@ diopiError_t diopiNLLLossV2(diopiContextHandle_t ctx, diopiTensorHandle_t out, d
4445 AscendTensor inputView = inputAt.view ({inputAt.shape (0 ), inputAt.shape (1 ), inputAt.numel () / inputAt.shape (0 ) / inputAt.shape (1 ), 1 });
4546 AscendTensor outView = (outAt.numel () > 1 ) ? outAt.view ({outAt.shape (0 ), outAt.numel () / outAt.shape (0 ), 1 }) : outAt;
4647 AscendTensor targetView = targetAt.view ({targetAt.shape (0 ), targetAt.numel () / targetAt.shape (0 ), 1 });
48+ DIOPI_ASCEND_CALL_ACLNN (aclnnNLLLoss2d, ctx, inputView, targetView, weightTmp, reduction, ignoreIndex, outView, totalWeight);
4749 }
4850
4951 return diopiSuccess;
@@ -85,7 +87,7 @@ diopiError_t diopiNLLLossV2Backward(diopiContextHandle_t ctx, diopiTensorHandle_
8587 gradInputAt.view ({gradInputAt.shape (0 ), gradInputAt.shape (1 ), gradInputAt.numel () / gradInputAt.shape (0 ) / gradInputAt.shape (1 ), 1 });
8688 AscendTensor gradOutputView;
8789 if (gradOutputAt.numel () > 1 ) {
88- gradOutputView.view ({gradOutputAt.shape (0 ), gradOutputAt.numel () / gradOutputAt.shape (0 ), 1 });
90+ gradOutputView = gradOutputAt .view ({gradOutputAt.shape (0 ), gradOutputAt.numel () / gradOutputAt.shape (0 ), 1 });
8991 } else {
9092 gradOutputView = gradOutputAt;
9193 }
0 commit comments