Skip to content

Commit

Permalink
Support script score when doc value is disabled and fix misusing DISI (
Browse files Browse the repository at this point in the history
…opensearch-project#1696)

* Revert "Revert 'Support script score when doc value is disabled' (opensearch-project#1662)"

This reverts commit bd2f403.

Signed-off-by: panguixin <[email protected]>

* fix misusing doc value

Signed-off-by: panguixin <[email protected]>

* add changelog

Signed-off-by: panguixin <[email protected]>

---------

Signed-off-by: panguixin <[email protected]>
  • Loading branch information
bugmakerrrrrr authored May 20, 2024
1 parent e584822 commit 4d59d4c
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 46 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
* Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688)
* Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684)
* Support script score when doc value is disabled and fix misusing DISI [#1696](https://github.com/opensearch-project/k-NN/pull/1696)
### Bug Fixes
* Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692)
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

package org.opensearch.knn.index;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.index.fielddata.LeafFieldData;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.index.fielddata.SortedBinaryDocValues;
Expand Down Expand Up @@ -39,10 +40,29 @@ public long ramBytesUsed() {
@Override
public ScriptDocValues<float[]> getScriptValues() {
try {
BinaryDocValues values = DocValues.getBinary(reader, fieldName);
return new KNNVectorScriptDocValues(values, fieldName, vectorDataType);
FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName);
if (fieldInfo == null) {
return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType);
}

DocIdSetIterator values;
if (fieldInfo.hasVectorValues()) {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
values = reader.getFloatVectorValues(fieldName);
break;
case BYTE:
values = reader.getByteVectorValues(fieldName);
break;
default:
throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding());
}
} else {
values = DocValues.getBinary(reader, fieldName);
}
return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType);
} catch (IOException e) {
throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e);
throw new IllegalStateException("Cannot load values for knn vector field: " + fieldName, e);
}
}

Expand Down
117 changes: 107 additions & 10 deletions src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,40 @@
package org.opensearch.knn.index;

import java.io.IOException;
import java.util.Objects;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.ExceptionsHelper;
import org.opensearch.index.fielddata.ScriptDocValues;

import java.io.IOException;

@RequiredArgsConstructor
public final class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public abstract class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {

private final BinaryDocValues binaryDocValues;
private final DocIdSetIterator vectorValues;
private final String fieldName;
@Getter
private final VectorDataType vectorDataType;
private boolean docExists = false;
private int lastDocID = -1;

@Override
public void setNextDocId(int docId) throws IOException {
if (binaryDocValues.advanceExact(docId)) {
docExists = true;
return;
if (docId < lastDocID) {
throw new IllegalArgumentException("docs were sent out-of-order: lastDocID=" + lastDocID + " vs docID=" + docId);
}

lastDocID = docId;

int curDocID = vectorValues.docID();
if (lastDocID > curDocID) {
curDocID = vectorValues.advance(docId);
}
docExists = false;
docExists = lastDocID == curDocID;
}

public float[] getValue() {
Expand All @@ -44,12 +54,14 @@ public float[] getValue() {
throw new IllegalStateException(errorMessage);
}
try {
return vectorDataType.getVectorFromBytesRef(binaryDocValues.binaryValue());
return doGetValue();
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}

protected abstract float[] doGetValue() throws IOException;

@Override
public int size() {
return docExists ? 1 : 0;
Expand All @@ -59,4 +71,89 @@ public int size() {
public float[] get(int i) {
throw new UnsupportedOperationException("knn vector does not support this operation");
}

/**
* Creates a KNNVectorScriptDocValues object based on the provided parameters.
*
* @param values The DocIdSetIterator representing the vector values.
* @param fieldName The name of the field.
* @param vectorDataType The data type of the vector.
* @return A KNNVectorScriptDocValues object based on the type of the values.
* @throws IllegalArgumentException If the type of values is unsupported.
*/
public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) {
Objects.requireNonNull(values, "values must not be null");
if (values instanceof ByteVectorValues) {
return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof FloatVectorValues) {
return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof BinaryDocValues) {
return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType);
} else {
throw new IllegalArgumentException("Unsupported values type: " + values.getClass());
}
}

private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues {
private final ByteVectorValues values;

KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
byte[] bytes = values.vectorValue();
float[] value = new float[bytes.length];
for (int i = 0; i < bytes.length; i++) {
value[i] = (float) bytes[i];
}
return value;
}
}

private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues {
private final FloatVectorValues values;

KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
return values.vectorValue();
}
}

private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues {
private final BinaryDocValues values;

KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
return getVectorDataType().getVectorFromBytesRef(values.binaryValue());
}
}

