Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support calling methods that doesn't match a rule name #1076

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class CalculateTree(Transformer):
number = float

def __init__(self):
super().__init__()
self.vars = {}

def assign_var(self, name, value):
Expand Down
2 changes: 1 addition & 1 deletion lark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .utils import logger
from .tree import Tree
from .visitors import Transformer, Visitor, v_args, Discard, Transformer_NonRecursive
from .visitors import Transformer, Visitor, v_args, Discard, Transformer_NonRecursive, call_for
from .exceptions import (ParseError, LexError, GrammarError, UnexpectedToken,
UnexpectedInput, UnexpectedCharacters, UnexpectedEOF, LarkError)
from .lexer import Token
Expand Down
5 changes: 5 additions & 0 deletions lark/load_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@

class FindRuleSize(Transformer):
def __init__(self, keep_all_tokens):
super().__init__()
self.keep_all_tokens = keep_all_tokens

def _will_not_get_removed(self, sym):
Expand Down Expand Up @@ -225,6 +226,7 @@ def expansions(self, args):
@inline_args
class EBNF_to_BNF(Transformer_InPlace):
def __init__(self):
super().__init__()
self.new_rules = []
self.rules_cache = {}
self.prefix = 'anon'
Expand Down Expand Up @@ -440,6 +442,7 @@ class PrepareAnonTerminals(Transformer_InPlace):
"""Create a unique list of anonymous terminals. Attempt to give meaningful names to them when we add them"""

def __init__(self, terminals):
super().__init__()
self.terminals = terminals
self.term_set = {td.name for td in self.terminals}
self.term_reverse = {td.pattern: td for td in terminals}
Expand Down Expand Up @@ -495,6 +498,7 @@ class _ReplaceSymbols(Transformer_InPlace):
"""Helper for ApplyTemplates"""

def __init__(self):
super().__init__()
self.names = {}

def value(self, c):
Expand All @@ -513,6 +517,7 @@ class ApplyTemplates(Transformer_InPlace):
"""Apply the templates, creating new rules that represent the used templates"""

def __init__(self, rule_defs):
super().__init__()
self.rule_defs = rule_defs
self.replacer = _ReplaceSymbols()
self.created_templates = set()
Expand Down
1 change: 1 addition & 0 deletions lark/reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class WriteTokensTransformer(Transformer_InPlace):
term_subs: Dict[str, Callable[[Symbol], str]]

def __init__(self, tokens: Dict[str, TerminalDef], term_subs: Dict[str, Callable[[Symbol], str]]) -> None:
super().__init__()
self.tokens = tokens
self.term_subs = term_subs

Expand Down
1 change: 1 addition & 0 deletions lark/tools/nearley.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _get_rulename(name):
@v_args(inline=True)
class NearleyToLark(Transformer):
def __init__(self):
super().__init__()
self._count = 0
self.extra_rules = {}
self.extra_rules_rev = {}
Expand Down
1 change: 1 addition & 0 deletions lark/tree_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _match_tree_template(self, template, tree):

class _ReplaceVars(Transformer):
def __init__(self, conf, vars):
super().__init__()
self._conf = conf
self._vars = vars

Expand Down
137 changes: 98 additions & 39 deletions lark/visitors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar, Tuple, List, Callable, Generic, Type, Union, Optional, Any
from typing import TypeVar, Tuple, List, Callable, Generic, Type, Union, Optional, Any, Dict
from abc import ABC

from .utils import combine_alternatives
Expand All @@ -9,6 +9,7 @@
###{standalone
from functools import wraps, update_wrapper
from inspect import getmembers, getmro
import warnings

_T = TypeVar('_T')
_R = TypeVar('_R')
Expand All @@ -35,6 +36,49 @@ def __repr__(self):

Discard = _DiscardType()

# User function lookup and overrides


def call_for(rule_name: str) -> Callable:
def _call_for(func: Callable) -> Callable:
func._rule_name = rule_name
return func

return _call_for


class _UserFuncLookup:
_user_func_overrides: Dict[str, Callable]

def __init__(self):
self._user_func_overrides = {}

def _look_up_user_func(self, rule_name: str) -> Optional[Callable]:
user_func = getattr(self, rule_name, None)
if user_func is not None:
return user_func

# backwards compatibility for subclass not calling __init__()
if not hasattr(self, "_user_func_overrides"):
warnings.warn("Subclasses of Transformer and Visitor should call super().__init__().",
DeprecationWarning)
self._user_func_overrides = {}

# check cache
user_func = self._user_func_overrides.get(rule_name)
if user_func is not None:
return user_func

for attr_name in dir(self):
if attr_name.startswith("_"):
continue
attr = getattr(self, attr_name)
if hasattr(attr, "_rule_name") and attr._rule_name == rule_name:
self._user_func_overrides[attr._rule_name] = attr
return attr

return None

# Transformers

class _Decoratable:
Expand Down Expand Up @@ -64,7 +108,7 @@ def __class_getitem__(cls, _):
return cls


