From 3fe8e65a2a96f8c0d833cbb66d8782b05707a164 Mon Sep 17 00:00:00 2001 From: Florian Rupprecht Date: Mon, 20 Nov 2023 15:55:52 -0500 Subject: [PATCH] Support group constraints --- src/styx/compiler/core.py | 181 +++++++++--------- src/styx/compiler/defs.py | 86 +++++++++ src/styx/main.py | 4 +- src/styx/pycodegen/core.py | 5 + src/styx/pycodegen/utils.py | 7 +- tests/test_groups.py | 163 ++++++++++++++++ ...t_validation.py => test_numeric_ranges.py} | 15 +- 7 files changed, 355 insertions(+), 106 deletions(-) create mode 100644 src/styx/compiler/defs.py create mode 100644 tests/test_groups.py rename tests/{test_input_validation.py => test_numeric_ranges.py} (96%) diff --git a/src/styx/compiler/core.py b/src/styx/compiler/core.py index 3f23373..1990fe4 100644 --- a/src/styx/compiler/core.py +++ b/src/styx/compiler/core.py @@ -3,10 +3,11 @@ from styx.boutiques import model as bt from styx.boutiques.utils import boutiques_split_command +from styx.compiler.defs import STYX_DEFINITIONS from styx.compiler.settings import CompilerSettings, DefsMode from styx.compiler.utils import optional_float_to_int from styx.pycodegen.core import INDENT as PY_INDENT -from styx.pycodegen.core import LineBuffer, PyArg, PyFunc, PyModule, collapse, indent +from styx.pycodegen.core import LineBuffer, PyArg, PyFunc, PyModule, collapse, expand, indent from styx.pycodegen.utils import ( as_py_literal, enquote, @@ -15,89 +16,6 @@ ensure_snake_case, ) -RUNTIME_DECLARATIONS = [ - 'P = typing.TypeVar("P")', - '"""Input host file type."""', - 'R = typing.TypeVar("R")', - '"""Output host file type."""', - "", - "", - "class Execution(typing.Protocol[P, R]):", - *indent( - [ - '"""', - "Execution object used to execute commands.", - "Created by `Runner.start_execution()`.", - '"""', - "def input_file(self, host_file: P) -> str:", - *indent( - [ - '"""', - "Resolve host input files.", - "Returns a local filepath.", - "Called (potentially multiple times) after " - "`Runner.start_execution()` and before `Runner.run()`.", - '"""', - "...", - ] - ), - "def run(self, cargs: list[str]) -> None:", - *indent( - [ - '"""', - "Run the command.", - "Called after all `Execution.input_file()` calls and " "before `Execution.output_file()` calls.", - '"""', - "...", - ] - ), - "def output_file(self, local_file: str) -> R:", - *indent( - [ - '"""', - "Resolve local output files.", - "Returns a host filepath.", - "Called (potentially multiple times) after " "`Runner.run()` and before `Execution.finalize()`.", - '"""', - "...", - ] - ), - "def finalize(self) -> None:", - *indent( - [ - '"""', - "Finalize the execution.", - "Called after all `Execution.output_file()` calls.", - '"""', - "...", - ] - ), - ] - ), - "", - "", - "class Runner(typing.Protocol[P, R]):", - *indent( - [ - '"""', - "Runner object used to execute commands.", - "Possible examples would be `LocalRunner`, " "`DockerRunner`, `DebugRunner`, ...", - "Used as a factory for `Execution` objects.", - '"""', - "def start_execution(self, tool_name: str) -> Execution[P, R]:", - *indent( - [ - '"""', - "Start an execution.", - "Called before any `Execution.input_file()` calls.", - '"""', - "...", - ] - ), - ] - ), -] - class BtPrimitive(Enum): String = 1 @@ -257,7 +175,7 @@ def _generate_raise_value_err(obj: str, expectation: str, reality: str | None = ) -def _generate_validation_expr( +def _generate_range_validation_expr( buf: LineBuffer, bt_input: BtInput, ) -> None: @@ -267,6 +185,7 @@ def _generate_validation_expr( # List argument length validation if bt_input.list_minimum is not None and bt_input.list_maximum is not None: + # Case: len(list[]) == X assert bt_input.list_minimum <= bt_input.list_maximum if bt_input.list_minimum == bt_input.list_maximum: buf.extend( @@ -282,6 +201,7 @@ def _generate_validation_expr( ] ) else: + # Case: X <= len(list[]) <= Y buf.extend( [ f"if {val_opt}not ({bt_input.list_minimum} <= len({bt_input.name}) <= {bt_input.list_maximum}): ", @@ -295,6 +215,7 @@ def _generate_validation_expr( ] ) elif bt_input.list_minimum is not None: + # Case len(list[]) >= X buf.extend( [ f"if {val_opt}not ({bt_input.list_minimum} <= len({bt_input.name})): ", @@ -308,6 +229,7 @@ def _generate_validation_expr( ] ) elif bt_input.list_maximum is not None: + # Case len(list[]) <= X buf.extend( [ f"if {val_opt}not (len({bt_input.name}) <= {bt_input.list_maximum}): ", @@ -325,6 +247,7 @@ def _generate_validation_expr( op_min = "<" if bt_input.minimum_exclusive else "<=" op_max = "<" if bt_input.maximum_exclusive else "<=" if bt_input.minimum is not None and bt_input.maximum is not None: + # Case: X <= arg <= Y assert bt_input.minimum <= bt_input.maximum if bt_input.type.is_list: buf.extend( @@ -353,6 +276,7 @@ def _generate_validation_expr( ] ) elif bt_input.minimum is not None: + # Case: X <= arg if bt_input.type.is_list: buf.extend( [ @@ -379,6 +303,7 @@ def _generate_validation_expr( ] ) elif bt_input.maximum is not None: + # Case: arg <= X if bt_input.type.is_list: buf.extend( [ @@ -406,7 +331,79 @@ def _generate_validation_expr( ) -def py_from_boutiques(tool: bt.Tool, settings: CompilerSettings) -> str: # type: ignore +def _generate_group_constraint_expr( + buf: LineBuffer, + group: bt.Group, # type: ignore +) -> None: + if group.mutually_exclusive: + txt_members = [enquote(x) for x in expand(",\\n\n".join(group.members))] + check_members = expand(" +\n".join([f"({x} is not None)" for x in group.members])) + buf.extend(["if ("]) + buf.extend(indent(check_members)) + buf.extend( + [ + ") > 1:", + *indent( + [ + "raise ValueError(", + *indent( + [ + '"Only one of the following arguments can be specified:\\n"', + *txt_members, + ] + ), + ")", + ] + ), + ] + ) + if group.all_or_none: + txt_members = [enquote(x) for x in expand(",\\n\n".join(group.members))] + check_members = expand(" ==\n".join([f"({x} is None)" for x in group.members])) + buf.extend(["if not ("]) + buf.extend(indent(check_members)) + buf.extend( + [ + "):", + *indent( + [ + "raise ValueError(", + *indent( + [ + '"All or none of the following arguments must be specified:\\n"', + *txt_members, + ] + ), + ")", + ] + ), + ] + ) + if group.one_is_required: + txt_members = [enquote("- " + x) for x in expand("\\n\n".join(group.members))] + check_members = expand(" or\n".join([f"({x} is not None)" for x in group.members])) + buf.extend(["if not ("]) + buf.extend(indent(check_members)) + buf.extend( + [ + "):", + *indent( + [ + "raise ValueError(", + *indent( + [ + '"One of the following arguments must be specified:\\n"', + *txt_members, + ] + ), + ")", + ] + ), + ] + ) + + +def _from_boutiques(tool: bt.Tool, settings: CompilerSettings) -> str: # type: ignore mod = PyModule() # Python names @@ -436,7 +433,11 @@ def py_from_boutiques(tool: bt.Tool, settings: CompilerSettings) -> str: # type # Input validation for i in args: - _generate_validation_expr(buf_body, i) + _generate_range_validation_expr(buf_body, i) + + if tool.groups is not None: + for group in tool.groups: + _generate_group_constraint_expr(buf_body, group) # Command line args building for segment in cmd: @@ -450,11 +451,11 @@ def py_from_boutiques(tool: bt.Tool, settings: CompilerSettings) -> str: # type # Definitions if settings.defs_mode == DefsMode.INLINE: - defs = RUNTIME_DECLARATIONS + defs = STYX_DEFINITIONS elif settings.defs_mode == DefsMode.IMPORT: defs = ["from styx.runners.styxdefs import *"] else: - return collapse(RUNTIME_DECLARATIONS) + return collapse(STYX_DEFINITIONS) buf_header = [] buf_header.extend( @@ -517,4 +518,4 @@ def py_from_boutiques(tool: bt.Tool, settings: CompilerSettings) -> str: # type def compile_descriptor(descriptor: bt.Tool, settings: CompilerSettings) -> str: # type: ignore """Compile a Boutiques descriptor to Python code.""" - return py_from_boutiques(descriptor, settings) + return _from_boutiques(descriptor, settings) diff --git a/src/styx/compiler/defs.py b/src/styx/compiler/defs.py new file mode 100644 index 0000000..bf7f269 --- /dev/null +++ b/src/styx/compiler/defs.py @@ -0,0 +1,86 @@ +"""Static type declarations used by compiled code.""" + +from styx.pycodegen.core import indent + +STYX_DEFINITIONS = [ + 'P = typing.TypeVar("P")', + '"""Input host file type."""', + 'R = typing.TypeVar("R")', + '"""Output host file type."""', + "", + "", + "class Execution(typing.Protocol[P, R]):", + *indent( + [ + '"""', + "Execution object used to execute commands.", + "Created by `Runner.start_execution()`.", + '"""', + "def input_file(self, host_file: P) -> str:", + *indent( + [ + '"""', + "Resolve host input files.", + "Returns a local filepath.", + "Called (potentially multiple times) after " + "`Runner.start_execution()` and before `Runner.run()`.", + '"""', + "...", + ] + ), + "def run(self, cargs: list[str]) -> None:", + *indent( + [ + '"""', + "Run the command.", + "Called after all `Execution.input_file()` calls and " "before `Execution.output_file()` calls.", + '"""', + "...", + ] + ), + "def output_file(self, local_file: str) -> R:", + *indent( + [ + '"""', + "Resolve local output files.", + "Returns a host filepath.", + "Called (potentially multiple times) after " "`Runner.run()` and before `Execution.finalize()`.", + '"""', + "...", + ] + ), + "def finalize(self) -> None:", + *indent( + [ + '"""', + "Finalize the execution.", + "Called after all `Execution.output_file()` calls.", + '"""', + "...", + ] + ), + ] + ), + "", + "", + "class Runner(typing.Protocol[P, R]):", + *indent( + [ + '"""', + "Runner object used to execute commands.", + "Possible examples would be `LocalRunner`, " "`DockerRunner`, `DebugRunner`, ...", + "Used as a factory for `Execution` objects.", + '"""', + "def start_execution(self, tool_name: str) -> Execution[P, R]:", + *indent( + [ + '"""', + "Start an execution.", + "Called before any `Execution.input_file()` calls.", + '"""', + "...", + ] + ), + ] + ), +] diff --git a/src/styx/main.py b/src/styx/main.py index dbbe059..411b414 100644 --- a/src/styx/main.py +++ b/src/styx/main.py @@ -2,11 +2,11 @@ from styx.boutiques.utils import boutiques_from_dict from styx.compiler.core import compile_descriptor -from styx.compiler.settings import CompilerSettings +from styx.compiler.settings import CompilerSettings, DefsMode def main() -> None: - settings = CompilerSettings() + settings = CompilerSettings(defs_mode=DefsMode.IMPORT) with open("examples/bet.json", "r") as json_file: json_data = json.load(json_file) descriptor = boutiques_from_dict(json_data) diff --git a/src/styx/pycodegen/core.py b/src/styx/pycodegen/core.py index 03c19df..e1c6939 100644 --- a/src/styx/pycodegen/core.py +++ b/src/styx/pycodegen/core.py @@ -25,6 +25,11 @@ def collapse(lines: LineBuffer) -> str: return "\n".join(lines) +def expand(text: str) -> LineBuffer: + """Expand a string into a LineBuffer.""" + return text.splitlines() + + def blank_before(lines: LineBuffer, blanks: int = 1) -> LineBuffer: """Add blank lines at the beginning of a LineBuffer if it is not empty.""" return [*([""] * blanks), *lines] if len(lines) > 0 else lines diff --git a/src/styx/pycodegen/utils.py b/src/styx/pycodegen/utils.py index e3d0914..dddd2b6 100644 --- a/src/styx/pycodegen/utils.py +++ b/src/styx/pycodegen/utils.py @@ -50,9 +50,12 @@ def ensure_camel_case(string: str) -> str: return _RX_ENSURE_CAMEL.sub("_", string).title().replace("_", "") -def enquote(s: str) -> str: # noqa +def enquote( + s: str, + quote: str = '"', +) -> str: # noqa """Put a string in "quotes".""" - return f'"{s}"' + return f"{quote}{s}{quote}" def as_py_literal(obj: str | float | int | bool) -> str: diff --git a/tests/test_groups.py b/tests/test_groups.py new file mode 100644 index 0000000..14ed43a --- /dev/null +++ b/tests/test_groups.py @@ -0,0 +1,163 @@ +"""Argument group constraint tests.""" + +import styx.boutiques.utils +import styx.compiler.core +import styx.compiler.settings +import styx.runners.core +from tests.utils.dynmodule import ( + BT_TYPE_NUMBER, + boutiques_dummy, + dynamic_module, +) + +_XYZ_INPUTS = [ + { + "id": "x", + "name": "The x", + "value-key": "[X]", + "type": BT_TYPE_NUMBER, + "integer": True, + "optional": True, + }, + { + "id": "y", + "name": "The y", + "value-key": "[Y]", + "type": BT_TYPE_NUMBER, + "integer": True, + "optional": True, + }, + { + "id": "z", + "name": "The z", + "value-key": "[Z]", + "type": BT_TYPE_NUMBER, + "integer": True, + "optional": True, + }, +] + + +def test_mutually_exclusive() -> None: + """Mutually exclusive argument group.""" + settings = styx.compiler.settings.CompilerSettings(defs_mode=styx.compiler.settings.DefsMode.IMPORT) + model = styx.boutiques.utils.boutiques_from_dict( + boutiques_dummy( + { + "command-line": "dummy [X] [Y] [Z]", + "inputs": _XYZ_INPUTS, + "groups": [ + { + "id": "group", + "name": "Group", + "members": ["x", "y", "z"], + "mutually-exclusive": True, + } + ], + } + ) + ) + + compiled_module = styx.compiler.core.compile_descriptor(model, settings) + + test_module = dynamic_module(compiled_module, "test_module") + dummy_runner = styx.runners.core.DummyRunner() + try: + test_module.dummy(runner=dummy_runner, x=1, y=2) + except ValueError as e: + assert "Only one" in str(e) + else: + assert False, "Expected ValueError" + + try: + test_module.dummy(runner=dummy_runner, x=1, y=2, z=3) + except ValueError as e: + assert "Only one" in str(e) + else: + assert False, "Expected ValueError" + + assert test_module.dummy(runner=dummy_runner, x=1) is not None + assert test_module.dummy(runner=dummy_runner, y=2) is not None + assert test_module.dummy(runner=dummy_runner, z=2) is not None + assert test_module.dummy(runner=dummy_runner) is not None + + +def test_all_or_none() -> None: + """All or none argument group.""" + settings = styx.compiler.settings.CompilerSettings(defs_mode=styx.compiler.settings.DefsMode.IMPORT) + model = styx.boutiques.utils.boutiques_from_dict( + boutiques_dummy( + { + "command-line": "dummy [X] [Y] [Z]", + "inputs": _XYZ_INPUTS, + "groups": [ + { + "id": "group", + "name": "Group", + "members": ["x", "y", "z"], + "all-or-none": True, + } + ], + } + ) + ) + + compiled_module = styx.compiler.core.compile_descriptor(model, settings) + + test_module = dynamic_module(compiled_module, "test_module") + dummy_runner = styx.runners.core.DummyRunner() + try: + test_module.dummy(runner=dummy_runner, x=1, y=2) + except ValueError as e: + assert "All or none" in str(e) + else: + assert False, "Expected ValueError" + try: + test_module.dummy(runner=dummy_runner, z=3) + except ValueError as e: + assert "All or none" in str(e) + else: + assert False, "Expected ValueError" + + assert test_module.dummy(runner=dummy_runner, x=1, y=2, z=3) is not None + assert test_module.dummy(runner=dummy_runner) is not None + + +def test_one_required() -> None: + """One required argument group.""" + settings = styx.compiler.settings.CompilerSettings(defs_mode=styx.compiler.settings.DefsMode.IMPORT) + model = styx.boutiques.utils.boutiques_from_dict( + boutiques_dummy( + { + "command-line": "dummy [X] [Y] [Z]", + "inputs": _XYZ_INPUTS, + "groups": [ + { + "id": "group", + "name": "Group", + "members": ["x", "y", "z"], + "one-is-required": True, + } + ], + } + ) + ) + + compiled_module = styx.compiler.core.compile_descriptor(model, settings) + + test_module = dynamic_module(compiled_module, "test_module") + dummy_runner = styx.runners.core.DummyRunner() + try: + test_module.dummy(runner=dummy_runner) + except ValueError as e: + assert "One of" in str(e) + else: + assert False, "Expected ValueError" + + assert test_module.dummy(runner=dummy_runner, x=1) is not None + assert test_module.dummy(runner=dummy_runner, y=2) is not None + assert test_module.dummy(runner=dummy_runner, z=3) is not None + assert test_module.dummy(runner=dummy_runner, x=1, y=2) is not None + assert test_module.dummy(runner=dummy_runner, x=1, z=3) is not None + assert test_module.dummy(runner=dummy_runner, y=2, z=3) is not None + assert test_module.dummy(runner=dummy_runner, x=1, y=2, z=3) is not None diff --git a/tests/test_input_validation.py b/tests/test_numeric_ranges.py similarity index 96% rename from tests/test_input_validation.py rename to tests/test_numeric_ranges.py index 32957b2..bd6afc7 100644 --- a/tests/test_input_validation.py +++ b/tests/test_numeric_ranges.py @@ -1,13 +1,4 @@ -"""Input validation tests. - -Non-goals: -- Argument types. -> typing - -Goals: -- Numeric ranges of values. -- Mutually exclusive arguments. - -""" +"""Numeric ranges tests.""" import styx.boutiques.utils import styx.compiler.core @@ -183,13 +174,13 @@ def test_outside_range() -> None: try: test_module.dummy(runner=dummy_runner, x=11) except ValueError as e: - assert "must be less than" in str(e) + assert "must be between" in str(e) else: assert False, "Expected ValueError" try: test_module.dummy(runner=dummy_runner, x=4) except ValueError as e: - assert "must be greater than" in str(e) + assert "must be between" in str(e) else: assert False, "Expected ValueError"