diff --git a/release-notes/opensearch-knn.release-notes-2.15.0.0.md b/release-notes/opensearch-knn.release-notes-2.15.0.0.md index 3def01638..198c32ce9 100644 --- a/release-notes/opensearch-knn.release-notes-2.15.0.0.md +++ b/release-notes/opensearch-knn.release-notes-2.15.0.0.md @@ -9,6 +9,7 @@ Compatible with OpenSearch 2.15.0 * Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684) * Support script score when doc value is disabled and fix misusing DISI [#1696](https://github.com/opensearch-project/k-NN/pull/1696) * Add validation for pq m parameter before training starts [#1713](https://github.com/opensearch-project/k-NN/pull/1713) +* Block delete model requests if an index uses the model [#1722](https://github.com/opensearch-project/k-NN/pull/1722) ### Bug Fixes * Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692) * Update threshold value after new result is added [#1715](https://github.com/opensearch-project/k-NN/pull/1715) diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 7c5bb61ad..0b4538ec8 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -25,6 +25,7 @@ public class KNNConstants { public static final String VECTOR = "vector"; public static final String K = "k"; public static final String TYPE_KNN_VECTOR = "knn_vector"; + public static final String PROPERTIES = "properties"; public static final String METHOD_PARAMETER_EF_SEARCH = "ef_search"; public static final String METHOD_PARAMETER_EF_CONSTRUCTION = "ef_construction"; public static final String METHOD_PARAMETER_M = "m"; diff --git a/src/main/java/org/opensearch/knn/common/exception/DeleteModelWhenInTrainStateException.java b/src/main/java/org/opensearch/knn/common/exception/DeleteModelException.java similarity index 75% rename from src/main/java/org/opensearch/knn/common/exception/DeleteModelWhenInTrainStateException.java rename to src/main/java/org/opensearch/knn/common/exception/DeleteModelException.java index 00f6e6e80..d9590c3f8 100644 --- a/src/main/java/org/opensearch/knn/common/exception/DeleteModelWhenInTrainStateException.java +++ b/src/main/java/org/opensearch/knn/common/exception/DeleteModelException.java @@ -10,18 +10,18 @@ import org.opensearch.core.rest.RestStatus; /** - * Exception thrown when a model is deleted while it is in the training state. The RestStatus associated with this + * Exception thrown when a model is deleted while it is in the training state or in use by an index. The RestStatus associated with this * exception should be a {@link RestStatus#CONFLICT} because the request cannot be deleted due to the model being in - * the training state. + * the training state or in use by an index. */ -public class DeleteModelWhenInTrainStateException extends OpenSearchException { +public class DeleteModelException extends OpenSearchException { /** * Constructor * * @param msg detailed exception message * @param args arguments of the message */ - public DeleteModelWhenInTrainStateException(String msg, Object... args) { + public DeleteModelException(String msg, Object... args) { super(LoggerMessageFormat.format(msg, args)); } diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 6940fcd39..0bc6c5edb 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -49,7 +49,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.common.exception.DeleteModelWhenInTrainStateException; +import org.opensearch.knn.common.exception.DeleteModelException; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.plugin.transport.DeleteModelResponse; import org.opensearch.knn.plugin.transport.GetModelResponse; @@ -84,7 +84,7 @@ public interface ModelDao { /** - * Creates model index. It is possible that the 2 threads call this function simulateously. In this case, one + * Creates model index. It is possible that the 2 threads call this function simultaneously. In this case, one * thread will throw a ResourceAlreadyExistsException. This should be caught and handled. * * @param actionListener CreateIndexResponse listener @@ -527,7 +527,7 @@ public void delete(String modelId, ActionListener listener) // If model is in Training state, fail delete model request if (ModelState.TRAINING == getModelResponse.getModel().getModelMetadata().getState()) { String errorMessage = String.format("Cannot delete model [%s]. Model is still in training", modelId); - listener.onFailure(new DeleteModelWhenInTrainStateException(errorMessage)); + listener.onFailure(new DeleteModelException(errorMessage)); return; } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java index df0c26624..d7536a5a7 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java @@ -7,6 +7,7 @@ import lombok.Value; import lombok.extern.log4j.Log4j2; +import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.core.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.master.AcknowledgedResponse; @@ -22,16 +23,22 @@ import org.opensearch.common.Priority; import org.opensearch.common.inject.Inject; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.indices.IndicesService; +import org.opensearch.knn.common.exception.DeleteModelException; import org.opensearch.knn.indices.ModelGraveyard; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; import java.io.IOException; +import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; +import static java.util.stream.Collectors.toList; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.PLUGIN_NAME; /** @@ -42,6 +49,7 @@ public class UpdateModelGraveyardTransportAction extends TransportClusterManager UpdateModelGraveyardRequest, AcknowledgedResponse> { private UpdateModelGraveyardExecutor updateModelGraveyardExecutor; + private final IndicesService indicesService; @Inject public UpdateModelGraveyardTransportAction( @@ -49,7 +57,8 @@ public UpdateModelGraveyardTransportAction( ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, - IndexNameExpressionResolver indexNameExpressionResolver + IndexNameExpressionResolver indexNameExpressionResolver, + IndicesService indicesService ) { super( UpdateModelGraveyardAction.NAME, @@ -61,6 +70,7 @@ public UpdateModelGraveyardTransportAction( indexNameExpressionResolver ); this.updateModelGraveyardExecutor = new UpdateModelGraveyardExecutor(); + this.indicesService = indicesService; } @Override @@ -82,7 +92,7 @@ protected void clusterManagerOperation( // ClusterManager updates model graveyard based on request parameters clusterService.submitStateUpdateTask( PLUGIN_NAME, - new UpdateModelGraveyardTask(request.getModelId(), request.isRemoveRequest()), + new UpdateModelGraveyardTask(request.getModelId(), request.isRemoveRequest(), indicesService), ClusterStateTaskConfig.build(Priority.NORMAL), updateModelGraveyardExecutor, new ClusterStateTaskListener() { @@ -111,6 +121,7 @@ protected ClusterBlockException checkBlock(UpdateModelGraveyardRequest request, private static class UpdateModelGraveyardTask { String modelId; boolean isRemoveRequest; + IndicesService indicesService; } /** @@ -123,7 +134,8 @@ private static class UpdateModelGraveyardExecutor implements ClusterStateTaskExe * @return Represents the result of a batched execution of cluster state update tasks (UpdateModelGraveyardTasks) */ @Override - public ClusterTasksResult execute(ClusterState clusterState, List taskList) { + public ClusterTasksResult execute(ClusterState clusterState, List taskList) + throws IOException { // Check if the objects are not null and throw a customized NullPointerException Objects.requireNonNull(clusterState, "Cluster state must not be null"); @@ -146,6 +158,17 @@ public ClusterTasksResult execute(ClusterState cluster modelGraveyard.remove(task.getModelId()); continue; } + List indicesUsingModel = getIndicesUsingModel(clusterState, task); + // Throw exception if any indices are using the model + if (!indicesUsingModel.isEmpty()) { + throw new DeleteModelException( + String.format( + "Cannot delete model [%s]. Model is in use by the following indices %s, which must be deleted first.", + task.getModelId(), + indicesUsingModel + ) + ); + } modelGraveyard.add(task.getModelId()); } @@ -155,5 +178,50 @@ public ClusterTasksResult execute(ClusterState cluster ClusterState updatedClusterState = ClusterState.builder(clusterState).metadata(metaDataBuilder).build(); return new ClusterTasksResult.Builder().successes(taskList).build(updatedClusterState); } + + private List getIndicesUsingModel(ClusterState clusterState, UpdateModelGraveyardTask task) throws IOException { + Map indices = clusterState.metadata().indices(); + String[] knnIndicesList = indices.values() + .stream() + .filter(metadata -> "true".equals(metadata.getSettings().get("index.knn", "false"))) + .map(metadata -> metadata.getIndex().getName()) + .toArray(String[]::new); + if (knnIndicesList.length == 0) { + return Collections.emptyList(); + } + + return clusterState.metadata() + .findMappings(knnIndicesList, task.getIndicesService().getFieldFilter()) + .entrySet() + .stream() + .filter(entry -> entry.getValue() != null) + .filter(entry -> { + Object properties = entry.getValue().getSourceAsMap().get("properties"); + if (properties == null || properties instanceof Map == false) { + return false; + } + Map propertiesMap = (Map) properties; + return propertiesMapContainsModel(propertiesMap, task.getModelId()); + }) + .map(Map.Entry::getKey) + .collect(toList()); + } + + private boolean propertiesMapContainsModel(Map propertiesMap, String modelId) { + for (Map.Entry fieldsEntry : propertiesMap.entrySet()) { + if (fieldsEntry.getKey() != null && fieldsEntry.getValue() instanceof Map) { + Map innerMap = (Map) fieldsEntry.getValue(); + for (Map.Entry innerEntry : innerMap.entrySet()) { + // If model is in use, fail delete model request + if (innerEntry.getKey().equals(MODEL_ID) + && innerEntry.getValue() instanceof String + && innerEntry.getValue().equals(modelId)) { + return true; + } + } + } + } + return false; + } } } diff --git a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java index 1c9e75c3a..d87a63ad5 100644 --- a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java @@ -16,6 +16,7 @@ import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; +import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -36,13 +37,12 @@ import org.opensearch.test.hamcrest.OpenSearchAssertions; import java.io.IOException; -import java.util.Collection; -import java.util.Collections; -import java.util.EnumSet; -import java.util.Map; +import java.util.*; import java.util.concurrent.ExecutionException; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.*; +import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; public class KNNSingleNodeTestCase extends OpenSearchSingleNodeTestCase { @Override @@ -181,6 +181,38 @@ protected void addDoc(String index, String docId, String fieldName, String dummy assertEquals(response.status(), RestStatus.CREATED); } + /** + * Index a new model + */ + protected void addDoc(Model model) throws IOException, ExecutionException, InterruptedException { + ModelMetadata modelMetadata = model.getModelMetadata(); + + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(MODEL_ID, model.getModelID()) + .field(KNN_ENGINE, modelMetadata.getKnnEngine().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, modelMetadata.getSpaceType().getValue()) + .field(DIMENSION, modelMetadata.getDimension()) + .field(MODEL_STATE, modelMetadata.getState().getName()) + .field(MODEL_TIMESTAMP, modelMetadata.getTimestamp().toString()) + .field(MODEL_DESCRIPTION, modelMetadata.getDescription()) + .field(MODEL_ERROR, modelMetadata.getError()); + + if (model.getModelBlob() != null) { + builder.field(MODEL_BLOB_PARAMETER, Base64.getEncoder().encodeToString(model.getModelBlob())); + } + + builder.endObject(); + + IndexRequest indexRequest = new IndexRequest().index(MODEL_INDEX_NAME) + .id(model.getModelID()) + .source(builder) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + IndexResponse response = client().index(indexRequest).get(); + assertTrue(response.status() == RestStatus.CREATED || response.status() == RestStatus.OK); + } + /** * Run a search against a k-NN index */ diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 699c920ff..3a9b7d596 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -1334,6 +1334,7 @@ public void testSharedIndexState_whenOneIndexDeleted_thenSecondIndexIsStillSearc // will give 15 second buffer from that Thread.sleep(1000 * 45); validateSearchWorkflow(secondIndexName, testData.queries, 10); + deleteKNNIndex(secondIndexName); deleteModel(modelId); } diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index ee2c77d1a..3a25c3064 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -18,6 +18,7 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.ResourceAlreadyExistsException; import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.admin.indices.create.CreateIndexRequestBuilder; import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.core.action.ActionListener; import org.opensearch.action.DocWriteResponse; @@ -26,7 +27,6 @@ import org.opensearch.action.delete.DeleteAction; import org.opensearch.action.delete.DeleteRequestBuilder; import org.opensearch.action.delete.DeleteResponse; -import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.WriteRequest; import org.opensearch.action.support.master.AcknowledgedResponse; @@ -35,7 +35,8 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.knn.KNNSingleNodeTestCase; -import org.opensearch.knn.common.exception.DeleteModelWhenInTrainStateException; +import org.opensearch.knn.TestUtils; +import org.opensearch.knn.common.exception.DeleteModelException; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; @@ -48,13 +49,11 @@ import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest; -import org.opensearch.core.rest.RestStatus; import org.opensearch.knn.training.TrainingJobClusterStateListener; import java.io.IOException; import java.time.ZoneOffset; import java.time.ZonedDateTime; -import java.util.Base64; import java.util.Collections; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; @@ -66,17 +65,11 @@ import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; -import static org.opensearch.cluster.metadata.Metadata.builder; -import static org.opensearch.knn.common.KNNConstants.DIMENSION; -import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; -import static org.opensearch.knn.common.KNNConstants.MODEL_BLOB_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; -import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; -import static org.opensearch.knn.common.KNNConstants.MODEL_STATE; -import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP; +import static org.opensearch.knn.common.KNNConstants.PROPERTIES; +import static org.opensearch.knn.common.KNNConstants.TYPE; +import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; public class ModelDaoTests extends KNNSingleNodeTestCase { @@ -568,7 +561,7 @@ public void testDelete() throws IOException, InterruptedException { ActionListener deleteModelTrainingListener = ActionListener.wrap( response -> fail("Deleting model when model does not exist should throw ResourceNotFoundException"), exception -> { - assertTrue(exception instanceof DeleteModelWhenInTrainStateException); + assertTrue(exception instanceof DeleteModelException); assertFalse(modelDao.isModelInGraveyard(modelId)); inProgressLatch2.countDown(); } @@ -636,6 +629,91 @@ public void testDelete() throws IOException, InterruptedException { assertTrue(inProgressLatch3.await(100, TimeUnit.SECONDS)); } + // Test Delete Model when the model is in use by an index + public void testDeleteModelInUse() throws IOException, ExecutionException, InterruptedException { + String modelId = "test-model-id-training"; + ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); + byte[] modelBlob = "deleteModel".getBytes(); + int dimension = 2; + createIndex(MODEL_INDEX_NAME); + + Model model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "", + "", + MethodComponentContext.EMPTY + ), + modelBlob, + modelId + ); + + // created model and added it to index + addDoc(model); + + String testIndex = "test-index"; + String testField = "test-field"; + + /* + Constructs the following json: + { + "properties": { + "test-field": { + "type": "knn_vector", + "model_id": "test-model-id-training" + } + } + } + */ + XContentBuilder mappings = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES) + .startObject(testField) + .field(TYPE, TYPE_KNN_VECTOR) + .field(MODEL_ID, modelId) + .endObject() + .endObject() + .endObject(); + + XContentBuilder settings = XContentFactory.jsonBuilder().startObject().field(TestUtils.INDEX_KNN, "true").endObject(); + + // Create index using model + CreateIndexRequestBuilder createIndexRequestBuilder = client().admin() + .indices() + .prepareCreate(testIndex) + .setMapping(mappings) + .setSettings(settings); + createIndex(testIndex, createIndexRequestBuilder); + + CountDownLatch latch = new CountDownLatch(1); + modelDao.delete(modelId, new ActionListener() { + @Override + public void onResponse(DeleteModelResponse deleteModelResponse) { + fail("Received delete model response when the request should have failed."); + } + + @Override + public void onFailure(Exception e) { + assertTrue(e instanceof DeleteModelException); + assertEquals( + String.format( + "Cannot delete model [%s]. Model is in use by the following indices [%s], which must be deleted first.", + modelId, + testIndex + ), + e.getMessage() + ); + latch.countDown(); + } + }); + assertTrue(latch.await(60, TimeUnit.SECONDS)); + } + // Test Delete Model when modelId is in Model Graveyard (previous delete model request which failed to // remove modelId from model graveyard). But, the model does not exist public void testDeleteModelWithModelInGraveyardModelDoesNotExist() throws InterruptedException { @@ -911,34 +989,4 @@ public void testDeleteWithStepListenersOnFailureModelBlocked() throws Interrupte assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); } - - public void addDoc(Model model) throws IOException, ExecutionException, InterruptedException { - ModelMetadata modelMetadata = model.getModelMetadata(); - - XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .field(MODEL_ID, model.getModelID()) - .field(KNN_ENGINE, modelMetadata.getKnnEngine().getName()) - .field(METHOD_PARAMETER_SPACE_TYPE, modelMetadata.getSpaceType().getValue()) - .field(DIMENSION, modelMetadata.getDimension()) - .field(MODEL_STATE, modelMetadata.getState().getName()) - .field(MODEL_TIMESTAMP, modelMetadata.getTimestamp().toString()) - .field(MODEL_DESCRIPTION, modelMetadata.getDescription()) - .field(MODEL_ERROR, modelMetadata.getError()); - - if (model.getModelBlob() != null) { - builder.field(MODEL_BLOB_PARAMETER, Base64.getEncoder().encodeToString(model.getModelBlob())); - } - - builder.endObject(); - - IndexRequest indexRequest = new IndexRequest().index(MODEL_INDEX_NAME) - .id(model.getModelID()) - .source(builder) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - IndexResponse response = client().index(indexRequest).get(); - assertTrue(response.status() == RestStatus.CREATED || response.status() == RestStatus.OK); - } - } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index 9f28d5a71..9af1f49cc 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -34,6 +34,7 @@ import java.util.*; import static org.opensearch.knn.TestUtils.*; +import static org.opensearch.knn.TestUtils.PROPERTIES; import static org.opensearch.knn.common.KNNConstants.*; /** diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java index b1983b964..cd60d566c 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java @@ -5,18 +5,41 @@ package org.opensearch.knn.plugin.transport; +import org.opensearch.action.admin.indices.create.CreateIndexRequestBuilder; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.cluster.ClusterState; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.KNNSingleNodeTestCase; +import org.opensearch.knn.TestUtils; +import org.opensearch.knn.common.exception.DeleteModelException; +import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelGraveyard; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; import org.opensearch.threadpool.ThreadPool; import java.io.IOException; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; +import static org.opensearch.knn.common.KNNConstants.PROPERTIES; +import static org.opensearch.knn.common.KNNConstants.TYPE; +import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; + public class UpdateModelGraveyardTransportActionTests extends KNNSingleNodeTestCase { public void testExecutor() { @@ -165,4 +188,201 @@ public void testCheckBlock() { .getInstance(UpdateModelGraveyardTransportAction.class); assertNull(updateModelGraveyardTransportAction.checkBlock(null, null)); } + + public void testGetIndicesUsingModel() throws IOException, ExecutionException, InterruptedException { + // Get update transport action + UpdateModelGraveyardTransportAction updateModelGraveyardTransportAction = node().injector() + .getInstance(UpdateModelGraveyardTransportAction.class); + + String modelId = "test-model-id"; + byte[] modelBlob = "testModel".getBytes(); + int dimension = 2; + + createIndex(MODEL_INDEX_NAME); + + Model model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "", + "", + MethodComponentContext.EMPTY + ), + modelBlob, + modelId + ); + + // created model and added it to index + addDoc(model); + + // Create basic index (not using k-NN) + String testIndex1 = "test-index1"; + createIndex(testIndex1); + + // Attempt to add model id to graveyard with one non-knn index present, should succeed + UpdateModelGraveyardRequest addModelGraveyardRequest = new UpdateModelGraveyardRequest(modelId, false); + updateModelGraveyardAndAssertNoError(updateModelGraveyardTransportAction, addModelGraveyardRequest); + + // Remove model from graveyard to prepare for next check + UpdateModelGraveyardRequest removeModelGraveyardRequest = new UpdateModelGraveyardRequest(modelId, true); + updateModelGraveyardAndAssertNoError(updateModelGraveyardTransportAction, removeModelGraveyardRequest); + + // Create k-NN index not using the model + String testIndex2 = "test-index2"; + createKNNIndex(testIndex2); + + // Attempt to add model id to graveyard with one non-knn index and one k-nn index not using model present, should succeed + updateModelGraveyardAndAssertNoError(updateModelGraveyardTransportAction, addModelGraveyardRequest); + + // Remove model from graveyard to prepare for next check + updateModelGraveyardAndAssertNoError(updateModelGraveyardTransportAction, removeModelGraveyardRequest); + + // Create k-NN index using model + String testIndex3 = "test-index3"; + String testField3 = "test-field3"; + + /* + Constructs the following json: + { + "properties": { + "test-field3": { + "type": "knn_vector", + "model_id": "test-model-id" + } + } + } + */ + XContentBuilder mappings3 = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES) + .startObject(testField3) + .field(TYPE, TYPE_KNN_VECTOR) + .field(MODEL_ID, modelId) + .endObject() + .endObject() + .endObject(); + + XContentBuilder settings = XContentFactory.jsonBuilder().startObject().field(TestUtils.INDEX_KNN, "true").endObject(); + + CreateIndexRequestBuilder createIndexRequestBuilder3 = client().admin() + .indices() + .prepareCreate(testIndex3) + .setMapping(mappings3) + .setSettings(settings); + createIndex(testIndex3, createIndexRequestBuilder3); + + // Attempt to add model id to graveyard when one index is using model, should fail + List indicesUsingModel = new ArrayList<>(); + indicesUsingModel.add(testIndex3); + updateModelGraveyardAndAssertDeleteModelException( + updateModelGraveyardTransportAction, + addModelGraveyardRequest, + indicesUsingModel.toString() + ); + + // Create second k-NN index using model + String testIndex4 = "test-index4"; + String testField4 = "test-field4"; + String standardField = "standard-field"; + + /* + Constructs the following json: + { + "properties": { + "standard-field": { + "type": "knn_vector", + "dimension": "2" + } + "test-field4": { + "type": "knn_vector", + "model_id": "test-model-id" + } + } + } + */ + XContentBuilder mappings4 = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES) + .startObject(standardField) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, dimension) + .endObject() + .startObject(testField4) + .field(TYPE, TYPE_KNN_VECTOR) + .field(MODEL_ID, modelId) + .endObject() + .endObject() + .endObject(); + + CreateIndexRequestBuilder createIndexRequestBuilder4 = client().admin() + .indices() + .prepareCreate(testIndex4) + .setMapping(mappings4) + .setSettings(settings); + createIndex(testIndex4, createIndexRequestBuilder4); + + // Add index at beginning to match order of list returned by getIndicesUsingModel() + indicesUsingModel.add(0, testIndex4); + + // Attempt to add model id to graveyard when one index is using model, should fail + updateModelGraveyardAndAssertDeleteModelException( + updateModelGraveyardTransportAction, + addModelGraveyardRequest, + indicesUsingModel.toString() + ); + } + + public void updateModelGraveyardAndAssertNoError( + UpdateModelGraveyardTransportAction updateModelGraveyardTransportAction, + UpdateModelGraveyardRequest updateModelGraveyardRequest + ) throws InterruptedException { + final CountDownLatch countDownLatch = new CountDownLatch(1); + client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse1 -> { + ClusterState clusterState1 = stateResponse1.getState(); + updateModelGraveyardTransportAction.clusterManagerOperation( + updateModelGraveyardRequest, + clusterState1, + ActionListener.wrap(acknowledgedResponse -> { + assertTrue(acknowledgedResponse.isAcknowledged()); + countDownLatch.countDown(); + }, e -> { fail("Update failed: " + e); }) + ); + }, e -> fail("Update failed: " + e))); + assertTrue(countDownLatch.await(60, TimeUnit.SECONDS)); + } + + public void updateModelGraveyardAndAssertDeleteModelException( + UpdateModelGraveyardTransportAction updateModelGraveyardTransportAction, + UpdateModelGraveyardRequest updateModelGraveyardRequest, + String indicesPresentInException + ) throws InterruptedException { + final CountDownLatch countDownLatch = new CountDownLatch(1); + client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse1 -> { + ClusterState clusterState1 = stateResponse1.getState(); + updateModelGraveyardTransportAction.clusterManagerOperation( + updateModelGraveyardRequest, + clusterState1, + ActionListener.wrap(acknowledgedResponse -> { + fail(); + }, e -> { + assertTrue(e instanceof DeleteModelException); + assertEquals( + String.format( + "Cannot delete model [%s]. Model is in use by the following indices %s, which must be deleted first.", + updateModelGraveyardRequest.getModelId(), + indicesPresentInException + ), + e.getMessage() + ); + countDownLatch.countDown(); + }) + ); + }, e -> fail("Update failed: " + e))); + + assertTrue(countDownLatch.await(60, TimeUnit.SECONDS)); + } } diff --git a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java index c90afaa62..656b95c2e 100644 --- a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java +++ b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java @@ -312,6 +312,8 @@ public void testRecall_whenFaissIVFFP32_thenRecallAbove75percent() { createIndexAndIngestDocs(indexName, TEST_FIELD_NAME, getSettings(), getModelMapping()); assertRecall(indexName, spaceType, 0.25f); + deleteIndex(indexName); + // Delete the model deleteModel(TEST_MODEL_ID); } @@ -387,6 +389,8 @@ public void testRecall_whenFaissIVFPQFP32_thenRecallAbove50percent() { createIndexAndIngestDocs(indexName, TEST_FIELD_NAME, getSettings(), getModelMapping()); assertRecall(indexName, spaceType, 0.5f); + deleteIndex(indexName); + // Delete the model deleteModel(TEST_MODEL_ID); } @@ -463,6 +467,8 @@ public void testRecall_whenFaissHNSWPQFP32_thenRecallAbove50percent() { createIndexAndIngestDocs(indexName, TEST_FIELD_NAME, getSettings(), getModelMapping()); assertRecall(indexName, spaceType, 0.5f); + deleteIndex(indexName); + // Delete the model deleteModel(TEST_MODEL_ID); }