Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Added a new parse util method to avoid repetition / refactored code with new method #728

Merged
merged 1 commit into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
### Enhancements
- Add Workflow Step for Reindex from source index to destination ([#718](https://github.com/opensearch-project/flow-framework/pull/718))
- Add param to delete workflow API to clear status even if resources exist ([#719](https://github.com/opensearch-project/flow-framework/pull/719))

### Bug Fixes
- Add user mapping to Workflow State index ([#705](https://github.com/opensearch-project/flow-framework/pull/705))

Expand Down
26 changes: 26 additions & 0 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.common.Booleans;
import org.opensearch.common.io.Streams;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentHelper;
Expand Down Expand Up @@ -472,4 +473,29 @@ public static Map<String, String> convertStringToObjectMapToStringToStringMap(Ma
return stringToStringMap;
}
}

/**
* Checks if the inputs map contains the specified key and parses the associated value to a generic class.
*
* @param <T> the type to which the value should be parsed
* @param inputs the map containing the input data
* @param key the key to check in the map
* @param type the class to parse the value to
* @throws IllegalArgumentException if the type is not supported
* @return the generic type value associated with the key if present, or null if the key is not found
*/
public static <T> T parseIfExists(Map<String, Object> inputs, String key, Class<T> type) {
if (!inputs.containsKey(key)) {
return null;
}

Object value = inputs.get(key);
if (type == Boolean.class) {
return type.cast(Booleans.parseBoolean(value.toString()));
} else if (type == Float.class) {
return type.cast(Float.parseFloat(value.toString()));
} else {
throw new IllegalArgumentException("Unsupported type: " + type);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.Booleans;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
Expand Down Expand Up @@ -123,7 +122,7 @@ public PlainActionFuture<WorkflowData> execute(
String modelGroupId = (String) inputs.get(MODEL_GROUP_ID);
String allConfig = (String) inputs.get(ALL_CONFIG);
String modelInterface = (String) inputs.get(INTERFACE_FIELD);
final Boolean deploy = inputs.containsKey(DEPLOY_FIELD) ? Booleans.parseBoolean(inputs.get(DEPLOY_FIELD).toString()) : null;
final Boolean deploy = ParseUtils.parseIfExists(inputs, DEPLOY_FIELD, Boolean.class);

// Build register model input
MLRegisterModelInputBuilder mlInputBuilder = MLRegisterModelInput.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.Booleans;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.core.rest.RestStatus;
Expand Down Expand Up @@ -147,9 +146,7 @@ public void onFailure(Exception ex) {
if (inputs.containsKey(MODEL_ACCESS_MODE)) {
modelAccessMode = AccessMode.from((inputs.get(MODEL_ACCESS_MODE)).toString().toLowerCase(Locale.ROOT));
}
Boolean isAddAllBackendRoles = inputs.containsKey(ADD_ALL_BACKEND_ROLES)
? Booleans.parseBoolean(inputs.get(ADD_ALL_BACKEND_ROLES).toString())
: null;
Boolean isAddAllBackendRoles = ParseUtils.parseIfExists(inputs, ADD_ALL_BACKEND_ROLES, Boolean.class);

MLRegisterModelGroupInputBuilder builder = MLRegisterModelGroupInput.builder();
builder.name(modelGroupName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.common.Booleans;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
Expand Down Expand Up @@ -100,7 +99,7 @@ public PlainActionFuture<WorkflowData> execute(
String connectorId = (String) inputs.get(CONNECTOR_ID);
Guardrails guardRails = (Guardrails) inputs.get(GUARDRAILS_FIELD);
String modelInterface = (String) inputs.get(INTERFACE_FIELD);
final Boolean deploy = inputs.containsKey(DEPLOY_FIELD) ? Booleans.parseBoolean(inputs.get(DEPLOY_FIELD).toString()) : null;
final Boolean deploy = ParseUtils.parseIfExists(inputs, DEPLOY_FIELD, Boolean.class);

MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder()
.functionName(FunctionName.REMOTE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.Booleans;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.exception.WorkflowStepException;
Expand Down Expand Up @@ -64,9 +63,7 @@ public PlainActionFuture<WorkflowData> execute(
String type = (String) inputs.get(TYPE);
String name = (String) inputs.get(NAME_FIELD);
String description = (String) inputs.get(DESCRIPTION_FIELD);
Boolean includeOutputInAgentResponse = inputs.containsKey(INCLUDE_OUTPUT_IN_AGENT_RESPONSE)
? Booleans.parseBoolean(inputs.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE).toString())
: null;
Boolean includeOutputInAgentResponse = ParseUtils.parseIfExists(inputs, INCLUDE_OUTPUT_IN_AGENT_RESPONSE, Boolean.class);
Map<String, String> parameters = getToolsParametersMap(inputs.get(PARAMETERS_FIELD), previousNodeInputs, outputs);

MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,4 +221,36 @@ public void testGetInputsFromPreviousSteps() {
assertEquals("Missing required inputs [not-here] in workflow [workflowId] node [nodeId]", e.getMessage());
assertEquals(RestStatus.BAD_REQUEST, e.getRestStatus());
}

public void testParseIfExistsWithBooleanClass() {
Map<String, Object> inputs = new HashMap<>();
inputs.put("key1", "true");
inputs.put("key2", "false");
inputs.put("key3", "true");

assertEquals(Boolean.TRUE, ParseUtils.parseIfExists(inputs, "key1", Boolean.class));
assertEquals(Boolean.FALSE, ParseUtils.parseIfExists(inputs, "key2", Boolean.class));
assertNull(ParseUtils.parseIfExists(inputs, "keyThatDoesntExist", Boolean.class));

}

public void testParseIfExistsWithFloatClass() {
Map<String, Object> inputs = new HashMap<>();
inputs.put("key1", "3.14");
inputs.put("key2", "0.01");
inputs.put("key3", "90.22");

assertEquals(Float.valueOf("3.14"), ParseUtils.parseIfExists(inputs, "key1", Float.class));
assertEquals(Float.valueOf("0.01"), ParseUtils.parseIfExists(inputs, "key2", Float.class));
assertNull(ParseUtils.parseIfExists(inputs, "keyThatDoesntExist", Float.class));

}

public void testParseIfExistWhenWrongTypeIsPassed() {

Map<String, Object> inputs = new HashMap<>();
inputs.put("key1", "3.14");

assertThrows(IllegalArgumentException.class, () -> ParseUtils.parseIfExists(inputs, "key1", Integer.class));
}
}
Loading