Skip to content

Commit

Permalink
Add coverage and changelog
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Aug 12, 2024
1 parent 3024b91 commit 2648713
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
### Documentation
### Maintenance
### Refactoring
- Refactor workflow step resource updates to eliminate duplication ([#796](https://github.com/opensearch-project/flow-framework/pull/796))
Original file line number Diff line number Diff line change
Expand Up @@ -667,13 +667,13 @@ public void updateFlowFrameworkSystemIndexDocWithScript(
/**
* Creates a new ResourceCreated object and a script to update the state index
* @param workflowId workflowId for the relevant step
* @param nodeId WorkflowData object with relevent step information
* @param nodeId current process node (workflow step) id
* @param workflowStepName the workflowstep name that created the resource
* @param resourceId the id of the newly created resource
* @param listener the ActionListener for this step to handle completing the future after update
* @throws IOException if parsing fails on new resource
*/
public void updateResourceInStateIndex(
private void updateResourceInStateIndex(
String workflowId,
String nodeId,
String workflowStepName,
Expand Down Expand Up @@ -704,7 +704,7 @@ public void updateResourceInStateIndex(
/**
* Adds a resource to the state index, including common exception handling
* @param currentNodeInputs Inputs to the current node
* @param nodeId WorkflowData object with relevent step information
* @param nodeId current process node (workflow step) id
* @param workflowStepName the workflow step name that created the resource
* @param resourceId the id of the newly created resource
* @param listener the ActionListener for this step to handle completing the future after update
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.flowframework.TestHelpers;
import org.opensearch.flowframework.common.WorkflowResources;
import org.opensearch.flowframework.model.ProvisioningProgress;
import org.opensearch.flowframework.model.ResourceCreated;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.model.WorkflowState;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.workflow.CreateConnectorStep;
import org.opensearch.flowframework.workflow.CreateIndexStep;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.index.get.GetResult;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -488,4 +491,52 @@ public void testDeleteFlowFrameworkSystemIndexDoc() throws IOException {
exceptionCaptor.getValue().getMessage()
);
}

public void testAddResourceToStateIndex() throws IOException {
ClusterState mockClusterState = mock(ClusterState.class);
Metadata mockMetaData = mock(Metadata.class);
when(clusterService.state()).thenReturn(mockClusterState);
when(mockClusterState.metadata()).thenReturn(mockMetaData);
when(mockMetaData.hasIndex(WORKFLOW_STATE_INDEX)).thenReturn(true);

@SuppressWarnings("unchecked")
ActionListener<WorkflowData> listener = mock(ActionListener.class);
// test success
doAnswer(invocation -> {
ActionListener<UpdateResponse> responseListener = invocation.getArgument(1);
responseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "this_id", -2, 0, 0, Result.UPDATED));
return null;
}).when(client).update(any(UpdateRequest.class), any());

flowFrameworkIndicesHandler.addResourceToStateIndex(
new WorkflowData(Collections.emptyMap(), null, null),
"node_id",
CreateConnectorStep.NAME,
"this_id",
listener
);

ArgumentCaptor<WorkflowData> responseCaptor = ArgumentCaptor.forClass(WorkflowData.class);
verify(listener, times(1)).onResponse(responseCaptor.capture());
assertEquals("this_id", responseCaptor.getValue().getContent().get(WorkflowResources.CONNECTOR_ID));

// test failure
doAnswer(invocation -> {
ActionListener<UpdateResponse> responseListener = invocation.getArgument(1);
responseListener.onFailure(new Exception("Failed to update state"));
return null;
}).when(client).update(any(UpdateRequest.class), any());

flowFrameworkIndicesHandler.addResourceToStateIndex(
new WorkflowData(Collections.emptyMap(), null, null),
"node_id",
CreateConnectorStep.NAME,
"this_id",
listener
);

ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(exceptionCaptor.capture());
assertEquals("Failed to update new created node_id resource create_connector id this_id", exceptionCaptor.getValue().getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static org.opensearch.flowframework.common.CommonValue.DELAY_FIELD;

Expand Down Expand Up @@ -50,6 +53,49 @@ public void testNoOpStepDelay() throws IOException, InterruptedException {
assertTrue(System.nanoTime() - start > 900_000_000L);
}

public void testNoOpStepInterrupt() throws IOException, InterruptedException {
NoOpStep noopStep = new NoOpStep();
WorkflowData delayData = new WorkflowData(Map.of(DELAY_FIELD, "5s"), null, null);

CountDownLatch latch = new CountDownLatch(1);
// Fetch errors from the separate thread
AtomicReference<AssertionError> assertionError = new AtomicReference<>();

Thread testThread = new Thread(() -> {
try {
PlainActionFuture<WorkflowData> future = noopStep.execute(
"nodeId",
delayData,
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap()
);
try {
future.actionGet();
} catch (Exception e) {
// Ignore the IllegalStateExcption/InterruptedExcpetion
}
assertTrue(future.isDone());
assertTrue(future.isCancelled());
assertTrue(Thread.currentThread().isInterrupted());
} catch (AssertionError e) {
assertionError.set(e);
} finally {
latch.countDown();
}
});

testThread.start();
Thread.sleep(100);
testThread.interrupt();

latch.await(1, TimeUnit.SECONDS);

if (assertionError.get() != null) {
throw assertionError.get();
}
}

public void testNoOpStepParse() throws IOException {
NoOpStep noopStep = new NoOpStep();
WorkflowData delayData = new WorkflowData(Map.of(DELAY_FIELD, "foo"), null, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
Expand Down Expand Up @@ -221,6 +222,88 @@ public void testRegisterLocalCustomModelSuccess() throws Exception {
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));
}

// This method tests code in the abstract parent
public void testRegisterLocalCustomModelDeployStateUpdateFail() throws Exception {
String taskId = "abcd";
String modelId = "model-id";
String status = MLTaskState.COMPLETED.name();

// Stub register for success case
doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(1);
MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any());

// Stub getTask for success case
doAnswer(invocation -> {
ActionListener<MLTask> actionListener = invocation.getArgument(1);
MLTask output = new MLTask(
taskId,
modelId,
null,
null,
MLTaskState.COMPLETED,
null,
null,
null,
null,
null,
null,
null,
null,
false
);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).getTask(any(), any());

AtomicInteger invocationCount = new AtomicInteger(0);
doAnswer(invocation -> {
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(4);
if (invocationCount.getAndIncrement() == 0) {
// succeed on first call (update register)
updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id"));
} else {
// fail on second call (update deploy)
updateResponseListener.onFailure(new RuntimeException("Failed to update deploy resource"));
}
return null;
}).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any());

WorkflowData boolStringWorkflowData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "xyz"),
Map.entry("version", "1.0.0"),
Map.entry("description", "description"),
Map.entry("function_name", "SPARSE_TOKENIZE"),
Map.entry("model_format", "TORCH_SCRIPT"),
Map.entry(MODEL_GROUP_ID, "abcdefg"),
Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"),
Map.entry("model_type", "bert"),
Map.entry("embedding_dimension", "384"),
Map.entry("framework_type", "sentence_transformers"),
Map.entry("url", "something.com"),
Map.entry(DEPLOY_FIELD, "true")
),
"test-id",
"test-node-id"
);

PlainActionFuture<WorkflowData> future = registerLocalModelStep.execute(
boolStringWorkflowData.getNodeId(),
boolStringWorkflowData,
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap()
);

ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Failed to update simulated deploy step resource model-id", ex.getCause().getMessage());
}

public void testRegisterLocalCustomModelFailure() {

doAnswer(invocation -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;

import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
Expand Down Expand Up @@ -224,6 +225,99 @@ public void testRegisterRemoteModelFailure() {

}

public void testRegisterRemoteModelUpdateFailure() {
String taskId = "abcd";
String modelId = "efgh";
String status = MLTaskState.CREATED.name();

doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(1);
MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId);
actionListener.onResponse(output);
return null;
}).when(mlNodeClient).register(any(MLRegisterModelInput.class), any());

doAnswer(invocation -> {
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(4);
updateResponseListener.onFailure(new RuntimeException("Failed to update register resource"));
return null;
}).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any());

WorkflowData deployWorkflowData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "xyz"),
Map.entry("description", "description"),
Map.entry(CONNECTOR_ID, "abcdefg"),
Map.entry(DEPLOY_FIELD, true)
),
"test-id",
"test-node-id"
);

PlainActionFuture<WorkflowData> future = this.registerRemoteModelStep.execute(
deployWorkflowData.getNodeId(),
deployWorkflowData,
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap()
);

assertTrue(future.isDone());
ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Failed to update new created test-node-id resource register_remote_model id efgh", ex.getCause().getMessage());
}

public void testRegisterRemoteModelDeployUpdateFailure() {
String taskId = "abcd";
String modelId = "efgh";
String status = MLTaskState.CREATED.name();

doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(1);
MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId);
actionListener.onResponse(output);
return null;
}).when(mlNodeClient).register(any(MLRegisterModelInput.class), any());

AtomicInteger invocationCount = new AtomicInteger(0);
doAnswer(invocation -> {
ActionListener<WorkflowData> updateResponseListener = invocation.getArgument(4);
if (invocationCount.getAndIncrement() == 0) {
// succeed on first call (update register)
updateResponseListener.onResponse(new WorkflowData(Map.of(MODEL_ID, modelId), "test-id", "test-node-id"));
} else {
// fail on second call (update deploy)
updateResponseListener.onFailure(new RuntimeException("Failed to update deploy resource"));
}
return null;
}).when(flowFrameworkIndicesHandler).addResourceToStateIndex(any(WorkflowData.class), anyString(), anyString(), anyString(), any());

WorkflowData deployWorkflowData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "xyz"),
Map.entry("description", "description"),
Map.entry(CONNECTOR_ID, "abcdefg"),
Map.entry(DEPLOY_FIELD, true)
),
"test-id",
"test-node-id"
);

PlainActionFuture<WorkflowData> future = this.registerRemoteModelStep.execute(
deployWorkflowData.getNodeId(),
deployWorkflowData,
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap()
);

assertTrue(future.isDone());
ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Failed to update simulated deploy step resource efgh", ex.getCause().getMessage());
}

public void testReisterRemoteModelInterfaceFailure() {
doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(1);
Expand Down

0 comments on commit 2648713

Please sign in to comment.