Skip to content

Commit

Permalink
addressed comments and fixed some unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Galitzky <[email protected]>
  • Loading branch information
amitgalitz committed Oct 30, 2023
1 parent b334ca3 commit 2f8efed
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 169 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public enum FlowFrameworkIndex {
),
WORKFLOW_STATE(
WORKFLOW_STATE_INDEX,
ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getGlobalContextMappings),
ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getWorkflowStateMappings),
WORKFLOW_STATE_INDEX_VERSION
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,16 @@
import static org.opensearch.flowframework.common.CommonValue.NO_SCHEMA_VERSION;
import static org.opensearch.flowframework.common.CommonValue.SCHEMA_VERSION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX_MAPPING;

/**
* A handler for global context related operations
* A handler for operations on system indices in the AI Flow Framework plugin
* The current indices we have are global-context and workflow-state indices
*/
public class FlowFrameworkIndicesHandler {
private static final Logger logger = LogManager.getLogger(FlowFrameworkIndicesHandler.class);
private final Client client;
ClusterService clusterService;
private final ClusterService clusterService;
private static final Map<String, AtomicBoolean> indexMappingUpdated = new HashMap<>();
private static final Map<String, Object> indexSettings = Map.of("index.auto_expand_replicas", "0-1");

Expand All @@ -70,6 +72,9 @@ public class FlowFrameworkIndicesHandler {
public FlowFrameworkIndicesHandler(Client client, ClusterService clusterService) {
this.client = client;
this.clusterService = clusterService;
for (FlowFrameworkIndex mlIndex : FlowFrameworkIndex.values()) {
indexMappingUpdated.put(mlIndex.getIndexName(), new AtomicBoolean(false));
}
}

static {
Expand All @@ -87,6 +92,15 @@ public static String getGlobalContextMappings() throws IOException {
return getIndexMappings(GLOBAL_CONTEXT_INDEX_MAPPING);
}

/**
* Get workflow-state index mapping
* @return workflow-state index mapping
* @throws IOException if mapping file cannot be read correctly
*/
public static String getWorkflowStateMappings() throws IOException {
return getIndexMappings(WORKFLOW_STATE_INDEX_MAPPING);
}

/**
* Create global context index if it's absent
* @param listener The action listener
Expand Down Expand Up @@ -314,9 +328,9 @@ public void putInitialStateToWorkflowState(String workflowId, User user, ActionL
*/
public void updateTemplateInGlobalContext(String documentId, Template template, ActionListener<IndexResponse> listener) {
if (!doesIndexExist(GLOBAL_CONTEXT_INDEX)) {
String exceptionMessage = "Failed to update workflow state for workflow_id : "
String exceptionMessage = "Failed to update template for workflow_id : "
+ documentId
+ ", workflow_state index does not exist.";
+ ", global_context index does not exist.";
logger.error(exceptionMessage);
listener.onFailure(new Exception(exceptionMessage));
} else {
Expand All @@ -337,53 +351,34 @@ public void updateTemplateInGlobalContext(String documentId, Template template,

/**
* Updates a document in the workflow state index
* @param workflowStateDocId the document ID
* @param indexName the index that we will be updating a document of.
* @param documentId the document ID
* @param updatedFields the fields to update the global state index with
* @param listener action listener
*/
public void updateWorkflowState(String workflowStateDocId, Map<String, Object> updatedFields, ActionListener<UpdateResponse> listener) {
if (!doesIndexExist(WORKFLOW_STATE_INDEX)) {
String exceptionMessage = "Failed to update state for given workflow due to missing workflow_state index";
public void updateFlowFrameworkSystemIndexDoc(
String indexName,
String documentId,
Map<String, Object> updatedFields,
ActionListener<UpdateResponse> listener
) {
if (!doesIndexExist(indexName)) {
String exceptionMessage = "Failed to update document for given workflow due to missing " + indexName + " index";
logger.error(exceptionMessage);
listener.onFailure(new Exception(exceptionMessage));
} else {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
UpdateRequest updateRequest = new UpdateRequest(WORKFLOW_STATE_INDEX, workflowStateDocId);
UpdateRequest updateRequest = new UpdateRequest(indexName, documentId);
Map<String, Object> updatedContent = new HashMap<>();
updatedContent.putAll(updatedFields);
updateRequest.doc(updatedContent);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
// TODO: decide what condition can be considered as an update conflict and add retry strategy
client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
logger.error("Failed to update workflow_state entry : {}. {}", workflowStateDocId, e.getMessage());
logger.error("Failed to update {} entry : {}. {}", indexName, documentId, e.getMessage());
listener.onFailure(e);
}
}
}

/**
* Update global context index for specific fields
* @param documentId global context index document id
* @param updatedFields updated fields; key: field name, value: new value
* @param listener UpdateResponse action listener
*/
public void storeResponseToGlobalContext(
String documentId,
Map<String, Object> updatedFields,
ActionListener<UpdateResponse> listener
) {
UpdateRequest updateRequest = new UpdateRequest(GLOBAL_CONTEXT_INDEX, documentId);
Map<String, Object> updatedUserOutputsContext = new HashMap<>();
updatedUserOutputsContext.putAll(updatedFields);
updateRequest.doc(updatedUserOutputsContext);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
// TODO: decide what condition can be considered as an update conflict and add retry strategy

try {
client.update(updateRequest, listener);
} catch (Exception e) {
logger.error("Failed to update global_context index");
listener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/
// TODO: transfer this to more detailed array for each step
public enum ProvisioningProgress {
NOT_STARTED,
IN_PROGRESS,
DONE,
NOT_STARTED
DONE
}
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,6 @@ public String toString() {
+ compatibilityVersion
+ ", workflows="
+ workflows
+ ", user="
+ user
+ "]";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
import static org.opensearch.flowframework.model.WorkflowState.PROVISIONING_PROGRESS_FIELD;
import static org.opensearch.flowframework.model.WorkflowState.STATE_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.getUserContext;
Expand Down Expand Up @@ -78,7 +79,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
globalContextResponse.getId(),
user,
ActionListener.wrap(stateResponse -> {
logger.info("create state workflow doc " + stateResponse);
logger.info("create state workflow doc");
listener.onResponse(new WorkflowResponse(globalContextResponse.getId()));
}, exception -> {
logger.error("Failed to save workflow state : {}", exception.getMessage());
Expand All @@ -95,11 +96,12 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
request.getWorkflowId(),
request.getTemplate(),
ActionListener.wrap(response -> {
flowFrameworkIndicesHandler.updateWorkflowState(
flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc(
WORKFLOW_STATE_INDEX,
request.getWorkflowId(),
ImmutableMap.of(STATE_FIELD, State.NOT_STARTED, PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.NOT_STARTED),
ActionListener.wrap(updateResponse -> {
logger.info("updated workflow {} state to NOT_STARTED", request.getWorkflowId());
logger.info("updated workflow {} state to {}", request.getWorkflowId(), State.NOT_STARTED.name());
listener.onResponse(new WorkflowResponse(request.getWorkflowId()));
}, exception -> {
logger.error("Failed to update workflow state : {}", exception.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
Expand All @@ -44,10 +43,10 @@
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
import static org.opensearch.flowframework.model.WorkflowState.PROVISIONING_PROGRESS_FIELD;
import static org.opensearch.flowframework.model.WorkflowState.PROVISION_START_TIME_FIELD;
import static org.opensearch.flowframework.model.WorkflowState.STATE_FIELD;
import static org.opensearch.flowframework.model.WorkflowState.USER_OUTPUTS_FIELD;

/**
* Transport Action to provision a workflow from a stored use case template
Expand Down Expand Up @@ -88,7 +87,6 @@ public ProvisionWorkflowTransportAction(

@Override
protected void doExecute(Task task, WorkflowRequest request, ActionListener<WorkflowResponse> listener) {

// Retrieve use case template from global context
String workflowId = request.getWorkflowId();
GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId);
Expand All @@ -111,25 +109,20 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
// Parse template from document source
Template template = Template.parse(response.getSourceAsString());

flowFrameworkIndicesHandler.updateWorkflowState(
flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc(
WORKFLOW_STATE_INDEX,
workflowId,
ImmutableMap.of(
STATE_FIELD,
State.PROVISIONING,
PROVISIONING_PROGRESS_FIELD,
ProvisioningProgress.IN_PROGRESS,
PROVISION_START_TIME_FIELD,
Instant.now().toEpochMilli(),
USER_OUTPUTS_FIELD,
Map.of("key1", "key2")
Instant.now().toEpochMilli()
),
ActionListener.wrap(updateResponse -> {
logger.info("updated workflow {} state to PROVISIONING", request.getWorkflowId());
listener.onResponse(new WorkflowResponse(request.getWorkflowId()));
}, exception -> {
logger.error("Failed to update workflow state : {}", exception.getMessage());
listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST));
})
}, exception -> { logger.error("Failed to update workflow state : {}", exception.getMessage()); })
);

// Respond to rest action then execute provisioning workflow async
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ public CompletableFuture<WorkflowData> execute() {
if (this.future.isDone()) {
throw new IllegalStateException("Process Node [" + this.id + "] already executed.");
}

CompletableFuture.runAsync(() -> {
List<CompletableFuture<WorkflowData>> predFutures = predecessors.stream().map(p -> p.future()).collect(Collectors.toList());
try {
Expand All @@ -152,9 +153,11 @@ public CompletableFuture<WorkflowData> execute() {
}
}, this.nodeTimeout, ThreadPool.Names.SAME);
}
// record start time for this step.
CompletableFuture<WorkflowData> stepFuture = this.workflowStep.execute(input);
// If completed exceptionally, this is a no-op
future.complete(stepFuture.get());
// record end time passing workflow steps
if (delayExec != null) {
delayExec.cancel();
}
Expand Down
Loading

0 comments on commit 2f8efed

Please sign in to comment.