From 42d858e1d9855c02350cb25634876c412c7d0c67 Mon Sep 17 00:00:00 2001 From: Florian Rupprecht Date: Thu, 5 Sep 2024 14:27:30 -0400 Subject: [PATCH] IR progress --- src/styx/backend/python/__init__.py | 2 + src/styx/backend/python/constraints.py | 172 ++++++++++++++++++++ src/styx/backend/python/interface.py | 100 ++---------- src/styx/backend/python/lookup.py | 83 ++++++++++ src/styx/frontend/__init__.py | 1 + src/styx/frontend/boutiques/__init__.py | 3 + src/styx/frontend/boutiques/core.py | 198 ++++++++++-------------- src/styx/ir/core.py | 22 +-- src/styx/ir/dyn.py | 50 ++++++ tests/test_carg_building.py | 2 +- tests/test_groups.py | 130 ---------------- tests/test_output_files.py | 2 +- tests/utils/compile_boutiques.py | 8 +- 13 files changed, 414 insertions(+), 359 deletions(-) create mode 100644 src/styx/backend/python/constraints.py create mode 100644 src/styx/backend/python/lookup.py create mode 100644 src/styx/ir/dyn.py delete mode 100644 tests/test_groups.py diff --git a/src/styx/backend/python/__init__.py b/src/styx/backend/python/__init__.py index 556e1b3..146306d 100644 --- a/src/styx/backend/python/__init__.py +++ b/src/styx/backend/python/__init__.py @@ -1 +1,3 @@ """Python wrapper backend.""" + +from .core import to_python diff --git a/src/styx/backend/python/constraints.py b/src/styx/backend/python/constraints.py new file mode 100644 index 0000000..85e71ce --- /dev/null +++ b/src/styx/backend/python/constraints.py @@ -0,0 +1,172 @@ +from styx.backend.python.lookup import LookupParam +from styx.backend.python.pycodegen.core import LineBuffer, PyFunc, indent +import styx.ir.core as ir + + +def _generate_raise_value_err(obj: str, expectation: str, reality: str | None = None) -> LineBuffer: + fstr = "" + if "{" in obj or "{" in expectation or (reality is not None and "{" in reality): + fstr = "f" + + return ( + [f'raise ValueError({fstr}"{obj} must be {expectation} but was {reality}")'] + if reality is not None + else [f'raise ValueError({fstr}"{obj} must be {expectation}")'] + ) + + +def _param_compile_constraint_checks(buf: LineBuffer, param: ir.IParam, lookup: LookupParam) -> None: + """Generate input constraint validation code for an input argument.""" + py_symbol = lookup.py_symbol[param.param.id_] + + min_value: float | int | None = None + max_value: float | int | None = None + list_count_min: int | None = None + list_count_max: int | None = None + + if isinstance(param, (ir.IFloat, ir.IInt)): + min_value = param.min_value + max_value = param.max_value + elif isinstance(param, ir.IList): + list_count_min = param.list_.count_min + list_count_max = param.list_.count_max + + val_opt = "" + if isinstance(param, ir.IOptional): + val_opt = f"{py_symbol} is not None and " + + # List argument length validation + if list_count_min is not None and list_count_max is not None: + # Case: len(list[]) == X + assert list_count_min <= list_count_max + if list_count_min == list_count_max: + buf.extend([ + f"if {val_opt}(len({py_symbol}) != {list_count_min}): ", + *indent( + _generate_raise_value_err( + f"Length of '{py_symbol}'", + f"{list_count_min}", + f"{{len({py_symbol})}}", + ) + ), + ]) + else: + # Case: X <= len(list[]) <= Y + buf.extend([ + f"if {val_opt}not ({list_count_min} <= " f"len({py_symbol}) <= {list_count_max}): ", + *indent( + _generate_raise_value_err( + f"Length of '{py_symbol}'", + f"between {list_count_min} and {list_count_max}", + f"{{len({py_symbol})}}", + ) + ), + ]) + elif list_count_min is not None: + # Case len(list[]) >= X + buf.extend([ + f"if {val_opt}not ({list_count_min} <= len({py_symbol})): ", + *indent( + _generate_raise_value_err( + f"Length of '{py_symbol}'", + f"greater than {list_count_min}", + f"{{len({py_symbol})}}", + ) + ), + ]) + elif list_count_max is not None: + # Case len(list[]) <= X + buf.extend([ + f"if {val_opt}not (len({py_symbol}) <= {list_count_max}): ", + *indent( + _generate_raise_value_err( + f"Length of '{py_symbol}'", + f"less than {list_count_max}", + f"{{len({py_symbol})}}", + ) + ), + ]) + + # Numeric argument range validation + op_min = "<=" + op_max = "<=" + if min_value is not None and max_value is not None: + # Case: X <= arg <= Y + assert min_value <= max_value + if isinstance(param, ir.IList): + buf.extend([ + f"if {val_opt}not ({min_value} {op_min} min({py_symbol}) " + f"and max({py_symbol}) {op_max} {max_value}): ", + *indent( + _generate_raise_value_err( + f"All elements of '{py_symbol}'", + f"between {min_value} {op_min} x {op_max} {max_value}", + ) + ), + ]) + else: + buf.extend([ + f"if {val_opt}not ({min_value} {op_min} {py_symbol} {op_max} {max_value}): ", + *indent( + _generate_raise_value_err( + f"'{py_symbol}'", + f"between {min_value} {op_min} x {op_max} {max_value}", + f"{{{py_symbol}}}", + ) + ), + ]) + elif min_value is not None: + # Case: X <= arg + if isinstance(param, ir.IList): + buf.extend([ + f"if {val_opt}not ({min_value} {op_min} min({py_symbol})): ", + *indent( + _generate_raise_value_err( + f"All elements of '{py_symbol}'", + f"greater than {min_value} {op_min} x", + ) + ), + ]) + else: + buf.extend([ + f"if {val_opt}not ({min_value} {op_min} {py_symbol}): ", + *indent( + _generate_raise_value_err( + f"'{py_symbol}'", + f"greater than {min_value} {op_min} x", + f"{{{py_symbol}}}", + ) + ), + ]) + elif max_value is not None: + # Case: arg <= X + if isinstance(param, ir.IList): + buf.extend([ + f"if {val_opt}not (max({py_symbol}) {op_max} {max_value}): ", + *indent( + _generate_raise_value_err( + f"All elements of '{py_symbol}'", + f"less than x {op_max} {max_value}", + ) + ), + ]) + else: + buf.extend([ + f"if {val_opt}not ({py_symbol} {op_max} {max_value}): ", + *indent( + _generate_raise_value_err( + f"'{py_symbol}'", + f"less than x {op_max} {max_value}", + f"{{{py_symbol}}}", + ) + ), + ]) + + +def struct_compile_constraint_checks( + func: PyFunc, + struct: ir.IStruct | ir.IParam, + lookup: LookupParam, +) -> None: + for param in struct.struct.iter_params(): + _param_compile_constraint_checks(func.body, param, lookup) diff --git a/src/styx/backend/python/interface.py b/src/styx/backend/python/interface.py index 8b42798..8c12671 100644 --- a/src/styx/backend/python/interface.py +++ b/src/styx/backend/python/interface.py @@ -1,5 +1,7 @@ import styx.ir.core as ir +from styx.backend.python.constraints import struct_compile_constraint_checks from styx.backend.python.documentation import docs_to_docstring +from styx.backend.python.lookup import LookupParam from styx.backend.python.metadata import generate_static_metadata from styx.backend.python.pycodegen.core import ( LineBuffer, @@ -12,100 +14,19 @@ indent, ) from styx.backend.python.pycodegen.scope import Scope -from styx.backend.python.pycodegen.utils import as_py_literal, python_pascalize, python_snakify +from styx.backend.python.pycodegen.utils import as_py_literal, python_snakify from styx.backend.python.utils import ( - iter_params_recursively, param_py_default_value, - param_py_type, param_py_var_is_set_by_user, param_py_var_to_str, struct_has_outputs, ) -class _LookupParam: - """Pre-compute and store Python symbols, types, class-names, etc. to reduce spaghetti code everywhere else.""" - - def __init__( - self, - interface: ir.Interface, - package_scope: Scope, - function_symbol: str, - function_scope: Scope, - ) -> None: - def _collect_output_field_symbols( - param: ir.IStruct | ir.IParam, lookup_output_field_symbol: dict[ir.IdType, str] - ) -> None: - scope = Scope(parent=package_scope) - for output in param.param.outputs: - output_field_symbol = scope.add_or_dodge(output.name) - assert output.id_ not in lookup_output_field_symbol - lookup_output_field_symbol[output.id_] = output_field_symbol - - def _collect_py_symbol(param: ir.IStruct | ir.IParam, lookup_py_symbol: dict[ir.IdType, str]) -> None: - scope = Scope(parent=function_scope) - for elem in param.struct.iter_params(): - symbol = scope.add_or_dodge(python_snakify(elem.param.name)) - assert elem.param.id_ not in lookup_py_symbol - lookup_py_symbol[elem.param.id_] = symbol - - self.param: dict[ir.IdType, ir.IParam] = {interface.command.param.id_: interface.command} - """Find param object by its ID. IParam.id_ -> IParam""" - self.py_type: dict[ir.IdType, str] = {interface.command.param.id_: function_symbol} - """Find Python type by param id. IParam.id_ -> Python type""" - self.py_symbol: dict[ir.IdType, str] = {} - """Find function-parameter symbol by param ID. IParam.id_ -> Python symbol""" - self.py_output_type: dict[ir.IdType, str] = { - interface.command.param.id_: package_scope.add_or_dodge( - python_pascalize(f"{interface.command.struct.name}_Outputs") - ) - } - """Find outputs class name by struct param ID. IStruct.id_ -> Python class name""" - self.py_output_field_symbol: dict[ir.IdType, str] = {} - """Find output field symbol by output ID. Output.id_ -> Python symbol""" - - _collect_py_symbol( - param=interface.command, - lookup_py_symbol=self.py_symbol, - ) - _collect_output_field_symbols( - param=interface.command, - lookup_output_field_symbol=self.py_output_field_symbol, - ) - - for elem in iter_params_recursively(interface.command): - self.param[elem.param.id_] = elem - - if isinstance(elem, ir.IStruct): - if elem.param.id_ not in self.py_type: # Struct unions may resolve these first - self.py_type[elem.param.id_] = package_scope.add_or_dodge( - python_pascalize(f"{interface.command.struct.name}_{elem.struct.name}") - ) - self.py_output_type[elem.param.id_] = package_scope.add_or_dodge( - python_pascalize(f"{interface.command.struct.name}_{elem.struct.name}_Outputs") - ) - _collect_py_symbol( - param=elem, - lookup_py_symbol=self.py_symbol, - ) - _collect_output_field_symbols( - param=elem, - lookup_output_field_symbol=self.py_output_field_symbol, - ) - elif isinstance(elem, ir.IStructUnion): - for alternative in elem.alts: - self.py_type[alternative.param.id_] = package_scope.add_or_dodge( - python_pascalize(f"{interface.command.struct.name}_{alternative.struct.name}") - ) - self.py_type[elem.param.id_] = param_py_type(elem, self.py_type) - else: - self.py_type[elem.param.id_] = param_py_type(elem, self.py_type) - - def _compile_struct( param: ir.IStruct | ir.IParam, interface_module: PyModule, - lookup: _LookupParam, + lookup: LookupParam, metadata_symbol: str, root_function: bool, ) -> None: @@ -128,8 +49,6 @@ def _compile_struct( ) pyargs = func_cargs_building.args interface_module.funcs.append(func_cargs_building) - - pyargs.append(PyArg(name="runner", type="Runner | None", default="None", docstring="Command runner")) else: func_cargs_building = PyFunc( name="run", @@ -191,6 +110,8 @@ def _compile_struct( root_function=False, ) + struct_compile_constraint_checks(func=func_cargs_building, struct=param, lookup=lookup) + func_cargs_building.body.extend([ "runner = runner or get_global_runner()", f"execution = runner.start_execution({metadata_symbol})", @@ -199,6 +120,7 @@ def _compile_struct( _compile_cargs_building(param, lookup, func_cargs_building, access_via_self=not root_function) if root_function: + pyargs.append(PyArg(name="runner", type="Runner | None", default="None", docstring="Command runner")) _compile_outputs_building( param=param, func=func_cargs_building, @@ -227,7 +149,7 @@ def _compile_struct( def _compile_cargs_building( param: ir.IParam | ir.IStruct, - lookup: _LookupParam, + lookup: LookupParam, func: PyFunc, access_via_self: bool, ) -> None: @@ -282,7 +204,7 @@ def _compile_cargs_building( def _compile_outputs_class( param: ir.IStruct | ir.IParam, interface_module: PyModule, - lookup: _LookupParam, + lookup: LookupParam, ) -> None: outputs_class = PyDataClass( name=lookup.py_output_type[param.param.id_], @@ -328,7 +250,7 @@ def _compile_outputs_class( def _compile_outputs_building( param: ir.IStruct | ir.IParam, func: PyFunc, - lookup: _LookupParam, + lookup: LookupParam, access_via_self: bool = False, ) -> None: """Generate the outputs building code.""" @@ -424,7 +346,7 @@ def compile_interface( function_scope.add_or_die("ret") # Lookup tables - lookup = _LookupParam( + lookup = LookupParam( interface=interface, package_scope=package_scope, function_symbol=function_symbol, diff --git a/src/styx/backend/python/lookup.py b/src/styx/backend/python/lookup.py new file mode 100644 index 0000000..266644c --- /dev/null +++ b/src/styx/backend/python/lookup.py @@ -0,0 +1,83 @@ +import styx.ir.core as ir +from styx.backend.python.pycodegen.scope import Scope +from styx.backend.python.pycodegen.utils import python_snakify, python_pascalize +from styx.backend.python.utils import iter_params_recursively, param_py_type + + +class LookupParam: + """Pre-compute and store Python symbols, types, class-names, etc. to reduce spaghetti code everywhere else.""" + + def __init__( + self, + interface: ir.Interface, + package_scope: Scope, + function_symbol: str, + function_scope: Scope, + ) -> None: + def _collect_output_field_symbols( + param: ir.IStruct | ir.IParam, lookup_output_field_symbol: dict[ir.IdType, str] + ) -> None: + scope = Scope(parent=package_scope) + for output in param.param.outputs: + output_field_symbol = scope.add_or_dodge(output.name) + assert output.id_ not in lookup_output_field_symbol + lookup_output_field_symbol[output.id_] = output_field_symbol + + def _collect_py_symbol(param: ir.IStruct | ir.IParam, lookup_py_symbol: dict[ir.IdType, str]) -> None: + scope = Scope(parent=function_scope) + for elem in param.struct.iter_params(): + symbol = scope.add_or_dodge(python_snakify(elem.param.name)) + assert elem.param.id_ not in lookup_py_symbol + lookup_py_symbol[elem.param.id_] = symbol + + self.param: dict[ir.IdType, ir.IParam] = {interface.command.param.id_: interface.command} + """Find param object by its ID. IParam.id_ -> IParam""" + self.py_type: dict[ir.IdType, str] = {interface.command.param.id_: function_symbol} + """Find Python type by param id. IParam.id_ -> Python type""" + self.py_symbol: dict[ir.IdType, str] = {} + """Find function-parameter symbol by param ID. IParam.id_ -> Python symbol""" + self.py_output_type: dict[ir.IdType, str] = { + interface.command.param.id_: package_scope.add_or_dodge( + python_pascalize(f"{interface.command.struct.name}_Outputs") + ) + } + """Find outputs class name by struct param ID. IStruct.id_ -> Python class name""" + self.py_output_field_symbol: dict[ir.IdType, str] = {} + """Find output field symbol by output ID. Output.id_ -> Python symbol""" + + _collect_py_symbol( + param=interface.command, + lookup_py_symbol=self.py_symbol, + ) + _collect_output_field_symbols( + param=interface.command, + lookup_output_field_symbol=self.py_output_field_symbol, + ) + + for elem in iter_params_recursively(interface.command): + self.param[elem.param.id_] = elem + + if isinstance(elem, ir.IStruct): + if elem.param.id_ not in self.py_type: # Struct unions may resolve these first + self.py_type[elem.param.id_] = package_scope.add_or_dodge( + python_pascalize(f"{interface.command.struct.name}_{elem.struct.name}") + ) + self.py_output_type[elem.param.id_] = package_scope.add_or_dodge( + python_pascalize(f"{interface.command.struct.name}_{elem.struct.name}_Outputs") + ) + _collect_py_symbol( + param=elem, + lookup_py_symbol=self.py_symbol, + ) + _collect_output_field_symbols( + param=elem, + lookup_output_field_symbol=self.py_output_field_symbol, + ) + elif isinstance(elem, ir.IStructUnion): + for alternative in elem.alts: + self.py_type[alternative.param.id_] = package_scope.add_or_dodge( + python_pascalize(f"{interface.command.struct.name}_{alternative.struct.name}") + ) + self.py_type[elem.param.id_] = param_py_type(elem, self.py_type) + else: + self.py_type[elem.param.id_] = param_py_type(elem, self.py_type) diff --git a/src/styx/frontend/__init__.py b/src/styx/frontend/__init__.py index e69de29..777cfeb 100644 --- a/src/styx/frontend/__init__.py +++ b/src/styx/frontend/__init__.py @@ -0,0 +1 @@ +"""Styx frontends.""" diff --git a/src/styx/frontend/boutiques/__init__.py b/src/styx/frontend/boutiques/__init__.py index e69de29..426278a 100644 --- a/src/styx/frontend/boutiques/__init__.py +++ b/src/styx/frontend/boutiques/__init__.py @@ -0,0 +1,3 @@ +"""Boutiques frontend""" + +from .core import from_boutiques diff --git a/src/styx/frontend/boutiques/core.py b/src/styx/frontend/boutiques/core.py index f35fbdf..4ec01aa 100644 --- a/src/styx/frontend/boutiques/core.py +++ b/src/styx/frontend/boutiques/core.py @@ -8,6 +8,7 @@ import styx.ir.core as ir from styx.frontend.boutiques.utils import boutiques_split_command +from styx.ir.dyn import dyn_param T = TypeVar("T") @@ -163,12 +164,14 @@ def _arg_elem_from_bt_elem( docs=input_docs, ) + constraints = _collect_constraints(d, input_type) + dlist = None if input_type.is_list: dlist = ir.DList( join=repeatable_join, - count_min=d.get("min-list-entries"), - count_max=d.get("max-list-entries"), + count_min=constraints.list_length_min, + count_max=constraints.list_length_max, ) match input_type.primitive: @@ -178,28 +181,15 @@ def _arg_elem_from_bt_elem( isinstance(o, str) for o in choices ]), "value-choices must be all string for string input" - if input_type.is_list: - if input_type.is_optional: - return ir.PStrListOpt( - param=dparam, - list_=dlist, - default_value=d.get("default-value"), - ) - return ir.PStrList( - param=dparam, - list_=dlist, - default_value=d.get("default-value"), - ) - if input_type.is_optional: - return ir.PStrOpt( - param=dparam, - default_value=d.get("default-value"), - choices=choices, - ) - return ir.PStr( + return dyn_param( + dyn_type="str", + dyn_list=input_type.is_list, + dyn_optional=input_type.is_optional, param=dparam, - default_value=d.get("default-value"), - choices=choices, + list_=dlist, + default_value=d.get("default-value", ir.IOptional.SetToNone) + if input_type.is_optional + else d.get("default-value"), ) case InputTypePrimitive.Integer: @@ -208,70 +198,41 @@ def _arg_elem_from_bt_elem( isinstance(o, int) for o in choices ]), "value-choices must be all int for integer input" - if input_type.is_list: - if input_type.is_optional: - return ir.PIntListOpt( - param=dparam, - list_=dlist, - default_value=d.get("default-value"), - ) - return ir.PIntList( - param=dparam, - list_=dlist, - default_value=d.get("default-value"), - ) - if input_type.is_optional: - return ir.PIntOpt( - param=dparam, - default_value=d.get("default-value"), - choices=choices, - ) - return ir.PInt( + return dyn_param( + dyn_type="int", + dyn_list=input_type.is_list, + dyn_optional=input_type.is_optional, param=dparam, - default_value=d.get("default-value"), - choices=choices, + list_=dlist, + default_value=d.get("default-value", ir.IOptional.SetToNone) + if input_type.is_optional + else d.get("default-value"), + min_value=constraints.value_min, + max_value=constraints.value_max, ) case InputTypePrimitive.Float: - if input_type.is_list: - if input_type.is_optional: - return ir.PFloatListOpt( - param=dparam, - list_=dlist, - default_value=d.get("default-value"), - ) - return ir.PFloatList( - param=dparam, - list_=dlist, - default_value=d.get("default-value"), - ) - if input_type.is_optional: - return ir.PFloatOpt( - param=dparam, - default_value=d.get("default-value"), - ) - return ir.PFloat( + return dyn_param( + dyn_type="float", + dyn_list=input_type.is_list, + dyn_optional=input_type.is_optional, param=dparam, - default_value=d.get("default-value"), + list_=dlist, + default_value=d.get("default-value", ir.IOptional.SetToNone) + if input_type.is_optional + else d.get("default-value"), + min_value=constraints.value_min, + max_value=constraints.value_max, ) case InputTypePrimitive.File: - if input_type.is_list: - if input_type.is_optional: - return ir.PFileListOpt( - param=dparam, - list_=dlist, - ) - return ir.PFileList( - param=dparam, - list_=dlist, - ) - if input_type.is_optional: - return ir.PFileOpt( - param=dparam, - ) - return ir.PFile( + return dyn_param( + dyn_type="file", + dyn_list=input_type.is_list, + dyn_optional=input_type.is_optional, param=dparam, + list_=dlist, + default_value_set_to_none=True, ) case InputTypePrimitive.Flag: @@ -289,28 +250,14 @@ def _arg_elem_from_bt_elem( dparam, dstruct = _struct_from_boutiques(d, id_counter) ir_id_lookup[input_bt_ref] = dparam.id_ # override - if input_type.is_list: - if input_type.is_optional: - return ir.PStructListOpt( - param=dparam, - struct=dstruct, - list_=dlist, - default_value_set_to_none=True, - ) - return ir.PStructList( - param=dparam, - struct=dstruct, - list_=dlist, - ) - if input_type.is_optional: - return ir.PStructOpt( - param=dparam, - struct=dstruct, - default_value_set_to_none=True, - ) - return ir.PStruct( + return dyn_param( + dyn_type="struct", + dyn_list=input_type.is_list, + dyn_optional=input_type.is_optional, param=dparam, struct=dstruct, + list_=dlist, + default_value_set_to_none=True, ) case InputTypePrimitive.SubCommandUnion: @@ -327,32 +274,47 @@ def _arg_elem_from_bt_elem( ) ) - if input_type.is_list: - if input_type.is_optional: - return ir.PStructUnionListOpt( - param=dparam, - alts=alts, - list_=dlist, - default_value_set_to_none=True, - ) - return ir.PStructUnionList( - param=dparam, - alts=alts, - list_=dlist, - ) - if input_type.is_optional: - return ir.PStructUnionOpt( - param=dparam, - alts=alts, - default_value_set_to_none=True, - ) - return ir.PStructUnion( + return dyn_param( + dyn_type="struct_union", + dyn_list=input_type.is_list, + dyn_optional=input_type.is_optional, param=dparam, alts=alts, + list_=dlist, + default_value_set_to_none=True, ) assert False +@dataclass +class _NumericConstraints: + value_min: int | float | None = None + value_max: int | float | None = None + list_length_min: int | None = None + list_length_max: int | None = None + + +def _collect_constraints(d, input_type): + ret = _NumericConstraints() + value_min_exclusive = False + value_max_exclusive = False + if input_type.primitive in (InputTypePrimitive.Float, InputTypePrimitive.Integer): + if (val := d.get("minimum")) is not None: + ret.value_min = int(val) if d.get("integer") else val + value_min_exclusive = d.get("exclusive-minimum") is True + if (val := d.get("maximum")) is not None: + ret.value_max = int(val) if d.get("integer") else val + value_max_exclusive = d.get("exclusive-maximum") is True + if d.get("list") is True: + ret.list_length_min = d.get("min-list-entries") + ret.list_length_max = d.get("max-list-entries") + if ret.value_min is not None and value_min_exclusive and input_type.primitive == InputTypePrimitive.Integer: + ret.value_min += 1 + if ret.value_max is not None and value_max_exclusive and input_type.primitive == InputTypePrimitive.Integer: + ret.value_max -= 1 + return ret + + def _struct_from_boutiques( bt: dict, id_counter: IdCounter, diff --git a/src/styx/ir/core.py b/src/styx/ir/core.py index 102fcf2..38b9eec 100644 --- a/src/styx/ir/core.py +++ b/src/styx/ir/core.py @@ -1,4 +1,5 @@ import dataclasses +import typing from abc import ABC from dataclasses import dataclass @@ -79,66 +80,53 @@ class IList(ABC): @dataclass class IInt(ABC): choices: list[int] | None = None + min_value: int | None = None + max_value: int | None = None @dataclass class PInt(IInt, IParam): default_value: int | None = None - min_value: int | None = None - max_value: int | None = None @dataclass class PIntOpt(IInt, IParam, IOptional): default_value: int | IOptional.SetToNoneAble | None = IOptional.SetToNone - min_value: int | None = None - max_value: int | None = None @dataclass class PIntList(IInt, IList, IParam): default_value: list[int] | None = None - all_min_value: int | None = None - all_max_value: int | None = None @dataclass class PIntListOpt(IInt, IList, IParam, IOptional): default_value: list[int] | IOptional.SetToNoneAble | None = IOptional.SetToNone - all_min_value: int | None = None - all_max_value: int | None = None class IFloat(ABC): - pass + min_value: int | None = None + max_value: int | None = None @dataclass class PFloat(IFloat, IParam): default_value: float | None = None - min_value: int | None = None - max_value: int | None = None @dataclass class PFloatOpt(IFloat, IParam, IOptional): default_value: float | IOptional.SetToNoneAble | None = IOptional.SetToNone - min_value: int | None = None - max_value: int | None = None @dataclass class PFloatList(IFloat, IList, IParam): default_value: list[float] | None = None - all_min_value: float | None = None - all_max_value: float | None = None @dataclass class PFloatListOpt(IFloat, IList, IParam, IOptional): default_value: list[float] | IOptional.SetToNoneAble | None = IOptional.SetToNone - all_min_value: float | None = None - all_max_value: float | None = None @dataclass diff --git a/src/styx/ir/dyn.py b/src/styx/ir/dyn.py new file mode 100644 index 0000000..c6f5b7a --- /dev/null +++ b/src/styx/ir/dyn.py @@ -0,0 +1,50 @@ +"""Convenience function that allows dynamic param class creation.""" + +import dataclasses +import typing + +import styx.ir.core as ir + + +def dyn_param( + dyn_type: typing.Literal["int", "float", "str", "file", "bool", "struct", "struct_union"], + dyn_list: bool, + dyn_optional: bool, + **kwargs, +) -> ir.IParam: + """Convenience function that allows dynamic param class creation.""" + + cls = { + ("int", True, True): ir.PIntListOpt, + ("int", True, False): ir.PIntList, + ("int", False, True): ir.PIntOpt, + ("int", False, False): ir.PInt, + ("float", True, True): ir.PFloatListOpt, + ("float", True, False): ir.PFloatList, + ("float", False, True): ir.PFloatOpt, + ("float", False, False): ir.PFloat, + ("str", True, True): ir.PStrListOpt, + ("str", True, False): ir.PStrList, + ("str", False, True): ir.PStrOpt, + ("str", False, False): ir.PStr, + ("file", True, True): ir.PFileListOpt, + ("file", True, False): ir.PFileList, + ("file", False, True): ir.PFileOpt, + ("file", False, False): ir.PFile, + ("bool", True, True): ir.PBoolListOpt, + ("bool", True, False): ir.PBoolList, + ("bool", False, True): ir.PBoolOpt, + ("bool", False, False): ir.PBool, + ("struct", True, True): ir.PStructListOpt, + ("struct", True, False): ir.PStructList, + ("struct", False, True): ir.PStructOpt, + ("struct", False, False): ir.PStruct, + ("struct_union", True, True): ir.PStructUnionListOpt, + ("struct_union", True, False): ir.PStructUnionList, + ("struct_union", False, True): ir.PStructUnionOpt, + ("struct_union", False, False): ir.PStructUnion, + }[(dyn_type, dyn_list, dyn_optional)] + kwargs_relevant = { + field.name: kwargs[field.name] for field in dataclasses.fields(cls) if field.name if field.name in kwargs + } + return cls(**kwargs_relevant) diff --git a/tests/test_carg_building.py b/tests/test_carg_building.py index 1107324..109bde7 100644 --- a/tests/test_carg_building.py +++ b/tests/test_carg_building.py @@ -266,7 +266,7 @@ def test_arg_order() -> None: test_module = dynamic_module(compiled_module, "test_module") dummy_runner = tests.utils.dummy_runner.DummyRunner() - test_module.dummy("aaa", "bbb", runner=dummy_runner) + test_module.dummy(a="aaa", b="bbb", runner=dummy_runner) assert dummy_runner.last_cargs is not None assert dummy_runner.last_cargs == ["bbb", "aaa"] diff --git a/tests/test_groups.py b/tests/test_groups.py deleted file mode 100644 index 4777608..0000000 --- a/tests/test_groups.py +++ /dev/null @@ -1,130 +0,0 @@ -"""Argument group constraint tests.""" - -import pytest - -import tests.utils.dummy_runner -from tests.utils.compile_boutiques import boutiques2python -from tests.utils.dynmodule import ( - BT_TYPE_NUMBER, - boutiques_dummy, - dynamic_module, -) - -_XYZ_INPUTS = [ - { - "id": "x", - "name": "The x", - "value-key": "[X]", - "type": BT_TYPE_NUMBER, - "integer": True, - "optional": True, - }, - { - "id": "y", - "name": "The y", - "value-key": "[Y]", - "type": BT_TYPE_NUMBER, - "integer": True, - "optional": True, - }, - { - "id": "z", - "name": "The z", - "value-key": "[Z]", - "type": BT_TYPE_NUMBER, - "integer": True, - "optional": True, - }, -] - - -def test_mutually_exclusive() -> None: - """Mutually exclusive argument group.""" - model = boutiques_dummy({ - "command-line": "dummy [X] [Y] [Z]", - "inputs": _XYZ_INPUTS, - "groups": [ - { - "id": "group", - "name": "Group", - "members": ["x", "y", "z"], - "mutually-exclusive": True, - } - ], - }) - - compiled_module = boutiques2python(model) - - test_module = dynamic_module(compiled_module, "test_module") - dummy_runner = tests.utils.dummy_runner.DummyRunner() - - with pytest.raises(ValueError): - test_module.dummy(runner=dummy_runner, x=1, y=2) - - with pytest.raises(ValueError): - test_module.dummy(runner=dummy_runner, x=1, y=2, z=3) - - assert test_module.dummy(runner=dummy_runner, x=1) is not None - assert test_module.dummy(runner=dummy_runner, y=2) is not None - assert test_module.dummy(runner=dummy_runner, z=2) is not None - assert test_module.dummy(runner=dummy_runner) is not None - - -def test_all_or_none() -> None: - """All or none argument group.""" - model = boutiques_dummy({ - "command-line": "dummy [X] [Y] [Z]", - "inputs": _XYZ_INPUTS, - "groups": [ - { - "id": "group", - "name": "Group", - "members": ["x", "y", "z"], - "all-or-none": True, - } - ], - }) - - compiled_module = boutiques2python(model) - - test_module = dynamic_module(compiled_module, "test_module") - dummy_runner = tests.utils.dummy_runner.DummyRunner() - with pytest.raises(ValueError): - test_module.dummy(runner=dummy_runner, x=1, y=2) - with pytest.raises(ValueError): - test_module.dummy(runner=dummy_runner, z=3) - - assert test_module.dummy(runner=dummy_runner, x=1, y=2, z=3) is not None - assert test_module.dummy(runner=dummy_runner) is not None - - -def test_one_required() -> None: - """One required argument group.""" - model = boutiques_dummy({ - "command-line": "dummy [X] [Y] [Z]", - "inputs": _XYZ_INPUTS, - "groups": [ - { - "id": "group", - "name": "Group", - "members": ["x", "y", "z"], - "one-is-required": True, - } - ], - }) - - compiled_module = boutiques2python(model) - print(compiled_module) - - test_module = dynamic_module(compiled_module, "test_module") - dummy_runner = tests.utils.dummy_runner.DummyRunner() - with pytest.raises(ValueError): - test_module.dummy(runner=dummy_runner) - - assert test_module.dummy(runner=dummy_runner, x=1) is not None - assert test_module.dummy(runner=dummy_runner, y=2) is not None - assert test_module.dummy(runner=dummy_runner, z=3) is not None - assert test_module.dummy(runner=dummy_runner, x=1, y=2) is not None - assert test_module.dummy(runner=dummy_runner, x=1, z=3) is not None - assert test_module.dummy(runner=dummy_runner, y=2, z=3) is not None - assert test_module.dummy(runner=dummy_runner, x=1, y=2, z=3) is not None diff --git a/tests/test_output_files.py b/tests/test_output_files.py index f10a883..7123dab 100644 --- a/tests/test_output_files.py +++ b/tests/test_output_files.py @@ -59,7 +59,7 @@ def test_output_file_with_template() -> None: { "id": "out", "name": "The out", - "path-template": "out-{x}.txt", + "path-template": "out-[X].txt", } ], }) diff --git a/tests/utils/compile_boutiques.py b/tests/utils/compile_boutiques.py index ca5ec98..b30573f 100644 --- a/tests/utils/compile_boutiques.py +++ b/tests/utils/compile_boutiques.py @@ -1,6 +1,8 @@ -from styx.backend.python.core import to_python -from styx.frontend.boutiques.core import from_boutiques +from styx.backend.python import to_python +from styx.frontend.boutiques import from_boutiques def boutiques2python(boutiques: dict, package: str = "no_package") -> str: - return to_python([from_boutiques(boutiques, package)]).__next__()[0] + ir = from_boutiques(boutiques, package) + py = to_python([ir]).__next__()[0] + return py