Skip to content

Commit

Permalink
Fix issue with lark 1.1.9 (#141)
Browse files Browse the repository at this point in the history
* fix lookup w/ string

* test null to empty

* pep

* re-remove pinned lark version

---------

Co-authored-by: sambles <[email protected]>
  • Loading branch information
ncerutti and sambles committed Sep 19, 2024
1 parent 55cae66 commit 0469379
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 7 deletions.
18 changes: 15 additions & 3 deletions ods_tools/odtf/transformers/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Callable, Iterable, List, Pattern, TypedDict, Union

from lark import Transformer as _LarkTransformer
from lark import Tree
from lark import Token, Tree
from lark import exceptions as lark_exceptions
from lark import v_args
from ..transformers.transform_utils import replace_multiple
Expand Down Expand Up @@ -452,7 +452,9 @@ def boolean(self, value):
:return: True if the value is "True", False otherwise
"""
return value == "True"
if isinstance(value, bool):
return value
return str(value).lower() == "true"

def null(self, value):
"""
Expand All @@ -479,6 +481,16 @@ def number(self, value):
return float(value)


def safe_lookup(r, name):
if isinstance(name, Token):
name = name.value
if name == 'True':
return True
elif name == 'False':
return False
return r.get(name, name)


def create_transformer_class(row, transformer_mapping):
"""
Creates a transformer class from the provided mapping overrides.
Expand All @@ -489,7 +501,7 @@ def create_transformer_class(row, transformer_mapping):
:return: The new transformer class
"""
transformer_mapping = {
"lookup": lambda r, name: r[name],
"lookup": safe_lookup,
"add": lambda r, lhs, rhs: add(lhs, rhs),
"subtract": lambda r, lhs, rhs: sub(lhs, rhs),
"multiply": lambda r, lhs, rhs: mul(lhs, rhs),
Expand Down
2 changes: 1 addition & 1 deletion requirements-extra.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
PyYAML
lark<=1.1.9
lark
networkx
pyodbc
sqlparams
Expand Down
2 changes: 1 addition & 1 deletion tests/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pytest
requests
chardet
PyYAML
lark<=1.1.9
lark
networkx
pyodbc
sqlparams
Expand Down
2 changes: 1 addition & 1 deletion tests/t_input.csv
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ Line,Input_int_1,Input_int_2,Input_string_1,Input_string_2,Input_multistring_1,I
6,55,55,letter_F,letter_F,"letter_C, letter_I, letter_A",ARG,5.2,5.2
7,101,101,letter_G,letter_G,"letter_B, letter_E, letter_E",,7.9,7.9
8,999,999,letter_H,letter_H,"letter_J, letter_I, letter_I","USA, UK",111.11,111.11
9,777,777,letter_I,letter_I,"letter_G, letter_I, letter_G",Null,0.001,0.001
9,777,777,letter_I,letter_I,"letter_G, letter_I, letter_G",,0.001,0.001
10,1,1,,,"letter_B, letter_A, letter_G","ARG, BRA, USA",,
3 changes: 2 additions & 1 deletion tests/test_ods_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,9 +948,10 @@ def test_transformation_as_expected(self):
0.000318471337579618, np.nan],
'Output_multistring_1': ["A;B;C", "A;J", "E;C", 'H', '', "C;I;A", "B;E;E", "J;I;I", "G;I;G", "B;A;G"],
'Output_multistring_2': ["United Kingdom;Italy", "Germany;Brasil", "France;France", "Sweden",
"Spain;Sweden", "Argentina", '', "United States;United Kingdom", "Null",
"Spain;Sweden", "Argentina", '', "United States;United Kingdom", '',
"Argentina;Brasil;United States"]
}

for column, values in expected_values.items():
if 'float' in column.lower():
assert np.allclose(output_df[column].tolist(), values, equal_nan=True, rtol=1e-5, atol=1e-5)
Expand Down

0 comments on commit 0469379

Please sign in to comment.