Skip to content

Commit

Permalink
Fixes uses of for_all to comply with YGM API (#72)
Browse files Browse the repository at this point in the history
* Fixes uses of for_all to comply with YGM API

* Changes ASSERT_RELEASE to YGM_ASSERT_RELEASE
  • Loading branch information
steiltre authored Aug 2, 2024
1 parent 57bcde1 commit 37c8623
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 90 deletions.
4 changes: 2 additions & 2 deletions examples/binary_nn_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ void build_index(

dist_index.comm().cout0("Distributing data across ranks");

auto add_point_lambda = [&dist_index, &bag_data](const auto &index,
const auto &point) {
auto add_point_lambda = [&dist_index, &bag_data](const auto &index_point) {
const auto &[index, point] = index_point;
dist_index.queue_data_point_insertion(index, point);
};

Expand Down
3 changes: 2 additions & 1 deletion examples/dhnsw_ascii_levenshtein.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ int main(int argc, char** argv) {
world.barrier();

world.cout0("Distributing data to local HNSWs");
bag_data.for_all([&dist_index, &world](const auto& id, const auto& line) {
bag_data.for_all([&dist_index, &world](const auto& id_line) {
const auto& [id, line] = id_line;
dist_index.queue_data_point_insertion(id, line);
});

Expand Down
18 changes: 9 additions & 9 deletions examples/dhnsw_big_ann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dist_t l2_sqr(const point_t& v1, const point_t& v2) {
std::cout << "Size mismatch: " << v1.size() << " != " << v2.size()
<< std::endl;
}
ASSERT_DEBUG(v1.size() == v2.size());
YGM_ASSERT_DEBUG(v1.size() == v2.size());

dist_t d = 0;
for (std::size_t i = 0; i < v1.size(); ++i) {
Expand Down Expand Up @@ -65,7 +65,7 @@ pair_bag<index_t, point_t> read_points(
pt.push_back(feature);
}

ASSERT_RELEASE(pt.size() == 128);
YGM_ASSERT_RELEASE(pt.size() == 128);

to_return.async_insert(std::make_pair(id, pt));
});
Expand Down Expand Up @@ -172,7 +172,7 @@ read_ground_truth(const std::string& filename, ygm::comm& world) {
// Sanity check on ground-truth
nn_dist_truth.for_all([](const auto& id, const auto& nn_vec) {
for (const auto& nn_dist : nn_vec) {
ASSERT_RELEASE(nn_dist.second >= 0.0);
YGM_ASSERT_RELEASE(nn_dist.second >= 0.0);
}
});

Expand Down Expand Up @@ -426,10 +426,10 @@ int main(int argc, char** argv) {
world.barrier();

world.cout0("Distributing data to local HNSWs");
index_points.for_all(
[&dist_index, &world](const auto& id, const auto& point) {
dist_index.queue_data_point_insertion(id, point);
});
index_points.for_all([&dist_index, &world](const auto& id_point) {
const auto& [id, point] = id_point;
dist_index.queue_data_point_insertion(id, point);
});

world.barrier();

Expand Down Expand Up @@ -495,8 +495,8 @@ int main(int argc, char** argv) {

query_points.for_all([&dist_index, k, num_hops, num_initial_queries,
voronoi_rank, store_results_lambda](
const auto& query_index,
const auto& query_point) {
const auto& query_index_point) {
const auto& [query_index, query_point] = query_index_point;
dist_index.query(query_point, k, num_hops, num_initial_queries,
voronoi_rank, store_results_lambda, query_index);
});
Expand Down
3 changes: 2 additions & 1 deletion examples/dhnsw_verbose_levenshtein.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ int main(int argc, char** argv) {
fuzzy_partitioner.print_tree();

world.cout0("Distributing data to local HNSWs");
bag_data.for_all([&dist_index, &world](const auto& id, const auto& line) {
bag_data.for_all([&dist_index, &world](const auto& id_line) {
const auto& [id, line] = id_line;
dist_index.queue_data_point_insertion(id, line);
});

Expand Down
10 changes: 5 additions & 5 deletions include/saltatlas/dhnsw/detail/dist_index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ class dhnsw_impl {
const point_type &v) {
index_vec_type point_partitions =
partitioner().find_point_partitions(v, m_max_voronoi_rank);
ASSERT_RELEASE(point_partitions[0] < m_num_cells);
YGM_ASSERT_RELEASE(point_partitions[0] < m_num_cells);
add_data_point_to_insertion_queue(index, v, point_partitions);
}

void add_data_point_to_insertion_queue(const index_type index,
const point_type &v,
const index_vec_type &closest_seeds) {
auto insertion_cell = closest_seeds[0];
ASSERT_RELEASE(insertion_cell < m_num_cells);
YGM_ASSERT_RELEASE(insertion_cell < m_num_cells);
m_comm->async(
cell_owner(insertion_cell),
[](auto mbox, ygm::ygm_ptr<dhnsw_impl> pthis, index_type index,
Expand All @@ -118,11 +118,11 @@ class dhnsw_impl {
}

void initialize_hnsw() {
ASSERT_RELEASE(m_constructed_index == false);
YGM_ASSERT_RELEASE(m_constructed_index == false);

// Initialize HNSW structures
for (int i = 0; i < num_local_cells(); ++i) {
ASSERT_RELEASE(m_cell_add_vec[i].size() > 0);
YGM_ASSERT_RELEASE(m_cell_add_vec[i].size() > 0);

hnswlib::HierarchicalNSW<dist_type> *hnsw =
new hnswlib::HierarchicalNSW<dist_type>(
Expand Down Expand Up @@ -187,7 +187,7 @@ class dhnsw_impl {
}

const point_type &get_point(index_type index) {
ASSERT_RELEASE(m_local_data.count(index) > 0);
YGM_ASSERT_RELEASE(m_local_data.count(index) > 0);
return m_local_data[index];
}

Expand Down
62 changes: 31 additions & 31 deletions include/saltatlas/dhnsw/detail/query_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@ template <typename DistType, typename IndexType, typename Point,
template <typename, typename, typename> class Partitioner>
class query_engine_impl {
public:
using dist_type = DistType;
using index_type = IndexType;
using point_type = Point;
using dist_type = DistType;
using index_type = IndexType;
using point_type = Point;
using partitioner_type = Partitioner<dist_type, index_type, point_type>;
using query_engine_impl_type =
query_engine_impl<dist_type, index_type, point_type, Partitioner>;
using dhnsw_impl_type =
dhnsw_impl<dist_type, index_type, point_type, Partitioner>;
using dist_ngbr_mmap_type = std::multimap<dist_type, index_type>;
using dist_ngbr_owner_map_type = std::map<index_type, int>;
using dist_ngbr_features_map_type = std::map<index_type, point_type>;
using dist_ngbr_mmap_type = std::multimap<dist_type, index_type>;
using dist_ngbr_owner_map_type = std::map<index_type, int>;
using dist_ngbr_features_map_type = std::map<index_type, point_type>;
using dist_ngbr_features_mmap_type =
std::multimap<dist_type, std::pair<index_type, point_type>>;

Expand Down Expand Up @@ -87,8 +87,8 @@ class query_engine_impl {
};

const point_type get_query_point() const { return m_query_point; }
const int get_k() const { return m_k; }
const int get_max_hops() const { return m_k; }
const int get_k() const { return m_k; }
const int get_max_hops() const { return m_k; }
const int get_initial_num_queries() const { return m_initial_num_queries; }
const int get_voronoi_rank() const { return m_voronoi_rank; }
const int get_num_local_queries() const { return m_queries_spawned; }
Expand Down Expand Up @@ -157,7 +157,7 @@ class query_engine_impl {
private:
void update_nearest_neighbors(const dist_ngbr_mmap_type &returned_neighbors,
const int owner_rank) {
ASSERT_RELEASE(owner_rank < engine->m_comm->size());
YGM_ASSERT_RELEASE(owner_rank < engine->m_comm->size());
if (m_nearest_neighbors.size() > 0) {
merge_nearest_neighbors(returned_neighbors, owner_rank);
} else {
Expand Down Expand Up @@ -240,16 +240,16 @@ class query_engine_impl {
engine->m_query_id_recycler.return_id(m_id);
return;
} else {
auto get_neighbor_features_lambda =
[](auto engine, const query_locator locator,
const index_type ngbr_index) {
auto get_neighbor_features_lambda = [](auto engine,
const query_locator locator,
const index_type ngbr_index) {
auto neighbor_features_response_lambda =
[](auto engine, const query_id_type &id,
const index_type ngbr_index, const point_type &ngbr) {
auto &query_controller = engine->m_query_controllers[id];

query_controller.m_nearest_neighbor_features[ngbr_index] = ngbr;
ASSERT_RELEASE(
YGM_ASSERT_RELEASE(
query_controller.m_nearest_neighbor_features.size() <=
query_controller.m_k);

Expand Down Expand Up @@ -302,7 +302,7 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);

query_locator locator{engine->m_comm->rank(), m_id};
for (const auto &[idx, owner_rank] : m_nearest_neighbor_owners) {
ASSERT_RELEASE(owner_rank < engine->m_comm->size());
YGM_ASSERT_RELEASE(owner_rank < engine->m_comm->size());
engine->m_comm->async(owner_rank, get_neighbor_features_lambda,
engine->pthis, locator, idx);
}
Expand All @@ -314,7 +314,7 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);
auto cell_query_lambda = [](auto engine, const point_type &q,
const index_type s_cell,
const dist_type max_dist, const int s_k,
const int s_voronoi_rank,
const int s_voronoi_rank,
const query_locator locator) {
int local_cell = engine->local_cell_index(s_cell);

Expand Down Expand Up @@ -404,13 +404,13 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);
// m_queries_returned instead, but might check
// between round finishing and next round starting
point_type m_query_point;
int m_k;
int m_max_hops;
int m_initial_num_queries;
int m_voronoi_rank;
int m_queries_spawned;
int m_queries_returned;
int m_current_hops;
int m_k;
int m_max_hops;
int m_initial_num_queries;
int m_voronoi_rank;
int m_queries_spawned;
int m_queries_returned;
int m_current_hops;
std::set<index_type> m_queried_cells;
std::set<index_type> m_next_cells;
dist_ngbr_mmap_type m_nearest_neighbors;
Expand All @@ -436,7 +436,7 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);
bool has_id_available() { return m_available_ids.size() > 0; }

T get_id() {
ASSERT_RELEASE(m_available_ids.size() > 0);
YGM_ASSERT_RELEASE(m_available_ids.size() > 0);

T id = m_available_ids.back();
m_available_ids.pop_back();
Expand Down Expand Up @@ -589,7 +589,7 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);

void (*fun_ptr)(const point_type &, const dist_ngbr_mmap_type &,
const query_controller &, cereal::YGMInputArchive &) =
[](const point_type &query_pt,
[](const point_type &query_pt,
const dist_ngbr_mmap_type &nearest_neighbors,
const query_controller &controller, cereal::YGMInputArchive &bia) {
std::tuple<PackArgs...> ta;
Expand Down Expand Up @@ -637,10 +637,10 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);
return to_return;
}

void deserialize_lambda(cereal::YGMInputArchive &iarchive,
const point_type &query_pt,
const dist_ngbr_mmap_type &nearest_neighbors,
const query_controller &controller) {
void deserialize_lambda(cereal::YGMInputArchive &iarchive,
const point_type &query_pt,
const dist_ngbr_mmap_type &nearest_neighbors,
const query_controller &controller) {
int64_t iptr;
iarchive(iptr);
iptr += (int64_t)&reference;
Expand All @@ -653,7 +653,7 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);
void deserialize_lambda_with_features(
cereal::YGMInputArchive &iarchive, const point_type &query_pt,
const dist_ngbr_features_mmap_type &nearest_neighbors,
const query_controller &controller) {
const query_controller &controller) {
int64_t iptr;
iarchive(iptr);
iptr += (int64_t)&reference;
Expand All @@ -664,9 +664,9 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);
}

std::vector<query_controller> m_query_controllers;
id_recycler<query_id_type> m_query_id_recycler;
id_recycler<query_id_type> m_query_id_recycler;

ygm::comm *m_comm;
ygm::comm *m_comm;
ygm::ygm_ptr<dhnsw_impl_type> m_dist_index_impl_ptr;
ygm::ygm_ptr<query_engine_impl_type> pthis;
};
Expand Down
2 changes: 1 addition & 1 deletion include/saltatlas/dnnd/data_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ void read_points_with_id_helper(
std::cerr << "Duplicate ID " << id << std::endl;
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
}
ref_point_store[id]= sent_point;
ref_point_store[id] = sent_point;
};
comm.async(point_partitioner(id), receiver, id, point);
}
Expand Down
4 changes: 2 additions & 2 deletions include/saltatlas/dnnd/detail/query_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class dknn_batch_query_kernel {
for (std::size_t i = 0; i < queries.size(); ++i) {
m_comm.async(
r,
[](const ygm::ygm_ptr<self_type>& dst_self, const id_type id,
[](ygm::ygm_ptr<self_type> dst_self, const id_type id,
const point_type& q) {
assert(!dst_self->m_query_store.contains(id));
dst_self->m_query_store[id] = q;
Expand Down Expand Up @@ -320,4 +320,4 @@ class dknn_batch_query_kernel {
std::unordered_map<std::size_t, std::unordered_set<id_type>> m_visited;
};

} // namespace saltatlas::dndetail
} // namespace saltatlas::dndetail
Loading

0 comments on commit 37c8623

Please sign in to comment.