Skip to content

Commit

Permalink
Even more coverage
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Jul 20, 2024
1 parent de9de39 commit 52c63dd
Showing 1 changed file with 83 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
Expand Down Expand Up @@ -221,6 +222,88 @@ public void testRegisterLocalCustomModelSuccess() throws Exception {
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));
}

// This method tests code in the abstract parent
public void testRegisterLocalCustomModelDeployStateUpdateFail() throws Exception {
String taskId = "abcd";
String modelId = "model-id";
String status = MLTaskState.COMPLETED.name();

// Stub register for success case
doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(1);
MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any());

// Stub getTask for success case
doAnswer(invocation -> {
ActionListener<MLTask> actionListener = invocation.getArgument(1);
MLTask output = new MLTask(
taskId,
modelId,
null,
null,
MLTaskState.COMPLETED,
null,
null,
null,
null,
null,
null,
null,
null,
false
);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).getTask(any(), any());

AtomicInteger invocationCount = new AtomicInteger(0);
doAnswer(invocation -> {
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(4);
if (invocationCount.getAndIncrement() == 0) {
// succeed on first call (update register)
updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id"));
} else {
// fail on second call (update deploy)
updateResponseListener.onFailure(new RuntimeException("Failed to update deploy resource"));
}
return null;
}).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any());

WorkflowData boolStringWorkflowData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "xyz"),
Map.entry("version", "1.0.0"),
Map.entry("description", "description"),
Map.entry("function_name", "SPARSE_TOKENIZE"),
Map.entry("model_format", "TORCH_SCRIPT"),
Map.entry(MODEL_GROUP_ID, "abcdefg"),
Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"),
Map.entry("model_type", "bert"),
Map.entry("embedding_dimension", "384"),
Map.entry("framework_type", "sentence_transformers"),
Map.entry("url", "something.com"),
Map.entry(DEPLOY_FIELD, "true")
),
"test-id",
"test-node-id"
);

PlainActionFuture<WorkflowData> future = registerLocalModelStep.execute(
boolStringWorkflowData.getNodeId(),
boolStringWorkflowData,
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap()
);

ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Failed to update simulated deploy step resource model-id", ex.getCause().getMessage());
}

public void testRegisterLocalCustomModelFailure() {

doAnswer(invocation -> {
Expand Down

0 comments on commit 52c63dd

Please sign in to comment.