Skip to content

Commit

Permalink
IR (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
nx10 authored Sep 10, 2024
1 parent fa72060 commit ee40d49
Show file tree
Hide file tree
Showing 49 changed files with 2,086 additions and 1,950 deletions.
1 change: 1 addition & 0 deletions src/styx/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Styx Python backend."""
3 changes: 3 additions & 0 deletions src/styx/backend/python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Python wrapper backend."""

from .core import to_python as to_python
172 changes: 172 additions & 0 deletions src/styx/backend/python/constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import styx.ir.core as ir
from styx.backend.python.lookup import LookupParam
from styx.backend.python.pycodegen.core import LineBuffer, PyFunc, indent


def _generate_raise_value_err(obj: str, expectation: str, reality: str | None = None) -> LineBuffer:
fstr = ""
if "{" in obj or "{" in expectation or (reality is not None and "{" in reality):
fstr = "f"

return (
[f'raise ValueError({fstr}"{obj} must be {expectation} but was {reality}")']
if reality is not None
else [f'raise ValueError({fstr}"{obj} must be {expectation}")']
)


def _param_compile_constraint_checks(buf: LineBuffer, param: ir.IParam, lookup: LookupParam) -> None:
"""Generate input constraint validation code for an input argument."""
py_symbol = lookup.py_symbol[param.param.id_]

min_value: float | int | None = None
max_value: float | int | None = None
list_count_min: int | None = None
list_count_max: int | None = None

if isinstance(param, (ir.IFloat, ir.IInt)):
min_value = param.min_value
max_value = param.max_value
elif isinstance(param, ir.IList):
list_count_min = param.list_.count_min
list_count_max = param.list_.count_max

val_opt = ""
if isinstance(param, ir.IOptional):
val_opt = f"{py_symbol} is not None and "

# List argument length validation
if list_count_min is not None and list_count_max is not None:
# Case: len(list[]) == X
assert list_count_min <= list_count_max
if list_count_min == list_count_max:
buf.extend([
f"if {val_opt}(len({py_symbol}) != {list_count_min}): ",
*indent(
_generate_raise_value_err(
f"Length of '{py_symbol}'",
f"{list_count_min}",
f"{{len({py_symbol})}}",
)
),
])
else:
# Case: X <= len(list[]) <= Y
buf.extend([
f"if {val_opt}not ({list_count_min} <= " f"len({py_symbol}) <= {list_count_max}): ",
*indent(
_generate_raise_value_err(
f"Length of '{py_symbol}'",
f"between {list_count_min} and {list_count_max}",
f"{{len({py_symbol})}}",
)
),
])
elif list_count_min is not None:
# Case len(list[]) >= X
buf.extend([
f"if {val_opt}not ({list_count_min} <= len({py_symbol})): ",
*indent(
_generate_raise_value_err(
f"Length of '{py_symbol}'",
f"greater than {list_count_min}",
f"{{len({py_symbol})}}",
)
),
])
elif list_count_max is not None:
# Case len(list[]) <= X
buf.extend([
f"if {val_opt}not (len({py_symbol}) <= {list_count_max}): ",
*indent(
_generate_raise_value_err(
f"Length of '{py_symbol}'",
f"less than {list_count_max}",
f"{{len({py_symbol})}}",
)
),
])

# Numeric argument range validation
op_min = "<="
op_max = "<="
if min_value is not None and max_value is not None:
# Case: X <= arg <= Y
assert min_value <= max_value
if isinstance(param, ir.IList):
buf.extend([
f"if {val_opt}not ({min_value} {op_min} min({py_symbol}) "
f"and max({py_symbol}) {op_max} {max_value}): ",
*indent(
_generate_raise_value_err(
f"All elements of '{py_symbol}'",
f"between {min_value} {op_min} x {op_max} {max_value}",
)
),
])
else:
buf.extend([
f"if {val_opt}not ({min_value} {op_min} {py_symbol} {op_max} {max_value}): ",
*indent(
_generate_raise_value_err(
f"'{py_symbol}'",
f"between {min_value} {op_min} x {op_max} {max_value}",
f"{{{py_symbol}}}",
)
),
])
elif min_value is not None:
# Case: X <= arg
if isinstance(param, ir.IList):
buf.extend([
f"if {val_opt}not ({min_value} {op_min} min({py_symbol})): ",
*indent(
_generate_raise_value_err(
f"All elements of '{py_symbol}'",
f"greater than {min_value} {op_min} x",
)
),
])
else:
buf.extend([
f"if {val_opt}not ({min_value} {op_min} {py_symbol}): ",
*indent(
_generate_raise_value_err(
f"'{py_symbol}'",
f"greater than {min_value} {op_min} x",
f"{{{py_symbol}}}",
)
),
])
elif max_value is not None:
# Case: arg <= X
if isinstance(param, ir.IList):
buf.extend([
f"if {val_opt}not (max({py_symbol}) {op_max} {max_value}): ",
*indent(
_generate_raise_value_err(
f"All elements of '{py_symbol}'",
f"less than x {op_max} {max_value}",
)
),
])
else:
buf.extend([
f"if {val_opt}not ({py_symbol} {op_max} {max_value}): ",
*indent(
_generate_raise_value_err(
f"'{py_symbol}'",
f"less than x {op_max} {max_value}",
f"{{{py_symbol}}}",
)
),
])


def struct_compile_constraint_checks(
func: PyFunc,
struct: ir.IStruct | ir.IParam,
lookup: LookupParam,
) -> None:
for param in struct.struct.iter_params():
_param_compile_constraint_checks(func.body, param, lookup)
53 changes: 53 additions & 0 deletions src/styx/backend/python/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import dataclasses
from typing import Any, Generator, Iterable

from styx.backend.python.documentation import docs_to_docstring
from styx.backend.python.interface import compile_interface
from styx.backend.python.pycodegen.core import PyModule
from styx.backend.python.pycodegen.scope import Scope
from styx.backend.python.pycodegen.utils import python_snakify
from styx.ir.core import Interface, Package


@dataclasses.dataclass
class _PackageData:
package: Package
package_symbol: str
scope: Scope
module: PyModule


def to_python(interfaces: Iterable[Interface]) -> Generator[tuple[str, list[str]], Any, None]:
"""For a stream of IR interfaces return a stream of Python modules and their module paths.
Args:
interfaces: Stream of IR interfaces.
Returns:
Stream of tuples (Python module, module path).
"""
packages: dict[str, _PackageData] = {}
global_scope = Scope(parent=Scope.python())

for interface in interfaces:
if interface.package.name not in packages:
packages[interface.package.name] = _PackageData(
package=interface.package,
package_symbol=global_scope.add_or_dodge(python_snakify(interface.package.name)),
scope=Scope(parent=global_scope),
module=PyModule(
docstr=docs_to_docstring(interface.package.docs),
),
)
package_data = packages[interface.package.name]

# interface_module_symbol = global_scope.add_or_dodge(python_snakify(interface.command.param.name))
interface_module_symbol = python_snakify(interface.command.param.name)

interface_module = PyModule()
compile_interface(interface=interface, package_scope=package_data.scope, interface_module=interface_module)
package_data.module.imports.append(f"from .{interface_module_symbol} import *")
yield interface_module.text(), [package_data.package_symbol, interface_module_symbol]

for package_data in packages.values():
yield package_data.module.text(), [package_data.package_symbol, "__init__"]
54 changes: 54 additions & 0 deletions src/styx/backend/python/documentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from styx.ir.core import (
Documentation,
)


def _ensure_period(s: str) -> str:
if not s.endswith("."):
return f"{s}."
return s


def _ensure_double_linebreak_if_not_empty(s: str) -> str:
if s == "" or s.endswith("\n\n"):
return s
if s.endswith("\n"):
return f"{s}\n"
return f"{s}\n\n"


def docs_to_docstring(docs: Documentation) -> str | None:
re = ""
if docs.title:
re += docs.title

if docs.description:
re = _ensure_double_linebreak_if_not_empty(re)
re += _ensure_period(docs.description)

if docs.authors:
re = _ensure_double_linebreak_if_not_empty(re)
if len(docs.authors) == 1:
re += f"Author: {docs.authors[0]}"
else:
re += f"Authors: {', '.join(docs.authors)}"

if docs.literature:
re = _ensure_double_linebreak_if_not_empty(re)
if len(docs.literature) == 1:
re += f"Literature: {docs.literature[0]}"
else:
entries = "\n".join(docs.literature)
re += f"Literature:\n{entries}"

if docs.urls:
re = _ensure_double_linebreak_if_not_empty(re)
if len(docs.urls) == 1:
re += f"URL: {docs.urls[0]}"
else:
entries = "\n".join(docs.urls)
re += f"URLs:\n{entries}"

if re:
return re
return None
Loading

0 comments on commit ee40d49

Please sign in to comment.