Skip to content

Commit

Permalink
Fix update document with knnn_vector size not matching issue (#208)
Browse files Browse the repository at this point in the history
* Fix update document with knnn_vector size not matching issue

Signed-off-by: zane-neo <[email protected]>

* Add apache common lang3 back to fix  PR gradle build failure

Signed-off-by: zane-neo <[email protected]>

* Fix PR check failure

Signed-off-by: zane-neo <[email protected]>

* Fix PR check failure

Signed-off-by: zane-neo <[email protected]>

* Fix PR check jar hell failure

Signed-off-by: zane-neo <[email protected]>

* Add apache common lang back to fix the PR check

Signed-off-by: zane-neo <[email protected]>

* Fix register model group failure in IT

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* Rename the method appendVectorFieldsToDocument to setVectorFieldsToDocument

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

---------

Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Jul 11, 2023
1 parent 7ee1eb3 commit 9599371
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Ex
handler.accept(ingestDocument, null);
} else {
mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> {
appendVectorFieldsToDocument(ingestDocument, knnMap, vectors);
setVectorFieldsToDocument(ingestDocument, knnMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
}
Expand All @@ -115,11 +115,11 @@ public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Ex

}

void appendVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> knnMap, List<List<Float>> vectors) {
void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> knnMap, List<List<Float>> vectors) {
Objects.requireNonNull(vectors, "embedding failed, inference returns null result!");
log.debug("Text embedding result fetched, starting build vector output!");
Map<String, Object> textEmbeddingResult = buildTextEmbeddingResult(knnMap, vectors, ingestDocument.getSourceAndMetadata());
textEmbeddingResult.forEach(ingestDocument::appendFieldValue);
textEmbeddingResult.forEach(ingestDocument::setFieldValue);
}

@SuppressWarnings({ "unchecked" })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Predicate;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -501,6 +502,7 @@ private String registerModelGroup() throws IOException, URISyntaxException {
String modelGroupRegisterRequestBody = Files.readString(
Path.of(classLoader.getResource("processor/CreateModelGroupRequestBody.json").toURI())
);
modelGroupRegisterRequestBody = modelGroupRegisterRequestBody.replace("<MODEL_GROUP_NAME>", UUID.randomUUID().toString());
Response modelGroupResponse = makeRequest(
client(),
"POST",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.NodeNotConnectedException;

import com.google.common.collect.ImmutableMap;

public class MLCommonsClientAccessorTests extends OpenSearchTestCase {

@Mock
Expand Down Expand Up @@ -168,7 +170,9 @@ private ModelTensorOutput createModelTensorOutput(final Float[] output) {
output,
new long[] { 1, 2 },
MLResultDataType.FLOAT64,
ByteBuffer.wrap(new byte[12])
ByteBuffer.wrap(new byte[12]),
"mockResult",
ImmutableMap.of("mockKey", "mockValue")
);
mlModelTensorList.add(tensor);
final ModelTensors modelTensors = new ModelTensors(mlModelTensorList);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@

import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.isA;
import static org.mockito.Mockito.isNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -350,7 +357,7 @@ public void testProcessResponse_successful() throws Exception {
Map<String, Object> knnMap = processor.buildMapWithKnnKeyAndOriginalValue(ingestDocument);

List<List<Float>> modelTensorList = createMockVectorResult();
processor.appendVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList);
processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList);
assertEquals(12, ingestDocument.getSourceAndMetadata().size());
}

Expand Down Expand Up @@ -398,6 +405,20 @@ public void testBuildVectorOutput_withNestedMap_successful() {
assertNotNull(actionGamesKnn);
}

public void test_updateDocument_appendVectorFieldsToDocument_successful() {
Map<String, Object> config = createPlainStringConfiguration();
IngestDocument ingestDocument = createPlainIngestDocument();
TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config);
Map<String, Object> knnMap = processor.buildMapWithKnnKeyAndOriginalValue(ingestDocument);
List<List<Float>> modelTensorList = createMockVectorResult();
processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList);

List<List<Float>> modelTensorList1 = createMockVectorResult();
processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList1);
assertEquals(12, ingestDocument.getSourceAndMetadata().size());
assertEquals(2, ((List<?>) ingestDocument.getSourceAndMetadata().get("oriKey6_knn")).size());
}

private List<List<Float>> createMockVectorResult() {
List<List<Float>> modelTensorList = new ArrayList<>();
List<Float> number1 = ImmutableList.of(1.234f, 2.354f);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"name": "test_model_group_public",
"name": "<MODEL_GROUP_NAME>",
"description": "This is a public model group"
}

0 comments on commit 9599371

Please sign in to comment.