diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index db5dd1fa6..fdf2459df 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -8,6 +8,7 @@ import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray; import java.io.IOException; +import java.net.URISyntaxException; import java.nio.file.Files; import java.nio.file.Path; import java.util.Collections; @@ -76,6 +77,7 @@ protected void updateClusterSettings() { updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false); // default threshold for native circuit breaker is 90, it may be not enough on test runner machine updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100); + updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true); } @SneakyThrows @@ -99,6 +101,10 @@ protected void updateClusterSettings(String settingKey, Object value) { } protected String uploadModel(String requestBody) throws Exception { + String modelGroupId = registerModelGroup(); + // model group id is dynamically generated, we need to update model update request body after group is registered + requestBody = requestBody.replace("", modelGroupId); + Response uploadResponse = makeRequest( client(), "POST", @@ -677,4 +683,27 @@ protected String getDeployedModelId() { assertEquals(1, modelIds.size()); return modelIds.iterator().next(); } + + @SneakyThrows + private String registerModelGroup() throws IOException, URISyntaxException { + String modelGroupRegisterRequestBody = Files.readString( + Path.of(classLoader.getResource("processor/CreateModelGroupRequestBody.json").toURI()) + ); + Response modelGroupResponse = makeRequest( + client(), + "POST", + "/_plugins/_ml/model_groups/_register", + null, + toHttpEntity(modelGroupRegisterRequestBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Map modelGroupResJson = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(modelGroupResponse.getEntity()), + false + ); + String modelGroupId = modelGroupResJson.get("model_group_id").toString(); + assertNotNull(modelGroupId); + return modelGroupId; + } } diff --git a/src/test/resources/processor/CreateModelGroupRequestBody.json b/src/test/resources/processor/CreateModelGroupRequestBody.json new file mode 100644 index 000000000..d6d398c76 --- /dev/null +++ b/src/test/resources/processor/CreateModelGroupRequestBody.json @@ -0,0 +1,5 @@ +{ + "name": "test_model_group_public", + "description": "This is a public model group", + "access_mode": "public" +} \ No newline at end of file diff --git a/src/test/resources/processor/UploadModelRequestBody.json b/src/test/resources/processor/UploadModelRequestBody.json index 9fc53f3b9..95f9c9cb5 100644 --- a/src/test/resources/processor/UploadModelRequestBody.json +++ b/src/test/resources/processor/UploadModelRequestBody.json @@ -4,6 +4,7 @@ "model_format": "TORCH_SCRIPT", "model_task_type": "text_embedding", "model_content_hash_value": "e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021", + "model_group_id": "", "model_config": { "model_type": "bert", "embedding_dimension": 768,