Skip to content

Commit

Permalink
(DNND) Add query_with_features() in the simple API
Browse files Browse the repository at this point in the history
  • Loading branch information
Keita Iwabuchi committed Sep 5, 2024
1 parent 3096d8f commit 92d4fef
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 36 deletions.
75 changes: 52 additions & 23 deletions examples/dnnd_simple_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,55 +47,84 @@ int main(int argc, char** argv) {
g.optimize(make_graph_undirected);

// Run queries
std::vector<point_type> queries;
if (comm.rank() == 0) {
queries.push_back(point_type{61.58, 29.68, 20.43, 99.22, 21.81});
}
int num_to_search = 4;
const auto results = g.query(queries.begin(), queries.end(), num_to_search);

// Show the query results
if (comm.rank() == 0) {
std::cout << "Neighbours (id, distance):";
for (const auto& [nn_id, nn_dist] : results[0]) {
std::cout << " " << nn_id << " (" << nn_dist << ")";
{
std::vector<point_type> queries;
if (comm.rank() == 0) {
queries.push_back(point_type{61.58, 29.68, 20.43, 99.22, 21.81});
} else if (comm.rank() == 1) {
queries.push_back(point_type{78.44, 54.43, 59.68, 65.80, 24.361});
}

{
int num_to_search = 4;
const auto results =
g.query(queries.begin(), queries.end(), num_to_search);
// Show the query results
comm.cout0() << "Query owner rank: neighbours (id, distance)..."
<< std::endl;
for (int i = 0; i <= 1; ++i) {
if (comm.rank() == i) {
std::cout << "Rank " << i << ": ";
for (const auto& [nn_id, nn_dist] : results[0]) {
std::cout << " " << nn_id << " (" << nn_dist << ")";
}
std::cout << std::endl;
}
comm.cf_barrier();
}
}

// Get nearest neighbours with features
comm.cout0()<< "\nNearest neighbor query result with features\n(show only the nearest point of the query from rank 0)" << std::endl;
{
int num_to_search = 4;
const auto results =
g.query_with_features(queries.begin(), queries.end(), num_to_search);
// Show the query results
if (comm.rank() == 0) {
auto& neighbours = results.first;
auto& features = results.second;
std::cout << "Point ID " << neighbours[0][0].id << ", distance "
<< neighbours[0][0].distance << ", feature {";
for (const auto& v : features[0][0]) {
std::cout << v << " ";
}
std::cout << "}" << std::endl;
}
}
std::cout << std::endl;
}

// --- Point Data Accessors --- //
// Get a local point by ID
comm.cout0() << "\nPoint 0's features: " << std::endl;
{
id_t pid = 0;
if (g.contains_local(pid)) {
auto p0 = g.get_local_point(pid);
std::cout << "Point 0 : ";
for (const auto& v : p0) {
std::cout << v << " ";
}
std::cout << std::endl;
}
comm.cout0() << std::endl;
}

// Point data iterator
comm.cout0() << "\nRank 0's all local points" << std::endl;
{
comm.cout0() << "Rank 0 local points" << std::endl;
for (const auto& [id, point] : g.local_points()) {
comm.cout0() << "Point " << id << " : ";
comm.cout0() << "Point ID " << id << " : ";
for (const auto& v : point) {
comm.cout0() << v << " ";
}
comm.cout0() << std::endl;
}
comm.cout0() << std::endl;
}

// Get points including the ones that are stored in other ranks
comm.cout0() << "\nGet points including the ones that are stored in other ranks"
<< std::endl;
{
auto points = g.get_points({0, 1});
id_t ids[] = {0, 1};
auto points = g.get_points(ids, ids + 2);
for (const auto& [id, point] : points) {
comm.cout0() << "Point " << id << " : ";
comm.cout0() << "Point ID " << id << " : ";
for (const auto& v : point) {
comm.cout0() << v << " ";
}
Expand Down
70 changes: 57 additions & 13 deletions include/saltatlas/dnnd/dnnd_simple.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,34 @@ class dnnd {
return query_result;
}

template <typename query_iterator>
std::pair<neighbor_store_type, std::vector<std::vector<point_type>>>
query_with_features(query_iterator queries_begin, query_iterator queries_end,
const int k, const double epsilon = 0.1) {
auto query_result = query(queries_begin, queries_end, k, epsilon);

std::vector<std::vector<point_type>> neighbor_features_store;
neighbor_features_store.reserve(query_result.size());
for (const auto& neighbors : query_result) {
std::vector<id_type> neighbor_ids;
neighbor_ids.reserve(neighbors.size());
for (const auto& neighbor : neighbors) {
neighbor_ids.push_back(neighbor.id);
}
auto neighbor_features =
get_points(neighbor_ids.begin(), neighbor_ids.end());

std::vector<point_type> neighbor_features_vec;
neighbor_features_vec.reserve(neighbor_ids.size());
for (const auto& id : neighbor_ids) {
neighbor_features_vec.push_back(neighbor_features.at(id));
}
neighbor_features_store.push_back(std::move(neighbor_features_vec));
}

return std::make_pair(query_result, neighbor_features_store);
}

/// \brief Dump the k-NN index to distributed files.
/// \param out_file_prefix File path prefix.
/// \param dump_distance If true, also dump distances
Expand Down Expand Up @@ -308,35 +336,51 @@ class dnnd {
return m_pstore.at(id);
}

std::vector<point_type> get_local_points(
const std::vector<id_type>& ids) const {
std::vector<point_type> points;
points.reserve(ids.size());
for (const auto& id : ids) {
points.push_back(m_pstore.at(id));
template <typename id_iterator>
std::unordered_map<id_type, point_type> get_local_points(
id_iterator ids_begin, id_iterator ids_end) const {
static_assert(
std::is_same_v<typename std::iterator_traits<id_iterator>::value_type,
id_type>,
"id_iterator must be an iterator of id_type");

std::unordered_map<id_type, point_type> points;
points.reserve(std::distance(ids_begin, ids_end));
for (auto it = ids_begin; it != ids_end; ++it) {
const auto id = *it;
points.emplace(id, m_pstore.at(id));
}
return points;
}

std::vector<std::pair<id_type, point_type>> get_points(
const std::vector<id_type>& ids) const {
std::vector<std::pair<id_type, point_type>> points;
points.reserve(ids.size());
template <typename id_iterator>
std::unordered_map<id_type, point_type> get_points(
id_iterator ids_begin, id_iterator ids_end) const {
static_assert(
std::is_same_v<typename std::iterator_traits<id_iterator>::value_type,
id_type>,
"id_iterator must be an iterator of id_type");

std::unordered_map<id_type, point_type> points;
points.reserve(std::distance(ids_begin, ids_end));

static auto& ref_points = points;
ref_points = points;

auto proc = [](auto comm, auto pthis, const id_type id,
const int source_rank) {
const int source_rank) {
assert(pthis->contains_local(id));

comm->async(
source_rank,
[](auto, const auto& id, const auto& point) {
ref_points.emplace_back(id, point);
ref_points.emplace(id, point);
},
id, pthis->get_local_point(id));
};

for (const auto& id : ids) {
for (auto it = ids_begin; it != ids_end; ++it) {
const auto id = *it;
m_comm.async(get_owner(id), proc, m_this, id, m_comm.rank());
}

Expand Down

0 comments on commit 92d4fef

Please sign in to comment.