Skip to content

Commit bc2b071

Browse files
committed
fix fill inplace call for aclnn
1 parent f6baaf8 commit bc2b071

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

impl/ascend/functions/nlllossv2.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ diopiError_t diopiNLLLossV2(diopiContextHandle_t ctx, diopiTensorHandle_t out, d
1818
AscendTensor inputAt(input);
1919
if (inputAt.numel() <= 0) {
2020
if (diopiReduction_t::ReductionMean == reduction) {
21-
DIOPI_ASCEND_CALL_ACLNN(aclnnInpalceFillScalar, ctx, out, std::nanf(""));
21+
diopiScalar_t nans{diopi_dtype_float64, std::nanf("")};
22+
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceFillScalar, ctx, out, &nans);
2223
} else if (diopiReduction_t::ReductionSum == reduction || diopiReduction_t::ReductionNone == reduction) {
2324
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceZero, ctx, out);
2425
}
@@ -44,6 +45,7 @@ diopiError_t diopiNLLLossV2(diopiContextHandle_t ctx, diopiTensorHandle_t out, d
4445
AscendTensor inputView = inputAt.view({inputAt.shape(0), inputAt.shape(1), inputAt.numel() / inputAt.shape(0) / inputAt.shape(1), 1});
4546
AscendTensor outView = (outAt.numel() > 1) ? outAt.view({outAt.shape(0), outAt.numel() / outAt.shape(0), 1}) : outAt;
4647
AscendTensor targetView = targetAt.view({targetAt.shape(0), targetAt.numel() / targetAt.shape(0), 1});
48+
DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss2d, ctx, inputView, targetView, weightTmp, reduction, ignoreIndex, outView, totalWeight);
4749
}
4850

4951
return diopiSuccess;
@@ -85,7 +87,7 @@ diopiError_t diopiNLLLossV2Backward(diopiContextHandle_t ctx, diopiTensorHandle_
8587
gradInputAt.view({gradInputAt.shape(0), gradInputAt.shape(1), gradInputAt.numel() / gradInputAt.shape(0) / gradInputAt.shape(1), 1});
8688
AscendTensor gradOutputView;
8789
if (gradOutputAt.numel() > 1) {
88-
gradOutputView.view({gradOutputAt.shape(0), gradOutputAt.numel() / gradOutputAt.shape(0), 1});
90+
gradOutputView = gradOutputAt.view({gradOutputAt.shape(0), gradOutputAt.numel() / gradOutputAt.shape(0), 1});
8991
} else {
9092
gradOutputView = gradOutputAt;
9193
}

0 commit comments

Comments
 (0)