Skip to content

Commit

Permalink
Added Predict and SearchRequest workflow step
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Mar 27, 2024
1 parent 8820d89 commit d803413
Show file tree
Hide file tree
Showing 5 changed files with 314 additions and 1 deletion.
14 changes: 14 additions & 0 deletions src/main/java/org/opensearch/flowframework/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,20 @@ private CommonValue() {}
/** Pipeline Configurations */
public static final String CONFIGURATIONS = "configurations";

/** Indexes for knn query **/
public static final String INPUT_INDEX = "input_index";

/** Query for knn Search Request*/
public static final String INCLUDES = "includes";

/** Vectors field */
public static final String VECTORS = "vectors";

/** Search request */
public static final String SEARCH_REQUEST = "search_request";
/** Search response */
public static final String SEARCH_RESPONSE = "search_response";

/*
* Constants associated with resource provisioning / state
*/
Expand Down
147 changes: 147 additions & 0 deletions src/main/java/org/opensearch/flowframework/workflow/PredictStep.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.exception.WorkflowStepException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static org.opensearch.flowframework.common.CommonValue.INCLUDES;
import static org.opensearch.flowframework.common.CommonValue.INPUT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.VECTORS;
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;

/**
* Step to predict data
*/
public class PredictStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(PredictStep.class);

/** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */
public static final String NAME = "predict";
private MachineLearningNodeClient mlClient;

/**
* Instantiate this class
* @param mlClient client to instantiate MLClient
*/
public PredictStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public PlainActionFuture<WorkflowData> execute(
String currentNodeId,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs,
Map<String, String> params
) {
PlainActionFuture<WorkflowData> predictFuture = PlainActionFuture.newFuture();

ActionListener<MLOutput> actionListener = new ActionListener<>() {

@Override
public void onResponse(MLOutput mlOutput) {
logger.info("Prediction done. Storing vectors");
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);

predictFuture.onResponse(
new WorkflowData(Map.ofEntries(Map.entry(VECTORS, vector)), currentNodeInputs.getWorkflowId(), currentNodeId)
);
}

@Override
public void onFailure(Exception e) {
String errorMessage = "Failed to predict the data";
logger.error(errorMessage, e);
predictFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e)));
}
};

Set<String> requiredKeys = Set.of(MODEL_ID, INPUT_INDEX, INCLUDES);

Set<String> optionalKeys = Collections.emptySet();

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs,
params
);

String modelId = (String) inputs.get(MODEL_ID);
List<String> indexes = (List<String>) inputs.get(INPUT_INDEX);
String[] includes = (String[]) inputs.get(INCLUDES);

MLInputDataset inputDataset = new SearchQueryInputDataset(indexes, generateQuery(includes));

MLInput mlInput = new MLInput(FunctionName.KMEANS, null, inputDataset);

mlClient.predict(modelId, mlInput, actionListener);

} catch (Exception e) {
predictFuture.onFailure(e);
}
return predictFuture;
}

private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<Float>> vector = new ArrayList<>();
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
for (final ModelTensors tensors : tensorOutputList) {
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
for (final ModelTensor tensor : tensorsList) {
vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList()));
}
}
return vector;
}

private SearchSourceBuilder generateQuery(String[] includes) {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.size(1000);
searchSourceBuilder.fetchSource(includes, null);
searchSourceBuilder.query(QueryBuilders.matchAllQuery());
return searchSourceBuilder;
}

