Skip to content

Commit

Permalink
Merge pull request #76 from KIwabuchi/feature/dnnd_new_api
Browse files Browse the repository at this point in the history
Add functions in simple API
  • Loading branch information
KIwabuchi authored Sep 5, 2024
2 parents a68d15f + 92d4fef commit 62fa2b3
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 48 deletions.
43 changes: 25 additions & 18 deletions examples/dnnd_advanced_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ using point_type = saltatlas::pm_feature_vector<float>;
// Custom distance function
// The distance function should have the signature as follows:
// distance_type(const point_type& a, const point_type& b);
dist_t custom_distance(const point_type& p1, const point_type& p2) {
// A simple (squared) L2 distance example
dist_t dist = 0.0;
for (size_t i = 0; i < p1.size(); ++i) {
dist += (p1[i] - p2[i]) * (p1[i] - p2[i]);
dist_t custom_distance(const point_type& p0, const point_type& p1) {
// L2 distance example
dist_t d = 0;
for (std::size_t i = 0; i < p0.size(); ++i) {
const auto x = (p0[i] - p1[i]);
d += x * x;
}
return dist;
return d;
}

int main(int argc, char** argv) {
Expand All @@ -46,18 +47,18 @@ int main(int argc, char** argv) {

// ----- NNG build and NN search APIs ----- //
int k = 4;
const auto id = g.build(custom_distance, k);
const auto id = g.build(saltatlas::distance::id::l2, k);

bool make_graph_undirected = true;
g.optimize(id, custom_distance, make_graph_undirected);
g.optimize(id, saltatlas::distance::id::l2, make_graph_undirected);

// Run queries
std::vector<point_type> queries;
if (comm.rank() == 0) {
queries.push_back(point_type{61.58, 29.68, 20.43, 99.22, 21.81});
}
int num_to_search = 4;
const auto results = g.query(id, custom_distance, queries.begin(),
const auto results = g.query(id, saltatlas::distance::id::l2, queries.begin(),
queries.end(), num_to_search);

if (comm.rank() == 0) {
Expand All @@ -69,6 +70,7 @@ int main(int argc, char** argv) {
}
}

// -- Persistent memory example -- //
std::filesystem::path datastorepath = "/tmp/dnnd-knng";
std::error_code ec;
std::filesystem::remove_all(datastorepath, ec);
Expand All @@ -80,30 +82,35 @@ int main(int argc, char** argv) {
"../examples/datasets/point_5-4.txt"};
g.load_points(paths.begin(), paths.end(), "wsv");
const auto id = g.build(custom_distance, 2);
comm.cout0() << "Created" << std::endl;
comm.cout0() << "Created KNNG " << id << std::endl;
}

{
saltatlas::dnnd<id_t, point_type, dist_t> g(saltatlas::open_only,
datastorepath, comm);
if (g.contains_local(0)) g.get_local_point(0);
comm.cf_barrier();
g.update(0, custom_distance, 4);
comm.cout0() << "Updated" << std::endl;
comm.cout0() << "Updated KNNG " << 0 << std::endl;

g.build(custom_distance, 4);
auto id = g.build(custom_distance, 4);
comm.cout0() << "Created KNNG " << id << std::endl;
}

{
saltatlas::dnnd<id_t, point_type, dist_t> g(saltatlas::open_read_only,
datastorepath, comm);
std::vector<point_type> queries;
std::vector<point_type> queries;
if (comm.rank() == 0) {
queries.push_back(point_type{61.58, 29.68, 20.43, 99.22, 21.81});
}
comm.cout0() << "Query 1" << std::endl;
g.query(0, custom_distance, queries.begin(), queries.end(), 4);

auto ret0 = g.query(0, custom_distance, queries.begin(), queries.end(), 4);
if (comm.rank() == 0) {
comm.cout0() << "Query result of rank0:" << std::endl;
std::cout << "Neighbours (id, distance):";
for (const auto& [nn_id, nn_dist] : ret0[0]) {
std::cout << " " << nn_id << " (" << nn_dist << ")";
}
std::cout << std::endl;
}

comm.cout0() << "Query 2" << std::endl;
std::vector<std::size_t> ids{0, 1};
Expand Down
13 changes: 11 additions & 2 deletions examples/dnnd_simple_custom_point_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,18 @@ int main(int argc, char** argv) {

g.optimize();

std::vector<graph_point> queries;
int num_to_search = 10;
std::vector<graph_point> queries{gen_point()};
int num_to_search = 2;
const auto results = g.query(queries.begin(), queries.end(), num_to_search);

if (comm.rank() == 0) {
std::cout << "Query result of rank 0" << std::endl;
std::cout << "Neighbours (id, distance):";
for (const auto& [nn_id, nn_dist] : results[0]) {
std::cout << " " << nn_id << " (" << nn_dist << ")";
}
std::cout << std::endl;
}

return 0;
}
101 changes: 75 additions & 26 deletions examples/dnnd_simple_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,40 +47,89 @@ int main(int argc, char** argv) {
g.optimize(make_graph_undirected);

// Run queries
std::vector<point_type> queries;
if (comm.rank() == 0) {
queries.push_back(point_type{61.58, 29.68, 20.43, 99.22, 21.81});
{
std::vector<point_type> queries;
if (comm.rank() == 0) {
queries.push_back(point_type{61.58, 29.68, 20.43, 99.22, 21.81});
} else if (comm.rank() == 1) {
queries.push_back(point_type{78.44, 54.43, 59.68, 65.80, 24.361});
}

{
int num_to_search = 4;
const auto results =
g.query(queries.begin(), queries.end(), num_to_search);
// Show the query results
comm.cout0() << "Query owner rank: neighbours (id, distance)..."
<< std::endl;
for (int i = 0; i <= 1; ++i) {
if (comm.rank() == i) {
std::cout << "Rank " << i << ": ";
for (const auto& [nn_id, nn_dist] : results[0]) {
std::cout << " " << nn_id << " (" << nn_dist << ")";
}
std::cout << std::endl;
}
comm.cf_barrier();
}
}

// Get nearest neighbours with features
comm.cout0()<< "\nNearest neighbor query result with features\n(show only the nearest point of the query from rank 0)" << std::endl;
{
int num_to_search = 4;
const auto results =
g.query_with_features(queries.begin(), queries.end(), num_to_search);
// Show the query results
if (comm.rank() == 0) {
auto& neighbours = results.first;
auto& features = results.second;
std::cout << "Point ID " << neighbours[0][0].id << ", distance "
<< neighbours[0][0].distance << ", feature {";
for (const auto& v : features[0][0]) {
std::cout << v << " ";
}
std::cout << "}" << std::endl;
}
}
}
int num_to_search = 4;
const auto results = g.query(queries.begin(), queries.end(), num_to_search);

// Show the query results
if (comm.rank() == 0) {
std::cout << "Neighbours (id, distance):";
for (const auto& [nn_id, nn_dist] : results[0]) {
std::cout << " " << nn_id << " (" << nn_dist << ")";

// --- Point Data Accessors --- //
comm.cout0() << "\nPoint 0's features: " << std::endl;
{
id_t pid = 0;
if (g.contains_local(pid)) {
auto p0 = g.get_local_point(pid);
for (const auto& v : p0) {
std::cout << v << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;
}

// Point Data Accessors
id_t pid = 0;
if (g.contains_local(pid)) {
auto p0 = g.get_local_point(pid);
std::cout << "Point 0 : ";
for (const auto& v : p0) {
std::cout << v << " ";
comm.cout0() << "\nRank 0's all local points" << std::endl;
{
for (const auto& [id, point] : g.local_points()) {
comm.cout0() << "Point ID " << id << " : ";
for (const auto& v : point) {
comm.cout0() << v << " ";
}
comm.cout0() << std::endl;
}
std::cout << std::endl;
}

// Point Data Accessors, another API
for (const auto& [id, point] : g.local_points()) {
comm.cout0() << "Point " << id << " : ";
for (const auto& v : point) {
comm.cout0() << v << " ";
comm.cout0() << "\nGet points including the ones that are stored in other ranks"
<< std::endl;
{
id_t ids[] = {0, 1};
auto points = g.get_points(ids, ids + 2);
for (const auto& [id, point] : points) {
comm.cout0() << "Point ID " << id << " : ";
for (const auto& v : point) {
comm.cout0() << v << " ";
}
comm.cout0() << std::endl;
}
comm.cout0() << std::endl;
}

// Dump a KNNG to files
Expand Down
Loading

0 comments on commit 62fa2b3

Please sign in to comment.