Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix user defined preprocess function missing prediction issue #2418

Merged
merged 5 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -27,6 +28,8 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.MLPreProcessFunction;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
Expand Down Expand Up @@ -99,6 +102,15 @@ private Tuple<Integer, Integer> calculateChunkSize(TextDocsInputDataSet textDocs
return Tuple.tuple(textDocsLength / stepSize + 1, stepSize);
}
} else {
Optional<ConnectorAction> predictAction = getConnector().findPredictAction();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have unit test for this section?

if (predictAction.isEmpty()) {
throw new IllegalArgumentException("no predict action found");
}
String preProcessFunction = predictAction.get().getPreProcessFunction();
if (preProcessFunction != null && !MLPreProcessFunction.contains(preProcessFunction)) {
// user defined preprocess script, this case, the chunk size is always equals to text docs length.
return Tuple.tuple(textDocsLength, 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this understanding correct ?

  1. textDocsLength means how many chunks
  2. 1 means step or chunk size ?

If correct, why hard code the chunk size as 1 ? This issue #2417 is an example, it's not the only case. For example some user may process two documents in one chunk.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

textDocsLength is chunks, 1 means step. When user process two documents in one chunk, user has to specify the input_docs_processed_step_size, in this case, this won't enter this branch. Multi-modal is a case that user needs to handle two documents in one chunk, I've tested this case and it works well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for testing multi-modal case.
It's possible that step size > 1 and no input_docs_processed_step_size. For example, before async http client, user can confiture pre_process function to read three documents and construct "inputs": [doc1, doc2, doc3], when response comes back, will move to next 3 docs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a rare case, and we have workaround which is to add input_docs_processed_step_size configuration.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, based on our discussion, this sounds good

}
// consider as batch.
return Tuple.tuple(1, textDocsLength);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.engine.algorithms.remote;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.when;
Expand All @@ -30,6 +31,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ingest.TestTemplateService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
Expand All @@ -42,6 +44,7 @@
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.script.ScriptService;
import org.opensearch.threadpool.ThreadPool;

import com.google.common.collect.ImmutableList;
Expand All @@ -67,10 +70,15 @@ public class AwsConnectorExecutorTest {

Encryptor encryptor;

@Mock
private ScriptService scriptService;

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
when(scriptService.compile(any(), any()))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"hello world\"}"));
}

@Test
Expand Down Expand Up @@ -282,4 +290,80 @@ public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArg
Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
assert exceptionCaptor.getValue() instanceof IllegalArgumentException;
}

@Test
public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPredictionAction() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://openai.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("http")
.parameters(parameters)
.credential(credential)
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(executor.getClient()).thenReturn(client);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
executor
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
ArgumentCaptor<Exception> exceptionArgumentCaptor = ArgumentCaptor.forClass(Exception.class);
Mockito.verify(actionListener, times(1)).onFailure(exceptionArgumentCaptor.capture());
assert exceptionArgumentCaptor.getValue() instanceof IllegalArgumentException;
assert "no predict action found".equals(exceptionArgumentCaptor.getValue().getMessage());
}

@Test
public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPreProcessFunction() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://openai.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.preProcessFunction(
"\n StringBuilder builder = new StringBuilder();\n builder.append(\"\\\"\");\n String first = params.text_docs[0];\n builder.append(first);\n builder.append(\"\\\"\");\n def parameters = \"{\" +\"\\\"text_inputs\\\":\" + builder + \"}\";\n return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";"
)
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("http")
.parameters(parameters)
.credential(credential)
.actions(Arrays.asList(predictAction))
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(executor.getClient()).thenReturn(client);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
when(executor.getScriptService()).thenReturn(scriptService);

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
executor
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
}
}
Loading