Skip to content

Commit

Permalink
Merge pull request #19030 from jmchilton/fixes_for_21.01
Browse files Browse the repository at this point in the history
Fix numerous issues with tool input format "21.01"
  • Loading branch information
bgruening authored Nov 1, 2024
2 parents 9b39ec7 + 6d7842e commit 3ff3843
Show file tree
Hide file tree
Showing 10 changed files with 459 additions and 200 deletions.
4 changes: 3 additions & 1 deletion lib/galaxy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1836,7 +1836,9 @@ def expand_incoming(
# Expand these out to individual parameters for given jobs (tool executions).
expanded_incomings: List[ToolStateJobInstanceT]
collection_info: Optional[MatchingCollections]
expanded_incomings, collection_info = expand_meta_parameters(request_context, self, incoming)
expanded_incomings, collection_info = expand_meta_parameters(
request_context, self, incoming, input_format=input_format
)

self._ensure_expansion_is_valid(expanded_incomings, rerun_remap_job_id)

Expand Down
39 changes: 27 additions & 12 deletions lib/galaxy/tools/parameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def visit_input_values(
context=None,
no_replacement_value=REPLACE_ON_TRUTHY,
replace_optional_connections=False,
allow_case_inference=False,
unset_value=None,
):
"""
Given a tools parameter definition (`inputs`) and a specific set of
Expand Down Expand Up @@ -158,7 +160,7 @@ def visit_input_values(
"""

def callback_helper(input, input_values, name_prefix, label_prefix, parent_prefix, context=None, error=None):
value = input_values.get(input.name)
value = input_values.get(input.name, unset_value)
args = {
"input": input,
"parent": input_values,
Expand All @@ -182,13 +184,23 @@ def callback_helper(input, input_values, name_prefix, label_prefix, parent_prefi
input_values[input.name] = input.value

def get_current_case(input, input_values):
test_parameter = input.test_param
test_parameter_name = test_parameter.name
try:
return input.get_current_case(input_values[input.test_param.name])
if test_parameter_name not in input_values and allow_case_inference:
return input.get_current_case(test_parameter.get_initial_value(None, input_values))
else:
return input.get_current_case(input_values[test_parameter_name])
except (KeyError, ValueError):
return -1

context = ExpressionContext(input_values, context)
payload = {"context": context, "no_replacement_value": no_replacement_value}
payload = {
"context": context,
"no_replacement_value": no_replacement_value,
"allow_case_inference": allow_case_inference,
"unset_value": unset_value,
}
for input in inputs.values():
if isinstance(input, Repeat) or isinstance(input, UploadDataset):
values = input_values[input.name] = input_values.get(input.name, [])
Expand Down Expand Up @@ -411,16 +423,15 @@ def populate_state(
group_state = state[input.name]
if input.type == "repeat":
repeat_input = cast(Repeat, input)
if (
len(incoming[repeat_input.name]) > repeat_input.max
or len(incoming[repeat_input.name]) < repeat_input.min
repeat_name = repeat_input.name
repeat_incoming = incoming.get(repeat_name) or []
if repeat_incoming and (
len(repeat_incoming) > repeat_input.max or len(repeat_incoming) < repeat_input.min
):
errors[repeat_input.name] = (
"The number of repeat elements is outside the range specified by the tool."
)
errors[repeat_name] = "The number of repeat elements is outside the range specified by the tool."
else:
del group_state[:]
for rep in incoming[repeat_input.name]:
for rep in repeat_incoming:
new_state: ToolStateJobInstancePopulatedT = {}
group_state.append(new_state)
repeat_errors: ParameterValidationErrorsT = {}
Expand Down Expand Up @@ -454,10 +465,13 @@ def populate_state(
current_case = conditional_input.get_current_case(value)
group_state = state[conditional_input.name] = {}
cast_errors: ParameterValidationErrorsT = {}
incoming_for_conditional = cast(
ToolStateJobInstanceT, incoming.get(conditional_input.name) or {}
)
populate_state(
request_context,
conditional_input.cases[current_case].inputs,
cast(ToolStateJobInstanceT, incoming.get(conditional_input.name)),
incoming_for_conditional,
group_state,
cast_errors,
context=context,
Expand All @@ -475,10 +489,11 @@ def populate_state(
elif input.type == "section":
section_input = cast(Section, input)
section_errors: ParameterValidationErrorsT = {}
incoming_for_state = cast(ToolStateJobInstanceT, incoming.get(section_input.name) or {})
populate_state(
request_context,
section_input.inputs,
cast(ToolStateJobInstanceT, incoming.get(section_input.name)),
incoming_for_state,
group_state,
section_errors,
context=context,
Expand Down
150 changes: 124 additions & 26 deletions lib/galaxy/tools/parameters/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,19 @@
matching,
subcollections,
)
from galaxy.util import permutations
from galaxy.util.permutations import (
build_combos,
input_classification,
is_in_state,
state_copy,
state_get_value,
state_remove_value,
state_set_value,
)
from . import visit_input_values
from .wrapped import process_key
from .._types import (
InputFormatT,
ToolRequestT,
ToolStateJobInstanceT,
)
Expand Down Expand Up @@ -161,7 +170,15 @@ def is_batch(value):
ExpandedT = Tuple[List[ToolStateJobInstanceT], Optional[matching.MatchingCollections]]


def expand_meta_parameters(trans, tool, incoming: ToolRequestT) -> ExpandedT:
def expand_flat_parameters_to_nested(incoming_copy: ToolRequestT) -> Dict[str, Any]:
nested_dict: Dict[str, Any] = {}
for incoming_key, incoming_value in incoming_copy.items():
if not incoming_key.startswith("__"):
process_key(incoming_key, incoming_value=incoming_value, d=nested_dict)
return nested_dict


def expand_meta_parameters(trans, tool, incoming: ToolRequestT, input_format: InputFormatT) -> ExpandedT:
"""
Take in a dictionary of raw incoming parameters and expand to a list
of expanded incoming parameters (one set of parameters per tool
Expand All @@ -176,33 +193,24 @@ def expand_meta_parameters(trans, tool, incoming: ToolRequestT) -> ExpandedT:
# order matters, so the following reorders incoming
# according to tool.inputs (which is ordered).
incoming_copy = incoming.copy()
nested_dict: Dict[str, Any] = {}
for incoming_key, incoming_value in incoming_copy.items():
if not incoming_key.startswith("__"):
process_key(incoming_key, incoming_value=incoming_value, d=nested_dict)

reordered_incoming = {}

def visitor(input, value, prefix, prefixed_name, prefixed_label, error, **kwargs):
if prefixed_name in incoming_copy:
reordered_incoming[prefixed_name] = incoming_copy[prefixed_name]
del incoming_copy[prefixed_name]
if input_format == "legacy":
nested_dict = expand_flat_parameters_to_nested(incoming_copy)
else:
nested_dict = incoming_copy

visit_input_values(inputs=tool.inputs, input_values=nested_dict, callback=visitor)
reordered_incoming.update(incoming_copy)
collections_to_match = matching.CollectionsToMatch()

def classifier(input_key):
value = incoming[input_key]
def classifier_from_value(value, input_key):
if isinstance(value, dict) and "values" in value:
# Explicit meta wrapper for inputs...
is_batch = value.get("batch", False)
is_linked = value.get("linked", True)
if is_batch and is_linked:
classification = permutations.input_classification.MATCHED
classification = input_classification.MATCHED
elif is_batch:
classification = permutations.input_classification.MULTIPLIED
classification = input_classification.MULTIPLIED
else:
classification = permutations.input_classification.SINGLE
classification = input_classification.SINGLE
if __collection_multirun_parameter(value):
collection_value = value["values"][0]
values = __expand_collection_parameter(
Expand All @@ -211,24 +219,114 @@ def classifier(input_key):
else:
values = value["values"]
else:
classification = permutations.input_classification.SINGLE
classification = input_classification.SINGLE
values = value
return classification, values

collections_to_match = matching.CollectionsToMatch()
nested = input_format != "legacy"
if not nested:
reordered_incoming = reorder_parameters(tool, incoming_copy, nested_dict, nested)
incoming_template = reordered_incoming

def classifier_flat(input_key):
return classifier_from_value(incoming[input_key], input_key)

# Stick an unexpanded version of multirun keys so they can be replaced,
# by expand_mult_inputs.
incoming_template = reordered_incoming
single_inputs, matched_multi_inputs, multiplied_multi_inputs = split_inputs_flat(
incoming_template, classifier_flat
)
else:
reordered_incoming = reorder_parameters(tool, incoming_copy, nested_dict, nested)
incoming_template = reordered_incoming
single_inputs, matched_multi_inputs, multiplied_multi_inputs = split_inputs_nested(
tool.inputs, incoming_template, classifier_from_value
)

expanded_incomings = permutations.expand_multi_inputs(incoming_template, classifier)
expanded_incomings = build_combos(single_inputs, matched_multi_inputs, multiplied_multi_inputs, nested=nested)
if collections_to_match.has_collections():
collection_info = trans.app.dataset_collection_manager.match_collections(collections_to_match)
else:
collection_info = None
return expanded_incomings, collection_info


def reorder_parameters(tool, incoming, nested_dict, nested):
# If we're going to multiply input dataset combinations
# order matters, so the following reorders incoming
# according to tool.inputs (which is ordered).
incoming_copy = state_copy(incoming, nested)

reordered_incoming = {}

def visitor(input, value, prefix, prefixed_name, prefixed_label, error, **kwargs):
if is_in_state(incoming_copy, prefixed_name, nested):
value_to_copy_over = state_get_value(incoming_copy, prefixed_name, nested)
state_set_value(reordered_incoming, prefixed_name, value_to_copy_over, nested)
state_remove_value(incoming_copy, prefixed_name, nested)

visit_input_values(inputs=tool.inputs, input_values=nested_dict, callback=visitor)

def merge_into(from_object, into_object):
if isinstance(from_object, dict):
for key, value in from_object.items():
if key not in into_object:
into_object[key] = value
else:
into_target = into_object[key]
merge_into(value, into_target)
elif isinstance(from_object, list):
for index in from_object:
if len(into_object) <= index:
into_object.append(from_object[index])
else:
merge_into(from_object[index], into_object[index])

merge_into(incoming_copy, reordered_incoming)
return reordered_incoming


def split_inputs_flat(inputs: Dict[str, Any], classifier):
single_inputs: Dict[str, Any] = {}
matched_multi_inputs: Dict[str, Any] = {}
multiplied_multi_inputs: Dict[str, Any] = {}

for input_key in inputs:
input_type, expanded_val = classifier(input_key)
if input_type == input_classification.SINGLE:
single_inputs[input_key] = expanded_val
elif input_type == input_classification.MATCHED:
matched_multi_inputs[input_key] = expanded_val
elif input_type == input_classification.MULTIPLIED:
multiplied_multi_inputs[input_key] = expanded_val

return (single_inputs, matched_multi_inputs, multiplied_multi_inputs)


def split_inputs_nested(inputs, nested_dict, classifier):
single_inputs: Dict[str, Any] = {}
matched_multi_inputs: Dict[str, Any] = {}
multiplied_multi_inputs: Dict[str, Any] = {}
unset_value = object()

def visitor(input, value, prefix, prefixed_name, prefixed_label, error, **kwargs):
if value is unset_value:
# don't want to inject extra nulls into state
return

input_type, expanded_val = classifier(value, prefixed_name)
if input_type == input_classification.SINGLE:
single_inputs[prefixed_name] = expanded_val
elif input_type == input_classification.MATCHED:
matched_multi_inputs[prefixed_name] = expanded_val
elif input_type == input_classification.MULTIPLIED:
multiplied_multi_inputs[prefixed_name] = expanded_val

visit_input_values(
inputs=inputs, input_values=nested_dict, callback=visitor, allow_case_inference=True, unset_value=unset_value
)
single_inputs_nested = expand_flat_parameters_to_nested(single_inputs)
return (single_inputs_nested, matched_multi_inputs, multiplied_multi_inputs)


def __expand_collection_parameter(trans, input_key, incoming_val, collections_to_match, linked=False):
# If subcollectin multirun of data_collection param - value will
# be "hdca_id|subcollection_type" else it will just be hdca_id
Expand Down
9 changes: 6 additions & 3 deletions lib/galaxy/tools/parameters/wrapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
InputValueWrapper,
SelectToolParameterWrapper,
)
from galaxy.util.permutations import (
looks_like_flattened_repeat_key,
split_flattened_repeat_key,
)

PARAMS_UNWRAPPED = object()

Expand Down Expand Up @@ -172,10 +176,9 @@ def process_key(incoming_key: str, incoming_value: Any, d: Dict[str, Any]):
# In case we get an empty repeat after we already filled in a repeat element
return
d[incoming_key] = incoming_value
elif key_parts[0].rsplit("_", 1)[-1].isdigit():
elif looks_like_flattened_repeat_key(key_parts[0]):
# Repeat
input_name, _index = key_parts[0].rsplit("_", 1)
index = int(_index)
input_name, index = split_flattened_repeat_key(key_parts[0])
d.setdefault(input_name, [])
newlist: List[Dict[Any, Any]] = [{} for _ in range(index + 1)]
d[input_name].extend(newlist[len(d[input_name]) :])
Expand Down
2 changes: 2 additions & 0 deletions lib/galaxy/tools/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,8 @@ def __init__(self, input_datasets: Optional[Dict[str, Any]] = None) -> None:
self.identifier_key_dict = {}

def identifier(self, dataset_value: str, input_values: Dict[str, str]) -> Optional[str]:
if isinstance(dataset_value, list):
raise TypeError(f"Expected {dataset_value} to be hashable")
element_identifier = None
if identifier_key := self.identifier_key_dict.get(dataset_value, None):
element_identifier = input_values.get(identifier_key, None)
Expand Down
Loading

0 comments on commit 3ff3843

Please sign in to comment.