From fbe5a74dd518e0ce82ba24a30f660ada016df5f9 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Wed, 6 Mar 2024 10:08:24 +0100 Subject: [PATCH] FIX: set i/o types for function implementations --- benchmarks/ampform.py | 6 ++++-- benchmarks/expression.py | 2 +- src/tensorwaves/data/__init__.py | 2 +- src/tensorwaves/data/transform.py | 7 +++++-- src/tensorwaves/estimator.py | 2 +- src/tensorwaves/function/__init__.py | 10 ++++------ src/tensorwaves/interface.py | 5 ++++- tests/optimizer/test_fit_simple_model.py | 2 +- 8 files changed, 21 insertions(+), 15 deletions(-) diff --git a/benchmarks/ampform.py b/benchmarks/ampform.py index 54f76c16..2bbdcd70 100644 --- a/benchmarks/ampform.py +++ b/benchmarks/ampform.py @@ -20,9 +20,11 @@ from ampform.helicity import HelicityModel from qrules.combinatorics import StateDefinition + from tensorwaves.function import ParametrizedBackendFunction from tensorwaves.interface import ( DataSample, FitResult, + Function, ParameterValue, ParametrizedFunction, ) @@ -55,7 +57,7 @@ def formulate_amplitude_model( def create_function( model: HelicityModel, backend: str, max_complexity: int | None = None -) -> ParametrizedFunction: +) -> ParametrizedBackendFunction: return create_parametrized_function( expression=model.expression.doit(), parameters=model.parameter_defaults, @@ -66,7 +68,7 @@ def create_function( def generate_data( model: HelicityModel, - function: ParametrizedFunction, + function: Function[DataSample, np.ndarray], data_sample_size: int, phsp_sample_size: int, backend: str, diff --git a/benchmarks/expression.py b/benchmarks/expression.py index c524b390..1964f38d 100644 --- a/benchmarks/expression.py +++ b/benchmarks/expression.py @@ -61,7 +61,7 @@ def _generate_domain( def _generate_data( size: int, - function: Function, + function: Function[DataSample, np.ndarray], rng: np.random.Generator, bunch_size: int = 10_000, ) -> DataSample: diff --git a/src/tensorwaves/data/__init__.py b/src/tensorwaves/data/__init__.py index 8be4e7d5..c5fd6b45 100644 --- a/src/tensorwaves/data/__init__.py +++ b/src/tensorwaves/data/__init__.py @@ -71,7 +71,7 @@ class IntensityDistributionGenerator(DataGenerator): def __init__( self, domain_generator: DataGenerator, - function: Function, + function: Function[DataSample, np.ndarray], domain_transformer: DataTransformer | None = None, bunch_size: int = 50_000, ) -> None: diff --git a/src/tensorwaves/data/transform.py b/src/tensorwaves/data/transform.py index 2fea0154..42493d50 100644 --- a/src/tensorwaves/data/transform.py +++ b/src/tensorwaves/data/transform.py @@ -16,6 +16,7 @@ from ._attrs import to_tuple if TYPE_CHECKING: # pragma: no cover + import numpy as np import sympy as sp @@ -55,7 +56,9 @@ def __call__(self, data: DataSample) -> DataSample: class SympyDataTransformer(DataTransformer): """Implementation of a `.DataTransformer`.""" - def __init__(self, functions: Mapping[str, Function]) -> None: + def __init__( + self, functions: Mapping[str, Function[DataSample, np.ndarray]] + ) -> None: if any(not isinstance(f, Function) for f in functions.values()): msg = ( f"Not all values in the mapping are an instance of {Function.__name__}" @@ -64,7 +67,7 @@ def __init__(self, functions: Mapping[str, Function]) -> None: self.__functions = dict(functions) @property - def functions(self) -> dict[str, Function]: + def functions(self) -> dict[str, Function[DataSample, np.ndarray]]: """Read-only access to the internal mapping of functions.""" return dict(self.__functions) diff --git a/src/tensorwaves/estimator.py b/src/tensorwaves/estimator.py index 7f75b54b..8ce98a30 100644 --- a/src/tensorwaves/estimator.py +++ b/src/tensorwaves/estimator.py @@ -118,7 +118,7 @@ class ChiSquared(Estimator): def __init__( # noqa: PLR0913 self, - function: ParametrizedFunction, + function: ParametrizedFunction[DataSample, np.ndarray], domain: DataSample, observed_values: np.ndarray, weights: np.ndarray | None = None, diff --git a/src/tensorwaves/function/__init__.py b/src/tensorwaves/function/__init__.py index 39263092..8453c7b5 100644 --- a/src/tensorwaves/function/__init__.py +++ b/src/tensorwaves/function/__init__.py @@ -3,9 +3,10 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Callable, Iterable, Mapping +from typing import Callable, Iterable, Mapping import attrs +import numpy as np from attrs import field, frozen from tensorwaves.interface import ( @@ -15,9 +16,6 @@ ParametrizedFunction, ) -if TYPE_CHECKING: - import numpy as np - def _all_str( _: PositionalArgumentFunction, __: attrs.Attribute, value: Iterable[str] @@ -66,7 +64,7 @@ def _to_tuple(argument_order: Iterable[str]) -> tuple[str, ...]: @frozen -class PositionalArgumentFunction(Function): +class PositionalArgumentFunction(Function[DataSample, np.ndarray]): """Wrapper around a function with positional arguments. This class provides a :meth:`~.Function.__call__` that can take a `.DataSample` for @@ -90,7 +88,7 @@ def __call__(self, data: DataSample) -> np.ndarray: return self.function(*args) -class ParametrizedBackendFunction(ParametrizedFunction): +class ParametrizedBackendFunction(ParametrizedFunction[DataSample, np.ndarray]): """Implements `.ParametrizedFunction` for a specific computational back-end. .. seealso:: :func:`.create_parametrized_function` diff --git a/src/tensorwaves/interface.py b/src/tensorwaves/interface.py index 3e0777d3..c4064936 100644 --- a/src/tensorwaves/interface.py +++ b/src/tensorwaves/interface.py @@ -41,7 +41,10 @@ def __call__(self, data: InputType) -> OutputType: ... """Allowed types for parameter values.""" -class ParametrizedFunction(Function[DataSample, np.ndarray]): +class ParametrizedFunction( + Function[InputType, OutputType], + Generic[InputType, OutputType], +): """Interface of a callable function. A `ParametrizedFunction` identifies certain variables in a mathematical expression diff --git a/tests/optimizer/test_fit_simple_model.py b/tests/optimizer/test_fit_simple_model.py index 5ff89525..52e8f79a 100644 --- a/tests/optimizer/test_fit_simple_model.py +++ b/tests/optimizer/test_fit_simple_model.py @@ -38,7 +38,7 @@ def generate_domain( def generate_data( size: int, boundaries: dict[str, tuple[float, float]], - function: Function, + function: Function[DataSample, np.ndarray], rng: np.random.Generator, bunch_size: int = 10_000, ) -> DataSample: