diff --git a/src/styx/compiler/compile/inputs.py b/src/styx/compiler/compile/inputs.py index 035d7a1..97be9d8 100644 --- a/src/styx/compiler/compile/inputs.py +++ b/src/styx/compiler/compile/inputs.py @@ -236,7 +236,7 @@ def _bt_template_str_parse( """Parse a Boutiques command line template string into segments.""" bt_template_str = boutiques_split_command(input_command_line_template) - bt_id_inputs = {input_.data.template_key: input_ for input_ in inputs} + template_key_inputs = {input_.data.template_key: input_ for input_ in inputs} segments: list[list[str | WithSymbol[InputArgument]]] = [] @@ -250,13 +250,12 @@ def _bt_template_str_parse( token = stack.pop() if isinstance(token, str): any_match = False - for _, bt_input in bt_id_inputs.items(): - value_key = bt_input.data.internal_id - if value_key == token: + for template_key, bt_input in template_key_inputs.items(): + if template_key == token: stack.append(bt_input) any_match = True break - o = token.split(value_key, 1) + o = token.split(template_key, 1) if len(o) == 2: stack.append(o[0]) stack.append(bt_input) diff --git a/src/styx/compiler/compile/outputs.py b/src/styx/compiler/compile/outputs.py index 93360ad..1e1cf0f 100644 --- a/src/styx/compiler/compile/outputs.py +++ b/src/styx/compiler/compile/outputs.py @@ -1,5 +1,5 @@ from styx.compiler.compile.common import SharedScopes, SharedSymbols -from styx.model.core import InputArgument, OutputArgument, WithSymbol +from styx.model.core import InputArgument, InputTypePrimitive, OutputArgument, WithSymbol from styx.pycodegen.core import PyFunc, PyModule, indent from styx.pycodegen.utils import as_py_literal, enbrace, enquote @@ -18,6 +18,8 @@ def generate_outputs_definition( '"""', f"Output object returned when calling `{symbols.function}(...)`.", '"""', + "root: OutputPathType", + '"""Output root folder. This is the root folder for all outputs."""', ]), ]) for out in outputs: @@ -54,16 +56,34 @@ def generate_output_building( func.body.append(f"{symbols.ret} = {symbols.output_class}(") + # Set root output path + func.body.extend(indent([f'root={symbols.execution}.output_file("."),'])) + for out in outputs: strip_extensions = out.data.stripped_file_extensions is not None if out.data.path_template is not None: s = out.data.path_template - for a in inputs: + for input_ in inputs: + if input_.data.template_key not in s: + continue + + substitute = input_.symbol + + if input_.data.type.primitive == InputTypePrimitive.File: + # Just use the stem of the file + # This is commonly used when output files 'inherit' the name of an input file + substitute = f"pathlib.Path({substitute}).stem" + elif input_.data.type.primitive != InputTypePrimitive.String: + raise Exception( + f"Unsupported input type {input_.data.type.primitive} " + f"for output path template of '{out.data.name}'." + ) + if strip_extensions: exts = as_py_literal(out.data.stripped_file_extensions, "'") - s = s.replace(f"{a.data.internal_id}", enbrace(f"{py_rstrip_fun}({a.symbol}, {exts})")) - else: - s = s.replace(f"{a.data.internal_id}", enbrace(a.symbol)) + substitute = f"{py_rstrip_fun}({substitute}, {exts})" + + s = s.replace(input_.data.template_key, enbrace(substitute)) s_optional = ", optional=True" if out.data.optional else ""