Skip to content

Commit

Permalink
Merge pull request #70 from KIwabuchi/feature/dnnd_new_api
Browse files Browse the repository at this point in the history
(DNND) Update Simple API Examples
  • Loading branch information
KIwabuchi authored Jun 12, 2024
2 parents b365ba3 + 5059ab3 commit 12bd23b
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 39 deletions.
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ if (SALTATLAS_USE_METALL)
add_saltatlas_example(dnnd_pm_optimize_example)
add_saltatlas_example(dnnd_pm_query_example)
add_saltatlas_example(dnnd_simple_example)
add_saltatlas_example(dnnd_simple_custom_distance_example)
add_saltatlas_example(dnnd_simple_custom_point_example)

add_saltatlas_dnnd_example_feature_type(dnnd_pm_const_example float)
Expand Down
70 changes: 70 additions & 0 deletions examples/dnnd_simple_custom_distance_example.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright 2024 Lawrence Livermore National Security, LLC and other
// saltatlas Project Developers. See the top-level COPYRIGHT file for details.
//
// SPDX-License-Identifier: MIT

/// \brief A simple example of using DNND's simple with a custom distance function.
/// It is recommended to see the examples/dnnd_simple_example.cpp beforehand.
/// Usage:
/// cd build
/// mpirun -n 2 ./example/dnnd_simple_custom_distance_example

#include <iostream>
#include <vector>

#include <ygm/comm.hpp>

#include <saltatlas/dnnd/dnnd_simple.hpp>

// Point ID type
using id_t = uint32_t;
using dist_t = double;

// Point Type
using point_type = saltatlas::feature_vector<float>;

// Custom distance function
// The distance function should have the signature as follows:
// distance_type(const point_type& a, const point_type& b);
dist_t custom_distance(const point_type& p1, const point_type& p2) {
// A simple (squared) L2 distance example
dist_t dist = 0.0;
for (size_t i = 0; i < p1.size(); ++i) {
dist += (p1[i] - p2[i]) * (p1[i] - p2[i]);
}
return dist;
}

int main(int argc, char** argv) {
ygm::comm comm(&argc, &argv);

saltatlas::dnnd<id_t, point_type, dist_t> g(custom_distance, comm);
std::vector<std::filesystem::path> paths{
"../examples/datasets/point_5-4.txt"};
g.load_points(paths.begin(), paths.end(), "wsv");

// ----- NNG build and NN search APIs ----- //
int k = 4;
g.build(k);

bool make_graph_undirected = true;
g.optimize(make_graph_undirected);

// Run queries
std::vector<point_type> queries;
if (comm.rank() == 0) {
queries.push_back(point_type{61.58, 29.68, 20.43, 99.22, 21.81});
}
int num_to_search = 4;
const auto results = g.query(queries.begin(), queries.end(), num_to_search);

if (comm.rank() == 0) {
std::cout << "Neighbours (id, distance):";
for (const auto& [nn_id, nn_dist] : results[0]) {
std::cout << " " << nn_id << " (" << nn_dist << ")";
}
std::cout << std::endl;
}

return 0;
}
8 changes: 8 additions & 0 deletions examples/dnnd_simple_custom_point_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
//
// SPDX-License-Identifier: MIT

/// \brief A simple example of using DNND's simple API with a custom point type
/// and custom distance function.
/// It is recommended to see the examples/dnnd_simple_example.cpp beforehand.
/// Usage:
/// cd build
/// mpirun -n 2 ./example/dnnd_simple_custom_point_example

#include <iostream>
#include <vector>