@Override
public String getName() {
return NAME;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.exception.WorkflowStepException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Map;
import java.util.Set;

import static org.opensearch.flowframework.common.CommonValue.*;
import static org.opensearch.flowframework.common.WorkflowResources.INDEX_NAME;

/**
* Step for search request
*/
public class SearchRequestStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(SearchRequestStep.class);
private final Client client;

/** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */
public static final String NAME = "search_request";

/**
* Instantiate this class
*
* @param client Client to search on an index
*/
public SearchRequestStep(Client client) {
this.client = client;
}

@Override
public PlainActionFuture<WorkflowData> execute(
String currentNodeId,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs,
Map<String, String> params
) {
PlainActionFuture<WorkflowData> searchIndexFuture = PlainActionFuture.newFuture();

Set<String> requiredKeys = Set.of(INDEX_NAME, CONFIGURATIONS);
Set<String> optionalKeys = Collections.emptySet();

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs,
params
);

String indexName = (String) inputs.get(INDEX_NAME);

String configurations = (String) inputs.get(CONFIGURATIONS);

byte[] byteArr = configurations.getBytes(StandardCharsets.UTF_8);
BytesReference configurationsBytes = new BytesArray(byteArr);
SearchRequest searchRequest = new SearchRequest(indexName);

try {
if (!configurations.isEmpty()) {
XContentParser parser = JsonXContent.jsonXContent.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
configurationsBytes.streamInput()
);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.parseXContent(parser);
searchRequest.source(searchSourceBuilder);
}
} catch (IOException ex) {
String errorMessage = "Failed to search for the index based on the query;";
logger.error(errorMessage, ex);
searchIndexFuture.onFailure(new WorkflowStepException(errorMessage, RestStatus.BAD_REQUEST));
}

client.search(searchRequest, ActionListener.wrap(searchResponse -> {
searchIndexFuture.onResponse(
new WorkflowData(
Map.ofEntries(Map.entry(SEARCH_RESPONSE, searchResponse), Map.entry(SEARCH_REQUEST, searchRequest)),
currentNodeInputs.getWorkflowId(),
currentNodeId
)
);
}, exception -> {
String errorMessage = "Failed to search on the index";
logger.error(errorMessage, exception);
searchIndexFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(exception)));
}));

} catch (Exception e) {
searchIndexFuture.onFailure(e);
}

return searchIndexFuture;
}

@Override
public String getName() {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION;
import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE;
import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME;
import static org.opensearch.flowframework.common.CommonValue.INCLUDES;
import static org.opensearch.flowframework.common.CommonValue.INPUT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT;
import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_STATUS;
Expand All @@ -47,10 +49,13 @@
import static org.opensearch.flowframework.common.CommonValue.PIPELINE_ID;
import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.CommonValue.SEARCH_REQUEST;
import static org.opensearch.flowframework.common.CommonValue.SEARCH_RESPONSE;
import static org.opensearch.flowframework.common.CommonValue.SUCCESS;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TYPE;
import static org.opensearch.flowframework.common.CommonValue.URL;
import static org.opensearch.flowframework.common.CommonValue.VECTORS;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;
import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID;
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
Expand Down Expand Up @@ -104,6 +109,8 @@ public WorkflowStepFactory(
);
stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(PredictStep.NAME, () -> new PredictStep(mlClient));
stepMap.put(SearchRequestStep.NAME, () -> new SearchRequestStep(client));
stepMap.put(DeleteConnectorStep.NAME, () -> new DeleteConnectorStep(mlClient));
stepMap.put(RegisterModelGroupStep.NAME, () -> new RegisterModelGroupStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(ToolStep.NAME, ToolStep::new);
Expand Down Expand Up @@ -226,6 +233,20 @@ public enum WorkflowSteps {
List.of(PIPELINE_ID),
Collections.emptyList(),
null
),

/** Predict Step */
PREDICT(PredictStep.NAME, List.of(MODEL_ID, INPUT_INDEX, INCLUDES), List.of(VECTORS), Collections.emptyList(), null),

/** Create Search Request Step */
SEARCH_REQUEST_STEP(
SearchRequestStep.NAME,
List.of(INDEX_NAME, CONFIGURATIONS),
// temporary for POC
List.of(SEARCH_REQUEST, SEARCH_RESPONSE),
Collections.emptyList(),
null

);

private final String workflowStepName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public void testParseWorkflowValidator() throws IOException {

WorkflowValidator validator = new WorkflowValidator(workflowStepValidators);

assertEquals(17, validator.getWorkflowStepValidators().size());
assertEquals(19, validator.getWorkflowStepValidators().size());

assertTrue(validator.getWorkflowStepValidators().keySet().contains("create_connector"));
assertEquals(7, validator.getWorkflowStepValidators().get("create_connector").getInputs().size());
Expand Down

0 comments on commit d803413

Please sign in to comment.