Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes uses of for_all to comply with YGM API #72

Merged
merged 2 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/binary_nn_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ void build_index(

dist_index.comm().cout0("Distributing data across ranks");

auto add_point_lambda = [&dist_index, &bag_data](const auto &index,
const auto &point) {
auto add_point_lambda = [&dist_index, &bag_data](const auto &index_point) {
const auto &[index, point] = index_point;
dist_index.queue_data_point_insertion(index, point);
};

Expand Down
3 changes: 2 additions & 1 deletion examples/dhnsw_ascii_levenshtein.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ int main(int argc, char** argv) {
world.barrier();

world.cout0("Distributing data to local HNSWs");
bag_data.for_all([&dist_index, &world](const auto& id, const auto& line) {
bag_data.for_all([&dist_index, &world](const auto& id_line) {
const auto& [id, line] = id_line;
dist_index.queue_data_point_insertion(id, line);
});

Expand Down
18 changes: 9 additions & 9 deletions examples/dhnsw_big_ann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dist_t l2_sqr(const point_t& v1, const point_t& v2) {
std::cout << "Size mismatch: " << v1.size() << " != " << v2.size()
<< std::endl;
}
ASSERT_DEBUG(v1.size() == v2.size());
YGM_ASSERT_DEBUG(v1.size() == v2.size());

dist_t d = 0;
for (std::size_t i = 0; i < v1.size(); ++i) {
Expand Down Expand Up @@ -65,7 +65,7 @@ pair_bag<index_t, point_t> read_points(
pt.push_back(feature);
}

ASSERT_RELEASE(pt.size() == 128);
YGM_ASSERT_RELEASE(pt.size() == 128);

to_return.async_insert(std::make_pair(id, pt));
});
Expand Down Expand Up @@ -172,7 +172,7 @@ read_ground_truth(const std::string& filename, ygm::comm& world) {
// Sanity check on ground-truth
nn_dist_truth.for_all([](const auto& id, const auto& nn_vec) {
for (const auto& nn_dist : nn_vec) {
ASSERT_RELEASE(nn_dist.second >= 0.0);
YGM_ASSERT_RELEASE(nn_dist.second >= 0.0);
}
});

Expand Down Expand Up @@ -426,10 +426,10 @@ int main(int argc, char** argv) {
world.barrier();

world.cout0("Distributing data to local HNSWs");
index_points.for_all(
[&dist_index, &world](const auto& id, const auto& point) {
dist_index.queue_data_point_insertion(id, point);
});
index_points.for_all([&dist_index, &world](const auto& id_point) {
const auto& [id, point] = id_point;
dist_index.queue_data_point_insertion(id, point);
});

world.barrier();

Expand Down Expand Up @@ -495,8 +495,8 @@ int main(int argc, char** argv) {

query_points.for_all([&dist_index, k, num_hops, num_initial_queries,
voronoi_rank, store_results_lambda](
const auto& query_index,
const auto& query_point) {
const auto& query_index_point) {
const auto& [query_index, query_point] = query_index_point;
dist_index.query(query_point, k, num_hops, num_initial_queries,
voronoi_rank, store_results_lambda, query_index);
});
Expand Down
3 changes: 2 additions & 1 deletion examples/dhnsw_verbose_levenshtein.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ int main(int argc, char** argv) {
fuzzy_partitioner.print_tree();

world.cout0("Distributing data to local HNSWs");
bag_data.for_all([&dist_index, &world](const auto& id, const auto& line) {
bag_data.for_all([&dist_index, &world](const auto& id_line) {
const auto& [id, line] = id_line;
dist_index.queue_data_point_insertion(id, line);
});

Expand Down
10 changes: 5 additions & 5 deletions include/saltatlas/dhnsw/detail/dist_index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ class dhnsw_impl {
const point_type &v) {
index_vec_type point_partitions =
partitioner().find_point_partitions(v, m_max_voronoi_rank);
ASSERT_RELEASE(point_partitions[0] < m_num_cells);
YGM_ASSERT_RELEASE(point_partitions[0] < m_num_cells);
add_data_point_to_insertion_queue(index, v, point_partitions);
}

void add_data_point_to_insertion_queue(const index_type index,
const point_type &v,
const index_vec_type &closest_seeds) {
auto insertion_cell = closest_seeds[0];
ASSERT_RELEASE(insertion_cell < m_num_cells);
YGM_ASSERT_RELEASE(insertion_cell < m_num_cells);
m_comm->async(
cell_owner(insertion_cell),
[](auto mbox, ygm::ygm_ptr<dhnsw_impl> pthis, index_type index,
Expand All @@ -118,11 +118,11 @@ class dhnsw_impl {
}

void initialize_hnsw() {
ASSERT_RELEASE(m_constructed_index == false);
YGM_ASSERT_RELEASE(m_constructed_index == false);

// Initialize HNSW structures
for (int i = 0; i < num_local_cells(); ++i) {
ASSERT_RELEASE(m_cell_add_vec[i].size() > 0);
YGM_ASSERT_RELEASE(m_cell_add_vec[i].size() > 0);

hnswlib::HierarchicalNSW<dist_type> *hnsw =
new hnswlib::HierarchicalNSW<dist_type>(
Expand Down Expand Up @@ -187,7 +187,7 @@ class dhnsw_impl {
}

const point_type &get_point(index_type index) {
ASSERT_RELEASE(m_local_data.count(index) > 0);
YGM_ASSERT_RELEASE(m_local_data.count(index) > 0);
return m_local_data[index];
}

Expand Down
62 changes: 31 additions & 31 deletions include/saltatlas/dhnsw/detail/query_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@ template <typename DistType, typename IndexType, typename Point,
template <typename, typename, typename> class Partitioner>
class query_engine_impl {
public:
using dist_type = DistType;
using index_type = IndexType;
using point_type = Point;
using dist_type = DistType;
using index_type = IndexType;
using point_type = Point;
using partitioner_type = Partitioner<dist_type, index_type, point_type>;
using query_engine_impl_type =
query_engine_impl<dist_type, index_type, point_type, Partitioner>;
using dhnsw_impl_type =
dhnsw_impl<dist_type, index_type, point_type, Partitioner>;
using dist_ngbr_mmap_type = std::multimap<dist_type, index_type>;
using dist_ngbr_owner_map_type = std::map<index_type, int>;
using dist_ngbr_features_map_type = std::map<index_type, point_type>;
using dist_ngbr_mmap_type = std::multimap<dist_type, index_type>;
using dist_ngbr_owner_map_type = std::map<index_type, int>;
using dist_ngbr_features_map_type = std::map<index_type, point_type>;
using dist_ngbr_features_mmap_type =
std::multimap<dist_type, std::pair<index_type, point_type>>;

Expand Down Expand Up @@ -87,8 +87,8 @@ class query_engine_impl {
};

const point_type get_query_point() const { return m_query_point; }
const int get_k() const { return m_k; }
const int get_max_hops() const { return m_k; }
const int get_k() const { return m_k; }
const int get_max_hops() const { return m_k; }
const int get_initial_num_queries() const { return m_initial_num_queries; }
const int get_voronoi_rank() const { return m_voronoi_rank; }
const int get_num_local_queries() const { return m_queries_spawned; }
Expand Down Expand Up @@ -157,7 +157,7 @@ class query_engine_impl {
private:
void update_nearest_neighbors(const dist_ngbr_mmap_type &returned_neighbors,
const int owner_rank) {
ASSERT_RELEASE(owner_rank < engine->m_comm->size());
YGM_ASSERT_RELEASE(owner_rank < engine->m_comm->size());
if (m_nearest_neighbors.size() > 0) {
merge_nearest_neighbors(returned_neighbors, owner_rank);
} else {
Expand Down Expand Up @@ -240,16 +240,16 @@ class query_engine_impl {
engine->m_query_id_recycler.return_id(m_id);
return;
} else {
auto get_neighbor_features_lambda =
[](auto engine, const query_locator locator,
const index_type ngbr_index) {
auto get_neighbor_features_lambda = [](auto engine,
const query_locator locator,
const index_type ngbr_index) {
auto neighbor_features_response_lambda =
[](auto engine, const query_id_type &id,
const index_type ngbr_index, const point_type &ngbr) {
auto &query_controller = engine->m_query_controllers[id];

query_controller.m_nearest_neighbor_features[ngbr_index] = ngbr;
ASSERT_RELEASE(
YGM_ASSERT_RELEASE(
query_controller.m_nearest_neighbor_features.size() <=
query_controller.m_k);

Expand Down Expand Up @@ -302,7 +302,7 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);

query_locator locator{engine->m_comm->rank(), m_id};
for (const auto &[idx, owner_rank] : m_nearest_neighbor_owners) {
ASSERT_RELEASE(owner_rank < engine->m_comm->size());
YGM_ASSERT_RELEASE(owner_rank < engine->m_comm->size());
engine->m_comm->async(owner_rank, get_neighbor_features_lambda,
engine->pthis, locator, idx);
}
Expand All @@ -314,7 +314,7 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);
auto cell_query_lambda = [](auto engine, const point_type &q,
const index_type s_cell,
const dist_type max_dist, const int s_k,
const int s_voronoi_rank,
const int s_voronoi_rank,
const query_locator locator) {
int local_cell = engine->local_cell_index(s_cell);

Expand Down Expand Up @@ -404,13 +404,13 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);
// m_queries_returned instead, but might check
// between round finishing and next round starting
point_type m_query_point;
int m_k;
int m_max_hops;
int m_initial_num_queries;
int m_voronoi_rank;
int m_queries_spawned;
int m_queries_returned;
int m_current_hops;
int m_k;
int m_max_hops;
int m_initial_num_queries;
int m_voronoi_rank;
int m_queries_spawned;
int m_queries_returned;
int m_current_hops;
std::set<index_type> m_queried_cells;
std::set<index_type> m_next_cells;
dist_ngbr_mmap_type m_nearest_neighbors;
Expand All @@ -436,7 +436,7 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);
bool has_id_available() { return m_available_ids.size() > 0; }

T get_id() {
ASSERT_RELEASE(m_available_ids.size() > 0);
YGM_ASSERT_RELEASE(m_available_ids.size() > 0);

T id = m_available_ids.back();
m_available_ids.pop_back();
Expand Down Expand Up @@ -589,7 +589,7 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);

void (*fun_ptr)(const point_type &, const dist_ngbr_mmap_type &,
const query_controller &, cereal::YGMInputArchive &) =
[](const point_type &query_pt,
[](const point_type &query_pt,
const dist_ngbr_mmap_type &nearest_neighbors,
const query_controller &controller, cereal::YGMInputArchive &bia) {
std::tuple<PackArgs...> ta;
Expand Down Expand Up @@ -637,10 +637,10 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);
return to_return;
}

void deserialize_lambda(cereal::YGMInputArchive &iarchive,
const point_type &query_pt,
const dist_ngbr_mmap_type &nearest_neighbors,
const query_controller &controller) {
void deserialize_lambda(cereal::YGMInputArchive &iarchive,
const point_type &query_pt,
const dist_ngbr_mmap_type &nearest_neighbors,
const query_controller &controller) {
int64_t iptr;
iarchive(iptr);
iptr += (int64_t)&reference;
Expand All @@ -653,7 +653,7 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);
void deserialize_lambda_with_features(
cereal::YGMInputArchive &iarchive, const point_type &query_pt,
const dist_ngbr_features_mmap_type &nearest_neighbors,
const query_controller &controller) {
const query_controller &controller) {
int64_t iptr;
iarchive(iptr);
iptr += (int64_t)&reference;
Expand All @@ -664,9 +664,9 @@ fun_ptr(query_controller.m_query_point, nn_mmap, *this, iarchive);
}

std::vector<query_controller> m_query_controllers;
id_recycler<query_id_type> m_query_id_recycler;
id_recycler<query_id_type> m_query_id_recycler;

ygm::comm *m_comm;
ygm::comm *m_comm;
ygm::ygm_ptr<dhnsw_impl_type> m_dist_index_impl_ptr;
ygm::ygm_ptr<query_engine_impl_type> pthis;
};
Expand Down
2 changes: 1 addition & 1 deletion include/saltatlas/dnnd/data_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ void read_points_with_id_helper(
std::cerr << "Duplicate ID " << id << std::endl;
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
}
ref_point_store[id]= sent_point;
ref_point_store[id] = sent_point;
};
comm.async(point_partitioner(id), receiver, id, point);
}
Expand Down
4 changes: 2 additions & 2 deletions include/saltatlas/dnnd/detail/query_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class dknn_batch_query_kernel {
for (std::size_t i = 0; i < queries.size(); ++i) {
m_comm.async(
r,
[](const ygm::ygm_ptr<self_type>& dst_self, const id_type id,
[](ygm::ygm_ptr<self_type> dst_self, const id_type id,
const point_type& q) {
assert(!dst_self->m_query_store.contains(id));
dst_self->m_query_store[id] = q;
Expand Down Expand Up @@ -320,4 +320,4 @@ class dknn_batch_query_kernel {
std::unordered_map<std::size_t, std::unordered_set<id_type>> m_visited;
};

} // namespace saltatlas::dndetail
} // namespace saltatlas::dndetail
Loading
Loading