Skip to content

Commit

Permalink
IR progress
Browse files Browse the repository at this point in the history
  • Loading branch information
nx10 committed Sep 5, 2024
1 parent 27f553a commit 42d858e
Show file tree
Hide file tree
Showing 13 changed files with 414 additions and 359 deletions.
2 changes: 2 additions & 0 deletions src/styx/backend/python/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
"""Python wrapper backend."""

from .core import to_python
172 changes: 172 additions & 0 deletions src/styx/backend/python/constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from styx.backend.python.lookup import LookupParam
from styx.backend.python.pycodegen.core import LineBuffer, PyFunc, indent
import styx.ir.core as ir


def _generate_raise_value_err(obj: str, expectation: str, reality: str | None = None) -> LineBuffer:
fstr = ""
if "{" in obj or "{" in expectation or (reality is not None and "{" in reality):
fstr = "f"

return (
[f'raise ValueError({fstr}"{obj} must be {expectation} but was {reality}")']
if reality is not None
else [f'raise ValueError({fstr}"{obj} must be {expectation}")']
)


def _param_compile_constraint_checks(buf: LineBuffer, param: ir.IParam, lookup: LookupParam) -> None:
"""Generate input constraint validation code for an input argument."""
py_symbol = lookup.py_symbol[param.param.id_]

min_value: float | int | None = None
max_value: float | int | None = None
list_count_min: int | None = None
list_count_max: int | None = None

if isinstance(param, (ir.IFloat, ir.IInt)):
min_value = param.min_value
max_value = param.max_value
elif isinstance(param, ir.IList):
list_count_min = param.list_.count_min
list_count_max = param.list_.count_max

val_opt = ""
if isinstance(param, ir.IOptional):
val_opt = f"{py_symbol} is not None and "

# List argument length validation
if list_count_min is not None and list_count_max is not None:
# Case: len(list[]) == X
assert list_count_min <= list_count_max
if list_count_min == list_count_max:
buf.extend([
f"if {val_opt}(len({py_symbol}) != {list_count_min}): ",
*indent(
_generate_raise_value_err(
f"Length of '{py_symbol}'",
f"{list_count_min}",
f"{{len({py_symbol})}}",
)
),
])
else:
# Case: X <= len(list[]) <= Y
buf.extend([
f"if {val_opt}not ({list_count_min} <= " f"len({py_symbol}) <= {list_count_max}): ",
*indent(
_generate_raise_value_err(
f"Length of '{py_symbol}'",
f"between {list_count_min} and {list_count_max}",
f"{{len({py_symbol})}}",
)
),
])
elif list_count_min is not None:
# Case len(list[]) >= X
buf.extend([
f"if {val_opt}not ({list_count_min} <= len({py_symbol})): ",
*indent(
_generate_raise_value_err(
f"Length of '{py_symbol}'",
f"greater than {list_count_min}",
f"{{len({py_symbol})}}",
)
),
])
elif list_count_max is not None:
# Case len(list[]) <= X
buf.extend([
f"if {val_opt}not (len({py_symbol}) <= {list_count_max}): ",
*indent(
_generate_raise_value_err(
f"Length of '{py_symbol}'",
f"less than {list_count_max}",
f"{{len({py_symbol})}}",
)
),
])

