Skip to content

Commit

Permalink
Fixed fialing tests after ml-commons added model_gropu_id feature (#262)
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Aug 24, 2023
1 parent 70675c8 commit d12f480
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray;

import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Collections;
Expand Down Expand Up @@ -76,6 +77,7 @@ protected void updateClusterSettings() {
updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false);
// default threshold for native circuit breaker is 90, it may be not enough on test runner machine
updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100);
updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true);
}

@SneakyThrows
Expand All @@ -99,6 +101,10 @@ protected void updateClusterSettings(String settingKey, Object value) {
}

protected String uploadModel(String requestBody) throws Exception {
String modelGroupId = registerModelGroup();
// model group id is dynamically generated, we need to update model update request body after group is registered
requestBody = requestBody.replace("<MODEL_GROUP_ID>", modelGroupId);

Response uploadResponse = makeRequest(
client(),
"POST",
Expand Down Expand Up @@ -677,4 +683,27 @@ protected String getDeployedModelId() {
assertEquals(1, modelIds.size());
return modelIds.iterator().next();
}

@SneakyThrows
private String registerModelGroup() throws IOException, URISyntaxException {
String modelGroupRegisterRequestBody = Files.readString(
Path.of(classLoader.getResource("processor/CreateModelGroupRequestBody.json").toURI())
);
Response modelGroupResponse = makeRequest(
client(),
"POST",
"/_plugins/_ml/model_groups/_register",
null,
toHttpEntity(modelGroupRegisterRequestBody),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
);
Map<String, Object> modelGroupResJson = XContentHelper.convertToMap(
XContentType.JSON.xContent(),
EntityUtils.toString(modelGroupResponse.getEntity()),
false
);
String modelGroupId = modelGroupResJson.get("model_group_id").toString();
assertNotNull(modelGroupId);
return modelGroupId;
}
}
5 changes: 5 additions & 0 deletions src/test/resources/processor/CreateModelGroupRequestBody.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"name": "test_model_group_public",
"description": "This is a public model group",
"access_mode": "public"
}
1 change: 1 addition & 0 deletions src/test/resources/processor/UploadModelRequestBody.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"model_format": "TORCH_SCRIPT",
"model_task_type": "text_embedding",
"model_content_hash_value": "e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021",
"model_group_id": "<MODEL_GROUP_ID>",
"model_config": {
"model_type": "bert",
"embedding_dimension": 768,
Expand Down

0 comments on commit d12f480

Please sign in to comment.