From d022254709adb464cf521ed5b61c790530711fff Mon Sep 17 00:00:00 2001 From: John Chilton Date: Wed, 16 Oct 2024 15:46:29 -0400 Subject: [PATCH 1/2] Refactor complex method for clarity. --- lib/galaxy/tools/parameters/meta.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/lib/galaxy/tools/parameters/meta.py b/lib/galaxy/tools/parameters/meta.py index f2d8ba1a68d1..f6e8b70ca950 100644 --- a/lib/galaxy/tools/parameters/meta.py +++ b/lib/galaxy/tools/parameters/meta.py @@ -161,6 +161,14 @@ def is_batch(value): ExpandedT = Tuple[List[ToolStateJobInstanceT], Optional[matching.MatchingCollections]] +def expand_flat_parameters_to_nested(incoming_copy: ToolRequestT) -> Dict[str, Any]: + nested_dict: Dict[str, Any] = {} + for incoming_key, incoming_value in incoming_copy.items(): + if not incoming_key.startswith("__"): + process_key(incoming_key, incoming_value=incoming_value, d=nested_dict) + return nested_dict + + def expand_meta_parameters(trans, tool, incoming: ToolRequestT) -> ExpandedT: """ Take in a dictionary of raw incoming parameters and expand to a list @@ -176,11 +184,7 @@ def expand_meta_parameters(trans, tool, incoming: ToolRequestT) -> ExpandedT: # order matters, so the following reorders incoming # according to tool.inputs (which is ordered). incoming_copy = incoming.copy() - nested_dict: Dict[str, Any] = {} - for incoming_key, incoming_value in incoming_copy.items(): - if not incoming_key.startswith("__"): - process_key(incoming_key, incoming_value=incoming_value, d=nested_dict) - + nested_dict = expand_flat_parameters_to_nested(incoming_copy) reordered_incoming = {} def visitor(input, value, prefix, prefixed_name, prefixed_label, error, **kwargs): From 6d7842e66fd7be15ddc9a9dbe08246749a4d7e85 Mon Sep 17 00:00:00 2001 From: John Chilton Date: Sun, 13 Oct 2024 21:39:47 -0400 Subject: [PATCH 2/2] Numerous fixes for tool input format "21.01" parameters. --- lib/galaxy/tools/__init__.py | 4 +- lib/galaxy/tools/parameters/__init__.py | 39 ++-- lib/galaxy/tools/parameters/meta.py | 138 ++++++++++--- lib/galaxy/tools/parameters/wrapped.py | 9 +- lib/galaxy/tools/wrappers.py | 2 + lib/galaxy/util/permutations.py | 132 +++++++++---- lib/galaxy_test/api/conftest.py | 15 +- lib/galaxy_test/api/test_tool_execute.py | 187 +++++++++++++++++- lib/galaxy_test/api/test_tools.py | 106 ---------- .../test_extended_metadata_mapping.py | 15 ++ 10 files changed, 451 insertions(+), 196 deletions(-) diff --git a/lib/galaxy/tools/__init__.py b/lib/galaxy/tools/__init__.py index 24ddffd5875a..68bf154b0f7d 100644 --- a/lib/galaxy/tools/__init__.py +++ b/lib/galaxy/tools/__init__.py @@ -1836,7 +1836,9 @@ def expand_incoming( # Expand these out to individual parameters for given jobs (tool executions). expanded_incomings: List[ToolStateJobInstanceT] collection_info: Optional[MatchingCollections] - expanded_incomings, collection_info = expand_meta_parameters(request_context, self, incoming) + expanded_incomings, collection_info = expand_meta_parameters( + request_context, self, incoming, input_format=input_format + ) self._ensure_expansion_is_valid(expanded_incomings, rerun_remap_job_id) diff --git a/lib/galaxy/tools/parameters/__init__.py b/lib/galaxy/tools/parameters/__init__.py index 9de78f2fbff5..ee3a8709817d 100644 --- a/lib/galaxy/tools/parameters/__init__.py +++ b/lib/galaxy/tools/parameters/__init__.py @@ -62,6 +62,8 @@ def visit_input_values( context=None, no_replacement_value=REPLACE_ON_TRUTHY, replace_optional_connections=False, + allow_case_inference=False, + unset_value=None, ): """ Given a tools parameter definition (`inputs`) and a specific set of @@ -158,7 +160,7 @@ def visit_input_values( """ def callback_helper(input, input_values, name_prefix, label_prefix, parent_prefix, context=None, error=None): - value = input_values.get(input.name) + value = input_values.get(input.name, unset_value) args = { "input": input, "parent": input_values, @@ -182,13 +184,23 @@ def callback_helper(input, input_values, name_prefix, label_prefix, parent_prefi input_values[input.name] = input.value def get_current_case(input, input_values): + test_parameter = input.test_param + test_parameter_name = test_parameter.name try: - return input.get_current_case(input_values[input.test_param.name]) + if test_parameter_name not in input_values and allow_case_inference: + return input.get_current_case(test_parameter.get_initial_value(None, input_values)) + else: + return input.get_current_case(input_values[test_parameter_name]) except (KeyError, ValueError): return -1 context = ExpressionContext(input_values, context) - payload = {"context": context, "no_replacement_value": no_replacement_value} + payload = { + "context": context, + "no_replacement_value": no_replacement_value, + "allow_case_inference": allow_case_inference, + "unset_value": unset_value, + } for input in inputs.values(): if isinstance(input, Repeat) or isinstance(input, UploadDataset): values = input_values[input.name] = input_values.get(input.name, []) @@ -411,16 +423,15 @@ def populate_state( group_state = state[input.name] if input.type == "repeat": repeat_input = cast(Repeat, input) - if ( - len(incoming[repeat_input.name]) > repeat_input.max - or len(incoming[repeat_input.name]) < repeat_input.min + repeat_name = repeat_input.name + repeat_incoming = incoming.get(repeat_name) or [] + if repeat_incoming and ( + len(repeat_incoming) > repeat_input.max or len(repeat_incoming) < repeat_input.min ): - errors[repeat_input.name] = ( - "The number of repeat elements is outside the range specified by the tool." - ) + errors[repeat_name] = "The number of repeat elements is outside the range specified by the tool." else: del group_state[:] - for rep in incoming[repeat_input.name]: + for rep in repeat_incoming: new_state: ToolStateJobInstancePopulatedT = {} group_state.append(new_state) repeat_errors: ParameterValidationErrorsT = {} @@ -454,10 +465,13 @@ def populate_state( current_case = conditional_input.get_current_case(value) group_state = state[conditional_input.name] = {} cast_errors: ParameterValidationErrorsT = {} + incoming_for_conditional = cast( + ToolStateJobInstanceT, incoming.get(conditional_input.name) or {} + ) populate_state( request_context, conditional_input.cases[current_case].inputs, - cast(ToolStateJobInstanceT, incoming.get(conditional_input.name)), + incoming_for_conditional, group_state, cast_errors, context=context, @@ -475,10 +489,11 @@ def populate_state( elif input.type == "section": section_input = cast(Section, input) section_errors: ParameterValidationErrorsT = {} + incoming_for_state = cast(ToolStateJobInstanceT, incoming.get(section_input.name) or {}) populate_state( request_context, section_input.inputs, - cast(ToolStateJobInstanceT, incoming.get(section_input.name)), + incoming_for_state, group_state, section_errors, context=context, diff --git a/lib/galaxy/tools/parameters/meta.py b/lib/galaxy/tools/parameters/meta.py index f6e8b70ca950..b74df54fa269 100644 --- a/lib/galaxy/tools/parameters/meta.py +++ b/lib/galaxy/tools/parameters/meta.py @@ -19,10 +19,19 @@ matching, subcollections, ) -from galaxy.util import permutations +from galaxy.util.permutations import ( + build_combos, + input_classification, + is_in_state, + state_copy, + state_get_value, + state_remove_value, + state_set_value, +) from . import visit_input_values from .wrapped import process_key from .._types import ( + InputFormatT, ToolRequestT, ToolStateJobInstanceT, ) @@ -169,7 +178,7 @@ def expand_flat_parameters_to_nested(incoming_copy: ToolRequestT) -> Dict[str, A return nested_dict -def expand_meta_parameters(trans, tool, incoming: ToolRequestT) -> ExpandedT: +def expand_meta_parameters(trans, tool, incoming: ToolRequestT, input_format: InputFormatT) -> ExpandedT: """ Take in a dictionary of raw incoming parameters and expand to a list of expanded incoming parameters (one set of parameters per tool @@ -184,29 +193,24 @@ def expand_meta_parameters(trans, tool, incoming: ToolRequestT) -> ExpandedT: # order matters, so the following reorders incoming # according to tool.inputs (which is ordered). incoming_copy = incoming.copy() - nested_dict = expand_flat_parameters_to_nested(incoming_copy) - reordered_incoming = {} - - def visitor(input, value, prefix, prefixed_name, prefixed_label, error, **kwargs): - if prefixed_name in incoming_copy: - reordered_incoming[prefixed_name] = incoming_copy[prefixed_name] - del incoming_copy[prefixed_name] + if input_format == "legacy": + nested_dict = expand_flat_parameters_to_nested(incoming_copy) + else: + nested_dict = incoming_copy - visit_input_values(inputs=tool.inputs, input_values=nested_dict, callback=visitor) - reordered_incoming.update(incoming_copy) + collections_to_match = matching.CollectionsToMatch() - def classifier(input_key): - value = incoming[input_key] + def classifier_from_value(value, input_key): if isinstance(value, dict) and "values" in value: # Explicit meta wrapper for inputs... is_batch = value.get("batch", False) is_linked = value.get("linked", True) if is_batch and is_linked: - classification = permutations.input_classification.MATCHED + classification = input_classification.MATCHED elif is_batch: - classification = permutations.input_classification.MULTIPLIED + classification = input_classification.MULTIPLIED else: - classification = permutations.input_classification.SINGLE + classification = input_classification.SINGLE if __collection_multirun_parameter(value): collection_value = value["values"][0] values = __expand_collection_parameter( @@ -215,17 +219,29 @@ def classifier(input_key): else: values = value["values"] else: - classification = permutations.input_classification.SINGLE + classification = input_classification.SINGLE values = value return classification, values - collections_to_match = matching.CollectionsToMatch() + nested = input_format != "legacy" + if not nested: + reordered_incoming = reorder_parameters(tool, incoming_copy, nested_dict, nested) + incoming_template = reordered_incoming + + def classifier_flat(input_key): + return classifier_from_value(incoming[input_key], input_key) - # Stick an unexpanded version of multirun keys so they can be replaced, - # by expand_mult_inputs. - incoming_template = reordered_incoming + single_inputs, matched_multi_inputs, multiplied_multi_inputs = split_inputs_flat( + incoming_template, classifier_flat + ) + else: + reordered_incoming = reorder_parameters(tool, incoming_copy, nested_dict, nested) + incoming_template = reordered_incoming + single_inputs, matched_multi_inputs, multiplied_multi_inputs = split_inputs_nested( + tool.inputs, incoming_template, classifier_from_value + ) - expanded_incomings = permutations.expand_multi_inputs(incoming_template, classifier) + expanded_incomings = build_combos(single_inputs, matched_multi_inputs, multiplied_multi_inputs, nested=nested) if collections_to_match.has_collections(): collection_info = trans.app.dataset_collection_manager.match_collections(collections_to_match) else: @@ -233,6 +249,84 @@ def classifier(input_key): return expanded_incomings, collection_info +def reorder_parameters(tool, incoming, nested_dict, nested): + # If we're going to multiply input dataset combinations + # order matters, so the following reorders incoming + # according to tool.inputs (which is ordered). + incoming_copy = state_copy(incoming, nested) + + reordered_incoming = {} + + def visitor(input, value, prefix, prefixed_name, prefixed_label, error, **kwargs): + if is_in_state(incoming_copy, prefixed_name, nested): + value_to_copy_over = state_get_value(incoming_copy, prefixed_name, nested) + state_set_value(reordered_incoming, prefixed_name, value_to_copy_over, nested) + state_remove_value(incoming_copy, prefixed_name, nested) + + visit_input_values(inputs=tool.inputs, input_values=nested_dict, callback=visitor) + + def merge_into(from_object, into_object): + if isinstance(from_object, dict): + for key, value in from_object.items(): + if key not in into_object: + into_object[key] = value + else: + into_target = into_object[key] + merge_into(value, into_target) + elif isinstance(from_object, list): + for index in from_object: + if len(into_object) <= index: + into_object.append(from_object[index]) + else: + merge_into(from_object[index], into_object[index]) + + merge_into(incoming_copy, reordered_incoming) + return reordered_incoming + + +def split_inputs_flat(inputs: Dict[str, Any], classifier): + single_inputs: Dict[str, Any] = {} + matched_multi_inputs: Dict[str, Any] = {} + multiplied_multi_inputs: Dict[str, Any] = {} + + for input_key in inputs: + input_type, expanded_val = classifier(input_key) + if input_type == input_classification.SINGLE: + single_inputs[input_key] = expanded_val + elif input_type == input_classification.MATCHED: + matched_multi_inputs[input_key] = expanded_val + elif input_type == input_classification.MULTIPLIED: + multiplied_multi_inputs[input_key] = expanded_val + + return (single_inputs, matched_multi_inputs, multiplied_multi_inputs) + + +def split_inputs_nested(inputs, nested_dict, classifier): + single_inputs: Dict[str, Any] = {} + matched_multi_inputs: Dict[str, Any] = {} + multiplied_multi_inputs: Dict[str, Any] = {} + unset_value = object() + + def visitor(input, value, prefix, prefixed_name, prefixed_label, error, **kwargs): + if value is unset_value: + # don't want to inject extra nulls into state + return + + input_type, expanded_val = classifier(value, prefixed_name) + if input_type == input_classification.SINGLE: + single_inputs[prefixed_name] = expanded_val + elif input_type == input_classification.MATCHED: + matched_multi_inputs[prefixed_name] = expanded_val + elif input_type == input_classification.MULTIPLIED: + multiplied_multi_inputs[prefixed_name] = expanded_val + + visit_input_values( + inputs=inputs, input_values=nested_dict, callback=visitor, allow_case_inference=True, unset_value=unset_value + ) + single_inputs_nested = expand_flat_parameters_to_nested(single_inputs) + return (single_inputs_nested, matched_multi_inputs, multiplied_multi_inputs) + + def __expand_collection_parameter(trans, input_key, incoming_val, collections_to_match, linked=False): # If subcollectin multirun of data_collection param - value will # be "hdca_id|subcollection_type" else it will just be hdca_id diff --git a/lib/galaxy/tools/parameters/wrapped.py b/lib/galaxy/tools/parameters/wrapped.py index 11fa98c0e644..d23e9be5edf9 100644 --- a/lib/galaxy/tools/parameters/wrapped.py +++ b/lib/galaxy/tools/parameters/wrapped.py @@ -25,6 +25,10 @@ InputValueWrapper, SelectToolParameterWrapper, ) +from galaxy.util.permutations import ( + looks_like_flattened_repeat_key, + split_flattened_repeat_key, +) PARAMS_UNWRAPPED = object() @@ -172,10 +176,9 @@ def process_key(incoming_key: str, incoming_value: Any, d: Dict[str, Any]): # In case we get an empty repeat after we already filled in a repeat element return d[incoming_key] = incoming_value - elif key_parts[0].rsplit("_", 1)[-1].isdigit(): + elif looks_like_flattened_repeat_key(key_parts[0]): # Repeat - input_name, _index = key_parts[0].rsplit("_", 1) - index = int(_index) + input_name, index = split_flattened_repeat_key(key_parts[0]) d.setdefault(input_name, []) newlist: List[Dict[Any, Any]] = [{} for _ in range(index + 1)] d[input_name].extend(newlist[len(d[input_name]) :]) diff --git a/lib/galaxy/tools/wrappers.py b/lib/galaxy/tools/wrappers.py index 6dd47e7f9bb4..f5d1d6a5a0bb 100644 --- a/lib/galaxy/tools/wrappers.py +++ b/lib/galaxy/tools/wrappers.py @@ -802,6 +802,8 @@ def __init__(self, input_datasets: Optional[Dict[str, Any]] = None) -> None: self.identifier_key_dict = {} def identifier(self, dataset_value: str, input_values: Dict[str, str]) -> Optional[str]: + if isinstance(dataset_value, list): + raise TypeError(f"Expected {dataset_value} to be hashable") element_identifier = None if identifier_key := self.identifier_key_dict.get(dataset_value, None): element_identifier = input_values.get(identifier_key, None) diff --git a/lib/galaxy/util/permutations.py b/lib/galaxy/util/permutations.py index 5dd1b11ee8fa..92573d8c6dbc 100644 --- a/lib/galaxy/util/permutations.py +++ b/lib/galaxy/util/permutations.py @@ -7,10 +7,8 @@ with itertools product and permutations. These are open questions. """ -from typing import ( - Dict, - TypeVar, -) +import copy +from typing import Tuple from galaxy.exceptions import MessageException from galaxy.util.bunch import Bunch @@ -21,47 +19,20 @@ MULTIPLIED="multiplied", ) -# generic type of splitting input dictionary -T = TypeVar("T") - class InputMatchedException(MessageException): """Indicates problem matching inputs while building up inputs permutations.""" -def expand_multi_inputs(inputs: Dict[str, T], classifier, key_filter=None): - key_filter = key_filter or (lambda x: True) - - single_inputs, matched_multi_inputs, multiplied_multi_inputs = __split_inputs(inputs, classifier, key_filter) - +def build_combos(single_inputs, matched_multi_inputs, multiplied_multi_inputs, nested): # Build up every combination of inputs to be run together. - input_combos = __extend_with_matched_combos(single_inputs, matched_multi_inputs) - input_combos = __extend_with_multiplied_combos(input_combos, multiplied_multi_inputs) - + input_combos = __extend_with_matched_combos(single_inputs, matched_multi_inputs, nested) + input_combos = __extend_with_multiplied_combos(input_combos, multiplied_multi_inputs, nested) return input_combos -def __split_inputs(inputs: Dict[str, T], classifier, key_filter): - key_filter = key_filter or (lambda x: True) - - single_inputs: Dict[str, T] = {} - matched_multi_inputs: Dict[str, T] = {} - multiplied_multi_inputs: Dict[str, T] = {} - - for input_key in filter(key_filter, inputs): - input_type, expanded_val = classifier(input_key) - if input_type == input_classification.SINGLE: - single_inputs[input_key] = expanded_val - elif input_type == input_classification.MATCHED: - matched_multi_inputs[input_key] = expanded_val - elif input_type == input_classification.MULTIPLIED: - multiplied_multi_inputs[input_key] = expanded_val - - return (single_inputs, matched_multi_inputs, multiplied_multi_inputs) - - -def __extend_with_matched_combos(single_inputs, multi_inputs): +def __extend_with_matched_combos(single_inputs, multi_inputs, nested): """ {a => 1, b => 2} and {c => {3, 4}, d => {5, 6}} @@ -81,7 +52,7 @@ def __extend_with_matched_combos(single_inputs, multi_inputs): first_multi_value = multi_inputs.get(first_multi_input_key) for value in first_multi_value: - new_inputs = __copy_and_extend_inputs(single_inputs, first_multi_input_key, value) + new_inputs = __copy_and_extend_inputs(single_inputs, first_multi_input_key, value, nested=nested) matched_multi_inputs.append(new_inputs) for multi_input_key, multi_input_values in multi_inputs.items(): @@ -94,12 +65,12 @@ def __extend_with_matched_combos(single_inputs, multi_inputs): ) for index, value in enumerate(multi_input_values): - matched_multi_inputs[index][multi_input_key] = value + state_set_value(matched_multi_inputs[index], multi_input_key, value, nested) return matched_multi_inputs -def __extend_with_multiplied_combos(input_combos, multi_inputs): +def __extend_with_multiplied_combos(input_combos, multi_inputs, nested): combos = input_combos for multi_input_key, multi_input_value in multi_inputs.items(): @@ -107,7 +78,7 @@ def __extend_with_multiplied_combos(input_combos, multi_inputs): for combo in combos: for input_value in multi_input_value: - iter_combo = __copy_and_extend_inputs(combo, multi_input_key, input_value) + iter_combo = __copy_and_extend_inputs(combo, multi_input_key, input_value, nested) iter_combos.append(iter_combo) combos = iter_combos @@ -115,7 +86,84 @@ def __extend_with_multiplied_combos(input_combos, multi_inputs): return combos -def __copy_and_extend_inputs(inputs, key, value): - new_inputs = dict(inputs) - new_inputs[key] = value +def __copy_and_extend_inputs(inputs, key, value, nested): + # can't deepcopy dicts with our models for reason I don't understand, + # test_map_over_two_collections_unlinked breaks if I try to combine these two branches of the if + new_inputs = state_copy(inputs, nested) + state_set_value(new_inputs, key, value, nested) return new_inputs + + +def state_copy(inputs, nested): + # can't deepcopy dicts with our models for reason I don't understand, + # test_map_over_two_collections_unlinked breaks if I try to combine these two branches of the if + if nested: + state_dict_copy = copy.deepcopy(inputs) + else: + state_dict_copy = dict(inputs) + return state_dict_copy + + +def state_set_value(state_dict, key, value, nested): + if "|" not in key or not nested: + state_dict[key] = value + else: + first, rest = key.split("|", 1) + if first not in state_dict and looks_like_flattened_repeat_key(first): + repeat_name, index = split_flattened_repeat_key(first) + if repeat_name not in state_dict: + state_dict[repeat_name] = [] + repeat_state = state_dict[repeat_name] + while len(repeat_state) <= index: + repeat_state.append({}) + state_set_value(repeat_state[index], rest, value, nested) + else: + state_set_value(state_dict[first], rest, value, nested) + + +def state_remove_value(state_dict, key, nested): + if "|" not in key or not nested: + del state_dict[key] + else: + first, rest = key.split("|", 1) + child_dict = state_dict[first] + # repeats? + if "|" in rest: + state_remove_value(child_dict, rest, nested) + else: + del child_dict[rest] + if len(child_dict) == 0: + del state_dict[first] + + +def state_get_value(state_dict, key, nested): + if "|" not in key or not nested: + return state_dict[key] + else: + first, rest = key.split("|", 1) + if first not in state_dict and looks_like_flattened_repeat_key(first): + repeat_name, index = split_flattened_repeat_key(first) + return state_get_value(state_dict[repeat_name][index], rest, nested) + else: + return state_get_value(state_dict[first], rest, nested) + + +def is_in_state(state_dict, key, nested): + if not state_dict: + return False + if "|" not in key or not nested: + return key in state_dict + else: + first, rest = key.split("|", 1) + # repeats? + is_in_state(state_dict.get(first), rest, nested) + + +def looks_like_flattened_repeat_key(key: str) -> bool: + return key.rsplit("_", 1)[-1].isdigit() + + +def split_flattened_repeat_key(key: str) -> Tuple[str, int]: + input_name, _index = key.rsplit("_", 1) + index = int(_index) + return input_name, index diff --git a/lib/galaxy_test/api/conftest.py b/lib/galaxy_test/api/conftest.py index a011532c7af5..74d8958b9158 100644 --- a/lib/galaxy_test/api/conftest.py +++ b/lib/galaxy_test/api/conftest.py @@ -157,15 +157,24 @@ def tool_input_format(request) -> Iterator[DescribeToolInputs]: def check_required_tools(anonymous_galaxy_interactor, request): for marker in request.node.iter_markers(): if marker.name == "requires_tool_id": - tool_id = marker.args[0] + tool_id = _requires_marker_to_effective_tool_id(anonymous_galaxy_interactor, marker) check_missing_tool(tool_id not in get_tool_ids(anonymous_galaxy_interactor)) @pytest.fixture -def required_tool_ids(request) -> List[str]: +def required_tool_ids(anonymous_galaxy_interactor, request) -> List[str]: tool_ids = [] for marker in request.node.iter_markers(): if marker.name == "requires_tool_id": - tool_id = marker.args[0] + tool_id = _requires_marker_to_effective_tool_id(anonymous_galaxy_interactor, marker) tool_ids.append(tool_id) return tool_ids + + +def _requires_marker_to_effective_tool_id(anonymous_galaxy_interactor, marker): + tool_id = marker.args[0] + if "|" in tool_id: + any_of_tool_ids = tool_id.split("|") + all_tool_ids = get_tool_ids(anonymous_galaxy_interactor) + tool_id = [t for t in any_of_tool_ids if t in all_tool_ids][0] + return tool_id diff --git a/lib/galaxy_test/api/test_tool_execute.py b/lib/galaxy_test/api/test_tool_execute.py index 3d9c60f0ef79..95bf43e27921 100644 --- a/lib/galaxy_test/api/test_tool_execute.py +++ b/lib/galaxy_test/api/test_tool_execute.py @@ -7,12 +7,17 @@ files, etc..). """ +from dataclasses import dataclass from typing import List +import pytest + from galaxy_test.base.decorators import requires_tool_id from galaxy_test.base.populators import ( + DescribeToolExecution, DescribeToolInputs, RequiredTool, + SrcDict, TargetHistory, ) @@ -104,28 +109,43 @@ def test_identifier_map_over_multiple_input_in_conditional( @requires_tool_id("identifier_multiple_in_repeat") -def test_identifier_multiple_reduce_in_repeat_new_payload_form( - target_history: TargetHistory, required_tool: RequiredTool +def test_identifier_multiple_reduce_in_repeat( + target_history: TargetHistory, required_tool: RequiredTool, tool_input_format: DescribeToolInputs ): hdca = target_history.with_pair() - execute = required_tool.execute.with_nested_inputs( + inputs = tool_input_format.when.nested( { "the_repeat": [{"the_data": {"input1": hdca.src_dict}}], } + ).when.flat( + { + "the_repeat_0|the_data|input1": hdca.src_dict, + } ) + execute = required_tool.execute.with_inputs(inputs) execute.assert_has_single_job.assert_has_single_output.with_contents_stripped("forward\nreverse") @requires_tool_id("output_action_change_format") -def test_map_over_with_output_format_actions(target_history: TargetHistory, required_tool: RequiredTool): +def test_map_over_with_output_format_actions( + target_history: TargetHistory, required_tool: RequiredTool, tool_input_format: DescribeToolInputs +): hdca = target_history.with_pair() for use_action in ["do", "dont"]: - execute = required_tool.execute.with_inputs( + inputs = tool_input_format.when.flat( { "input_cond|dispatch": use_action, "input_cond|input": {"batch": True, "values": [hdca.src_dict]}, } + ).when.nested( + { + "input_cond": { + "dispatch": use_action, + "input": {"batch": True, "values": [hdca.src_dict]}, + } + } ) + execute = required_tool.execute.with_inputs(inputs) execute.assert_has_n_jobs(2).assert_creates_n_implicit_collections(1) expected_extension = "txt" if (use_action == "do") else "data" execute.assert_has_job(0).with_single_output.with_file_ext(expected_extension) @@ -196,7 +216,7 @@ def test_identifier_with_multiple_normal_datasets(target_history: TargetHistory, execute.assert_has_single_job.assert_has_single_output.with_contents_stripped("Normal HDA1\nNormal HDA2") -@requires_tool_id("cat1") +@requires_tool_id("cat|cat1") def test_map_over_empty_collection(target_history: TargetHistory, required_tool: RequiredTool): hdca = target_history.with_list([]) inputs = { @@ -204,7 +224,160 @@ def test_map_over_empty_collection(target_history: TargetHistory, required_tool: } execute = required_tool.execute.with_inputs(inputs) execute.assert_has_n_jobs(0) - execute.assert_creates_implicit_collection(0).named("Concatenate datasets on collection 1") + name = execute.assert_creates_implicit_collection(0).details["name"] + assert "Concatenate datasets" in name + assert "on collection 1" in name + + +@dataclass +class MultiRunInRepeatFixtures: + repeat_datasets: List[SrcDict] + common_dataset: SrcDict + + +@pytest.fixture +def multi_run_in_repeat_datasets(target_history: TargetHistory) -> MultiRunInRepeatFixtures: + dataset1 = target_history.with_dataset("123").src_dict + dataset2 = target_history.with_dataset("456").src_dict + common_dataset = target_history.with_dataset("Common").src_dict + return MultiRunInRepeatFixtures([dataset1, dataset2], common_dataset) + + +@requires_tool_id("cat|cat1") +def test_multi_run_in_repeat( + required_tool: RequiredTool, + multi_run_in_repeat_datasets: MultiRunInRepeatFixtures, + tool_input_format: DescribeToolInputs, +): + inputs = tool_input_format.when.flat( + { + "input1": {"batch": False, "values": [multi_run_in_repeat_datasets.common_dataset]}, + "queries_0|input2": {"batch": True, "values": multi_run_in_repeat_datasets.repeat_datasets}, + } + ).when.nested( + { + "input1": {"batch": False, "values": [multi_run_in_repeat_datasets.common_dataset]}, + "queries": [ + { + "input2": {"batch": True, "values": multi_run_in_repeat_datasets.repeat_datasets}, + } + ], + } + ) + execute = required_tool.execute.with_inputs(inputs) + _check_multi_run_in_repeat(execute) + + +@requires_tool_id("cat|cat1") +def test_multi_run_in_repeat_mismatch( + required_tool: RequiredTool, + multi_run_in_repeat_datasets: MultiRunInRepeatFixtures, + tool_input_format: DescribeToolInputs, +): + """Same test as above but without the batch wrapper around the common dataset shared between multirun.""" + inputs = tool_input_format.when.flat( + { + "input1": multi_run_in_repeat_datasets.common_dataset, + "queries_0|input2": {"batch": True, "values": multi_run_in_repeat_datasets.repeat_datasets}, + } + ).when.nested( + { + "input1": multi_run_in_repeat_datasets.common_dataset, + "queries": [ + { + "input2": {"batch": True, "values": multi_run_in_repeat_datasets.repeat_datasets}, + } + ], + } + ) + execute = required_tool.execute.with_inputs(inputs) + _check_multi_run_in_repeat(execute) + + +def _check_multi_run_in_repeat(execute: DescribeToolExecution): + execute.assert_has_n_jobs(2) + execute.assert_has_job(0).with_single_output.with_contents_stripped("Common\n123") + execute.assert_has_job(1).with_single_output.with_contents_stripped("Common\n456") + + +@dataclass +class TwoMultiRunsFixture: + first_two_datasets: List[SrcDict] + second_two_datasets: List[SrcDict] + + +@pytest.fixture +def two_multi_run_datasets(target_history: TargetHistory) -> TwoMultiRunsFixture: + dataset1 = target_history.with_dataset("123").src_dict + dataset2 = target_history.with_dataset("456").src_dict + dataset3 = target_history.with_dataset("789").src_dict + dataset4 = target_history.with_dataset("0ab").src_dict + return TwoMultiRunsFixture([dataset1, dataset2], [dataset3, dataset4]) + + +@requires_tool_id("cat|cat1") +def test_multirun_on_multiple_inputs( + required_tool: RequiredTool, + two_multi_run_datasets: TwoMultiRunsFixture, + tool_input_format: DescribeToolInputs, +): + inputs = tool_input_format.when.flat( + { + "input1": {"batch": True, "values": two_multi_run_datasets.first_two_datasets}, + "queries_0|input2": {"batch": True, "values": two_multi_run_datasets.second_two_datasets}, + } + ).when.nested( + { + "input1": {"batch": True, "values": two_multi_run_datasets.first_two_datasets}, + "queries": [ + {"input2": {"batch": True, "values": two_multi_run_datasets.second_two_datasets}}, + ], + } + ) + execute = required_tool.execute.with_inputs(inputs) + execute.assert_has_n_jobs(2) + execute.assert_has_job(0).with_single_output.with_contents_stripped("123\n789") + execute.assert_has_job(1).with_single_output.with_contents_stripped("456\n0ab") + + +@requires_tool_id("cat|cat1") +def test_multirun_on_multiple_inputs_unlinked( + required_tool: RequiredTool, + two_multi_run_datasets: TwoMultiRunsFixture, + tool_input_format: DescribeToolInputs, +): + inputs = tool_input_format.when.flat( + { + "input1": {"batch": True, "linked": False, "values": two_multi_run_datasets.first_two_datasets}, + "queries_0|input2": {"batch": True, "linked": False, "values": two_multi_run_datasets.second_two_datasets}, + } + ).when.nested( + { + "input1": {"batch": True, "linked": False, "values": two_multi_run_datasets.first_two_datasets}, + "queries": [ + {"input2": {"batch": True, "linked": False, "values": two_multi_run_datasets.second_two_datasets}}, + ], + } + ) + execute = required_tool.execute.with_inputs(inputs) + execute.assert_has_n_jobs(4) + execute.assert_has_job(0).with_single_output.with_contents_stripped("123\n789") + execute.assert_has_job(1).with_single_output.with_contents_stripped("123\n0ab") + execute.assert_has_job(2).with_single_output.with_contents_stripped("456\n789") + execute.assert_has_job(3).with_single_output.with_contents_stripped("456\n0ab") + + +@requires_tool_id("cat|cat1") +def test_map_over_collection( + target_history: TargetHistory, required_tool: RequiredTool, tool_input_format: DescribeToolInputs +): + hdca = target_history.with_pair(["123", "456"]) + inputs = tool_input_format.when.any({"input1": {"batch": True, "values": [hdca.src_dict]}}) + execute = required_tool.execute.with_inputs(inputs) + execute.assert_has_n_jobs(2).assert_creates_n_implicit_collections(1) + output_collection = execute.assert_creates_implicit_collection(0) + output_collection.assert_has_dataset_element("forward").with_contents_stripped("123") + output_collection.assert_has_dataset_element("reverse").with_contents_stripped("456") @requires_tool_id("gx_repeat_boolean_min") diff --git a/lib/galaxy_test/api/test_tools.py b/lib/galaxy_test/api/test_tools.py index 002b9d873389..0f261de94016 100644 --- a/lib/galaxy_test/api/test_tools.py +++ b/lib/galaxy_test/api/test_tools.py @@ -71,21 +71,6 @@ def _build_pair(self, history_id, contents): hdca_id = create_response.json()["outputs"][0]["id"] return hdca_id - def _run_and_check_simple_collection_mapping(self, history_id, inputs): - create = self._run_cat(history_id, inputs=inputs, assert_ok=True) - outputs = create["outputs"] - jobs = create["jobs"] - implicit_collections = create["implicit_collections"] - assert len(jobs) == 2 - assert len(outputs) == 2 - assert len(implicit_collections) == 1 - output1 = outputs[0] - output2 = outputs[1] - output1_content = self.dataset_populator.get_history_dataset_content(history_id, dataset=output1) - output2_content = self.dataset_populator.get_history_dataset_content(history_id, dataset=output2) - assert output1_content.strip() == "123" - assert output2_content.strip() == "456" - def _run_cat(self, history_id, inputs, assert_ok=False, **kwargs): return self._run("cat", history_id, inputs, assert_ok=assert_ok, **kwargs) @@ -1503,56 +1488,6 @@ def test_multirun_non_data_parameter(self, history_id): ] assert sorted(len(c.split("\n")) for c in outputs_contents) == [1, 2, 3] - @skip_without_tool("cat1") - def test_multirun_in_repeat(self): - history_id, common_dataset, repeat_datasets = self._setup_repeat_multirun() - inputs = { - "input1": common_dataset, - "queries_0|input2": {"batch": True, "values": repeat_datasets}, - } - self._check_repeat_multirun(history_id, inputs) - - @skip_without_tool("cat1") - def test_multirun_in_repeat_mismatch(self): - history_id, common_dataset, repeat_datasets = self._setup_repeat_multirun() - inputs = { - "input1": {"batch": False, "values": [common_dataset]}, - "queries_0|input2": {"batch": True, "values": repeat_datasets}, - } - self._check_repeat_multirun(history_id, inputs) - - @skip_without_tool("cat1") - def test_multirun_on_multiple_inputs(self): - history_id, first_two, second_two = self._setup_two_multiruns() - inputs = { - "input1": {"batch": True, "values": first_two}, - "queries_0|input2": {"batch": True, "values": second_two}, - } - outputs = self._cat1_outputs(history_id, inputs=inputs) - assert len(outputs) == 2 - outputs_contents = [ - self.dataset_populator.get_history_dataset_content(history_id, dataset=o).strip() for o in outputs - ] - assert "123\n789" in outputs_contents - assert "456\n0ab" in outputs_contents - - @skip_without_tool("cat1") - def test_multirun_on_multiple_inputs_unlinked(self): - history_id, first_two, second_two = self._setup_two_multiruns() - inputs = { - "input1": {"batch": True, "linked": False, "values": first_two}, - "queries_0|input2": {"batch": True, "linked": False, "values": second_two}, - } - outputs = self._cat1_outputs(history_id, inputs=inputs) - outputs_contents = [ - self.dataset_populator.get_history_dataset_content(history_id, dataset=o).strip() for o in outputs - ] - assert len(outputs) == 4 - assert "123\n789" in outputs_contents - assert "456\n0ab" in outputs_contents - assert "123\n0ab" in outputs_contents - assert "456\n789" in outputs_contents - @skip_without_tool("dbkey_output_action") def test_dynamic_parameter_error_handling(self): # Run test with valid index once, then supply invalid dbkey and invalid table @@ -1676,47 +1611,6 @@ def _verify_element(self, history_id, element, **props): for key, value in props.items(): assert details[key] == value - def _setup_repeat_multirun(self): - history_id = self.dataset_populator.new_history() - new_dataset1 = self.dataset_populator.new_dataset(history_id, content="123") - new_dataset2 = self.dataset_populator.new_dataset(history_id, content="456") - common_dataset = self.dataset_populator.new_dataset(history_id, content="Common") - return ( - history_id, - dataset_to_param(common_dataset), - [dataset_to_param(new_dataset1), dataset_to_param(new_dataset2)], - ) - - def _check_repeat_multirun(self, history_id, inputs): - outputs = self._cat1_outputs(history_id, inputs=inputs) - assert len(outputs) == 2 - output1 = outputs[0] - output2 = outputs[1] - output1_content = self.dataset_populator.get_history_dataset_content(history_id, dataset=output1) - output2_content = self.dataset_populator.get_history_dataset_content(history_id, dataset=output2) - assert output1_content.strip() == "Common\n123" - assert output2_content.strip() == "Common\n456" - - def _setup_two_multiruns(self): - history_id = self.dataset_populator.new_history() - new_dataset1 = self.dataset_populator.new_dataset(history_id, content="123") - new_dataset2 = self.dataset_populator.new_dataset(history_id, content="456") - new_dataset3 = self.dataset_populator.new_dataset(history_id, content="789") - new_dataset4 = self.dataset_populator.new_dataset(history_id, content="0ab") - return ( - history_id, - [dataset_to_param(new_dataset1), dataset_to_param(new_dataset2)], - [dataset_to_param(new_dataset3), dataset_to_param(new_dataset4)], - ) - - @skip_without_tool("cat") - def test_map_over_collection(self, history_id): - hdca_id = self._build_pair(history_id, ["123", "456"]) - inputs = { - "input1": {"batch": True, "values": [{"src": "hdca", "id": hdca_id}]}, - } - self._run_and_check_simple_collection_mapping(history_id, inputs) - @skip_without_tool("output_filter_with_input") def test_map_over_with_output_filter_no_filtering(self, history_id): hdca_id = self.dataset_collection_populator.create_list_in_history(history_id, wait=True).json()["outputs"][0][ diff --git a/test/integration/test_extended_metadata_mapping.py b/test/integration/test_extended_metadata_mapping.py index b4e088146492..6dd1cbf6c747 100644 --- a/test/integration/test_extended_metadata_mapping.py +++ b/test/integration/test_extended_metadata_mapping.py @@ -29,3 +29,18 @@ def test_map_over_collection(self, history_id): "input1": {"batch": True, "values": [{"src": "hdca", "id": hdca_id}]}, } self._run_and_check_simple_collection_mapping(history_id, inputs) + + def _run_and_check_simple_collection_mapping(self, history_id, inputs): + create = self._run_cat(history_id, inputs=inputs, assert_ok=True) + outputs = create["outputs"] + jobs = create["jobs"] + implicit_collections = create["implicit_collections"] + assert len(jobs) == 2 + assert len(outputs) == 2 + assert len(implicit_collections) == 1 + output1 = outputs[0] + output2 = outputs[1] + output1_content = self.dataset_populator.get_history_dataset_content(history_id, dataset=output1) + output2_content = self.dataset_populator.get_history_dataset_content(history_id, dataset=output2) + assert output1_content.strip() == "123" + assert output2_content.strip() == "456"