Skip to content

Commit 36f1ac1

Browse files
committed
feat: Add SVE kernels for TopKV.
Resolves MLCE-1719 Change-Id: I7a0c7bd1154b9cb7f35c7fd1c3b8ad54698f8799 Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com>
1 parent 6305151 commit 36f1ac1

11 files changed

Lines changed: 516 additions & 9 deletions

File tree

filelist.json

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2453,6 +2453,7 @@
24532453
]
24542454
}
24552455
},
2456+
24562457
"TopKV": {
24572458
"files": {
24582459
"common": [
@@ -2463,15 +2464,17 @@
24632464
"neon": {
24642465
"fp16": [ "src/cpu/kernels/topkv/generic/neon/fp16.cpp" ],
24652466
"fp32": [ "src/cpu/kernels/topkv/generic/neon/fp32.cpp" ],
2466-
"integer":["src/cpu/kernels/topkv/generic/neon/integer.cpp"],
2467-
"qasymm8": [
2468-
"src/cpu/kernels/topkv/generic/neon/qasymm8.cpp"
2469-
],
2470-
"qasymm8_signed": [
2471-
"src/cpu/kernels/topkv/generic/neon/qasymm8_signed.cpp"
2472-
]
2467+
"integer": [ "src/cpu/kernels/topkv/generic/neon/integer.cpp" ],
2468+
"qasymm8": [ "src/cpu/kernels/topkv/generic/neon/qasymm8.cpp" ],
2469+
"qasymm8_signed": [ "src/cpu/kernels/topkv/generic/neon/qasymm8_signed.cpp" ]
2470+
},
2471+
"sve": {
2472+
"fp32": [ "src/cpu/kernels/topkv/generic/sve/fp32.cpp" ],
2473+
"fp16": [ "src/cpu/kernels/topkv/generic/sve/fp16.cpp" ],
2474+
"integer": [ "src/cpu/kernels/topkv/generic/sve/integer.cpp" ],
2475+
"qasymm8": [ "src/cpu/kernels/topkv/generic/sve/qasymm8.cpp" ],
2476+
"qasymm8_signed": [ "src/cpu/kernels/topkv/generic/sve/qasymm8_signed.cpp" ]
24732477
}
2474-
24752478
}
24762479
},
24772480
"Transpose": {

src/BUILD.bazel

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,12 @@ filegroup(
395395
"cpu/kernels/scale/sve/qasymm8.cpp",
396396
"cpu/kernels/scale/sve/qasymm8_signed.cpp",
397397
"cpu/kernels/softmax/generic/sve/impl.cpp",
398-
"cpu/kernels/softmax/generic/sve/impl_bf16.cpp"] +
398+
"cpu/kernels/softmax/generic/sve/impl_bf16.cpp",
399+
"cpu/kernels/topkv/generic/sve/fp16.cpp",
400+
"cpu/kernels/topkv/generic/sve/fp32.cpp",
401+
"cpu/kernels/topkv/generic/sve/integer.cpp",
402+
"cpu/kernels/topkv/generic/sve/qasymm8.cpp",
403+
"cpu/kernels/topkv/generic/sve/qasymm8_signed.cpp"] +
399404
glob(["**/*.h",
400405
"**/*.hpp",
401406
"**/*.inl"]),

src/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ target_sources(
365365
cpu/kernels/scale/sve/qasymm8_signed.cpp
366366
cpu/kernels/softmax/generic/sve/impl.cpp
367367
cpu/kernels/softmax/generic/sve/impl_bf16.cpp
368+
cpu/kernels/topkv/generic/sve/fp16.cpp
369+
cpu/kernels/topkv/generic/sve/fp32.cpp
370+
cpu/kernels/topkv/generic/sve/integer.cpp
371+
cpu/kernels/topkv/generic/sve/qasymm8.cpp
372+
cpu/kernels/topkv/generic/sve/qasymm8_signed.cpp
368373
)
369374

370375
target_sources(

src/cpu/kernels/CpuTopKVKernel.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,42 @@ namespace
4343
{
4444

4545
static const std::vector<CpuTopKVKernel::TopKVKernel> available_kernels = {
46+
47+
{"sve_fp16_topkv",
48+
[](const CpuTopKVKernelDataTypeISASelectorData &data)
49+
{ return (data.dt == DataType::F16) && data.isa.fp16 && data.isa.sve; },
50+
REGISTER_FP16_SVE(arm_compute::cpu::topkv_fp16_sve)},
51+
52+
{"sve_fp32_topkv",
53+
[](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::F32) && data.isa.sve; },
54+
REGISTER_FP32_SVE(arm_compute::cpu::topkv_fp32_sve)},
55+
56+
{"sve_qasymm8_topkv",
57+
[](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::QASYMM8) && data.isa.sve; },
58+
REGISTER_QASYMM8_SVE(arm_compute::cpu::topkv_qasymm8_sve)},
59+
60+
{"sve_qasymm8_signed_topkv",
61+
[](const CpuTopKVKernelDataTypeISASelectorData &data)
62+
{ return (data.dt == DataType::QASYMM8_SIGNED) && data.isa.sve; },
63+
REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::topkv_qasymm8_signed_sve)},
64+
65+
{"sve_s32_topkv",
66+
[](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::S32) && data.isa.sve; },
67+
REGISTER_INTEGER_SVE(arm_compute::cpu::topkv_s32_sve)},
68+
4669
{"neon_s32_topkv", [](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::S32); },
4770
REGISTER_INTEGER_NEON(arm_compute::cpu::topkv_s32_neon)},
71+
4872
{"neon_fp32_topkv", [](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::F32); },
4973
REGISTER_FP32_NEON(arm_compute::cpu::topkv_fp32_neon)},
74+
5075
{"neon_fp16_topkv",
5176
[](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::F16) && data.isa.fp16; },
5277
REGISTER_FP16_NEON(arm_compute::cpu::topkv_fp16_neon)},
78+
5379
{"neon_qu8_topkv", [](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::QASYMM8); },
5480
REGISTER_QASYMM8_NEON(arm_compute::cpu::topkv_qasymm8_neon)},
81+
5582
{"neon_qs8_topkv",
5683
[](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::QASYMM8_SIGNED); },
5784
REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::topkv_qasymm8_signed_neon)}};
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright (c) 2026 Arm Limited.
3+
*
4+
* SPDX-License-Identifier: MIT
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to
8+
* deal in the Software without restriction, including without limitation the
9+
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10+
* sell copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in all
14+
* copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
* SOFTWARE.
23+
*/
24+
#if defined(__ARM_FEATURE_SVE)
25+
26+
#include "src/cpu/kernels/topkv/generic/sve/impl.h"
27+
28+
#include <arm_sve.h>
29+
30+
namespace arm_compute
31+
{
32+
namespace cpu
33+
{
34+
namespace detail
35+
{
36+
37+
template <>
38+
inline uint32_t vector_length<float16_t>()
39+
{
40+
return static_cast<uint32_t>(svcnth());
41+
}
42+
43+
template <>
44+
inline uint32_t count_gt_block<float16_t>(const float16_t *ptr, float16_t thr, uint32_t block_elems)
45+
{
46+
const svbool_t pg = svwhilelt_b16(static_cast<uint64_t>(0), static_cast<uint64_t>(block_elems));
47+
const svfloat16_t v = svld1_f16(pg, ptr);
48+
const svbool_t gt = svcmpgt_n_f16(pg, v, thr);
49+
return static_cast<uint32_t>(svcntp_b16(svptrue_b16(), gt));
50+
}
51+
52+
} // namespace detail
53+
54+
void topkv_fp16_sve(const ITensor *predictions, const ITensor *targets, ITensor *out, uint32_t k, const Window &win)
55+
{
56+
detail::topkv_sve_wrapper<float16_t>(predictions, targets, out, k, win);
57+
}
58+
59+
// Force instantiation into this TU
60+
template void
61+
detail::topkv_sve_wrapper<float16_t>(const ITensor *, const ITensor *, ITensor *, uint32_t, const Window &);
62+
63+
} // namespace cpu
64+
} // namespace arm_compute
65+
66+
#endif // __ARM_FEATURE_SVE
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright (c) 2026 Arm Limited.
3+
*
4+
* SPDX-License-Identifier: MIT
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to
8+
* deal in the Software without restriction, including without limitation the
9+
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10+
* sell copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in all
14+
* copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
* SOFTWARE.
23+
*/
24+
#if defined(__ARM_FEATURE_SVE)
25+
26+
#include "src/cpu/kernels/topkv/generic/sve/impl.h"
27+
28+
#include <arm_sve.h>
29+
#include <cstdint>
30+
31+
namespace arm_compute
32+
{
33+
namespace cpu
34+
{
35+
namespace detail
36+
{
37+
38+
template <>
39+
inline uint32_t vector_length<float>()
40+
{
41+
return static_cast<uint32_t>(svcntw());
42+
}
43+
44+
template <>
45+
inline uint32_t count_gt_block<float>(const float *ptr, float thr, uint32_t block_elems)
46+
{
47+
const svbool_t pg = svwhilelt_b32(static_cast<uint64_t>(0), static_cast<uint64_t>(block_elems));
48+
const svfloat32_t v = svld1_f32(pg, ptr);
49+
const svbool_t gt = svcmpgt_n_f32(pg, v, thr);
50+
return static_cast<uint32_t>(svcntp_b32(svptrue_b32(), gt));
51+
}
52+
53+
} // namespace detail
54+
55+
void topkv_fp32_sve(const ITensor *predictions, const ITensor *targets, ITensor *out, uint32_t k, const Window &win)
56+
{
57+
detail::topkv_sve_wrapper<float>(predictions, targets, out, k, win);
58+
}
59+
60+
// Force instantiation into this TU
61+
template void detail::topkv_sve_wrapper<float>(const ITensor *, const ITensor *, ITensor *, uint32_t, const Window &);
62+
63+
} // namespace cpu
64+
} // namespace arm_compute
65+
66+
#endif // __ARM_FEATURE_SVE
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Copyright (c) 2026 Arm Limited.
3+
*
4+
* SPDX-License-Identifier: MIT
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to
8+
* deal in the Software without restriction, including without limitation the
9+
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10+
* sell copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in all
14+
* copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
* SOFTWARE.
23+
*/
24+
#ifndef ACL_SRC_CPU_KERNELS_TOPKV_GENERIC_SVE_IMPL_H
25+
#define ACL_SRC_CPU_KERNELS_TOPKV_GENERIC_SVE_IMPL_H
26+
27+
#include "arm_compute/core/Coordinates.h"
28+
#include "arm_compute/core/Error.h"
29+
#include "arm_compute/core/Helpers.h"
30+
#include "arm_compute/core/ITensor.h"
31+
#include "arm_compute/core/Types.h"
32+
#include "arm_compute/core/Window.h"
33+
34+
#include <cstdint>
35+
#include <cstring>
36+
37+
namespace arm_compute
38+
{
39+
namespace cpu
40+
{
41+
namespace detail
42+
{
43+
44+
/*
45+
* Type-specific hooks (declared here, defined in each cpp).
46+
*
47+
* - vector_length<Scalar>()
48+
* Return the SVE vector length in elements for Scalar (no clamping).
49+
*
50+
* - count_gt_block<Scalar>(ptr, thr, block_elems)
51+
* Count how many elements in [ptr, ptr + block_elems) are > thr.
52+
* Tail-safe via predicate. block_elems is always <= vector_length<Scalar>().
53+
*
54+
t contains the SVE intrinsics
55+
* (e.g., qasymm8.cpp, qasymm8_signed.cpp, fp16.cpp, fp32.cpp, integer.cpp).
56+
*/
57+
58+
template <typename Scalar>
59+
uint32_t vector_length();
60+
61+
template <typename Scalar>
62+
uint32_t count_gt_block(const Scalar *ptr, Scalar thr, uint32_t block_elems);
63+
64+
// ----------------------------------------------------------------------------
65+
// Generic wrapper (type-agnostic) - uses the above hooks.
66+
// Semantics (matching TopKV tests you showed):
67+
// - predictions is N x C
68+
// - window iterates across output elements (classes) => id.x() == class index c
69+
// - for each class c, targets[c] gives the sample index t
70+
// - scan across N samples and compute rank (#samples with value > predictions[t])
71+
// - output is U8 boolean: (rank < k)
72+
// ----------------------------------------------------------------------------
73+
template <typename Scalar>
74+
inline void
75+
topkv_sve_wrapper(const ITensor *predictions, const ITensor *targets, ITensor *out, uint32_t k, const Window &window)
76+
{
77+
ARM_COMPUTE_ERROR_ON_NULLPTR(predictions, targets, out);
78+
ARM_COMPUTE_ERROR_ON(k == 0);
79+
80+
const ITensorInfo *pred_info = predictions->info();
81+
const uint32_t N = pred_info->dimension(0); // samples
82+
const uint32_t C = pred_info->dimension(1); // classes
83+
84+
const uint32_t vl = vector_length<Scalar>(); // cache once per kernel invocation
85+
86+
Iterator tgt_it(targets, window);
87+
Iterator out_it(out, window);
88+
89+
execute_window_loop(
90+
window,
91+
[&](const Coordinates &id)
92+
{
93+
const uint32_t c = static_cast<uint32_t>(id.x()); // class index
94+
ARM_COMPUTE_ERROR_ON(c >= C);
95+
96+
uint32_t t = {*reinterpret_cast<uint32_t *>(tgt_it.ptr())};
97+
ARM_COMPUTE_ERROR_ON(t >= N);
98+
99+
const Scalar *col_ptr = reinterpret_cast<const Scalar *>(predictions->ptr_to_element(Coordinates(0, c)));
100+
ARM_COMPUTE_ERROR_ON(col_ptr == nullptr);
101+
102+
const Scalar thr = col_ptr[t];
103+
104+
uint32_t rank = 0;
105+
uint32_t idx = 0;
106+
107+
while (idx < N)
108+
{
109+
const uint32_t remaining = N - idx;
110+
const uint32_t bw = (remaining < vl) ? remaining : vl;
111+
112+
rank += count_gt_block<Scalar>(col_ptr + idx, thr, bw);
113+
114+
if (rank >= k)
115+
{
116+
break;
117+
}
118+
119+
idx += bw;
120+
}
121+
122+
*reinterpret_cast<uint8_t *>(out_it.ptr()) = static_cast<uint8_t>(rank < k);
123+
},
124+
tgt_it, out_it);
125+
}
126+
127+
} // namespace detail
128+
} // namespace cpu
129+
} // namespace arm_compute
130+
131+
#endif // ACL_SRC_CPU_KERNELS_TOPKV_GENERIC_SVE_IMPL_H

0 commit comments

Comments
 (0)