Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements ToPythonASTMapper #107

Merged
merged 5 commits into from
Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 251 additions & 2 deletions pymbolic/interop/ast.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
__copyright__ = "Copyright (C) 2015 Andreas Kloeckner"
__copyright__ = """
Copyright (C) 2015 Andreas Kloeckner
Copyright (C) 2022 Kaushik Kulkarni
"""

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
Expand All @@ -22,6 +25,9 @@

import ast
import pymbolic.primitives as p
from typing import Tuple, List, Any
from pymbolic.typing import ExpressionT, ScalarT
from pymbolic.mapper import CachedMapper

__doc__ = r'''

Expand Down Expand Up @@ -189,7 +195,7 @@ def map_Call(self, expr): # noqa
# (expr func, expr* args, keyword* keywords)
func = self.rec(expr.func)
args = tuple([self.rec(arg) for arg in expr.args])
if expr.keywords:
if getattr(expr, "keywords", []):
return p.CallWithKwargs(func, args,
{
kw.arg: self.rec(kw.value)
Expand Down Expand Up @@ -252,4 +258,247 @@ def map_Tuple(self, expr): # noqa

# }}}


# {{{ PymbolicToASTMapper

class PymbolicToASTMapper(CachedMapper):
def map_variable(self, expr) -> ast.expr:
return ast.Name(id=expr.name)

def _map_multi_children_op(self,
children: Tuple[ExpressionT, ...],
op_type: ast.operator) -> ast.expr:
rec_children = [self.rec(child) for child in children]
result = rec_children[-1]
for child in rec_children[-2::-1]:
result = ast.BinOp(child, op_type, result)

return result

def map_sum(self, expr: p.Sum) -> ast.expr:
return self._map_multi_children_op(expr.children, ast.Add())

def map_product(self, expr: p.Product) -> ast.expr:
return self._map_multi_children_op(expr.children, ast.Mult())

def map_constant(self, expr: ScalarT) -> ast.expr:
import sys
if isinstance(expr, bool):
return ast.NameConstant(expr)
else:
# needed because of https://bugs.python.org/issue36280
if sys.version_info < (3, 8):
return ast.Num(expr)
else:
return ast.Constant(expr, None)

def map_call(self, expr: p.Call) -> ast.expr:
return ast.Call(
func=self.rec(expr.function),
args=[self.rec(param) for param in expr.parameters])

def map_call_with_kwargs(self, expr) -> ast.expr:
return ast.Call(
func=self.rec(expr.function),
args=[self.rec(param) for param in expr.parameters],
keywords=[
ast.keyword(
arg=kw,
value=self.rec(param))
for kw, param in sorted(expr.kw_parameters.items())])

def map_subscript(self, expr) -> ast.expr:
return ast.Subscript(value=self.rec(expr.aggregate),
slice=self.rec(expr.index))

def map_lookup(self, expr) -> ast.expr:
return ast.Attribute(self.rec(expr.aggregate),
expr.name)

def map_quotient(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.Div())

def map_floor_div(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.FloorDiv())

def map_remainder(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.Mod())

def map_power(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.base,
expr.exponent),
ast.Pow())

def map_left_shift(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.shiftee,
expr.shift),
ast.LShift())

def map_right_shift(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.RShift())

def map_bitwise_not(self, expr) -> ast.expr:
return ast.UnaryOp(ast.Invert(), self.rec(expr.child))

def map_bitwise_or(self, expr) -> ast.expr:
return self._map_multi_children_op(expr.children,
ast.BitOr())

def map_bitwise_xor(self, expr) -> ast.expr:
return self._map_multi_children_op(expr.children,
ast.BitXor())

def map_bitwise_and(self, expr) -> ast.expr:
return self._map_multi_children_op(expr.children,
ast.BitAnd())

def map_logical_not(self, expr) -> ast.expr:
return ast.UnaryOp(self.rec(expr.child), ast.Not())

def map_logical_or(self, expr) -> ast.expr:
return ast.BoolOp(ast.Or(), [self.rec(child)
for child in expr.children])

def map_logical_and(self, expr) -> ast.expr:
return ast.BoolOp(ast.And(), [self.rec(child)
for child in expr.children])

def map_list(self, expr: List[Any]) -> ast.expr:
return ast.List([self.rec(el) for el in expr])

def map_tuple(self, expr: Tuple[Any, ...]) -> ast.expr:
return ast.Tuple([self.rec(el) for el in expr])

def map_if(self, expr: p.If) -> ast.expr:
return ast.IfExp(test=self.rec(expr.condition),
body=self.rec(expr.then),
orelse=self.rec(expr.else_))

def map_nan(self, expr: p.NaN) -> ast.expr:
if isinstance(expr.data_type(float("nan")), float):
return ast.Call(
ast.Name(id="float"),
args=[ast.Constant("nan")],
keywords=[])
else:
# TODO: would need attributes of NumPy
raise NotImplementedError("Non-float nan not implemented")

def map_slice(self, expr: p.Slice) -> ast.expr:
return ast.Slice(*[self.rec(child)
for child in expr.children])

def map_numpy_array(self, expr) -> ast.expr:
raise NotImplementedError

def map_multivector(self, expr) -> ast.expr:
raise NotImplementedError

def map_common_subexpression(self, expr) -> ast.expr:
raise NotImplementedError

def map_substitution(self, expr) -> ast.expr:
raise NotImplementedError

def map_derivative(self, expr) -> ast.expr:
raise NotImplementedError

def map_if_positive(self, expr) -> ast.expr:
raise NotImplementedError

def map_comparison(self, expr: p.Comparison) -> ast.expr:
raise NotImplementedError

def map_polynomial(self, expr) -> ast.expr:
raise NotImplementedError

def map_wildcard(self, expr) -> ast.expr:
raise NotImplementedError

def map_dot_wildcard(self, expr) -> ast.expr:
raise NotImplementedError

def map_star_wildcard(self, expr) -> ast.expr:
raise NotImplementedError

def map_function_symbol(self, expr) -> ast.expr:
raise NotImplementedError

def map_min(self, expr) -> ast.expr:
raise NotImplementedError

def map_max(self, expr) -> ast.expr:
raise NotImplementedError


def to_python_ast(expr) -> ast.expr:
"""
Maps *expr* to :class:`ast.expr`.
"""
return PymbolicToASTMapper()(expr)


def to_evaluatable_python_function(expr: ExpressionT,
fn_name: str
) -> str:
"""
Returns a :class:`str` of the python code with a single function *fn_name*
that takes in the variables in *expr* as keyword-only arguments and returns
the evaluated value of *expr*.

.. testsetup::

>>> from pymbolic import parse
>>> from pymbolic.interop.ast import to_evaluatable_python_function

.. doctest::

>>> expr = parse("S//32 + E%32")
>>> # Skipping doctest as astunparse and ast.unparse have certain subtle
>>> # differences
>>> print(to_evaluatable_python_function(expr, "foo"))) # doctest: +SKIP
def foo(*, E, S):
return S // 32 + E % 32
"""
import sys
from pymbolic.mapper.dependency import CachedDependencyMapper

if sys.version_info < (3, 9):
try:
from astunparse import unparse
except ImportError:
raise RuntimeError("'to_evaluate_python_function' needs"
"astunparse for Py<3.9. Install via `pip"
" install astunparse`")
else:
unparse = ast.unparse

dep_mapper = CachedDependencyMapper(composite_leaves=True)
deps = sorted({dep.name for dep in dep_mapper(expr)})

ast_func = ast.FunctionDef(name=fn_name,
args=ast.arguments(args=[],
posonlyargs=[],
kwonlyargs=[ast.arg(dep, None)
for dep in deps],
kw_defaults=[None]*len(deps),
vararg=None,
kwarg=None,
defaults=[]),
body=[ast.Return(to_python_ast(expr))],
decorator_list=[])
ast_module = ast.Module([ast_func], type_ignores=[])

return unparse(ast.fix_missing_locations(ast_module))

# }}}

# vim: foldmethod=marker
14 changes: 14 additions & 0 deletions pymbolic/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pymbolic.primitives import Expression
from numbers import Number
from typing import Union

try:
import numpy as np
except ImportError:
BoolT = bool
else:
BoolT = Union[bool, np.bool_]


ScalarT = Union[Number, int, BoolT, float]
ExpressionT = Union[ScalarT, Expression]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# test_matchpy_interop.py
matchpy
astunparse; python_version < '3.9'
12 changes: 12 additions & 0 deletions test/test_pymbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,18 @@ def test_nodecount():
assert get_num_nodes(expr) == 12


def test_python_ast_interop_roundtrip():
from pymbolic.interop.ast import (ASTToPymbolic,
PymbolicToASTMapper)

ast2p = ASTToPymbolic()
p2ast = PymbolicToASTMapper()
ntests = 40
for i in range(ntests):
expr = generate_random_expression(seed=(5+i))
assert ast2p(p2ast(expr)) == expr


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down