|
2 | 2 | #include "../../../devices/metax/metax_handle.h" |
3 | 3 | #include "../../../devices/metax/metax_kernel_common.h" |
4 | 4 | #include "binary_cross_entropy_with_logits_metax.h" |
| 5 | +#if defined(ENABLE_METAX_MC_API) |
5 | 6 | #include <mc_runtime.h> |
| 7 | +#else |
| 8 | +#include <hc_runtime.h> |
| 9 | +#endif |
6 | 10 | #include <type_traits> |
7 | 11 |
|
8 | 12 | namespace op::bce_with_logits::metax { |
@@ -191,7 +195,7 @@ infiniStatus_t Descriptor::calculate( |
191 | 195 | const void *pos_weight, |
192 | 196 | void *stream) const { |
193 | 197 |
|
194 | | - mcStream_t custream = (mcStream_t)stream; |
| 198 | + hcStream_t custream = (hcStream_t)stream; |
195 | 199 | size_t n = _info.num_elements; |
196 | 200 |
|
197 | 201 | // F16/BF16 + 归约需要 float workspace |
@@ -219,7 +223,7 @@ infiniStatus_t Descriptor::calculate( |
219 | 223 | case INFINI_DTYPE_F32: { |
220 | 224 | // 如果是规约操作,计算前需将输出位置清零 |
221 | 225 | if (_reduction != INFINIOP_REDUCTION_NONE) { |
222 | | - mcMemsetAsync(out, 0, sizeof(float), custream); |
| 226 | + hcMemsetAsync(out, 0, sizeof(float), custream); |
223 | 227 | } |
224 | 228 |
|
225 | 229 | bce_logits_kernel<float, float><<<grid, block, 0, custream>>>( |
@@ -255,7 +259,7 @@ infiniStatus_t Descriptor::calculate( |
255 | 259 | out_raw = out; |
256 | 260 | } else { |
257 | 261 | workspace_f = static_cast<float *>(workspace); |
258 | | - mcMemsetAsync(workspace_f, 0, sizeof(float), custream); |
| 262 | + hcMemsetAsync(workspace_f, 0, sizeof(float), custream); |
259 | 263 | out_raw = workspace_f; |
260 | 264 | } |
261 | 265 |
|
@@ -294,7 +298,7 @@ infiniStatus_t Descriptor::calculate( |
294 | 298 | out_raw = out; |
295 | 299 | } else { |
296 | 300 | workspace_f = static_cast<float *>(workspace); |
297 | | - mcMemsetAsync(workspace_f, 0, sizeof(float), custream); |
| 301 | + hcMemsetAsync(workspace_f, 0, sizeof(float), custream); |
298 | 302 | out_raw = workspace_f; |
299 | 303 | } |
300 | 304 |
|
@@ -324,8 +328,8 @@ infiniStatus_t Descriptor::calculate( |
324 | 328 | return INFINI_STATUS_BAD_TENSOR_DTYPE; |
325 | 329 | } |
326 | 330 |
|
327 | | - mcError_t err = mcGetLastError(); |
328 | | - if (err != mcSuccess) { |
| 331 | + hcError_t err = hcGetLastError(); |
| 332 | + if (err != hcSuccess) { |
329 | 333 | return INFINI_STATUS_INTERNAL_ERROR; |
330 | 334 | } |
331 | 335 | return INFINI_STATUS_SUCCESS; |
|
0 commit comments