Skip to content

Commit

Permalink
Adds Create Ingest Pipeline Step (#44)
Browse files Browse the repository at this point in the history
* Initial create ingest pipeline implementation

Signed-off-by: Joshua Palis <[email protected]>

* Adding TODOs for storing response to global context index, fixing comments

Signed-off-by: Joshua Palis <[email protected]>

* Updating workflowData interface, modifying CreateIngestPipelineStep

Signed-off-by: Joshua Palis <[email protected]>

* updating workflow data extraction to read pipelineId from parameters rather than from content

Signed-off-by: Joshua Palis <[email protected]>

* removing unecessary cast

Signed-off-by: Joshua Palis <[email protected]>

* Pulls all required data from content rather than from params, fixes javadoc error

Signed-off-by: Joshua Palis <[email protected]>

* fixing comments

Signed-off-by: Joshua Palis <[email protected]>

* addressing PR comments, adding switch statement to handle parsing workflow data for required fields

Signed-off-by: Joshua Palis <[email protected]>

* Adding entry import

Signed-off-by: Joshua Palis <[email protected]>

* fixing comments

Signed-off-by: Joshua Palis <[email protected]>

* Adds unit tests for create ingest pipeline step, fixes pipeline request body generator

Signed-off-by: Joshua Palis <[email protected]>

* Adding failure tests

Signed-off-by: Joshua Palis <[email protected]>

* Addressing PR comments

Signed-off-by: Joshua Palis <[email protected]>

* Addressing PR comments

Signed-off-by: Joshua Palis <[email protected]>

* Fixing workflow data

Signed-off-by: Joshua Palis <[email protected]>

---------

Signed-off-by: Joshua Palis <[email protected]>
  • Loading branch information
joshpalis committed Sep 25, 2023
1 parent a530739 commit 1d22bee
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.flowframework.workflow.CreateIndex.CreateIndexStep;
import org.opensearch.flowframework.workflow.CreateIngestPipelineStep;
import org.opensearch.plugins.Plugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
Expand Down Expand Up @@ -48,7 +49,8 @@ public Collection<Object> createComponents(
Supplier<RepositoriesService> repositoriesServiceSupplier
) {
this.client = client;
CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client);
CreateIndexStep createIndexStep = new CreateIndexStep(client);
return ImmutableList.of(createIndexStep);
return ImmutableList.of(createIngestPipelineStep, createIndexStep);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ingest.PutPipelineRequest;
import org.opensearch.client.Client;
import org.opensearch.client.ClusterAdminClient;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

/**
* Workflow step to create an ingest pipeline
*/
public class CreateIngestPipelineStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(CreateIngestPipelineStep.class);
private static final String NAME = "create_ingest_pipeline_step";

// Common pipeline configuration fields
private static final String PIPELINE_ID_FIELD = "id";
private static final String DESCRIPTION_FIELD = "description";
private static final String PROCESSORS_FIELD = "processors";
private static final String TYPE_FIELD = "type";

// Temporary text embedding processor fields
private static final String FIELD_MAP = "field_map";
private static final String MODEL_ID_FIELD = "model_id";
private static final String INPUT_FIELD = "input_field_name";
private static final String OUTPUT_FIELD = "output_field_name";

// Client to store a pipeline in the cluster state
private final ClusterAdminClient clusterAdminClient;

/**
* Instantiates a new CreateIngestPipelineStep
*
* @param client The client to create a pipeline and store workflow data into the global context index
*/
public CreateIngestPipelineStep(Client client) {
this.clusterAdminClient = client.admin().cluster();
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> createIngestPipelineFuture = new CompletableFuture<>();

String pipelineId = null;
String description = null;
String type = null;
String modelId = null;
String inputFieldName = null;
String outputFieldName = null;
BytesReference configuration = null;

// Extract required content from workflow data and generate the ingest pipeline configuration
for (WorkflowData workflowData : data) {

Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case PIPELINE_ID_FIELD:
pipelineId = (String) content.get(PIPELINE_ID_FIELD);
break;
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
break;
case TYPE_FIELD:
type = (String) content.get(TYPE_FIELD);
break;
case MODEL_ID_FIELD:
modelId = (String) content.get(MODEL_ID_FIELD);
break;
case INPUT_FIELD:
inputFieldName = (String) content.get(INPUT_FIELD);
break;
case OUTPUT_FIELD:
outputFieldName = (String) content.get(OUTPUT_FIELD);
break;
default:
break;
}
}

// Determmine if fields have been populated, else iterate over remaining workflow data
if (Stream.of(pipelineId, description, modelId, type, inputFieldName, outputFieldName).allMatch(x -> x != null)) {
try {
configuration = BytesReference.bytes(
buildIngestPipelineRequestContent(description, modelId, type, inputFieldName, outputFieldName)
);
} catch (IOException e) {
logger.error("Failed to create ingest pipeline configuration: " + e.getMessage());
createIngestPipelineFuture.completeExceptionally(e);
}
break;
}
}

