diff --git a/src/styx/backend/__init__.py b/src/styx/backend/__init__.py new file mode 100644 index 0000000..cb28796 --- /dev/null +++ b/src/styx/backend/__init__.py @@ -0,0 +1 @@ +"""Styx Python backend.""" diff --git a/src/styx/backend/python/__init__.py b/src/styx/backend/python/__init__.py new file mode 100644 index 0000000..2cec5fb --- /dev/null +++ b/src/styx/backend/python/__init__.py @@ -0,0 +1,3 @@ +"""Python wrapper backend.""" + +from .core import to_python as to_python diff --git a/src/styx/backend/python/constraints.py b/src/styx/backend/python/constraints.py new file mode 100644 index 0000000..5ad9a71 --- /dev/null +++ b/src/styx/backend/python/constraints.py @@ -0,0 +1,172 @@ +import styx.ir.core as ir +from styx.backend.python.lookup import LookupParam +from styx.backend.python.pycodegen.core import LineBuffer, PyFunc, indent + + +def _generate_raise_value_err(obj: str, expectation: str, reality: str | None = None) -> LineBuffer: + fstr = "" + if "{" in obj or "{" in expectation or (reality is not None and "{" in reality): + fstr = "f" + + return ( + [f'raise ValueError({fstr}"{obj} must be {expectation} but was {reality}")'] + if reality is not None + else [f'raise ValueError({fstr}"{obj} must be {expectation}")'] + ) + + +def _param_compile_constraint_checks(buf: LineBuffer, param: ir.IParam, lookup: LookupParam) -> None: + """Generate input constraint validation code for an input argument.""" + py_symbol = lookup.py_symbol[param.param.id_] + + min_value: float | int | None = None + max_value: float | int | None = None + list_count_min: int | None = None + list_count_max: int | None = None + + if isinstance(param, (ir.IFloat, ir.IInt)): + min_value = param.min_value + max_value = param.max_value + elif isinstance(param, ir.IList): + list_count_min = param.list_.count_min + list_count_max = param.list_.count_max + + val_opt = "" + if isinstance(param, ir.IOptional): + val_opt = f"{py_symbol} is not None and " + + # List argument length validation + if list_count_min is not None and list_count_max is not None: + # Case: len(list[]) == X + assert list_count_min <= list_count_max + if list_count_min == list_count_max: + buf.extend([ + f"if {val_opt}(len({py_symbol}) != {list_count_min}): ", + *indent( + _generate_raise_value_err( + f"Length of '{py_symbol}'", + f"{list_count_min}", + f"{{len({py_symbol})}}", + ) + ), + ]) + else: + # Case: X <= len(list[]) <= Y + buf.extend([ + f"if {val_opt}not ({list_count_min} <= " f"len({py_symbol}) <= {list_count_max}): ", + *indent( + _generate_raise_value_err( + f"Length of '{py_symbol}'", + f"between {list_count_min} and {list_count_max}", + f"{{len({py_symbol})}}", + ) + ), + ]) + elif list_count_min is not None: + # Case len(list[]) >= X + buf.extend([ + f"if {val_opt}not ({list_count_min} <= len({py_symbol})): ", + *indent( + _generate_raise_value_err( + f"Length of '{py_symbol}'", + f"greater than {list_count_min}", + f"{{len({py_symbol})}}", + ) + ), + ]) + elif list_count_max is not None: + # Case len(list[]) <= X + buf.extend([ + f"if {val_opt}not (len({py_symbol}) <= {list_count_max}): ", + *indent( + _generate_raise_value_err( + f"Length of '{py_symbol}'", + f"less than {list_count_max}", + f"{{len({py_symbol})}}", + ) + ), + ]) + + # Numeric argument range validation + op_min = "<=" + op_max = "<=" + if min_value is not None and max_value is not None: + # Case: X <= arg <= Y + assert min_value <= max_value + if isinstance(param, ir.IList): + buf.extend([ + f"if {val_opt}not ({min_value} {op_min} min({py_symbol}) " + f"and max({py_symbol}) {op_max} {max_value}): ", + *indent( + _generate_raise_value_err( + f"All elements of '{py_symbol}'", + f"between {min_value} {op_min} x {op_max} {max_value}", + ) + ), + ]) + else: + buf.extend([ + f"if {val_opt}not ({min_value} {op_min} {py_symbol} {op_max} {max_value}): ", + *indent( + _generate_raise_value_err( + f"'{py_symbol}'", + f"between {min_value} {op_min} x {op_max} {max_value}", + f"{{{py_symbol}}}", + ) + ), + ]) + elif min_value is not None: + # Case: X <= arg + if isinstance(param, ir.IList): + buf.extend([ + f"if {val_opt}not ({min_value} {op_min} min({py_symbol})): ", + *indent( + _generate_raise_value_err( + f"All elements of '{py_symbol}'", + f"greater than {min_value} {op_min} x", + ) + ), + ]) + else: + buf.extend([ + f"if {val_opt}not ({min_value} {op_min} {py_symbol}): ", + *indent( + _generate_raise_value_err( + f"'{py_symbol}'", + f"greater than {min_value} {op_min} x", + f"{{{py_symbol}}}", + ) + ), + ]) + elif max_value is not None: + # Case: arg <= X + if isinstance(param, ir.IList): + buf.extend([ + f"if {val_opt}not (max({py_symbol}) {op_max} {max_value}): ", + *indent( + _generate_raise_value_err( + f"All elements of '{py_symbol}'", + f"less than x {op_max} {max_value}", + ) + ), + ]) + else: + buf.extend([ + f"if {val_opt}not ({py_symbol} {op_max} {max_value}): ", + *indent( + _generate_raise_value_err( + f"'{py_symbol}'", + f"less than x {op_max} {max_value}", + f"{{{py_symbol}}}", + ) + ), + ]) + + +def struct_compile_constraint_checks( + func: PyFunc, + struct: ir.IStruct | ir.IParam, + lookup: LookupParam, +) -> None: + for param in struct.struct.iter_params(): + _param_compile_constraint_checks(func.body, param, lookup) diff --git a/src/styx/backend/python/core.py b/src/styx/backend/python/core.py new file mode 100644 index 0000000..16ae990 --- /dev/null +++ b/src/styx/backend/python/core.py @@ -0,0 +1,53 @@ +import dataclasses +from typing import Any, Generator, Iterable + +from styx.backend.python.documentation import docs_to_docstring +from styx.backend.python.interface import compile_interface +from styx.backend.python.pycodegen.core import PyModule +from styx.backend.python.pycodegen.scope import Scope +from styx.backend.python.pycodegen.utils import python_snakify +from styx.ir.core import Interface, Package + + +@dataclasses.dataclass +class _PackageData: + package: Package + package_symbol: str + scope: Scope + module: PyModule + + +def to_python(interfaces: Iterable[Interface]) -> Generator[tuple[str, list[str]], Any, None]: + """For a stream of IR interfaces return a stream of Python modules and their module paths. + + Args: + interfaces: Stream of IR interfaces. + + Returns: + Stream of tuples (Python module, module path). + """ + packages: dict[str, _PackageData] = {} + global_scope = Scope(parent=Scope.python()) + + for interface in interfaces: + if interface.package.name not in packages: + packages[interface.package.name] = _PackageData( + package=interface.package, + package_symbol=global_scope.add_or_dodge(python_snakify(interface.package.name)), + scope=Scope(parent=global_scope), + module=PyModule( + docstr=docs_to_docstring(interface.package.docs), + ), + ) + package_data = packages[interface.package.name] + + # interface_module_symbol = global_scope.add_or_dodge(python_snakify(interface.command.param.name)) + interface_module_symbol = python_snakify(interface.command.param.name) + + interface_module = PyModule() + compile_interface(interface=interface, package_scope=package_data.scope, interface_module=interface_module) + package_data.module.imports.append(f"from .{interface_module_symbol} import *") + yield interface_module.text(), [package_data.package_symbol, interface_module_symbol] + + for package_data in packages.values(): + yield package_data.module.text(), [package_data.package_symbol, "__init__"] diff --git a/src/styx/backend/python/documentation.py b/src/styx/backend/python/documentation.py new file mode 100644 index 0000000..7e83e4b --- /dev/null +++ b/src/styx/backend/python/documentation.py @@ -0,0 +1,54 @@ +from styx.ir.core import ( + Documentation, +) + + +def _ensure_period(s: str) -> str: + if not s.endswith("."): + return f"{s}." + return s + + +def _ensure_double_linebreak_if_not_empty(s: str) -> str: + if s == "" or s.endswith("\n\n"): + return s + if s.endswith("\n"): + return f"{s}\n" + return f"{s}\n\n" + + +def docs_to_docstring(docs: Documentation) -> str | None: + re = "" + if docs.title: + re += docs.title + + if docs.description: + re = _ensure_double_linebreak_if_not_empty(re) + re += _ensure_period(docs.description) + + if docs.authors: + re = _ensure_double_linebreak_if_not_empty(re) + if len(docs.authors) == 1: + re += f"Author: {docs.authors[0]}" + else: + re += f"Authors: {', '.join(docs.authors)}" + + if docs.literature: + re = _ensure_double_linebreak_if_not_empty(re) + if len(docs.literature) == 1: + re += f"Literature: {docs.literature[0]}" + else: + entries = "\n".join(docs.literature) + re += f"Literature:\n{entries}" + + if docs.urls: + re = _ensure_double_linebreak_if_not_empty(re) + if len(docs.urls) == 1: + re += f"URL: {docs.urls[0]}" + else: + entries = "\n".join(docs.urls) + re += f"URLs:\n{entries}" + + if re: + return re + return None diff --git a/src/styx/backend/python/interface.py b/src/styx/backend/python/interface.py new file mode 100644 index 0000000..83778f4 --- /dev/null +++ b/src/styx/backend/python/interface.py @@ -0,0 +1,454 @@ +import styx.ir.core as ir +from styx.backend.python.constraints import struct_compile_constraint_checks +from styx.backend.python.documentation import docs_to_docstring +from styx.backend.python.lookup import LookupParam +from styx.backend.python.metadata import generate_static_metadata +from styx.backend.python.pycodegen.core import ( + LineBuffer, + PyArg, + PyDataClass, + PyFunc, + PyModule, + expand, + indent, +) +from styx.backend.python.pycodegen.scope import Scope +from styx.backend.python.pycodegen.utils import as_py_literal, enquote, python_snakify +from styx.backend.python.utils import ( + param_py_default_value, + param_py_var_is_set_by_user, + param_py_var_to_str, + struct_has_outputs, +) + + +def _compile_struct( + struct: ir.IStruct | ir.IParam, + interface_module: PyModule, + lookup: LookupParam, + metadata_symbol: str, + root_function: bool, +) -> None: + has_outputs = root_function or struct_has_outputs(struct) + + outputs_type = lookup.py_output_type[struct.param.id_] + + if root_function: + func_cargs_building = PyFunc( + name=lookup.py_type[struct.param.id_], + return_type=outputs_type, + return_descr=f"NamedTuple of outputs " f"(described in `{outputs_type}`).", + docstring_body=docs_to_docstring(struct.param.docs), + ) + pyargs = func_cargs_building.args + else: + func_cargs_building = PyFunc( + name="run", + docstring_body="Build command line arguments. This method is called by the main command.", + return_type="list[str]", + return_descr="Command line arguments", + args=[ + PyArg(name="self", type=None, default=None, docstring="The sub-command object."), + PyArg(name="execution", type="Execution", default=None, docstring="The execution object."), + ], + ) + struct_class = PyDataClass( + name=lookup.py_struct_type[struct.param.id_], + docstring=docs_to_docstring(struct.param.docs), + methods=[func_cargs_building], + ) + if has_outputs: + func_outputs = PyFunc( + name="outputs", + docstring_body="Collect output file paths.", + return_type=outputs_type, + return_descr=f"NamedTuple of outputs " f"(described in `{outputs_type}`).", + args=[ + PyArg(name="self", type=None, default=None, docstring="The sub-command object."), + PyArg(name="execution", type="Execution", default=None, docstring="The execution object."), + ], + ) + pyargs = struct_class.fields + + # Collect param python symbols + for elem in struct.struct.iter_params(): + symbol = lookup.py_symbol[elem.param.id_] + pyargs.append( + PyArg( + name=symbol, + type=lookup.py_type[elem.param.id_], + default=param_py_default_value(elem), + docstring=elem.param.docs.description, + ) + ) + + if isinstance(elem, ir.IStruct): + _compile_struct( + struct=elem, + interface_module=interface_module, + lookup=lookup, + metadata_symbol=metadata_symbol, + root_function=False, + ) + elif isinstance(elem, ir.IStructUnion): + for child in elem.alts: + _compile_struct( + struct=child, + interface_module=interface_module, + lookup=lookup, + metadata_symbol=metadata_symbol, + root_function=False, + ) + + struct_compile_constraint_checks(func=func_cargs_building, struct=struct, lookup=lookup) + + if has_outputs: + _compile_outputs_class( + struct=struct, + interface_module=interface_module, + lookup=lookup, + ) + + if root_function: + func_cargs_building.body.extend([ + "runner = runner or get_global_runner()", + f"execution = runner.start_execution({metadata_symbol})", + ]) + + _compile_cargs_building(struct, lookup, func_cargs_building, access_via_self=not root_function) + + if root_function: + pyargs.append(PyArg(name="runner", type="Runner | None", default="None", docstring="Command runner")) + _compile_outputs_building( + struct=struct, + func=func_cargs_building, + lookup=lookup, + access_via_self=False, + ) + func_cargs_building.body.extend([ + "execution.run(cargs)", + "return ret", + ]) + interface_module.funcs_and_classes.append(func_cargs_building) + else: + if has_outputs: + _compile_outputs_building( + struct=struct, + func=func_outputs, + lookup=lookup, + access_via_self=True, + ) + func_outputs.body.extend([ + "return ret", + ]) + struct_class.methods.append(func_outputs) + func_cargs_building.body.extend([ + "return cargs", + ]) + interface_module.funcs_and_classes.append(struct_class) + interface_module.exports.append(struct_class.name) + + +def _compile_cargs_building( + param: ir.IParam | ir.IStruct, + lookup: LookupParam, + func: PyFunc, + access_via_self: bool, +) -> None: + func.body.append("cargs = []") + + for group in param.struct.groups: + group_conditions_py = [] + + building_cargs_py: list[tuple[str, bool]] = [] + for carg in group.cargs: + building_carg_py: list[tuple[str, bool]] = [] + for token in carg.tokens: + if isinstance(token, str): + building_carg_py.append((as_py_literal(token), False)) + continue + elem_symbol = lookup.py_symbol[token.param.id_] + if access_via_self: + elem_symbol = f"self.{elem_symbol}" + building_carg_py.append(param_py_var_to_str(token, elem_symbol)) + if (py_var_is_set_by_user := param_py_var_is_set_by_user(token, elem_symbol, False)) is not None: + group_conditions_py.append(py_var_is_set_by_user) + + if len(building_carg_py) == 1: + building_cargs_py.append(building_carg_py[0]) + else: + destructured = [s if not s_is_list else f'" ".join({s})' for s, s_is_list in building_carg_py] + building_cargs_py.append((" + ".join(destructured), False)) + + buf_appending: LineBuffer = [] + + if len(building_cargs_py) == 1: + for val, val_is_list in building_cargs_py: + if val_is_list: + buf_appending.append(f"cargs.extend({val})") + else: + buf_appending.append(f"cargs.append({val})") + else: + x = [(f"*{val}" if val_is_list else val) for val, val_is_list in building_cargs_py] + buf_appending.extend([ + "cargs.extend([", + *indent(expand(",\n".join(x))), + "])", + ]) + + if len(group_conditions_py) > 0: + func.body.append(f"if {' and '.join(group_conditions_py)}:") + func.body.extend(indent(buf_appending)) + else: + func.body.extend(buf_appending) + + +def _compile_outputs_class( + struct: ir.IStruct | ir.IParam, + interface_module: PyModule, + lookup: LookupParam, +) -> None: + outputs_class = PyDataClass( + name=lookup.py_output_type[struct.param.id_], + docstring=f"Output object returned when calling `{lookup.py_type[struct.param.id_]}(...)`.", + is_named_tuple=True, + ) + outputs_class.fields.append( + PyArg( + name="root", + type="OutputPathType", + default=None, + docstring="Output root folder. This is the root folder for all outputs.", + ) + ) + + for output in struct.param.outputs: + output_symbol = lookup.py_output_field_symbol[output.id_] + + # Optional if any of its param references is optional + optional = False + for token in output.tokens: + if isinstance(token, str): + continue + optional = optional or isinstance(lookup.param[token.ref_id], ir.IOptional) + + if not optional: + output_type = "OutputPathType" + else: + output_type = "OutputPathType | None" + + outputs_class.fields.append( + PyArg( + name=output_symbol, + type=output_type, + default=None, + docstring=output.docs.description, + ) + ) + + for sub_struct in struct.struct.iter_params(): + if isinstance(sub_struct, ir.IStruct): + if struct_has_outputs(sub_struct): + output_type = lookup.py_output_type[sub_struct.param.id_] + if isinstance(sub_struct, ir.IList): + output_type = f"typing.List[{output_type}]" + if isinstance(sub_struct, ir.IOptional): + output_type = f"{output_type} | None" + + output_symbol = lookup.py_symbol[sub_struct.param.id_] # todo: name collisions + + input_type = lookup.py_struct_type[sub_struct.param.id_] + docs_append = "" + if isinstance(sub_struct, ir.IList): + docs_append = "This is a list of outputs with the same length and order as the inputs." + + outputs_class.fields.append( + PyArg( + name=output_symbol, + type=output_type, + default=None, + docstring=f"Outputs from {enquote(input_type, '`')}.{docs_append}", + ) + ) + elif isinstance(sub_struct, ir.IStructUnion): + if any([struct_has_outputs(s) for s in sub_struct.alts]): + alt_types = [ + lookup.py_output_type[sub_command.param.id_] + for sub_command in sub_struct.alts + if struct_has_outputs(sub_command) + ] + if len(alt_types) > 0: + output_type = ", ".join(alt_types) + output_type = f"typing.Union[{output_type}]" + + if isinstance(sub_struct, ir.IList): + output_type = f"typing.List[{output_type}]" + if isinstance(sub_struct, ir.IOptional): + output_type = f"{output_type} | None" + + output_symbol = lookup.py_symbol[sub_struct.param.id_] # todo: name collisions + + alt_input_types = [ + lookup.py_struct_type[sub_command.param.id_] + for sub_command in sub_struct.alts + if struct_has_outputs(sub_command) + ] + docs_append = "" + if isinstance(sub_struct, ir.IList): + docs_append = "This is a list of outputs with the same length and order as the inputs." + + input_types_human = ' or '.join([enquote(t, '`') for t in alt_input_types]) + outputs_class.fields.append( + PyArg( + name=output_symbol, + type=output_type, + default=None, + docstring=f"Outputs from {input_types_human}.{docs_append}", + ) + ) + + interface_module.funcs_and_classes.append(outputs_class) + interface_module.exports.append(outputs_class.name) + + +def _compile_outputs_building( + struct: ir.IStruct | ir.IParam, + func: PyFunc, + lookup: LookupParam, + access_via_self: bool = False, +) -> None: + """Generate the outputs building code.""" + func.body.append(f"ret = {lookup.py_output_type[struct.param.id_]}(") + + # Set root output path + func.body.extend(indent(['root=execution.output_file("."),'])) + + def _py_get_val( + output_param_reference: ir.OutputParamReference, + ) -> str: + param = lookup.param[output_param_reference.ref_id] + symbol = lookup.py_symbol[param.param.id_] + + substitute = symbol + if access_via_self: + substitute = f"self.{substitute}" + + if isinstance(param, ir.IList): + raise Exception(f"Output path template replacements cannot be lists. ({param.param.name})") + + if isinstance(param, ir.IStr): + return substitute + + if isinstance(param, (ir.IInt, ir.IFloat)): + return f"str({substitute})" + + if isinstance(param, ir.IFile): + re = f"pathlib.Path({substitute}).name" + for suffix in output_param_reference.file_remove_suffixes: + re += f".removesuffix({as_py_literal(suffix)})" + return re + + if isinstance(param, ir.IBool): + raise Exception(f"Unsupported input type " f"for output path template of '{param.param.name}'.") + assert False + + for output in struct.param.outputs: + output_symbol = lookup.py_output_field_symbol[output.id_] + + output_segments: list[str] = [] + conditions = [] + for token in output.tokens: + if isinstance(token, str): + output_segments.append(as_py_literal(token)) + continue + output_segments.append(_py_get_val(token)) + + ostruct = lookup.param[token.ref_id] + param_symbol = lookup.py_symbol[ostruct.param.id_] + if (py_var_is_set_by_user := param_py_var_is_set_by_user(ostruct, param_symbol, False)) is not None: + conditions.append(py_var_is_set_by_user) + + condition_py = "" + if len(conditions) > 0: + condition_py = " and ".join(conditions) + condition_py = f" if ({condition_py}) else None" + + func.body.extend( + indent([f"{output_symbol}=execution.output_file({' + '.join(output_segments)}){condition_py},"]) + ) + + # sub struct outputs + for sub_struct in struct.struct.iter_params(): + has_outputs = False + if isinstance(sub_struct, ir.IStruct): + has_outputs = struct_has_outputs(sub_struct) + elif isinstance(sub_struct, ir.IStructUnion): + has_outputs = any([struct_has_outputs(s) for s in sub_struct.alts]) + if not has_outputs: + continue + + output_symbol = lookup.py_symbol[sub_struct.param.id_] # todo: name collisions + output_symbol_resolved = output_symbol + if access_via_self: + output_symbol_resolved = f"self.{output_symbol_resolved}" + + if isinstance(sub_struct, ir.IList): + opt = "" + if isinstance(sub_struct, ir.IOptional): + opt = f" if {output_symbol_resolved} else None" + func.body.extend( + indent([f"{output_symbol}=" f"[i.outputs(execution) for i in {output_symbol_resolved}]{opt},"]) + ) + else: + o = f"{output_symbol_resolved}.outputs(execution)" + if isinstance(sub_struct, ir.IOptional): + o = f"{o} if {output_symbol_resolved} else None" + func.body.extend(indent([f"{output_symbol}={o},"])) + + func.body.extend([")"]) + + +def compile_interface( + interface: ir.Interface, + package_scope: Scope, + interface_module: PyModule, +) -> None: + """Entry point to the Python backend.""" + interface_module.imports.extend([ + "import typing", + "import pathlib", + "from styxdefs import *", + "import dataclasses", + ]) + + metadata_symbol = generate_static_metadata( + module=interface_module, + scope=package_scope, + interface=interface, + ) + interface_module.exports.append(metadata_symbol) + + function_symbol = package_scope.add_or_dodge(python_snakify(interface.command.param.name)) + interface_module.exports.append(function_symbol) + + function_scope = Scope(parent=package_scope) + function_scope.add_or_die("runner") + function_scope.add_or_die("execution") + function_scope.add_or_die("cargs") + function_scope.add_or_die("ret") + + # Lookup tables + lookup = LookupParam( + interface=interface, + package_scope=package_scope, + function_symbol=function_symbol, + function_scope=function_scope, + ) + + _compile_struct( + struct=interface.command, + interface_module=interface_module, + lookup=lookup, + metadata_symbol=metadata_symbol, + root_function=True, + ) diff --git a/src/styx/backend/python/lookup.py b/src/styx/backend/python/lookup.py new file mode 100644 index 0000000..e8399cb --- /dev/null +++ b/src/styx/backend/python/lookup.py @@ -0,0 +1,88 @@ +import styx.ir.core as ir +from styx.backend.python.pycodegen.scope import Scope +from styx.backend.python.pycodegen.utils import python_pascalize, python_snakify +from styx.backend.python.utils import iter_params_recursively, param_py_type + + +class LookupParam: + """Pre-compute and store Python symbols, types, class-names, etc. to reduce spaghetti code everywhere else.""" + + def __init__( + self, + interface: ir.Interface, + package_scope: Scope, + function_symbol: str, + function_scope: Scope, + ) -> None: + def _collect_output_field_symbols( + param: ir.IStruct | ir.IParam, lookup_output_field_symbol: dict[ir.IdType, str] + ) -> None: + scope = Scope(parent=package_scope) + for output in param.param.outputs: + output_field_symbol = scope.add_or_dodge(python_snakify(output.name)) + assert output.id_ not in lookup_output_field_symbol + lookup_output_field_symbol[output.id_] = output_field_symbol + + def _collect_py_symbol(param: ir.IStruct | ir.IParam, lookup_py_symbol: dict[ir.IdType, str]) -> None: + scope = Scope(parent=function_scope) + for elem in param.struct.iter_params(): + symbol = scope.add_or_dodge(python_snakify(elem.param.name)) + assert elem.param.id_ not in lookup_py_symbol + lookup_py_symbol[elem.param.id_] = symbol + + self.param: dict[ir.IdType, ir.IParam] = {interface.command.param.id_: interface.command} + """Find param object by its ID. IParam.id_ -> IParam""" + self.py_struct_type: dict[ir.IdType, str] = {interface.command.param.id_: function_symbol} + """Find Python struct type by param id. IParam.id_ -> Python type + (this is different from py_type because of optionals and lists)""" + self.py_type: dict[ir.IdType, str] = {interface.command.param.id_: function_symbol} + """Find Python type by param id. IParam.id_ -> Python type""" + self.py_symbol: dict[ir.IdType, str] = {} + """Find function-parameter symbol by param ID. IParam.id_ -> Python symbol""" + self.py_output_type: dict[ir.IdType, str] = { + interface.command.param.id_: package_scope.add_or_dodge( + python_pascalize(f"{interface.command.struct.name}_Outputs") + ) + } + """Find outputs class name by struct param ID. IStruct.id_ -> Python class name""" + self.py_output_field_symbol: dict[ir.IdType, str] = {} + """Find output field symbol by output ID. Output.id_ -> Python symbol""" + + _collect_py_symbol( + param=interface.command, + lookup_py_symbol=self.py_symbol, + ) + _collect_output_field_symbols( + param=interface.command, + lookup_output_field_symbol=self.py_output_field_symbol, + ) + + for elem in iter_params_recursively(interface.command): + self.param[elem.param.id_] = elem + + if isinstance(elem, ir.IStruct): + if elem.param.id_ not in self.py_struct_type: # Struct unions may resolve these first + self.py_struct_type[elem.param.id_] = package_scope.add_or_dodge( + python_pascalize(f"{interface.command.struct.name}_{elem.struct.name}") + ) + self.py_type[elem.param.id_] = param_py_type(elem, self.py_struct_type) + self.py_output_type[elem.param.id_] = package_scope.add_or_dodge( + python_pascalize(f"{interface.command.struct.name}_{elem.struct.name}_Outputs") + ) + _collect_py_symbol( + param=elem, + lookup_py_symbol=self.py_symbol, + ) + _collect_output_field_symbols( + param=elem, + lookup_output_field_symbol=self.py_output_field_symbol, + ) + elif isinstance(elem, ir.IStructUnion): + for alternative in elem.alts: + self.py_struct_type[alternative.param.id_] = package_scope.add_or_dodge( + python_pascalize(f"{interface.command.struct.name}_{alternative.struct.name}") + ) + self.py_type[alternative.param.id_] = param_py_type(alternative, self.py_struct_type) + self.py_type[elem.param.id_] = param_py_type(elem, self.py_struct_type) + else: + self.py_type[elem.param.id_] = param_py_type(elem, self.py_struct_type) diff --git a/src/styx/backend/python/metadata.py b/src/styx/backend/python/metadata.py new file mode 100644 index 0000000..2145b17 --- /dev/null +++ b/src/styx/backend/python/metadata.py @@ -0,0 +1,33 @@ +from styx.backend.python.pycodegen.core import PyModule, indent +from styx.backend.python.pycodegen.scope import Scope +from styx.backend.python.pycodegen.utils import as_py_literal, python_screaming_snakify +from styx.ir.core import Interface + + +def generate_static_metadata( + module: PyModule, + scope: Scope, + interface: Interface, +) -> str: + """Generate the static metadata.""" + metadata_symbol = scope.add_or_dodge(f"{python_screaming_snakify(interface.command.param.name)}_METADATA") + + entries = { + "id": interface.uid, + "name": interface.command.param.name, + "package": interface.package.name, + } + + if interface.command.param.docs.literature: + entries["citations"] = interface.command.param.docs.literature + + if interface.package.docker: + entries["container_image_tag"] = interface.package.docker + + module.header.extend([ + f"{metadata_symbol} = Metadata(", + *indent([f"{k}={as_py_literal(v)}," for k, v in entries.items()]), + ")", + ]) + + return metadata_symbol diff --git a/src/styx/pycodegen/__init__.py b/src/styx/backend/python/pycodegen/__init__.py similarity index 100% rename from src/styx/pycodegen/__init__.py rename to src/styx/backend/python/pycodegen/__init__.py diff --git a/src/styx/pycodegen/core.py b/src/styx/backend/python/pycodegen/core.py similarity index 69% rename from src/styx/pycodegen/core.py rename to src/styx/backend/python/pycodegen/core.py index ee1b554..d533591 100644 --- a/src/styx/pycodegen/core.py +++ b/src/styx/backend/python/pycodegen/core.py @@ -3,7 +3,7 @@ from abc import ABC from dataclasses import dataclass, field -from styx.pycodegen.utils import enquote, ensure_endswith, linebreak_paragraph +from styx.backend.python.pycodegen.utils import enquote, ensure_endswith, linebreak_paragraph LineBuffer = list[str] INDENT = " " @@ -68,9 +68,9 @@ class PyArg: """Python function argument.""" name: str - type: str | None - default: str | None - docstring: str + type: str | None = None + default: str | None = None + docstring: str | None = None def declaration(self) -> str: """Generate the argument declaration ("var[: type][ = default]").""" @@ -84,11 +84,11 @@ def declaration(self) -> str: class PyFunc(PyGen): """Python function.""" - name: str = "" + name: str args: list[PyArg] = field(default_factory=list) - docstring_body: str = "" + docstring_body: str | None = None body: LineBuffer = field(default_factory=list) - return_descr: str = "" + return_descr: str | None = None return_type: str | None = None def generate(self) -> LineBuffer: @@ -107,8 +107,12 @@ def generate(self) -> LineBuffer: arg_docstr_buf = [] for arg in self.args: + if arg.name == "self": + continue arg_docstr = linebreak_paragraph( - f"{arg.name}: {arg.docstring}", width=80 - (4 * 3) - 1, first_line_width=80 - (4 * 2) - 1 + f"{arg.name}: {arg.docstring if arg.docstring else ''}", + width=80 - (4 * 3) - 1, + first_line_width=80 - (4 * 2) - 1, ) arg_docstr = ensure_endswith("\\\n".join(arg_docstr), ".").split("\n") arg_docstr_buf.append(arg_docstr[0]) @@ -116,7 +120,10 @@ def generate(self) -> LineBuffer: # Add docstring (Google style) - docstring_linebroken = linebreak_paragraph(self.docstring_body, width=80 - 4) + if self.docstring_body: + docstring_linebroken = linebreak_paragraph(self.docstring_body, width=80 - 4) + else: + docstring_linebroken = "" buf.extend( indent([ @@ -125,14 +132,16 @@ def generate(self) -> LineBuffer: "", "Args:", *indent(arg_docstr_buf), - "Returns:", - *indent([f"{self.return_descr}"]), + *(["Returns:", *indent([f"{self.return_descr}"])] if self.return_descr else []), '"""', ]) ) # Add function body - buf.extend(indent(self.body)) + if self.body: + buf.extend(indent(self.body)) + else: + buf.extend(indent(["pass"])) return buf @@ -144,28 +153,46 @@ class PyDataClass(PyGen): docstring: str fields: list[PyArg] = field(default_factory=list) methods: list[PyFunc] = field(default_factory=list) + is_named_tuple: bool = False def generate(self) -> LineBuffer: # Sort fields so default arguments come last self.fields.sort(key=lambda a: a.default is not None) def _arg_docstring(arg: PyArg) -> LineBuffer: + if not arg.docstring: + return [] return linebreak_paragraph(f'"""{arg.docstring}"""', width=80 - 4, first_line_width=80 - 4) args = concat([[f.declaration(), *_arg_docstring(f)] for f in self.fields]) methods = concat([method.generate() for method in self.methods], [""]) - buf = [ - "@dataclasses.dataclass", - f"class {self.name}:", - *indent([ - '"""', - f"{self.docstring}", - '"""', - *args, - *blank_before(methods), - ]), - ] + if not self.is_named_tuple: + buf = [ + "@dataclasses.dataclass", + f"class {self.name}:", + *indent([ + *( + ['"""', *linebreak_paragraph(self.docstring, width=80 - 4, first_line_width=80 - 4), '"""'] + if self.docstring + else [] + ), + *args, + *blank_before(methods), + ]), + ] + else: + buf = [ + f"class {self.name}(typing.NamedTuple):", + *indent([ + '"""', + f"{self.docstring}", + '"""', + *args, + *blank_before(methods), + ]), + ] + return buf @@ -175,15 +202,16 @@ class PyModule(PyGen): imports: LineBuffer = field(default_factory=list) header: LineBuffer = field(default_factory=list) - funcs: list[PyFunc] = field(default_factory=list) + funcs_and_classes: list[PyFunc | PyDataClass] = field(default_factory=list) footer: LineBuffer = field(default_factory=list) exports: list[str] = field(default_factory=list) + docstr: str | None = None def generate(self) -> LineBuffer: exports = ( [ "__all__ = [", - *indent(list(map(lambda x: f"{enquote(x)},", self.exports))), + *indent(list(map(lambda x: f"{enquote(x)},", sorted(self.exports)))), "]", ] if self.exports @@ -191,13 +219,14 @@ def generate(self) -> LineBuffer: ) return blank_after([ + *(['"""', *linebreak_paragraph(self.docstr), '"""'] if self.docstr else []), *comment([ "This file was auto generated by Styx.", "Do not edit this file directly.", ]), *blank_before(self.imports), *blank_before(self.header), - *[line for func in self.funcs for line in blank_before(func.generate(), 2)], + *[line for func in self.funcs_and_classes for line in blank_before(func.generate(), 2)], *blank_before(self.footer), *blank_before(exports, 2), ]) diff --git a/src/styx/pycodegen/scope.py b/src/styx/backend/python/pycodegen/scope.py similarity index 100% rename from src/styx/pycodegen/scope.py rename to src/styx/backend/python/pycodegen/scope.py diff --git a/src/styx/pycodegen/string_case.py b/src/styx/backend/python/pycodegen/string_case.py similarity index 100% rename from src/styx/pycodegen/string_case.py rename to src/styx/backend/python/pycodegen/string_case.py diff --git a/src/styx/pycodegen/utils.py b/src/styx/backend/python/pycodegen/utils.py similarity index 100% rename from src/styx/pycodegen/utils.py rename to src/styx/backend/python/pycodegen/utils.py diff --git a/src/styx/backend/python/utils.py b/src/styx/backend/python/utils.py new file mode 100644 index 0000000..3b9e1c1 --- /dev/null +++ b/src/styx/backend/python/utils.py @@ -0,0 +1,180 @@ +from typing import Any, Generator + +import styx.ir.core as ir +from styx.backend.python.pycodegen.utils import as_py_literal, enquote + + +def iter_params_recursively(param: ir.IParam | str, skip_self: bool = True) -> Generator[ir.IParam, Any, None]: + """Iterate through all child-params recursively.""" + if isinstance(param, str): + return + if not skip_self: + yield param + if isinstance(param, ir.IStruct): + for e in param.struct.iter_params(): + yield from iter_params_recursively(e, False) + elif isinstance(param, ir.IStructUnion): + for e in param.alts: + yield from iter_params_recursively(e, False) + + +def param_py_type(param: ir.IParam, lookup_struct_type: dict[ir.IdType, str]) -> str: + """Return the Python type expression for a param. + + Args: + param: The param. + lookup_struct_type: lookup dictionary for struct types (pre-compute). + + Returns: + Python type expression. + """ + + def _base() -> str: + if isinstance(param, ir.IStr): + if param.choices: + return f"typing.Literal[{', '.join(map(as_py_literal, param.choices))}]" + return "str" + if isinstance(param, ir.IInt): + if param.choices: + return f"typing.Literal[{', '.join(map(as_py_literal, param.choices))}]" + return "int" + if isinstance(param, ir.IFloat): + return "float" + if isinstance(param, ir.IFile): + return "InputPathType" + if isinstance(param, ir.IBool): + return "bool" + if isinstance(param, ir.IStruct): + return lookup_struct_type[param.param.id_] + if isinstance(param, ir.IStructUnion): + return f"typing.Union[{', '.join(lookup_struct_type[i.param.id_] for i in param.alts)}]" + assert False + + type_ = _base() + if isinstance(param, ir.IList): + type_ = f"list[{type_}]" + if isinstance(param, ir.IOptional): + type_ = f"{type_} | None" + + return type_ + + +def param_py_var_to_str( + param: ir.IParam, + symbol: str, +) -> tuple[str, bool]: + """Python var to str. + + Return a Python expression that converts the variable to a string or string array + and a boolean that indicates if the expression value is an array. + """ + + def _val() -> tuple[str, bool]: + if not isinstance(param, ir.IList): + if isinstance(param, ir.IStr): + return symbol, False + if isinstance(param, (ir.IInt, ir.IFloat)): + return f"str({symbol})", False + if isinstance(param, ir.IBool): + as_list = (len(param.value_true) > 1) or (len(param.value_false) > 1) + if as_list: + value_true = param.value_true + value_false = param.value_false + else: + value_true = param.value_true[0] if len(param.value_true) > 0 else None + value_false = param.value_false[0] if len(param.value_false) > 0 else None + if len(param.value_true) > 0: + if len(param.value_false) > 0: + return f"({as_py_literal(value_true)} if {symbol} else {as_py_literal(value_true)})", as_list + return as_py_literal(value_true), as_list + assert len(param.value_false) > 0 + return as_py_literal(value_false), as_list + if isinstance(param, ir.IFile): + return f"execution.input_file({symbol})", False + if isinstance(param, (ir.IStruct, ir.IStructUnion)): + return f"{symbol}.run(execution)", True + assert False + + if param.list_.join is None: + if isinstance(param, ir.IStr): + return symbol, True + if isinstance(param, (ir.IInt, ir.IFloat)): + return f"map(str, {symbol})", True + if isinstance(param, ir.IBool): + assert False, "TODO: Not implemented yet" + if isinstance(param, ir.IFile): + return f"[execution.input_file(f) for f in {symbol}]", True + if isinstance(param, (ir.IStruct, ir.IStructUnion)): + return f"[a for c in [s.run(execution) for s in {symbol}] for a in c]", True + assert False + + # arg.data.list_separator is not None + sep_join = f"{enquote(param.list_.join)}.join" + if isinstance(param, ir.IStr): + return f"{sep_join}({symbol})", False + if isinstance(param, (ir.IInt, ir.IFloat)): + return f"{sep_join}(map(str, {symbol}))", False + if isinstance(param, ir.IBool): + assert False, "TODO: Not implemented yet" + if isinstance(param, ir.IFile): + return f"{sep_join}([execution.input_file(f) for f in {symbol}])", False + if isinstance(param, (ir.IStruct, ir.IStructUnion)): + return f"{sep_join}([a for c in [s.run(execution) for s in {symbol}] for a in c])", False + assert False + + return _val() + + +def param_py_default_value(param: ir.IParam) -> str | None: + # Is this cheating? Maybe. + + if hasattr(param, "default_value"): + if param.default_value is ir.IOptional.SetToNone: + return "None" + if param.default_value is None: + return None + return as_py_literal(param.default_value) + + if hasattr(param, "default_value_set_to_none"): + if param.default_value_set_to_none: + return "None" + return None + + +def param_py_var_is_set_by_user( + param: ir.IParam, + symbol: str, + enbrace_statement: bool = False, +) -> str | None: + """Return a Python expression that checks if the variable is set by the user. + + Returns `None` if the param must always be specified. + """ + if isinstance(param, ir.IOptional): + if enbrace_statement: + return f"({symbol} is not None)" + return f"{symbol} is not None" + + if isinstance(param, ir.IBool): + if len(param.value_true) > 0 and len(param.value_false) == 0: + return symbol + if len(param.value_false) > 0 and len(param.value_true) == 0: + if enbrace_statement: + return f"(not {symbol})" + return f"not {symbol}" + return None + + +def struct_has_outputs(struct: ir.IParam | ir.IStruct) -> bool: + """Check if the sub-command has outputs.""" + if len(struct.param.outputs) > 0: + return True + for p in struct.struct.iter_params(): + if isinstance(p, ir.IStruct): + if struct_has_outputs(p): + return True + if isinstance(p, ir.IStructUnion): + for struct in p.alts: + if struct_has_outputs(struct): + return True + return False diff --git a/src/styx/compiler/__init__.py b/src/styx/compiler/__init__.py deleted file mode 100644 index df72f69..0000000 --- a/src/styx/compiler/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Boutiques to python compiler.""" diff --git a/src/styx/compiler/compile/__init__.py b/src/styx/compiler/compile/__init__.py deleted file mode 100644 index dd0e2cf..0000000 --- a/src/styx/compiler/compile/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Compilation of Styx data model to Python code.""" diff --git a/src/styx/compiler/compile/common.py b/src/styx/compiler/compile/common.py deleted file mode 100644 index 300538d..0000000 --- a/src/styx/compiler/compile/common.py +++ /dev/null @@ -1,21 +0,0 @@ -from dataclasses import dataclass - -from styx.pycodegen.scope import Scope - - -@dataclass -class SharedScopes: - module: Scope - function: Scope - output_tuple: Scope - - -@dataclass -class SharedSymbols: - function: str - output_class: str - metadata: str - runner: str - execution: str - cargs: str - ret: str diff --git a/src/styx/compiler/compile/constraints.py b/src/styx/compiler/compile/constraints.py deleted file mode 100644 index 1f328c1..0000000 --- a/src/styx/compiler/compile/constraints.py +++ /dev/null @@ -1,226 +0,0 @@ -from styx.compiler.compile.inputs import codegen_var_is_set_by_user -from styx.model.core import GroupConstraint, InputArgument, WithSymbol -from styx.pycodegen.core import LineBuffer, PyFunc, expand, indent -from styx.pycodegen.utils import enquote - - -def _generate_raise_value_err(obj: str, expectation: str, reality: str | None = None) -> LineBuffer: - fstr = "" - if "{" in obj or "{" in expectation or (reality is not None and "{" in reality): - fstr = "f" - - return ( - [f'raise ValueError({fstr}"{obj} must be {expectation} but was {reality}")'] - if reality is not None - else [f'raise ValueError({fstr}"{obj} must be {expectation}")'] - ) - - -def generate_input_constraint_validation( - buf: LineBuffer, - input_: WithSymbol[InputArgument], -) -> None: - """Generate input constraint validation code for an input argument.""" - py_symbol = input_.symbol - constraints = input_.data.constraints - - val_opt = "" - if input_.data.type.is_optional: - val_opt = f"{py_symbol} is not None and " - - # List argument length validation - if constraints.list_length_min is not None and constraints.list_length_max is not None: - # Case: len(list[]) == X - assert constraints.list_length_min <= constraints.list_length_max - if constraints.list_length_min == constraints.list_length_max: - buf.extend([ - f"if {val_opt}(len({py_symbol}) != {constraints.list_length_min}): ", - *indent( - _generate_raise_value_err( - f"Length of '{py_symbol}'", - f"{constraints.list_length_min}", - f"{{len({py_symbol})}}", - ) - ), - ]) - else: - # Case: X <= len(list[]) <= Y - buf.extend([ - f"if {val_opt}not ({constraints.list_length_min} <= " - f"len({py_symbol}) <= {constraints.list_length_max}): ", - *indent( - _generate_raise_value_err( - f"Length of '{py_symbol}'", - f"between {constraints.list_length_min} and {constraints.list_length_max}", - f"{{len({py_symbol})}}", - ) - ), - ]) - elif constraints.list_length_min is not None: - # Case len(list[]) >= X - buf.extend([ - f"if {val_opt}not ({constraints.list_length_min} <= len({py_symbol})): ", - *indent( - _generate_raise_value_err( - f"Length of '{py_symbol}'", - f"greater than {constraints.list_length_min}", - f"{{len({py_symbol})}}", - ) - ), - ]) - elif constraints.list_length_max is not None: - # Case len(list[]) <= X - buf.extend([ - f"if {val_opt}not (len({py_symbol}) <= {constraints.list_length_max}): ", - *indent( - _generate_raise_value_err( - f"Length of '{py_symbol}'", - f"less than {constraints.list_length_max}", - f"{{len({py_symbol})}}", - ) - ), - ]) - - # Numeric argument range validation - op_min = "<" if constraints.value_min_exclusive else "<=" - op_max = "<" if constraints.value_max_exclusive else "<=" - if constraints.value_min is not None and constraints.value_max is not None: - # Case: X <= arg <= Y - assert constraints.value_min <= constraints.value_max - if input_.data.type.is_list: - buf.extend([ - f"if {val_opt}not ({constraints.value_min} {op_min} min({py_symbol}) " - f"and max({py_symbol}) {op_max} {constraints.value_max}): ", - *indent( - _generate_raise_value_err( - f"All elements of '{py_symbol}'", - f"between {constraints.value_min} {op_min} x {op_max} {constraints.value_max}", - ) - ), - ]) - else: - buf.extend([ - f"if {val_opt}not ({constraints.value_min} {op_min} {py_symbol} {op_max} {constraints.value_max}): ", - *indent( - _generate_raise_value_err( - f"'{py_symbol}'", - f"between {constraints.value_min} {op_min} x {op_max} {constraints.value_max}", - f"{{{py_symbol}}}", - ) - ), - ]) - elif constraints.value_min is not None: - # Case: X <= arg - if input_.data.type.is_list: - buf.extend([ - f"if {val_opt}not ({constraints.value_min} {op_min} min({py_symbol})): ", - *indent( - _generate_raise_value_err( - f"All elements of '{py_symbol}'", - f"greater than {constraints.value_min} {op_min} x", - ) - ), - ]) - else: - buf.extend([ - f"if {val_opt}not ({constraints.value_min} {op_min} {py_symbol}): ", - *indent( - _generate_raise_value_err( - f"'{py_symbol}'", - f"greater than {constraints.value_min} {op_min} x", - f"{{{py_symbol}}}", - ) - ), - ]) - elif constraints.value_max is not None: - # Case: arg <= X - if input_.data.type.is_list: - buf.extend([ - f"if {val_opt}not (max({py_symbol}) {op_max} {constraints.value_max}): ", - *indent( - _generate_raise_value_err( - f"All elements of '{py_symbol}'", - f"less than x {op_max} {constraints.value_max}", - ) - ), - ]) - else: - buf.extend([ - f"if {val_opt}not ({py_symbol} {op_max} {constraints.value_max}): ", - *indent( - _generate_raise_value_err( - f"'{py_symbol}'", - f"less than x {op_max} {constraints.value_max}", - f"{{{py_symbol}}}", - ) - ), - ]) - - -def generate_group_constraint_validation( - buf: LineBuffer, - group: GroupConstraint, # type: ignore - args_lookup: dict[str, WithSymbol[InputArgument]], -) -> None: - group_args = [args_lookup[x] for x in group.members if x] - if group.members_mutually_exclusive: - txt_members = [enquote(x) for x in expand(",\\n\n".join(group.members))] - check_members = expand(" +\n".join([codegen_var_is_set_by_user(x, True) for x in group_args])) - buf.extend(["if ("]) - buf.extend(indent(check_members)) - buf.extend([ - ") > 1:", - *indent([ - "raise ValueError(", - *indent([ - '"Only one of the following arguments can be specified:\\n"', - *txt_members, - ]), - ")", - ]), - ]) - if group.members_must_include_all_or_none: - txt_members = [enquote(x) for x in expand(",\\n\n".join(group.members))] - check_members = expand(" ==\n".join([codegen_var_is_set_by_user(x, True) for x in group_args])) - buf.extend(["if not ("]) - buf.extend(indent(check_members)) - buf.extend([ - "):", - *indent([ - "raise ValueError(", - *indent([ - '"All or none of the following arguments must be specified:\\n"', - *txt_members, - ]), - ")", - ]), - ]) - if group.members_must_include_one: - txt_members = [enquote("- " + x) for x in expand("\\n\n".join(group.members))] - check_members = expand(" or\n".join([codegen_var_is_set_by_user(x, True) for x in group_args])) - buf.extend(["if not ("]) - buf.extend(indent(check_members)) - buf.extend([ - "):", - *indent([ - "raise ValueError(", - *indent([ - '"One of the following arguments must be specified:\\n"', - *txt_members, - ]), - ")", - ]), - ]) - - -def generate_constraint_checks( - func: PyFunc, - group_constraints: list[GroupConstraint], - inputs: list[WithSymbol[InputArgument]], -) -> None: - for arg in inputs: - generate_input_constraint_validation(func.body, arg) - - inputs_lookup_bt_name = {x.data.name: x for x in inputs} - for group_constraint in group_constraints: - generate_group_constraint_validation(func.body, group_constraint, inputs_lookup_bt_name) diff --git a/src/styx/compiler/compile/descriptor.py b/src/styx/compiler/compile/descriptor.py deleted file mode 100644 index c6ad71e..0000000 --- a/src/styx/compiler/compile/descriptor.py +++ /dev/null @@ -1,146 +0,0 @@ -from styx.compiler.compile.common import SharedScopes, SharedSymbols -from styx.compiler.compile.constraints import generate_constraint_checks -from styx.compiler.compile.inputs import build_input_arguments, generate_command_line_args_building -from styx.compiler.compile.metadata import generate_static_metadata -from styx.compiler.compile.outputs import generate_output_building, generate_outputs_class -from styx.compiler.compile.subcommand import generate_sub_command_classes -from styx.compiler.settings import CompilerSettings -from styx.model.core import Descriptor, InputArgument, OutputArgument, SubCommand, WithSymbol -from styx.pycodegen.core import PyArg, PyFunc, PyModule -from styx.pycodegen.scope import Scope -from styx.pycodegen.utils import ( - python_pascalize, - python_screaming_snakify, - python_snakify, -) - - -def _generate_run_function( - module: PyModule, - symbols: SharedSymbols, - scopes: SharedScopes, - command: SubCommand, - inputs: list[WithSymbol[InputArgument]], - outputs: list[WithSymbol[OutputArgument]], -) -> None: - # Sub-command classes - sub_aliases, sub_sub_command_class_aliases, _ = generate_sub_command_classes( - module, symbols, command, scopes.module - ) - - # Function - func = PyFunc( - name=symbols.function, - return_type=symbols.output_class, - return_descr=f"NamedTuple of outputs " f"(described in `{symbols.output_class}`).", - docstring_body=command.doc, - ) - module.funcs.append(func) - - # Function arguments - func.args.extend(build_input_arguments(inputs, sub_aliases)) - func.args.append(PyArg(name="runner", type="Runner | None", default="None", docstring="Command runner")) - - # Function body: Runner instantiation - func.body.extend([ - f"{symbols.runner} = {symbols.runner} or get_global_runner()", - ]) - - # Constraint checking - generate_constraint_checks(func, command.group_constraints, inputs) - - # Function body - func.body.extend([ - f"{symbols.execution} = {symbols.runner}.start_execution({symbols.metadata})", - f"{symbols.cargs} = []", - ]) - - # Command line args building - generate_command_line_args_building(command.input_command_line_template, func, inputs) - - # Outputs static definition - generate_outputs_class( - module, symbols.output_class, symbols.function, outputs, inputs, sub_sub_command_class_aliases - ) - # Outputs building code - generate_output_building( - func, scopes.function, symbols.execution, symbols.output_class, symbols.ret, outputs, inputs, False - ) - - # Function body: Run and return - func.body.extend([ - f"{symbols.execution}.run({symbols.cargs})", - f"return {symbols.ret}", - ]) - - -def compile_descriptor(descriptor: Descriptor, settings: CompilerSettings) -> str: - """Compile a descriptor to Python code.""" - # --- Scopes and symbols --- - - _module_scope = Scope(parent=Scope.python()) - scopes = SharedScopes( - module=_module_scope, - function=Scope(parent=_module_scope), - output_tuple=Scope(parent=_module_scope), - ) - - # Module level symbols - scopes.module.add_or_die("styx") - scopes.module.add_or_die("InputFileType") - scopes.module.add_or_die("OutputFileType") - scopes.module.add_or_die("Runner") - scopes.module.add_or_die("Execution") - scopes.module.add_or_die("Metadata") - - symbols = SharedSymbols( - function=scopes.module.add_or_dodge(python_snakify(descriptor.command.name)), - output_class=scopes.module.add_or_dodge(f"{python_pascalize(descriptor.command.name)}Outputs"), - metadata=scopes.module.add_or_dodge(f"{python_screaming_snakify(descriptor.command.name)}_METADATA"), - runner=scopes.function.add_or_die("runner"), - execution=scopes.function.add_or_die("execution"), - cargs=scopes.function.add_or_die("cargs"), - ret=scopes.function.add_or_die("ret"), - ) - - # Input symbols - inputs: list[WithSymbol[InputArgument]] = [] - for input_ in descriptor.command.inputs: - py_symbol = scopes.function.add_or_dodge(python_snakify(input_.name)) - inputs.append(WithSymbol(input_, py_symbol)) - - # Output symbols - outputs: list[WithSymbol[OutputArgument]] = [] - for output in descriptor.command.outputs: - py_symbol = scopes.output_tuple.add_or_dodge(python_snakify(output.name)) - outputs.append(WithSymbol(output, py_symbol)) - - # --- Code generation --- - module = PyModule() - - module.imports.append("import typing") - module.imports.append("import pathlib") - module.imports.append("from styxdefs import *") - - module.exports.append(symbols.function) - module.exports.append(symbols.output_class) - module.exports.append(symbols.metadata) - - # Static metadata - generate_static_metadata(module, descriptor, symbols) - - # Main command run function - _generate_run_function( - module, - symbols, - scopes, - command=descriptor.command, - inputs=inputs, - outputs=outputs, - ) - - # --- Return code --- - - module.imports.sort() - module.exports.sort() - return module.text() diff --git a/src/styx/compiler/compile/inputs.py b/src/styx/compiler/compile/inputs.py deleted file mode 100644 index 8fd5308..0000000 --- a/src/styx/compiler/compile/inputs.py +++ /dev/null @@ -1,284 +0,0 @@ -from styx.model.boutiques_split_command import boutiques_split_command -from styx.model.core import InputArgument, InputTypePrimitive, WithSymbol -from styx.pycodegen.core import LineBuffer, PyArg, PyFunc, expand, indent -from styx.pycodegen.utils import as_py_literal, enquote - - -def _input_argument_to_py_type(arg: InputArgument, sub_command_types: dict[str, str]) -> str: - """Return the Python type expression.""" - - def _base() -> str: - if arg.type.is_enum: - assert arg.enum_values is not None - assert arg.type.primitive != InputTypePrimitive.Flag - assert arg.type.primitive != InputTypePrimitive.SubCommand - return f"typing.Literal[{', '.join(map(as_py_literal, arg.enum_values))}]" - - match arg.type.primitive: - case InputTypePrimitive.String: - return "str" - case InputTypePrimitive.Number: - return "float | int" - case InputTypePrimitive.Integer: - return "int" - case InputTypePrimitive.File: - return "InputPathType" - case InputTypePrimitive.Flag: - return "bool" - case InputTypePrimitive.SubCommand: - assert arg.sub_command is not None - return sub_command_types[arg.sub_command.internal_id] - case InputTypePrimitive.SubCommandUnion: - assert arg.sub_command_union is not None - return f"typing.Union[{', '.join(sub_command_types[i.internal_id] for i in arg.sub_command_union)}]" - case _: - assert False - - type_ = _base() - if arg.type.primitive != InputTypePrimitive.Flag: - if arg.type.is_list: - type_ = f"list[{type_}]" - if arg.type.is_optional: - type_ = f"{type_} | None" - return type_ - - -def build_input_arguments( - inputs: list[WithSymbol[InputArgument]], - sub_command_types: dict[str, str], -) -> list[PyArg]: - """Build Python function arguments from input arguments.""" - return [ - PyArg( - name=arg.symbol, - type=_input_argument_to_py_type(arg.data, sub_command_types), - default=as_py_literal(arg.data.default_value) if arg.data.has_default_value else None, - docstring=arg.data.doc, - ) - for arg in inputs - ] - - -def codegen_var_is_set_by_user(arg: WithSymbol[InputArgument], enbrace_statement: bool = False) -> str: - """Return a Python expression that checks if the variable is set by the user.""" - if arg.data.type.primitive == InputTypePrimitive.Flag: - return arg.symbol - if enbrace_statement: - return f"({arg.symbol} is not None)" - return f"{arg.symbol} is not None" - - -def _codegen_var_to_str(arg: WithSymbol[InputArgument]) -> tuple[str, bool]: - """Return a Python expression that converts the variable to a string or string array. - - Return a boolean that indicates if the expression is an array. - """ - if arg.data.type.primitive == InputTypePrimitive.Flag: - assert arg.data.command_line_flag is not None, f"Flag input must have a command line flag ({arg.data.name})" - return enquote(arg.data.command_line_flag), False - - def _val() -> tuple[str, bool]: - if not arg.data.type.is_list: - match arg.data.type.primitive: - case InputTypePrimitive.String: - return arg.symbol, False - case InputTypePrimitive.Number: - return f"str({arg.symbol})", False - case InputTypePrimitive.Integer: - return f"str({arg.symbol})", False - case InputTypePrimitive.File: - return f"execution.input_file({arg.symbol})", False - case InputTypePrimitive.SubCommand: - return f"{arg.symbol}.run(execution)", True - case InputTypePrimitive.SubCommandUnion: - return f"{arg.symbol}.run(execution)", True - case _: - assert False - - # arg.data.type.is_list is True - if arg.data.list_separator is None: - match arg.data.type.primitive: - case InputTypePrimitive.String: - return arg.symbol, True - case InputTypePrimitive.Number: - return f"map(str, {arg.symbol})", True - case InputTypePrimitive.Integer: - return f"map(str, {arg.symbol})", True - case InputTypePrimitive.File: - return f"[execution.input_file(f) for f in {arg.symbol}]", True - case InputTypePrimitive.SubCommand: - return f"[a for c in [s.run(execution) for s in {arg.symbol}] for a in c]", True - case InputTypePrimitive.SubCommandUnion: - return f"[a for c in [s.run(execution) for s in {arg.symbol}] for a in c]", True - case _: - assert False - - # arg.data.list_separator is not None - sep_join = f"{enquote(arg.data.list_separator)}.join" - match arg.data.type.primitive: - case InputTypePrimitive.String: - return f"{sep_join}({arg.symbol})", False - case InputTypePrimitive.Number: - return f"{sep_join}(map(str, {arg.symbol}))", False - case InputTypePrimitive.Integer: - return f"{sep_join}(map(str, {arg.symbol}))", False - case InputTypePrimitive.File: - return f"{sep_join}([execution.input_file(f) for f in {arg.symbol}])", False - case InputTypePrimitive.SubCommand: - return f"{sep_join}([a for c in [s.run(execution) for s in {arg.symbol}] for a in c])", False - case InputTypePrimitive.SubCommandUnion: - return f"{sep_join}([a for c in [s.run(execution) for s in {arg.symbol}] for a in c])", False - case _: - assert False - - if arg.data.command_line_flag is not None: - val, val_is_list = _val() - if arg.data.command_line_flag_separator is not None: - assert not val_is_list, "List variables with non-null command_line_flag_separator are not supported" - prefix = arg.data.command_line_flag + arg.data.command_line_flag_separator - return f"({enquote(prefix)} + {val})", False - - if val_is_list: - return f"[{enquote(arg.data.command_line_flag)}, *{val}]", True - return f"[{enquote(arg.data.command_line_flag)}, {val}]", True - return _val() - - -def _input_segment_to_py_arg_builder(buf: LineBuffer, segment: list[str | WithSymbol[InputArgument]]) -> None: - """Return a Python expression that builds the command line arguments.""" - if len(segment) == 0: - return - - input_args: list[WithSymbol[InputArgument]] = [i for i in segment if isinstance(i, WithSymbol)] - - indent_level = 0 - - # Are there variables? - if len(input_args) > 0: - optional_segment = True - for arg in input_args: - if not arg.data.type.is_optional: - optional_segment = False # Segment will always be included - if optional_segment: - # Codegen: Condition: Is any variable in the segment set by the user? - condition = [] - for arg in input_args: - condition.append(codegen_var_is_set_by_user(arg)) - buf.append(f"if {' or '.join(condition)}:") - indent_level += 1 - - # Codegen: Build the string - # Codegen: Append to the command line arguments - if len(input_args) > 1: - # We need to check which variables are set - statement = [] - for token in segment: - if isinstance(token, str): - if len(token) == 0: - continue - statement.append(enquote(token)) - else: - var, is_list = _codegen_var_to_str(token) - assert not is_list, "List variables are not supported in this context" - if token.data.type.is_optional: - statement.append(f'({var} if {codegen_var_is_set_by_user(token)} else "")') - else: - statement.append(var) - buf.extend( - indent( - [ - "cargs.append(", - *indent(expand(" +\n".join(statement))), - ")", - ], - indent_level, - ) - ) - - else: - # We know the var has been set by the user - if len(segment) == 1: - if isinstance(segment[0], str): - buf.extend(indent([f"cargs.append({enquote(segment[0])})"], indent_level)) - else: - var, is_list = _codegen_var_to_str(segment[0]) - if is_list: - buf.extend(indent([f"cargs.extend({var})"], indent_level)) - else: - buf.extend(indent([f"cargs.append({var})"], indent_level)) - return - - statement = [] - for token in segment: - if isinstance(token, str): - statement.append(enquote(token)) - else: - var, is_list = _codegen_var_to_str(token) - assert not is_list, f"List variables are not supported in this context ({var})" - statement.append(var) - - buf.extend( - indent( - [ - "cargs.append(", - *indent(expand(" +\n".join(statement))), - ")", - ], - indent_level, - ) - ) - - -def _bt_template_str_parse( - input_command_line_template: str, - inputs: list[WithSymbol[InputArgument]], -) -> list[list[str | WithSymbol[InputArgument]]]: - """Parse a Boutiques command line template string into segments.""" - bt_template_str = boutiques_split_command(input_command_line_template) - - template_key_inputs = {input_.data.template_key: input_ for input_ in inputs} - - segments: list[list[str | WithSymbol[InputArgument]]] = [] - - for arg in bt_template_str: - segment: list[str | WithSymbol[InputArgument]] = [] - - stack: list[str | WithSymbol[InputArgument]] = [arg] - - # turn template into segments - while stack: - token = stack.pop() - if isinstance(token, str): - any_match = False - for template_key, bt_input in template_key_inputs.items(): - if template_key == token: - stack.append(bt_input) - any_match = True - break - o = token.split(template_key, 1) - if len(o) == 2: - stack.append(o[0]) - stack.append(bt_input) - stack.append(o[1]) - any_match = True - break - if not any_match: - segment.insert(0, token) - elif isinstance(token, WithSymbol): - segment.insert(0, token) - else: - assert False - segments.append(segment) - - return segments - - -def generate_command_line_args_building( - input_command_line_template: str, - func: PyFunc, - inputs: list[WithSymbol[InputArgument]], -) -> None: - """Generate the command line arguments building code.""" - segments = _bt_template_str_parse(input_command_line_template, inputs) - for segment in segments: - _input_segment_to_py_arg_builder(func.body, segment) diff --git a/src/styx/compiler/compile/metadata.py b/src/styx/compiler/compile/metadata.py deleted file mode 100644 index 077d64c..0000000 --- a/src/styx/compiler/compile/metadata.py +++ /dev/null @@ -1,17 +0,0 @@ -from styx.compiler.compile.common import SharedSymbols -from styx.model.core import Descriptor -from styx.pycodegen.core import PyModule, indent -from styx.pycodegen.utils import as_py_literal - - -def generate_static_metadata( - module: PyModule, - descriptor: Descriptor, - symbols: SharedSymbols, -) -> None: - """Generate the static metadata.""" - module.header.extend([ - f"{symbols.metadata} = Metadata(", - *indent([f"{k}={as_py_literal(v)}," for k, v in descriptor.metadata.items()]), - ")", - ]) diff --git a/src/styx/compiler/compile/outputs.py b/src/styx/compiler/compile/outputs.py deleted file mode 100644 index 5c5eaed..0000000 --- a/src/styx/compiler/compile/outputs.py +++ /dev/null @@ -1,223 +0,0 @@ -from styx.compiler.compile.inputs import codegen_var_is_set_by_user -from styx.model.core import InputArgument, InputTypePrimitive, OutputArgument, SubCommand, WithSymbol -from styx.pycodegen.core import PyFunc, PyModule, indent -from styx.pycodegen.scope import Scope -from styx.pycodegen.utils import as_py_literal, enbrace, enquote - - -def _find_output_dependencies( - output: WithSymbol[OutputArgument], - inputs: list[WithSymbol[InputArgument]], -) -> list[WithSymbol[InputArgument]]: - """Find the input dependencies for an output.""" - return [input_ for input_ in inputs if input_.data.template_key in output.data.path_template] - - -def _sub_command_has_outputs(sub_command: SubCommand) -> bool: - """Check if the sub-command has outputs.""" - if len(sub_command.outputs) > 0: - return True - for input_ in sub_command.inputs: - if input_.type.primitive == InputTypePrimitive.SubCommand: - assert input_.sub_command is not None - if _sub_command_has_outputs(input_.sub_command): - return True - if input_.type.primitive == InputTypePrimitive.SubCommandUnion: - assert input_.sub_command_union is not None - for sub_command in input_.sub_command_union: - if _sub_command_has_outputs(sub_command): - return True - return False - - -def generate_outputs_class( - module: PyModule, - symbol_output_class: str, - symbol_parent_function: str, - outputs: list[WithSymbol[OutputArgument]], - inputs: list[WithSymbol[InputArgument]], - sub_command_output_class_aliases: dict[str, str], -) -> None: - """Generate the static output class definition.""" - module.header.extend([ - "", - "", - f"class {symbol_output_class}(typing.NamedTuple):", - *indent([ - '"""', - f"Output object returned when calling `{symbol_parent_function}(...)`.", - '"""', - "root: OutputPathType", - '"""Output root folder. This is the root folder for all outputs."""', - ]), - ]) - for out in outputs: - deps = _find_output_dependencies(out, inputs) - if any([input_.data.type.is_optional for input_ in deps]): - out_type = "OutputPathType | None" - else: - out_type = "OutputPathType" - - # Declaration - module.header.extend( - indent([ - f"{out.symbol}: {out_type}", - f'"""{out.data.doc}"""', - ]) - ) - - for input_ in inputs: - if input_.data.type.primitive == InputTypePrimitive.SubCommand: - assert input_.data.sub_command is not None - - if _sub_command_has_outputs(input_.data.sub_command): - sub_commands_type = sub_command_output_class_aliases[input_.data.sub_command.internal_id] - if input_.data.type.is_list: - sub_commands_type = f"typing.List[{sub_commands_type}]" - - module.header.extend( - indent([ - f"{input_.symbol}: {sub_commands_type}", - '"""Subcommand outputs"""', - ]) - ) - - elif input_.data.type.primitive == InputTypePrimitive.SubCommandUnion: - assert input_.data.sub_command_union is not None - - sub_commands = [ - sub_command_output_class_aliases[sub_command.internal_id] - for sub_command in input_.data.sub_command_union - if _sub_command_has_outputs(sub_command) - ] - if len(sub_commands) > 0: - sub_commands_type = ", ".join(sub_commands) - sub_commands_type = f"typing.Union[{sub_commands_type}]" - - if input_.data.type.is_list: - sub_commands_type = f"typing.List[{sub_commands_type}]" - - module.header.extend( - indent([ - f"{input_.symbol}: {sub_commands_type}", - '"""Subcommand outputs"""', - ]) - ) - - -def generate_output_building( - func: PyFunc, - func_scope: Scope, - symbol_execution: str, - symbol_output_class: str, - symbol_return_var: str, - outputs: list[WithSymbol[OutputArgument]], - inputs: list[WithSymbol[InputArgument]], - access_via_self: bool = False, -) -> None: - """Generate the output building code.""" - py_rstrip_fun = func_scope.add_or_dodge("_rstrip") - if any([out.data.stripped_file_extensions is not None for out in outputs]): - func.body.extend([ - f"def {py_rstrip_fun}(s, r):", - *indent([ - "for postfix in r:", - *indent([ - "if s.endswith(postfix):", - *indent(["return s[: -len(postfix)]"]), - ]), - "return s", - ]), - ]) - - func.body.append(f"{symbol_return_var} = {symbol_output_class}(") - - # Set root output path - func.body.extend(indent([f'root={symbol_execution}.output_file("."),'])) - - for out in outputs: - strip_extensions = out.data.stripped_file_extensions is not None - s_optional = ", optional=True" if out.data.optional else "" - if out.data.path_template is not None: - path_template = out.data.path_template - - input_dependencies = _find_output_dependencies(out, inputs) - - if len(input_dependencies) == 0: - # No substitutions needed - func.body.extend( - indent([f"{out.symbol}={symbol_execution}.output_file(f{enquote(path_template)}{s_optional}),"]) - ) - else: - for input_ in input_dependencies: - substitute = input_.symbol - - if access_via_self: - substitute = f"self.{substitute}" - - if input_.data.type.is_list: - raise Exception(f"Output path template replacements cannot be lists. ({input_.data.name})") - - if input_.data.type.primitive == InputTypePrimitive.File: - # Just use the name of the file - # This is commonly used when output files 'inherit' the name of an input file - substitute = f"pathlib.Path({substitute}).name" - elif (input_.data.type.primitive == InputTypePrimitive.Number) or ( - input_.data.type.primitive == InputTypePrimitive.Integer - ): - # Convert to string - substitute = f"str({substitute})" - elif input_.data.type.primitive != InputTypePrimitive.String: - raise Exception( - f"Unsupported input type {input_.data.type.primitive} " - f"for output path template of '{out.data.name}'." - ) - - if strip_extensions: - exts = as_py_literal(out.data.stripped_file_extensions, "'") - substitute = f"{py_rstrip_fun}({substitute}, {exts})" - - path_template = path_template.replace(input_.data.template_key, enbrace(substitute)) - - resolved_output = f"{symbol_execution}.output_file(f{enquote(path_template)}{s_optional})" - - if any([input_.data.type.is_optional for input_ in input_dependencies]): - # Codegen: Condition: Is any variable in the segment set by the user? - condition = [codegen_var_is_set_by_user(i) for i in input_dependencies] - resolved_output = f"{resolved_output} if {' and '.join(condition)} else None" - - func.body.extend(indent([f"{out.symbol}={resolved_output},"])) - else: - raise NotImplementedError - - for input_ in inputs: - if (input_.data.type.primitive == InputTypePrimitive.SubCommand) or ( - input_.data.type.primitive == InputTypePrimitive.SubCommandUnion - ): - if input_.data.type.primitive == InputTypePrimitive.SubCommand: - assert input_.data.sub_command is not None - has_outouts = _sub_command_has_outputs(input_.data.sub_command) - else: - assert input_.data.sub_command_union is not None - has_outouts = any([ - _sub_command_has_outputs(sub_command) for sub_command in input_.data.sub_command_union - ]) - if has_outouts: - resolved_input = input_.symbol - if access_via_self: - resolved_input = f"self.{resolved_input}" - - if input_.data.type.is_list: - opt = "" - if input_.data.type.is_optional: - opt = f" if {resolved_input} else None" - func.body.extend( - indent([f"{input_.symbol}=" f"[i.outputs({symbol_execution}) for i in {resolved_input}]{opt},"]) - ) - else: - o = f"{resolved_input}.outputs({symbol_execution})" - if input_.data.type.is_optional: - o = f"{o} if {resolved_input} else None" - func.body.extend(indent([f"{input_.symbol}={o},"])) - - func.body.extend([")"]) diff --git a/src/styx/compiler/compile/reexport_module.py b/src/styx/compiler/compile/reexport_module.py deleted file mode 100644 index cc2c748..0000000 --- a/src/styx/compiler/compile/reexport_module.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Generate __init__.py that re-exports __all__ from all submodules.""" - -from styx.pycodegen.core import PyModule - - -def generate_reexport_module(relative_imports: list[str]) -> str: - """Generate __init__.py that re-exports __all__ from all submodules.""" - module: PyModule = PyModule() - module.imports = list(map(lambda item: f"from .{item} import *", sorted(relative_imports))) - return module.text() diff --git a/src/styx/compiler/compile/subcommand.py b/src/styx/compiler/compile/subcommand.py deleted file mode 100644 index ae49817..0000000 --- a/src/styx/compiler/compile/subcommand.py +++ /dev/null @@ -1,198 +0,0 @@ -from styx.compiler.compile.common import SharedSymbols -from styx.compiler.compile.constraints import generate_constraint_checks -from styx.compiler.compile.inputs import build_input_arguments, generate_command_line_args_building -from styx.compiler.compile.outputs import generate_output_building, generate_outputs_class -from styx.model.core import InputArgument, InputTypePrimitive, OutputArgument, SubCommand, WithSymbol -from styx.pycodegen.core import PyArg, PyDataClass, PyFunc, PyModule, blank_before -from styx.pycodegen.scope import Scope -from styx.pycodegen.utils import python_pascalize, python_snakify - - -def _sub_command_class_name(symbol_module: str, sub_command: SubCommand) -> str: - """Return the name of the sub-command class.""" - # Prefix the sub-command name with the module name so its likely unique across modules. - return python_pascalize(f"{symbol_module}_{sub_command.name}") - - -def _sub_command_output_class_name(symbol_module: str, sub_command: SubCommand) -> str: - """Return the name of the sub-command output class.""" - # Prefix the sub-command name with the module name so its likely unique across modules. - return python_pascalize(f"{symbol_module}_{sub_command.name}_Outputs") - - -def _sub_command_has_outputs(sub_command: SubCommand) -> bool: - """Check if the sub-command has outputs.""" - if len(sub_command.outputs) > 0: - return True - for input_ in sub_command.inputs: - if input_.type.primitive == InputTypePrimitive.SubCommand: - assert input_.sub_command is not None - if _sub_command_has_outputs(input_.sub_command): - return True - if input_.type.primitive == InputTypePrimitive.SubCommandUnion: - assert input_.sub_command_union is not None - for sub_command in input_.sub_command_union: - if _sub_command_has_outputs(sub_command): - return True - return False - - -def _generate_sub_command( - module: PyModule, - scope_module: Scope, - symbols: SharedSymbols, - sub_command: SubCommand, - outputs: list[WithSymbol[OutputArgument]], - inputs: list[WithSymbol[InputArgument]], - aliases: dict[str, str], - sub_command_output_class_aliases: dict[str, str], -) -> tuple[str, str]: - """Generate the static output class definition.""" - class_name = scope_module.add_or_dodge(_sub_command_class_name(symbols.function, sub_command)) - output_class_name = scope_module.add_or_dodge(_sub_command_output_class_name(symbols.function, sub_command)) - - module.exports.append(class_name) - sub_command_class = PyDataClass( - name=class_name, - docstring=sub_command.doc, - ) - # generate arguments - sub_command_class.fields.extend(build_input_arguments(inputs, aliases)) - - # generate run method - run_method = PyFunc( - name="run", - docstring_body="Build command line arguments. This method is called by the main command.", - args=[ - PyArg(name="self", type=None, default=None, docstring="The sub-command object."), - PyArg(name="execution", type="Execution", default=None, docstring="The execution object."), - ], - return_type="list[str]", - body=[ - "cargs = []", - ], - ) - inputs_self = [WithSymbol(i.data, f"self.{i.symbol}") for i in inputs] - - generate_constraint_checks(run_method, sub_command.group_constraints, inputs_self) - - generate_command_line_args_building(sub_command.input_command_line_template, run_method, inputs_self) - run_method.body.extend([ - "return cargs", - ]) - sub_command_class.methods.append(run_method) - - # Outputs method - - outputs_method = PyFunc( - name="outputs", - docstring_body="Collect output file paths.", - return_type=output_class_name, - return_descr=f"NamedTuple of outputs (described in `{output_class_name}`).", - args=[ - PyArg(name="self", type=None, default=None, docstring="The sub-command object."), - PyArg(name="execution", type="Execution", default=None, docstring="The execution object."), - ], - body=[], - ) - - if _sub_command_has_outputs(sub_command): - generate_outputs_class( - module, - output_class_name, - class_name + ".run", - outputs, - inputs, - sub_command_output_class_aliases, - ) - module.exports.append(output_class_name) - generate_output_building( - outputs_method, Scope(), symbols.execution, output_class_name, "ret", outputs, inputs, True - ) - outputs_method.body.extend(["return ret"]) - sub_command_class.methods.append(outputs_method) - - module.header.extend(blank_before(sub_command_class.generate(), 2)) - if "import dataclasses" not in module.imports: - module.imports.append("import dataclasses") - - return class_name, output_class_name - - -def generate_sub_command_classes( - module: PyModule, - symbols: SharedSymbols, - command: SubCommand, - scope_module: Scope, -) -> tuple[dict[str, str], dict[str, str], list[WithSymbol[InputArgument]]]: - """Build Python function arguments from input arguments.""" - # internal_id -> class_name - aliases: dict[str, str] = {} - # subcommand.internal_id -> subcommand.outputs() class name - sub_command_output_class_aliases: dict[str, str] = {} - - inputs_scope = Scope(parent=scope_module) - outputs_scope = Scope(parent=scope_module) - - # Input symbols - inputs: list[WithSymbol[InputArgument]] = [] - for i in command.inputs: - py_symbol = inputs_scope.add_or_dodge(python_snakify(i.name)) - inputs.append(WithSymbol(i, py_symbol)) - - for input_ in inputs: - if input_.data.type.primitive == InputTypePrimitive.SubCommand: - assert input_.data.sub_command is not None - sub_command = input_.data.sub_command - sub_aliases, sub_sub_command_output_class_aliases, sub_inputs = generate_sub_command_classes( - module, symbols, sub_command, inputs_scope - ) - aliases.update(sub_aliases) - sub_command_output_class_aliases.update(sub_sub_command_output_class_aliases) - - sub_outputs = [] - for output in sub_command.outputs: - py_symbol = outputs_scope.add_or_dodge(python_snakify(output.name)) - sub_outputs.append(WithSymbol(output, py_symbol)) - - sub_command_type, sub_command_output_type = _generate_sub_command( - module, - scope_module, - symbols, - sub_command, - sub_outputs, - sub_inputs, - aliases, - sub_command_output_class_aliases, - ) - aliases[sub_command.internal_id] = sub_command_type - sub_command_output_class_aliases[sub_command.internal_id] = sub_command_output_type - - if input_.data.type.primitive == InputTypePrimitive.SubCommandUnion: - assert input_.data.sub_command_union is not None - for sub_command in input_.data.sub_command_union: - sub_aliases, sub_sub_command_output_class_aliases, sub_inputs = generate_sub_command_classes( - module, symbols, sub_command, inputs_scope - ) - aliases.update(sub_aliases) - sub_command_output_class_aliases.update(sub_sub_command_output_class_aliases) - - sub_outputs = [] - for output in sub_command.outputs: - py_symbol = outputs_scope.add_or_dodge(python_snakify(output.name)) - sub_outputs.append(WithSymbol(output, py_symbol)) - - sub_command_type, sub_command_output_type = _generate_sub_command( - module, - scope_module, - symbols, - sub_command, - sub_outputs, - sub_inputs, - aliases, - sub_command_output_class_aliases, - ) - aliases[sub_command.internal_id] = sub_command_type - sub_command_output_class_aliases[sub_command.internal_id] = sub_command_output_type - - return aliases, sub_command_output_class_aliases, inputs diff --git a/src/styx/compiler/core.py b/src/styx/compiler/core.py deleted file mode 100644 index 180efa3..0000000 --- a/src/styx/compiler/core.py +++ /dev/null @@ -1,8 +0,0 @@ -from styx.compiler.compile.descriptor import compile_descriptor -from styx.compiler.settings import CompilerSettings -from styx.model.from_boutiques import descriptor_from_boutiques # type: ignore - - -def compile_boutiques_dict(boutiques_descriptor: dict, settings: CompilerSettings | None = None) -> str: - descriptor = descriptor_from_boutiques(boutiques_descriptor) - return compile_descriptor(descriptor, settings if settings is not None else CompilerSettings()) diff --git a/src/styx/compiler/settings.py b/src/styx/compiler/settings.py deleted file mode 100644 index c05a357..0000000 --- a/src/styx/compiler/settings.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Compiler settings.""" - -import pathlib -from dataclasses import dataclass - - -@dataclass -class CompilerSettings: - """Compiler settings.""" - - input_path: pathlib.Path | None = None - output_path: pathlib.Path | None = None - - debug_mode: bool = False diff --git a/src/styx/compiler/utils.py b/src/styx/compiler/utils.py deleted file mode 100644 index 37f5169..0000000 --- a/src/styx/compiler/utils.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Compiler utilities.""" - - -def optional_float_to_int(value: float | None) -> int | None: - """Convert an optional float to an optional int.""" - return int(value) if value is not None else None diff --git a/src/styx/frontend/__init__.py b/src/styx/frontend/__init__.py new file mode 100644 index 0000000..777cfeb --- /dev/null +++ b/src/styx/frontend/__init__.py @@ -0,0 +1 @@ +"""Styx frontends.""" diff --git a/src/styx/frontend/boutiques/__init__.py b/src/styx/frontend/boutiques/__init__.py new file mode 100644 index 0000000..f1522bd --- /dev/null +++ b/src/styx/frontend/boutiques/__init__.py @@ -0,0 +1,3 @@ +"""Boutiques frontend.""" + +from .core import from_boutiques as from_boutiques diff --git a/src/styx/frontend/boutiques/core.py b/src/styx/frontend/boutiques/core.py new file mode 100644 index 0000000..443097e --- /dev/null +++ b/src/styx/frontend/boutiques/core.py @@ -0,0 +1,480 @@ +"""Boutiques backend.""" + +import hashlib +import json +from dataclasses import dataclass +from enum import Enum +from typing import TypeVar + +import styx.ir.core as ir +from styx.frontend.boutiques.utils import boutiques_split_command +from styx.ir.dyn import dyn_param + +T = TypeVar("T") + + +def destruct_template( + template: str, + lookup: dict[str, T], +) -> list[str | T]: + """Destruct a template string to a list of strings and replacements. + + This is used to safely destruct boutiques `command-line` as well as `path-template` strings. + + Example: + >>> destruct_template( + >>> template="hello x, I am y", + >>> lookup={"x": 12, "y": 34}, + >>> ) + ["hello ", 12, ", I am ", 34] + """ + destructed: list[str | T] = [] + stack: list[str | T] = [template] + while len(stack) > 0: + x = stack.pop(0) + if not isinstance(x, str): + destructed.append(x) + continue + did_split = False + for alias, replacement in lookup.items(): + if alias in x: + left, right = x.split(alias, 1) + if len(right) > 0: + stack.insert(0, right) + stack.insert(0, replacement) + if len(left) > 0: + stack.insert(0, left) + did_split = True + break + if not did_split: + destructed.append(x) + return destructed + + +@dataclass +class IdCounter: + _counter: int = 0 + + def next(self) -> int: + self._counter += 1 + return self._counter - 1 + + +def _hash_from_boutiques(tool: dict) -> str: + """Generate a hash from a Boutiques tool.""" + json_str = json.dumps(tool, sort_keys=True) + return hashlib.sha1(json_str.encode()).hexdigest() + + +def _bt_template_str_parse( + input_command_line_template: str, + lookup_input: dict[str, dict], +) -> list[list[str | dict]]: + """Parse a Boutiques command line template string into segments.""" + bt_template_str = boutiques_split_command(input_command_line_template) + return [destruct_template(arg, lookup_input) for arg in bt_template_str] + + +class InputTypePrimitive(Enum): + String = 1 + Float = 2 + Integer = 3 + File = 4 + Flag = 5 + SubCommand = 6 + SubCommandUnion = 7 + + +@dataclass +class InputType: + primitive: InputTypePrimitive + is_list: bool = False + is_optional: bool = False + is_enum: bool = False + + +def _input_type_primitive_from_boutiques(bt_input: dict) -> InputTypePrimitive: + """Convert a Boutiques input to a Styx input type primitive.""" + if "type" not in bt_input: + raise ValueError(f"type is missing for input: '{bt_input['id']}'") + + if isinstance(bt_input["type"], dict): + return InputTypePrimitive.SubCommand + + if isinstance(bt_input["type"], list): + return InputTypePrimitive.SubCommandUnion + + bt_type_name = bt_input["type"] + if not isinstance(bt_type_name, str): + bt_type_name = bt_type_name.value + + if bt_type_name == "String": + return InputTypePrimitive.String + elif bt_type_name == "File": + return InputTypePrimitive.File + elif bt_type_name == "Flag": + return InputTypePrimitive.Flag + elif bt_type_name == "Number" and not bt_input.get("integer"): + return InputTypePrimitive.Float + elif bt_type_name == "Number" and bt_input.get("integer"): + return InputTypePrimitive.Integer + else: + raise NotImplementedError + + +def _input_type_from_boutiques(bt_input: dict) -> InputType: + """Convert a Boutiques input to a Styx input type.""" + bt_is_list = bt_input.get("list") is True + bt_is_optional = bt_input.get("optional") is True + bt_is_enum = bt_input.get("value-choices") is not None + primitive = _input_type_primitive_from_boutiques(bt_input) + if primitive == InputTypePrimitive.File: + assert not bt_is_enum + if primitive == InputTypePrimitive.Flag: + return InputType(InputTypePrimitive.Flag, False, True, False) + return InputType(primitive, bt_is_list, bt_is_optional, bt_is_enum) + + +def _arg_elem_from_bt_elem( + elem: dict, + id_counter: IdCounter, + ir_id_lookup: dict[str, ir.IdType], +) -> ir.IParam: + if not isinstance(elem, dict): + assert False + + d = elem + + input_bt_ref = d["value-key"] + input_docs = ir.Documentation( + title=d.get("name"), + description=d.get("description"), + ) + input_name = d["id"] + + repeatable_join: str | None = d.get("list-separator") + input_type = _input_type_from_boutiques(d) + + input_id = id_counter.next() + ir_id_lookup[input_bt_ref] = input_id + + dparam = ir.DParam( + id_=input_id, + name=input_name, + docs=input_docs, + ) + + constraints = _collect_constraints(d, input_type) + + dlist = None + if input_type.is_list: + dlist = ir.DList( + join=repeatable_join, + count_min=constraints.list_length_min, + count_max=constraints.list_length_max, + ) + + match input_type.primitive: + case InputTypePrimitive.String: + choices = d.get("value-choices") + assert choices is None or all([ + isinstance(o, str) for o in choices + ]), "value-choices must be all string for string input" + + return dyn_param( + dyn_type="str", + dyn_list=input_type.is_list, + dyn_optional=input_type.is_optional, + param=dparam, + list_=dlist, + default_value=d.get("default-value", ir.IOptional.SetToNone) + if input_type.is_optional + else d.get("default-value"), + choices=choices, + ) + + case InputTypePrimitive.Integer: + choices = d.get("value-choices") + assert choices is None or all([ + isinstance(o, int) for o in choices + ]), "value-choices must be all int for integer input" + + return dyn_param( + dyn_type="int", + dyn_list=input_type.is_list, + dyn_optional=input_type.is_optional, + param=dparam, + list_=dlist, + default_value=d.get("default-value", ir.IOptional.SetToNone) + if input_type.is_optional + else d.get("default-value"), + min_value=constraints.value_min, + max_value=constraints.value_max, + choices=choices, + ) + + case InputTypePrimitive.Float: + return dyn_param( + dyn_type="float", + dyn_list=input_type.is_list, + dyn_optional=input_type.is_optional, + param=dparam, + list_=dlist, + default_value=d.get("default-value", ir.IOptional.SetToNone) + if input_type.is_optional + else d.get("default-value"), + min_value=constraints.value_min, + max_value=constraints.value_max, + ) + + case InputTypePrimitive.File: + return dyn_param( + dyn_type="file", + dyn_list=input_type.is_list, + dyn_optional=input_type.is_optional, + param=dparam, + list_=dlist, + default_value_set_to_none=True, + ) + + case InputTypePrimitive.Flag: + input_prefix = d.get("command-line-flag") + assert input_prefix is not None, "Flag type input must have command-line-flag" + + dparam.prefix = [] + return ir.PBool( + param=dparam, + default_value=d.get("default-value") is True, + value_true=[input_prefix] if input_prefix else [], + value_false=[], + ) + case InputTypePrimitive.SubCommand: + dparam, dstruct = _struct_from_boutiques(d, id_counter) + ir_id_lookup[input_bt_ref] = dparam.id_ # override + + return dyn_param( + dyn_type="struct", + dyn_list=input_type.is_list, + dyn_optional=input_type.is_optional, + param=dparam, + struct=dstruct, + list_=dlist, + default_value_set_to_none=True, + ) + + case InputTypePrimitive.SubCommandUnion: + bt_alts = d.get("type") + assert isinstance(bt_alts, list) + + alts: list[ir.PStruct] = [] + for bt_alt in bt_alts: + alt_dparam, alt_dstruct = _struct_from_boutiques(bt_alt, id_counter) + alts.append( + ir.PStruct( + param=alt_dparam, + struct=alt_dstruct, + ) + ) + + return dyn_param( + dyn_type="struct_union", + dyn_list=input_type.is_list, + dyn_optional=input_type.is_optional, + param=dparam, + alts=alts, + list_=dlist, + default_value_set_to_none=True, + ) + assert False + + +@dataclass +class _NumericConstraints: + value_min: int | float | None = None + value_max: int | float | None = None + list_length_min: int | None = None + list_length_max: int | None = None + + +def _collect_constraints(d: dict, input_type: InputType) -> _NumericConstraints: + ret = _NumericConstraints() + value_min_exclusive = False + value_max_exclusive = False + if input_type.primitive in (InputTypePrimitive.Float, InputTypePrimitive.Integer): + if (val := d.get("minimum")) is not None: + ret.value_min = int(val) if d.get("integer") else val + value_min_exclusive = d.get("exclusive-minimum") is True + if (val := d.get("maximum")) is not None: + ret.value_max = int(val) if d.get("integer") else val + value_max_exclusive = d.get("exclusive-maximum") is True + if d.get("list") is True: + ret.list_length_min = d.get("min-list-entries") + ret.list_length_max = d.get("max-list-entries") + if ret.value_min is not None and value_min_exclusive and input_type.primitive == InputTypePrimitive.Integer: + ret.value_min += 1 + if ret.value_max is not None and value_max_exclusive and input_type.primitive == InputTypePrimitive.Integer: + ret.value_max -= 1 + return ret + + +def _struct_from_boutiques( + bt: dict, + id_counter: IdCounter, +) -> tuple[ir.DParam, ir.DStruct]: + def _get_authors(bt: dict) -> list[str]: + if "author" in bt: + return [bt["author"]] + return [] + + def _get_urls(bt: dict) -> list[str]: + if "url" in bt: + return [bt["url"]] + return [] + + parent_input: dict | None = None + if "type" not in bt: # Root boutiques descriptor + if (bt_id := bt.get("id", bt.get("name"))) is None: + raise Exception(f"Descriptor is missing id/name: {bt_id}") + + groups, ir_id_lookup = _collect_inputs(bt, id_counter) + outputs = _collect_outputs(bt, ir_id_lookup, id_counter) + + docs = ir.Documentation( + description=bt.get("description"), + authors=_get_authors(bt), + urls=_get_urls(bt), + ) + + return ir.DParam( + id_=id_counter.next(), + name=bt_id, + outputs=outputs, + docs=docs, + ), ir.DStruct( + name=bt_id, + groups=groups, + docs=docs, + ) + + else: + parent_input = bt + bt = bt["type"] + + groups, ir_id_lookup = _collect_inputs(bt, id_counter) + outputs = _collect_outputs(bt, ir_id_lookup, id_counter) + + docs_parent = ir.Documentation( + description=parent_input.get("description"), + authors=_get_authors(parent_input), + urls=_get_urls(parent_input), + ) + + docs = ir.Documentation( + description=bt.get("description"), + authors=_get_authors(bt), + urls=_get_urls(bt), + ) + + return ir.DParam( + id_=id_counter.next(), + name=parent_input["id"], + outputs=outputs, + docs=docs_parent, + ), ir.DStruct( + name=bt["id"], + groups=groups, + docs=docs, + ) + + +def _collect_outputs(bt: dict, ir_id_lookup: dict[str, ir.IdType], id_counter: IdCounter) -> list[ir.Output]: + outputs: list[ir.Output] = [] + for bt_output in bt.get("output-files", []): + path_template = bt_output["path-template"] + destructed = destruct_template(path_template, ir_id_lookup) + output_sequence = [ + ir.OutputParamReference( + ref_id=x, + file_remove_suffixes=bt_output.get("path-template-stripped-extensions", []), + ) + if isinstance(x, int) + else x + for x in destructed + ] + outputs.append( + ir.Output( + id_=id_counter.next(), + name=bt_output["id"], + tokens=output_sequence, + docs=ir.Documentation(description=bt_output.get("description")), + ) + ) + return outputs + + +def _collect_inputs(bt: dict, id_counter: IdCounter) -> tuple[list[ir.ConditionalGroup], dict[str, ir.IdType]]: + inputs_lookup = {i["value-key"]: i for i in bt.get("inputs", [])} + # maps boutiques 'value-keys' to expressions + ir_id_lookup: dict[str, ir.IdType] = {} + groups: list[ir.ConditionalGroup] = [] + for bt_segment in _bt_template_str_parse(bt.get("command-line", ""), inputs_lookup): + group = ir.ConditionalGroup() + carg = ir.Carg() + + for bt_elem in bt_segment: + if isinstance(bt_elem, str): + carg.tokens.append(bt_elem) + continue + + param = _arg_elem_from_bt_elem( + bt_elem, + id_counter, + ir_id_lookup, + ) + + if not isinstance(param, ir.IBool): + # bool arguments use command line flag as value + input_prefix: str | None = bt_elem.get("command-line-flag") + input_prefix_join: str | None = bt_elem.get("command-line-flag-separator") + if input_prefix_join is not None: + carg.tokens.append((input_prefix if input_prefix else "") + input_prefix_join) + elif input_prefix: + group.cargs.append(ir.Carg([input_prefix])) + + carg.tokens.append(param) + + group.cargs.append(carg) + groups.append(group) + + return groups, ir_id_lookup + + +def from_boutiques( + tool: dict, + package_name: str, + package_docs: ir.Documentation | None = None, +) -> ir.Interface: + """Convert a Boutiques tool to a Styx descriptor.""" + hash_ = _hash_from_boutiques(tool) + + docker: str | None = None + if "container-image" in tool: + docker = tool["container-image"].get("image") + + id_counter = IdCounter() + + dparam, dstruct = _struct_from_boutiques(tool, id_counter) + + return ir.Interface( + uid=f"{hash_}.boutiques", + package=ir.Package( + name=package_name, + version=tool.get("tool-version"), + docker=docker, + docs=package_docs if package_docs else ir.Documentation(), + ), + command=ir.PStruct( + param=dparam, + struct=dstruct, + ), + ) diff --git a/src/styx/model/boutiques_split_command.py b/src/styx/frontend/boutiques/utils.py similarity index 100% rename from src/styx/model/boutiques_split_command.py rename to src/styx/frontend/boutiques/utils.py diff --git a/src/styx/ir/__init__.py b/src/styx/ir/__init__.py new file mode 100644 index 0000000..bea421e --- /dev/null +++ b/src/styx/ir/__init__.py @@ -0,0 +1 @@ +"""Styx internal representation.""" diff --git a/src/styx/ir/core.py b/src/styx/ir/core.py new file mode 100644 index 0000000..9524313 --- /dev/null +++ b/src/styx/ir/core.py @@ -0,0 +1,294 @@ +import dataclasses +from abc import ABC +from dataclasses import dataclass +from typing import Any, Generator + + +@dataclass +class Documentation: + title: str | None = None + description: str | None = None + + authors: list[str] = dataclasses.field(default_factory=list) + literature: list[str] | None = dataclasses.field(default_factory=list) + urls: list[str] | None = dataclasses.field(default_factory=list) + + +@dataclass +class Package: + """Metadata for software package containing command.""" + + name: str + version: str | None + docker: str | None + + docs: Documentation = dataclasses.field(default_factory=Documentation) + + +IdType = int + + +@dataclass +class OutputParamReference: + ref_id: IdType + file_remove_suffixes: list[str] = dataclasses.field(default_factory=list) + + +@dataclass +class Output: + id_: IdType + name: str + tokens: list[str | OutputParamReference] = dataclasses.field(default_factory=list) + docs: Documentation | None = None + + +@dataclass +class DParam: + id_: IdType + name: str + + outputs: list[Output] = dataclasses.field(default_factory=list) + + docs: Documentation | None = None + + +@dataclass +class IParam(ABC): + param: DParam + + +class IOptional(ABC): + class SetToNoneAble: # noqa + pass + + SetToNone = SetToNoneAble() + pass + + +@dataclass +class DList: + count_min: int | None = None + count_max: int | None = None + join: str | None = None + + +@dataclass +class IList(ABC): + list_: DList = dataclasses.field(default_factory=DList) + + +@dataclass +class IInt(ABC): + choices: list[int] | None = None + min_value: int | None = None + max_value: int | None = None + + +@dataclass +class PInt(IInt, IParam): + default_value: int | None = None + + +@dataclass +class PIntOpt(IInt, IParam, IOptional): + default_value: int | IOptional.SetToNoneAble | None = IOptional.SetToNone + + +@dataclass +class PIntList(IInt, IList, IParam): + default_value: list[int] | None = None + + +@dataclass +class PIntListOpt(IInt, IList, IParam, IOptional): + default_value: list[int] | IOptional.SetToNoneAble | None = IOptional.SetToNone + + +class IFloat(ABC): + min_value: int | None = None + max_value: int | None = None + + +@dataclass +class PFloat(IFloat, IParam): + default_value: float | None = None + + +@dataclass +class PFloatOpt(IFloat, IParam, IOptional): + default_value: float | IOptional.SetToNoneAble | None = IOptional.SetToNone + + +@dataclass +class PFloatList(IFloat, IList, IParam): + default_value: list[float] | None = None + + +@dataclass +class PFloatListOpt(IFloat, IList, IParam, IOptional): + default_value: list[float] | IOptional.SetToNoneAble | None = IOptional.SetToNone + + +@dataclass +class IStr(ABC): + choices: list[str] | None = None + + +@dataclass +class PStr(IStr, IParam): + default_value: str | None = None + + +@dataclass +class PStrOpt(IStr, IParam, IOptional): + default_value: str | IOptional.SetToNoneAble | None = IOptional.SetToNone + + +@dataclass +class PStrList(IStr, IList, IParam): + default_value: list[str] | None = None + + +@dataclass +class PStrListOpt(IStr, IList, IParam, IOptional): + default_value: list[str] | IOptional.SetToNoneAble | None = IOptional.SetToNone + + +class IFile(ABC): + pass + + +@dataclass +class PFile(IFile, IParam): + pass + + +@dataclass +class PFileOpt(IFile, IParam, IOptional): + default_value_set_to_none: bool = True + + +@dataclass +class PFileList(IFile, IList, IParam): + pass + + +@dataclass +class PFileListOpt(IFile, IList, IParam, IOptional): + default_value_set_to_none: bool = True + + +@dataclass +class IBool(ABC): + value_true: list[str] = dataclasses.field(default_factory=list) + value_false: list[str] = dataclasses.field(default_factory=list) + + +@dataclass +class PBool(IBool, IParam): + default_value: bool | None = None + + +@dataclass +class PBoolOpt(IBool, IParam, IOptional): + default_value: bool | IOptional.SetToNoneAble | None = IOptional.SetToNone + + +@dataclass +class PBoolList(IBool, IList, IParam): + default_value: list[bool] | None = None + + +@dataclass +class PBoolListOpt(IBool, IList, IParam, IOptional): + default_value: list[bool] | IOptional.SetToNoneAble | None = IOptional.SetToNone + value_true: list[str] = dataclasses.field(default_factory=list) + value_false: list[str] = dataclasses.field(default_factory=list) + + +@dataclass +class Carg: + tokens: list[IParam | str] = dataclasses.field(default_factory=list) + + def iter_params(self) -> Generator[IParam, Any, None]: + for token in self.tokens: + if isinstance(token, IParam): + yield token + + +@dataclass +class ConditionalGroup: + cargs: list[Carg] = dataclasses.field(default_factory=list) + + def iter_params(self) -> Generator[IParam, Any, None]: + for carg in self.cargs: + yield from carg.iter_params() + + +@dataclass +class DStruct: + name: str | None = None + groups: list[ConditionalGroup] = dataclasses.field(default_factory=list) + """(group (cargs (join str+params))) """ + docs: Documentation | None = None + + def iter_params(self) -> Generator[IParam, Any, None]: + for group in self.groups: + yield from group.iter_params() + + +@dataclass +class IStruct(ABC): + struct: DStruct = dataclasses.field(default_factory=DStruct) + + +@dataclass +class PStruct(IStruct, IParam): + pass + + +@dataclass +class PStructOpt(IStruct, IParam, IOptional): + default_value_set_to_none: bool = True + + +@dataclass +class PStructList(IStruct, IList, IParam): + pass + + +@dataclass +class PStructListOpt(IStruct, IList, IParam, IOptional): + default_value_set_to_none: bool = True + + +@dataclass +class IStructUnion(ABC): + alts: list[PStruct] = dataclasses.field(default_factory=list) + + +@dataclass +class PStructUnion(IStructUnion, IParam): + pass + + +@dataclass +class PStructUnionOpt(IStructUnion, IParam, IOptional): + default_value_set_to_none: bool = True + + +@dataclass +class PStructUnionList(IStructUnion, IList, IParam): + pass + + +@dataclass +class PStructUnionListOpt(IStructUnion, IList, IParam, IOptional): + default_value_set_to_none: bool = True + + +@dataclass +class Interface: + uid: str + package: Package + command: PStruct diff --git a/src/styx/ir/dyn.py b/src/styx/ir/dyn.py new file mode 100644 index 0000000..eeedbe6 --- /dev/null +++ b/src/styx/ir/dyn.py @@ -0,0 +1,49 @@ +"""Convenience function that allows dynamic param class creation.""" + +import dataclasses +import typing + +import styx.ir.core as ir + + +def dyn_param( + dyn_type: typing.Literal["int", "float", "str", "file", "bool", "struct", "struct_union"], + dyn_list: bool, + dyn_optional: bool, + **kwargs, # noqa +) -> ir.IParam: + """Convenience function that allows dynamic param class creation.""" + cls = { + ("int", True, True): ir.PIntListOpt, + ("int", True, False): ir.PIntList, + ("int", False, True): ir.PIntOpt, + ("int", False, False): ir.PInt, + ("float", True, True): ir.PFloatListOpt, + ("float", True, False): ir.PFloatList, + ("float", False, True): ir.PFloatOpt, + ("float", False, False): ir.PFloat, + ("str", True, True): ir.PStrListOpt, + ("str", True, False): ir.PStrList, + ("str", False, True): ir.PStrOpt, + ("str", False, False): ir.PStr, + ("file", True, True): ir.PFileListOpt, + ("file", True, False): ir.PFileList, + ("file", False, True): ir.PFileOpt, + ("file", False, False): ir.PFile, + ("bool", True, True): ir.PBoolListOpt, + ("bool", True, False): ir.PBoolList, + ("bool", False, True): ir.PBoolOpt, + ("bool", False, False): ir.PBool, + ("struct", True, True): ir.PStructListOpt, + ("struct", True, False): ir.PStructList, + ("struct", False, True): ir.PStructOpt, + ("struct", False, False): ir.PStruct, + ("struct_union", True, True): ir.PStructUnionListOpt, + ("struct_union", True, False): ir.PStructUnionList, + ("struct_union", False, True): ir.PStructUnionOpt, + ("struct_union", False, False): ir.PStructUnion, + }[(dyn_type, dyn_list, dyn_optional)] + kwargs_relevant = { + field.name: kwargs[field.name] for field in dataclasses.fields(cls) if field.name if field.name in kwargs + } + return cls(**kwargs_relevant) diff --git a/src/styx/ir/optimize.py b/src/styx/ir/optimize.py new file mode 100644 index 0000000..16ca075 --- /dev/null +++ b/src/styx/ir/optimize.py @@ -0,0 +1,10 @@ +import styx.ir.core as ir + + +# todo: +# Likely optimizations: +# - Find nested required=True repeated=False expressions and merge them +# - Find neighbouring ConstantParameters in ExpressionSequences and merge them +# - Find min 0 max 1 repetitions and convert them to required=False +def optimize(expr: ir.Interface) -> ir.Interface: + return expr diff --git a/src/styx/ir/pretty_print.py b/src/styx/ir/pretty_print.py new file mode 100644 index 0000000..0fcc003 --- /dev/null +++ b/src/styx/ir/pretty_print.py @@ -0,0 +1,78 @@ +import dataclasses +from dataclasses import fields, is_dataclass +from typing import Any + +_LineBuffer = list[str] + + +def _expand(text: str) -> _LineBuffer: + """Expand a string into a LineBuffer.""" + return text.splitlines() + + +def _indent(lines: _LineBuffer, level: int = 1) -> _LineBuffer: + """Indent a LineBuffer by a given level.""" + if level == 0: + return lines + return [f"{' ' * level}{line}" for line in lines] + + +def _indentation(level: int = 1) -> str: + return " " * level + + +def _pretty_print(obj: Any, ind: int = 0) -> str: # noqa: ANN401 + def field_is_default(obj: Any, field_: dataclasses.Field) -> bool: # noqa: ANN401 + val = getattr(obj, field_.name) + if val == field_.default: + return True + if field_.default_factory != dataclasses.MISSING: + return val == field_.default_factory() + return False + + match obj: + case bool(): + return f"{obj}" + case str(): + return obj.__repr__() + case int(): + return f"{obj}" + case float(): + return f"{obj}" + case dict(): + if len(obj) == 0: + return "{}" + return f"\n{_indentation(ind)}".join([ + "{", + *_expand(",\n".join([f" {_pretty_print(key, 1)}: {_pretty_print(value, 1)}" for key, value in obj])), + "}", + ]) + case list(): + if len(obj) == 0: + return "[]" + return f"\n{_indentation(ind)}".join([ + "[", + *_expand(",\n".join([f" {_pretty_print(value, 1)}" for value in obj])), + "]", + ]) + case _: + if is_dataclass(obj): + return f"\n{_indentation(ind)}".join([ + f"{obj.__class__.__name__}(", + *_expand( + ",\n".join([ + f" {field.name}={_pretty_print(getattr(obj, field.name), 1)}" + for field in fields(obj) + if not field_is_default(obj, field) + ]) + ), + ")", + ]) + else: + return str(obj) + + +def pretty_print(obj: Any) -> None: # noqa: ANN401 + from rich import print + + print(_pretty_print(obj)) diff --git a/src/styx/ir/stats.py b/src/styx/ir/stats.py new file mode 100644 index 0000000..55d7df2 --- /dev/null +++ b/src/styx/ir/stats.py @@ -0,0 +1,43 @@ +import styx.ir.core as ir + + +def _expr_counter(expr: ir.IParam) -> int: + if isinstance(expr, ir.IStruct): + return 1 + sum([_expr_counter(e) for e in expr.struct.iter_params()]) + if isinstance(expr, ir.IStructUnion): + return 1 + sum([_expr_counter(e) for e in expr.alts]) + return 1 + + +def _param_counter(expr: ir.IParam) -> int: + if isinstance(expr, ir.IStruct): + return sum([_param_counter(e) for e in expr.struct.iter_params()]) + if isinstance(expr, ir.IStructUnion): + return sum([_param_counter(e) for e in expr.alts]) + return 1 + + +def _mccabe(expr: ir.IParam) -> int: + complexity = 1 + + if isinstance(expr, ir.IOptional) or ( + isinstance(expr, (ir.IStruct, ir.IStructUnion)) and isinstance(expr, ir.IList) + ): + complexity = 2 + + match expr: + case ir.IStruct(): + x = [_mccabe(e) for e in expr.struct.iter_params()] + return complexity * (sum(x) - len(x) + 1) + case ir.IStructUnion(): + return complexity * sum([_mccabe(e) for e in expr.alts]) + return complexity + + +def stats(interface: ir.Interface) -> dict[str, str | int | float]: + return { + "name": interface.command.param.name, + "num_expressions": _expr_counter(interface.command), + "num_params": _param_counter(interface.command), + "mccabe": _mccabe(interface.command), + } diff --git a/src/styx/main.py b/src/styx/main.py deleted file mode 100644 index 2b2d0a1..0000000 --- a/src/styx/main.py +++ /dev/null @@ -1,177 +0,0 @@ -import argparse -import json -import pathlib - -import tomli as tomllib # Remove once we move to python 3.11 - -from styx.compiler.compile.reexport_module import generate_reexport_module -from styx.compiler.core import compile_boutiques_dict -from styx.compiler.settings import CompilerSettings -from styx.pycodegen.utils import python_snakify - - -def load_settings_from_toml( - config_path: pathlib.Path, - override_input_folder: pathlib.Path | None = None, - override_output_folder: pathlib.Path | None = None, -) -> CompilerSettings: - """Load settings from a TOML file.""" - if not config_path.exists(): - if override_input_folder is None: - raise FileNotFoundError(f"Configuration file {config_path} does not exist") - return CompilerSettings( - input_path=override_input_folder, - output_path=override_output_folder, - ) - - with open(config_path, "rb") as f: - settings = tomllib.load(f) - return CompilerSettings( - input_path=override_input_folder or pathlib.Path(settings.get("input_path", ".")), - output_path=override_output_folder or pathlib.Path(settings.get("output_path", ".")), - ) - - -def collect_settings( - work_dir: pathlib.Path, - override_input_folder: pathlib.Path | None = None, - override_output_folder: pathlib.Path | None = None, - override_config_file: pathlib.Path | None = None, -) -> CompilerSettings: - """Collect settings.""" - if not work_dir.exists(): - raise FileNotFoundError(f"Work directory {work_dir} does not exist") - - config_file: pathlib.Path | None = None - - if override_config_file is not None: - config_file = override_config_file - elif override_input_folder is not None and (override_input_folder / "styx.toml").exists(): - config_file = override_input_folder / "styx.toml" - elif override_input_folder is not None and (override_input_folder / "pyproject.toml").exists(): - config_file = override_input_folder / "pyproject.toml" - elif (work_dir / "styx.toml").exists(): - config_file = work_dir / "styx.toml" - elif (work_dir / "pyproject.toml").exists(): - config_file = work_dir / "pyproject.toml" - - if config_file is not None: - settings = load_settings_from_toml( - config_path=config_file, - override_input_folder=override_input_folder, - override_output_folder=override_output_folder, - ) - else: - settings = CompilerSettings( - input_path=override_input_folder or work_dir, - output_path=override_output_folder, - ) - - return settings - - -def main() -> None: - """Main entry point.""" - parser = argparse.ArgumentParser(description="Compile JSON descriptors to Python modules") - parser.add_argument( - "-i", "--input-folder", type=pathlib.Path, help="Path to the input folder containing JSON descriptors" - ) - parser.add_argument( - "-o", "--output-folder", type=pathlib.Path, help="Path to the output folder for compiled Python modules" - ) - parser.add_argument("-c", "--config", type=pathlib.Path, help="Path to the configuration file") - parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") - args = parser.parse_args() - - settings = collect_settings( - work_dir=pathlib.Path.cwd(), - override_input_folder=args.input_folder, - override_output_folder=args.output_folder, - override_config_file=args.config, - ) - settings.debug_mode = settings.debug_mode or args.debug - - assert settings.input_path is not None - json_files = settings.input_path.glob("**/*.json") - - module_tree: dict = {} - fail_counter = 0 - total_counter = 0 - for json_path in json_files: - total_counter += 1 - output_module_path = json_path.parent.relative_to(settings.input_path).parts - # ensure module path is valid python symbol - output_module_path = tuple(python_snakify(part) for part in output_module_path) - output_module_name = python_snakify(json_path.stem) - output_file_name = f"{output_module_name}.py" - - subtree = module_tree - for part in output_module_path: - if part not in subtree: - subtree[part] = {} - subtree = subtree[part] - if "__items__" not in subtree: - subtree["__items__"] = [] - subtree["__items__"].append(output_module_name) - - # check if source is newer than target - if settings.output_path: - output_path = settings.output_path / pathlib.Path(*output_module_path) / output_file_name - if output_path.exists() and json_path.stat().st_mtime < output_path.stat().st_mtime: - continue - - with open(json_path, "r", encoding="utf-8") as json_file: - try: - json_data = json.load(json_file) - except json.JSONDecodeError: - print(f"Skipped: {json_path} (invalid JSON)") - fail_counter += 1 - continue - try: - code = compile_boutiques_dict(json_data, settings) - - if settings.output_path: - output_path = settings.output_path / pathlib.Path(*output_module_path) - output_path.mkdir(parents=True, exist_ok=True) - output_path = output_path / output_file_name - with open(output_path, "w") as py_file: - py_file.write(code) - print(f"Compiled {json_path} to {output_path}") - else: - print(f"Compiled {json_path} -> {pathlib.Path(*output_module_path) / output_file_name}: {'---' * 10}") - print(code) - print("---" * 10) - except Exception as e: - print(f"Skipped: {json_path}") - if settings.debug_mode: - raise e - fail_counter += 1 - import traceback - - print(traceback.format_exc()) - - # Re-export __init__.py files - # TODO: make optional - if settings.output_path is not None: - - def _walk_tree(tree: dict, path: str) -> None: - for key, value in tree.items(): - if key == "__items__": - assert settings.output_path is not None - out_path = settings.output_path / path / "__init__.py" - with open(out_path, "w") as init_file: - init_file.write(generate_reexport_module(value)) - print(f"Generated {out_path}") - else: - _walk_tree(value, f"{path}/{key}" if path else key) - - _walk_tree(module_tree, "") - - if fail_counter > 0: - print(f"Failed to compile {fail_counter}/{total_counter} descriptors.") - else: - print(f"Successfully compiled {total_counter} descriptors.") - - -if __name__ == "__main__": - main() diff --git a/src/styx/model/__init__.py b/src/styx/model/__init__.py deleted file mode 100644 index 97cc8b1..0000000 --- a/src/styx/model/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Styx data model.""" diff --git a/src/styx/model/core.py b/src/styx/model/core.py deleted file mode 100644 index 90ded29..0000000 --- a/src/styx/model/core.py +++ /dev/null @@ -1,105 +0,0 @@ -import pathlib -from dataclasses import dataclass -from enum import Enum -from typing import Generic, Mapping, Sequence, TypeAlias, TypeVar, Union - -T = TypeVar("T") - -TYPE_INPUT_VALUE_PRIMITIVE: TypeAlias = str | float | int | bool | pathlib.Path -TYPE_INPUT_VALUE: TypeAlias = TYPE_INPUT_VALUE_PRIMITIVE | Sequence[TYPE_INPUT_VALUE_PRIMITIVE] | None -TYPE_METADATA: TypeAlias = Mapping[str, str | int | float] - - -class InputTypePrimitive(Enum): - String = 1 - Number = 2 - Integer = 3 - File = 4 - Flag = 5 - SubCommand = 6 - SubCommandUnion = 7 - - -@dataclass -class InputType: - primitive: InputTypePrimitive - is_list: bool = False - is_optional: bool = False - is_enum: bool = False - - -@dataclass -class InputArgumentConstraints: - value_min: float | int | None = None - value_min_exclusive: bool = False - value_max: float | int | None = None - value_max_exclusive: bool = False - list_length_min: int | None = None - list_length_max: int | None = None - - -@dataclass -class InputArgument: - internal_id: str - template_key: str - - name: str - type: InputType - doc: str - constraints: InputArgumentConstraints - has_default_value: bool = False - default_value: TYPE_INPUT_VALUE | None = None - - command_line_flag: str | None = None - command_line_flag_separator: str | None = None - list_separator: str | None = None - enum_values: list[TYPE_INPUT_VALUE_PRIMITIVE] | None = None - - sub_command: Union["SubCommand", None] = None - sub_command_union: list["SubCommand"] | None = None - - -@dataclass -class OutputArgument: - name: str - doc: str - path_template: str - optional: bool = False - - stripped_file_extensions: list[str] | None = None - - -@dataclass -class GroupConstraint: - name: str - description: str - members: list[str] - - members_mutually_exclusive: bool = False - members_must_include_one: bool = False - members_must_include_all_or_none: bool = False - - -@dataclass -class SubCommand: - internal_id: str - - name: str - doc: str - input_command_line_template: str - inputs: list[InputArgument] - outputs: list[OutputArgument] - group_constraints: list[GroupConstraint] - - -@dataclass -class Descriptor: - hash: str - metadata: TYPE_METADATA - command: SubCommand - - -@dataclass -class WithSymbol(Generic[T]): - data: T - symbol: str diff --git a/src/styx/model/from_boutiques.py b/src/styx/model/from_boutiques.py deleted file mode 100644 index aa48e16..0000000 --- a/src/styx/model/from_boutiques.py +++ /dev/null @@ -1,325 +0,0 @@ -"""Convert a Boutiques tool to a Styx descriptor.""" - -import hashlib -import json -import pathlib - -from styx.model.core import ( - TYPE_INPUT_VALUE, - TYPE_INPUT_VALUE_PRIMITIVE, - Descriptor, - GroupConstraint, - InputArgument, - InputArgumentConstraints, - InputType, - InputTypePrimitive, - OutputArgument, - SubCommand, -) - - -def _input_type_primitive_from_boutiques(bt_input: dict) -> InputTypePrimitive: - """Convert a Boutiques input to a Styx input type primitive.""" - if "type" not in bt_input: - raise ValueError(f"type is missing for input: '{bt_input['id']}'") - - if isinstance(bt_input["type"], dict): - return InputTypePrimitive.SubCommand - - if isinstance(bt_input["type"], list): - return InputTypePrimitive.SubCommandUnion - - bt_type_name = bt_input["type"] - if not isinstance(bt_type_name, str): - bt_type_name = bt_type_name.value - - if bt_type_name == "String": - return InputTypePrimitive.String - elif bt_type_name == "File": - return InputTypePrimitive.File - elif bt_type_name == "Flag": - return InputTypePrimitive.Flag - elif bt_type_name == "Number" and not bt_input.get("integer"): - return InputTypePrimitive.Number - elif bt_type_name == "Number" and bt_input.get("integer"): - return InputTypePrimitive.Integer - else: - raise NotImplementedError - - -def _input_type_from_boutiques(bt_input: dict) -> InputType: - """Convert a Boutiques input to a Styx input type.""" - bt_is_list = bt_input.get("list") is True - bt_is_optional = bt_input.get("optional") is True - bt_is_enum = bt_input.get("value-choices") is not None - primitive = _input_type_primitive_from_boutiques(bt_input) - if primitive == InputTypePrimitive.File: - assert not bt_is_enum - if primitive == InputTypePrimitive.Flag: - return InputType(InputTypePrimitive.Flag, False, True, False) - return InputType(primitive, bt_is_list, bt_is_optional, bt_is_enum) - - -def _default_value_from_boutiques(bt_input: dict) -> tuple[bool, TYPE_INPUT_VALUE | None]: - """Convert a Boutiques input to a Styx default value.""" - primitive = _input_type_primitive_from_boutiques(bt_input) - default_value = bt_input.get("default-value") - if default_value is None: - if primitive == InputTypePrimitive.Flag: - return True, False - if bt_input.get("optional") is True: - return True, None - else: - return False, None - - if primitive == InputTypePrimitive.File: - assert isinstance(default_value, str), f"Expected string default-value, got {type(default_value)}" - return True, pathlib.Path(default_value) - elif primitive == InputTypePrimitive.String: - assert isinstance(default_value, str), f"Expected string default-value, got {type(default_value)}" - elif primitive == InputTypePrimitive.Number: - assert isinstance(default_value, (int, float)), f"Expected number default-value, got {type(default_value)}" - elif primitive == InputTypePrimitive.Integer: - assert isinstance(default_value, int), f"Expected integer default-value, got {type(default_value)}" - elif primitive == InputTypePrimitive.Flag: - assert isinstance(default_value, bool), f"Expected boolean default-value, got {type(default_value)}" - elif primitive == InputTypePrimitive.SubCommand: - assert isinstance(default_value, str), f"Expected string default-value, got {type(default_value)}" - else: - raise NotImplementedError - - return True, default_value - - -def _constraints_from_boutiques(bt_input: dict) -> InputArgumentConstraints: - """Convert a Boutiques input to a Styx input constraints.""" - value_min = None - value_min_exclusive = False - value_max = None - value_max_exclusive = False - list_length_min = None - list_length_max = None - - input_type = _input_type_primitive_from_boutiques(bt_input) - - if input_type in (InputTypePrimitive.Number, InputTypePrimitive.Integer): - if (val := bt_input.get("minimum")) is not None: - value_min = int(val) if bt_input.get("integer") else val - value_min_exclusive = bt_input.get("exclusive-minimum") is True - if (val := bt_input.get("maximum")) is not None: - value_max = int(val) if bt_input.get("integer") else val - value_max_exclusive = bt_input.get("exclusive-maximum") is True - if bt_input.get("list") is True: - list_length_min = bt_input.get("min-list-entries") - list_length_max = bt_input.get("max-list-entries") - - return InputArgumentConstraints( - value_min=value_min, - value_min_exclusive=value_min_exclusive, - value_max=value_max, - value_max_exclusive=value_max_exclusive, - list_length_min=list_length_min, - list_length_max=list_length_max, - ) - - -def _sub_command_from_boutiques(bt_subcommand: dict) -> SubCommand: - """Convert a Boutiques input to a Styx sub-command.""" - if "id" not in bt_subcommand: - raise ValueError(f"id is missing for sub-command: '{bt_subcommand}'") - if "command-line" not in bt_subcommand: - raise ValueError(f"command-line is missing for sub-command: '{bt_subcommand}'") - - inputs = [] - if "inputs" in bt_subcommand: - for input_ in bt_subcommand["inputs"]: - inputs.append(_input_argument_from_boutiques(input_)) - - outputs = [] - if "output-files" in bt_subcommand: - for output in bt_subcommand["output-files"]: - outputs.append(_output_argument_from_boutiques(output)) - - group_constraints = [] - if "groups" in bt_subcommand: - for group in bt_subcommand["groups"]: - group_constraints.append(_group_constraint_from_boutiques(group)) - - return SubCommand( - internal_id=bt_subcommand["id"], - name=bt_subcommand["id"], - doc=bt_subcommand.get("description", "Description missing"), - input_command_line_template=bt_subcommand["command-line"], - inputs=inputs, - outputs=outputs, - group_constraints=group_constraints, - ) - - -def _input_argument_from_boutiques(bt_input: dict) -> InputArgument: - """Convert a Boutiques input to a Styx input argument.""" - if "id" not in bt_input: - raise ValueError(f"id is missing for input: '{bt_input}'") - if "type" not in bt_input: - raise ValueError(f"type is missing for input '{bt_input['id']}'") - # Do we want to automatically generate value-key from ID if missing? - # Note: Boutiques 0.5 does not require value-key and I don't know why. - if "value-key" not in bt_input: - raise ValueError(f"value-key is missing for input '{bt_input['id']}'") - if len(bt_input["value-key"]) == 0: - raise ValueError(f"value-key is empty for input '{bt_input['id']}'") - - type_ = _input_type_from_boutiques(bt_input) - has_default_value, default_value = _default_value_from_boutiques(bt_input) - constraints = _constraints_from_boutiques(bt_input) - list_separator = bt_input.get("list-separator", None) - - enum_values: list[TYPE_INPUT_VALUE_PRIMITIVE] | None = None - if (value_choices := bt_input.get("value-choices")) is not None: - assert isinstance(value_choices, list) - assert all(isinstance(value, (str, int, float)) for value in value_choices) - - if type_.primitive == InputTypePrimitive.Integer: - enum_values = [int(value) for value in value_choices] - else: - enum_values = value_choices - - sub_command = None - sub_command_union = None - if type_.primitive == InputTypePrimitive.SubCommand: - sub_command = _sub_command_from_boutiques(bt_input["type"]) - elif type_.primitive == InputTypePrimitive.SubCommandUnion: - sub_command_union = [_sub_command_from_boutiques(subcommand) for subcommand in bt_input["type"]] - - return InputArgument( - internal_id=bt_input["value-key"], - template_key=bt_input["value-key"], - name=bt_input["id"], - type=type_, - doc=bt_input.get("description", "Description missing"), - has_default_value=has_default_value, - default_value=default_value, - constraints=constraints, - command_line_flag=bt_input.get("command-line-flag"), - command_line_flag_separator=bt_input.get("command-line-flag-separator"), - list_separator=list_separator, - enum_values=enum_values, - sub_command=sub_command, - sub_command_union=sub_command_union, - ) - - -def _output_argument_from_boutiques(bt_output: dict) -> OutputArgument: - """Convert a Boutiques output to a Styx output argument.""" - if "id" not in bt_output: - raise ValueError(f"id is missing for output: '{bt_output}'") - if "path-template" not in bt_output: - raise ValueError(f"path-template is missing for output '{bt_output['id']}'") - - return OutputArgument( - name=bt_output["id"], - doc=bt_output.get("description", "Description missing"), - optional=bt_output.get("optional") is True, - stripped_file_extensions=bt_output.get("path-template-stripped-extensions"), - path_template=bt_output["path-template"], - ) - - -def _group_constraint_from_boutiques(bt_group: dict) -> GroupConstraint: - """Convert a Boutiques group to a Styx group constraint.""" - if "id" not in bt_group: - raise ValueError(f"id is missing for group: '{bt_group}'") - - return GroupConstraint( - name=bt_group["id"], - description=bt_group.get("description", "Description missing"), - members=bt_group.get("members", []), # A group without members does not make sense. Raise an error? - members_mutually_exclusive=bt_group.get("mutually-exclusive") is True, - members_must_include_one=bt_group.get("one-is-required") is True, - members_must_include_all_or_none=bt_group.get("all-or-none") is True, - ) - - -def _hash_from_boutiques(tool: dict) -> str: - """Generate a hash from a Boutiques tool.""" - json_str = json.dumps(tool, sort_keys=True) - return hashlib.sha1(json_str.encode()).hexdigest() - - -def _boutiques_metadata(tool: dict, tool_hash: str) -> dict: - """Extract metadata from a Boutiques tool.""" - if "name" not in tool: - raise ValueError(f"name is missing for tool '{tool}'") - - metadata = {"id": tool_hash, "name": tool["name"]} - if (container_image := tool.get("container-image")) is not None: - if (val := container_image.get("type")) is not None: - metadata["container_image_type"] = val - if (val := container_image.get("index")) is not None: - metadata["container_image_index"] = val - if (val := container_image["image"]) is not None: - metadata["container_image_tag"] = val - return metadata - - -def _boutiques_documentation(tool: dict) -> str: - """Extract documentation from a Boutiques tool.""" - if "name" not in tool: - raise ValueError(f"name is missing for tool '{tool}'") - - doc = tool["name"] - - if "author" in tool: - doc += f" by {tool['author']}" - - description = tool.get("description", "Description missing.") - if not description.endswith("."): - description += "." - - doc += f".\n\n{description}" - - if "url" in tool: - doc += f"\n\nMore information: {tool['url']}" - - return doc - - -def descriptor_from_boutiques(tool: dict) -> Descriptor: - """Convert a Boutiques tool to a Styx descriptor.""" - if "name" not in tool: - raise ValueError(f"name is missing for tool '{tool}'") - if "command-line" not in tool: - raise ValueError(f"command-line is missing for tool '{tool['name']}'") - - inputs = [] - if "inputs" in tool: - for input_ in tool["inputs"]: - inputs.append(_input_argument_from_boutiques(input_)) - - outputs = [] - if "output-files" in tool: - for output in tool["output-files"]: - outputs.append(_output_argument_from_boutiques(output)) - - group_constraints = [] - if "groups" in tool: - for group in tool["groups"]: - group_constraints.append(_group_constraint_from_boutiques(group)) - - hash_ = _hash_from_boutiques(tool) - - metadata = _boutiques_metadata(tool, hash_) - - return Descriptor( - hash=hash_, - metadata=metadata, - command=SubCommand( - internal_id=tool["name"], - name=tool["name"], - doc=_boutiques_documentation(tool), - input_command_line_template=tool["command-line"], - inputs=inputs, - outputs=outputs, - group_constraints=group_constraints, - ), - ) diff --git a/tests/test_carg_building.py b/tests/test_carg_building.py index 870c855..109bde7 100644 --- a/tests/test_carg_building.py +++ b/tests/test_carg_building.py @@ -1,8 +1,7 @@ """Test command line argument building.""" -import styx.compiler.core -import styx.compiler.settings import tests.utils.dummy_runner +from tests.utils.compile_boutiques import boutiques2python from tests.utils.dynmodule import ( BT_TYPE_FILE, BT_TYPE_FLAG, @@ -27,7 +26,7 @@ def test_positional_string_arg() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() @@ -51,7 +50,7 @@ def test_positional_number_arg() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() @@ -75,7 +74,7 @@ def test_positional_file_arg() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() @@ -100,7 +99,7 @@ def test_flag_arg() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() @@ -125,7 +124,7 @@ def test_named_arg() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() @@ -159,7 +158,7 @@ def test_list_of_strings_arg() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() @@ -193,7 +192,7 @@ def test_list_of_numbers_arg() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() @@ -218,7 +217,7 @@ def test_static_args() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() @@ -263,11 +262,11 @@ def test_arg_order() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() - test_module.dummy("aaa", "bbb", runner=dummy_runner) + test_module.dummy(a="aaa", b="bbb", runner=dummy_runner) assert dummy_runner.last_cargs is not None assert dummy_runner.last_cargs == ["bbb", "aaa"] diff --git a/tests/test_default_values.py b/tests/test_default_values.py index b7702de..5e85f35 100644 --- a/tests/test_default_values.py +++ b/tests/test_default_values.py @@ -1,8 +1,7 @@ """Input argument default value tests.""" -import styx.compiler.core -import styx.compiler.settings import tests.utils.dummy_runner +from tests.utils.compile_boutiques import boutiques2python from tests.utils.dynmodule import ( BT_TYPE_STRING, boutiques_dummy, @@ -25,7 +24,7 @@ def test_default_string_arg() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() diff --git a/tests/test_groups.py b/tests/test_groups.py deleted file mode 100644 index 44d4e23..0000000 --- a/tests/test_groups.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Argument group constraint tests.""" - -import pytest - -import styx.compiler.core -import styx.compiler.settings -import tests.utils.dummy_runner -from tests.utils.dynmodule import ( - BT_TYPE_NUMBER, - boutiques_dummy, - dynamic_module, -) - -_XYZ_INPUTS = [ - { - "id": "x", - "name": "The x", - "value-key": "[X]", - "type": BT_TYPE_NUMBER, - "integer": True, - "optional": True, - }, - { - "id": "y", - "name": "The y", - "value-key": "[Y]", - "type": BT_TYPE_NUMBER, - "integer": True, - "optional": True, - }, - { - "id": "z", - "name": "The z", - "value-key": "[Z]", - "type": BT_TYPE_NUMBER, - "integer": True, - "optional": True, - }, -] - - -def test_mutually_exclusive() -> None: - """Mutually exclusive argument group.""" - model = boutiques_dummy({ - "command-line": "dummy [X] [Y] [Z]", - "inputs": _XYZ_INPUTS, - "groups": [ - { - "id": "group", - "name": "Group", - "members": ["x", "y", "z"], - "mutually-exclusive": True, - } - ], - }) - - compiled_module = styx.compiler.core.compile_boutiques_dict(model) - - test_module = dynamic_module(compiled_module, "test_module") - dummy_runner = tests.utils.dummy_runner.DummyRunner() - - with pytest.raises(ValueError): - test_module.dummy(runner=dummy_runner, x=1, y=2) - - with pytest.raises(ValueError): - test_module.dummy(runner=dummy_runner, x=1, y=2, z=3) - - assert test_module.dummy(runner=dummy_runner, x=1) is not None - assert test_module.dummy(runner=dummy_runner, y=2) is not None - assert test_module.dummy(runner=dummy_runner, z=2) is not None - assert test_module.dummy(runner=dummy_runner) is not None - - -def test_all_or_none() -> None: - """All or none argument group.""" - model = boutiques_dummy({ - "command-line": "dummy [X] [Y] [Z]", - "inputs": _XYZ_INPUTS, - "groups": [ - { - "id": "group", - "name": "Group", - "members": ["x", "y", "z"], - "all-or-none": True, - } - ], - }) - - compiled_module = styx.compiler.core.compile_boutiques_dict(model) - - test_module = dynamic_module(compiled_module, "test_module") - dummy_runner = tests.utils.dummy_runner.DummyRunner() - with pytest.raises(ValueError): - test_module.dummy(runner=dummy_runner, x=1, y=2) - with pytest.raises(ValueError): - test_module.dummy(runner=dummy_runner, z=3) - - assert test_module.dummy(runner=dummy_runner, x=1, y=2, z=3) is not None - assert test_module.dummy(runner=dummy_runner) is not None - - -def test_one_required() -> None: - """One required argument group.""" - model = boutiques_dummy({ - "command-line": "dummy [X] [Y] [Z]", - "inputs": _XYZ_INPUTS, - "groups": [ - { - "id": "group", - "name": "Group", - "members": ["x", "y", "z"], - "one-is-required": True, - } - ], - }) - - compiled_module = styx.compiler.core.compile_boutiques_dict(model) - print(compiled_module) - - test_module = dynamic_module(compiled_module, "test_module") - dummy_runner = tests.utils.dummy_runner.DummyRunner() - with pytest.raises(ValueError): - test_module.dummy(runner=dummy_runner) - - assert test_module.dummy(runner=dummy_runner, x=1) is not None - assert test_module.dummy(runner=dummy_runner, y=2) is not None - assert test_module.dummy(runner=dummy_runner, z=3) is not None - assert test_module.dummy(runner=dummy_runner, x=1, y=2) is not None - assert test_module.dummy(runner=dummy_runner, x=1, z=3) is not None - assert test_module.dummy(runner=dummy_runner, y=2, z=3) is not None - assert test_module.dummy(runner=dummy_runner, x=1, y=2, z=3) is not None diff --git a/tests/test_numeric_ranges.py b/tests/test_numeric_ranges.py index ea95e6a..a331244 100644 --- a/tests/test_numeric_ranges.py +++ b/tests/test_numeric_ranges.py @@ -2,9 +2,8 @@ import pytest -import styx.compiler.core -import styx.compiler.settings import tests.utils.dummy_runner +from tests.utils.compile_boutiques import boutiques2python from tests.utils.dynmodule import ( BT_TYPE_NUMBER, boutiques_dummy, @@ -28,7 +27,7 @@ def test_below_range_minimum_inclusive() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() @@ -52,7 +51,7 @@ def test_above_range_maximum_inclusive() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() @@ -77,7 +76,7 @@ def test_above_range_maximum_exclusive() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() @@ -102,7 +101,7 @@ def test_below_range_minimum_exclusive() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() @@ -127,7 +126,7 @@ def test_outside_range() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() diff --git a/tests/test_output_files.py b/tests/test_output_files.py index 0c73432..7123dab 100644 --- a/tests/test_output_files.py +++ b/tests/test_output_files.py @@ -1,8 +1,7 @@ """Test output file paths.""" -import styx.compiler.core -import styx.compiler.settings import tests.utils.dummy_runner +from tests.utils.compile_boutiques import boutiques2python from tests.utils.dynmodule import ( BT_TYPE_FILE, BT_TYPE_NUMBER, @@ -32,7 +31,7 @@ def test_output_file() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() @@ -60,12 +59,12 @@ def test_output_file_with_template() -> None: { "id": "out", "name": "The out", - "path-template": "out-{x}.txt", + "path-template": "out-[X].txt", } ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() @@ -99,7 +98,7 @@ def test_output_file_with_template_and_stripped_extensions() -> None: ], }) - compiled_module = styx.compiler.core.compile_boutiques_dict(model) + compiled_module = boutiques2python(model) test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() diff --git a/tests/test_scope.py b/tests/test_scope.py index f2ce3a3..95babe6 100644 --- a/tests/test_scope.py +++ b/tests/test_scope.py @@ -2,7 +2,7 @@ import pytest -from styx.pycodegen.scope import Scope +from styx.backend.python.pycodegen.scope import Scope def test_scope_add_or_die() -> None: diff --git a/tests/test_string_case.py b/tests/test_string_case.py index 5d3ffe9..c86ed26 100644 --- a/tests/test_string_case.py +++ b/tests/test_string_case.py @@ -2,7 +2,7 @@ import pytest -from styx.pycodegen.string_case import ( +from styx.backend.python.pycodegen.string_case import ( camel_case, pascal_case, screaming_snake_case, diff --git a/tests/utils/compile_boutiques.py b/tests/utils/compile_boutiques.py new file mode 100644 index 0000000..b30573f --- /dev/null +++ b/tests/utils/compile_boutiques.py @@ -0,0 +1,8 @@ +from styx.backend.python import to_python +from styx.frontend.boutiques import from_boutiques + + +def boutiques2python(boutiques: dict, package: str = "no_package") -> str: + ir = from_boutiques(boutiques, package) + py = to_python([ir]).__next__()[0] + return py