Skip to content

Commit

Permalink
Add unit tests + small fixes
Browse files Browse the repository at this point in the history
Signed-off-by: krishy91 <[email protected]>
  • Loading branch information
krishy91 committed Jan 16, 2024
1 parent 861d3c0 commit 70dcf1b
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
*/
package org.opensearch.neuralsearch.processor;

import java.util.*;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -48,14 +52,14 @@ public abstract class InferenceProcessor extends AbstractProcessor {
private final Environment environment;

public InferenceProcessor(
String tag,
String description,
String type,
String listTypeNestedMapKey,
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment
String tag,
String description,
String type,
String listTypeNestedMapKey,
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment
) {
super(tag, description);
this.type = type;
Expand All @@ -71,21 +75,21 @@ public InferenceProcessor(

private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {
if (fieldMap == null
|| fieldMap.size() == 0
|| fieldMap.entrySet()
|| fieldMap.size() == 0
|| fieldMap.entrySet()
.stream()
.anyMatch(
x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString())
x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString())
)) {
throw new IllegalArgumentException("Unable to create the processor as field_map has invalid key or value");
}
}

public abstract void doExecute(
IngestDocument ingestDocument,
Map<String, Object> ProcessMap,
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
IngestDocument ingestDocument,
Map<String, Object> ProcessMap,
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
);

@Override
Expand Down Expand Up @@ -162,34 +166,34 @@ Map<String, Object> buildMapWithProcessorKeyAndOriginalValue(IngestDocument inge
}

private void buildMapWithProcessorKeyAndOriginalValueForMapType(
String parentKey,
Object processorKey,
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> treeRes
String parentKey,
Object processorKey,
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> treeRes
) {
if (processorKey == null || sourceAndMetadataMap == null) return;
if (processorKey instanceof Map) {
Map<String, Object> next = new LinkedHashMap<>();
if (sourceAndMetadataMap.get(parentKey) instanceof Map) {
for (Map.Entry<String, Object> nestedFieldMapEntry : ((Map<String, Object>) processorKey).entrySet()) {
buildMapWithProcessorKeyAndOriginalValueForMapType(
nestedFieldMapEntry.getKey(),
nestedFieldMapEntry.getValue(),
(Map<String, Object>) sourceAndMetadataMap.get(parentKey),
next
nestedFieldMapEntry.getKey(),
nestedFieldMapEntry.getValue(),
(Map<String, Object>) sourceAndMetadataMap.get(parentKey),
next
);
}
} else if (sourceAndMetadataMap.get(parentKey) instanceof List) {
for (Map.Entry<String, Object> nestedFieldMapEntry : ((Map<String, Object>) processorKey).entrySet()) {
List<Map<String, Object>> list = (List<Map<String, Object>>) sourceAndMetadataMap.get(parentKey);
List<Object> listOfStrings = list.stream().map(x -> x.get(nestedFieldMapEntry.getKey())).collect(Collectors.toList());
Map<String, Object> map = new HashMap();
Map<String, Object> map = new LinkedHashMap<>();
map.put(nestedFieldMapEntry.getKey(), listOfStrings);
buildMapWithProcessorKeyAndOriginalValueForMapType(
nestedFieldMapEntry.getKey(),
nestedFieldMapEntry.getValue(),
map,
next
nestedFieldMapEntry.getKey(),
nestedFieldMapEntry.getValue(),
map,
next
);
}
}
Expand Down Expand Up @@ -234,9 +238,9 @@ private void validateNestedTypeValue(String sourceKey, Object sourceValue, Suppl
validateListTypeValue(sourceKey, sourceValue, maxDepthSupplier);
} else if (Map.class.isAssignableFrom(sourceValue.getClass())) {
((Map) sourceValue).values()
.stream()
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1));
.stream()
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1));
} else if (!String.class.isAssignableFrom(sourceValue.getClass())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
Expand Down Expand Up @@ -287,11 +291,11 @@ Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> res

@SuppressWarnings({ "unchecked" })
private void putNLPResultToSourceMapForMapType(
String processorKey,
Object sourceValue,
List<?> results,
IndexWrapper indexWrapper,
Map<String, Object> sourceAndMetadataMap
String processorKey,
Object sourceValue,
List<?> results,
IndexWrapper indexWrapper,
Map<String, Object> sourceAndMetadataMap
) {
if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return;
if (sourceValue instanceof Map) {
Expand All @@ -303,11 +307,11 @@ private void putNLPResultToSourceMapForMapType(
}
} else {
putNLPResultToSourceMapForMapType(
inputNestedMapEntry.getKey(),
inputNestedMapEntry.getValue(),
results,
indexWrapper,
(Map<String, Object>) sourceAndMetadataMap.get(processorKey)
inputNestedMapEntry.getKey(),
inputNestedMapEntry.getValue(),
results,
indexWrapper,
(Map<String, Object>) sourceAndMetadataMap.get(processorKey)
);
}
}
Expand All @@ -321,7 +325,7 @@ private void putNLPResultToSourceMapForMapType(
private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceValue, List<?> results, IndexWrapper indexWrapper) {
List<Map<String, Object>> keyToResult = new ArrayList<>();
IntStream.range(0, sourceValue.size())
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
return keyToResult;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Arrays;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

Expand Down Expand Up @@ -404,6 +405,20 @@ public void testBuildVectorOutput_withNestedMap_successful() {
assertNotNull(actionGamesKnn);
}

public void testBuildVectorOutput_withNestedList_successful() {
Map<String, Object> config = createNestedListConfiguration();
IngestDocument ingestDocument = createNestedListIngestDocument();
TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config);
Map<String, Object> knnMap = textEmbeddingProcessor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument);
List<List<Float>> modelTensorList = createMockVectorResult();
textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata());
List<Map<String, Object>> nestedObj = (List<Map<String, Object>>) ingestDocument.getSourceAndMetadata().get("nestedField");
assertTrue(nestedObj.get(0).containsKey("vectorField"));
assertTrue(nestedObj.get(1).containsKey("vectorField"));
assertNotNull(nestedObj.get(0).get("vectorField"));
assertNotNull(nestedObj.get(1).get("vectorField"));
}

public void test_updateDocument_appendVectorFieldsToDocument_successful() {
Map<String, Object> config = createPlainStringConfiguration();
IngestDocument ingestDocument = createPlainIngestDocument();
Expand Down Expand Up @@ -520,4 +535,22 @@ private IngestDocument createNestedMapIngestDocument() {
result.put("favorites", favorite);
return new IngestDocument(result, new HashMap<>());
}

private Map<String, Object> createNestedListConfiguration() {
Map<String, Object> nestedConfig = new HashMap<>();
nestedConfig.put("textField", "vectorField");
Map<String, Object> result = new HashMap<>();
result.put("nestedField", nestedConfig);
return result;
}

private IngestDocument createNestedListIngestDocument() {
HashMap<String, Object> nestedObj1 = new HashMap<>();
nestedObj1.put("textField", "This is a text field");
HashMap<String, Object> nestedObj2 = new HashMap<>();
nestedObj2.put("textField", "This is another text field");
HashMap<String, Object> nestedList = new HashMap<>();
nestedList.put("nestedField", Arrays.asList(nestedObj1, nestedObj2));
return new IngestDocument(nestedList, new HashMap<>());
}
}

0 comments on commit 70dcf1b

Please sign in to comment.