if (configuration == null) {
// Required workflow data not found
createIngestPipelineFuture.completeExceptionally(new Exception("Failed to create ingest pipeline, required inputs not found"));
} else {
// Create PutPipelineRequest and execute
PutPipelineRequest putPipelineRequest = new PutPipelineRequest(pipelineId, configuration, XContentType.JSON);
clusterAdminClient.putPipeline(putPipelineRequest, ActionListener.wrap(response -> {
logger.info("Created ingest pipeline : " + putPipelineRequest.getId());

// PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead
createIngestPipelineFuture.complete(new WorkflowData(Map.of("pipelineId", putPipelineRequest.getId())));

// TODO : Use node client to index response data to global context (pending global context index implementation)

}, exception -> {
logger.error("Failed to create ingest pipeline : " + exception.getMessage());
createIngestPipelineFuture.completeExceptionally(exception);
}));
}

return createIngestPipelineFuture;
}

@Override
public String getName() {
return NAME;
}

/**
* Temporary, generates the ingest pipeline request content for text_embedding processor from workflow data
* {
* "description" : "<description>",
* "processors" : [
* {
* "<type>" : {
* "model_id" : "<model_id>",
* "field_map" : {
* "<input_field_name>" : "<output_field_name>"
* }
* }
* ]
* }
*
* @param description The description of the ingest pipeline configuration
* @param modelId The ID of the model that will be used in the embedding interface
* @param type The processor type
* @param inputFieldName The field name used to cache text for text embeddings
* @param outputFieldName The field name in which output text is stored
* @throws IOException if the request content fails to be generated
* @return the xcontent builder with the formatted ingest pipeline configuration
*/
private XContentBuilder buildIngestPipelineRequestContent(
String description,
String modelId,
String type,
String inputFieldName,
String outputFieldName
) throws IOException {
return XContentFactory.jsonBuilder()
.startObject()
.field(DESCRIPTION_FIELD, description)
.startArray(PROCESSORS_FIELD)
.startObject()
.startObject(type)
.field(MODEL_ID_FIELD, modelId)
.startObject(FIELD_MAP)
.field(inputFieldName, outputFieldName)
.endObject()
.endObject()
.endObject()
.endArray()
.endObject();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow.CreateIngestPipeline;

import org.opensearch.action.ingest.PutPipelineRequest;
import org.opensearch.action.support.master.AcknowledgedResponse;
import org.opensearch.client.AdminClient;
import org.opensearch.client.Client;
import org.opensearch.client.ClusterAdminClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.workflow.CreateIngestPipelineStep;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.test.OpenSearchTestCase;

import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import org.mockito.ArgumentCaptor;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class CreateIngestPipelineStepTests extends OpenSearchTestCase {

private WorkflowData inputData;
private WorkflowData outpuData;
private Client client;
private AdminClient adminClient;
private ClusterAdminClient clusterAdminClient;

@Override
public void setUp() throws Exception {
super.setUp();

inputData = new WorkflowData(
Map.ofEntries(
Map.entry("id", "pipelineId"),
Map.entry("description", "some description"),
Map.entry("type", "text_embedding"),
Map.entry("model_id", "model_id"),
Map.entry("input_field_name", "inputField"),
Map.entry("output_field_name", "outputField")
)
);

// Set output data to returned pipelineId
outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipelineId", "pipelineId")));

client = mock(Client.class);
adminClient = mock(AdminClient.class);
clusterAdminClient = mock(ClusterAdminClient.class);

when(client.admin()).thenReturn(adminClient);
when(adminClient.cluster()).thenReturn(clusterAdminClient);
}

public void testCreateIngestPipelineStep() throws InterruptedException, ExecutionException {

CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client);

ArgumentCaptor<ActionListener> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);
CompletableFuture<WorkflowData> future = createIngestPipelineStep.execute(List.of(inputData));

assertFalse(future.isDone());

// Mock put pipeline request execution and return true
verify(clusterAdminClient, times(1)).putPipeline(any(PutPipelineRequest.class), actionListenerCaptor.capture());
actionListenerCaptor.getValue().onResponse(new AcknowledgedResponse(true));

assertTrue(future.isDone() && !future.isCompletedExceptionally());
assertEquals(outpuData.getContent(), future.get().getContent());
}

public void testCreateIngestPipelineStepFailure() throws InterruptedException {

CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client);

ArgumentCaptor<ActionListener> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);
CompletableFuture<WorkflowData> future = createIngestPipelineStep.execute(List.of(inputData));

assertFalse(future.isDone());

// Mock put pipeline request execution and return false
verify(clusterAdminClient, times(1)).putPipeline(any(PutPipelineRequest.class), actionListenerCaptor.capture());
actionListenerCaptor.getValue().onFailure(new Exception("Failed to create ingest pipeline"));

assertTrue(future.isDone() && future.isCompletedExceptionally());

ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get());
assertTrue(exception.getCause() instanceof Exception);
assertEquals("Failed to create ingest pipeline", exception.getCause().getMessage());
}

public void testMissingData() throws InterruptedException {
CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client);

// Data with missing input and output fields
WorkflowData incorrectData = new WorkflowData(
Map.ofEntries(
Map.entry("id", "pipelineId"),
Map.entry("description", "some description"),
Map.entry("type", "text_embedding"),
Map.entry("model_id", "model_id")
)
);

CompletableFuture<WorkflowData> future = createIngestPipelineStep.execute(List.of(incorrectData));
assertTrue(future.isDone() && future.isCompletedExceptionally());

ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get());
assertTrue(exception.getCause() instanceof Exception);
assertEquals("Failed to create ingest pipeline, required inputs not found", exception.getCause().getMessage());
}

}

0 comments on commit 1d22bee

Please sign in to comment.