diff --git a/cpp/include/raft/cluster/detail/agglomerative.cuh b/cpp/include/raft/cluster/detail/agglomerative.cuh deleted file mode 100644 index f2c83abdd3..0000000000 --- a/cpp/include/raft/cluster/detail/agglomerative.cuh +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::cluster::detail { -template -class UnionFind { - public: - value_idx next_label; - std::vector parent; - std::vector size; - - value_idx n_indices; - - UnionFind(value_idx N_) - : n_indices(2 * N_ - 1), parent(2 * N_ - 1, -1), size(2 * N_ - 1, 1), next_label(N_) - { - memset(size.data() + N_, 0, (size.size() - N_) * sizeof(value_idx)); - } - - value_idx find(value_idx n) - { - value_idx p; - p = n; - - while (parent[n] != -1) - n = parent[n]; - - // path compression - while (parent[p] != n) { - p = parent[p == -1 ? n_indices - 1 : p]; - parent[p == -1 ? n_indices - 1 : p] = n; - } - return n; - } - - void perform_union(value_idx m, value_idx n) - { - size[next_label] = size[m] + size[n]; - parent[m] = next_label; - parent[n] = next_label; - - next_label += 1; - } -}; - -/** - * Agglomerative labeling on host. This has not been found to be a bottleneck - * in the algorithm. A parallel version of this can be done using a parallel - * variant of Kruskal's MST algorithm - * (ref http://cucis.ece.northwestern.edu/publications/pdf/HenPat12.pdf), - * which breaks apart the sorted MST results into overlapping subsets and - * independently runs Kruskal's algorithm on each subset, merging them back - * together into a single hierarchy when complete. Unfortunately, - * this is nontrivial and the speedup wouldn't be useful until this - * becomes a bottleneck. - * - * @tparam value_idx - * @tparam value_t - * @param[in] handle the raft handle - * @param[in] rows src edges of the sorted MST - * @param[in] cols dst edges of the sorted MST - * @param[in] nnz the number of edges in the sorted MST - * @param[out] out_src parents of output - * @param[out] out_dst children of output - * @param[out] out_delta distances of output - * @param[out] out_size cluster sizes of output - */ -template -void build_dendrogram_host(raft::resources const& handle, - const value_idx* rows, - const value_idx* cols, - const value_t* data, - size_t nnz, - value_idx* children, - value_t* out_delta, - value_idx* out_size) -{ - auto stream = resource::get_cuda_stream(handle); - - value_idx n_edges = nnz; - - std::vector mst_src_h(n_edges); - std::vector mst_dst_h(n_edges); - std::vector mst_weights_h(n_edges); - - update_host(mst_src_h.data(), rows, n_edges, stream); - update_host(mst_dst_h.data(), cols, n_edges, stream); - update_host(mst_weights_h.data(), data, n_edges, stream); - - resource::sync_stream(handle, stream); - - std::vector children_h(n_edges * 2); - std::vector out_size_h(n_edges); - std::vector out_delta_h(n_edges); - - UnionFind U(nnz + 1); - - for (std::size_t i = 0; i < nnz; i++) { - value_idx a = mst_src_h[i]; - value_idx b = mst_dst_h[i]; - value_t delta = mst_weights_h[i]; - - value_idx aa = U.find(a); - value_idx bb = U.find(b); - - value_idx children_idx = i * 2; - - children_h[children_idx] = aa; - children_h[children_idx + 1] = bb; - out_delta_h[i] = delta; - out_size_h[i] = U.size[aa] + U.size[bb]; - - U.perform_union(aa, bb); - } - - raft::update_device(children, children_h.data(), n_edges * 2, stream); - raft::update_device(out_size, out_size_h.data(), n_edges, stream); - raft::update_device(out_delta, out_delta_h.data(), n_edges, stream); -} - -template -RAFT_KERNEL write_levels_kernel(const value_idx* children, value_idx* parents, value_idx n_vertices) -{ - value_idx tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid < n_vertices) { - value_idx level = tid / 2; - value_idx child = children[tid]; - parents[child] = level; - } -} - -/** - * Instead of propagating a label from roots to children, - * the children each iterate up the tree until they find - * the label of their parent. This increases the potential - * parallelism. - * @tparam value_idx - * @param children - * @param parents - * @param n_leaves - * @param labels - */ -template -RAFT_KERNEL inherit_labels(const value_idx* children, - const value_idx* levels, - std::size_t n_leaves, - value_idx* labels, - int cut_level, - value_idx n_vertices) -{ - value_idx tid = blockDim.x * blockIdx.x + threadIdx.x; - - if (tid < n_vertices) { - value_idx node = children[tid]; - value_idx cur_level = tid / 2; - - /** - * Any roots above the cut level should be ignored. - * Any leaves at the cut level should already be labeled - */ - if (cur_level > cut_level) return; - - value_idx cur_parent = node; - value_idx label = labels[cur_parent]; - - while (label == -1) { - cur_parent = cur_level + n_leaves; - cur_level = levels[cur_parent]; - label = labels[cur_parent]; - } - - labels[node] = label; - } -} - -template -struct init_label_roots { - init_label_roots(value_idx* labels_) : labels(labels_) {} - - template - __host__ __device__ void operator()(Tuple t) - { - labels[thrust::get<1>(t)] = thrust::get<0>(t); - } - - private: - value_idx* labels; -}; - -/** - * Cuts the dendrogram at a particular level where the number of nodes - * is equal to n_clusters, then propagates the resulting labels - * to all the children. - * - * @tparam value_idx - * @param handle - * @param labels - * @param children - * @param n_clusters - * @param n_leaves - */ -template -void extract_flattened_clusters(raft::resources const& handle, - value_idx* labels, - const value_idx* children, - size_t n_clusters, - size_t n_leaves) -{ - auto stream = resource::get_cuda_stream(handle); - auto thrust_policy = resource::get_thrust_policy(handle); - - // Handle special case where n_clusters == 1 - if (n_clusters == 1) { - thrust::fill(thrust_policy, labels, labels + n_leaves, 0); - } else { - /** - * Compute levels for each node - * - * 1. Initialize "levels" array of size n_leaves * 2 - * - * 2. For each entry in children, write parent - * out for each of the children - */ - - auto n_edges = (n_leaves - 1) * 2; - - thrust::device_ptr d_ptr = thrust::device_pointer_cast(children); - value_idx n_vertices = *(thrust::max_element(thrust_policy, d_ptr, d_ptr + n_edges)) + 1; - - // Prevent potential infinite loop from labeling disconnected - // connectivities graph. - RAFT_EXPECTS(n_leaves > 0, "n_leaves must be positive"); - RAFT_EXPECTS( - static_cast(n_vertices) == static_cast((n_leaves - 1) * 2), - "Multiple components found in MST or MST is invalid. " - "Cannot find single-linkage solution."); - - rmm::device_uvector levels(n_vertices, stream); - - value_idx n_blocks = ceildiv(n_vertices, (value_idx)tpb); - write_levels_kernel<<>>(children, levels.data(), n_vertices); - /** - * Step 1: Find label roots: - * - * 1. Copying children[children.size()-(n_clusters-1):] entries to - * separate arrayo - * 2. sort array - * 3. take first n_clusters entries - */ - - value_idx child_size = (n_clusters - 1) * 2; - rmm::device_uvector label_roots(child_size, stream); - - value_idx children_cpy_start = n_edges - child_size; - raft::copy_async(label_roots.data(), children + children_cpy_start, child_size, stream); - - thrust::sort(thrust_policy, - label_roots.data(), - label_roots.data() + (child_size), - thrust::greater()); - - rmm::device_uvector tmp_labels(n_vertices, stream); - - // Init labels to -1 - thrust::fill(thrust_policy, tmp_labels.data(), tmp_labels.data() + n_vertices, -1); - - // Write labels for cluster roots to "labels" - thrust::counting_iterator first(0); - - auto z_iter = thrust::make_zip_iterator( - thrust::make_tuple(first, label_roots.data() + (label_roots.size() - n_clusters))); - - thrust::for_each( - thrust_policy, z_iter, z_iter + n_clusters, init_label_roots(tmp_labels.data())); - - /** - * Step 2: Propagate labels by having children iterate through their parents - * 1. Initialize labels to -1 - * 2. For each element in levels array, propagate until parent's - * label is !=-1 - */ - value_idx cut_level = (n_edges / 2) - (n_clusters - 1); - - inherit_labels<<>>( - children, levels.data(), n_leaves, tmp_labels.data(), cut_level, n_vertices); - - // copy tmp labels to actual labels - raft::copy_async(labels, tmp_labels.data(), n_leaves, stream); - } -} - -}; // namespace raft::cluster::detail diff --git a/cpp/include/raft/cluster/detail/connectivities.cuh b/cpp/include/raft/cluster/detail/connectivities.cuh deleted file mode 100644 index c527b754c3..0000000000 --- a/cpp/include/raft/cluster/detail/connectivities.cuh +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -#include - -namespace raft::cluster::detail { - -template -struct distance_graph_impl { - void run(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - rmm::device_uvector& indptr, - rmm::device_uvector& indices, - rmm::device_uvector& data, - int c); -}; - -/** - * Connectivities specialization to build a knn graph - * @tparam value_idx - * @tparam value_t - */ -template -struct distance_graph_impl { - void run(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - rmm::device_uvector& indptr, - rmm::device_uvector& indices, - rmm::device_uvector& data, - int c) - { - auto stream = resource::get_cuda_stream(handle); - auto thrust_policy = resource::get_thrust_policy(handle); - - // Need to symmetrize knn into undirected graph - raft::sparse::COO knn_graph_coo(stream); - - raft::sparse::neighbors::knn_graph(handle, X, m, n, metric, knn_graph_coo, c); - - indices.resize(knn_graph_coo.nnz, stream); - data.resize(knn_graph_coo.nnz, stream); - - // self-loops get max distance - auto transform_in = thrust::make_zip_iterator( - thrust::make_tuple(knn_graph_coo.rows(), knn_graph_coo.cols(), knn_graph_coo.vals())); - - thrust::transform(thrust_policy, - transform_in, - transform_in + knn_graph_coo.nnz, - knn_graph_coo.vals(), - [=] __device__(const thrust::tuple& tup) { - bool self_loop = thrust::get<0>(tup) == thrust::get<1>(tup); - return (self_loop * std::numeric_limits::max()) + - (!self_loop * thrust::get<2>(tup)); - }); - - raft::sparse::convert::sorted_coo_to_csr( - knn_graph_coo.rows(), knn_graph_coo.nnz, indptr.data(), m + 1, stream); - - // TODO: Wouldn't need to copy here if we could compute knn - // graph directly on the device uvectors - // ref: https://github.com/rapidsai/raft/issues/227 - raft::copy_async(indices.data(), knn_graph_coo.cols(), knn_graph_coo.nnz, stream); - raft::copy_async(data.data(), knn_graph_coo.vals(), knn_graph_coo.nnz, stream); - } -}; - -template -RAFT_KERNEL fill_indices2(value_idx* indices, size_t m, size_t nnz) -{ - value_idx tid = (blockIdx.x * blockDim.x) + threadIdx.x; - if (tid >= nnz) return; - value_idx v = tid % m; - indices[tid] = v; -} - -/** - * Compute connected CSR of pairwise distances - * @tparam value_idx - * @tparam value_t - * @param handle - * @param X - * @param m - * @param n - * @param metric - * @param[out] indptr - * @param[out] indices - * @param[out] data - */ -template -void pairwise_distances(const raft::resources& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - value_idx* indptr, - value_idx* indices, - value_t* data) -{ - auto stream = resource::get_cuda_stream(handle); - auto exec_policy = resource::get_thrust_policy(handle); - - value_idx nnz = m * m; - - value_idx blocks = raft::ceildiv(nnz, (value_idx)256); - fill_indices2<<>>(indices, m, nnz); - - thrust::sequence(exec_policy, indptr, indptr + m, 0, (int)m); - - raft::update_device(indptr + m, &nnz, 1, stream); - - // TODO: It would ultimately be nice if the MST could accept - // dense inputs directly so we don't need to double the memory - // usage to hand it a sparse array here. - distance::pairwise_distance(handle, X, X, data, m, m, n, metric); - // self-loops get max distance - auto transform_in = - thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), data)); - - thrust::transform(exec_policy, - transform_in, - transform_in + nnz, - data, - [=] __device__(const thrust::tuple& tup) { - value_idx idx = thrust::get<0>(tup); - bool self_loop = idx % m == idx / m; - return (self_loop * std::numeric_limits::max()) + - (!self_loop * thrust::get<1>(tup)); - }); -} - -/** - * Connectivities specialization for pairwise distances - * @tparam value_idx - * @tparam value_t - */ -template -struct distance_graph_impl { - void run(const raft::resources& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - rmm::device_uvector& indptr, - rmm::device_uvector& indices, - rmm::device_uvector& data, - int c) - { - auto stream = resource::get_cuda_stream(handle); - - size_t nnz = m * m; - - indices.resize(nnz, stream); - data.resize(nnz, stream); - - pairwise_distances(handle, X, m, n, metric, indptr.data(), indices.data(), data.data()); - } -}; - -/** - * Returns a CSR connectivities graph based on the given linkage distance. - * @tparam value_idx - * @tparam value_t - * @tparam dist_type - * @param[in] handle raft handle - * @param[in] X dense data for which to construct connectivites - * @param[in] m number of rows in X - * @param[in] n number of columns in X - * @param[in] metric distance metric to use - * @param[out] indptr indptr array of connectivities graph - * @param[out] indices column indices array of connectivities graph - * @param[out] data distances array of connectivities graph - * @param[out] c constant 'c' used for nearest neighbors-based distances - * which will guarantee k <= log(n) + c - */ -template -void get_distance_graph(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - rmm::device_uvector& indptr, - rmm::device_uvector& indices, - rmm::device_uvector& data, - int c) -{ - auto stream = resource::get_cuda_stream(handle); - - indptr.resize(m + 1, stream); - - distance_graph_impl dist_graph; - dist_graph.run(handle, X, m, n, metric, indptr, indices, data, c); -} - -}; // namespace raft::cluster::detail diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh deleted file mode 100644 index 4efeedcbaa..0000000000 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ /dev/null @@ -1,1255 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace raft { -namespace cluster { -namespace detail { - -// ========================================================= -// Init functions -// ========================================================= - -// Selects 'n_clusters' samples randomly from X -template -void initRandom(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids) -{ - common::nvtx::range fun_scope("initRandom"); - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_clusters = params.n_clusters; - detail::shuffleAndGather(handle, X, centroids, n_clusters, params.rng_state.seed); -} - -/* - * @brief Selects 'n_clusters' samples from the input X using kmeans++ algorithm. - - * @note This is the algorithm described in - * "k-means++: the advantages of careful seeding". 2007, Arthur, D. and Vassilvitskii, S. - * ACM-SIAM symposium on Discrete algorithms. - * - * Scalable kmeans++ pseudocode - * 1: C = sample a point uniformly at random from X - * 2: while |C| < k - * 3: Sample x in X with probability p_x = d^2(x, C) / phi_X (C) - * 4: C = C U {x} - * 5: end for - */ -template -void kmeansPlusPlus(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroidsRawData, - rmm::device_uvector& workspace) -{ - common::nvtx::range fun_scope("kmeansPlusPlus"); - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - // number of seeding trials for each center (except the first) - auto n_trials = 2 + static_cast(std::ceil(log(n_clusters))); - - RAFT_LOG_DEBUG( - "Run sequential k-means++ to select %d centroids from %d input samples " - "(%d seeding trials per iterations)", - n_clusters, - n_samples, - n_trials); - - auto dataBatchSize = getDataBatchSize(params.batch_samples, n_samples); - - // temporary buffers - auto indices = raft::make_device_vector(handle, n_trials); - auto centroidCandidates = raft::make_device_matrix(handle, n_trials, n_features); - auto costPerCandidate = raft::make_device_vector(handle, n_trials); - auto minClusterDistance = raft::make_device_vector(handle, n_samples); - auto distBuffer = raft::make_device_matrix(handle, n_trials, n_samples); - - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - rmm::device_scalar clusterCost(stream); - rmm::device_scalar> minClusterIndexAndDistance(stream); - - // Device and matrix views - raft::device_vector_view indices_view(indices.data_handle(), n_trials); - auto const_weights_view = - raft::make_device_vector_view(minClusterDistance.data_handle(), n_samples); - auto const_indices_view = - raft::make_device_vector_view(indices.data_handle(), n_trials); - auto const_X_view = - raft::make_device_matrix_view(X.data_handle(), n_samples, n_features); - raft::device_matrix_view candidates_view( - centroidCandidates.data_handle(), n_trials, n_features); - - // L2 norm of X: ||c||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormX.data_handle(), - X.data_handle(), - X.extent(1), - X.extent(0), - raft::linalg::L2Norm, - true, - stream); - } - - raft::random::RngState rng(params.rng_state.seed, params.rng_state.type); - std::mt19937 gen(params.rng_state.seed); - std::uniform_int_distribution<> dis(0, n_samples - 1); - - // <<< Step-1 >>>: C <-- sample a point uniformly at random from X - auto initialCentroid = raft::make_device_matrix_view( - X.data_handle() + dis(gen) * n_features, 1, n_features); - int n_clusters_picked = 1; - - // store the chosen centroid in the buffer - raft::copy( - centroidsRawData.data_handle(), initialCentroid.data_handle(), initialCentroid.size(), stream); - - // C = initial set of centroids - auto centroids = raft::make_device_matrix_view( - centroidsRawData.data_handle(), initialCentroid.extent(0), initialCentroid.extent(1)); - // <<< End of Step-1 >>> - - // Calculate cluster distance, d^2(x, C), for all the points x in X to the nearest centroid - detail::minClusterDistanceCompute(handle, - X, - centroids, - minClusterDistance.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - RAFT_LOG_DEBUG(" k-means++ - Sampled %d/%d centroids", n_clusters_picked, n_clusters); - - // <<<< Step-2 >>> : while |C| < k - while (n_clusters_picked < n_clusters) { - // <<< Step-3 >>> : Sample x in X with probability p_x = d^2(x, C) / phi_X (C) - // Choose 'n_trials' centroid candidates from X with probability proportional to the squared - // distance to the nearest existing cluster - - raft::random::discrete(handle, rng, indices_view, const_weights_view); - raft::matrix::gather(handle, const_X_view, const_indices_view, candidates_view); - - // Calculate pairwise distance between X and the centroid candidates - // Output - pwd [n_trials x n_samples] - auto pwd = distBuffer.view(); - detail::pairwise_distance_kmeans( - handle, centroidCandidates.view(), X, pwd, workspace, metric); - - // Update nearest cluster distance for each centroid candidate - // Note pwd and minDistBuf points to same buffer which currently holds pairwise distance values. - // Outputs minDistanceBuf[n_trials x n_samples] where minDistance[i, :] contains updated - // minClusterDistance that includes candidate-i - auto minDistBuf = distBuffer.view(); - raft::linalg::matrixVectorOp(minDistBuf.data_handle(), - pwd.data_handle(), - minClusterDistance.data_handle(), - pwd.extent(1), - pwd.extent(0), - true, - true, - raft::min_op{}, - stream); - - // Calculate costPerCandidate[n_trials] where costPerCandidate[i] is the cluster cost when using - // centroid candidate-i - raft::linalg::reduce(costPerCandidate.data_handle(), - minDistBuf.data_handle(), - minDistBuf.extent(1), - minDistBuf.extent(0), - static_cast(0), - true, - true, - stream); - - // Greedy Choice - Choose the candidate that has minimum cluster cost - // ArgMin operation below identifies the index of minimum cost in costPerCandidate - { - // Determine temporary device storage requirements - size_t temp_storage_bytes = 0; - cub::DeviceReduce::ArgMin(nullptr, - temp_storage_bytes, - costPerCandidate.data_handle(), - minClusterIndexAndDistance.data(), - costPerCandidate.extent(0), - stream); - - // Allocate temporary storage - workspace.resize(temp_storage_bytes, stream); - - // Run argmin-reduction - cub::DeviceReduce::ArgMin(workspace.data(), - temp_storage_bytes, - costPerCandidate.data_handle(), - minClusterIndexAndDistance.data(), - costPerCandidate.extent(0), - stream); - - int bestCandidateIdx = -1; - raft::copy(&bestCandidateIdx, &minClusterIndexAndDistance.data()->key, 1, stream); - resource::sync_stream(handle); - /// <<< End of Step-3 >>> - - /// <<< Step-4 >>>: C = C U {x} - // Update minimum cluster distance corresponding to the chosen centroid candidate - raft::copy(minClusterDistance.data_handle(), - minDistBuf.data_handle() + bestCandidateIdx * n_samples, - n_samples, - stream); - - raft::copy(centroidsRawData.data_handle() + n_clusters_picked * n_features, - centroidCandidates.data_handle() + bestCandidateIdx * n_features, - n_features, - stream); - - ++n_clusters_picked; - /// <<< End of Step-4 >>> - } - - RAFT_LOG_DEBUG(" k-means++ - Sampled %d/%d centroids", n_clusters_picked, n_clusters); - } /// <<<< Step-5 >>> -} - -/** - * - * @tparam DataT - * @tparam IndexT - * @param handle - * @param[in] X input matrix (size n_samples, n_features) - * @param[in] weight number of samples currently assigned to each centroid - * @param[in] cur_centroids matrix of current centroids (size n_clusters, n_features) - * @param[in] l2norm_x - * @param[out] min_cluster_and_dist - * @param[out] new_centroids - * @param[out] new_weight - * @param[inout] workspace - */ -template -void update_centroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroids, - - // TODO: Figure out how to best wrap iterator types in mdspan - LabelsIterator cluster_labels, - raft::device_vector_view weight_per_cluster, - raft::device_matrix_view new_centroids, - rmm::device_uvector& workspace) -{ - auto n_clusters = centroids.extent(0); - auto n_samples = X.extent(0); - - workspace.resize(n_samples, resource::get_cuda_stream(handle)); - - // Calculates weighted sum of all the samples assigned to cluster-i and stores the - // result in new_centroids[i] - raft::linalg::reduce_rows_by_key((DataT*)X.data_handle(), - X.extent(1), - cluster_labels, - sample_weights.data_handle(), - workspace.data(), - X.extent(0), - X.extent(1), - n_clusters, - new_centroids.data_handle(), - resource::get_cuda_stream(handle)); - - // Reduce weights by key to compute weight in each cluster - raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), - cluster_labels, - weight_per_cluster.data_handle(), - (IndexT)1, - (IndexT)sample_weights.extent(0), - (IndexT)n_clusters, - resource::get_cuda_stream(handle)); - - // Computes new_centroids[i] = new_centroids[i]/weight_per_cluster[i] where - // new_centroids[n_clusters x n_features] - 2D array, new_centroids[i] has sum of all the - // samples assigned to cluster-i - // weight_per_cluster[n_clusters] - 1D array, weight_per_cluster[i] contains sum of weights in - // cluster-i. - // Note - when weight_per_cluster[i] is 0, new_centroids[i] is reset to 0 - raft::linalg::matrixVectorOp(new_centroids.data_handle(), - new_centroids.data_handle(), - weight_per_cluster.data_handle(), - new_centroids.extent(1), - new_centroids.extent(0), - true, - false, - raft::div_checkzero_op{}, - resource::get_cuda_stream(handle)); - - // copy centroids[i] to new_centroids[i] when weight_per_cluster[i] is 0 - cub::ArgIndexInputIterator itr_wt(weight_per_cluster.data_handle()); - raft::matrix::gather_if( - const_cast(centroids.data_handle()), - static_cast(centroids.extent(1)), - static_cast(centroids.extent(0)), - itr_wt, - itr_wt, - static_cast(weight_per_cluster.size()), - new_centroids.data_handle(), - [=] __device__(raft::KeyValuePair map) { // predicate - // copy when the sum of weights in the cluster is 0 - return map.value == 0; - }, - raft::key_op{}, - resource::get_cuda_stream(handle)); -} - -// TODO: Resizing is needed to use mdarray instead of rmm::device_uvector -template -void kmeans_fit_main(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view weight, - raft::device_matrix_view centroidsRawData, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) -{ - common::nvtx::range fun_scope("kmeans_fit_main"); - logger::get(RAFT_NAME).set_level(params.verbosity); - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - // stores (key, value) pair corresponding to each sample where - // - key is the index of nearest cluster - // - value is the distance to the nearest cluster - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - - // temporary buffer to store L2 norm of centroids or distance matrix, - // destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // temporary buffer to store intermediate centroids, destructor releases the - // resource - auto newCentroids = raft::make_device_matrix(handle, n_clusters, n_features); - - // temporary buffer to store weights per cluster, destructor releases the - // resource - auto wtInCluster = raft::make_device_vector(handle, n_clusters); - - rmm::device_scalar clusterCostD(stream); - - // L2 norm of X: ||x||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - auto l2normx_view = - raft::make_device_vector_view(L2NormX.data_handle(), n_samples); - - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormX.data_handle(), - X.data_handle(), - X.extent(1), - X.extent(0), - raft::linalg::L2Norm, - true, - stream); - } - - RAFT_LOG_DEBUG( - "Calling KMeans.fit with %d samples of input data and the initialized " - "cluster centers", - n_samples); - - DataT priorClusteringCost = 0; - for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { - RAFT_LOG_DEBUG( - "KMeans.fit: Iteration-%d: fitting the model using the initialized " - "cluster centers", - n_iter[0]); - - auto centroids = raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features); - - // computes minClusterAndDistance[0:n_samples) where - // minClusterAndDistance[i] is a pair where - // 'key' is index to a sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - detail::minClusterAndDistanceCompute(handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // Using TransformInputIteratorT to dereference an array of - // raft::KeyValuePair and converting them to just return the Key to be used - // in reduce_rows_by_key prims - detail::KeyValueIndexOp conversion_op; - cub::TransformInputIterator, - raft::KeyValuePair*> - itr(minClusterAndDistance.data_handle(), conversion_op); - - update_centroids(handle, - X, - weight, - raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features), - itr, - wtInCluster.view(), - newCentroids.view(), - workspace); - - // compute the squared norm between the newCentroids and the original - // centroids, destructor releases the resource - auto sqrdNorm = raft::make_device_scalar(handle, DataT(0)); - raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), - newCentroids.size(), - raft::sqdiff_op{}, - stream, - centroids.data_handle(), - newCentroids.data_handle()); - - DataT sqrdNormError = 0; - raft::copy(&sqrdNormError, sqrdNorm.data_handle(), sqrdNorm.size(), stream); - - raft::copy( - centroidsRawData.data_handle(), newCentroids.data_handle(), newCentroids.size(), stream); - - bool done = false; - if (params.inertia_check) { - // calculate cluster cost phi_x(C) - detail::computeClusterCost(handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - - DataT curClusteringCost = clusterCostD.value(stream); - - ASSERT(curClusteringCost != (DataT)0.0, - "Too few points and centroids being found is getting 0 cost from " - "centers"); - - if (n_iter[0] > 1) { - DataT delta = curClusteringCost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - } - priorClusteringCost = curClusteringCost; - } - - resource::sync_stream(handle, stream); - if (sqrdNormError < params.tol) done = true; - - if (done) { - RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", n_iter[0]); - break; - } - } - - auto centroids = raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features); - - detail::minClusterAndDistanceCompute(handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // TODO: add different templates for InType of binaryOp to avoid thrust transform - thrust::transform(resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - weight.data_handle(), - minClusterAndDistance.data_handle(), - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }); - - // calculate cluster cost phi_x(C) - detail::computeClusterCost(handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - - inertia[0] = clusterCostD.value(stream); - - RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", - n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], - inertia[0]); -} - -/* - * @brief Selects 'n_clusters' samples from X using scalable kmeans++ algorithm. - - * @note This is the algorithm described in - * "Scalable K-Means++", 2012, Bahman Bahmani, Benjamin Moseley, - * Andrea Vattani, Ravi Kumar, Sergei Vassilvitskii, - * https://arxiv.org/abs/1203.6402 - - * Scalable kmeans++ pseudocode - * 1: C = sample a point uniformly at random from X - * 2: psi = phi_X (C) - * 3: for O( log(psi) ) times do - * 4: C' = sample each point x in X independently with probability - * p_x = l * (d^2(x, C) / phi_X (C) ) - * 5: C = C U C' - * 6: end for - * 7: For x in C, set w_x to be the number of points in X closer to x than any - * other point in C - * 8: Recluster the weighted points in C into k clusters - - * TODO: Resizing is needed to use mdarray instead of rmm::device_uvector - - */ -template -void initScalableKMeansPlusPlus(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroidsRawData, - rmm::device_uvector& workspace) -{ - common::nvtx::range fun_scope("initScalableKMeansPlusPlus"); - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - raft::random::RngState rng(params.rng_state.seed, params.rng_state.type); - - // <<<< Step-1 >>> : C <- sample a point uniformly at random from X - std::mt19937 gen(params.rng_state.seed); - std::uniform_int_distribution<> dis(0, n_samples - 1); - - auto cIdx = dis(gen); - auto initialCentroid = raft::make_device_matrix_view( - X.data_handle() + cIdx * n_features, 1, n_features); - - // flag the sample that is chosen as initial centroid - std::vector h_isSampleCentroid(n_samples); - std::fill(h_isSampleCentroid.begin(), h_isSampleCentroid.end(), 0); - h_isSampleCentroid[cIdx] = 1; - - // device buffer to flag the sample that is chosen as initial centroid - auto isSampleCentroid = raft::make_device_vector(handle, n_samples); - - raft::copy( - isSampleCentroid.data_handle(), h_isSampleCentroid.data(), isSampleCentroid.size(), stream); - - rmm::device_uvector centroidsBuf(initialCentroid.size(), stream); - - // reset buffer to store the chosen centroid - raft::copy(centroidsBuf.data(), initialCentroid.data_handle(), initialCentroid.size(), stream); - - auto potentialCentroids = raft::make_device_matrix_view( - centroidsBuf.data(), initialCentroid.extent(0), initialCentroid.extent(1)); - // <<< End of Step-1 >>> - - // temporary buffer to store L2 norm of centroids or distance matrix, - // destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // L2 norm of X: ||x||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormX.data_handle(), - X.data_handle(), - X.extent(1), - X.extent(0), - raft::linalg::L2Norm, - true, - stream); - } - - auto minClusterDistanceVec = raft::make_device_vector(handle, n_samples); - auto uniformRands = raft::make_device_vector(handle, n_samples); - rmm::device_scalar clusterCost(stream); - - // <<< Step-2 >>>: psi <- phi_X (C) - detail::minClusterDistanceCompute(handle, - X, - potentialCentroids, - minClusterDistanceVec.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // compute partial cluster cost from the samples in rank - detail::computeClusterCost(handle, - minClusterDistanceVec.view(), - workspace, - raft::make_device_scalar_view(clusterCost.data()), - raft::identity_op{}, - raft::add_op{}); - - auto psi = clusterCost.value(stream); - - // <<< End of Step-2 >>> - - // Scalable kmeans++ paper claims 8 rounds is sufficient - resource::sync_stream(handle, stream); - int niter = std::min(8, (int)ceil(log(psi))); - RAFT_LOG_DEBUG("KMeans||: psi = %g, log(psi) = %g, niter = %d ", psi, log(psi), niter); - - // <<<< Step-3 >>> : for O( log(psi) ) times do - for (int iter = 0; iter < niter; ++iter) { - RAFT_LOG_DEBUG("KMeans|| - Iteration %d: # potential centroids sampled - %d", - iter, - potentialCentroids.extent(0)); - - detail::minClusterDistanceCompute(handle, - X, - potentialCentroids, - minClusterDistanceVec.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - detail::computeClusterCost(handle, - minClusterDistanceVec.view(), - workspace, - raft::make_device_scalar_view(clusterCost.data()), - raft::identity_op{}, - raft::add_op{}); - - psi = clusterCost.value(stream); - - // <<<< Step-4 >>> : Sample each point x in X independently and identify new - // potentialCentroids - raft::random::uniform( - handle, rng, uniformRands.data_handle(), uniformRands.extent(0), (DataT)0, (DataT)1); - - detail::SamplingOp select_op(psi, - params.oversampling_factor, - n_clusters, - uniformRands.data_handle(), - isSampleCentroid.data_handle()); - - rmm::device_uvector CpRaw(0, stream); - detail::sampleCentroids(handle, - X, - minClusterDistanceVec.view(), - isSampleCentroid.view(), - select_op, - CpRaw, - workspace); - auto Cp = raft::make_device_matrix_view( - CpRaw.data(), CpRaw.size() / n_features, n_features); - /// <<<< End of Step-4 >>>> - - /// <<<< Step-5 >>> : C = C U C' - // append the data in Cp to the buffer holding the potentialCentroids - centroidsBuf.resize(centroidsBuf.size() + Cp.size(), stream); - raft::copy( - centroidsBuf.data() + centroidsBuf.size() - Cp.size(), Cp.data_handle(), Cp.size(), stream); - - IndexT tot_centroids = potentialCentroids.extent(0) + Cp.extent(0); - potentialCentroids = - raft::make_device_matrix_view(centroidsBuf.data(), tot_centroids, n_features); - /// <<<< End of Step-5 >>> - } /// <<<< Step-6 >>> - - RAFT_LOG_DEBUG("KMeans||: total # potential centroids sampled - %d", - potentialCentroids.extent(0)); - - if ((int)potentialCentroids.extent(0) > n_clusters) { - // <<< Step-7 >>>: For x in C, set w_x to be the number of pts closest to X - // temporary buffer to store the sample count per cluster, destructor - // releases the resource - auto weight = raft::make_device_vector(handle, potentialCentroids.extent(0)); - - detail::countSamplesInCluster( - handle, params, X, L2NormX.view(), potentialCentroids, workspace, weight.view()); - - // <<< end of Step-7 >>> - - // Step-8: Recluster the weighted points in C into k clusters - detail::kmeansPlusPlus( - handle, params, potentialCentroids, centroidsRawData, workspace); - - auto inertia = make_host_scalar(0); - auto n_iter = make_host_scalar(0); - KMeansParams default_params; - default_params.n_clusters = params.n_clusters; - - detail::kmeans_fit_main(handle, - default_params, - potentialCentroids, - weight.view(), - centroidsRawData, - inertia.view(), - n_iter.view(), - workspace); - - } else if ((int)potentialCentroids.extent(0) < n_clusters) { - // supplement with random - auto n_random_clusters = n_clusters - potentialCentroids.extent(0); - - RAFT_LOG_DEBUG( - "[Warning!] KMeans||: found fewer than %d centroids during " - "initialization (found %d centroids, remaining %d centroids will be " - "chosen randomly from input samples)", - n_clusters, - potentialCentroids.extent(0), - n_random_clusters); - - // generate `n_random_clusters` centroids - KMeansParams rand_params; - rand_params.init = KMeansParams::InitMethod::Random; - rand_params.n_clusters = n_random_clusters; - initRandom(handle, rand_params, X, centroidsRawData); - - // copy centroids generated during kmeans|| iteration to the buffer - raft::copy(centroidsRawData.data_handle() + n_random_clusters * n_features, - potentialCentroids.data_handle(), - potentialCentroids.size(), - stream); - } else { - // found the required n_clusters - raft::copy(centroidsRawData.data_handle(), - potentialCentroids.data_handle(), - potentialCentroids.size(), - stream); - } -} - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. It must be noted - * that the data must be in row-major format and stored in device accessible - * location. - * @param[in] n_samples Number of samples in the input X. - * @param[in] n_features Number of features or the dimensions of each - * sample. - * @param[in] sample_weight Optional weights for each observation in X. - * @param[inout] centroids [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] Otherwise, generated centroids from the - * kmeans algorithm is stored at the address pointed by 'centroids'. - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -void kmeans_fit(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - common::nvtx::range fun_scope("kmeans_fit"); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - cudaStream_t stream = resource::get_cuda_stream(handle); - // Check that parameters are valid - if (sample_weight.has_value()) - RAFT_EXPECTS(sample_weight.value().extent(0) == n_samples, - "invalid parameter (sample_weight!=n_samples)"); - RAFT_EXPECTS(n_clusters > 0, "invalid parameter (n_clusters<=0)"); - RAFT_EXPECTS(params.tol > 0, "invalid parameter (tol<=0)"); - RAFT_EXPECTS(params.oversampling_factor >= 0, "invalid parameter (oversampling_factor<0)"); - RAFT_EXPECTS((int)centroids.extent(0) == params.n_clusters, - "invalid parameter (centroids.extent(0) != n_clusters)"); - RAFT_EXPECTS(centroids.extent(1) == n_features, - "invalid parameter (centroids.extent(1) != n_features)"); - - // Display a message if the batch size is smaller than n_samples but will be ignored - if (params.batch_samples < (int)n_samples && - (params.metric == raft::distance::DistanceType::L2Expanded || - params.metric == raft::distance::DistanceType::L2SqrtExpanded)) { - RAFT_LOG_DEBUG( - "batch_samples=%d was passed, but batch_samples=%d will be used (reason: " - "batch_samples has no impact on the memory footprint when FusedL2NN can be used)", - params.batch_samples, - (int)n_samples); - } - // Display a message if batch_centroids is set and a fusedL2NN-compatible metric is used - if (params.batch_centroids != 0 && params.batch_centroids != params.n_clusters && - (params.metric == raft::distance::DistanceType::L2Expanded || - params.metric == raft::distance::DistanceType::L2SqrtExpanded)) { - RAFT_LOG_DEBUG( - "batch_centroids=%d was passed, but batch_centroids=%d will be used (reason: " - "batch_centroids has no impact on the memory footprint when FusedL2NN can be used)", - params.batch_centroids, - params.n_clusters); - } - - logger::get(RAFT_NAME).set_level(params.verbosity); - - // Allocate memory - rmm::device_uvector workspace(0, stream); - auto weight = raft::make_device_vector(handle, n_samples); - if (sample_weight.has_value()) - raft::copy(weight.data_handle(), sample_weight.value().data_handle(), n_samples, stream); - else - thrust::fill(resource::get_thrust_policy(handle), - weight.data_handle(), - weight.data_handle() + weight.size(), - 1); - - // check if weights sum up to n_samples - checkWeight(handle, weight.view(), workspace); - - auto centroidsRawData = raft::make_device_matrix(handle, n_clusters, n_features); - - auto n_init = params.n_init; - if (params.init == KMeansParams::InitMethod::Array && n_init != 1) { - RAFT_LOG_DEBUG( - "Explicit initial center position passed: performing only one init in " - "k-means instead of n_init=%d", - n_init); - n_init = 1; - } - - std::mt19937 gen(params.rng_state.seed); - inertia[0] = std::numeric_limits::max(); - - for (auto seed_iter = 0; seed_iter < n_init; ++seed_iter) { - KMeansParams iter_params = params; - iter_params.rng_state.seed = gen(); - - DataT iter_inertia = std::numeric_limits::max(); - IndexT n_current_iter = 0; - if (iter_params.init == KMeansParams::InitMethod::Random) { - // initializing with random samples from input dataset - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers by " - "randomly choosing from the " - "input data.", - seed_iter + 1, - n_init); - initRandom(handle, iter_params, X, centroidsRawData.view()); - } else if (iter_params.init == KMeansParams::InitMethod::KMeansPlusPlus) { - // default method to initialize is kmeans++ - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers using " - "k-means++ algorithm.", - seed_iter + 1, - n_init); - if (iter_params.oversampling_factor == 0) - detail::kmeansPlusPlus( - handle, iter_params, X, centroidsRawData.view(), workspace); - else - detail::initScalableKMeansPlusPlus( - handle, iter_params, X, centroidsRawData.view(), workspace); - } else if (iter_params.init == KMeansParams::InitMethod::Array) { - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers from " - "the ndarray array input " - "passed to init argument.", - seed_iter + 1, - n_init); - raft::copy( - centroidsRawData.data_handle(), centroids.data_handle(), n_clusters * n_features, stream); - } else { - THROW("unknown initialization method to select initial centers"); - } - - detail::kmeans_fit_main(handle, - iter_params, - X, - weight.view(), - centroidsRawData.view(), - raft::make_host_scalar_view(&iter_inertia), - raft::make_host_scalar_view(&n_current_iter), - workspace); - if (iter_inertia < inertia[0]) { - inertia[0] = iter_inertia; - n_iter[0] = n_current_iter; - raft::copy( - centroids.data_handle(), centroidsRawData.data_handle(), n_clusters * n_features, stream); - } - RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter[0] - %d", - seed_iter + 1, - n_init, - inertia[0], - n_iter[0]); - } - RAFT_LOG_DEBUG("KMeans.fit: async call returned (fit could still be running on the device)"); -} - -template -void kmeans_fit(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT& inertia, - IndexT& n_iter) -{ - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); - auto centroidsView = - raft::make_device_matrix_view(centroids, params.n_clusters, n_features); - std::optional> sample_weightView = std::nullopt; - if (sample_weight) - sample_weightView = - raft::make_device_vector_view(sample_weight, n_samples); - auto inertiaView = raft::make_host_scalar_view(&inertia); - auto n_iterView = raft::make_host_scalar_view(&n_iter); - - detail::kmeans_fit( - handle, params, XView, sample_weightView, centroidsView, inertiaView, n_iterView); -} - -template -void kmeans_predict(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) -{ - common::nvtx::range fun_scope("kmeans_predict"); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - cudaStream_t stream = resource::get_cuda_stream(handle); - // Check that parameters are valid - if (sample_weight.has_value()) - RAFT_EXPECTS(sample_weight.value().extent(0) == n_samples, - "invalid parameter (sample_weight!=n_samples)"); - RAFT_EXPECTS(params.n_clusters > 0, "invalid parameter (n_clusters<=0)"); - RAFT_EXPECTS(params.tol > 0, "invalid parameter (tol<=0)"); - RAFT_EXPECTS(params.oversampling_factor >= 0, "invalid parameter (oversampling_factor<0)"); - RAFT_EXPECTS((int)centroids.extent(0) == params.n_clusters, - "invalid parameter (centroids.extent(0) != n_clusters)"); - RAFT_EXPECTS(centroids.extent(1) == n_features, - "invalid parameter (centroids.extent(1) != n_features)"); - - logger::get(RAFT_NAME).set_level(params.verbosity); - auto metric = params.metric; - - // Allocate memory - // Device-accessible allocation of expandable storage used as temporary buffers - rmm::device_uvector workspace(0, stream); - auto weight = raft::make_device_vector(handle, n_samples); - if (sample_weight.has_value()) - raft::copy(weight.data_handle(), sample_weight.value().data_handle(), n_samples, stream); - else - thrust::fill(resource::get_thrust_policy(handle), - weight.data_handle(), - weight.data_handle() + weight.size(), - 1); - - // check if weights sum up to n_samples - if (normalize_weight) checkWeight(handle, weight.view(), workspace); - - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // L2 norm of X: ||x||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormX.data_handle(), - X.data_handle(), - X.extent(1), - X.extent(0), - raft::linalg::L2Norm, - true, - stream); - } - - // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] - // is a pair where - // 'key' is index to a sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - auto l2normx_view = - raft::make_device_vector_view(L2NormX.data_handle(), n_samples); - detail::minClusterAndDistanceCompute(handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // calculate cluster cost phi_x(C) - rmm::device_scalar clusterCostD(stream); - // TODO: add different templates for InType of binaryOp to avoid thrust transform - thrust::transform(resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - weight.data_handle(), - minClusterAndDistance.data_handle(), - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }); - - detail::computeClusterCost(handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - - thrust::transform(resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - labels.data_handle(), - raft::key_op{}); - - inertia[0] = clusterCostD.value(stream); -} - -template -void kmeans_predict(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - IndexT* labels, - bool normalize_weight, - DataT& inertia) -{ - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); - auto centroidsView = - raft::make_device_matrix_view(centroids, params.n_clusters, n_features); - std::optional> sample_weightView{std::nullopt}; - if (sample_weight) - sample_weightView.emplace( - raft::make_device_vector_view(sample_weight, n_samples)); - auto labelsView = raft::make_device_vector_view(labels, n_samples); - auto inertiaView = raft::make_host_scalar_view(&inertia); - - detail::kmeans_predict(handle, - params, - XView, - sample_weightView, - centroidsView, - labelsView, - normalize_weight, - inertiaView); -} - -template -void kmeans_fit_predict(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - common::nvtx::range fun_scope("kmeans_fit_predict"); - if (!centroids.has_value()) { - auto n_features = X.extent(1); - auto centroids_matrix = - raft::make_device_matrix(handle, params.n_clusters, n_features); - detail::kmeans_fit( - handle, params, X, sample_weight, centroids_matrix.view(), inertia, n_iter); - detail::kmeans_predict( - handle, params, X, sample_weight, centroids_matrix.view(), labels, true, inertia); - } else { - detail::kmeans_fit( - handle, params, X, sample_weight, centroids.value(), inertia, n_iter); - detail::kmeans_predict( - handle, params, X, sample_weight, centroids.value(), labels, true, inertia); - } -} - -template -void kmeans_fit_predict(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - DataT* centroids, - IndexT n_samples, - IndexT n_features, - IndexT* labels, - DataT& inertia, - IndexT& n_iter) -{ - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); - std::optional> sample_weightView{std::nullopt}; - if (sample_weight) - sample_weightView.emplace( - raft::make_device_vector_view(sample_weight, n_samples)); - std::optional> centroidsView{std::nullopt}; - if (centroids) - centroidsView.emplace( - raft::make_device_matrix_view(centroids, params.n_clusters, n_features)); - auto labelsView = raft::make_device_vector_view(labels, n_samples); - auto inertiaView = raft::make_host_scalar_view(&inertia); - auto n_iterView = raft::make_host_scalar_view(&n_iter); - - detail::kmeans_fit_predict( - handle, params, XView, sample_weightView, centroidsView, labelsView, inertiaView, n_iterView); -} - -/** - * @brief Transform X to a cluster-distance space. - * - * @param[in] handle The handle to the cuML library context that - * manages the CUDA resources. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format - * @param[in] centroids Cluster centroids. The data must be in row-major format. - * @param[out] X_new X transformed in the new space.. - */ -template -void kmeans_transform(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_matrix_view X_new) -{ - common::nvtx::range fun_scope("kmeans_transform"); - logger::get(RAFT_NAME).set_level(params.verbosity); - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - // Device-accessible allocation of expandable storage used as temporary buffers - rmm::device_uvector workspace(0, stream); - auto dataBatchSize = getDataBatchSize(params.batch_samples, n_samples); - - // tile over the input data and calculate distance matrix [n_samples x - // n_clusters] - for (IndexT dIdx = 0; dIdx < (IndexT)n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - auto ns = std::min(static_cast(dataBatchSize), static_cast(n_samples - dIdx)); - - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = raft::make_device_matrix_view( - X.data_handle() + n_features * dIdx, ns, n_features); - - // pairwiseDistanceView [ns x n_clusters] - auto pairwiseDistanceView = raft::make_device_matrix_view( - X_new.data_handle() + n_clusters * dIdx, ns, n_clusters); - - // calculate pairwise distance between cluster centroids and current batch - // of input dataset - pairwise_distance_kmeans( - handle, datasetView, centroids, pairwiseDistanceView, workspace, metric); - } -} - -template -void kmeans_transform(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT* X_new) -{ - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); - auto centroidsView = - raft::make_device_matrix_view(centroids, params.n_clusters, n_features); - auto X_newView = raft::make_device_matrix_view(X_new, n_samples, n_features); - - detail::kmeans_transform(handle, params, XView, centroidsView, X_newView); -} -} // namespace detail -} // namespace cluster -} // namespace raft diff --git a/cpp/include/raft/cluster/detail/kmeans_auto_find_k.cuh b/cpp/include/raft/cluster/detail/kmeans_auto_find_k.cuh deleted file mode 100644 index 97755351c4..0000000000 --- a/cpp/include/raft/cluster/detail/kmeans_auto_find_k.cuh +++ /dev/null @@ -1,230 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::cluster::detail { - -template -void compute_dispersion(raft::resources const& handle, - raft::device_matrix_view X, - KMeansParams& params, - raft::device_matrix_view centroids_view, - raft::device_vector_view labels, - raft::device_vector_view clusterSizes, - rmm::device_uvector& workspace, - raft::host_vector_view clusterDispertionView, - raft::host_vector_view resultsView, - raft::host_scalar_view residual, - raft::host_scalar_view n_iter, - int val, - idx_t n, - idx_t d) -{ - auto centroids_const_view = - raft::make_device_matrix_view(centroids_view.data_handle(), val, d); - - idx_t* clusterSizes_ptr = clusterSizes.data_handle(); - auto cluster_sizes_view = - raft::make_device_vector_view(clusterSizes_ptr, val); - - params.n_clusters = val; - - raft::cluster::detail::kmeans_fit_predict( - handle, params, X, std::nullopt, std::make_optional(centroids_view), labels, residual, n_iter); - - detail::countLabels(handle, labels.data_handle(), clusterSizes.data_handle(), n, val, workspace); - - resultsView[val] = residual[0]; - clusterDispertionView[val] = raft::stats::cluster_dispersion( - handle, centroids_const_view, cluster_sizes_view, std::nullopt, n); -} - -template -void find_k(raft::resources const& handle, - raft::device_matrix_view X, - raft::host_scalar_view best_k, - raft::host_scalar_view residual, - raft::host_scalar_view n_iter, - idx_t kmax, - idx_t kmin = 1, - idx_t maxiter = 100, - value_t tol = 1e-2) -{ - idx_t n = X.extent(0); - idx_t d = X.extent(1); - - RAFT_EXPECTS(n >= 1, "n must be >= 1"); - RAFT_EXPECTS(d >= 1, "d must be >= 1"); - RAFT_EXPECTS(kmin >= 1, "kmin must be >= 1"); - RAFT_EXPECTS(kmax <= n, "kmax must be <= number of data samples in X"); - RAFT_EXPECTS(tol >= 0, "tolerance must be >= 0"); - RAFT_EXPECTS(maxiter >= 0, "maxiter must be >= 0"); - // Allocate memory - // Device memory - - auto centroids = raft::make_device_matrix(handle, kmax, X.extent(1)); - auto clusterSizes = raft::make_device_vector(handle, kmax); - auto labels = raft::make_device_vector(handle, n); - - rmm::device_uvector workspace(0, resource::get_cuda_stream(handle)); - - idx_t* clusterSizes_ptr = clusterSizes.data_handle(); - - // Host memory - auto results = raft::make_host_vector(kmax + 1); - auto clusterDispersion = raft::make_host_vector(kmax + 1); - - auto clusterDispertionView = clusterDispersion.view(); - auto resultsView = results.view(); - - // Loop to find *best* k - // Perform k-means in binary search - int left = kmin; // must be at least 2 - int right = kmax; // int(floor(len(data)/2)) #assumption of clusters of size 2 at least - int mid = ((unsigned int)left + (unsigned int)right) >> 1; - int oldmid = mid; - int tests = 0; - double objective[3]; // 0= left of mid, 1= right of mid - if (left == 1) left = 2; // at least do 2 clusters - - KMeansParams params; - params.max_iter = maxiter; - params.tol = tol; - - auto centroids_view = - raft::make_device_matrix_view(centroids.data_handle(), left, d); - compute_dispersion(handle, - X, - params, - centroids_view, - labels.view(), - clusterSizes.view(), - workspace, - clusterDispertionView, - resultsView, - residual, - n_iter, - left, - n, - d); - - // eval right edge0 - resultsView[right] = 1e20; - while (resultsView[right] > resultsView[left] && tests < 3) { - centroids_view = - raft::make_device_matrix_view(centroids.data_handle(), right, d); - compute_dispersion(handle, - X, - params, - centroids_view, - labels.view(), - clusterSizes.view(), - workspace, - clusterDispertionView, - resultsView, - residual, - n_iter, - right, - n, - d); - - tests += 1; - } - - objective[0] = (n - left) / (left - 1) * clusterDispertionView[left] / resultsView[left]; - objective[1] = (n - right) / (right - 1) * clusterDispertionView[right] / resultsView[right]; - while (left < right - 1) { - resultsView[mid] = 1e20; - tests = 0; - while (resultsView[mid] > resultsView[left] && tests < 3) { - centroids_view = - raft::make_device_matrix_view(centroids.data_handle(), mid, d); - compute_dispersion(handle, - X, - params, - centroids_view, - labels.view(), - clusterSizes.view(), - workspace, - clusterDispertionView, - resultsView, - residual, - n_iter, - mid, - n, - d); - - if (resultsView[mid] > resultsView[left] && (mid + 1) < right) { - mid += 1; - resultsView[mid] = 1e20; - } else if (resultsView[mid] > resultsView[left] && (mid - 1) > left) { - mid -= 1; - resultsView[mid] = 1e20; - } - tests += 1; - } - - // maximize Calinski-Harabasz Index, minimize resid/ cluster - objective[0] = (n - left) / (left - 1) * clusterDispertionView[left] / resultsView[left]; - objective[1] = (n - right) / (right - 1) * clusterDispertionView[right] / resultsView[right]; - objective[2] = (n - mid) / (mid - 1) * clusterDispertionView[mid] / resultsView[mid]; - objective[0] = (objective[2] - objective[0]) / (mid - left); - objective[1] = (objective[1] - objective[2]) / (right - mid); - - if (objective[0] > 0 && objective[1] < 0) { - // our point is in the left-of-mid side - right = mid; - } else { - left = mid; - } - oldmid = mid; - mid = ((unsigned int)right + (unsigned int)left) >> 1; - } - - best_k[0] = right; - objective[0] = (n - left) / (left - 1) * clusterDispertionView[left] / resultsView[left]; - objective[1] = (n - oldmid) / (oldmid - 1) * clusterDispertionView[oldmid] / resultsView[oldmid]; - if (objective[1] < objective[0]) { best_k[0] = left; } - - // if best_k isn't what we just ran, re-run to get correct centroids and dist data on return-> - // this saves memory - if (best_k[0] != oldmid) { - auto centroids_view = - raft::make_device_matrix_view(centroids.data_handle(), best_k[0], d); - - params.n_clusters = best_k[0]; - raft::cluster::detail::kmeans_fit_predict(handle, - params, - X, - std::nullopt, - std::make_optional(centroids_view), - labels.view(), - residual, - n_iter); - } -} -} // namespace raft::cluster::detail \ No newline at end of file diff --git a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh deleted file mode 100644 index 0a5a3ba5aa..0000000000 --- a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh +++ /dev/null @@ -1,1089 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include -#include - -#include -#include -#include -#include - -namespace raft::cluster::detail { - -constexpr static inline float kAdjustCentersWeight = 7.0f; - -/** - * @brief Predict labels for the dataset; floating-point types only. - * - * NB: no minibatch splitting is done here, it may require large amount of temporary memory (n_rows - * * n_cluster * sizeof(MathT)). - * - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * - * @param[in] handle The raft handle. - * @param[in] params Structure containing the hyper-parameters - * @param[in] centers Pointer to the row-major matrix of cluster centers [n_clusters, dim] - * @param[in] n_clusters Number of clusters/centers - * @param[in] dim Dimensionality of the data - * @param[in] dataset Pointer to the data [n_rows, dim] - * @param[in] dataset_norm Pointer to the precomputed norm (for L2 metrics only) [n_rows] - * @param[in] n_rows Number samples in the `dataset` - * @param[out] labels Output predictions [n_rows] - * @param[inout] mr (optional) Memory resource to use for temporary allocations - */ -template -inline std::enable_if_t> predict_core( - const raft::resources& handle, - const kmeans_balanced_params& params, - const MathT* centers, - IdxT n_clusters, - IdxT dim, - const MathT* dataset, - const MathT* dataset_norm, - IdxT n_rows, - LabelT* labels, - rmm::device_async_resource_ref mr) -{ - auto stream = resource::get_cuda_stream(handle); - switch (params.metric) { - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2SqrtExpanded: { - auto workspace = raft::make_device_mdarray( - handle, mr, make_extents((sizeof(int)) * n_rows)); - - auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( - handle, mr, make_extents(n_rows)); - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - thrust::fill(resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - initial_value); - - auto centroidsNorm = - raft::make_device_mdarray(handle, mr, make_extents(n_clusters)); - raft::linalg::rowNorm( - centroidsNorm.data_handle(), centers, dim, n_clusters, raft::linalg::L2Norm, true, stream); - - raft::distance::fusedL2NNMinReduce, IdxT>( - minClusterAndDistance.data_handle(), - dataset, - centers, - dataset_norm, - centroidsNorm.data_handle(), - n_rows, - n_clusters, - dim, - (void*)workspace.data_handle(), - (params.metric == raft::distance::DistanceType::L2Expanded) ? false : true, - false, - stream); - - // todo(lsugy): use KVP + iterator in caller. - // Copy keys to output labels - thrust::transform(resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + n_rows, - labels, - raft::compose_op, raft::key_op>()); - break; - } - case raft::distance::DistanceType::InnerProduct: { - // TODO: pass buffer - rmm::device_uvector distances(n_rows * n_clusters, stream, mr); - - MathT alpha = -1.0; - MathT beta = 0.0; - - linalg::gemm(handle, - true, - false, - n_clusters, - n_rows, - dim, - &alpha, - centers, - dim, - dataset, - dim, - &beta, - distances.data(), - n_clusters, - stream); - - auto distances_const_view = raft::make_device_matrix_view( - distances.data(), n_rows, n_clusters); - auto labels_view = raft::make_device_vector_view(labels, n_rows); - raft::matrix::argmin(handle, distances_const_view, labels_view); - break; - } - default: { - RAFT_FAIL("The chosen distance metric is not supported (%d)", int(params.metric)); - } - } -} - -/** - * @brief Suggest a minibatch size for kmeans prediction. - * - * This function is used as a heuristic to split the work over a large dataset - * to reduce the size of temporary memory allocations. - * - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * - * @param[in] n_clusters number of clusters in kmeans clustering - * @param[in] n_rows Number of samples in the dataset - * @param[in] dim Number of features in the dataset - * @param[in] metric Distance metric - * @param[in] needs_conversion Whether the data needs to be converted to MathT - * @return A suggested minibatch size and the expected memory cost per-row (in bytes) - */ -template -constexpr auto calc_minibatch_size(IdxT n_clusters, - IdxT n_rows, - IdxT dim, - raft::distance::DistanceType metric, - bool needs_conversion) -> std::tuple -{ - n_clusters = std::max(1, n_clusters); - - // Estimate memory needs per row (i.e element of the batch). - size_t mem_per_row = 0; - switch (metric) { - // fusedL2NN needs a mutex and a key-value pair for each row. - case distance::DistanceType::L2Expanded: - case distance::DistanceType::L2SqrtExpanded: { - mem_per_row += sizeof(int); - mem_per_row += sizeof(raft::KeyValuePair); - } break; - // Other metrics require storing a distance matrix. - default: { - mem_per_row += sizeof(MathT) * n_clusters; - } - } - - // If we need to convert to MathT, space required for the converted batch. - if (!needs_conversion) { mem_per_row += sizeof(MathT) * dim; } - - // Heuristic: calculate the minibatch size in order to use at most 1GB of memory. - IdxT minibatch_size = (1 << 30) / mem_per_row; - minibatch_size = 64 * div_rounding_up_safe(minibatch_size, IdxT{64}); - minibatch_size = std::min(minibatch_size, n_rows); - return std::make_tuple(minibatch_size, mem_per_row); -} - -/** - * @brief Given the data and labels, calculate cluster centers and sizes in one sweep. - * - * @note all pointers must be accessible on the device. - * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * @tparam CounterT counter type supported by CUDA's native atomicAdd - * @tparam MappingOpT type of the mapping operation - * - * @param[in] handle The raft handle. - * @param[inout] centers Pointer to the output [n_clusters, dim] - * @param[inout] cluster_sizes Number of rows in each cluster [n_clusters] - * @param[in] n_clusters Number of clusters/centers - * @param[in] dim Dimensionality of the data - * @param[in] dataset Pointer to the data [n_rows, dim] - * @param[in] n_rows Number of samples in the `dataset` - * @param[in] labels Output predictions [n_rows] - * @param[in] reset_counters Whether to clear the output arrays before calculating. - * When set to `false`, this function may be used to update existing centers and sizes using - * the weighted average principle. - * @param[in] mapping_op Mapping operation from T to MathT - * @param[inout] mr (optional) Memory resource to use for temporary allocations on the device - */ -template -void calc_centers_and_sizes(const raft::resources& handle, - MathT* centers, - CounterT* cluster_sizes, - IdxT n_clusters, - IdxT dim, - const T* dataset, - IdxT n_rows, - const LabelT* labels, - bool reset_counters, - MappingOpT mapping_op, - rmm::device_async_resource_ref mr) -{ - auto stream = resource::get_cuda_stream(handle); - - if (!reset_counters) { - raft::linalg::matrixVectorOp( - centers, centers, cluster_sizes, dim, n_clusters, true, false, raft::mul_op(), stream); - } - - rmm::device_uvector workspace(0, stream, mr); - - // If we reset the counters, we can compute directly the new sizes in cluster_sizes. - // If we don't reset, we compute in a temporary buffer and add in a separate step. - rmm::device_uvector temp_cluster_sizes(0, stream, mr); - CounterT* temp_sizes = cluster_sizes; - if (!reset_counters) { - temp_cluster_sizes.resize(n_clusters, stream); - temp_sizes = temp_cluster_sizes.data(); - } - - // Apply mapping only when the data and math types are different. - if constexpr (std::is_same_v) { - raft::linalg::reduce_rows_by_key( - dataset, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); - } else { - // todo(lsugy): use iterator from KV output of fusedL2NN - cub::TransformInputIterator mapping_itr(dataset, mapping_op); - raft::linalg::reduce_rows_by_key( - mapping_itr, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); - } - - // Compute weight of each cluster - raft::cluster::detail::countLabels(handle, labels, temp_sizes, n_rows, n_clusters, workspace); - - // Add previous sizes if necessary - if (!reset_counters) { - raft::linalg::add(cluster_sizes, cluster_sizes, temp_sizes, n_clusters, stream); - } - - raft::linalg::matrixVectorOp(centers, - centers, - cluster_sizes, - dim, - n_clusters, - true, - false, - raft::div_checkzero_op(), - stream); -} - -/** Computes the L2 norm of the dataset, converting to MathT if necessary */ -template -void compute_norm(const raft::resources& handle, - MathT* dataset_norm, - const T* dataset, - IdxT dim, - IdxT n_rows, - MappingOpT mapping_op, - std::optional mr = std::nullopt) -{ - common::nvtx::range fun_scope("compute_norm"); - auto stream = resource::get_cuda_stream(handle); - rmm::device_uvector mapped_dataset( - 0, stream, mr.value_or(resource::get_workspace_resource(handle))); - - const MathT* dataset_ptr = nullptr; - - if (std::is_same_v) { - dataset_ptr = reinterpret_cast(dataset); - } else { - mapped_dataset.resize(n_rows * dim, stream); - - linalg::unaryOp(mapped_dataset.data(), dataset, n_rows * dim, mapping_op, stream); - - dataset_ptr = static_cast(mapped_dataset.data()); - } - - raft::linalg::rowNorm( - dataset_norm, dataset_ptr, dim, n_rows, raft::linalg::L2Norm, true, stream); -} - -/** - * @brief Predict labels for the dataset. - * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * @tparam MappingOpT type of the mapping operation - * - * @param[in] handle The raft handle - * @param[in] params Structure containing the hyper-parameters - * @param[in] centers Pointer to the row-major matrix of cluster centers [n_clusters, dim] - * @param[in] n_clusters Number of clusters/centers - * @param[in] dim Dimensionality of the data - * @param[in] dataset Pointer to the data [n_rows, dim] - * @param[in] n_rows Number samples in the `dataset` - * @param[out] labels Output predictions [n_rows] - * @param[in] mapping_op Mapping operation from T to MathT - * @param[inout] mr (optional) memory resource to use for temporary allocations - * @param[in] dataset_norm (optional) Pre-computed norms of each row in the dataset [n_rows] - */ -template -void predict(const raft::resources& handle, - const kmeans_balanced_params& params, - const MathT* centers, - IdxT n_clusters, - IdxT dim, - const T* dataset, - IdxT n_rows, - LabelT* labels, - MappingOpT mapping_op, - std::optional mr = std::nullopt, - const MathT* dataset_norm = nullptr) -{ - auto stream = resource::get_cuda_stream(handle); - common::nvtx::range fun_scope( - "predict(%zu, %u)", static_cast(n_rows), n_clusters); - auto mem_res = mr.value_or(resource::get_workspace_resource(handle)); - auto [max_minibatch_size, _mem_per_row] = - calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); - rmm::device_uvector cur_dataset( - std::is_same_v ? 0 : max_minibatch_size * dim, stream, mem_res); - bool need_compute_norm = - dataset_norm == nullptr && (params.metric == raft::distance::DistanceType::L2Expanded || - params.metric == raft::distance::DistanceType::L2SqrtExpanded); - rmm::device_uvector cur_dataset_norm( - need_compute_norm ? max_minibatch_size : 0, stream, mem_res); - const MathT* dataset_norm_ptr = nullptr; - auto cur_dataset_ptr = cur_dataset.data(); - for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { - IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); - - if constexpr (std::is_same_v) { - cur_dataset_ptr = const_cast(dataset + offset * dim); - } else { - linalg::unaryOp( - cur_dataset_ptr, dataset + offset * dim, minibatch_size * dim, mapping_op, stream); - } - - // Compute the norm now if it hasn't been pre-computed. - if (need_compute_norm) { - compute_norm( - handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mem_res); - dataset_norm_ptr = cur_dataset_norm.data(); - } else if (dataset_norm != nullptr) { - dataset_norm_ptr = dataset_norm + offset; - } - - predict_core(handle, - params, - centers, - n_clusters, - dim, - cur_dataset_ptr, - dataset_norm_ptr, - minibatch_size, - labels + offset, - mem_res); - } -} - -template -__launch_bounds__((WarpSize * BlockDimY)) RAFT_KERNEL - adjust_centers_kernel(MathT* centers, // [n_clusters, dim] - IdxT n_clusters, - IdxT dim, - const T* dataset, // [n_rows, dim] - IdxT n_rows, - const LabelT* labels, // [n_rows] - const CounterT* cluster_sizes, // [n_clusters] - MathT threshold, - IdxT average, - IdxT seed, - IdxT* count, - MappingOpT mapping_op) -{ - IdxT l = threadIdx.y + BlockDimY * static_cast(blockIdx.y); - if (l >= n_clusters) return; - auto csize = static_cast(cluster_sizes[l]); - // skip big clusters - if (csize > static_cast(average * threshold)) return; - - // choose a "random" i that belongs to a rather large cluster - IdxT i; - IdxT j = laneId(); - if (j == 0) { - do { - auto old = atomicAdd(count, IdxT{1}); - i = (seed * (old + 1)) % n_rows; - } while (static_cast(cluster_sizes[labels[i]]) < average); - } - i = raft::shfl(i, 0); - - // Adjust the center of the selected smaller cluster to gravitate towards - // a sample from the selected larger cluster. - const IdxT li = static_cast(labels[i]); - // Weight of the current center for the weighted average. - // We dump it for anomalously small clusters, but keep constant otherwise. - const MathT wc = min(static_cast(csize), static_cast(kAdjustCentersWeight)); - // Weight for the datapoint used to shift the center. - const MathT wd = 1.0; - for (; j < dim; j += WarpSize) { - MathT val = 0; - val += wc * centers[j + dim * li]; - val += wd * mapping_op(dataset[j + dim * i]); - val /= wc + wd; - centers[j + dim * l] = val; - } -} - -/** - * @brief Adjust centers for clusters that have small number of entries. - * - * For each cluster, where the cluster size is not bigger than a threshold, the center is moved - * towards a data point that belongs to a large cluster. - * - * NB: if this function returns `true`, you should update the labels. - * - * NB: all pointers must be on the device side. - * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * @tparam CounterT counter type supported by CUDA's native atomicAdd - * @tparam MappingOpT type of the mapping operation - * - * @param[inout] centers cluster centers [n_clusters, dim] - * @param[in] n_clusters number of rows in `centers` - * @param[in] dim number of columns in `centers` and `dataset` - * @param[in] dataset a host pointer to the row-major data matrix [n_rows, dim] - * @param[in] n_rows number of rows in `dataset` - * @param[in] labels a host pointer to the cluster indices [n_rows] - * @param[in] cluster_sizes number of rows in each cluster [n_clusters] - * @param[in] threshold defines a criterion for adjusting a cluster - * (cluster_sizes <= average_size * threshold) - * 0 <= threshold < 1 - * @param[in] mapping_op Mapping operation from T to MathT - * @param[in] stream CUDA stream - * @param[inout] device_memory memory resource to use for temporary allocations - * - * @return whether any of the centers has been updated (and thus, `labels` need to be recalculated). - */ -template -auto adjust_centers(MathT* centers, - IdxT n_clusters, - IdxT dim, - const T* dataset, - IdxT n_rows, - const LabelT* labels, - const CounterT* cluster_sizes, - MathT threshold, - MappingOpT mapping_op, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref device_memory) -> bool -{ - common::nvtx::range fun_scope( - "adjust_centers(%zu, %u)", static_cast(n_rows), n_clusters); - if (n_clusters == 0) { return false; } - constexpr static std::array kPrimes{29, 71, 113, 173, 229, 281, 349, 409, 463, 541, - 601, 659, 733, 809, 863, 941, 1013, 1069, 1151, 1223, - 1291, 1373, 1451, 1511, 1583, 1657, 1733, 1811, 1889, 1987, - 2053, 2129, 2213, 2287, 2357, 2423, 2531, 2617, 2687, 2741}; - static IdxT i = 0; - static IdxT i_primes = 0; - - bool adjusted = false; - IdxT average = n_rows / n_clusters; - IdxT ofst; - do { - i_primes = (i_primes + 1) % kPrimes.size(); - ofst = kPrimes[i_primes]; - } while (n_rows % ofst == 0); - - constexpr uint32_t kBlockDimY = 4; - const dim3 block_dim(WarpSize, kBlockDimY, 1); - const dim3 grid_dim(1, raft::ceildiv(n_clusters, static_cast(kBlockDimY)), 1); - rmm::device_scalar update_count(0, stream, device_memory); - adjust_centers_kernel<<>>(centers, - n_clusters, - dim, - dataset, - n_rows, - labels, - cluster_sizes, - threshold, - average, - ofst, - update_count.data(), - mapping_op); - adjusted = update_count.value(stream) > 0; // NB: rmm scalar performs the sync - - return adjusted; -} - -/** - * @brief Expectation-maximization-balancing combined in an iterative process. - * - * Note, the `cluster_centers` is assumed to be already initialized here. - * Thus, this function can be used for fine-tuning existing clusters; - * to train from scratch, use `build_clusters` function below. - * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * @tparam CounterT counter type supported by CUDA's native atomicAdd - * @tparam MappingOpT type of the mapping operation - * - * @param[in] handle The raft handle - * @param[in] params Structure containing the hyper-parameters - * @param[in] n_iters Requested number of iterations (can differ from params.n_iter!) - * @param[in] dim Dimensionality of the dataset - * @param[in] dataset Pointer to a managed row-major array [n_rows, dim] - * @param[in] dataset_norm Pointer to the precomputed norm (for L2 metrics only) [n_rows] - * @param[in] n_rows Number of rows in the dataset - * @param[in] n_cluster Requested number of clusters - * @param[inout] cluster_centers Pointer to a managed row-major array [n_clusters, dim] - * @param[out] cluster_labels Pointer to a managed row-major array [n_rows] - * @param[out] cluster_sizes Pointer to a managed row-major array [n_clusters] - * @param[in] balancing_pullback - * if the cluster centers are rebalanced on this number of iterations, - * one extra iteration is performed (this could happen several times) (default should be `2`). - * In other words, the first and then every `ballancing_pullback`-th rebalancing operation adds - * one more iteration to the main cycle. - * @param[in] balancing_threshold - * the rebalancing takes place if any cluster is smaller than `avg_size * balancing_threshold` - * on a given iteration (default should be `~ 0.25`). - * @param[in] mapping_op Mapping operation from T to MathT - * @param[inout] device_memory - * A memory resource for device allocations (makes sense to provide a memory pool here) - */ -template -void balancing_em_iters(const raft::resources& handle, - const kmeans_balanced_params& params, - uint32_t n_iters, - IdxT dim, - const T* dataset, - const MathT* dataset_norm, - IdxT n_rows, - IdxT n_clusters, - MathT* cluster_centers, - LabelT* cluster_labels, - CounterT* cluster_sizes, - uint32_t balancing_pullback, - MathT balancing_threshold, - MappingOpT mapping_op, - rmm::device_async_resource_ref device_memory) -{ - auto stream = resource::get_cuda_stream(handle); - uint32_t balancing_counter = balancing_pullback; - for (uint32_t iter = 0; iter < n_iters; iter++) { - // Balancing step - move the centers around to equalize cluster sizes - // (but not on the first iteration) - if (iter > 0 && adjust_centers(cluster_centers, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - cluster_sizes, - balancing_threshold, - mapping_op, - stream, - device_memory)) { - if (balancing_counter++ >= balancing_pullback) { - balancing_counter -= balancing_pullback; - n_iters++; - } - } - switch (params.metric) { - // For some metrics, cluster calculation and adjustment tends to favor zero center vectors. - // To avoid converging to zero, we normalize the center vectors on every iteration. - case raft::distance::DistanceType::InnerProduct: - case raft::distance::DistanceType::CosineExpanded: - case raft::distance::DistanceType::CorrelationExpanded: { - auto clusters_in_view = raft::make_device_matrix_view( - cluster_centers, n_clusters, dim); - auto clusters_out_view = raft::make_device_matrix_view( - cluster_centers, n_clusters, dim); - raft::linalg::row_normalize( - handle, clusters_in_view, clusters_out_view, raft::linalg::L2Norm); - break; - } - default: break; - } - // E: Expectation step - predict labels - predict(handle, - params, - cluster_centers, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - mapping_op, - device_memory, - dataset_norm); - // M: Maximization step - calculate optimal cluster centers - calc_centers_and_sizes(handle, - cluster_centers, - cluster_sizes, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - true, - mapping_op, - device_memory); - } -} - -/** Randomly initialize cluster centers and then call `balancing_em_iters`. */ -template -void build_clusters(const raft::resources& handle, - const kmeans_balanced_params& params, - IdxT dim, - const T* dataset, - IdxT n_rows, - IdxT n_clusters, - MathT* cluster_centers, - LabelT* cluster_labels, - CounterT* cluster_sizes, - MappingOpT mapping_op, - rmm::device_async_resource_ref device_memory, - const MathT* dataset_norm = nullptr) -{ - auto stream = resource::get_cuda_stream(handle); - - // "randomly" initialize labels - auto labels_view = raft::make_device_vector_view(cluster_labels, n_rows); - linalg::map_offset( - handle, - labels_view, - raft::compose_op(raft::cast_op(), raft::mod_const_op(n_clusters))); - - // update centers to match the initialized labels. - calc_centers_and_sizes(handle, - cluster_centers, - cluster_sizes, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - true, - mapping_op, - device_memory); - - // run EM - balancing_em_iters(handle, - params, - params.n_iters, - dim, - dataset, - dataset_norm, - n_rows, - n_clusters, - cluster_centers, - cluster_labels, - cluster_sizes, - 2, - MathT{0.25}, - mapping_op, - device_memory); -} - -/** Calculate how many fine clusters should belong to each mesocluster. */ -template -inline auto arrange_fine_clusters(IdxT n_clusters, - IdxT n_mesoclusters, - IdxT n_rows, - const CounterT* mesocluster_sizes) -{ - std::vector fine_clusters_nums(n_mesoclusters); - std::vector fine_clusters_csum(n_mesoclusters + 1); - fine_clusters_csum[0] = 0; - - IdxT n_lists_rem = n_clusters; - IdxT n_nonempty_ms_rem = 0; - for (IdxT i = 0; i < n_mesoclusters; i++) { - n_nonempty_ms_rem += mesocluster_sizes[i] > CounterT{0} ? 1 : 0; - } - IdxT n_rows_rem = n_rows; - CounterT mesocluster_size_sum = 0; - CounterT mesocluster_size_max = 0; - IdxT fine_clusters_nums_max = 0; - for (IdxT i = 0; i < n_mesoclusters; i++) { - if (i < n_mesoclusters - 1) { - // Although the algorithm is meant to produce balanced clusters, when something - // goes wrong, we may get empty clusters (e.g. during development/debugging). - // The code below ensures a proportional arrangement of fine cluster numbers - // per mesocluster, even if some clusters are empty. - if (mesocluster_sizes[i] == 0) { - fine_clusters_nums[i] = 0; - } else { - n_nonempty_ms_rem--; - auto s = static_cast( - static_cast(n_lists_rem * mesocluster_sizes[i]) / n_rows_rem + .5); - s = std::min(s, n_lists_rem - n_nonempty_ms_rem); - fine_clusters_nums[i] = std::max(s, IdxT{1}); - } - } else { - fine_clusters_nums[i] = n_lists_rem; - } - n_lists_rem -= fine_clusters_nums[i]; - n_rows_rem -= mesocluster_sizes[i]; - mesocluster_size_max = max(mesocluster_size_max, mesocluster_sizes[i]); - mesocluster_size_sum += mesocluster_sizes[i]; - fine_clusters_nums_max = max(fine_clusters_nums_max, fine_clusters_nums[i]); - fine_clusters_csum[i + 1] = fine_clusters_csum[i] + fine_clusters_nums[i]; - } - - RAFT_EXPECTS(static_cast(mesocluster_size_sum) == n_rows, - "mesocluster sizes do not add up (%zu) to the total trainset size (%zu)", - static_cast(mesocluster_size_sum), - static_cast(n_rows)); - RAFT_EXPECTS(fine_clusters_csum[n_mesoclusters] == n_clusters, - "fine cluster numbers do not add up (%zu) to the total number of clusters (%zu)", - static_cast(fine_clusters_csum[n_mesoclusters]), - static_cast(n_clusters)); - - return std::make_tuple(static_cast(mesocluster_size_max), - fine_clusters_nums_max, - std::move(fine_clusters_nums), - std::move(fine_clusters_csum)); -} - -/** - * Given the (coarse) mesoclusters and the distribution of fine clusters within them, - * build the fine clusters. - * - * Processing one mesocluster at a time: - * 1. Copy mesocluster data into a separate buffer - * 2. Predict fine cluster - * 3. Refince the fine cluster centers - * - * As a result, the fine clusters are what is returned by `build_hierarchical`; - * this function returns the total number of fine clusters, which can be checked to be - * the same as the requested number of clusters. - * - * Note: this function uses at most `fine_clusters_nums_max` points per mesocluster for training; - * if one of the clusters is larger than that (as given by `mesocluster_sizes`), the extra data - * is ignored and a warning is reported. - */ -template -auto build_fine_clusters(const raft::resources& handle, - const kmeans_balanced_params& params, - IdxT dim, - const T* dataset_mptr, - const MathT* dataset_norm_mptr, - const LabelT* labels_mptr, - IdxT n_rows, - const IdxT* fine_clusters_nums, - const IdxT* fine_clusters_csum, - const CounterT* mesocluster_sizes, - IdxT n_mesoclusters, - IdxT mesocluster_size_max, - IdxT fine_clusters_nums_max, - MathT* cluster_centers, - MappingOpT mapping_op, - rmm::device_async_resource_ref managed_memory, - rmm::device_async_resource_ref device_memory) -> IdxT -{ - auto stream = resource::get_cuda_stream(handle); - rmm::device_uvector mc_trainset_ids_buf(mesocluster_size_max, stream, managed_memory); - rmm::device_uvector mc_trainset_buf(mesocluster_size_max * dim, stream, device_memory); - rmm::device_uvector mc_trainset_norm_buf(mesocluster_size_max, stream, device_memory); - auto mc_trainset_ids = mc_trainset_ids_buf.data(); - auto mc_trainset = mc_trainset_buf.data(); - auto mc_trainset_norm = mc_trainset_norm_buf.data(); - - // label (cluster ID) of each vector - rmm::device_uvector mc_trainset_labels(mesocluster_size_max, stream, device_memory); - - rmm::device_uvector mc_trainset_ccenters( - fine_clusters_nums_max * dim, stream, device_memory); - // number of vectors in each cluster - rmm::device_uvector mc_trainset_csizes_tmp( - fine_clusters_nums_max, stream, device_memory); - - // Training clusters in each meso-cluster - IdxT n_clusters_done = 0; - for (IdxT i = 0; i < n_mesoclusters; i++) { - IdxT k = 0; - for (IdxT j = 0; j < n_rows && k < mesocluster_size_max; j++) { - if (labels_mptr[j] == LabelT(i)) { mc_trainset_ids[k++] = j; } - } - if (k != static_cast(mesocluster_sizes[i])) - RAFT_LOG_WARN("Incorrect mesocluster size at %d. %zu vs %zu", - static_cast(i), - static_cast(k), - static_cast(mesocluster_sizes[i])); - if (k == 0) { - RAFT_LOG_DEBUG("Empty cluster %d", i); - RAFT_EXPECTS(fine_clusters_nums[i] == 0, - "Number of fine clusters must be zero for the empty mesocluster (got %d)", - static_cast(fine_clusters_nums[i])); - continue; - } else { - RAFT_EXPECTS(fine_clusters_nums[i] > 0, - "Number of fine clusters must be non-zero for a non-empty mesocluster"); - } - - cub::TransformInputIterator mapping_itr(dataset_mptr, mapping_op); - raft::matrix::gather(mapping_itr, dim, n_rows, mc_trainset_ids, k, mc_trainset, stream); - if (params.metric == raft::distance::DistanceType::L2Expanded || - params.metric == raft::distance::DistanceType::L2SqrtExpanded) { - thrust::gather(resource::get_thrust_policy(handle), - mc_trainset_ids, - mc_trainset_ids + k, - dataset_norm_mptr, - mc_trainset_norm); - } - - build_clusters(handle, - params, - dim, - mc_trainset, - k, - fine_clusters_nums[i], - mc_trainset_ccenters.data(), - mc_trainset_labels.data(), - mc_trainset_csizes_tmp.data(), - mapping_op, - device_memory, - mc_trainset_norm); - - raft::copy(cluster_centers + (dim * fine_clusters_csum[i]), - mc_trainset_ccenters.data(), - fine_clusters_nums[i] * dim, - stream); - resource::sync_stream(handle, stream); - n_clusters_done += fine_clusters_nums[i]; - } - return n_clusters_done; -} - -/** - * @brief Hierarchical balanced k-means - * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * @tparam MappingOpT type of the mapping operation - * - * @param[in] handle The raft handle. - * @param[in] params Structure containing the hyper-parameters - * @param dim number of columns in `centers` and `dataset` - * @param[in] dataset a device pointer to the source dataset [n_rows, dim] - * @param n_rows number of rows in the input - * @param[out] cluster_centers a device pointer to the found cluster centers [n_cluster, dim] - * @param n_cluster - * @param metric the distance type - * @param mapping_op Mapping operation from T to MathT - * @param stream - */ -template -void build_hierarchical(const raft::resources& handle, - const kmeans_balanced_params& params, - IdxT dim, - const T* dataset, - IdxT n_rows, - MathT* cluster_centers, - IdxT n_clusters, - MappingOpT mapping_op) -{ - auto stream = resource::get_cuda_stream(handle); - using LabelT = uint32_t; - - common::nvtx::range fun_scope( - "build_hierarchical(%zu, %u)", static_cast(n_rows), n_clusters); - - IdxT n_mesoclusters = std::min(n_clusters, static_cast(std::sqrt(n_clusters) + 0.5)); - RAFT_LOG_DEBUG("build_hierarchical: n_mesoclusters: %u", n_mesoclusters); - - // TODO: Remove the explicit managed memory- we shouldn't be creating this on the user's behalf. - rmm::mr::managed_memory_resource managed_memory; - rmm::device_async_resource_ref device_memory = resource::get_workspace_resource(handle); - auto [max_minibatch_size, mem_per_row] = - calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); - - // Precompute the L2 norm of the dataset if relevant. - const MathT* dataset_norm = nullptr; - rmm::device_uvector dataset_norm_buf(0, stream, device_memory); - if (params.metric == raft::distance::DistanceType::L2Expanded || - params.metric == raft::distance::DistanceType::L2SqrtExpanded) { - dataset_norm_buf.resize(n_rows, stream); - for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { - IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); - compute_norm(handle, - dataset_norm_buf.data() + offset, - dataset + dim * offset, - dim, - minibatch_size, - mapping_op, - device_memory); - } - dataset_norm = (const MathT*)dataset_norm_buf.data(); - } - - /* Temporary workaround to cub::DeviceHistogram not supporting any type that isn't natively - * supported by atomicAdd: find a supported CounterT based on the IdxT. */ - typedef typename std::conditional_t - CounterT; - - // build coarse clusters (mesoclusters) - rmm::device_uvector mesocluster_labels_buf(n_rows, stream, &managed_memory); - rmm::device_uvector mesocluster_sizes_buf(n_mesoclusters, stream, &managed_memory); - { - rmm::device_uvector mesocluster_centers_buf(n_mesoclusters * dim, stream, device_memory); - build_clusters(handle, - params, - dim, - dataset, - n_rows, - n_mesoclusters, - mesocluster_centers_buf.data(), - mesocluster_labels_buf.data(), - mesocluster_sizes_buf.data(), - mapping_op, - device_memory, - dataset_norm); - } - - auto mesocluster_sizes = mesocluster_sizes_buf.data(); - auto mesocluster_labels = mesocluster_labels_buf.data(); - - resource::sync_stream(handle, stream); - - // build fine clusters - auto [mesocluster_size_max, fine_clusters_nums_max, fine_clusters_nums, fine_clusters_csum] = - arrange_fine_clusters(n_clusters, n_mesoclusters, n_rows, mesocluster_sizes); - - const IdxT mesocluster_size_max_balanced = div_rounding_up_safe( - 2lu * size_t(n_rows), std::max(size_t(n_mesoclusters), 1lu)); - if (mesocluster_size_max > mesocluster_size_max_balanced) { - RAFT_LOG_WARN( - "build_hierarchical: built unbalanced mesoclusters (max_mesocluster_size == %u > %u). " - "At most %u points will be used for training within each mesocluster. " - "Consider increasing the number of training iterations `n_iters`.", - mesocluster_size_max, - mesocluster_size_max_balanced, - mesocluster_size_max_balanced); - RAFT_LOG_TRACE_VEC(mesocluster_sizes, n_mesoclusters); - RAFT_LOG_TRACE_VEC(fine_clusters_nums.data(), n_mesoclusters); - mesocluster_size_max = mesocluster_size_max_balanced; - } - - auto n_clusters_done = build_fine_clusters(handle, - params, - dim, - dataset, - dataset_norm, - mesocluster_labels, - n_rows, - fine_clusters_nums.data(), - fine_clusters_csum.data(), - mesocluster_sizes, - n_mesoclusters, - mesocluster_size_max, - fine_clusters_nums_max, - cluster_centers, - mapping_op, - &managed_memory, - device_memory); - RAFT_EXPECTS(n_clusters_done == n_clusters, "Didn't process all clusters."); - - rmm::device_uvector cluster_sizes(n_clusters, stream, device_memory); - rmm::device_uvector labels(n_rows, stream, device_memory); - - // Fine-tuning k-means for all clusters - // - // (*) Since the likely cluster centroids have been calculated hierarchically already, the number - // of iterations for fine-tuning kmeans for whole clusters should be reduced. However, there is a - // possibility that the clusters could be unbalanced here, in which case the actual number of - // iterations would be increased. - // - balancing_em_iters(handle, - params, - std::max(params.n_iters / 10, 2), - dim, - dataset, - dataset_norm, - n_rows, - n_clusters, - cluster_centers, - labels.data(), - cluster_sizes.data(), - 5, - MathT{0.2}, - mapping_op, - device_memory); -} - -} // namespace raft::cluster::detail diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh deleted file mode 100644 index 8263aa4615..0000000000 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ /dev/null @@ -1,663 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace raft { -namespace cluster { -namespace detail { - -template -struct SamplingOp { - DataT* rnd; - uint8_t* flag; - DataT cluster_cost; - double oversampling_factor; - IndexT n_clusters; - - CUB_RUNTIME_FUNCTION __forceinline__ - SamplingOp(DataT c, double l, IndexT k, DataT* rand, uint8_t* ptr) - : cluster_cost(c), oversampling_factor(l), n_clusters(k), rnd(rand), flag(ptr) - { - } - - __host__ __device__ __forceinline__ bool operator()( - const raft::KeyValuePair& a) const - { - DataT prob_threshold = (DataT)rnd[a.key]; - - DataT prob_x = ((oversampling_factor * n_clusters * a.value) / cluster_cost); - - return !flag[a.key] && (prob_x > prob_threshold); - } -}; - -template -struct KeyValueIndexOp { - __host__ __device__ __forceinline__ IndexT - operator()(const raft::KeyValuePair& a) const - { - return a.key; - } -}; - -// Computes the intensity histogram from a sequence of labels -template -void countLabels(raft::resources const& handle, - SampleIteratorT labels, - CounterT* count, - IndexT n_samples, - IndexT n_clusters, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - - // CUB::DeviceHistogram requires a signed index type - typedef typename std::make_signed_t CubIndexT; - - CubIndexT num_levels = n_clusters + 1; - CubIndexT lower_level = 0; - CubIndexT upper_level = n_clusters; - - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(nullptr, - temp_storage_bytes, - labels, - count, - num_levels, - lower_level, - upper_level, - static_cast(n_samples), - stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(workspace.data(), - temp_storage_bytes, - labels, - count, - num_levels, - lower_level, - upper_level, - static_cast(n_samples), - stream)); -} - -template -void checkWeight(raft::resources const& handle, - raft::device_vector_view weight, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto wt_aggr = raft::make_device_scalar(handle, 0); - auto n_samples = weight.extent(0); - - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - nullptr, temp_storage_bytes, weight.data_handle(), wt_aggr.data_handle(), n_samples, stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceReduce::Sum(workspace.data(), - temp_storage_bytes, - weight.data_handle(), - wt_aggr.data_handle(), - n_samples, - stream)); - DataT wt_sum = 0; - raft::copy(&wt_sum, wt_aggr.data_handle(), 1, stream); - resource::sync_stream(handle, stream); - - if (wt_sum != n_samples) { - RAFT_LOG_DEBUG( - "[Warning!] KMeans: normalizing the user provided sample weight to " - "sum up to %d samples", - n_samples); - - auto scale = static_cast(n_samples) / wt_sum; - raft::linalg::unaryOp(weight.data_handle(), - weight.data_handle(), - n_samples, - raft::mul_const_op{scale}, - stream); - } -} - -template -IndexT getDataBatchSize(int batch_samples, IndexT n_samples) -{ - auto minVal = std::min(static_cast(batch_samples), n_samples); - return (minVal == 0) ? n_samples : minVal; -} - -template -IndexT getCentroidsBatchSize(int batch_centroids, IndexT n_local_clusters) -{ - auto minVal = std::min(static_cast(batch_centroids), n_local_clusters); - return (minVal == 0) ? n_local_clusters : minVal; -} - -template -void computeClusterCost(raft::resources const& handle, - raft::device_vector_view minClusterDistance, - rmm::device_uvector& workspace, - raft::device_scalar_view clusterCost, - MainOpT main_op, - ReductionOpT reduction_op) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - - cub::TransformInputIterator itr(minClusterDistance.data_handle(), - main_op); - - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(nullptr, - temp_storage_bytes, - itr, - clusterCost.data_handle(), - minClusterDistance.size(), - reduction_op, - OutputT(), - stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(workspace.data(), - temp_storage_bytes, - itr, - clusterCost.data_handle(), - minClusterDistance.size(), - reduction_op, - OutputT(), - stream)); -} - -template -void sampleCentroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view minClusterDistance, - raft::device_vector_view isSampleCentroid, - SamplingOp& select_op, - rmm::device_uvector& inRankCp, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_local_samples = X.extent(0); - auto n_features = X.extent(1); - - auto nSelected = raft::make_device_scalar(handle, 0); - cub::ArgIndexInputIterator ip_itr(minClusterDistance.data_handle()); - auto sampledMinClusterDistance = - raft::make_device_vector, IndexT>(handle, n_local_samples); - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceSelect::If(nullptr, - temp_storage_bytes, - ip_itr, - sampledMinClusterDistance.data_handle(), - nSelected.data_handle(), - n_local_samples, - select_op, - stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceSelect::If(workspace.data(), - temp_storage_bytes, - ip_itr, - sampledMinClusterDistance.data_handle(), - nSelected.data_handle(), - n_local_samples, - select_op, - stream)); - - IndexT nPtsSampledInRank = 0; - raft::copy(&nPtsSampledInRank, nSelected.data_handle(), 1, stream); - resource::sync_stream(handle, stream); - - uint8_t* rawPtr_isSampleCentroid = isSampleCentroid.data_handle(); - thrust::for_each_n(resource::get_thrust_policy(handle), - sampledMinClusterDistance.data_handle(), - nPtsSampledInRank, - [=] __device__(raft::KeyValuePair val) { - rawPtr_isSampleCentroid[val.key] = 1; - }); - - inRankCp.resize(nPtsSampledInRank * n_features, stream); - - raft::matrix::gather((DataT*)X.data_handle(), - X.extent(1), - X.extent(0), - sampledMinClusterDistance.data_handle(), - nPtsSampledInRank, - inRankCp.data(), - raft::key_op{}, - stream); -} - -// calculate pairwise distance between 'dataset[n x d]' and 'centroids[k x d]', -// result will be stored in 'pairwiseDistance[n x k]' -template -void pairwise_distance_kmeans(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_matrix_view pairwiseDistance, - rmm::device_uvector& workspace, - raft::distance::DistanceType metric) -{ - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = centroids.extent(0); - - ASSERT(X.extent(1) == centroids.extent(1), - "# features in dataset and centroids are different (must be same)"); - - raft::distance::pairwise_distance(handle, - X.data_handle(), - centroids.data_handle(), - pairwiseDistance.data_handle(), - n_samples, - n_clusters, - n_features, - workspace, - metric); -} - -// shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores -// in 'out' does not modify the input -template -void shuffleAndGather(raft::resources const& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - uint32_t n_samples_to_gather, - uint64_t seed) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = in.extent(0); - auto n_features = in.extent(1); - - auto indices = raft::make_device_vector(handle, n_samples); - - // shuffle indices on device - raft::random::permute(indices.data_handle(), - nullptr, - nullptr, - (IndexT)in.extent(1), - (IndexT)in.extent(0), - true, - stream); - - raft::matrix::gather((DataT*)in.data_handle(), - in.extent(1), - in.extent(0), - indices.data_handle(), - static_cast(n_samples_to_gather), - out.data_handle(), - stream); -} - -// Calculates a pair for every sample in input 'X' where key is an -// index to an sample in 'centroids' (index of the nearest centroid) and 'value' -// is the distance between the sample and the 'centroid[key]' -template -void minClusterAndDistanceCompute( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - raft::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = centroids.extent(0); - // todo(lsugy): change batch size computation when using fusedL2NN! - bool is_fused = metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded; - auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); - auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); - - if (is_fused) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::rowNorm(L2NormBuf_OR_DistBuf.data(), - centroids.data_handle(), - centroids.extent(1), - centroids.extent(0), - raft::linalg::L2Norm, - true, - stream); - } else { - // TODO: Unless pool allocator is used, passing in a workspace for this - // isn't really increasing performance because this needs to do a re-allocation - // anyways. ref https://github.com/rapidsai/raft/issues/930 - L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); - } - - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer - auto pairwiseDistance = raft::make_device_matrix_view( - L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); - - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - - thrust::fill(resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - initial_value); - - // tile over the input dataset - for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); - - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = raft::make_device_matrix_view( - X.data_handle() + (dIdx * n_features), ns, n_features); - - // minClusterAndDistanceView [ns x n_clusters] - auto minClusterAndDistanceView = - raft::make_device_vector_view, IndexT>( - minClusterAndDistance.data_handle() + dIdx, ns); - - auto L2NormXView = - raft::make_device_vector_view(L2NormX.data_handle() + dIdx, ns); - - if (is_fused) { - workspace.resize((sizeof(int)) * ns, stream); - - // todo(lsugy): remove cIdx - raft::distance::fusedL2NNMinReduce, IndexT>( - minClusterAndDistanceView.data_handle(), - datasetView.data_handle(), - centroids.data_handle(), - L2NormXView.data_handle(), - centroidsNorm.data_handle(), - ns, - n_clusters, - n_features, - (void*)workspace.data(), - metric != raft::distance::DistanceType::L2Expanded, - false, - stream); - } else { - // tile over the centroids - for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { - // # of centroids for the current batch - auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); - - // centroidsView [nc x n_features] - view representing the current batch - // of centroids - auto centroidsView = raft::make_device_matrix_view( - centroids.data_handle() + (cIdx * n_features), nc, n_features); - - // pairwiseDistanceView [ns x nc] - view representing the pairwise - // distance for current batch - auto pairwiseDistanceView = - raft::make_device_matrix_view(pairwiseDistance.data_handle(), ns, nc); - - // calculate pairwise distance between current tile of cluster centroids - // and input dataset - pairwise_distance_kmeans( - handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric); - - // argmin reduction returning pair - // calculates the closest centroid and the distance to the closest - // centroid - raft::linalg::coalescedReduction( - minClusterAndDistanceView.data_handle(), - pairwiseDistanceView.data_handle(), - pairwiseDistanceView.extent(1), - pairwiseDistanceView.extent(0), - initial_value, - stream, - true, - [=] __device__(const DataT val, const IndexT i) { - raft::KeyValuePair pair; - pair.key = cIdx + i; - pair.value = val; - return pair; - }, - raft::argmin_op{}, - raft::identity_op{}); - } - } - } -} - -template -void minClusterDistanceCompute(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view minClusterDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - raft::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = centroids.extent(0); - - bool is_fused = metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded; - auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); - auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); - - if (is_fused) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::rowNorm(L2NormBuf_OR_DistBuf.data(), - centroids.data_handle(), - centroids.extent(1), - centroids.extent(0), - raft::linalg::L2Norm, - true, - stream); - } else { - L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); - } - - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer - auto pairwiseDistance = raft::make_device_matrix_view( - L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); - - thrust::fill(resource::get_thrust_policy(handle), - minClusterDistance.data_handle(), - minClusterDistance.data_handle() + minClusterDistance.size(), - std::numeric_limits::max()); - - // tile over the input data and calculate distance matrix [n_samples x - // n_clusters] - for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); - - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = raft::make_device_matrix_view( - X.data_handle() + dIdx * n_features, ns, n_features); - - // minClusterDistanceView [ns x n_clusters] - auto minClusterDistanceView = - raft::make_device_vector_view(minClusterDistance.data_handle() + dIdx, ns); - - auto L2NormXView = - raft::make_device_vector_view(L2NormX.data_handle() + dIdx, ns); - - if (is_fused) { - workspace.resize((sizeof(IndexT)) * ns, stream); - - raft::distance::fusedL2NNMinReduce( - minClusterDistanceView.data_handle(), - datasetView.data_handle(), - centroids.data_handle(), - L2NormXView.data_handle(), - centroidsNorm.data_handle(), - ns, - n_clusters, - n_features, - (void*)workspace.data(), - metric != raft::distance::DistanceType::L2Expanded, - false, - stream); - } else { - // tile over the centroids - for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { - // # of centroids for the current batch - auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); - - // centroidsView [nc x n_features] - view representing the current batch - // of centroids - auto centroidsView = raft::make_device_matrix_view( - centroids.data_handle() + cIdx * n_features, nc, n_features); - - // pairwiseDistanceView [ns x nc] - view representing the pairwise - // distance for current batch - auto pairwiseDistanceView = - raft::make_device_matrix_view(pairwiseDistance.data_handle(), ns, nc); - - // calculate pairwise distance between current tile of cluster centroids - // and input dataset - pairwise_distance_kmeans( - handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric); - - raft::linalg::coalescedReduction(minClusterDistanceView.data_handle(), - pairwiseDistanceView.data_handle(), - pairwiseDistanceView.extent(1), - pairwiseDistanceView.extent(0), - std::numeric_limits::max(), - stream, - true, - raft::identity_op{}, - raft::min_op{}, - raft::identity_op{}); - } - } - } -} - -template -void countSamplesInCluster(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view L2NormX, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace, - raft::device_vector_view sampleCountInCluster) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = centroids.extent(0); - - // stores (key, value) pair corresponding to each sample where - // - key is the index of nearest cluster - // - value is the distance to the nearest cluster - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - - // temporary buffer to store distance matrix, destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] - // is a pair where - // 'key' is index to an sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - detail::minClusterAndDistanceCompute(handle, - X, - (raft::device_matrix_view)centroids, - minClusterAndDistance.view(), - L2NormX, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // Using TransformInputIteratorT to dereference an array of raft::KeyValuePair - // and converting them to just return the Key to be used in reduce_rows_by_key - // prims - detail::KeyValueIndexOp conversion_op; - cub::TransformInputIterator, - raft::KeyValuePair*> - itr(minClusterAndDistance.data_handle(), conversion_op); - - // count # of samples in each cluster - countLabels(handle, - itr, - sampleCountInCluster.data_handle(), - (IndexT)n_samples, - (IndexT)n_clusters, - workspace); -} -} // namespace detail -} // namespace cluster -} // namespace raft diff --git a/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh b/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh deleted file mode 100644 index e89f5480e3..0000000000 --- a/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh +++ /dev/null @@ -1,1001 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Note: This file is deprecated and will be removed in a future release - * Please use include/raft/cluster/kmeans.cuh instead - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace raft { -namespace cluster { -namespace detail { -// ========================================================= -// Useful grid settings -// ========================================================= - -constexpr unsigned int BLOCK_SIZE = 1024; -constexpr unsigned int WARP_SIZE = 32; -constexpr unsigned int BSIZE_DIV_WSIZE = (BLOCK_SIZE / WARP_SIZE); - -// ========================================================= -// CUDA kernels -// ========================================================= - -/** - * @brief Compute distances between observation vectors and centroids - * Block dimensions should be (warpSize, 1, - * blockSize/warpSize). Ideally, the grid is large enough so there - * are d threads in the x-direction, k threads in the y-direction, - * and n threads in the z-direction. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param obs (Input, d*n entries) Observation matrix. Matrix is - * stored column-major and each column is an observation - * vector. Matrix dimensions are d x n. - * @param centroids (Input, d*k entries) Centroid matrix. Matrix is - * stored column-major and each column is a centroid. Matrix - * dimensions are d x k. - * @param dists (Output, n*k entries) Distance matrix. Matrix is - * stored column-major and the (i,j)-entry is the square of the - * Euclidean distance between the ith observation vector and jth - * centroid. Matrix dimensions are n x k. Entries must be - * initialized to zero. - */ -template -RAFT_KERNEL computeDistances(index_type_t n, - index_type_t d, - index_type_t k, - const value_type_t* __restrict__ obs, - const value_type_t* __restrict__ centroids, - value_type_t* __restrict__ dists) -{ - // Loop index - index_type_t i; - - // Block indices - index_type_t bidx; - // Global indices - index_type_t gidx, gidy, gidz; - - // Private memory - value_type_t centroid_private, dist_private; - - // Global x-index indicates index of vector entry - bidx = blockIdx.x; - while (bidx * blockDim.x < d) { - gidx = threadIdx.x + bidx * blockDim.x; - - // Global y-index indicates centroid - gidy = threadIdx.y + blockIdx.y * blockDim.y; - while (gidy < k) { - // Load centroid coordinate from global memory - centroid_private = (gidx < d) ? centroids[IDX(gidx, gidy, d)] : 0; - - // Global z-index indicates observation vector - gidz = threadIdx.z + blockIdx.z * blockDim.z; - while (gidz < n) { - // Load observation vector coordinate from global memory - dist_private = (gidx < d) ? obs[IDX(gidx, gidz, d)] : 0; - - // Compute contribution of current entry to distance - dist_private = centroid_private - dist_private; - dist_private = dist_private * dist_private; - - // Perform reduction on warp - for (i = WARP_SIZE / 2; i > 0; i /= 2) - dist_private += __shfl_down_sync(warp_full_mask(), dist_private, i, 2 * i); - - // Write result to global memory - if (threadIdx.x == 0) atomicAdd(dists + IDX(gidz, gidy, n), dist_private); - - // Move to another observation vector - gidz += blockDim.z * gridDim.z; - } - - // Move to another centroid - gidy += blockDim.y * gridDim.y; - } - - // Move to another vector entry - bidx += gridDim.x; - } -} - -/** - * @brief Find closest centroid to observation vectors. - * Block and grid dimensions should be 1-dimensional. Ideally the - * grid is large enough so there are n threads. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param n Number of observation vectors. - * @param k Number of clusters. - * @param centroids (Input, d*k entries) Centroid matrix. Matrix is - * stored column-major and each column is a centroid. Matrix - * dimensions are d x k. - * @param dists (Input/output, n*k entries) Distance matrix. Matrix - * is stored column-major and the (i,j)-entry is the square of - * the Euclidean distance between the ith observation vector and - * jth centroid. Matrix dimensions are n x k. On exit, the first - * n entries give the square of the Euclidean distance between - * observation vectors and closest centroids. - * @param codes (Output, n entries) Cluster assignments. - * @param clusterSizes (Output, k entries) Number of points in each - * cluster. Entries must be initialized to zero. - */ -template -RAFT_KERNEL minDistances(index_type_t n, - index_type_t k, - value_type_t* __restrict__ dists, - index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes) -{ - // Loop index - index_type_t i, j; - - // Current matrix entry - value_type_t dist_curr; - - // Smallest entry in row - value_type_t dist_min; - index_type_t code_min; - - // Each row in observation matrix is processed by a thread - i = threadIdx.x + blockIdx.x * blockDim.x; - while (i < n) { - // Find minimum entry in row - code_min = 0; - dist_min = dists[IDX(i, 0, n)]; - for (j = 1; j < k; ++j) { - dist_curr = dists[IDX(i, j, n)]; - code_min = (dist_curr < dist_min) ? j : code_min; - dist_min = (dist_curr < dist_min) ? dist_curr : dist_min; - } - - // Transfer result to global memory - dists[i] = dist_min; - codes[i] = code_min; - - // Increment cluster sizes - atomicAdd(clusterSizes + code_min, 1); - - // Move to another row - i += blockDim.x * gridDim.x; - } -} - -/** - * @brief Check if newly computed distances are smaller than old distances. - * Block and grid dimensions should be 1-dimensional. Ideally the - * grid is large enough so there are n threads. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param n Number of observation vectors. - * @param dists_old (Input/output, n entries) Distances between - * observation vectors and closest centroids. On exit, entries - * are replaced by entries in 'dists_new' if the corresponding - * observation vectors are closest to the new centroid. - * @param dists_new (Input, n entries) Distance between observation - * vectors and new centroid. - * @param codes_old (Input/output, n entries) Cluster - * assignments. On exit, entries are replaced with 'code_new' if - * the corresponding observation vectors are closest to the new - * centroid. - * @param code_new Index associated with new centroid. - */ -template -RAFT_KERNEL minDistances2(index_type_t n, - value_type_t* __restrict__ dists_old, - const value_type_t* __restrict__ dists_new, - index_type_t* __restrict__ codes_old, - index_type_t code_new) -{ - // Loop index - index_type_t i = threadIdx.x + blockIdx.x * blockDim.x; - - // Distances - value_type_t dist_old_private; - value_type_t dist_new_private; - - // Each row is processed by a thread - while (i < n) { - // Get old and new distances - dist_old_private = dists_old[i]; - dist_new_private = dists_new[i]; - - // Update if new distance is smaller than old distance - if (dist_new_private < dist_old_private) { - dists_old[i] = dist_new_private; - codes_old[i] = code_new; - } - - // Move to another row - i += blockDim.x * gridDim.x; - } -} - -/** - * @brief Compute size of k-means clusters. - * Block and grid dimensions should be 1-dimensional. Ideally the - * grid is large enough so there are n threads. - * @tparam index_type_t the type of data used for indexing. - * @param n Number of observation vectors. - * @param k Number of clusters. - * @param codes (Input, n entries) Cluster assignments. - * @param clusterSizes (Output, k entries) Number of points in each - * cluster. Entries must be initialized to zero. - */ -template -RAFT_KERNEL computeClusterSizes(index_type_t n, - const index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes) -{ - index_type_t i = threadIdx.x + blockIdx.x * blockDim.x; - while (i < n) { - atomicAdd(clusterSizes + codes[i], 1); - i += blockDim.x * gridDim.x; - } -} - -/** - * @brief Divide rows of centroid matrix by cluster sizes. - * Divides the ith column of the sum matrix by the size of the ith - * cluster. If the sum matrix has been initialized so that the ith - * row is the sum of all observation vectors in the ith cluster, - * this kernel produces cluster centroids. The grid and block - * dimensions should be 2-dimensional. Ideally the grid is large - * enough so there are d threads in the x-direction and k threads - * in the y-direction. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param clusterSizes (Input, k entries) Number of points in each - * cluster. - * @param centroids (Input/output, d*k entries) Sum matrix. Matrix - * is stored column-major and matrix dimensions are d x k. The - * ith column is the sum of all observation vectors in the ith - * cluster. On exit, the matrix is the centroid matrix (each - * column is the mean position of a cluster). - */ -template -RAFT_KERNEL divideCentroids(index_type_t d, - index_type_t k, - const index_type_t* __restrict__ clusterSizes, - value_type_t* __restrict__ centroids) -{ - // Global indices - index_type_t gidx, gidy; - - // Current cluster size - index_type_t clusterSize_private; - - // Observation vector is determined by global y-index - gidy = threadIdx.y + blockIdx.y * blockDim.y; - while (gidy < k) { - // Get cluster size from global memory - clusterSize_private = clusterSizes[gidy]; - - // Add vector entries to centroid matrix - // vector entris are determined by global x-index - gidx = threadIdx.x + blockIdx.x * blockDim.x; - while (gidx < d) { - centroids[IDX(gidx, gidy, d)] /= clusterSize_private; - gidx += blockDim.x * gridDim.x; - } - - // Move to another centroid - gidy += blockDim.y * gridDim.y; - } -} - -// ========================================================= -// Helper functions -// ========================================================= - -/** - * @brief Randomly choose new centroids. - * Centroid is randomly chosen with k-means++ algorithm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param rand Random number drawn uniformly from [0,1). - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are n x d. - * @param dists (Input, device memory, 2*n entries) Workspace. The - * first n entries should be the distance between observation - * vectors and the closest centroid. - * @param centroid (Output, device memory, d entries) Centroid - * coordinates. - * @return Zero if successful. Otherwise non-zero. - */ -template -static int chooseNewCentroid(raft::resources const& handle, - index_type_t n, - index_type_t d, - value_type_t rand, - const value_type_t* __restrict__ obs, - value_type_t* __restrict__ dists, - value_type_t* __restrict__ centroid) -{ - // Cumulative sum of distances - value_type_t* distsCumSum = dists + n; - // Residual sum of squares - value_type_t distsSum{0}; - // Observation vector that is chosen as new centroid - index_type_t obsIndex; - - auto stream = resource::get_cuda_stream(handle); - auto thrust_exec_policy = resource::get_thrust_policy(handle); - - // Compute cumulative sum of distances - thrust::inclusive_scan(thrust_exec_policy, - thrust::device_pointer_cast(dists), - thrust::device_pointer_cast(dists + n), - thrust::device_pointer_cast(distsCumSum)); - RAFT_CHECK_CUDA(stream); - RAFT_CUDA_TRY(cudaMemcpyAsync( - &distsSum, distsCumSum + n - 1, sizeof(value_type_t), cudaMemcpyDeviceToHost, stream)); - - // Randomly choose observation vector - // Probabilities are proportional to square of distance to closest - // centroid (see k-means++ algorithm) - // - // seg-faults due to Thrust bug - // on binary-search-like algorithms - // when run with stream dependent - // execution policies; fixed on Thrust GitHub - // hence replace w/ linear interpolation, - // until the Thrust issue gets resolved: - // - // obsIndex = (thrust::lower_bound( - // thrust_exec_policy, thrust::device_pointer_cast(distsCumSum), - // thrust::device_pointer_cast(distsCumSum + n), distsSum * rand) - - // thrust::device_pointer_cast(distsCumSum)); - // - // linear interpolation logic: - //{ - value_type_t minSum{0}; - RAFT_CUDA_TRY( - cudaMemcpyAsync(&minSum, distsCumSum, sizeof(value_type_t), cudaMemcpyDeviceToHost, stream)); - RAFT_CHECK_CUDA(stream); - - if (distsSum > minSum) { - value_type_t vIndex = static_cast(n - 1); - obsIndex = static_cast(vIndex * (distsSum * rand - minSum) / (distsSum - minSum)); - } else { - obsIndex = 0; - } - //} - - RAFT_CHECK_CUDA(stream); - obsIndex = std::max(obsIndex, static_cast(0)); - obsIndex = std::min(obsIndex, n - 1); - - // Record new centroid position - RAFT_CUDA_TRY(cudaMemcpyAsync(centroid, - obs + IDX(0, obsIndex, d), - d * sizeof(value_type_t), - cudaMemcpyDeviceToDevice, - stream)); - - return 0; -} - -/** - * @brief Choose initial cluster centroids for k-means algorithm. - * Centroids are randomly chosen with k-means++ algorithm - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param centroids (Output, device memory, d*k entries) Centroid - * matrix. Matrix is stored column-major and each column is a - * centroid. Matrix dimensions are d x k. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param clusterSizes (Output, device memory, k entries) Number of - * points in each cluster. - * @param dists (Output, device memory, 2*n entries) Workspace. On - * exit, the first n entries give the square of the Euclidean - * distance between observation vectors and the closest centroid. - * @return Zero if successful. Otherwise non-zero. - */ -template -static int initializeCentroids(raft::resources const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - const value_type_t* __restrict__ obs, - value_type_t* __restrict__ centroids, - index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes, - value_type_t* __restrict__ dists, - unsigned long long seed) -{ - // ------------------------------------------------------- - // Variable declarations - // ------------------------------------------------------- - - // Loop index - index_type_t i; - - // Random number generator - thrust::default_random_engine rng(seed); - thrust::uniform_real_distribution uniformDist(0, 1); - - auto stream = resource::get_cuda_stream(handle); - auto thrust_exec_policy = resource::get_thrust_policy(handle); - - constexpr unsigned grid_lower_bound{65535}; - - // ------------------------------------------------------- - // Implementation - // ------------------------------------------------------- - - // Initialize grid dimensions - dim3 blockDim_warp{WARP_SIZE, 1, BSIZE_DIV_WSIZE}; - - // CUDA grid dimensions - dim3 gridDim_warp{std::min(ceildiv(d, WARP_SIZE), grid_lower_bound), - 1, - std::min(ceildiv(n, BSIZE_DIV_WSIZE), grid_lower_bound)}; - - // CUDA grid dimensions - dim3 gridDim_block{std::min(ceildiv(n, BLOCK_SIZE), grid_lower_bound), 1, 1}; - - // Assign observation vectors to code 0 - RAFT_CUDA_TRY(cudaMemsetAsync(codes, 0, n * sizeof(index_type_t), stream)); - - // Choose first centroid - thrust::fill(thrust_exec_policy, - thrust::device_pointer_cast(dists), - thrust::device_pointer_cast(dists + n), - 1); - RAFT_CHECK_CUDA(stream); - if (chooseNewCentroid(handle, n, d, uniformDist(rng), obs, dists, centroids)) - WARNING("error in k-means++ (could not pick centroid)"); - - // Compute distances from first centroid - RAFT_CUDA_TRY(cudaMemsetAsync(dists, 0, n * sizeof(value_type_t), stream)); - computeDistances<<>>(n, d, 1, obs, centroids, dists); - RAFT_CHECK_CUDA(stream); - - // Choose remaining centroids - for (i = 1; i < k; ++i) { - // Choose ith centroid - if (chooseNewCentroid(handle, n, d, uniformDist(rng), obs, dists, centroids + IDX(0, i, d))) - WARNING("error in k-means++ (could not pick centroid)"); - - // Compute distances from ith centroid - RAFT_CUDA_TRY(cudaMemsetAsync(dists + n, 0, n * sizeof(value_type_t), stream)); - computeDistances<<>>( - n, d, 1, obs, centroids + IDX(0, i, d), dists + n); - RAFT_CHECK_CUDA(stream); - - // Recompute minimum distances - minDistances2<<>>(n, dists, dists + n, codes, i); - RAFT_CHECK_CUDA(stream); - } - - // Compute cluster sizes - RAFT_CUDA_TRY(cudaMemsetAsync(clusterSizes, 0, k * sizeof(index_type_t), stream)); - computeClusterSizes<<>>(n, codes, clusterSizes); - RAFT_CHECK_CUDA(stream); - - return 0; -} - -/** - * @brief Find cluster centroids closest to observation vectors. - * Distance is measured with Euclidean norm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param centroids (Input, device memory, d*k entries) Centroid - * matrix. Matrix is stored column-major and each column is a - * centroid. Matrix dimensions are d x k. - * @param dists (Output, device memory, n*k entries) Workspace. On - * exit, the first n entries give the square of the Euclidean - * distance between observation vectors and the closest centroid. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param clusterSizes (Output, device memory, k entries) Number of - * points in each cluster. - * @param residual_host (Output, host memory, 1 entry) Residual sum - * of squares of assignment. - * @return Zero if successful. Otherwise non-zero. - */ -template -static int assignCentroids(raft::resources const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - const value_type_t* __restrict__ obs, - const value_type_t* __restrict__ centroids, - value_type_t* __restrict__ dists, - index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes, - value_type_t* residual_host) -{ - auto stream = resource::get_cuda_stream(handle); - auto thrust_exec_policy = resource::get_thrust_policy(handle); - - // Compute distance between centroids and observation vectors - RAFT_CUDA_TRY(cudaMemsetAsync(dists, 0, n * k * sizeof(value_type_t), stream)); - - // CUDA grid dimensions - dim3 blockDim{WARP_SIZE, 1, BLOCK_SIZE / WARP_SIZE}; - - dim3 gridDim; - constexpr unsigned grid_lower_bound{65535}; - gridDim.x = std::min(ceildiv(d, WARP_SIZE), grid_lower_bound); - gridDim.y = std::min(static_cast(k), grid_lower_bound); - gridDim.z = std::min(ceildiv(n, BSIZE_DIV_WSIZE), grid_lower_bound); - - computeDistances<<>>(n, d, k, obs, centroids, dists); - RAFT_CHECK_CUDA(stream); - - // Find centroid closest to each observation vector - RAFT_CUDA_TRY(cudaMemsetAsync(clusterSizes, 0, k * sizeof(index_type_t), stream)); - blockDim.x = BLOCK_SIZE; - blockDim.y = 1; - blockDim.z = 1; - gridDim.x = std::min(ceildiv(n, BLOCK_SIZE), grid_lower_bound); - gridDim.y = 1; - gridDim.z = 1; - minDistances<<>>(n, k, dists, codes, clusterSizes); - RAFT_CHECK_CUDA(stream); - - // Compute residual sum of squares - *residual_host = thrust::reduce( - thrust_exec_policy, thrust::device_pointer_cast(dists), thrust::device_pointer_cast(dists + n)); - - return 0; -} - -/** - * @brief Update cluster centroids for k-means algorithm. - * All clusters are assumed to be non-empty. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param codes (Input, device memory, n entries) Cluster - * assignments. - * @param clusterSizes (Input, device memory, k entries) Number of - * points in each cluster. - * @param centroids (Output, device memory, d*k entries) Centroid - * matrix. Matrix is stored column-major and each column is a - * centroid. Matrix dimensions are d x k. - * @param work (Output, device memory, n*d entries) Workspace. - * @param work_int (Output, device memory, 2*d*n entries) - * Workspace. - * @return Zero if successful. Otherwise non-zero. - */ -template -static int updateCentroids(raft::resources const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - const value_type_t* __restrict__ obs, - const index_type_t* __restrict__ codes, - const index_type_t* __restrict__ clusterSizes, - value_type_t* __restrict__ centroids, - value_type_t* __restrict__ work, - index_type_t* __restrict__ work_int) -{ - // ------------------------------------------------------- - // Variable declarations - // ------------------------------------------------------- - - // Useful constants - const value_type_t one = 1; - const value_type_t zero = 0; - - constexpr unsigned grid_lower_bound{65535}; - - auto stream = resource::get_cuda_stream(handle); - auto cublas_h = resource::get_cublas_handle(handle); - auto thrust_exec_policy = resource::get_thrust_policy(handle); - - // Device memory - thrust::device_ptr obs_copy(work); - thrust::device_ptr codes_copy(work_int); - thrust::device_ptr rows(work_int + d * n); - - // Take transpose of observation matrix - // #TODO: Call from public API when ready - RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgeam(cublas_h, - CUBLAS_OP_T, - CUBLAS_OP_N, - n, - d, - &one, - obs, - d, - &zero, - (value_type_t*)NULL, - n, - thrust::raw_pointer_cast(obs_copy), - n, - stream)); - - // Cluster assigned to each observation matrix entry - thrust::sequence(thrust_exec_policy, rows, rows + d * n); - RAFT_CHECK_CUDA(stream); - thrust::transform(thrust_exec_policy, - rows, - rows + d * n, - thrust::make_constant_iterator(n), - rows, - thrust::modulus()); - RAFT_CHECK_CUDA(stream); - thrust::gather( - thrust_exec_policy, rows, rows + d * n, thrust::device_pointer_cast(codes), codes_copy); - RAFT_CHECK_CUDA(stream); - - // Row associated with each observation matrix entry - thrust::sequence(thrust_exec_policy, rows, rows + d * n); - RAFT_CHECK_CUDA(stream); - thrust::transform(thrust_exec_policy, - rows, - rows + d * n, - thrust::make_constant_iterator(n), - rows, - thrust::divides()); - RAFT_CHECK_CUDA(stream); - - // Sort and reduce to add observation vectors in same cluster - thrust::stable_sort_by_key(thrust_exec_policy, - codes_copy, - codes_copy + d * n, - make_zip_iterator(make_tuple(obs_copy, rows))); - RAFT_CHECK_CUDA(stream); - thrust::reduce_by_key(thrust_exec_policy, - rows, - rows + d * n, - obs_copy, - codes_copy, // Output to codes_copy is ignored - thrust::device_pointer_cast(centroids)); - RAFT_CHECK_CUDA(stream); - - // Divide sums by cluster size to get centroid matrix - // - // CUDA grid dimensions - dim3 blockDim{WARP_SIZE, BLOCK_SIZE / WARP_SIZE, 1}; - - // CUDA grid dimensions - dim3 gridDim{std::min(ceildiv(d, WARP_SIZE), grid_lower_bound), - std::min(ceildiv(k, BSIZE_DIV_WSIZE), grid_lower_bound), - 1}; - - divideCentroids<<>>(d, k, clusterSizes, centroids); - RAFT_CHECK_CUDA(stream); - - return 0; -} - -// ========================================================= -// k-means algorithm -// ========================================================= - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param tol Tolerance for convergence. k-means stops when the - * change in residual divided by n is less than tol. - * @param maxiter Maximum number of k-means iterations. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param clusterSizes (Output, device memory, k entries) Number of - * points in each cluster. - * @param centroids (Output, device memory, d*k entries) Centroid - * matrix. Matrix is stored column-major and each column is a - * centroid. Matrix dimensions are d x k. - * @param work (Output, device memory, n*max(k,d) entries) - * Workspace. - * @param work_int (Output, device memory, 2*d*n entries) - * Workspace. - * @param residual_host (Output, host memory, 1 entry) Residual sum - * of squares (sum of squares of distances between observation - * vectors and centroids). - * @param iters_host (Output, host memory, 1 entry) Number of - * k-means iterations. - * @param seed random seed to be used. - * @return error flag. - */ -template -int kmeans(raft::resources const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - value_type_t tol, - index_type_t maxiter, - const value_type_t* __restrict__ obs, - index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes, - value_type_t* __restrict__ centroids, - value_type_t* __restrict__ work, - index_type_t* __restrict__ work_int, - value_type_t* residual_host, - index_type_t* iters_host, - unsigned long long seed) -{ - // ------------------------------------------------------- - // Variable declarations - // ------------------------------------------------------- - - // Current iteration - index_type_t iter; - - constexpr unsigned grid_lower_bound{65535}; - - // Residual sum of squares at previous iteration - value_type_t residualPrev = 0; - - // Random number generator - thrust::default_random_engine rng(seed); - thrust::uniform_real_distribution uniformDist(0, 1); - - // ------------------------------------------------------- - // Initialization - // ------------------------------------------------------- - - auto stream = resource::get_cuda_stream(handle); - auto cublas_h = resource::get_cublas_handle(handle); - auto thrust_exec_policy = resource::get_thrust_policy(handle); - - // Trivial cases - if (k == 1) { - RAFT_CUDA_TRY(cudaMemsetAsync(codes, 0, n * sizeof(index_type_t), stream)); - RAFT_CUDA_TRY( - cudaMemcpyAsync(clusterSizes, &n, sizeof(index_type_t), cudaMemcpyHostToDevice, stream)); - if (updateCentroids(handle, n, d, k, obs, codes, clusterSizes, centroids, work, work_int)) - WARNING("could not compute k-means centroids"); - - dim3 blockDim{WARP_SIZE, 1, BLOCK_SIZE / WARP_SIZE}; - - dim3 gridDim{std::min(ceildiv(d, WARP_SIZE), grid_lower_bound), - 1, - std::min(ceildiv(n, BLOCK_SIZE / WARP_SIZE), grid_lower_bound)}; - - RAFT_CUDA_TRY(cudaMemsetAsync(work, 0, n * k * sizeof(value_type_t), stream)); - computeDistances<<>>(n, d, 1, obs, centroids, work); - RAFT_CHECK_CUDA(stream); - *residual_host = thrust::reduce( - thrust_exec_policy, thrust::device_pointer_cast(work), thrust::device_pointer_cast(work + n)); - RAFT_CHECK_CUDA(stream); - return 0; - } - if (n <= k) { - thrust::sequence(thrust_exec_policy, - thrust::device_pointer_cast(codes), - thrust::device_pointer_cast(codes + n)); - RAFT_CHECK_CUDA(stream); - thrust::fill_n(thrust_exec_policy, thrust::device_pointer_cast(clusterSizes), n, 1); - RAFT_CHECK_CUDA(stream); - - if (n < k) - RAFT_CUDA_TRY(cudaMemsetAsync(clusterSizes + n, 0, (k - n) * sizeof(index_type_t), stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync( - centroids, obs, d * n * sizeof(value_type_t), cudaMemcpyDeviceToDevice, stream)); - *residual_host = 0; - return 0; - } - - // Initialize cuBLAS - // #TODO: Call from public API when ready - RAFT_CUBLAS_TRY( - raft::linalg::detail::cublassetpointermode(cublas_h, CUBLAS_POINTER_MODE_HOST, stream)); - - // ------------------------------------------------------- - // k-means++ algorithm - // ------------------------------------------------------- - - // Choose initial cluster centroids - if (initializeCentroids(handle, n, d, k, obs, centroids, codes, clusterSizes, work, seed)) - WARNING("could not initialize k-means centroids"); - - // Apply k-means iteration until convergence - for (iter = 0; iter < maxiter; ++iter) { - // Update cluster centroids - if (updateCentroids(handle, n, d, k, obs, codes, clusterSizes, centroids, work, work_int)) - WARNING("could not update k-means centroids"); - - // Determine centroid closest to each observation - residualPrev = *residual_host; - if (assignCentroids(handle, n, d, k, obs, centroids, work, codes, clusterSizes, residual_host)) - WARNING("could not assign observation vectors to k-means clusters"); - - // Reinitialize empty clusters with new centroids - index_type_t emptyCentroid = (thrust::find(thrust_exec_policy, - thrust::device_pointer_cast(clusterSizes), - thrust::device_pointer_cast(clusterSizes + k), - 0) - - thrust::device_pointer_cast(clusterSizes)); - - // FIXME: emptyCentroid never reaches k (infinite loop) under certain - // conditions, such as if obs is corrupt (as seen as a result of a - // DataFrame column of NULL edge vals used to create the Graph) - while (emptyCentroid < k) { - if (chooseNewCentroid( - handle, n, d, uniformDist(rng), obs, work, centroids + IDX(0, emptyCentroid, d))) - WARNING("could not replace empty centroid"); - if (assignCentroids( - handle, n, d, k, obs, centroids, work, codes, clusterSizes, residual_host)) - WARNING("could not assign observation vectors to k-means clusters"); - emptyCentroid = (thrust::find(thrust_exec_policy, - thrust::device_pointer_cast(clusterSizes), - thrust::device_pointer_cast(clusterSizes + k), - 0) - - thrust::device_pointer_cast(clusterSizes)); - RAFT_CHECK_CUDA(stream); - } - - // Check for convergence - if (std::fabs(residualPrev - (*residual_host)) / n < tol) { - ++iter; - break; - } - } - - // Warning if k-means has failed to converge - if (std::fabs(residualPrev - (*residual_host)) / n >= tol) WARNING("k-means failed to converge"); - - *iters_host = iter; - return 0; -} - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param tol Tolerance for convergence. k-means stops when the - * change in residual divided by n is less than tol. - * @param maxiter Maximum number of k-means iterations. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param residual On exit, residual sum of squares (sum of squares - * of distances between observation vectors and centroids). - * @param iters on exit, number of k-means iterations. - * @param seed random seed to be used. - * @return error flag - */ -template -int kmeans(raft::resources const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - value_type_t tol, - index_type_t maxiter, - const value_type_t* __restrict__ obs, - index_type_t* __restrict__ codes, - value_type_t& residual, - index_type_t& iters, - unsigned long long seed = 123456) -{ - // Check that parameters are valid - RAFT_EXPECTS(n > 0, "invalid parameter (n<1)"); - RAFT_EXPECTS(d > 0, "invalid parameter (d<1)"); - RAFT_EXPECTS(k > 0, "invalid parameter (k<1)"); - RAFT_EXPECTS(tol > 0, "invalid parameter (tol<=0)"); - RAFT_EXPECTS(maxiter >= 0, "invalid parameter (maxiter<0)"); - - // Allocate memory - raft::spectral::matrix::vector_t clusterSizes(handle, k); - raft::spectral::matrix::vector_t centroids(handle, d * k); - raft::spectral::matrix::vector_t work(handle, n * std::max(k, d)); - raft::spectral::matrix::vector_t work_int(handle, 2 * d * n); - - // Perform k-means - return kmeans(handle, - n, - d, - k, - tol, - maxiter, - obs, - codes, - clusterSizes.raw(), - centroids.raw(), - work.raw(), - work_int.raw(), - &residual, - &iters, - seed); -} - -} // namespace detail -} // namespace cluster -} // namespace raft diff --git a/cpp/include/raft/cluster/detail/mst.cuh b/cpp/include/raft/cluster/detail/mst.cuh deleted file mode 100644 index 55becc8e15..0000000000 --- a/cpp/include/raft/cluster/detail/mst.cuh +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -namespace raft::cluster::detail { - -template -void merge_msts(sparse::solver::Graph_COO& coo1, - sparse::solver::Graph_COO& coo2, - cudaStream_t stream) -{ - /** Add edges to existing mst **/ - int final_nnz = coo2.n_edges + coo1.n_edges; - - coo1.src.resize(final_nnz, stream); - coo1.dst.resize(final_nnz, stream); - coo1.weights.resize(final_nnz, stream); - - /** - * Construct final edge list - */ - raft::copy_async(coo1.src.data() + coo1.n_edges, coo2.src.data(), coo2.n_edges, stream); - raft::copy_async(coo1.dst.data() + coo1.n_edges, coo2.dst.data(), coo2.n_edges, stream); - raft::copy_async(coo1.weights.data() + coo1.n_edges, coo2.weights.data(), coo2.n_edges, stream); - - coo1.n_edges = final_nnz; -} - -/** - * Connect an unconnected knn graph (one in which mst returns an msf). The - * device buffers underlying the Graph_COO object are modified in-place. - * @tparam value_idx index type - * @tparam value_t floating-point value type - * @param[in] handle raft handle - * @param[in] X original dense data from which knn grpah was constructed - * @param[inout] msf edge list containing the mst result - * @param[in] m number of rows in X - * @param[in] n number of columns in X - * @param[inout] color the color labels array returned from the mst invocation - * @return updated MST edge list - */ -template -void connect_knn_graph( - raft::resources const& handle, - const value_t* X, - sparse::solver::Graph_COO& msf, - size_t m, - size_t n, - value_idx* color, - red_op reduction_op, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded) -{ - auto stream = resource::get_cuda_stream(handle); - - raft::sparse::COO connected_edges(stream); - - // default row and column batch sizes are chosen for computing cross component nearest neighbors. - // Reference: PR #1445 - static constexpr size_t default_row_batch_size = 4096; - static constexpr size_t default_col_batch_size = 16; - - raft::sparse::neighbors::cross_component_nn(handle, - connected_edges, - X, - color, - m, - n, - reduction_op, - min(m, default_row_batch_size), - min(n, default_col_batch_size)); - - rmm::device_uvector indptr2(m + 1, stream); - raft::sparse::convert::sorted_coo_to_csr( - connected_edges.rows(), connected_edges.nnz, indptr2.data(), m + 1, stream); - - // On the second call, we hand the MST the original colors - // and the new set of edges and let it restart the optimization process - auto new_mst = - raft::sparse::solver::mst(handle, - indptr2.data(), - connected_edges.cols(), - connected_edges.vals(), - m, - connected_edges.nnz, - color, - stream, - false, - false); - - merge_msts(msf, new_mst, stream); -} - -/** - * Constructs an MST and sorts the resulting edges in ascending - * order by their weight. - * - * Hierarchical clustering heavily relies upon the ordering - * and vertices returned in the MST. If the result of the - * MST was actually a minimum-spanning forest, the CSR - * being passed into the MST is not connected. In such a - * case, this graph will be connected by performing a - * KNN across the components. - * @tparam value_idx - * @tparam value_t - * @param[in] handle raft handle - * @param[in] indptr CSR indptr of connectivities graph - * @param[in] indices CSR indices array of connectivities graph - * @param[in] pw_dists CSR weights array of connectivities graph - * @param[in] m number of rows in X / src vertices in connectivities graph - * @param[in] n number of columns in X - * @param[out] mst_src output src edges - * @param[out] mst_dst output dst edges - * @param[out] mst_weight output weights (distances) - * @param[in] max_iter maximum iterations to run knn graph connection. This - * argument is really just a safeguard against the potential for infinite loops. - */ -template -void build_sorted_mst( - raft::resources const& handle, - const value_t* X, - const value_idx* indptr, - const value_idx* indices, - const value_t* pw_dists, - size_t m, - size_t n, - value_idx* mst_src, - value_idx* mst_dst, - value_t* mst_weight, - value_idx* color, - size_t nnz, - red_op reduction_op, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded, - int max_iter = 10) -{ - auto stream = resource::get_cuda_stream(handle); - - // We want to have MST initialize colors on first call. - auto mst_coo = raft::sparse::solver::mst( - handle, indptr, indices, pw_dists, (value_idx)m, nnz, color, stream, false, true); - - int iters = 1; - int n_components = raft::sparse::neighbors::get_n_components(color, m, stream); - - while (n_components > 1 && iters < max_iter) { - connect_knn_graph(handle, X, mst_coo, m, n, color, reduction_op); - - iters++; - - n_components = raft::sparse::neighbors::get_n_components(color, m, stream); - } - - /** - * The `max_iter` argument was introduced only to prevent the potential for an infinite loop. - * Ideally the log2(n) guarantees of the MST should be enough to connect KNN graphs with a - * massive number of data samples in very few iterations. If it does not, there are 3 likely - * reasons why (in order of their likelihood): - * 1. There is a bug in this code somewhere - * 2. Either the given KNN graph wasn't generated from X or the same metric is not being used - * to generate the 1-nn (currently only L2SqrtExpanded is supported). - * 3. max_iter was not large enough to connect the graph (less likely). - * - * Note that a KNN graph generated from 50 random isotropic balls (with significant overlap) - * was able to be connected in a single iteration. - */ - RAFT_EXPECTS(n_components == 1, - "KNN graph could not be connected in %d iterations. " - "Please verify that the input knn graph is generated from X " - "(and the same distance metric used)," - " or increase 'max_iter'", - max_iter); - - raft::sparse::op::coo_sort_by_weight( - mst_coo.src.data(), mst_coo.dst.data(), mst_coo.weights.data(), mst_coo.n_edges, stream); - - raft::copy_async(mst_src, mst_coo.src.data(), mst_coo.n_edges, stream); - raft::copy_async(mst_dst, mst_coo.dst.data(), mst_coo.n_edges, stream); - raft::copy_async(mst_weight, mst_coo.weights.data(), mst_coo.n_edges, stream); -} - -}; // namespace raft::cluster::detail \ No newline at end of file diff --git a/cpp/include/raft/cluster/detail/single_linkage.cuh b/cpp/include/raft/cluster/detail/single_linkage.cuh deleted file mode 100644 index ccc6472684..0000000000 --- a/cpp/include/raft/cluster/detail/single_linkage.cuh +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::cluster::detail { - -static const size_t EMPTY = 0; - -/** - * Single-linkage clustering, capable of constructing a KNN graph to - * scale the algorithm beyond the n^2 memory consumption of implementations - * that use the fully-connected graph of pairwise distances by connecting - * a knn graph when k is not large enough to connect it. - - * @tparam value_idx - * @tparam value_t - * @tparam dist_type method to use for constructing connectivities graph - * @param[in] handle raft handle - * @param[in] X dense input matrix in row-major layout - * @param[in] m number of rows in X - * @param[in] n number of columns in X - * @param[in] metric distance metrix to use when constructing connectivities graph - * @param[out] out struct containing output dendrogram and cluster assignments - * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect - control - * of k. The algorithm will set `k = log(n) + c` - * @param[in] n_clusters number of clusters to assign data samples - */ -template -void single_linkage(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - linkage_output* out, - int c, - size_t n_clusters) -{ - ASSERT(n_clusters <= m, "n_clusters must be less than or equal to the number of data points"); - - auto stream = resource::get_cuda_stream(handle); - - rmm::device_uvector indptr(EMPTY, stream); - rmm::device_uvector indices(EMPTY, stream); - rmm::device_uvector pw_dists(EMPTY, stream); - - /** - * 1. Construct distance graph - */ - detail::get_distance_graph( - handle, X, m, n, metric, indptr, indices, pw_dists, c); - - rmm::device_uvector mst_rows(m - 1, stream); - rmm::device_uvector mst_cols(m - 1, stream); - rmm::device_uvector mst_data(m - 1, stream); - - /** - * 2. Construct MST, sorted by weights - */ - rmm::device_uvector color(m, stream); - raft::sparse::neighbors::FixConnectivitiesRedOp op(m); - detail::build_sorted_mst(handle, - X, - indptr.data(), - indices.data(), - pw_dists.data(), - m, - n, - mst_rows.data(), - mst_cols.data(), - mst_data.data(), - color.data(), - indices.size(), - op, - metric); - - pw_dists.release(); - - /** - * Perform hierarchical labeling - */ - size_t n_edges = mst_rows.size(); - - rmm::device_uvector out_delta(n_edges, stream); - rmm::device_uvector out_size(n_edges, stream); - // Create dendrogram - detail::build_dendrogram_host(handle, - mst_rows.data(), - mst_cols.data(), - mst_data.data(), - n_edges, - out->children, - out_delta.data(), - out_size.data()); - detail::extract_flattened_clusters(handle, out->labels, out->children, n_clusters, m); - - out->m = m; - out->n_clusters = n_clusters; - out->n_leaves = m; - out->n_connected_components = 1; -} -}; // namespace raft::cluster::detail \ No newline at end of file diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh deleted file mode 100644 index 38318e8ec8..0000000000 --- a/cpp/include/raft/cluster/kmeans.cuh +++ /dev/null @@ -1,1120 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::cluster::kmeans { - -/** - * Functor used for sampling centroids - */ -template -using SamplingOp = detail::SamplingOp; - -/** - * Functor used to extract the index from a KeyValue pair - * storing both index and a distance. - */ -template -using KeyValueIndexOp = detail::KeyValueIndexOp; - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * - * @code{.cpp} - * #include - * #include - * #include - * using namespace raft::cluster; - * ... - * raft::raft::resources handle; - * raft::cluster::KMeansParams params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * - * kmeans::fit(handle, - * params, - * X, - * std::nullopt, - * centroids, - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers. - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -[[deprecated("Use cuVS instead")]] void fit( - raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - detail::kmeans_fit(handle, params, X, sample_weight, centroids, inertia, n_iter); -} - -/** - * @brief Predict the closest cluster each sample in X belongs to. - * - * @code{.cpp} - * #include - * #include - * #include - * using namespace raft::cluster; - * ... - * raft::raft::resources handle; - * raft::cluster::KMeansParams params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * - * kmeans::fit(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * ... - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * false, - * labels.view(), - * raft::make_scalar_view(&ineratia)); - * @endcode - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X New data to predict. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[in] centroids Cluster centroids. The data must be in - * row-major format. - * [dim = n_clusters x n_features] - * @param[in] normalize_weight True if the weights should be normalized - * @param[out] labels Index of the cluster each sample in X - * belongs to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to - * their closest cluster center. - */ -template -[[deprecated("Use cuVS instead")]] void predict( - raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) -{ - detail::kmeans_predict( - handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); -} - -/** - * @brief Compute k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * #include - * using namespace raft::cluster; - * ... - * raft::raft::resources handle; - * raft::cluster::KMeansParams params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * labels.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -[[deprecated("Use cuVS instead")]] void fit_predict( - raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - detail::kmeans_fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} - -/** - * @brief Transform X to a cluster-distance space. - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Cluster centroids. The data must be in row-major format. - * [dim = n_clusters x n_features] - * @param[out] X_new X transformed in the new space. - * [dim = n_samples x n_features] - */ -template -void transform(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_matrix_view X_new) -{ - detail::kmeans_transform(handle, params, X, centroids, X_new); -} - -template -[[deprecated("Use cuVS instead")]] void transform(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT* X_new) -{ - detail::kmeans_transform( - handle, params, X, centroids, n_samples, n_features, X_new); -} - -/** - * Automatically find the optimal value of k using a binary search. - * This method maximizes the Calinski-Harabasz Index while minimizing the per-cluster inertia. - * - * @code{.cpp} - * #include - * #include - * #include - * - * #include - * - * using namespace raft::cluster; - * - * raft::handle_t handle; - * int n_samples = 100, n_features = 15, n_clusters = 10; - * auto X = raft::make_device_matrix(handle, n_samples, n_features); - * auto labels = raft::make_device_vector(handle, n_samples); - * - * raft::random::make_blobs(handle, X, labels, n_clusters); - * - * auto best_k = raft::make_host_scalar(0); - * auto n_iter = raft::make_host_scalar(0); - * auto inertia = raft::make_host_scalar(0); - * - * kmeans::find_k(handle, X, best_k.view(), inertia.view(), n_iter.view(), n_clusters+1); - * - * @endcode - * - * @tparam idx_t indexing type (should be integral) - * @tparam value_t value type (should be floating point) - * @param handle raft handle - * @param X input observations (shape n_samples, n_dims) - * @param best_k best k found from binary search - * @param inertia inertia of best k found - * @param n_iter number of iterations used to find best k - * @param kmax maximum k to try in search - * @param kmin minimum k to try in search (should be >= 1) - * @param maxiter maximum number of iterations to run - * @param tol tolerance for early stopping convergence - */ -template -void find_k(raft::resources const& handle, - raft::device_matrix_view X, - raft::host_scalar_view best_k, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - idx_t kmax, - idx_t kmin = 1, - idx_t maxiter = 100, - value_t tol = 1e-3) -{ - detail::find_k(handle, X, best_k, inertia, n_iter, kmax, kmin, maxiter, tol); -} - -/** - * @brief Select centroids according to a sampling operation - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] isSampleCentroid Flag the sample chosen as initial centroid - * [dim = n_samples] - * @param[in] select_op The sampling operation used to select the centroids - * @param[out] inRankCp The sampled centroids - * [dim = n_selected_centroids x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void sample_centroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view minClusterDistance, - raft::device_vector_view isSampleCentroid, - SamplingOp& select_op, - rmm::device_uvector& inRankCp, - rmm::device_uvector& workspace) -{ - detail::sampleCentroids( - handle, X, minClusterDistance, isSampleCentroid, select_op, inRankCp, workspace); -} - -/** - * @brief Compute cluster cost - * - * @tparam DataT the type of data used for weights, distances. - * @tparam ReductionOpT the type of data used for the reduction operation. - * - * @param[in] handle The raft handle - * @param[in] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] workspace Temporary workspace buffer which can get resized - * @param[out] clusterCost Resulting cluster cost - * @param[in] reduction_op The reduction operation used for the cost - * - */ -template -void cluster_cost(raft::resources const& handle, - raft::device_vector_view minClusterDistance, - rmm::device_uvector& workspace, - raft::device_scalar_view clusterCost, - ReductionOpT reduction_op) -{ - detail::computeClusterCost( - handle, minClusterDistance, workspace, clusterCost, raft::identity_op{}, reduction_op); -} - -/** - * @brief Update centroids given current centroids and number of points assigned to each centroid. - * This function also produces a vector of RAFT key/value pairs containing the cluster assignment - * for each point and its distance. - * - * @tparam DataT - * @tparam IndexT - * @param[in] handle: Raft handle to use for managing library resources - * @param[in] X: input matrix (size n_samples, n_features) - * @param[in] sample_weights: number of samples currently assigned to each centroid (size n_samples) - * @param[in] centroids: matrix of current centroids (size n_clusters, n_features) - * @param[in] labels: Iterator of labels (can also be a raw pointer) - * @param[out] weight_per_cluster: sum of sample weights per cluster (size n_clusters) - * @param[out] new_centroids: output matrix of updated centroids (size n_clusters, n_features) - */ -template -void update_centroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroids, - LabelsIterator labels, - raft::device_vector_view weight_per_cluster, - raft::device_matrix_view new_centroids) -{ - // TODO: Passing these into the algorithm doesn't really present much of a benefit - // because they are being resized anyways. - // ref https://github.com/rapidsai/raft/issues/930 - rmm::device_uvector workspace(0, resource::get_cuda_stream(handle)); - - detail::update_centroids( - handle, X, sample_weights, centroids, labels, weight_per_cluster, new_centroids, workspace); -} - -/** - * @brief Compute distance for every sample to it's nearest centroid - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[out] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance - * matrix - * @param[in] metric Distance metric to use - * @param[in] batch_samples batch size for input data samples - * @param[in] batch_centroids batch size for input centroids - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void min_cluster_distance(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view minClusterDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - raft::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) -{ - detail::minClusterDistanceCompute(handle, - X, - centroids, - minClusterDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - metric, - batch_samples, - batch_centroids, - workspace); -} - -/** - * @brief Calculates a pair for every sample in input 'X' where key is an - * index of one of the 'centroids' (index of the nearest centroid) and 'value' - * is the distance between the sample and the 'centroid[key]' - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[out] minClusterAndDistance Distance vector that contains for every sample, the nearest - * centroid and it's distance - * [dim = n_samples] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance - * matrix - * @param[in] metric distance metric - * @param[in] batch_samples batch size of data samples - * @param[in] batch_centroids batch size of centroids - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void min_cluster_and_distance( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - raft::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) -{ - detail::minClusterAndDistanceCompute(handle, - X, - centroids, - minClusterAndDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - metric, - batch_samples, - batch_centroids, - workspace); -} - -/** - * @brief Shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores - * in 'out' does not modify the input - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] in The data to shuffle and gather - * [dim = n_samples x n_features] - * @param[out] out The sampled data - * [dim = n_samples_to_gather x n_features] - * @param[in] n_samples_to_gather Number of sample to gather - * @param[in] seed Seed for the shuffle - * - */ -template -void shuffle_and_gather(raft::resources const& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - uint32_t n_samples_to_gather, - uint64_t seed) -{ - detail::shuffleAndGather(handle, in, out, n_samples_to_gather, seed); -} - -/** - * @brief Count the number of samples in each cluster - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - * @param[out] sampleCountInCluster The count for each centroid - * [dim = n_cluster] - * - */ -template -void count_samples_in_cluster(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view L2NormX, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace, - raft::device_vector_view sampleCountInCluster) -{ - detail::countSamplesInCluster( - handle, params, X, L2NormX, centroids, workspace, sampleCountInCluster); -} - -/** - * @brief Selects 'n_clusters' samples from the input X using kmeans++ algorithm. - * - * @see "k-means++: the advantages of careful seeding". 2007, Arthur, D. and Vassilvitskii, S. - * ACM-SIAM symposium on Discrete algorithms. - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[out] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - */ -template -void init_plus_plus(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace) -{ - detail::kmeansPlusPlus(handle, params, X, centroids, workspace); -} - -/* - * @brief Main function used to fit KMeans (after cluster initialization) - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids [in] Initial cluster centers. - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - * @param[in] workspace Temporary workspace buffer which can get resized - */ -template -void fit_main(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) -{ - detail::kmeans_fit_main( - handle, params, X, sample_weights, centroids, inertia, n_iter, workspace); -} - -}; // end namespace raft::cluster::kmeans - -namespace raft::cluster { - -/** - * Note: All of the functions below in raft::cluster are deprecated and will - * be removed in a future release. Please use raft::cluster::kmeans instead. - */ - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers. - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -void kmeans_fit(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - kmeans::fit(handle, params, X, sample_weight, centroids, inertia, n_iter); -} - -template -void kmeans_fit(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT& inertia, - IndexT& n_iter) -{ - kmeans::fit( - handle, params, X, sample_weight, centroids, n_samples, n_features, inertia, n_iter); -} - -/** - * @brief Predict the closest cluster each sample in X belongs to. - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X New data to predict. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[in] centroids Cluster centroids. The data must be in - * row-major format. - * [dim = n_clusters x n_features] - * @param[in] normalize_weight True if the weights should be normalized - * @param[out] labels Index of the cluster each sample in X - * belongs to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to - * their closest cluster center. - */ -template -void kmeans_predict(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) -{ - kmeans::predict( - handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); -} - -template -void kmeans_predict(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - IndexT* labels, - bool normalize_weight, - DataT& inertia) -{ - kmeans::predict(handle, - params, - X, - sample_weight, - centroids, - n_samples, - n_features, - labels, - normalize_weight, - inertia); -} - -/** - * @brief Compute k-means clustering and predicts cluster index for each sample - * in the input. - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -void kmeans_fit_predict(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - kmeans::fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} - -template -void kmeans_fit_predict(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - DataT* centroids, - IndexT n_samples, - IndexT n_features, - IndexT* labels, - DataT& inertia, - IndexT& n_iter) -{ - kmeans::fit_predict( - handle, params, X, sample_weight, centroids, n_samples, n_features, labels, inertia, n_iter); -} - -/** - * @brief Transform X to a cluster-distance space. - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Cluster centroids. The data must be in row-major format. - * [dim = n_clusters x n_features] - * @param[out] X_new X transformed in the new space. - * [dim = n_samples x n_features] - */ -template -void kmeans_transform(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_matrix_view X_new) -{ - kmeans::transform(handle, params, X, centroids, X_new); -} - -template -void kmeans_transform(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT* X_new) -{ - kmeans::transform(handle, params, X, centroids, n_samples, n_features, X_new); -} - -template -using SamplingOp = kmeans::SamplingOp; - -template -using KeyValueIndexOp = kmeans::KeyValueIndexOp; - -/** - * @brief Select centroids according to a sampling operation - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] isSampleCentroid Flag the sample chosen as initial centroid - * [dim = n_samples] - * @param[in] select_op The sampling operation used to select the centroids - * @param[out] inRankCp The sampled centroids - * [dim = n_selected_centroids x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void sampleCentroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view minClusterDistance, - raft::device_vector_view isSampleCentroid, - SamplingOp& select_op, - rmm::device_uvector& inRankCp, - rmm::device_uvector& workspace) -{ - kmeans::sample_centroids( - handle, X, minClusterDistance, isSampleCentroid, select_op, inRankCp, workspace); -} - -/** - * @brief Compute cluster cost - * - * @tparam DataT the type of data used for weights, distances. - * @tparam ReductionOpT the type of data used for the reduction operation. - * - * @param[in] handle The raft handle - * @param[in] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] workspace Temporary workspace buffer which can get resized - * @param[out] clusterCost Resulting cluster cost - * @param[in] reduction_op The reduction operation used for the cost - * - */ -template -void computeClusterCost(raft::resources const& handle, - raft::device_vector_view minClusterDistance, - rmm::device_uvector& workspace, - raft::device_scalar_view clusterCost, - ReductionOpT reduction_op) -{ - kmeans::cluster_cost(handle, minClusterDistance, workspace, clusterCost, reduction_op); -} - -/** - * @brief Compute distance for every sample to it's nearest centroid - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[out] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance - * matrix - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void minClusterDistanceCompute(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view minClusterDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - rmm::device_uvector& workspace) -{ - kmeans::min_cluster_distance(handle, - X, - centroids, - minClusterDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); -} - -/** - * @brief Calculates a pair for every sample in input 'X' where key is an - * index of one of the 'centroids' (index of the nearest centroid) and 'value' - * is the distance between the sample and the 'centroid[key]' - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[out] minClusterAndDistance Distance vector that contains for every sample, the nearest - * centroid and it's distance - * [dim = n_samples] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance - * matrix - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void minClusterAndDistanceCompute( - raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - rmm::device_uvector& workspace) -{ - kmeans::min_cluster_and_distance(handle, - X, - centroids, - minClusterAndDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); -} - -/** - * @brief Shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores - * in 'out' does not modify the input - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] in The data to shuffle and gather - * [dim = n_samples x n_features] - * @param[out] out The sampled data - * [dim = n_samples_to_gather x n_features] - * @param[in] n_samples_to_gather Number of sample to gather - * @param[in] seed Seed for the shuffle - * - */ -template -void shuffleAndGather(raft::resources const& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - uint32_t n_samples_to_gather, - uint64_t seed) -{ - kmeans::shuffle_and_gather(handle, in, out, n_samples_to_gather, seed); -} - -/** - * @brief Count the number of samples in each cluster - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - * @param[out] sampleCountInCluster The count for each centroid - * [dim = n_cluster] - * - */ -template -void countSamplesInCluster(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view L2NormX, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace, - raft::device_vector_view sampleCountInCluster) -{ - kmeans::count_samples_in_cluster( - handle, params, X, L2NormX, centroids, workspace, sampleCountInCluster); -} - -/* - * @brief Selects 'n_clusters' samples from the input X using kmeans++ algorithm. - - * @note This is the algorithm described in - * "k-means++: the advantages of careful seeding". 2007, Arthur, D. and Vassilvitskii, S. - * ACM-SIAM symposium on Discrete algorithms. - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[out] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - */ -template -void kmeansPlusPlus(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroidsRawData, - rmm::device_uvector& workspace) -{ - kmeans::init_plus_plus(handle, params, X, centroidsRawData, workspace); -} - -/* - * @brief Main function used to fit KMeans (after cluster initialization) - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids [in] Initial cluster centers. - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - * @param[in] workspace Temporary workspace buffer which can get resized - */ -template -void kmeans_fit_main(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view weight, - raft::device_matrix_view centroidsRawData, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) -{ - kmeans::fit_main( - handle, params, X, weight, centroidsRawData, inertia, n_iter, workspace); -} -}; // namespace raft::cluster diff --git a/cpp/include/raft/cluster/kmeans_balanced.cuh b/cpp/include/raft/cluster/kmeans_balanced.cuh deleted file mode 100644 index 7479047fce..0000000000 --- a/cpp/include/raft/cluster/kmeans_balanced.cuh +++ /dev/null @@ -1,371 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -#include - -namespace raft::cluster::kmeans_balanced { - -/** - * @brief Find clusters of balanced sizes with a hierarchical k-means algorithm. - * - * This variant of the k-means algorithm first clusters the dataset in mesoclusters, then clusters - * the subsets associated to each mesocluster into fine clusters, and finally runs a few k-means - * iterations over the whole dataset and with all the centroids to obtain the final clusters. - * - * Each k-means iteration applies expectation-maximization-balancing: - * - Balancing: adjust centers for clusters that have a small number of entries. If the size of a - * cluster is below a threshold, the center is moved towards a bigger cluster. - * - Expectation: predict the labels (i.e find closest cluster centroid to each point) - * - Maximization: calculate optimal centroids (i.e find the center of gravity of each cluster) - * - * The number of mesoclusters is chosen by rounding the square root of the number of clusters. E.g - * for 512 clusters, we would have 23 mesoclusters. The number of fine clusters per mesocluster is - * chosen proportionally to the number of points in each mesocluster. - * - * This variant of k-means uses random initialization and a fixed number of iterations, though - * iterations can be repeated if the balancing step moved the centroids. - * - * Additionally, this algorithm supports quantized datasets in arbitrary types but the core part of - * the algorithm will work with a floating-point type, hence a conversion function can be provided - * to map the data type to the math type. - * - * @code{.cpp} - * #include - * #include - * #include - * ... - * raft::handle_t handle; - * raft::cluster::kmeans_balanced_params params; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * raft::cluster::kmeans_balanced::fit(handle, params, X, centroids.view()); - * @endcode - * - * @tparam DataT Type of the input data. - * @tparam MathT Type of the centroids and mapped data. - * @tparam IndexT Type used for indexing. - * @tparam MappingOpT Type of the mapping function. - * @param[in] handle The raft resources - * @param[in] params Structure containing the hyper-parameters - * @param[in] X Training instances to cluster. The data must be in row-major format. - * [dim = n_samples x n_features] - * @param[out] centroids The generated centroids [dim = n_clusters x n_features] - * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic - * datatype. If DataT == MathT, this must be the identity. - */ -template -[[deprecated("Use cuVS instead")]] void fit(const raft::resources& handle, - kmeans_balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - MappingOpT mapping_op = raft::identity_op()) -{ - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), - "Number of features in dataset and centroids are different"); - RAFT_EXPECTS(static_cast(X.extent(0)) * static_cast(X.extent(1)) <= - static_cast(std::numeric_limits::max()), - "The chosen index type cannot represent all indices for the given dataset"); - RAFT_EXPECTS(centroids.extent(0) > IndexT{0} && centroids.extent(0) <= X.extent(0), - "The number of centroids must be strictly positive and cannot exceed the number of " - "points in the training dataset."); - - detail::build_hierarchical(handle, - params, - X.extent(1), - X.data_handle(), - X.extent(0), - centroids.data_handle(), - centroids.extent(0), - mapping_op); -} - -/** - * @brief Predict the closest cluster each sample in X belongs to. - * - * @code{.cpp} - * #include - * #include - * #include - * ... - * raft::handle_t handle; - * raft::cluster::kmeans_balanced_params params; - * auto labels = raft::make_device_vector(handle, n_rows); - * raft::cluster::kmeans_balanced::predict(handle, params, X, centroids, labels); - * @endcode - * - * @tparam DataT Type of the input data. - * @tparam MathT Type of the centroids and mapped data. - * @tparam IndexT Type used for indexing. - * @tparam LabelT Type of the output labels. - * @tparam MappingOpT Type of the mapping function. - * @param[in] handle The raft resources - * @param[in] params Structure containing the hyper-parameters - * @param[in] X Dataset for which to infer the closest clusters. - * [dim = n_samples x n_features] - * @param[in] centroids The input centroids [dim = n_clusters x n_features] - * @param[out] labels The output labels [dim = n_samples] - * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic - * datatype. If DataT == MathT, this must be the identity. - */ -template -[[deprecated("Use cuVS instead")]] void predict( - const raft::resources& handle, - kmeans_balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - MappingOpT mapping_op = raft::identity_op()) -{ - RAFT_EXPECTS(X.extent(0) == labels.extent(0), - "Number of rows in dataset and labels are different"); - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), - "Number of features in dataset and centroids are different"); - RAFT_EXPECTS(static_cast(X.extent(0)) * static_cast(X.extent(1)) <= - static_cast(std::numeric_limits::max()), - "The chosen index type cannot represent all indices for the given dataset"); - RAFT_EXPECTS(static_cast(centroids.extent(0)) <= - static_cast(std::numeric_limits::max()), - "The chosen label type cannot represent all cluster labels"); - - detail::predict(handle, - params, - centroids.data_handle(), - centroids.extent(0), - X.extent(1), - X.data_handle(), - X.extent(0), - labels.data_handle(), - mapping_op); -} - -/** - * @brief Compute hierarchical balanced k-means clustering and predict cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * #include - * ... - * raft::handle_t handle; - * raft::cluster::kmeans_balanced_params params; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, n_rows); - * raft::cluster::kmeans_balanced::fit_predict( - * handle, params, X, centroids.view(), labels.view()); - * @endcode - * - * @tparam DataT Type of the input data. - * @tparam MathT Type of the centroids and mapped data. - * @tparam IndexT Type used for indexing. - * @tparam LabelT Type of the output labels. - * @tparam MappingOpT Type of the mapping function. - * @param[in] handle The raft resources - * @param[in] params Structure containing the hyper-parameters - * @param[in] X Training instances to cluster. The data must be in row-major format. - * [dim = n_samples x n_features] - * @param[out] centroids The output centroids [dim = n_clusters x n_features] - * @param[out] labels The output labels [dim = n_samples] - * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic - * datatype. If DataT and MathT are the same, this must be the identity. - */ -template -[[deprecated("Use cuVS instead")]] void fit_predict( - const raft::resources& handle, - kmeans_balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - MappingOpT mapping_op = raft::identity_op()) -{ - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)); - raft::cluster::kmeans_balanced::fit(handle, params, X, centroids, mapping_op); - raft::cluster::kmeans_balanced::predict(handle, params, X, centroids_const, labels, mapping_op); -} - -namespace helpers { - -/** - * @brief Randomly initialize centers and apply expectation-maximization-balancing iterations - * - * This is essentially the non-hierarchical balanced k-means algorithm which is used by the - * hierarchical algorithm once to build the mesoclusters and once per mesocluster to build the fine - * clusters. - * - * @code{.cpp} - * #include - * #include - * #include - * ... - * raft::handle_t handle; - * raft::cluster::kmeans_balanced_params params; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, n_samples); - * auto sizes = raft::make_device_vector(handle, n_clusters); - * raft::cluster::kmeans_balanced::build_clusters( - * handle, params, X, centroids.view(), labels.view(), sizes.view()); - * @endcode - * - * @tparam DataT Type of the input data. - * @tparam MathT Type of the centroids and mapped data. - * @tparam IndexT Type used for indexing. - * @tparam LabelT Type of the output labels. - * @tparam CounterT Counter type supported by CUDA's native atomicAdd. - * @tparam MappingOpT Type of the mapping function. - * @param[in] handle The raft resources - * @param[in] params Structure containing the hyper-parameters - * @param[in] X Training instances to cluster. The data must be in row-major format. - * [dim = n_samples x n_features] - * @param[out] centroids The output centroids [dim = n_clusters x n_features] - * @param[out] labels The output labels [dim = n_samples] - * @param[out] cluster_sizes Size of each cluster [dim = n_clusters] - * @param[in] mapping_op (optional) Functor to convert from the input datatype to the - * arithmetic datatype. If DataT == MathT, this must be the identity. - * @param[in] X_norm (optional) Dataset's row norms [dim = n_samples] - */ -template -[[deprecated("Use cuVS instead")]] void build_clusters( - const raft::resources& handle, - const kmeans_balanced_params& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - raft::device_vector_view cluster_sizes, - MappingOpT mapping_op = raft::identity_op(), - std::optional> X_norm = std::nullopt) -{ - RAFT_EXPECTS(X.extent(0) == labels.extent(0), - "Number of rows in dataset and labels are different"); - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), - "Number of features in dataset and centroids are different"); - RAFT_EXPECTS(centroids.extent(0) == cluster_sizes.extent(0), - "Number of rows in centroids and clusyer_sizes are different"); - - detail::build_clusters(handle, - params, - X.extent(1), - X.data_handle(), - X.extent(0), - centroids.extent(0), - centroids.data_handle(), - labels.data_handle(), - cluster_sizes.data_handle(), - mapping_op, - resource::get_workspace_resource(handle), - X_norm.has_value() ? X_norm.value().data_handle() : nullptr); -} - -/** - * @brief Given the data and labels, calculate cluster centers and sizes in one sweep. - * - * Let `S_i = {x_k | x_k \in X & labels[k] == i}` be the vectors in the dataset with label i. - * - * On exit, - * `centers_i = (\sum_{x \in S_i} x + w_i * center_i) / (|S_i| + w_i)`, - * where `w_i = reset_counters ? 0 : cluster_size[i]`. - * - * In other words, the updated cluster centers are a weighted average of the existing cluster - * center, and the coordinates of the points labeled with i. _This allows calling this function - * multiple times with different datasets with the same effect as if calling this function once - * on the combined dataset_. - * - * @code{.cpp} - * #include - * #include - * ... - * raft::handle_t handle; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto sizes = raft::make_device_vector(handle, n_clusters); - * raft::cluster::kmeans_balanced::calc_centers_and_sizes( - * handle, X, labels, centroids.view(), sizes.view(), true); - * @endcode - * - * @tparam DataT Type of the input data. - * @tparam MathT Type of the centroids and mapped data. - * @tparam IndexT Type used for indexing. - * @tparam LabelT Type of the output labels. - * @tparam CounterT Counter type supported by CUDA's native atomicAdd. - * @tparam MappingOpT Type of the mapping function. - * @param[in] handle The raft resources - * @param[in] X Dataset for which to calculate cluster centers. The data must be in - * row-major format. [dim = n_samples x n_features] - * @param[in] labels The input labels [dim = n_samples] - * @param[out] centroids The output centroids [dim = n_clusters x n_features] - * @param[out] cluster_sizes Size of each cluster [dim = n_clusters] - * @param[in] reset_counters Whether to clear the output arrays before calculating. - * When set to `false`, this function may be used to update existing - * centers and sizes using the weighted average principle. - * @param[in] mapping_op (optional) Functor to convert from the input datatype to the - * arithmetic datatype. If DataT == MathT, this must be the identity. - */ -template -[[deprecated("Use cuVS instead")]] void calc_centers_and_sizes( - const raft::resources& handle, - raft::device_matrix_view X, - raft::device_vector_view labels, - raft::device_matrix_view centroids, - raft::device_vector_view cluster_sizes, - bool reset_counters = true, - MappingOpT mapping_op = raft::identity_op()) -{ - RAFT_EXPECTS(X.extent(0) == labels.extent(0), - "Number of rows in dataset and labels are different"); - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), - "Number of features in dataset and centroids are different"); - RAFT_EXPECTS(centroids.extent(0) == cluster_sizes.extent(0), - "Number of rows in centroids and clusyer_sizes are different"); - - detail::calc_centers_and_sizes(handle, - centroids.data_handle(), - cluster_sizes.data_handle(), - centroids.extent(0), - X.extent(1), - X.data_handle(), - X.extent(0), - labels.data_handle(), - reset_counters, - mapping_op, - resource::get_workspace_resource(handle)); -} - -} // namespace helpers - -} // namespace raft::cluster::kmeans_balanced diff --git a/cpp/include/raft/cluster/kmeans_balanced_types.hpp b/cpp/include/raft/cluster/kmeans_balanced_types.hpp deleted file mode 100644 index 11b77e288a..0000000000 --- a/cpp/include/raft/cluster/kmeans_balanced_types.hpp +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -namespace raft::cluster::kmeans_balanced { - -/** - * Simple object to specify hyper-parameters to the balanced k-means algorithm. - * - * The following metrics are currently supported in k-means balanced: - * - InnerProduct - * - L2Expanded - * - L2SqrtExpanded - */ -struct kmeans_balanced_params : kmeans_base_params { - /** - * Number of training iterations - */ - uint32_t n_iters = 20; -}; - -} // namespace raft::cluster::kmeans_balanced - -namespace raft::cluster { - -using kmeans_balanced::kmeans_balanced_params; - -} // namespace raft::cluster diff --git a/cpp/include/raft/cluster/kmeans_deprecated.cuh b/cpp/include/raft/cluster/kmeans_deprecated.cuh deleted file mode 100644 index 11f964eef5..0000000000 --- a/cpp/include/raft/cluster/kmeans_deprecated.cuh +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -namespace raft { -namespace cluster { - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param tol Tolerance for convergence. k-means stops when the - * change in residual divided by n is less than tol. - * @param maxiter Maximum number of k-means iterations. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param residual On exit, residual sum of squares (sum of squares - * of distances between observation vectors and centroids). - * @param iters on exit, number of k-means iterations. - * @param seed random seed to be used. - * @return error flag - */ -template -int kmeans(raft::resources const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - value_type_t tol, - index_type_t maxiter, - const value_type_t* __restrict__ obs, - index_type_t* __restrict__ codes, - value_type_t& residual, - index_type_t& iters, - unsigned long long seed = 123456) -{ - return detail::kmeans( - handle, n, d, k, tol, maxiter, obs, codes, residual, iters, seed); -} -} // namespace cluster -} // namespace raft diff --git a/cpp/include/raft/cluster/kmeans_types.hpp b/cpp/include/raft/cluster/kmeans_types.hpp deleted file mode 100644 index 4d956ad7a0..0000000000 --- a/cpp/include/raft/cluster/kmeans_types.hpp +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include - -namespace raft::cluster { - -/** Base structure for parameters that are common to all k-means algorithms */ -struct kmeans_base_params { - /** - * Metric to use for distance computation. The supported metrics can vary per algorithm. - */ - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; -}; - -} // namespace raft::cluster - -namespace raft::cluster::kmeans { - -/** - * Simple object to specify hyper-parameters to the kmeans algorithm. - */ -struct KMeansParams : kmeans_base_params { - enum InitMethod { - - /** - * Sample the centroids using the kmeans++ strategy - */ - KMeansPlusPlus, - - /** - * Sample the centroids uniformly at random - */ - Random, - - /** - * User provides the array of initial centroids - */ - Array - }; - - /** - * The number of clusters to form as well as the number of centroids to generate (default:8). - */ - int n_clusters = 8; - - /** - * Method for initialization, defaults to k-means++: - * - InitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm - * to select the initial cluster centers. - * - InitMethod::Random (random): Choose 'n_clusters' observations (rows) at - * random from the input data for the initial centroids. - * - InitMethod::Array (ndarray): Use 'centroids' as initial cluster centers. - */ - InitMethod init = KMeansPlusPlus; - - /** - * Maximum number of iterations of the k-means algorithm for a single run. - */ - int max_iter = 300; - - /** - * Relative tolerance with regards to inertia to declare convergence. - */ - double tol = 1e-4; - - /** - * verbosity level. - */ - int verbosity = RAFT_LEVEL_INFO; - - /** - * Seed to the random number generator. - */ - raft::random::RngState rng_state{0}; - - /** - * Number of instance k-means algorithm will be run with different seeds. - */ - int n_init = 1; - - /** - * Oversampling factor for use in the k-means|| algorithm - */ - double oversampling_factor = 2.0; - - // batch_samples and batch_centroids are used to tile 1NN computation which is - // useful to optimize/control the memory footprint - // Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 - // then don't tile the centroids - int batch_samples = 1 << 15; - - /** - * if 0 then batch_centroids = n_clusters - */ - int batch_centroids = 0; // - - bool inertia_check = false; -}; - -} // namespace raft::cluster::kmeans - -namespace raft::cluster { - -using kmeans::KMeansParams; - -} // namespace raft::cluster diff --git a/cpp/include/raft/cluster/single_linkage.cuh b/cpp/include/raft/cluster/single_linkage.cuh deleted file mode 100644 index 067445c542..0000000000 --- a/cpp/include/raft/cluster/single_linkage.cuh +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include - -namespace raft::cluster { - -/** - * Note: All of the functions below in the raft::cluster namespace are deprecated - * and will be removed in a future release. Please use raft::cluster::hierarchy - * instead. - */ - -/** - * Single-linkage clustering, capable of constructing a KNN graph to - * scale the algorithm beyond the n^2 memory consumption of implementations - * that use the fully-connected graph of pairwise distances by connecting - * a knn graph when k is not large enough to connect it. - - * @tparam value_idx - * @tparam value_t - * @tparam dist_type method to use for constructing connectivities graph - * @param[in] handle raft handle - * @param[in] X dense input matrix in row-major layout - * @param[in] m number of rows in X - * @param[in] n number of columns in X - * @param[in] metric distance metrix to use when constructing connectivities graph - * @param[out] out struct containing output dendrogram and cluster assignments - * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect - control - * of k. The algorithm will set `k = log(n) + c` - * @param[in] n_clusters number of clusters to assign data samples - */ -template -[[deprecated("Use cuVS instead")]] void single_linkage(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - linkage_output* out, - int c, - size_t n_clusters) -{ - detail::single_linkage( - handle, X, m, n, metric, out, c, n_clusters); -} -}; // namespace raft::cluster - -namespace raft::cluster::hierarchy { - -constexpr int DEFAULT_CONST_C = 15; - -/** - * Single-linkage clustering, capable of constructing a KNN graph to - * scale the algorithm beyond the n^2 memory consumption of implementations - * that use the fully-connected graph of pairwise distances by connecting - * a knn graph when k is not large enough to connect it. - - * @tparam value_idx - * @tparam value_t - * @tparam dist_type method to use for constructing connectivities graph - * @param[in] handle raft handle - * @param[in] X dense input matrix in row-major layout - * @param[out] dendrogram output dendrogram (size [n_rows - 1] * 2) - * @param[out] labels output labels vector (size n_rows) - * @param[in] metric distance metrix to use when constructing connectivities graph - * @param[in] n_clusters number of clusters to assign data samples - * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect - control of k. The algorithm will set `k = log(n) + c` - */ -template -[[deprecated("Use cuVS instead")]] void single_linkage( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view dendrogram, - raft::device_vector_view labels, - raft::distance::DistanceType metric, - size_t n_clusters, - std::optional c = std::make_optional(DEFAULT_CONST_C)) -{ - linkage_output out_arrs; - out_arrs.children = dendrogram.data_handle(); - out_arrs.labels = labels.data_handle(); - - raft::cluster::single_linkage( - handle, - X.data_handle(), - static_cast(X.extent(0)), - static_cast(X.extent(1)), - metric, - &out_arrs, - c.has_value() ? c.value() : DEFAULT_CONST_C, - n_clusters); -} -}; // namespace raft::cluster::hierarchy diff --git a/cpp/include/raft/cluster/single_linkage_types.hpp b/cpp/include/raft/cluster/single_linkage_types.hpp deleted file mode 100644 index cd815622bf..0000000000 --- a/cpp/include/raft/cluster/single_linkage_types.hpp +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::cluster::hierarchy { - -/** - * Determines the method for computing the minimum spanning tree (MST) - */ -enum LinkageDistance { - - /** - * Use a pairwise distance matrix as input to the mst. This - * is very fast and the best option for fairly small datasets (~50k data points) - */ - PAIRWISE = 0, - - /** - * Construct a KNN graph as input to the mst and provide additional - * edges if the mst does not converge. This is slower but scales - * to very large datasets. - */ - KNN_GRAPH = 1 -}; - -}; // end namespace raft::cluster::hierarchy - -// The code below is now considered legacy -namespace raft::cluster { - -using hierarchy::LinkageDistance; - -/** - * Simple container object for consolidating linkage results. This closely - * mirrors the trained instance variables populated in - * Scikit-learn's AgglomerativeClustering estimator. - * @tparam value_idx - * @tparam value_t - */ -template -class linkage_output { - public: - idx_t m; - idx_t n_clusters; - - idx_t n_leaves; - idx_t n_connected_components; - - // TODO: These will be made private in a future release - idx_t* labels; // size: m - idx_t* children; // size: (m-1, 2) - - raft::device_vector_view get_labels() - { - return raft::make_device_vector_view(labels, m); - } - - raft::device_matrix_view get_children() - { - return raft::make_device_matrix_view(children, m - 1, 2); - } -}; - -class linkage_output_int : public linkage_output {}; -class linkage_output_int64 : public linkage_output {}; - -}; // namespace raft::cluster diff --git a/cpp/include/raft/cluster/specializations.cuh b/cpp/include/raft/cluster/specializations.cuh deleted file mode 100644 index e85b05575f..0000000000 --- a/cpp/include/raft/cluster/specializations.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_HIDE_DEPRECATION_WARNINGS -#pragma message( \ - __FILE__ \ - " is deprecated and will be removed." \ - " Including specializations is not necessary any more." \ - " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") -#endif