From 9006b32350312924f78447bd8eb14aea9a0a92cc Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Thu, 29 Jul 2021 14:19:27 +0300 Subject: [PATCH] Unmark deleted --- .gitignore | 3 + hnswlib/hnswalg.h | 80 +++++++++++-------- python_bindings/bindings.cpp | 55 ++++++------- python_bindings/tests/bindings_test_labels.py | 19 +++-- 4 files changed, 88 insertions(+), 69 deletions(-) diff --git a/.gitignore b/.gitignore index d2cde965..dab30385 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,6 @@ python_bindings/tests/__pycache__/ *.pyd hnswlib.cpython*.so var/ +.idea/ +.vscode/ + diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index f23c17d9..f2a8b9dc 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -18,7 +18,6 @@ namespace hnswlib { public: static const tableint max_update_element_locks = 65536; HierarchicalNSW(SpaceInterface *s) { - } HierarchicalNSW(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) { @@ -29,7 +28,7 @@ namespace hnswlib { link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) { max_elements_ = max_elements; - has_deletions_=false; + num_deleted_ = 0; data_size_ = s->get_data_size(); fstdistfunc_ = s->get_dist_func(); dist_func_param_ = s->get_dist_func_param(); @@ -56,8 +55,6 @@ namespace hnswlib { visited_list_pool_ = new VisitedListPool(1, max_elements); - - //initializations for special treatment of the first node enterpoint_node_ = -1; maxlevel_ = -1; @@ -92,6 +89,7 @@ namespace hnswlib { size_t cur_element_count; size_t size_data_per_element_; size_t size_links_per_element_; + size_t num_deleted_; size_t M_; size_t maxM_; @@ -112,20 +110,15 @@ namespace hnswlib { std::vector link_list_update_locks_; tableint enterpoint_node_; - size_t size_links_level0_; size_t offsetData_, offsetLevel0_; - char *data_level0_memory_; char **linkLists_; std::vector element_levels_; size_t data_size_; - bool has_deletions_; - - size_t label_offset_; DISTFUNC fstdistfunc_; void *dist_func_param_; @@ -547,7 +540,7 @@ namespace hnswlib { } } - if (has_deletions_) { + if (num_deleted_) { std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, ef_); top_candidates.swap(top_candidates1); @@ -623,8 +616,6 @@ namespace hnswlib { } void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i=0) { - - std::ifstream input(location, std::ios::binary); if (!input.is_open()) @@ -639,7 +630,7 @@ namespace hnswlib { readBinaryPOD(input, max_elements_); readBinaryPOD(input, cur_element_count); - size_t max_elements=max_elements_i; + size_t max_elements = max_elements_i; if(max_elements < cur_element_count) max_elements = max_elements_; max_elements_ = max_elements; @@ -688,26 +679,19 @@ namespace hnswlib { input.seekg(pos,input.beg); - data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); if (data_level0_memory_ == nullptr) throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); input.read(data_level0_memory_, cur_element_count * size_data_per_element_); - - - size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); - size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); std::vector(max_elements).swap(link_list_locks_); std::vector(max_update_element_locks).swap(link_list_update_locks_); - visited_list_pool_ = new VisitedListPool(1, max_elements); - linkLists_ = (char **) malloc(sizeof(void *) * max_elements); if (linkLists_ == nullptr) throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); @@ -731,11 +715,9 @@ namespace hnswlib { } } - has_deletions_=false; - for (size_t i = 0; i < cur_element_count; i++) { if(isMarkedDeleted(i)) - has_deletions_=true; + num_deleted_ += 1; } input.close(); @@ -765,19 +747,19 @@ namespace hnswlib { } static const unsigned char DELETE_MARK = 0x01; -// static const unsigned char REUSE_MARK = 0x10; + // static const unsigned char REUSE_MARK = 0x10; /** * Marks an element with the given label deleted, does NOT really change the current graph. * @param label */ void markDelete(labeltype label) { - has_deletions_=true; auto search = label_lookup_.find(label); if (search == label_lookup_.end()) { throw std::runtime_error("Label not found"); } - markDeletedInternal(search->second); + tableint internalId = search->second; + markDeletedInternal(internalId); } /** @@ -786,8 +768,31 @@ namespace hnswlib { * @param internalId */ void markDeletedInternal(tableint internalId) { - unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; - *ll_cur |= DELETE_MARK; + assert(internalId < cur_element_count); + if (!isMarkedDeleted(internalId)) + { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur |= DELETE_MARK; + num_deleted_ += 1; + } + else + { + throw std::runtime_error("The requested to delete element is already deleted"); + } + } + + /** + * Remove the deleted mark of the node, does NOT really change the current graph. + * @param label + */ + void unmarkDelete(labeltype label) + { + auto search = label_lookup_.find(label); + if (search == label_lookup_.end()) { + throw std::runtime_error("Label not found"); + } + tableint internalId = search->second; + unmarkDeletedInternal(internalId); } /** @@ -795,8 +800,17 @@ namespace hnswlib { * @param internalId */ void unmarkDeletedInternal(tableint internalId) { - unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; - *ll_cur &= ~DELETE_MARK; + assert(internalId < cur_element_count); + if (isMarkedDeleted(internalId)) + { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur &= ~DELETE_MARK; + num_deleted_ -= 1; + } + else + { + throw std::runtime_error("The requested to undelete element is not deleted"); + } } /** @@ -857,8 +871,8 @@ namespace hnswlib { } for (auto&& neigh : sNeigh) { -// if (neigh == internalId) -// continue; + // if (neigh == internalId) + // continue; std::priority_queue, std::vector>, CompareByFirst> candidates; size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1 @@ -1133,7 +1147,7 @@ namespace hnswlib { } std::priority_queue, std::vector>, CompareByFirst> top_candidates; - if (has_deletions_) { + if (num_deleted_) { top_candidates=searchBaseLayerST( currObj, query_data, std::max(ef_, k)); } diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 285b5185..48fdf475 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -70,16 +70,14 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn std::rethrow_exception(lastException); } } - - } - inline void assert_true(bool expr, const std::string & msg) { - if (expr == false) - throw std::runtime_error("Unpickle Error: "+msg); - return; - } +inline void assert_true(bool expr, const std::string & msg) { + if (expr == false) + throw std::runtime_error("Unpickle Error: "+msg); + return; +} template @@ -141,14 +139,12 @@ class Index { seed=random_seed; } - void set_ef(size_t ef) { default_ef=ef; if (appr_alg) appr_alg->ef_ = ef; } - void set_num_threads(int num_threads) { this->num_threads_default = num_threads; } @@ -207,14 +203,14 @@ class Index { if (!ids_.is_none()) { py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_); auto ids_numpy = items.request(); - if(ids_numpy.ndim==1 && ids_numpy.shape[0]==rows) { + if(ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) { std::vector ids1(ids_numpy.shape[0]); for (size_t i = 0; i < ids1.size(); i++) { ids1[i] = items.data()[i]; } ids.swap(ids1); } - else if(ids_numpy.ndim==0 && rows==1) { + else if(ids_numpy.ndim == 0 && rows == 1) { ids.push_back(*items.data()); } else @@ -227,7 +223,7 @@ class Index { int start = 0; if (!ep_added) { size_t id = ids.size() ? ids.at(0) : (cur_l); - float *vector_data=(float *) items.data(0); + float *vector_data = (float *) items.data(0); std::vector norm_array(dim); if(normalize){ normalize_vector(vector_data, norm_array.data()); @@ -279,7 +275,6 @@ class Index { } std::vector getIdsList() { - std::vector ids; for(auto kv : appr_alg->label_lookup_) { @@ -290,9 +285,6 @@ class Index { py::dict getAnnData() const { /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */ - - - std::unique_lock templock(appr_alg->global); unsigned int level0_npy_size = appr_alg->cur_element_count * appr_alg->size_data_per_element_; @@ -369,7 +361,7 @@ class Index { "mult"_a=appr_alg->mult_, "ef_construction"_a=appr_alg->ef_construction_, "ef"_a=appr_alg->ef_, - "has_deletions"_a=appr_alg->has_deletions_, + "has_deletions"_a=(bool)appr_alg->num_deleted_, "size_links_per_element"_a=appr_alg->size_links_per_element_, "label_lookup_external"_a=py::array_t( @@ -402,10 +394,7 @@ class Index { {sizeof(char)}, // C-style contiguous strides for double link_list_npy, // the data pointer free_when_done_ll) - ); - - } @@ -431,7 +420,6 @@ class Index { static Index * createFromParams(const py::dict d) { - // check serialization version assert_true(((int)py::int_(Index::ser_version)) >= d["ser_version"].cast(), "Invalid serialization version!"); @@ -466,8 +454,6 @@ class Index { } void setAnnData(const py::dict d) { /* WARNING: Index::setAnnData is not thread-safe with Index::addItems */ - - std::unique_lock templock(appr_alg->global); assert_true(appr_alg->offsetLevel0_ == d["offset_level0"].cast(), "Invalid value of offsetLevel0_ "); @@ -489,7 +475,6 @@ class Index { assert_true(appr_alg->ef_construction_ == d["ef_construction"].cast(), "Invalid value of ef_construction_ "); appr_alg->ef_ = d["ef"].cast(); - appr_alg->has_deletions_=d["has_deletions"].cast(); assert_true(appr_alg->size_links_per_element_ == d["size_links_per_element"].cast(), "Invalid value of size_links_per_element_ "); @@ -535,10 +520,20 @@ class Index { } } + + // set num_deleted + appr_alg->num_deleted_ = 0; + bool has_deletions = d["has_deletions"].cast(); + if (has_deletions) + { + for (size_t i = 0; i < appr_alg->cur_element_count; i++) { + if(appr_alg->isMarkedDeleted(i)) + appr_alg->num_deleted_ += 1; + } + } } py::object knnQuery_return_numpy(py::object input, size_t k = 1, int num_threads = -1) { - py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); auto buffer = items.request(); hnswlib::labeltype *data_numpy_l; @@ -561,7 +556,6 @@ class Index { features = buffer.shape[0]; } - // avoid using threads when the number of searches is small: if(rows<=num_threads*4){ @@ -609,7 +603,6 @@ class Index { } ); } - } py::capsule free_when_done_l(data_numpy_l, [](void *f) { delete[] f; @@ -618,7 +611,6 @@ class Index { delete[] f; }); - return py::make_tuple( py::array_t( {rows, k}, // shape @@ -638,6 +630,10 @@ class Index { appr_alg->markDelete(label); } + void unmarkDeleted(size_t label) { + appr_alg->unmarkDelete(label); + } + void resizeIndex(size_t new_size) { appr_alg->resizeIndex(new_size); } @@ -649,11 +645,9 @@ class Index { size_t getCurrentCount() const { return appr_alg->cur_element_count; } - }; - PYBIND11_PLUGIN(hnswlib) { py::module m("hnswlib"); @@ -672,6 +666,7 @@ PYBIND11_PLUGIN(hnswlib) { .def("save_index", &Index::saveIndex, py::arg("path_to_index")) .def("load_index", &Index::loadIndex, py::arg("path_to_index"), py::arg("max_elements")=0) .def("mark_deleted", &Index::markDeleted, py::arg("label")) + .def("unmark_deleted", &Index::unmarkDeleted, py::arg("label")) .def("resize_index", &Index::resizeIndex, py::arg("new_size")) .def("get_max_elements", &Index::getMaxElements) .def("get_current_count", &Index::getCurrentCount) diff --git a/python_bindings/tests/bindings_test_labels.py b/python_bindings/tests/bindings_test_labels.py index 668d7694..87259f1f 100644 --- a/python_bindings/tests/bindings_test_labels.py +++ b/python_bindings/tests/bindings_test_labels.py @@ -94,23 +94,23 @@ def testRandomSelf(self): self.assertEqual(np.sum(~np.asarray(sorted_labels) == np.asarray(range(num_elements))), 0) # Delete data1 - labels1, _ = p.knn_query(data1, k=1) + labels1_deleted, _ = p.knn_query(data1, k=1) - for l in labels1: + for l in labels1_deleted: p.mark_deleted(l[0]) labels2, _ = p.knn_query(data2, k=1) - items=p.get_items(labels2) + items = p.get_items(labels2) diff_with_gt_labels = np.mean(np.abs(data2-items)) self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) # console labels1_after, _ = p.knn_query(data1, k=1) for la in labels1_after: - for lb in labels1: + for lb in labels1_deleted: if la[0] == lb[0]: self.assertTrue(False) print("All the data in data1 are removed") - # checking saving/loading index with elements marked as deleted + # Checking saving/loading index with elements marked as deleted del_index_path = "with_deleted.bin" p.save_index(del_index_path) p = hnswlib.Index(space='l2', dim=dim) @@ -119,9 +119,16 @@ def testRandomSelf(self): labels1_after, _ = p.knn_query(data1, k=1) for la in labels1_after: - for lb in labels1: + for lb in labels1_deleted: if la[0] == lb[0]: self.assertTrue(False) + # Unmark deleted data + for l in labels1_deleted: + p.unmark_deleted(l[0]) + labels_restored, _ = p.knn_query(data1, k=1) + self.assertAlmostEqual(np.mean(labels_restored.reshape(-1) == np.arange(len(data1))), 1.0, 3) + print("All the data in data1 are restored") + os.remove(index_path) os.remove(del_index_path)