Skip to content

Commit

Permalink
Support group constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
nx10 committed Nov 20, 2023
1 parent 1215380 commit 3fe8e65
Show file tree
Hide file tree
Showing 7 changed files with 355 additions and 106 deletions.
181 changes: 91 additions & 90 deletions src/styx/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

from styx.boutiques import model as bt
from styx.boutiques.utils import boutiques_split_command
from styx.compiler.defs import STYX_DEFINITIONS
from styx.compiler.settings import CompilerSettings, DefsMode
from styx.compiler.utils import optional_float_to_int
from styx.pycodegen.core import INDENT as PY_INDENT
from styx.pycodegen.core import LineBuffer, PyArg, PyFunc, PyModule, collapse, indent
from styx.pycodegen.core import LineBuffer, PyArg, PyFunc, PyModule, collapse, expand, indent
from styx.pycodegen.utils import (
as_py_literal,
enquote,
Expand All @@ -15,89 +16,6 @@
ensure_snake_case,
)

RUNTIME_DECLARATIONS = [
'P = typing.TypeVar("P")',
'"""Input host file type."""',
'R = typing.TypeVar("R")',
'"""Output host file type."""',
"",
"",
"class Execution(typing.Protocol[P, R]):",
*indent(
[
'"""',
"Execution object used to execute commands.",
"Created by `Runner.start_execution()`.",
'"""',
"def input_file(self, host_file: P) -> str:",
*indent(
[
'"""',
"Resolve host input files.",
"Returns a local filepath.",
"Called (potentially multiple times) after "
"`Runner.start_execution()` and before `Runner.run()`.",
'"""',
"...",
]
),
"def run(self, cargs: list[str]) -> None:",
*indent(
[
'"""',
"Run the command.",
"Called after all `Execution.input_file()` calls and " "before `Execution.output_file()` calls.",
'"""',
"...",
]
),
"def output_file(self, local_file: str) -> R:",
*indent(
[
'"""',
"Resolve local output files.",
"Returns a host filepath.",
"Called (potentially multiple times) after " "`Runner.run()` and before `Execution.finalize()`.",
'"""',
"...",
]
),
"def finalize(self) -> None:",
*indent(
[
'"""',
"Finalize the execution.",
"Called after all `Execution.output_file()` calls.",
'"""',
"...",
]
),
]
),
"",
"",
"class Runner(typing.Protocol[P, R]):",
*indent(
[
'"""',
"Runner object used to execute commands.",
"Possible examples would be `LocalRunner`, " "`DockerRunner`, `DebugRunner`, ...",
"Used as a factory for `Execution` objects.",
'"""',
"def start_execution(self, tool_name: str) -> Execution[P, R]:",
*indent(
[
'"""',
"Start an execution.",
"Called before any `Execution.input_file()` calls.",
'"""',
"...",
]
),
]
),
]


class BtPrimitive(Enum):
String = 1
Expand Down Expand Up @@ -257,7 +175,7 @@ def _generate_raise_value_err(obj: str, expectation: str, reality: str | None =
)


