Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Toy: Limited Training Size #274

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 135 additions & 60 deletions section_faiss_vector_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,13 @@ func calculateNprobe(nlist int, indexOptimizedFor string) int32 {
// perhaps, parallelized merging can help speed things up over here.
func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase,
vecIndexes []*vecIndexInfo, w *CountHashWriter, closeCh chan struct{}) error {

// safe to assume that all the indexes are of the same config values, given
// that they are extracted from the field mapping info.
var dims, metric int
var dims, metric, nvecs, finalVecIDCap, indexDataCap, reconsCap int
var indexOptimizedFor string

var validMerge bool
var finalVecIDCap, indexDataCap, reconsCap int

for segI, segBase := range sbs {
// Considering merge operations on vector indexes are expensive, it is
// worth including an early exit if the merge is aborted, saving us
Expand All @@ -308,6 +307,8 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase,
}
if len(vecIndexes[segI].vecIds) > 0 {
indexReconsLen := len(vecIndexes[segI].vecIds) * index.D()
dims = index.D()
nvecs += len(vecIndexes[segI].vecIds)
if indexReconsLen > reconsCap {
reconsCap = indexReconsLen
}
Expand All @@ -328,47 +329,6 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase,
return nil
}

finalVecIDs := make([]int64, 0, finalVecIDCap)
// merging of indexes with reconstruction method.
// the indexes[i].vecIds has only the valid vecs of this vector
// index present in it, so we'd be reconstructing only those.
indexData := make([]float32, 0, indexDataCap)
// reusable buffer for reconstruction
recons := make([]float32, 0, reconsCap)
var err error
for i := 0; i < len(vecIndexes); i++ {
if isClosed(closeCh) {
freeReconstructedIndexes(vecIndexes)
return seg.ErrClosed
}

// reconstruct the vectors only if present, it could be that
// some of the indexes had all of their vectors updated/deleted.
if len(vecIndexes[i].vecIds) > 0 {
neededReconsLen := len(vecIndexes[i].vecIds) * vecIndexes[i].index.D()
recons = recons[:neededReconsLen]
// todo: parallelize reconstruction
recons, err = vecIndexes[i].index.ReconstructBatch(vecIndexes[i].vecIds, recons)
if err != nil {
freeReconstructedIndexes(vecIndexes)
return err
}
indexData = append(indexData, recons...)
// Adding vector IDs in the same order as the vectors
finalVecIDs = append(finalVecIDs, vecIndexes[i].vecIds...)
}
}

if len(indexData) == 0 {
// no valid vectors for this index, so we don't even have to
// record it in the section
freeReconstructedIndexes(vecIndexes)
return nil
}
recons = nil

nvecs := len(finalVecIDs)

// index type to be created after merge based on the number of vectors
// in indexData added into the index.
nlist := determineCentroids(nvecs)
Expand All @@ -378,8 +338,6 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase,
// to do the same is not needed because the following operations don't need
// the reconstructed ones anymore and doing so will hold up memory which can
// be detrimental while creating indexes during introduction.
freeReconstructedIndexes(vecIndexes)
vecIndexes = nil

faissIndex, err := faiss.IndexFactory(dims, indexDescription, metric)
if err != nil {
Expand All @@ -388,34 +346,151 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase,
defer faissIndex.Close()

if indexClass == IndexTypeIVF {
// the direct map maintained in the IVF index is essential for the
// reconstruction of vectors based on vector IDs in the future merges.
// the AddWithIDs API also needs a direct map to be set before using.
err = faissIndex.SetDirectMap(2)
if err != nil {
return err
}

nprobe := calculateNprobe(nlist, indexOptimizedFor)
faissIndex.SetNProbe(nprobe)
}

if nvecs < 100000 {
finalVecIDs := make([]int64, 0, finalVecIDCap)
// merging of indexes with reconstruction method.
// the indexes[i].vecIds has only the valid vecs of this vector
// index present in it, so we'd be reconstructing only those.
indexData := make([]float32, 0, indexDataCap)
// reusable buffer for reconstruction
recons := make([]float32, 0, reconsCap)
var err error
for i := 0; i < len(vecIndexes); i++ {
if isClosed(closeCh) {
freeReconstructedIndexes(vecIndexes)
return seg.ErrClosed
}

// reconstruct the vectors only if present, it could be that
// some of the indexes had all of their vectors updated/deleted.
if len(vecIndexes[i].vecIds) > 0 {
neededReconsLen := len(vecIndexes[i].vecIds) * vecIndexes[i].index.D()
recons = recons[:neededReconsLen]
// todo: parallelize reconstruction
recons, err = vecIndexes[i].index.ReconstructBatch(vecIndexes[i].vecIds, recons)
if err != nil {
freeReconstructedIndexes(vecIndexes)
return err
}
indexData = append(indexData, recons...)
// Adding vector IDs in the same order as the vectors
finalVecIDs = append(finalVecIDs, vecIndexes[i].vecIds...)
}
}

if len(indexData) == 0 {
// no valid vectors for this index, so we don't even have to
// record it in the section
freeReconstructedIndexes(vecIndexes)
return nil
}

// train the vector index, essentially performs k-means clustering to partition
// the data space of indexData such that during the search time, we probe
// only a subset of vectors -> non-exhaustive search. could be a time
// consuming step when the indexData is large.
err = faissIndex.Train(indexData)
recons = nil
freeReconstructedIndexes(vecIndexes)
vecIndexes = nil

if indexClass == IndexTypeIVF {
err = faissIndex.Train(indexData)
if err != nil {
return err
}
}
err = faissIndex.AddWithIDs(indexData, finalVecIDs)
if err != nil {
return err
}
}

err = faissIndex.AddWithIDs(indexData, finalVecIDs)
if err != nil {
return err
indexData = nil
finalVecIDs = nil

} else {
recons := make([]float32, 0, reconsCap)
curVecs := 0
vecLimit := 100000
if vecLimit < nlist*40 {
vecLimit = nlist * 40
}
finalVecIDs := make([]int64, 0, vecLimit)
indexData := make([]float32, 0, vecLimit*dims)
trained := false

var err error

for i := 0; i < len(vecIndexes); i++ {
if isClosed(closeCh) {
freeReconstructedIndexes(vecIndexes)
return seg.ErrClosed
}

if len(vecIndexes[i].vecIds) > 0 {
neededReconsLen := len(vecIndexes[i].vecIds) * dims
recons = recons[:neededReconsLen]
// todo: parallelize reconstruction
recons, err = vecIndexes[i].index.ReconstructBatch(vecIndexes[i].vecIds, recons)
if err != nil {
freeReconstructedIndexes(vecIndexes)
return err
}
vecLen := len(vecIndexes[i].vecIds)
shift := 0
for curVecs+vecLen > vecLimit {
indexData = append(indexData, recons[shift*dims:(shift+vecLimit-curVecs)*dims]...)
finalVecIDs = append(finalVecIDs, vecIndexes[i].vecIds[shift:(shift+vecLimit-curVecs)]...)
if !trained {
err = faissIndex.Train(indexData)
if err != nil {
freeReconstructedIndexes(vecIndexes)
return err
}
trained = true
}
err = faissIndex.AddWithIDs(indexData, finalVecIDs)
if err != nil {
freeReconstructedIndexes(vecIndexes)
return err
}
indexData = indexData[:0]
finalVecIDs = finalVecIDs[:0]
shift += vecLimit - curVecs
vecLen -= vecLimit - curVecs
curVecs = 0
}
if vecLen != 0 {
indexData = append(indexData, recons[shift*dims:(shift+vecLen)*dims]...)
finalVecIDs = append(finalVecIDs, vecIndexes[i].vecIds[shift:shift+vecLen]...)
curVecs = len(finalVecIDs)
}
}
}

recons = nil
freeReconstructedIndexes(vecIndexes)
vecIndexes = nil
if curVecs > 0 {
if !trained {
err = faissIndex.Train(indexData)
if err != nil {
return err
}
}
err = faissIndex.AddWithIDs(indexData, finalVecIDs)
if err != nil {
return err
}
}
indexData = nil
finalVecIDs = nil
}

indexData = nil
finalVecIDs = nil
var mergedIndexBytes []byte
mergedIndexBytes, err = faiss.WriteIndexIntoBuffer(faissIndex)
if err != nil {
Expand Down
Loading