From d803413d189a98f37f8ce99802875ddf06261007 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 27 Mar 2024 13:09:55 -0700 Subject: [PATCH] Added Predict and SearchRequest workflow step Signed-off-by: Owais Kazi --- .../flowframework/common/CommonValue.java | 14 ++ .../flowframework/workflow/PredictStep.java | 147 ++++++++++++++++++ .../workflow/SearchRequestStep.java | 131 ++++++++++++++++ .../workflow/WorkflowStepFactory.java | 21 +++ .../model/WorkflowValidatorTests.java | 2 +- 5 files changed, 314 insertions(+), 1 deletion(-) create mode 100644 src/main/java/org/opensearch/flowframework/workflow/PredictStep.java create mode 100644 src/main/java/org/opensearch/flowframework/workflow/SearchRequestStep.java diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index d3960d90b..667c94819 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -169,6 +169,20 @@ private CommonValue() {} /** Pipeline Configurations */ public static final String CONFIGURATIONS = "configurations"; + /** Indexes for knn query **/ + public static final String INPUT_INDEX = "input_index"; + + /** Query for knn Search Request*/ + public static final String INCLUDES = "includes"; + + /** Vectors field */ + public static final String VECTORS = "vectors"; + + /** Search request */ + public static final String SEARCH_REQUEST = "search_request"; + /** Search response */ + public static final String SEARCH_RESPONSE = "search_response"; + /* * Constants associated with resource provisioning / state */ diff --git a/src/main/java/org/opensearch/flowframework/workflow/PredictStep.java b/src/main/java/org/opensearch/flowframework/workflow/PredictStep.java new file mode 100644 index 000000000..f5a8121f7 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/PredictStep.java @@ -0,0 +1,147 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.exception.WorkflowStepException; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.SearchQueryInputDataset; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.opensearch.flowframework.common.CommonValue.INCLUDES; +import static org.opensearch.flowframework.common.CommonValue.INPUT_INDEX; +import static org.opensearch.flowframework.common.CommonValue.VECTORS; +import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; + +/** + * Step to predict data + */ +public class PredictStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(PredictStep.class); + + /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ + public static final String NAME = "predict"; + private MachineLearningNodeClient mlClient; + + /** + * Instantiate this class + * @param mlClient client to instantiate MLClient + */ + public PredictStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + } + + @Override + public PlainActionFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs, + Map params + ) { + PlainActionFuture predictFuture = PlainActionFuture.newFuture(); + + ActionListener actionListener = new ActionListener<>() { + + @Override + public void onResponse(MLOutput mlOutput) { + logger.info("Prediction done. Storing vectors"); + final List> vector = buildVectorFromResponse(mlOutput); + + predictFuture.onResponse( + new WorkflowData(Map.ofEntries(Map.entry(VECTORS, vector)), currentNodeInputs.getWorkflowId(), currentNodeId) + ); + } + + @Override + public void onFailure(Exception e) { + String errorMessage = "Failed to predict the data"; + logger.error(errorMessage, e); + predictFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); + } + }; + + Set requiredKeys = Set.of(MODEL_ID, INPUT_INDEX, INCLUDES); + + Set optionalKeys = Collections.emptySet(); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs, + params + ); + + String modelId = (String) inputs.get(MODEL_ID); + List indexes = (List) inputs.get(INPUT_INDEX); + String[] includes = (String[]) inputs.get(INCLUDES); + + MLInputDataset inputDataset = new SearchQueryInputDataset(indexes, generateQuery(includes)); + + MLInput mlInput = new MLInput(FunctionName.KMEANS, null, inputDataset); + + mlClient.predict(modelId, mlInput, actionListener); + + } catch (Exception e) { + predictFuture.onFailure(e); + } + return predictFuture; + } + + private List> buildVectorFromResponse(MLOutput mlOutput) { + final List> vector = new ArrayList<>(); + final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; + final List tensorOutputList = modelTensorOutput.getMlModelOutputs(); + for (final ModelTensors tensors : tensorOutputList) { + final List tensorsList = tensors.getMlModelTensors(); + for (final ModelTensor tensor : tensorsList) { + vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList())); + } + } + return vector; + } + + private SearchSourceBuilder generateQuery(String[] includes) { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.size(1000); + searchSourceBuilder.fetchSource(includes, null); + searchSourceBuilder.query(QueryBuilders.matchAllQuery()); + return searchSourceBuilder; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/SearchRequestStep.java b/src/main/java/org/opensearch/flowframework/workflow/SearchRequestStep.java new file mode 100644 index 000000000..133fbc13a --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/SearchRequestStep.java @@ -0,0 +1,131 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.exception.WorkflowStepException; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Map; +import java.util.Set; + +import static org.opensearch.flowframework.common.CommonValue.*; +import static org.opensearch.flowframework.common.WorkflowResources.INDEX_NAME; + +/** + * Step for search request + */ +public class SearchRequestStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(SearchRequestStep.class); + private final Client client; + + /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ + public static final String NAME = "search_request"; + + /** + * Instantiate this class + * + * @param client Client to search on an index + */ + public SearchRequestStep(Client client) { + this.client = client; + } + + @Override + public PlainActionFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs, + Map params + ) { + PlainActionFuture searchIndexFuture = PlainActionFuture.newFuture(); + + Set requiredKeys = Set.of(INDEX_NAME, CONFIGURATIONS); + Set optionalKeys = Collections.emptySet(); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs, + params + ); + + String indexName = (String) inputs.get(INDEX_NAME); + + String configurations = (String) inputs.get(CONFIGURATIONS); + + byte[] byteArr = configurations.getBytes(StandardCharsets.UTF_8); + BytesReference configurationsBytes = new BytesArray(byteArr); + SearchRequest searchRequest = new SearchRequest(indexName); + + try { + if (!configurations.isEmpty()) { + XContentParser parser = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + configurationsBytes.streamInput() + ); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.parseXContent(parser); + searchRequest.source(searchSourceBuilder); + } + } catch (IOException ex) { + String errorMessage = "Failed to search for the index based on the query;"; + logger.error(errorMessage, ex); + searchIndexFuture.onFailure(new WorkflowStepException(errorMessage, RestStatus.BAD_REQUEST)); + } + + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + searchIndexFuture.onResponse( + new WorkflowData( + Map.ofEntries(Map.entry(SEARCH_RESPONSE, searchResponse), Map.entry(SEARCH_REQUEST, searchRequest)), + currentNodeInputs.getWorkflowId(), + currentNodeId + ) + ); + }, exception -> { + String errorMessage = "Failed to search on the index"; + logger.error(errorMessage, exception); + searchIndexFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(exception))); + })); + + } catch (Exception e) { + searchIndexFuture.onFailure(e); + } + + return searchIndexFuture; + } + + @Override + public String getName() { + return null; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 224cbf1eb..ad338a715 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -37,6 +37,8 @@ import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION; import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE; import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; +import static org.opensearch.flowframework.common.CommonValue.INCLUDES; +import static org.opensearch.flowframework.common.CommonValue.INPUT_INDEX; import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE; import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_STATUS; @@ -47,10 +49,13 @@ import static org.opensearch.flowframework.common.CommonValue.PIPELINE_ID; import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.SEARCH_REQUEST; +import static org.opensearch.flowframework.common.CommonValue.SEARCH_RESPONSE; import static org.opensearch.flowframework.common.CommonValue.SUCCESS; import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TYPE; import static org.opensearch.flowframework.common.CommonValue.URL; +import static org.opensearch.flowframework.common.CommonValue.VECTORS; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; @@ -104,6 +109,8 @@ public WorkflowStepFactory( ); stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(PredictStep.NAME, () -> new PredictStep(mlClient)); + stepMap.put(SearchRequestStep.NAME, () -> new SearchRequestStep(client)); stepMap.put(DeleteConnectorStep.NAME, () -> new DeleteConnectorStep(mlClient)); stepMap.put(RegisterModelGroupStep.NAME, () -> new RegisterModelGroupStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(ToolStep.NAME, ToolStep::new); @@ -226,6 +233,20 @@ public enum WorkflowSteps { List.of(PIPELINE_ID), Collections.emptyList(), null + ), + + /** Predict Step */ + PREDICT(PredictStep.NAME, List.of(MODEL_ID, INPUT_INDEX, INCLUDES), List.of(VECTORS), Collections.emptyList(), null), + + /** Create Search Request Step */ + SEARCH_REQUEST_STEP( + SearchRequestStep.NAME, + List.of(INDEX_NAME, CONFIGURATIONS), + // temporary for POC + List.of(SEARCH_REQUEST, SEARCH_RESPONSE), + Collections.emptyList(), + null + ); private final String workflowStepName; diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index 80a9788c2..5b7848c9b 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -46,7 +46,7 @@ public void testParseWorkflowValidator() throws IOException { WorkflowValidator validator = new WorkflowValidator(workflowStepValidators); - assertEquals(17, validator.getWorkflowStepValidators().size()); + assertEquals(19, validator.getWorkflowStepValidators().size()); assertTrue(validator.getWorkflowStepValidators().keySet().contains("create_connector")); assertEquals(7, validator.getWorkflowStepValidators().get("create_connector").getInputs().size());