class Transformer(_Decoratable, ABC, Generic[_T]):
class Transformer(_UserFuncLookup, _Decoratable, ABC, Generic[_T]):
"""Transformers visit each node of the tree, and run the appropriate method on it according to the node's data.

Methods are provided by the user via inheritance, and called according to ``tree.data``.
Expand Down Expand Up @@ -95,39 +139,38 @@ class Transformer(_Decoratable, ABC, Generic[_T]):
__visit_tokens__ = True # For backwards compatibility

def __init__(self, visit_tokens: bool=True) -> None:
super().__init__()
self.__visit_tokens__ = visit_tokens

def _call_userfunc(self, tree, new_children=None):
# Assumes tree is already transformed
children = new_children if new_children is not None else tree.children
try:
f = getattr(self, tree.data)
except AttributeError:
f = super()._look_up_user_func(tree.data)
if f is None:
return self.__default__(tree.data, children, tree.meta)
else:
try:
wrapper = getattr(f, 'visit_wrapper', None)
if wrapper is not None:
return f.visit_wrapper(f, tree.data, children, tree.meta)
else:
return f(children)
except GrammarError:
raise
except Exception as e:
raise VisitError(tree.data, tree, e)

def _call_userfunc_token(self, token):
try:
f = getattr(self, token.type)
except AttributeError:
wrapper = getattr(f, 'visit_wrapper', None)
if wrapper is not None:
return f.visit_wrapper(f, tree.data, children, tree.meta)
else:
return f(children)
except GrammarError:
raise
except Exception as e:
raise VisitError(tree.data, tree, e)

def _call_userfunc_token(self, token):
f = super()._look_up_user_func(token.type)
if f is None:
return self.__default_token__(token)
else:
try:
return f(token)
except GrammarError:
raise
except Exception as e:
raise VisitError(token.type, token, e)

try:
return f(token)
except GrammarError:
raise
except Exception as e:
raise VisitError(token.type, token, e)

def _transform_children(self, children):
for c in children:
Expand Down Expand Up @@ -204,6 +247,19 @@ def foo(self, children):
assert composed_transformer.transform(t) == 'foobar'

"""
prefix_format = "{}__{}"

def _make_merged_method(prefix_with: str, to_wrap: Callable) -> Callable:
# Python methods don't allow attributes to be set, while that is allowed for functions. As
# a result, a wrapping function is needed to update the rule name.
# A factory function is needed to capture a reference to the method that is being wrapped.
@wraps(to_wrap)
def _merged_method(*args, **kwargs):
return to_wrap(*args, **kwargs)

_merged_method._rule_name = prefix_format.format(prefix_with, to_wrap._rule_name)
return _merged_method

if base_transformer is None:
base_transformer = Transformer()
for prefix, transformer in transformers_to_merge.items():
Expand All @@ -213,10 +269,12 @@ def foo(self, children):
continue
if method_name.startswith("_") or method_name == "transform":
continue
prefixed_method = prefix + "__" + method_name
prefixed_method = prefix_format.format(prefix, method_name)
if hasattr(base_transformer, prefixed_method):
raise AttributeError("Cannot merge: method '%s' appears more than once" % prefixed_method)

if hasattr(method, "_rule_name"):
method = _make_merged_method(prefix, method)
setattr(base_transformer, prefixed_method, method)

return base_transformer
Expand All @@ -226,12 +284,10 @@ class InlineTransformer(Transformer): # XXX Deprecated
def _call_userfunc(self, tree, new_children=None):
# Assumes tree is already transformed
children = new_children if new_children is not None else tree.children
try:
f = getattr(self, tree.data)
except AttributeError:
f = super()._look_up_user_func(tree.data)
if f is None:
return self.__default__(tree.data, children, tree.meta)
else:
return f(*children)
return f(*children)

class TransformerChain(Generic[_T]):

Expand Down Expand Up @@ -317,9 +373,12 @@ def _transform_tree(self, tree):

# Visitors

class VisitorBase:
class VisitorBase(_UserFuncLookup):
def _call_userfunc(self, tree):
return getattr(self, tree.data, self.__default__)(tree)
f = super()._look_up_user_func(tree.data)
if f is None:
f = self.__default__
return f(tree)

def __default__(self, tree):
"""Default function that is called if there is no attribute matching ``tree.data``
Expand Down Expand Up @@ -379,7 +438,7 @@ def visit_topdown(self,tree: Tree) -> Tree:
return tree


class Interpreter(_Decoratable, ABC, Generic[_T]):
class Interpreter(_UserFuncLookup, _Decoratable, ABC, Generic[_T]):
"""Interpreter walks the tree starting at the root.

Visits the tree, starting with the root and finally the leaves (top-down)
Expand All @@ -392,7 +451,10 @@ class Interpreter(_Decoratable, ABC, Generic[_T]):
"""

def visit(self, tree: Tree) -> _T:
f = getattr(self, tree.data)
f = super()._look_up_user_func(tree.data)
if f is None:
return self.__default__(tree)

wrapper = getattr(f, 'visit_wrapper', None)
if wrapper is not None:
return f.visit_wrapper(f, tree.data, tree.children, tree.meta)
Expand All @@ -403,9 +465,6 @@ def visit_children(self, tree: Tree) -> List[_T]:
return [self.visit(child) if isinstance(child, Tree) else child
for child in tree.children]

def __getattr__(self, name):
return self.__default__

def __default__(self, tree):
return self.visit_children(tree)

Expand Down
Loading