def _generate_validation_expr(
def _generate_range_validation_expr(
buf: LineBuffer,
bt_input: BtInput,
) -> None:
Expand All @@ -267,6 +185,7 @@ def _generate_validation_expr(

# List argument length validation
if bt_input.list_minimum is not None and bt_input.list_maximum is not None:
# Case: len(list[]) == X
assert bt_input.list_minimum <= bt_input.list_maximum
if bt_input.list_minimum == bt_input.list_maximum:
buf.extend(
Expand All @@ -282,6 +201,7 @@ def _generate_validation_expr(
]
)
else:
# Case: X <= len(list[]) <= Y
buf.extend(
[
f"if {val_opt}not ({bt_input.list_minimum} <= len({bt_input.name}) <= {bt_input.list_maximum}): ",
Expand All @@ -295,6 +215,7 @@ def _generate_validation_expr(
]
)
elif bt_input.list_minimum is not None:
# Case len(list[]) >= X
buf.extend(
[
f"if {val_opt}not ({bt_input.list_minimum} <= len({bt_input.name})): ",
Expand All @@ -308,6 +229,7 @@ def _generate_validation_expr(
]
)
elif bt_input.list_maximum is not None:
# Case len(list[]) <= X
buf.extend(
[
f"if {val_opt}not (len({bt_input.name}) <= {bt_input.list_maximum}): ",
Expand All @@ -325,6 +247,7 @@ def _generate_validation_expr(
op_min = "<" if bt_input.minimum_exclusive else "<="
op_max = "<" if bt_input.maximum_exclusive else "<="
if bt_input.minimum is not None and bt_input.maximum is not None:
# Case: X <= arg <= Y
assert bt_input.minimum <= bt_input.maximum
if bt_input.type.is_list:
buf.extend(
Expand Down Expand Up @@ -353,6 +276,7 @@ def _generate_validation_expr(
]
)
elif bt_input.minimum is not None:
# Case: X <= arg
if bt_input.type.is_list:
buf.extend(
[
Expand All @@ -379,6 +303,7 @@ def _generate_validation_expr(
]
)
elif bt_input.maximum is not None:
# Case: arg <= X
if bt_input.type.is_list:
buf.extend(
[
Expand Down Expand Up @@ -406,7 +331,79 @@ def _generate_validation_expr(
)


def py_from_boutiques(tool: bt.Tool, settings: CompilerSettings) -> str: # type: ignore
def _generate_group_constraint_expr(
buf: LineBuffer,
group: bt.Group, # type: ignore
) -> None:
if group.mutually_exclusive:
txt_members = [enquote(x) for x in expand(",\\n\n".join(group.members))]
check_members = expand(" +\n".join([f"({x} is not None)" for x in group.members]))
buf.extend(["if ("])
buf.extend(indent(check_members))
buf.extend(
[
") > 1:",
*indent(
[
"raise ValueError(",
*indent(
[
'"Only one of the following arguments can be specified:\\n"',
*txt_members,
]
),
")",
]
),
]
)
if group.all_or_none:
txt_members = [enquote(x) for x in expand(",\\n\n".join(group.members))]
check_members = expand(" ==\n".join([f"({x} is None)" for x in group.members]))
buf.extend(["if not ("])
buf.extend(indent(check_members))
buf.extend(
[
"):",
*indent(
[
"raise ValueError(",
*indent(
[
'"All or none of the following arguments must be specified:\\n"',
*txt_members,
]
),
")",
]
),
]
)
if group.one_is_required:
txt_members = [enquote("- " + x) for x in expand("\\n\n".join(group.members))]
check_members = expand(" or\n".join([f"({x} is not None)" for x in group.members]))
buf.extend(["if not ("])
buf.extend(indent(check_members))
buf.extend(
[
"):",
*indent(
[
"raise ValueError(",
*indent(
[
'"One of the following arguments must be specified:\\n"',
*txt_members,
]
),
")",
]
),
]
)


def _from_boutiques(tool: bt.Tool, settings: CompilerSettings) -> str: # type: ignore
mod = PyModule()

# Python names
Expand Down Expand Up @@ -436,7 +433,11 @@ def py_from_boutiques(tool: bt.Tool, settings: CompilerSettings) -> str: # type

# Input validation
for i in args:
_generate_validation_expr(buf_body, i)
_generate_range_validation_expr(buf_body, i)

if tool.groups is not None:
for group in tool.groups:
_generate_group_constraint_expr(buf_body, group)

# Command line args building
for segment in cmd:
Expand All @@ -450,11 +451,11 @@ def py_from_boutiques(tool: bt.Tool, settings: CompilerSettings) -> str: # type

# Definitions
if settings.defs_mode == DefsMode.INLINE:
defs = RUNTIME_DECLARATIONS
defs = STYX_DEFINITIONS
elif settings.defs_mode == DefsMode.IMPORT:
defs = ["from styx.runners.styxdefs import *"]
else:
return collapse(RUNTIME_DECLARATIONS)
return collapse(STYX_DEFINITIONS)

buf_header = []
buf_header.extend(
Expand Down Expand Up @@ -517,4 +518,4 @@ def py_from_boutiques(tool: bt.Tool, settings: CompilerSettings) -> str: # type

def compile_descriptor(descriptor: bt.Tool, settings: CompilerSettings) -> str: # type: ignore
"""Compile a Boutiques descriptor to Python code."""
return py_from_boutiques(descriptor, settings)
return _from_boutiques(descriptor, settings)
86 changes: 86 additions & 0 deletions src/styx/compiler/defs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Static type declarations used by compiled code."""

from styx.pycodegen.core import indent

STYX_DEFINITIONS = [
'P = typing.TypeVar("P")',
'"""Input host file type."""',
'R = typing.TypeVar("R")',
'"""Output host file type."""',
"",
"",
"class Execution(typing.Protocol[P, R]):",
*indent(
[
'"""',
"Execution object used to execute commands.",
"Created by `Runner.start_execution()`.",
'"""',
"def input_file(self, host_file: P) -> str:",
*indent(
[
'"""',
"Resolve host input files.",
"Returns a local filepath.",
"Called (potentially multiple times) after "
"`Runner.start_execution()` and before `Runner.run()`.",
'"""',
"...",
]
),
"def run(self, cargs: list[str]) -> None:",
*indent(
[
'"""',
"Run the command.",
"Called after all `Execution.input_file()` calls and " "before `Execution.output_file()` calls.",
'"""',
"...",
]
),
"def output_file(self, local_file: str) -> R:",
*indent(
[
'"""',
"Resolve local output files.",
"Returns a host filepath.",
"Called (potentially multiple times) after " "`Runner.run()` and before `Execution.finalize()`.",
'"""',
"...",
]
),
"def finalize(self) -> None:",
*indent(
[
'"""',
"Finalize the execution.",
"Called after all `Execution.output_file()` calls.",
'"""',
"...",
]
),
]
),
"",
"",
"class Runner(typing.Protocol[P, R]):",
*indent(
[
'"""',
"Runner object used to execute commands.",
"Possible examples would be `LocalRunner`, " "`DockerRunner`, `DebugRunner`, ...",
"Used as a factory for `Execution` objects.",
'"""',
"def start_execution(self, tool_name: str) -> Execution[P, R]:",
*indent(
[
'"""',
"Start an execution.",
"Called before any `Execution.input_file()` calls.",
'"""',
"...",
]
),
]
),
]
4 changes: 2 additions & 2 deletions src/styx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from styx.boutiques.utils import boutiques_from_dict
from styx.compiler.core import compile_descriptor
from styx.compiler.settings import CompilerSettings
from styx.compiler.settings import CompilerSettings, DefsMode


def main() -> None:
settings = CompilerSettings()
settings = CompilerSettings(defs_mode=DefsMode.IMPORT)
with open("examples/bet.json", "r") as json_file:
json_data = json.load(json_file)
descriptor = boutiques_from_dict(json_data)
Expand Down
5 changes: 5 additions & 0 deletions src/styx/pycodegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def collapse(lines: LineBuffer) -> str:
return "\n".join(lines)


def expand(text: str) -> LineBuffer:
"""Expand a string into a LineBuffer."""
return text.splitlines()


def blank_before(lines: LineBuffer, blanks: int = 1) -> LineBuffer:
"""Add blank lines at the beginning of a LineBuffer if it is not empty."""
return [*([""] * blanks), *lines] if len(lines) > 0 else lines
Expand Down
Loading

0 comments on commit 3fe8e65

Please sign in to comment.