# Numeric argument range validation
op_min = "<="
op_max = "<="
if min_value is not None and max_value is not None:
# Case: X <= arg <= Y
assert min_value <= max_value
if isinstance(param, ir.IList):
buf.extend([
f"if {val_opt}not ({min_value} {op_min} min({py_symbol}) "
f"and max({py_symbol}) {op_max} {max_value}): ",
*indent(
_generate_raise_value_err(
f"All elements of '{py_symbol}'",
f"between {min_value} {op_min} x {op_max} {max_value}",
)
),
])
else:
buf.extend([
f"if {val_opt}not ({min_value} {op_min} {py_symbol} {op_max} {max_value}): ",
*indent(
_generate_raise_value_err(
f"'{py_symbol}'",
f"between {min_value} {op_min} x {op_max} {max_value}",
f"{{{py_symbol}}}",
)
),
])
elif min_value is not None:
# Case: X <= arg
if isinstance(param, ir.IList):
buf.extend([
f"if {val_opt}not ({min_value} {op_min} min({py_symbol})): ",
*indent(
_generate_raise_value_err(
f"All elements of '{py_symbol}'",
f"greater than {min_value} {op_min} x",
)
),
])
else:
buf.extend([
f"if {val_opt}not ({min_value} {op_min} {py_symbol}): ",
*indent(
_generate_raise_value_err(
f"'{py_symbol}'",
f"greater than {min_value} {op_min} x",
f"{{{py_symbol}}}",
)
),
])
elif max_value is not None:
# Case: arg <= X
if isinstance(param, ir.IList):
buf.extend([
f"if {val_opt}not (max({py_symbol}) {op_max} {max_value}): ",
*indent(
_generate_raise_value_err(
f"All elements of '{py_symbol}'",
f"less than x {op_max} {max_value}",
)
),
])
else:
buf.extend([
f"if {val_opt}not ({py_symbol} {op_max} {max_value}): ",
*indent(
_generate_raise_value_err(
f"'{py_symbol}'",
f"less than x {op_max} {max_value}",
f"{{{py_symbol}}}",
)
),
])


def struct_compile_constraint_checks(
func: PyFunc,
struct: ir.IStruct | ir.IParam,
lookup: LookupParam,
) -> None:
for param in struct.struct.iter_params():
_param_compile_constraint_checks(func.body, param, lookup)
100 changes: 11 additions & 89 deletions src/styx/backend/python/interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import styx.ir.core as ir
from styx.backend.python.constraints import struct_compile_constraint_checks
from styx.backend.python.documentation import docs_to_docstring
from styx.backend.python.lookup import LookupParam
from styx.backend.python.metadata import generate_static_metadata
from styx.backend.python.pycodegen.core import (
LineBuffer,
Expand All @@ -12,100 +14,19 @@
indent,
)
from styx.backend.python.pycodegen.scope import Scope
from styx.backend.python.pycodegen.utils import as_py_literal, python_pascalize, python_snakify
from styx.backend.python.pycodegen.utils import as_py_literal, python_snakify
from styx.backend.python.utils import (
iter_params_recursively,
param_py_default_value,
param_py_type,
param_py_var_is_set_by_user,
param_py_var_to_str,
struct_has_outputs,
)


class _LookupParam:
"""Pre-compute and store Python symbols, types, class-names, etc. to reduce spaghetti code everywhere else."""

def __init__(
self,
interface: ir.Interface,
package_scope: Scope,
function_symbol: str,
function_scope: Scope,
) -> None:
def _collect_output_field_symbols(
param: ir.IStruct | ir.IParam, lookup_output_field_symbol: dict[ir.IdType, str]
) -> None:
scope = Scope(parent=package_scope)
for output in param.param.outputs:
output_field_symbol = scope.add_or_dodge(output.name)
assert output.id_ not in lookup_output_field_symbol
lookup_output_field_symbol[output.id_] = output_field_symbol

def _collect_py_symbol(param: ir.IStruct | ir.IParam, lookup_py_symbol: dict[ir.IdType, str]) -> None:
scope = Scope(parent=function_scope)
for elem in param.struct.iter_params():
symbol = scope.add_or_dodge(python_snakify(elem.param.name))
assert elem.param.id_ not in lookup_py_symbol
lookup_py_symbol[elem.param.id_] = symbol

self.param: dict[ir.IdType, ir.IParam] = {interface.command.param.id_: interface.command}
"""Find param object by its ID. IParam.id_ -> IParam"""
self.py_type: dict[ir.IdType, str] = {interface.command.param.id_: function_symbol}
"""Find Python type by param id. IParam.id_ -> Python type"""
self.py_symbol: dict[ir.IdType, str] = {}
"""Find function-parameter symbol by param ID. IParam.id_ -> Python symbol"""
self.py_output_type: dict[ir.IdType, str] = {
interface.command.param.id_: package_scope.add_or_dodge(
python_pascalize(f"{interface.command.struct.name}_Outputs")
)
}
"""Find outputs class name by struct param ID. IStruct.id_ -> Python class name"""
self.py_output_field_symbol: dict[ir.IdType, str] = {}
"""Find output field symbol by output ID. Output.id_ -> Python symbol"""

