Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Use AbstractBatchingProcessor for InferenceProcessor #832

Merged
merged 2 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
- Enable sorting and search_after features in Hybrid Search [#827](https://github.com/opensearch-project/neural-search/pull/827)
### Enhancements
- InferenceProcessor inherits from AbstractBatchingProcessor to support sub batching in processor [#820](https://github.com/opensearch-project/neural-search/pull/820)
- Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/)
- Enable '.' for nested field in text embedding processor ([#811](https://github.com/opensearch-project/neural-search/pull/811))
### Bug Fixes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.IndexFieldMapper;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.AbstractBatchingProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
Expand All @@ -46,7 +46,7 @@
* and set the target fields according to the field name map.
*/
@Log4j2
public abstract class InferenceProcessor extends AbstractProcessor {
public abstract class InferenceProcessor extends AbstractBatchingProcessor {

public static final String MODEL_ID_FIELD = "model_id";
public static final String FIELD_MAP_FIELD = "field_map";
Expand All @@ -69,6 +69,7 @@ public abstract class InferenceProcessor extends AbstractProcessor {
public InferenceProcessor(
String tag,
String description,
int batchSize,
String type,
String listTypeNestedMapKey,
String modelId,
Expand All @@ -77,7 +78,7 @@ public InferenceProcessor(
Environment environment,
ClusterService clusterService
) {
super(tag, description);
super(tag, description, batchSize);
this.type = type;
if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it");
validateEmbeddingConfiguration(fieldMap);
Expand Down Expand Up @@ -144,7 +145,7 @@ public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Ex
abstract void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException);

@Override
public void batchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
handler.accept(Collections.emptyList());
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ public final class SparseEncodingProcessor extends InferenceProcessor {
public SparseEncodingProcessor(
String tag,
String description,
int batchSize,
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment,
ClusterService clusterService
) {
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ public final class TextEmbeddingProcessor extends InferenceProcessor {
public TextEmbeddingProcessor(
String tag,
String description,
int batchSize,
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment,
ClusterService clusterService
) {
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
import org.opensearch.ingest.Processor;
import org.opensearch.ingest.AbstractBatchingProcessor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;

Expand All @@ -24,27 +24,23 @@
* Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
*/
@Log4j2
public class SparseEncodingProcessorFactory implements Processor.Factory {
public class SparseEncodingProcessorFactory extends AbstractBatchingProcessor.Factory {
private final MLCommonsClientAccessor clientAccessor;
private final Environment environment;
private final ClusterService clusterService;

public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService) {
super(TYPE);
this.clientAccessor = clientAccessor;
this.environment = environment;
this.clusterService = clusterService;
}

@Override
public SparseEncodingProcessor create(
Map<String, Processor.Factory> registry,
String processorTag,
String description,
Map<String, Object> config
) throws Exception {
String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD);
Map<String, Object> fieldMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD);

return new SparseEncodingProcessor(processorTag, description, modelId, fieldMap, clientAccessor, environment, clusterService);
protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map<String, Object> config) {
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
Map<String, Object> fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);

return new SparseEncodingProcessor(tag, description, batchSize, modelId, fieldMap, clientAccessor, environment, clusterService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
import org.opensearch.ingest.Processor;
import org.opensearch.ingest.AbstractBatchingProcessor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;

/**
* Factory for text embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
*/
public class TextEmbeddingProcessorFactory implements Processor.Factory {
public final class TextEmbeddingProcessorFactory extends AbstractBatchingProcessor.Factory {

private final MLCommonsClientAccessor clientAccessor;

Expand All @@ -34,20 +34,16 @@ public TextEmbeddingProcessorFactory(
final Environment environment,
final ClusterService clusterService
) {
super(TYPE);
this.clientAccessor = clientAccessor;
this.environment = environment;
this.clusterService = clusterService;
}

@Override
public TextEmbeddingProcessor create(
final Map<String, Processor.Factory> registry,
final String processorTag,
final String description,
final Map<String, Object> config
) throws Exception {
String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD);
Map<String, Object> filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD);
return new TextEmbeddingProcessor(processorTag, description, modelId, filedMap, clientAccessor, environment, clusterService);
protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map<String, Object> config) {
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
Map<String, Object> filedMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
return new TextEmbeddingProcessor(tag, description, batchSize, modelId, filedMap, clientAccessor, environment, clusterService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import static org.opensearch.ingest.ConfigurationUtils.readMap;
import static org.opensearch.ingest.ConfigurationUtils.readStringProperty;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.Factory;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.EMBEDDING_FIELD;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.FIELD_MAP_FIELD;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.MODEL_ID_FIELD;
Expand All @@ -16,6 +15,7 @@

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;

Expand All @@ -25,31 +25,18 @@
* Factory for text_image embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
*/
@AllArgsConstructor
public class TextImageEmbeddingProcessorFactory implements Factory {
public class TextImageEmbeddingProcessorFactory implements Processor.Factory {

private final MLCommonsClientAccessor clientAccessor;
private final Environment environment;
private final ClusterService clusterService;

@Override
public TextImageEmbeddingProcessor create(
final Map<String, Factory> registry,
final String processorTag,
final String description,
final Map<String, Object> config
) throws Exception {
String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD);
String embedding = readStringProperty(TYPE, processorTag, config, EMBEDDING_FIELD);
Map<String, String> filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD);
return new TextImageEmbeddingProcessor(
processorTag,
description,
modelId,
embedding,
filedMap,
clientAccessor,
environment,
clusterService
);
public Processor create(Map<String, Processor.Factory> processorFactories, String tag, String description, Map<String, Object> config)
throws Exception {
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
String embedding = readStringProperty(TYPE, tag, config, EMBEDDING_FIELD);
Map<String, String> filedMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
return new TextImageEmbeddingProcessor(tag, description, modelId, embedding, filedMap, clientAccessor, environment, clusterService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package org.opensearch.neuralsearch.processor;

import lombok.Getter;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.MockitoAnnotations;
Expand All @@ -15,6 +16,7 @@
import org.opensearch.ingest.IngestDocumentWrapper;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -42,6 +44,7 @@ public class InferenceProcessorTests extends InferenceProcessorTestCase {
private static final String DESCRIPTION = "description";
private static final String MAP_KEY = "map_key";
private static final String MODEL_ID = "model_id";
private static final int BATCH_SIZE = 10;
private static final Map<String, Object> FIELD_MAP = Map.of("key1", "embedding_key1", "key2", "embedding_key2");

@Before
Expand All @@ -54,7 +57,7 @@ public void setup() {
}

public void test_batchExecute_emptyInput() {
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), null);
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null);
Consumer resultHandler = mock(Consumer.class);
processor.batchExecute(Collections.emptyList(), resultHandler);
ArgumentCaptor<List<IngestDocumentWrapper>> captor = ArgumentCaptor.forClass(List.class);
Expand All @@ -65,7 +68,7 @@ public void test_batchExecute_emptyInput() {

public void test_batchExecute_allFailedValidation() {
final int docCount = 2;
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), null);
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null);
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1"));
wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1"));
Expand All @@ -83,7 +86,7 @@ public void test_batchExecute_allFailedValidation() {

public void test_batchExecute_partialFailedValidation() {
final int docCount = 2;
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), null);
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null);
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1"));
wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("value3", "value4"));
Expand All @@ -105,7 +108,7 @@ public void test_batchExecute_partialFailedValidation() {
public void test_batchExecute_happyCase() {
final int docCount = 2;
List<List<Float>> inferenceResults = createMockVectorWithLength(6);
TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, null);
TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, BATCH_SIZE, null);
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("value1", "value2"));
wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("value3", "value4"));
Expand All @@ -126,7 +129,7 @@ public void test_batchExecute_happyCase() {
public void test_batchExecute_sort() {
final int docCount = 2;
List<List<Float>> inferenceResults = createMockVectorWithLength(100);
TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, null);
TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, BATCH_SIZE, null);
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("aaaaa", "bbb"));
wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("cc", "ddd"));
Expand Down Expand Up @@ -158,7 +161,7 @@ public void test_batchExecute_sort() {
public void test_doBatchExecute_exception() {
final int docCount = 2;
List<List<Float>> inferenceResults = createMockVectorWithLength(6);
TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, new RuntimeException());
TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, BATCH_SIZE, new RuntimeException());
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("value1", "value2"));
wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("value3", "value4"));
Expand All @@ -174,12 +177,36 @@ public void test_doBatchExecute_exception() {
verify(clientAccessor).inferenceSentences(anyString(), anyList(), any());
}

public void test_batchExecute_subBatches() {
final int docCount = 5;
List<List<Float>> inferenceResults = createMockVectorWithLength(6);
TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, 2, null);
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
for (int i = 0; i < docCount; ++i) {
wrapperList.get(i).getIngestDocument().setFieldValue("key1", Collections.singletonList("value" + i));
}
List<IngestDocumentWrapper> allResults = new ArrayList<>();
processor.batchExecute(wrapperList, allResults::addAll);
for (int i = 0; i < docCount; ++i) {
assertEquals(allResults.get(i).getIngestDocument(), wrapperList.get(i).getIngestDocument());
assertEquals(allResults.get(i).getSlot(), wrapperList.get(i).getSlot());
assertEquals(allResults.get(i).getException(), wrapperList.get(i).getException());
}
assertEquals(3, processor.getAllInferenceInputs().size());
assertEquals(List.of("value0", "value1"), processor.getAllInferenceInputs().get(0));
assertEquals(List.of("value2", "value3"), processor.getAllInferenceInputs().get(1));
assertEquals(List.of("value4"), processor.getAllInferenceInputs().get(2));
}

private class TestInferenceProcessor extends InferenceProcessor {
List<?> vectors;
Exception exception;

public TestInferenceProcessor(List<?> vectors, Exception exception) {
super(TAG, DESCRIPTION, TYPE, MAP_KEY, MODEL_ID, FIELD_MAP, clientAccessor, environment, clusterService);
@Getter
List<List<String>> allInferenceInputs = new ArrayList<>();

public TestInferenceProcessor(List<?> vectors, int batchSize, Exception exception) {
super(TAG, DESCRIPTION, batchSize, TYPE, MAP_KEY, MODEL_ID, FIELD_MAP, clientAccessor, environment, clusterService);
this.vectors = vectors;
this.exception = exception;
}
Expand All @@ -196,6 +223,7 @@ public void doExecute(
void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
// use to verify if doBatchExecute is called from InferenceProcessor
clientAccessor.inferenceSentences(MODEL_ID, inferenceList, ActionListener.wrap(results -> {}, ex -> {}));
allInferenceInputs.add(inferenceList);
if (this.exception != null) {
onException.accept(this.exception);
} else {
Expand Down
Loading
Loading