Skip to content

Commit 6c0d7ee

Browse files
authored
Add unit test for scalar vs. CGSize1 probing (#509)
1 parent 1f09fa9 commit 6c0d7ee

5 files changed

Lines changed: 125 additions & 4 deletions

File tree

include/cuco/detail/extent/extent.inl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ struct window_extent {
4040
template <int32_t CGSize_, int32_t WindowSize_, typename SizeType_, std::size_t N_>
4141
friend auto constexpr make_window_extent(extent<SizeType_, N_> ext);
4242

43+
template <typename Rhs>
44+
friend __host__ __device__ constexpr value_type operator-(window_extent const& lhs,
45+
Rhs rhs) noexcept
46+
{
47+
return lhs.value() - rhs;
48+
}
49+
4350
template <typename Rhs>
4451
friend __host__ __device__ constexpr value_type operator/(window_extent const& lhs,
4552
Rhs rhs) noexcept

include/cuco/detail/probing_scheme_impl.inl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,8 @@ __host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::operato
149149
using size_type = typename Extent::value_type;
150150
return detail::probing_iterator<Extent>{
151151
cuco::detail::sanitize_hash<size_type>(hash1_(probe_key)) % upper_bound,
152-
max(size_type{1},
153-
cuco::detail::sanitize_hash<size_type>(hash2_(probe_key)) %
154-
upper_bound), // step size in range [1, prime - 1]
152+
cuco::detail::sanitize_hash<size_type>(hash2_(probe_key)) % (upper_bound - 1) +
153+
1, // step size in range [1, prime - 1]
155154
upper_bound};
156155
}
157156

include/cuco/utility/fast_int.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,12 @@ struct fast_int {
151151
return rhs.mulhi(rhs.magic_, mul) >> rhs.shift_;
152152
}
153153

154+
template <typename Rhs>
155+
friend __host__ __device__ constexpr auto operator-(fast_int const& lhs, Rhs rhs) noexcept
156+
{
157+
return lhs.value() - rhs;
158+
}
159+
154160
template <typename Rhs>
155161
friend __host__ __device__ constexpr auto operator/(fast_int const& lhs, Rhs rhs) noexcept
156162
{

tests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ ConfigureTest(UTILITY_TEST
5353
utility/extent_test.cu
5454
utility/storage_test.cu
5555
utility/fast_int_test.cu
56-
utility/hash_test.cu)
56+
utility/hash_test.cu
57+
utility/probing_scheme_test.cu)
5758

5859
###################################################################################################
5960
# - static_set tests ------------------------------------------------------------------------------
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <test_utils.hpp>
18+
19+
#include <cuco/detail/utility/cuda.hpp>
20+
#include <cuco/extent.cuh>
21+
#include <cuco/hash_functions.cuh>
22+
#include <cuco/probing_scheme.cuh>
23+
24+
#include <thrust/device_vector.h>
25+
26+
#include <cooperative_groups.h>
27+
28+
#include <catch2/catch_template_test_macros.hpp>
29+
30+
#include <cstddef>
31+
#include <cstdint>
32+
33+
template <class ProbingScheme, class Key, class Extent, class OutputIt>
34+
__global__ void generate_scalar_probing_sequence(Key key,
35+
Extent upper_bound,
36+
size_t seq_length,
37+
OutputIt out_seq)
38+
{
39+
auto constexpr cg_size = ProbingScheme::cg_size;
40+
static_assert(cg_size == 1, "Invalid CG size");
41+
42+
auto const tid = blockIdx.x * blockDim.x + threadIdx.x;
43+
auto probing_scheme = ProbingScheme{};
44+
45+
if (tid == 0) {
46+
auto iter = probing_scheme(key, upper_bound);
47+
48+
for (size_t i = 0; i < seq_length; ++i) {
49+
out_seq[i] = *iter;
50+
iter++;
51+
}
52+
}
53+
}
54+
55+
template <class ProbingScheme, class Key, class Extent, class OutputIt>
56+
__global__ void generate_cg_probing_sequence(Key key,
57+
Extent upper_bound,
58+
size_t seq_length,
59+
OutputIt out_seq)
60+
{
61+
auto constexpr cg_size = ProbingScheme::cg_size;
62+
63+
auto const tid = blockIdx.x * blockDim.x + threadIdx.x;
64+
auto probing_scheme = ProbingScheme{};
65+
66+
if (tid < cg_size) {
67+
auto const tile =
68+
cooperative_groups::tiled_partition<cg_size>(cooperative_groups::this_thread_block());
69+
70+
auto iter = probing_scheme(tile, key, upper_bound);
71+
72+
for (size_t i = tile.thread_rank(); i < seq_length; ++i) {
73+
out_seq[i] = *iter;
74+
iter++;
75+
}
76+
}
77+
}
78+
79+
TEMPLATE_TEST_CASE_SIG(
80+
"probing_scheme scalar vs CGSize 1 test",
81+
"",
82+
((typename Key, cuco::test::probe_sequence Probe, int32_t WindowSize), Key, Probe, WindowSize),
83+
(int32_t, cuco::test::probe_sequence::double_hashing, 1),
84+
(int32_t, cuco::test::probe_sequence::double_hashing, 2),
85+
(int64_t, cuco::test::probe_sequence::double_hashing, 1),
86+
(int64_t, cuco::test::probe_sequence::double_hashing, 2),
87+
(int32_t, cuco::test::probe_sequence::linear_probing, 1),
88+
(int32_t, cuco::test::probe_sequence::linear_probing, 2),
89+
(int64_t, cuco::test::probe_sequence::linear_probing, 1),
90+
(int64_t, cuco::test::probe_sequence::linear_probing, 2))
91+
{
92+
auto const upper_bound = cuco::make_window_extent<1, WindowSize>(cuco::extent<std::size_t>{10});
93+
constexpr size_t seq_length{8};
94+
constexpr Key key{42};
95+
96+
using probe = std::conditional_t<Probe == cuco::test::probe_sequence::linear_probing,
97+
cuco::linear_probing<1, cuco::default_hash_function<Key>>,
98+
cuco::double_hashing<1, cuco::default_hash_function<Key>>>;
99+
100+
thrust::device_vector<size_t> scalar_seq(seq_length);
101+
generate_scalar_probing_sequence<probe>
102+
<<<1, 1>>>(key, upper_bound, seq_length, scalar_seq.begin());
103+
thrust::device_vector<size_t> cg_seq(seq_length);
104+
generate_cg_probing_sequence<probe><<<1, 1>>>(key, upper_bound, seq_length, cg_seq.begin());
105+
106+
REQUIRE(cuco::test::equal(
107+
scalar_seq.begin(), scalar_seq.end(), cg_seq.begin(), thrust::equal_to<std::size_t>{}));
108+
}

0 commit comments

Comments
 (0)