/**
* Creates an empty KNNVectorScriptDocValues object based on the provided field name and vector data type.
*
* @param fieldName The name of the field.
* @param type The data type of the vector.
* @return An empty KNNVectorScriptDocValues object.
*/
public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) {
return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) {
@Override
protected float[] doGetValue() throws IOException {
throw new UnsupportedOperationException("empty values");
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@

package org.opensearch.knn.index;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.knn.KNNTestCase;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.document.BinaryDocValuesField;
Expand Down Expand Up @@ -33,26 +41,39 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase {
public void setUp() throws Exception {
super.setUp();
directory = newDirectory();
createKNNVectorDocument(directory);
Class<? extends DocIdSetIterator> valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class);
createKNNVectorDocument(directory, valuesClass);
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
scriptDocValues = new KNNVectorScriptDocValues(
leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME),
MOCK_INDEX_FIELD_NAME,
VectorDataType.FLOAT
);
LeafReader leafReader = reader.getContext().leaves().get(0).reader();
DocIdSetIterator vectorValues;
if (BinaryDocValues.class.equals(valuesClass)) {
vectorValues = DocValues.getBinary(leafReader, MOCK_INDEX_FIELD_NAME);
} else if (ByteVectorValues.class.equals(valuesClass)) {
vectorValues = leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME);
} else {
vectorValues = leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME);
}

scriptDocValues = KNNVectorScriptDocValues.create(vectorValues, MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT);
}

private void createKNNVectorDocument(Directory directory) throws IOException {
private void createKNNVectorDocument(Directory directory, Class<? extends DocIdSetIterator> valuesClass) throws IOException {
IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random()));
IndexWriter writer = new IndexWriter(directory, conf);
Document knnDocument = new Document();
knnDocument.add(
new BinaryDocValuesField(
Field field;
if (BinaryDocValues.class.equals(valuesClass)) {
field = new BinaryDocValuesField(
MOCK_INDEX_FIELD_NAME,
new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue()
)
);
);
} else if (ByteVectorValues.class.equals(valuesClass)) {
field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA);
} else {
field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA);
}

knnDocument.add(field);
writer.addDocument(knnDocument);
writer.commit();
writer.close();
Expand Down Expand Up @@ -84,4 +105,18 @@ public void testSize() throws IOException {
public void testGet() throws IOException {
expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0));
}

public void testUnsupportedValues() throws IOException {
expectThrows(
IllegalArgumentException.class,
() -> KNNVectorScriptDocValues.create(DocValues.emptyNumeric(), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT)
);
}

public void testEmptyValues() throws IOException {
KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT);
assertEquals(0, values.size());
scriptDocValues.setNextDocId(0);
assertEquals(0, values.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() {
createKNNFloatVectorDocument(directory);
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
return new KNNVectorScriptDocValues(
return KNNVectorScriptDocValues.create(
leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME),
VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME,
VectorDataType.FLOAT
Expand All @@ -70,7 +70,7 @@ private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() {
createKNNByteVectorDocument(directory);
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
return new KNNVectorScriptDocValues(
return KNNVectorScriptDocValues.create(
leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME),
VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME,
VectorDataType.BYTE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOEx
if (scriptDocValues == null) {
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
scriptDocValues = new KNNVectorScriptDocValues(
scriptDocValues = KNNVectorScriptDocValues.create(
leafReaderContext.reader().getBinaryDocValues(fieldName),
fieldName,
VectorDataType.FLOAT
Expand Down
Loading

0 comments on commit 4d59d4c

Please sign in to comment.