Skip to content

Commit

Permalink
Merge pull request #334 from dyashuni/unmark_deleted
Browse files Browse the repository at this point in the history
Unmark deleted
  • Loading branch information
yurymalkov authored Nov 23, 2021
2 parents 36d00bf + 9006b32 commit 47bb1a1
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 69 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ python_bindings/tests/__pycache__/
*.pyd
hnswlib.cpython*.so
var/
.idea/
.vscode/

80 changes: 47 additions & 33 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ namespace hnswlib {
public:
static const tableint max_update_element_locks = 65536;
HierarchicalNSW(SpaceInterface<dist_t> *s) {

}

HierarchicalNSW(SpaceInterface<dist_t> *s, const std::string &location, bool nmslib = false, size_t max_elements=0) {
Expand All @@ -29,7 +28,7 @@ namespace hnswlib {
link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) {
max_elements_ = max_elements;

has_deletions_=false;
num_deleted_ = 0;
data_size_ = s->get_data_size();
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();
Expand All @@ -56,8 +55,6 @@ namespace hnswlib {

visited_list_pool_ = new VisitedListPool(1, max_elements);



//initializations for special treatment of the first node
enterpoint_node_ = -1;
maxlevel_ = -1;
Expand Down Expand Up @@ -92,6 +89,7 @@ namespace hnswlib {
size_t cur_element_count;
size_t size_data_per_element_;
size_t size_links_per_element_;
size_t num_deleted_;

size_t M_;
size_t maxM_;
Expand All @@ -112,20 +110,15 @@ namespace hnswlib {
std::vector<std::mutex> link_list_update_locks_;
tableint enterpoint_node_;


size_t size_links_level0_;
size_t offsetData_, offsetLevel0_;


char *data_level0_memory_;
char **linkLists_;
std::vector<int> element_levels_;

size_t data_size_;

bool has_deletions_;


size_t label_offset_;
DISTFUNC<dist_t> fstdistfunc_;
void *dist_func_param_;
Expand Down Expand Up @@ -547,7 +540,7 @@ namespace hnswlib {
}
}

if (has_deletions_) {
if (num_deleted_) {
std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<true>(currObj, query_data,
ef_);
top_candidates.swap(top_candidates1);
Expand Down Expand Up @@ -623,8 +616,6 @@ namespace hnswlib {
}

void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i=0) {


std::ifstream input(location, std::ios::binary);

if (!input.is_open())
Expand All @@ -639,7 +630,7 @@ namespace hnswlib {
readBinaryPOD(input, max_elements_);
readBinaryPOD(input, cur_element_count);

size_t max_elements=max_elements_i;
size_t max_elements = max_elements_i;
if(max_elements < cur_element_count)
max_elements = max_elements_;
max_elements_ = max_elements;
Expand Down Expand Up @@ -688,26 +679,19 @@ namespace hnswlib {

input.seekg(pos,input.beg);


data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
if (data_level0_memory_ == nullptr)
throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
input.read(data_level0_memory_, cur_element_count * size_data_per_element_);




size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);


size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
std::vector<std::mutex>(max_elements).swap(link_list_locks_);
std::vector<std::mutex>(max_update_element_locks).swap(link_list_update_locks_);


visited_list_pool_ = new VisitedListPool(1, max_elements);


linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
if (linkLists_ == nullptr)
throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
Expand All @@ -731,11 +715,9 @@ namespace hnswlib {
}
}

has_deletions_=false;

for (size_t i = 0; i < cur_element_count; i++) {
if(isMarkedDeleted(i))
has_deletions_=true;
num_deleted_ += 1;
}

input.close();
Expand Down Expand Up @@ -765,19 +747,19 @@ namespace hnswlib {
}

static const unsigned char DELETE_MARK = 0x01;
// static const unsigned char REUSE_MARK = 0x10;
// static const unsigned char REUSE_MARK = 0x10;
/**
* Marks an element with the given label deleted, does NOT really change the current graph.
* @param label
*/
void markDelete(labeltype label)
{
has_deletions_=true;
auto search = label_lookup_.find(label);
if (search == label_lookup_.end()) {
throw std::runtime_error("Label not found");
}
markDeletedInternal(search->second);
tableint internalId = search->second;
markDeletedInternal(internalId);
}

/**
Expand All @@ -786,17 +768,49 @@ namespace hnswlib {
* @param internalId
*/
void markDeletedInternal(tableint internalId) {
unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
*ll_cur |= DELETE_MARK;
assert(internalId < cur_element_count);
if (!isMarkedDeleted(internalId))
{
unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
*ll_cur |= DELETE_MARK;
num_deleted_ += 1;
}
else
{
throw std::runtime_error("The requested to delete element is already deleted");
}
}

/**
* Remove the deleted mark of the node, does NOT really change the current graph.
* @param label
*/
void unmarkDelete(labeltype label)
{
auto search = label_lookup_.find(label);
if (search == label_lookup_.end()) {
throw std::runtime_error("Label not found");
}
tableint internalId = search->second;
unmarkDeletedInternal(internalId);
}

/**
* Remove the deleted mark of the node.
* @param internalId
*/
void unmarkDeletedInternal(tableint internalId) {
unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
*ll_cur &= ~DELETE_MARK;
assert(internalId < cur_element_count);
if (isMarkedDeleted(internalId))
{
unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
*ll_cur &= ~DELETE_MARK;
num_deleted_ -= 1;
}
else
{
throw std::runtime_error("The requested to undelete element is not deleted");
}
}

/**
Expand Down Expand Up @@ -857,8 +871,8 @@ namespace hnswlib {
}

for (auto&& neigh : sNeigh) {
// if (neigh == internalId)
// continue;
// if (neigh == internalId)
// continue;

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1
Expand Down Expand Up @@ -1133,7 +1147,7 @@ namespace hnswlib {
}

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
if (has_deletions_) {
if (num_deleted_) {
top_candidates=searchBaseLayerST<true,true>(
currObj, query_data, std::max(ef_, k));
}
Expand Down
Loading

0 comments on commit 47bb1a1

Please sign in to comment.