From aa36e6636b5179d305a6186a56cc272f3678b52e Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Wed, 8 Nov 2023 14:39:44 +0100 Subject: [PATCH] ENH: set data keys as first positional arguments (#488) --- src/tensorwaves/function/sympy/__init__.py | 5 ++++- tests/function/test_function.py | 3 ++- tests/test_estimator.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/tensorwaves/function/sympy/__init__.py b/src/tensorwaves/function/sympy/__init__.py index 6a682d37..f63ab6f2 100644 --- a/src/tensorwaves/function/sympy/__init__.py +++ b/src/tensorwaves/function/sympy/__init__.py @@ -109,7 +109,10 @@ def create_parametrized_function( [0.0, 0.0, 0.0, 0.0, 0.0] """ free_symbols = _get_free_symbols(expression) - sorted_symbols = sorted(free_symbols, key=lambda s: s.name) + parameter_set = set(parameters) + parameter_symbols = sorted(free_symbols & parameter_set, key=lambda s: s.name) + data_symbols = sorted(free_symbols - parameter_set, key=lambda s: s.name) + sorted_symbols = tuple(data_symbols + parameter_symbols) # for partial+gradient lambdified_function = _lambdify_normal_or_fast( expression=expression, symbols=sorted_symbols, diff --git a/tests/function/test_function.py b/tests/function/test_function.py index 7bca3c3c..a6f79d3f 100644 --- a/tests/function/test_function.py +++ b/tests/function/test_function.py @@ -34,7 +34,8 @@ def function(self) -> ParametrizedBackendFunction: return create_parametrized_function(expression, parameters, backend="numpy") def test_argument_order(self, function: ParametrizedBackendFunction): - assert function.argument_order == ("c_1", "c_2", "c_3", "c_4", "x") + """Test whether data arguments come before parameters.""" + assert function.argument_order == ("x", "c_1", "c_2", "c_3", "c_4") @pytest.mark.parametrize( ("test_data", "expected_results"), diff --git a/tests/test_estimator.py b/tests/test_estimator.py index 6fe2f42a..707382f2 100644 --- a/tests/test_estimator.py +++ b/tests/test_estimator.py @@ -112,7 +112,7 @@ def test_create_cached_function(backend): assert isinstance(cached_function, ParametrizedBackendFunction) assert isinstance(cache_transformer, SympyDataTransformer) - assert cached_function.argument_order == ("a", "c", "f0", "x") + assert cached_function.argument_order == ("f0", "x", "a", "c") # data args first assert set(cached_function.parameters) == {"a", "c"} assert set(cache_transformer.functions) == {"f0", "x"}