_collect_py_symbol(
param=interface.command,
lookup_py_symbol=self.py_symbol,
)
_collect_output_field_symbols(
param=interface.command,
lookup_output_field_symbol=self.py_output_field_symbol,
)

for elem in iter_params_recursively(interface.command):
self.param[elem.param.id_] = elem

if isinstance(elem, ir.IStruct):
if elem.param.id_ not in self.py_type: # Struct unions may resolve these first
self.py_type[elem.param.id_] = package_scope.add_or_dodge(
python_pascalize(f"{interface.command.struct.name}_{elem.struct.name}")
)
self.py_output_type[elem.param.id_] = package_scope.add_or_dodge(
python_pascalize(f"{interface.command.struct.name}_{elem.struct.name}_Outputs")
)
_collect_py_symbol(
param=elem,
lookup_py_symbol=self.py_symbol,
)
_collect_output_field_symbols(
param=elem,
lookup_output_field_symbol=self.py_output_field_symbol,
)
elif isinstance(elem, ir.IStructUnion):
for alternative in elem.alts:
self.py_type[alternative.param.id_] = package_scope.add_or_dodge(
python_pascalize(f"{interface.command.struct.name}_{alternative.struct.name}")
)
self.py_type[elem.param.id_] = param_py_type(elem, self.py_type)
else:
self.py_type[elem.param.id_] = param_py_type(elem, self.py_type)


def _compile_struct(
param: ir.IStruct | ir.IParam,
interface_module: PyModule,
lookup: _LookupParam,
lookup: LookupParam,
metadata_symbol: str,
root_function: bool,
) -> None:
Expand All @@ -128,8 +49,6 @@ def _compile_struct(
)
pyargs = func_cargs_building.args
interface_module.funcs.append(func_cargs_building)

pyargs.append(PyArg(name="runner", type="Runner | None", default="None", docstring="Command runner"))
else:
func_cargs_building = PyFunc(
name="run",
Expand Down Expand Up @@ -191,6 +110,8 @@ def _compile_struct(
root_function=False,
)

struct_compile_constraint_checks(func=func_cargs_building, struct=param, lookup=lookup)

func_cargs_building.body.extend([
"runner = runner or get_global_runner()",
f"execution = runner.start_execution({metadata_symbol})",
Expand All @@ -199,6 +120,7 @@ def _compile_struct(
_compile_cargs_building(param, lookup, func_cargs_building, access_via_self=not root_function)

if root_function:
pyargs.append(PyArg(name="runner", type="Runner | None", default="None", docstring="Command runner"))
_compile_outputs_building(
param=param,
func=func_cargs_building,
Expand Down Expand Up @@ -227,7 +149,7 @@ def _compile_struct(

def _compile_cargs_building(
param: ir.IParam | ir.IStruct,
lookup: _LookupParam,
lookup: LookupParam,
func: PyFunc,
access_via_self: bool,
) -> None:
Expand Down Expand Up @@ -282,7 +204,7 @@ def _compile_cargs_building(
def _compile_outputs_class(
param: ir.IStruct | ir.IParam,
interface_module: PyModule,
lookup: _LookupParam,
lookup: LookupParam,
) -> None:
outputs_class = PyDataClass(
name=lookup.py_output_type[param.param.id_],
Expand Down Expand Up @@ -328,7 +250,7 @@ def _compile_outputs_class(
def _compile_outputs_building(
param: ir.IStruct | ir.IParam,
func: PyFunc,
lookup: _LookupParam,
lookup: LookupParam,
access_via_self: bool = False,
) -> None:
"""Generate the outputs building code."""
Expand Down Expand Up @@ -424,7 +346,7 @@ def compile_interface(
function_scope.add_or_die("ret")

# Lookup tables
lookup = _LookupParam(
lookup = LookupParam(
interface=interface,
package_scope=package_scope,
function_symbol=function_symbol,
Expand Down
Loading

0 comments on commit 42d858e

Please sign in to comment.