Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Create Connector actions parsing #127

Merged
merged 4 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;

Expand All @@ -24,8 +25,10 @@
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -84,39 +87,44 @@
String description = null;
String version = null;
String protocol = null;
Map<String, String> parameters = new HashMap<>();
Map<String, String> credentials = new HashMap<>();
List<ConnectorAction> actions = new ArrayList<>();

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
name = (String) content.get(NAME_FIELD);
break;
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
break;
case VERSION_FIELD:
version = (String) content.get(VERSION_FIELD);
break;
case PROTOCOL_FIELD:
protocol = (String) content.get(PROTOCOL_FIELD);
break;
case PARAMETERS_FIELD:
parameters = getParameterMap((Map<String, String>) content.get(PARAMETERS_FIELD));
break;
case CREDENTIALS_FIELD:
credentials = (Map<String, String>) content.get(CREDENTIALS_FIELD);
break;
case ACTIONS_FIELD:
actions = (List<ConnectorAction>) content.get(ACTIONS_FIELD);
break;
Map<String, String> parameters = Collections.emptyMap();
Map<String, String> credentials = Collections.emptyMap();
List<ConnectorAction> actions = Collections.emptyList();

try {
for (WorkflowData workflowData : data) {
for (Entry<String, Object> entry : workflowData.getContent().entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
name = (String) entry.getValue();
break;
case DESCRIPTION_FIELD:
description = (String) entry.getValue();
break;
case VERSION_FIELD:
version = (String) entry.getValue();
break;
case PROTOCOL_FIELD:
protocol = (String) entry.getValue();
break;
case PARAMETERS_FIELD:
parameters = getParameterMap(entry.getValue());
break;
case CREDENTIALS_FIELD:
credentials = getStringToStringMap(entry.getValue(), CREDENTIALS_FIELD);
break;
case ACTIONS_FIELD:
actions = getConnectorActionList(entry.getValue());
break;
}
}

}
} catch (IllegalArgumentException iae) {
createConnectorFuture.completeExceptionally(new FlowFrameworkException(iae.getMessage(), RestStatus.BAD_REQUEST));
return createConnectorFuture;
} catch (PrivilegedActionException pae) {
createConnectorFuture.completeExceptionally(new FlowFrameworkException(pae.getMessage(), RestStatus.UNAUTHORIZED));
return createConnectorFuture;

Check warning on line 127 in src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java#L122-L127

Added lines #L122 - L127 were not covered by tests
}

if (Stream.of(name, description, version, protocol, parameters, credentials, actions).allMatch(x -> x != null)) {
Expand Down Expand Up @@ -145,21 +153,48 @@
return NAME;
}

private static Map<String, String> getParameterMap(Map<String, String> params) {
@SuppressWarnings("unchecked")
private static Map<String, String> getStringToStringMap(Object map, String fieldName) {
if (map instanceof Map) {
return (Map<String, String>) map;
}
throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map.");

Check warning on line 161 in src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java#L161

Added line #L161 was not covered by tests
}

private static Map<String, String> getParameterMap(Object parameterMap) throws PrivilegedActionException {
Map<String, String> parameters = new HashMap<>();
for (String key : params.keySet()) {
String value = params.get(key);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
parameters.put(key, value);
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
for (Entry<String, String> entry : getStringToStringMap(parameterMap, PARAMETERS_FIELD).entrySet()) {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
parameters.put(entry.getKey(), entry.getValue());
return null;
});
}
return parameters;
}

private static List<ConnectorAction> getConnectorActionList(Object array) {
if (!(array instanceof Map[])) {
throw new IllegalArgumentException("[" + ACTIONS_FIELD + "] must be an array of key-value maps.");

Check warning on line 177 in src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java#L177

Added line #L177 was not covered by tests
}
List<ConnectorAction> actions = new ArrayList<>();
for (Map<?, ?> map : (Map<?, ?>[]) array) {
String actionType = (String) map.get(ConnectorAction.ACTION_TYPE_FIELD);
if (actionType == null) {
throw new IllegalArgumentException("[" + ConnectorAction.ACTION_TYPE_FIELD + "] is missing.");

Check warning on line 183 in src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java#L183

Added line #L183 was not covered by tests
}
@SuppressWarnings("unchecked")
ConnectorAction action = ConnectorAction.builder()
.actionType(ActionType.valueOf(actionType.toUpperCase(Locale.ROOT)))
.method((String) map.get(ConnectorAction.METHOD_FIELD))
.url((String) map.get(ConnectorAction.URL_FIELD))
.headers((Map<String, String>) map.get(ConnectorAction.HEADERS_FIELD))
.requestBody((String) map.get(ConnectorAction.REQUEST_BODY_FIELD))
.preProcessFunction((String) map.get(ConnectorAction.ACTION_PRE_PROCESS_FUNCTION))
.postProcessFunction((String) map.get(ConnectorAction.ACTION_POST_PROCESS_FUNCTION))
.build();
actions.add(action);
}
return actions;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.common.CommonValue;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.connector.ConnectorAction;
Expand All @@ -27,7 +28,6 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
Expand All @@ -44,45 +44,38 @@ public void setUp() throws Exception {

Map<String, String> params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7"));
Map<String, String> credentials = Map.ofEntries(Map.entry("key1", "value1"), Map.entry("key2", "value2"));
Map<?, ?>[] actions = new Map<?, ?>[] {
Map.ofEntries(
Map.entry(ConnectorAction.ACTION_TYPE_FIELD, ConnectorAction.ActionType.PREDICT.name()),
Map.entry(ConnectorAction.METHOD_FIELD, "post"),
Map.entry(ConnectorAction.URL_FIELD, "foo.test"),
Map.entry(
ConnectorAction.REQUEST_BODY_FIELD,
"{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }"
)
) };

MockitoAnnotations.openMocks(this);

ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "post";
String url = "foot.test";

inputData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "test"),
Map.entry("description", "description"),
Map.entry("version", "1"),
Map.entry("protocol", "test"),
Map.entry("params", params),
Map.entry("credentials", credentials),
Map.entry(
"actions",
List.of(
new ConnectorAction(
actionType,
method,
url,
null,
"{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }",
null,
null
)
)
)
Map.entry(CommonValue.NAME_FIELD, "test"),
Map.entry(CommonValue.DESCRIPTION_FIELD, "description"),
Map.entry(CommonValue.VERSION_FIELD, "1"),
Map.entry(CommonValue.PROTOCOL_FIELD, "test"),
Map.entry(CommonValue.PARAMETERS_FIELD, params),
Map.entry(CommonValue.CREDENTIALS_FIELD, credentials),
Map.entry(CommonValue.ACTIONS_FIELD, actions)
)
);

}

public void testCreateConnector() throws IOException, ExecutionException, InterruptedException {

String connectorId = "connect";
CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient);

@SuppressWarnings("unchecked")
ArgumentCaptor<ActionListener<MLCreateConnectorResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
Expand All @@ -104,6 +97,7 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr
public void testCreateConnectorFailure() throws IOException {
CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient);

@SuppressWarnings("unchecked")
ArgumentCaptor<ActionListener<MLCreateConnectorResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
Expand Down
Loading