Expand Down Expand Up @@ -39,6 +46,7 @@ int main(int argc, char** argv) {

// Add points
{
// Assuming ids and points are stored in vectors
std::vector<id_t> ids;
std::vector<graph_point> points;
g.add_points(ids.begin(), ids.end(), points.begin(), points.end());
Expand Down
16 changes: 15 additions & 1 deletion examples/dnnd_simple_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
//
// SPDX-License-Identifier: MIT

/// \brief A simple example of using the DNND's simple API.
/// Usage:
/// cd build
/// mpirun -n 2 ./example/dnnd_simple_example

#include <iostream>
#include <vector>

Expand All @@ -20,8 +25,16 @@ using point_type = saltatlas::feature_vector<float>;
int main(int argc, char** argv) {
ygm::comm comm(&argc, &argv);

saltatlas::dnnd<id_t, point_type, dist_t> g(saltatlas::distance::id::l2,
// Create a DNND object
// Use the squared L2 distance function
saltatlas::dnnd<id_t, point_type, dist_t> g(saltatlas::distance::id::sql2,
comm);

// Load points from file(s)
// The file format is assumed to be whitespace-separated values (wsv)
// One point per line. Each feature value is separated by a whitespace.
// DNND assigns an ID to each point in the order they are loaded,
// i.e., ID is the line number starting from 0.
std::vector<std::filesystem::path> paths{
"../examples/datasets/point_5-4.txt"};
g.load_points(paths.begin(), paths.end(), "wsv");
Expand All @@ -41,6 +54,7 @@ int main(int argc, char** argv) {
int num_to_search = 4;
const auto results = g.query(queries.begin(), queries.end(), num_to_search);

// Show the query results
if (comm.rank() == 0) {
std::cout << "Neighbours (id, distance):";
for (const auto& [nn_id, nn_dist] : results[0]) {
Expand Down
26 changes: 13 additions & 13 deletions include/saltatlas/dnnd/detail/base_dnnd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ struct data_core {
knn_index(allocator) {}

saltatlas::distance::id distance_id;
uint64_t rnd_seed;
uint64_t rnd_seed; // TODO: does not have to hold?
point_store_type pstore;
knn_index_type knn_index;
std::size_t index_k{0};
Expand Down Expand Up @@ -122,6 +122,7 @@ class base_dnnd {
/// \return A point partitioner instance.
point_partitioner get_point_partitioner() const {
const int size = m_comm.size();
// TODO: hash id?
return [size](const id_type& id) { return id % size; };
};

Expand Down Expand Up @@ -275,8 +276,8 @@ class base_dnnd {
protected:
/// \brief Initialize the internal data core instance.
/// \param data_core A data core instance.
void init_data_core(data_core_type& data_core,
const distance_function_type& distance_function) {
void set_data_core(data_core_type& data_core,
const distance_function_type& distance_function) {
if (m_data_core) {
m_comm.cerr0() << "Data core is already initialized." << std::endl;
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
Expand All @@ -286,12 +287,12 @@ class base_dnnd {
m_distance_function = distance_function;
}

/// \brief init_data_core for using a pre-defined distance function.
/// \brief set_data_core for using a pre-defined distance function.
/// The data core argument must contain a valid distance function ID.
void init_data_core(data_core_type& data_core) {
init_data_core(data_core,
distance::distance_function<point_type, distance_type>(
data_core.distance_id));
void set_data_core(data_core_type& data_core) {
set_data_core(data_core,
distance::distance_function<point_type, distance_type>(
data_core.distance_id));
}

private:
Expand Down Expand Up @@ -323,11 +324,10 @@ class base_dnnd {
return ret;
}

data_core_type* m_data_core{nullptr};
ygm::comm& m_comm;
bool m_verbose{false};
ygm::ygm_ptr<self_type> m_self{this};
distance_function_type m_distance_function;
data_core_type* m_data_core{nullptr};
ygm::comm& m_comm;
bool m_verbose{false}; // TODO: make this changeable after construction
distance_function_type m_distance_function;
};

} // namespace saltatlas::dndetail
4 changes: 2 additions & 2 deletions include/saltatlas/dnnd/dnnd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class dnnd : public dndetail::base_dnnd<Id, Point, Distance> {
const bool verbose = false)
: base_type(verbose, comm),
m_data_core(distance::convert_to_distance_id(distance_name), rnd_seed) {
base_type::init_data_core(m_data_core);
base_type::set_data_core(m_data_core);
}

/// \brief Constructor.
Expand All @@ -62,7 +62,7 @@ class dnnd : public dndetail::base_dnnd<Id, Point, Distance> {
const uint64_t rnd_seed = std::random_device{}(),
const bool verbose = false)
: base_type(verbose, comm), m_data_core(distance::id::custom, rnd_seed) {
base_type::init_data_core(m_data_core, distance_func);
base_type::set_data_core(m_data_core, distance_func);
}

private:
Expand Down
6 changes: 3 additions & 3 deletions include/saltatlas/dnnd/dnnd_pm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class dnnd_pm : public dndetail::base_dnnd<
const bool verbose = false)
: base_type(verbose, comm) {
priv_create(datastore_path, distance_name, rnd_seed);
base_type::init_data_core(*m_data_core);
base_type::set_data_core(*m_data_core);
comm.cf_barrier();
}

Expand All @@ -90,7 +90,7 @@ class dnnd_pm : public dndetail::base_dnnd<
const bool verbose = false)
: base_type(verbose, comm) {
priv_open(datastore_path);
base_type::init_data_core(*m_data_core);
base_type::set_data_core(*m_data_core);
comm.cf_barrier();
}

Expand All @@ -102,7 +102,7 @@ class dnnd_pm : public dndetail::base_dnnd<
ygm::comm& comm, const bool verbose = false)
: base_type(verbose, comm) {
priv_open_read_only(datastore_path);
base_type::init_data_core(*m_data_core);
base_type::set_data_core(*m_data_core);
comm.cf_barrier();
}

Expand Down
Loading

0 comments on commit 12bd23b

Please sign in to comment.