Skip to content

Commit

Permalink
Fix the flaky test due to m_l_limit_exceeded_exception (opensearch-pr…
Browse files Browse the repository at this point in the history
…oject#150) (opensearch-project#164)

* increase the CB threshold, delete model after test

* add log

* add wait time

* enhancement: wait model undeploy before delete; refactor the wait response logic

* modify ci yml

---------

(cherry picked from commit 0791c34)

Signed-off-by: zhichao-aws <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
2 people authored and yuye-aws committed Apr 26, 2024
1 parent 1042f1a commit 362f8c3
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 19 deletions.
12 changes: 3 additions & 9 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ jobs:
needs: Get-CI-Image-Tag
strategy:
matrix:
java:
- 11
- 17
- 21.0.1
java: [11, 17, 21]
name: Build and Test skills plugin on Linux
runs-on: ubuntu-latest
container:
Expand Down Expand Up @@ -71,7 +68,7 @@ jobs:
build-MacOS:
strategy:
matrix:
java: [ 11, 17 ]
java: [11, 17, 21]

name: Build and Test skills Plugin on MacOS
needs: Get-CI-Image-Tag
Expand All @@ -95,10 +92,7 @@ jobs:
build-windows:
strategy:
matrix:
java:
- 11
- 17
- 21.0.1
java: [11, 17, 21]
name: Build and Test skills plugin on Windows
needs: Get-CI-Image-Tag
runs-on: windows-latest
Expand Down
52 changes: 42 additions & 10 deletions src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import org.apache.commons.lang3.StringUtils;
Expand All @@ -35,6 +36,7 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand All @@ -57,6 +59,7 @@ public void updateClusterSettings() {
updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false);
// default threshold for native circuit breaker is 90, it may be not enough on test runner machine
updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100);
updateClusterSettings("plugins.ml_commons.jvm_heap_memory_threshold", 100);
updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true);
}

Expand Down Expand Up @@ -123,26 +126,35 @@ protected String indexMonitor(String monitorAsJsonString) {
}

@SneakyThrows
protected Map<String, Object> waitTaskComplete(String taskId) {
protected Map<String, Object> waitResponseMeetingCondition(
String method,
String endpoint,
String jsonEntity,
Predicate<Map<String, Object>> condition
) {
for (int i = 0; i < MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; i++) {
Response response = makeRequest(client(), "GET", "/_plugins/_ml/tasks/" + taskId, null, (String) null, null);
Response response = makeRequest(client(), method, endpoint, null, jsonEntity, null);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
Map<String, Object> responseInMap = parseResponseToMap(response);
String state = responseInMap.get(MLTask.STATE_FIELD).toString();
if (state.equals(MLTaskState.COMPLETED.toString())) {
if (condition.test(responseInMap)) {
return responseInMap;
}
if (state.equals(MLTaskState.FAILED.toString())
|| state.equals(MLTaskState.CANCELLED.toString())
|| state.equals(MLTaskState.COMPLETED_WITH_ERROR.toString())) {
fail("The task failed with state " + state);
}
logger.info("The " + i + "-th response: " + responseInMap.toString());
Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND);
}
fail("The task failed to complete after " + MAX_TASK_RESULT_QUERY_TIME_IN_SECOND + " seconds.");
fail("The response failed to meet condition after " + MAX_TASK_RESULT_QUERY_TIME_IN_SECOND + " seconds.");
return null;
}

@SneakyThrows
protected Map<String, Object> waitTaskComplete(String taskId) {
Predicate<Map<String, Object>> condition = responseInMap -> {
String state = responseInMap.get(MLTask.STATE_FIELD).toString();
return state.equals(MLTaskState.COMPLETED.toString());
};
return waitResponseMeetingCondition("GET", "/_plugins/_ml/tasks/" + taskId, (String) null, condition);
}

// Register the model then deploy it. Returns the model_id until the model is deployed
protected String registerModelThenDeploy(String requestBody) {
String registerModelTaskId = registerModel(requestBody);
Expand All @@ -153,6 +165,26 @@ protected String registerModelThenDeploy(String requestBody) {
return modelId;
}

@SneakyThrows
private void waitModelUndeployed(String modelId) {
Predicate<Map<String, Object>> condition = responseInMap -> {
String state = responseInMap.get(MLModel.MODEL_STATE_FIELD).toString();
return !state.equals(MLModelState.DEPLOYED.toString())
&& !state.equals(MLModelState.DEPLOYING.toString())
&& !state.equals(MLModelState.PARTIALLY_DEPLOYED.toString());
};
waitResponseMeetingCondition("GET", "/_plugins/_ml/models/" + modelId, (String) null, condition);
return;
}

@SneakyThrows
protected void deleteModel(String modelId) {
// need to undeploy first as model can be in use
makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, (String) null, null);
waitModelUndeployed(modelId);
makeRequest(client(), "DELETE", "/_plugins/_ml/models/" + modelId, null, (String) null, null);
}

protected void createIndexWithConfiguration(String indexName, String indexConfiguration) throws Exception {
Response response = makeRequest(client(), "PUT", indexName, null, indexConfiguration, null);
Map<String, Object> responseInMap = parseResponseToMap(response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ public void setUp() {
public void tearDown() {
super.tearDown();
deleteExternalIndices();
deleteModel(modelId);
}

public void testNeuralSparseSearchToolInFlowAgent() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ public void stopMockLLM() {
server.stop(1);
}

@After
public void deleteModel() {
deleteModel(modelId);
}

private String setUpConnector() {
String url = String.format(Locale.ROOT, "http://127.0.0.1:%d/invoke", server.getAddress().getPort());
return createConnector(
Expand Down

0 comments on commit 362f8c3

Please sign in to comment.