KNN build with brute force method in UMAP does not use all_neighbors but directly calls brute force build and search.
|
if (params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN) { |
|
auto idx = [&]() { |
|
if (data_on_device) { // inputsA on device |
|
return cuvs::neighbors::brute_force::build( |
|
handle, |
|
{static_cast<cuvs::distance::DistanceType>(params->metric), params->p}, |
|
raft::make_device_matrix_view<const float, int64_t>(inputsA.X, inputsA.n, inputsA.d)); |
|
} else { // inputsA on host |
|
return cuvs::neighbors::brute_force::build( |
|
handle, |
|
{static_cast<cuvs::distance::DistanceType>(params->metric), params->p}, |
|
raft::make_host_matrix_view<const float, int64_t>(inputsA.X, inputsA.n, inputsA.d)); |
|
} |
|
}(); |
|
cuvs::neighbors::brute_force::search( |
|
handle, |
|
idx, |
|
raft::make_device_matrix_view<const float, int64_t>(inputsB.X, inputsB.n, inputsB.d), |
|
raft::make_device_matrix_view<int64_t, int64_t>(out.knn_indices, inputsB.n, n_neighbors), |
|
raft::make_device_matrix_view<float, int64_t>(out.knn_dists, inputsB.n, n_neighbors)); |
|
} else { // nn_descent |
Consolidate with the NN Descent's all neighbors API call and remove the if-else branching.
KNN build with brute force method in UMAP does not use
all_neighborsbut directly calls brute force build and search.cuml/cpp/src/umap/knn_graph/algo.cuh
Lines 62 to 82 in 604e7e0
Consolidate with the NN Descent's all neighbors API call and remove the if-else branching.