Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Fix user defined preprocess function missing prediction issue #2426

Merged
merged 1 commit into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -27,6 +28,8 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.MLPreProcessFunction;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
Expand Down Expand Up @@ -99,6 +102,15 @@ private Tuple<Integer, Integer> calculateChunkSize(TextDocsInputDataSet textDocs
return Tuple.tuple(textDocsLength / stepSize + 1, stepSize);
}
} else {
Optional<ConnectorAction> predictAction = getConnector().findPredictAction();
if (predictAction.isEmpty()) {
throw new IllegalArgumentException("no predict action found");
}
String preProcessFunction = predictAction.get().getPreProcessFunction();
if (preProcessFunction != null && !MLPreProcessFunction.contains(preProcessFunction)) {
// user defined preprocess script, this case, the chunk size is always equals to text docs length.
return Tuple.tuple(textDocsLength, 1);
}
// consider as batch.
return Tuple.tuple(1, textDocsLength);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.engine.algorithms.remote;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.when;
Expand All @@ -30,6 +31,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ingest.TestTemplateService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
Expand All @@ -42,6 +44,7 @@
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.script.ScriptService;
import org.opensearch.threadpool.ThreadPool;

import com.google.common.collect.ImmutableList;
Expand All @@ -67,10 +70,15 @@ public class AwsConnectorExecutorTest {

Encryptor encryptor;

@Mock
private ScriptService scriptService;

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
when(scriptService.compile(any(), any()))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"hello world\"}"));
}

@Test
Expand Down Expand Up @@ -282,4 +290,80 @@ public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArg
Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
assert exceptionCaptor.getValue() instanceof IllegalArgumentException;
}

@Test
public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPredictionAction() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://openai.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("http")
.parameters(parameters)
.credential(credential)
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(executor.getClient()).thenReturn(client);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
executor
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
ArgumentCaptor<Exception> exceptionArgumentCaptor = ArgumentCaptor.forClass(Exception.class);
Mockito.verify(actionListener, times(1)).onFailure(exceptionArgumentCaptor.capture());
assert exceptionArgumentCaptor.getValue() instanceof IllegalArgumentException;
assert "no predict action found".equals(exceptionArgumentCaptor.getValue().getMessage());
}

@Test
public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPreProcessFunction() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://openai.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.preProcessFunction(
"\n StringBuilder builder = new StringBuilder();\n builder.append(\"\\\"\");\n String first = params.text_docs[0];\n builder.append(first);\n builder.append(\"\\\"\");\n def parameters = \"{\" +\"\\\"text_inputs\\\":\" + builder + \"}\";\n return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";"
)
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("http")
.parameters(parameters)
.credential(credential)
.actions(Arrays.asList(predictAction))
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(executor.getClient()).thenReturn(client);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
when(executor.getScriptService()).thenReturn(scriptService);

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
executor
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
}
}
Loading