diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5994be09c7..d6a3e7fcb7 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,12 @@
# Changelog
+### 35.3.7 [#990](https://github.com/openfisca/openfisca-core/pull/990)
+
+#### Technical changes
+
+- Update dependencies.
+ - Extend NumPy compatibility to v1.20 to support M1 processors.
+
### 35.3.6 [#984](https://github.com/openfisca/openfisca-core/pull/984)
#### Technical changes
diff --git a/openfisca_core/indexed_enums/enum.py b/openfisca_core/indexed_enums/enum.py
index 7ba1f78b41..3fbb98f6d5 100644
--- a/openfisca_core/indexed_enums/enum.py
+++ b/openfisca_core/indexed_enums/enum.py
@@ -7,23 +7,28 @@
from openfisca_core.indexed_enums import config, EnumArray
+if typing.TYPE_CHECKING:
+ IndexedEnumArray = numpy.object_
+
class Enum(enum.Enum):
"""
- Enum based on `enum34 `_, whose items have an
- index.
+ Enum based on `enum34 `_, whose items
+ have an index.
"""
# Tweak enums to add an index attribute to each enum item
def __init__(self, name: str) -> None:
- # When the enum item is initialized, self._member_names_ contains the names of
- # the previously initialized items, so its length is the index of this item.
+ # When the enum item is initialized, self._member_names_ contains the
+ # names of the previously initialized items, so its length is the index
+ # of this item.
self.index = len(self._member_names_)
# Bypass the slow Enum.__eq__
__eq__ = object.__eq__
- # In Python 3, __hash__ must be defined if __eq__ is defined to stay hashable.
+ # In Python 3, __hash__ must be defined if __eq__ is defined to stay
+ # hashable.
__hash__ = object.__hash__
@classmethod
@@ -31,16 +36,18 @@ def encode(
cls,
array: typing.Union[
EnumArray,
- numpy.ndarray[int],
- numpy.ndarray[str],
- numpy.ndarray[Enum],
+ numpy.int_,
+ numpy.float_,
+ IndexedEnumArray,
],
) -> EnumArray:
"""
- Encode a string numpy array, an enum item numpy array, or an int numpy array
- into an :any:`EnumArray`. See :any:`EnumArray.decode` for decoding.
+ Encode a string numpy array, an enum item numpy array, or an int numpy
+ array into an :any:`EnumArray`. See :any:`EnumArray.decode` for
+ decoding.
- :param ndarray array: Array of string identifiers, or of enum items, to encode.
+ :param ndarray array: Array of string identifiers, or of enum items, to
+ encode.
:returns: An :any:`EnumArray` encoding the input array values.
:rtype: :any:`EnumArray`
@@ -59,24 +66,31 @@ def encode(
>>> encoded_array[0]
2 # Encoded value
"""
- if type(array) is EnumArray:
+ if isinstance(array, EnumArray):
return array
- if array.dtype.kind in {'U', 'S'}: # String array
+ # String array
+ if isinstance(array, numpy.ndarray) and \
+ array.dtype.kind in {'U', 'S'}:
array = numpy.select(
[array == item.name for item in cls],
[item.index for item in cls],
).astype(config.ENUM_ARRAY_DTYPE)
- elif array.dtype.kind == 'O': # Enum items arrays
+ # Enum items arrays
+ elif isinstance(array, numpy.ndarray) and \
+ array.dtype.kind == 'O':
# Ensure we are comparing the comparable. The problem this fixes:
# On entering this method "cls" will generally come from
- # variable.possible_values, while the array values may come from directly
- # importing a module containing an Enum class. However, variables (and
- # hence their possible_values) are loaded by a call to load_module, which
- # gives them a different identity from the ones imported in the usual way.
- # So, instead of relying on the "cls" passed in, we use only its name to
- # check that the values in the array, if non-empty, are of the right type.
+ # variable.possible_values, while the array values may come from
+ # directly importing a module containing an Enum class. However,
+ # variables (and hence their possible_values) are loaded by a call
+ # to load_module, which gives them a different identity from the
+ # ones imported in the usual way.
+ #
+ # So, instead of relying on the "cls" passed in, we use only its
+ # name to check that the values in the array, if non-empty, are of
+ # the right type.
if len(array) > 0 and cls.__name__ is array[0].__class__.__name__:
cls = array[0].__class__
diff --git a/openfisca_core/indexed_enums/enum_array.py b/openfisca_core/indexed_enums/enum_array.py
index 278c0bd289..90ab35922a 100644
--- a/openfisca_core/indexed_enums/enum_array.py
+++ b/openfisca_core/indexed_enums/enum_array.py
@@ -7,6 +7,8 @@
if typing.TYPE_CHECKING:
from openfisca_core.indexed_enums import Enum
+ IndexedEnumArray = numpy.object_
+
class EnumArray(numpy.ndarray):
"""
@@ -20,7 +22,7 @@ class EnumArray(numpy.ndarray):
# https://docs.scipy.org/doc/numpy-1.13.0/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array.
def __new__(
cls,
- input_array: numpy.ndarray[int],
+ input_array: numpy.int_,
possible_values: typing.Optional[typing.Type[Enum]] = None,
) -> EnumArray:
obj = numpy.asarray(input_array).view(cls)
@@ -28,15 +30,15 @@ def __new__(
return obj
# See previous comment
- def __array_finalize__(self, obj: typing.Optional[numpy.ndarray[int]]) -> None:
+ def __array_finalize__(self, obj: typing.Optional[numpy.int_]) -> None:
if obj is None:
return
self.possible_values = getattr(obj, "possible_values", None)
def __eq__(self, other: typing.Any) -> bool:
- # When comparing to an item of self.possible_values, use the item index to
- # speed up the comparison.
+ # When comparing to an item of self.possible_values, use the item index
+ # to speed up the comparison.
if other.__class__.__name__ is self.possible_values.__name__:
# Use view(ndarray) so that the result is a classic ndarray, not an
# EnumArray.
@@ -49,8 +51,8 @@ def __ne__(self, other: typing.Any) -> bool:
def _forbidden_operation(self, other: typing.Any) -> typing.NoReturn:
raise TypeError(
- "Forbidden operation. The only operations allowed on EnumArrays are "
- "'==' and '!='.",
+ "Forbidden operation. The only operations allowed on EnumArrays "
+ "are '==' and '!='.",
)
__add__ = _forbidden_operation
@@ -62,7 +64,7 @@ def _forbidden_operation(self, other: typing.Any) -> typing.NoReturn:
__and__ = _forbidden_operation
__or__ = _forbidden_operation
- def decode(self) -> numpy.ndarray[Enum]:
+ def decode(self) -> IndexedEnumArray:
"""
Return the array of enum items corresponding to self.
@@ -72,14 +74,16 @@ def decode(self) -> numpy.ndarray[Enum]:
>>> enum_array[0]
>>> 2 # Encoded value
>>> enum_array.decode()[0]
- # Decoded value : enum item
+
+
+ Decoded value: enum item
"""
return numpy.select(
[self == item.index for item in self.possible_values],
list(self.possible_values),
)
- def decode_to_str(self) -> numpy.ndarray[str]:
+ def decode_to_str(self) -> numpy.str_:
"""
Return the array of string identifiers corresponding to self.
diff --git a/openfisca_core/taxscales/abstract_rate_tax_scale.py b/openfisca_core/taxscales/abstract_rate_tax_scale.py
index bc972824f1..b9316273d1 100644
--- a/openfisca_core/taxscales/abstract_rate_tax_scale.py
+++ b/openfisca_core/taxscales/abstract_rate_tax_scale.py
@@ -3,28 +3,36 @@
import typing
import warnings
-import numpy
-
from openfisca_core.taxscales import RateTaxScaleLike
+if typing.TYPE_CHECKING:
+ import numpy
+
+ NumericalArray = typing.Union[numpy.int_, numpy.float_]
+
class AbstractRateTaxScale(RateTaxScaleLike):
"""
- Base class for various types of rate-based tax scales: marginal rate, linear
- average rate...
+ Base class for various types of rate-based tax scales: marginal rate,
+ linear average rate...
"""
- def __init__(self, name: typing.Optional[str] = None, option = None, unit = None) -> None:
+ def __init__(
+ self, name: typing.Optional[str] = None,
+ option: typing.Any = None,
+ unit: typing.Any = None,
+ ) -> None:
message = [
- "The 'AbstractRateTaxScale' class has been deprecated since version",
- "34.7.0, and will be removed in the future.",
+ "The 'AbstractRateTaxScale' class has been deprecated since",
+ "version 34.7.0, and will be removed in the future.",
]
+
warnings.warn(" ".join(message), DeprecationWarning)
- super(AbstractRateTaxScale, self).__init__(name, option, unit)
+ super().__init__(name, option, unit)
def calc(
self,
- tax_base: typing.Union[numpy.ndarray[int], numpy.ndarray[float]],
+ tax_base: NumericalArray,
right: bool,
) -> typing.NoReturn:
raise NotImplementedError(
diff --git a/openfisca_core/taxscales/abstract_tax_scale.py b/openfisca_core/taxscales/abstract_tax_scale.py
index 9d5f23538b..9cbeeb7565 100644
--- a/openfisca_core/taxscales/abstract_tax_scale.py
+++ b/openfisca_core/taxscales/abstract_tax_scale.py
@@ -3,24 +3,34 @@
import typing
import warnings
-import numpy
-
from openfisca_core.taxscales import TaxScaleLike
+if typing.TYPE_CHECKING:
+ import numpy
+
+ NumericalArray = typing.Union[numpy.int_, numpy.float_]
+
class AbstractTaxScale(TaxScaleLike):
"""
- Base class for various types of tax scales: amount-based tax scales, rate-based
- tax scales...
+ Base class for various types of tax scales: amount-based tax scales,
+ rate-based tax scales...
"""
- def __init__(self, name: typing.Optional[str] = None, option = None, unit = None) -> None:
+ def __init__(
+ self,
+ name: typing.Optional[str] = None,
+ option: typing.Any = None,
+ unit: numpy.int_ = None,
+ ) -> None:
+
message = [
- "The 'AbstractTaxScale' class has been deprecated since version 34.7.0,",
- "and will be removed in the future.",
+ "The 'AbstractTaxScale' class has been deprecated since",
+ "version 34.7.0, and will be removed in the future.",
]
+
warnings.warn(" ".join(message), DeprecationWarning)
- super(AbstractTaxScale, self).__init__(name, option, unit)
+ super().__init__(name, option, unit)
def __repr__(self) -> typing.NoReturn:
raise NotImplementedError(
@@ -30,7 +40,7 @@ def __repr__(self) -> typing.NoReturn:
def calc(
self,
- tax_base: typing.Union[numpy.ndarray[int], numpy.ndarray[float]],
+ tax_base: NumericalArray,
right: bool,
) -> typing.NoReturn:
raise NotImplementedError(
diff --git a/openfisca_core/taxscales/amount_tax_scale_like.py b/openfisca_core/taxscales/amount_tax_scale_like.py
index 465635bf0b..cfc0a6973f 100644
--- a/openfisca_core/taxscales/amount_tax_scale_like.py
+++ b/openfisca_core/taxscales/amount_tax_scale_like.py
@@ -9,14 +9,19 @@
class AmountTaxScaleLike(TaxScaleLike, abc.ABC):
"""
- Base class for various types of amount-based tax scales: single amount, marginal
- amount...
+ Base class for various types of amount-based tax scales: single amount,
+ marginal amount...
"""
amounts: typing.List
- def __init__(self, name: typing.Optional[str] = None, option = None, unit = None) -> None:
- super(AmountTaxScaleLike, self).__init__(name, option, unit)
+ def __init__(
+ self,
+ name: typing.Optional[str] = None,
+ option: typing.Any = None,
+ unit: typing.Any = None,
+ ) -> None:
+ super().__init__(name, option, unit)
self.amounts = []
def __repr__(self) -> str:
@@ -24,7 +29,8 @@ def __repr__(self) -> str:
os.linesep.join(
[
f"- threshold: {threshold}{os.linesep} amount: {amount}"
- for (threshold, amount) in zip(self.thresholds, self.amounts)
+ for (threshold, amount)
+ in zip(self.thresholds, self.amounts)
]
)
)
@@ -37,6 +43,7 @@ def add_bracket(
if threshold in self.thresholds:
i = self.thresholds.index(threshold)
self.amounts[i] += amount
+
else:
i = bisect.bisect_left(self.thresholds, threshold)
self.thresholds.insert(i, threshold)
@@ -45,5 +52,6 @@ def add_bracket(
def to_dict(self) -> dict:
return {
str(threshold): self.amounts[index]
- for index, threshold in enumerate(self.thresholds)
+ for index, threshold
+ in enumerate(self.thresholds)
}
diff --git a/openfisca_core/taxscales/helpers.py b/openfisca_core/taxscales/helpers.py
index a1b49e197e..181fbfed36 100644
--- a/openfisca_core/taxscales/helpers.py
+++ b/openfisca_core/taxscales/helpers.py
@@ -10,11 +10,13 @@
if typing.TYPE_CHECKING:
from openfisca_core.parameters import ParameterNodeAtInstant
+ TaxScales = typing.Optional[taxscales.MarginalRateTaxScale]
+
def combine_tax_scales(
node: ParameterNodeAtInstant,
- combined_tax_scales: typing.Optional[taxscales.MarginalRateTaxScale] = None,
- ) -> typing.Optional[taxscales.MarginalRateTaxScale]:
+ combined_tax_scales: TaxScales = None,
+ ) -> TaxScales:
"""
Combine all the MarginalRateTaxScales in the node into a single
MarginalRateTaxScale.
diff --git a/openfisca_core/taxscales/linear_average_rate_tax_scale.py b/openfisca_core/taxscales/linear_average_rate_tax_scale.py
index 885d29d8cb..d1fe9c8094 100644
--- a/openfisca_core/taxscales/linear_average_rate_tax_scale.py
+++ b/openfisca_core/taxscales/linear_average_rate_tax_scale.py
@@ -10,14 +10,17 @@
log = logging.getLogger(__name__)
+if typing.TYPE_CHECKING:
+ NumericalArray = typing.Union[numpy.int_, numpy.float_]
+
class LinearAverageRateTaxScale(RateTaxScaleLike):
def calc(
self,
- tax_base: typing.Union[numpy.ndarray[int], numpy.ndarray[float]],
+ tax_base: NumericalArray,
right: bool = False,
- ) -> numpy.ndarray[float]:
+ ) -> numpy.float_:
if len(self.rates) == 1:
return tax_base * self.rates[0]
diff --git a/openfisca_core/taxscales/marginal_amount_tax_scale.py b/openfisca_core/taxscales/marginal_amount_tax_scale.py
index 9208813d47..348f2445c0 100644
--- a/openfisca_core/taxscales/marginal_amount_tax_scale.py
+++ b/openfisca_core/taxscales/marginal_amount_tax_scale.py
@@ -6,19 +6,29 @@
from openfisca_core.taxscales import AmountTaxScaleLike
+if typing.TYPE_CHECKING:
+ NumericalArray = typing.Union[numpy.int_, numpy.float_]
+
class MarginalAmountTaxScale(AmountTaxScaleLike):
def calc(
self,
- tax_base: typing.Union[numpy.ndarray[int], numpy.ndarray[float]],
+ tax_base: NumericalArray,
right: bool = False,
- ) -> numpy.ndarray[float]:
+ ) -> numpy.float_:
"""
- Matches the input amount to a set of brackets and returns the sum of cell
- values from the lowest bracket to the one containing the input.
+ Matches the input amount to a set of brackets and returns the sum of
+ cell values from the lowest bracket to the one containing the input.
"""
base1 = numpy.tile(tax_base, (len(self.thresholds), 1)).T
- thresholds1 = numpy.tile(numpy.hstack((self.thresholds, numpy.inf)), (len(tax_base), 1))
- a = numpy.maximum(numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0)
+
+ thresholds1 = numpy.tile(
+ numpy.hstack((self.thresholds, numpy.inf)), (len(tax_base), 1)
+ )
+
+ a = numpy.maximum(
+ numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0
+ )
+
return numpy.dot(self.amounts, a.T > 0)
diff --git a/openfisca_core/taxscales/marginal_rate_tax_scale.py b/openfisca_core/taxscales/marginal_rate_tax_scale.py
index e3749209ed..38331e0bb8 100644
--- a/openfisca_core/taxscales/marginal_rate_tax_scale.py
+++ b/openfisca_core/taxscales/marginal_rate_tax_scale.py
@@ -9,11 +9,14 @@
from openfisca_core import taxscales
from openfisca_core.taxscales import RateTaxScaleLike
+if typing.TYPE_CHECKING:
+ NumericalArray = typing.Union[numpy.int_, numpy.float_]
+
class MarginalRateTaxScale(RateTaxScaleLike):
def add_tax_scale(self, tax_scale: RateTaxScaleLike) -> None:
- # Pour ne pas avoir de problèmes avec les barèmes vides
+ # So as not to have problems with empty scales
if (len(tax_scale.thresholds) > 0):
for threshold_low, threshold_high, rate in zip(
tax_scale.thresholds[:-1],
@@ -22,7 +25,7 @@ def add_tax_scale(self, tax_scale: RateTaxScaleLike) -> None:
):
self.combine_bracket(rate, threshold_low, threshold_high)
- # Pour traiter le dernier threshold
+ # To process the last threshold
self.combine_bracket(
tax_scale.rates[-1],
tax_scale.thresholds[-1],
@@ -30,16 +33,17 @@ def add_tax_scale(self, tax_scale: RateTaxScaleLike) -> None:
def calc(
self,
- tax_base: typing.Union[numpy.ndarray[int], numpy.ndarray[float]],
+ tax_base: NumericalArray,
factor: float = 1.0,
round_base_decimals: typing.Optional[int] = None,
- ) -> numpy.ndarray[float]:
+ ) -> numpy.float_:
"""
- Compute the tax amount for the given tax bases by applying the taxscale.
+ Compute the tax amount for the given tax bases by applying a taxscale.
:param ndarray tax_base: Array of the tax bases.
- :param float factor: Factor to apply to the thresholds of the tax scale.
- :param int round_base_decimals: Decimals to keep when rounding thresholds.
+ :param float factor: Factor to apply to the thresholds of the taxscale.
+ :param int round_base_decimals: Decimals to keep when rounding
+ thresholds.
:returns: Float array with tax amount for the given tax bases.
@@ -52,20 +56,31 @@ def calc(
>>> tax_scale.calc(tax_base)
[0.0, 5.0]
"""
-
base1 = numpy.tile(tax_base, (len(self.thresholds), 1)).T
factor = numpy.ones(len(tax_base)) * factor
- # finfo(float_).eps is used to avoid nan = 0 * inf creation
- thresholds1 = numpy.outer(factor + numpy.finfo(numpy.float_).eps, numpy.array(self.thresholds + [numpy.inf]))
+ # To avoid the creation of:
+ #
+ # numpy.nan = 0 * numpy.inf
+ #
+ # We use:
+ #
+ # numpy.finfo(float_).eps
+ thresholds1 = numpy.outer(
+ factor + numpy.finfo(numpy.float_).eps,
+ numpy.array(self.thresholds + [numpy.inf]),
+ )
if round_base_decimals is not None:
thresholds1 = numpy.round_(thresholds1, round_base_decimals)
- a = numpy.maximum(numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0)
+ a = numpy.maximum(
+ numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0
+ )
if round_base_decimals is None:
return numpy.dot(self.rates, a.T)
+
else:
r = numpy.tile(self.rates, (len(tax_base), 1))
b = numpy.round_(a, round_base_decimals)
@@ -91,26 +106,30 @@ def combine_bracket(
if threshold_high:
j = self.thresholds.index(threshold_high) - 1
+
else:
j = len(self.thresholds) - 1
+
while i <= j:
self.add_bracket(self.thresholds[i], rate)
i += 1
def marginal_rates(
self,
- tax_base: typing.Union[numpy.ndarray[int], numpy.ndarray[float]],
+ tax_base: NumericalArray,
factor: float = 1.0,
round_base_decimals: typing.Optional[int] = None,
- ) -> numpy.ndarray[float]:
+ ) -> numpy.float_:
"""
Compute the marginal tax rates relevant for the given tax bases.
:param ndarray tax_base: Array of the tax bases.
- :param float factor: Factor to apply to the thresholds of the tax scale.
- :param int round_base_decimals: Decimals to keep when rounding thresholds.
+ :param float factor: Factor to apply to the thresholds of a tax scale.
+ :param int round_base_decimals: Decimals to keep when rounding
+ thresholds.
- :returns: Float array with relevant marginal tax rate for the given tax bases.
+ :returns: Float array with relevant marginal tax rate for the given tax
+ bases.
For instance:
@@ -135,29 +154,29 @@ def inverse(self) -> MarginalRateTaxScale:
Invert a taxscale:
- Assume tax_scale composed of bracket which thresholds are expressed in term
- of brut revenue.
+ Assume tax_scale composed of bracket whose thresholds are expressed
+ in terms of gross revenue.
- The inverse is another MarginalTaxSclae which thresholds are expressed in
- terms of net revenue.
+ The inverse is another MarginalRateTaxScale whose thresholds are
+ expressed in terms of net revenue.
- IF net = revbrut - tax_scale.calc(revbrut)
- THEN brut = tax_scale.inverse().calc(net)
+ If net = gross_revenue - tax_scale.calc(gross_revenue)
+ Then gross = tax_scale.inverse().calc(net)
"""
# Threshold of net revenue.
net_threshold: int = 0
- # Threshold of brut revenue.
+ # Threshold of gross revenue.
threshold: int
- # Ordonnée à l'origine des segments des différents seuils dans une
- # représentation du revenu imposable comme fonction linéaire par morceaux du
- # revenu brut.
+ # The intercept of the segments of the different thresholds in a
+ # representation of taxable revenue as a piecewise linear function
+ # of gross revenue.
theta: int
- # Actually 1 / (1- global_rate)
+ # Actually 1 / (1 - global_rate)
inverse = self.__class__(
- name = self.name + "'",
+ name = str(self.name) + "'",
option = self.option,
unit = self.unit,
)
@@ -167,7 +186,8 @@ def inverse(self) -> MarginalRateTaxScale:
previous_rate = 0
theta = 0
- # On calcule le seuil de revenu imposable de la tranche considérée.
+ # We calculate the taxable revenue threshold of the considered
+ # bracket.
net_threshold = (1 - previous_rate) * threshold + theta
inverse.add_bracket(net_threshold, 1 / (1 - rate))
theta = (rate - previous_rate) * threshold + theta
diff --git a/openfisca_core/taxscales/rate_tax_scale_like.py b/openfisca_core/taxscales/rate_tax_scale_like.py
index ba21e3007a..824a94debe 100644
--- a/openfisca_core/taxscales/rate_tax_scale_like.py
+++ b/openfisca_core/taxscales/rate_tax_scale_like.py
@@ -11,17 +11,25 @@
from openfisca_core.errors import EmptyArgumentError
from openfisca_core.taxscales import TaxScaleLike
+if typing.TYPE_CHECKING:
+ NumericalArray = typing.Union[numpy.int_, numpy.float_]
+
class RateTaxScaleLike(TaxScaleLike, abc.ABC):
"""
- Base class for various types of rate-based tax scales: marginal rate, linear
- average rate...
+ Base class for various types of rate-based tax scales: marginal rate,
+ linear average rate...
"""
rates: typing.List
- def __init__(self, name: typing.Optional[str] = None, option = None, unit = None) -> None:
- super(RateTaxScaleLike, self).__init__(name, option, unit)
+ def __init__(
+ self,
+ name: typing.Optional[str] = None,
+ option: typing.Any = None,
+ unit: typing.Any = None,
+ ) -> None:
+ super().__init__(name, option, unit)
self.rates = []
def __repr__(self) -> str:
@@ -29,7 +37,8 @@ def __repr__(self) -> str:
os.linesep.join(
[
f"- threshold: {threshold}{os.linesep} rate: {rate}"
- for (threshold, rate) in zip(self.thresholds, self.rates)
+ for (threshold, rate)
+ in zip(self.thresholds, self.rates)
]
)
)
@@ -42,6 +51,7 @@ def add_bracket(
if threshold in self.thresholds:
i = self.thresholds.index(threshold)
self.rates[i] += rate
+
else:
i = bisect.bisect_left(self.thresholds, threshold)
self.thresholds.insert(i, threshold)
@@ -85,7 +95,11 @@ def multiply_thresholds(
for i, threshold in enumerate(self.thresholds):
if decimals is not None:
- self.thresholds[i] = numpy.around(threshold * factor, decimals = decimals)
+ self.thresholds[i] = numpy.around(
+ threshold * factor,
+ decimals = decimals,
+ )
+
else:
self.thresholds[i] = threshold * factor
@@ -111,18 +125,19 @@ def multiply_thresholds(
def bracket_indices(
self,
- tax_base: typing.Union[numpy.ndarray[int], numpy.ndarray[float]],
+ tax_base: NumericalArray,
factor: float = 1.0,
round_decimals: typing.Optional[int] = None,
- ) -> numpy.ndarray[int]:
+ ) -> numpy.int_:
"""
Compute the relevant bracket indices for the given tax bases.
:param ndarray tax_base: Array of the tax bases.
- :param float factor: Factor to apply to the thresholds of the tax scales.
+ :param float factor: Factor to apply to the thresholds.
:param int round_decimals: Decimals to keep when rounding thresholds.
- :returns: Int array with relevant bracket indices for the given tax bases.
+ :returns: Integer array with relevant bracket indices for the given tax
+ bases.
For instance:
@@ -153,8 +168,17 @@ def bracket_indices(
base1 = numpy.tile(tax_base, (len(self.thresholds), 1)).T
factor = numpy.ones(len(tax_base)) * factor
- # finfo(float_).eps is used to avoid nan = 0 * inf creation
- thresholds1 = numpy.outer(factor + numpy.finfo(numpy.float_).eps, numpy.array(self.thresholds))
+ # To avoid the creation of:
+ #
+ # numpy.nan = 0 * numpy.inf
+ #
+ # We use:
+ #
+ # numpy.finfo(float_).eps
+ thresholds1 = numpy.outer(
+ + factor
+ + numpy.finfo(numpy.float_).eps, numpy.array(self.thresholds)
+ )
if round_decimals is not None:
thresholds1 = numpy.round_(thresholds1, round_decimals)
@@ -164,5 +188,6 @@ def bracket_indices(
def to_dict(self) -> dict:
return {
str(threshold): self.rates[index]
- for index, threshold in enumerate(self.thresholds)
+ for index, threshold
+ in enumerate(self.thresholds)
}
diff --git a/openfisca_core/taxscales/single_amount_tax_scale.py b/openfisca_core/taxscales/single_amount_tax_scale.py
index 891614adee..bdfee48010 100644
--- a/openfisca_core/taxscales/single_amount_tax_scale.py
+++ b/openfisca_core/taxscales/single_amount_tax_scale.py
@@ -6,19 +6,37 @@
from openfisca_core.taxscales import AmountTaxScaleLike
+if typing.TYPE_CHECKING:
+ NumericalArray = typing.Union[numpy.int_, numpy.float_]
+
class SingleAmountTaxScale(AmountTaxScaleLike):
def calc(
self,
- tax_base: typing.Union[numpy.ndarray[int], numpy.ndarray[float]],
+ tax_base: NumericalArray,
right: bool = False,
- ) -> numpy.ndarray[float]:
+ ) -> numpy.float_:
"""
- Matches the input amount to a set of brackets and returns the single cell value
- that fits within that bracket.
+ Matches the input amount to a set of brackets and returns the single
+ cell value that fits within that bracket.
"""
- guarded_thresholds = numpy.array([-numpy.inf] + self.thresholds + [numpy.inf])
- bracket_indices = numpy.digitize(tax_base, guarded_thresholds, right = right)
- guarded_amounts = numpy.array([0] + self.amounts + [0])
+ guarded_thresholds = numpy.array(
+ [-numpy.inf]
+ + self.thresholds
+ + [numpy.inf]
+ )
+
+ bracket_indices = numpy.digitize(
+ tax_base,
+ guarded_thresholds,
+ right = right,
+ )
+
+ guarded_amounts = numpy.array(
+ [0]
+ + self.amounts
+ + [0]
+ )
+
return guarded_amounts[bracket_indices - 1]
diff --git a/openfisca_core/taxscales/tax_scale_like.py b/openfisca_core/taxscales/tax_scale_like.py
index 513e0bcdfd..8177ee0505 100644
--- a/openfisca_core/taxscales/tax_scale_like.py
+++ b/openfisca_core/taxscales/tax_scale_like.py
@@ -8,20 +8,28 @@
from openfisca_core import commons
+if typing.TYPE_CHECKING:
+ NumericalArray = typing.Union[numpy.int_, numpy.float_]
+
class TaxScaleLike(abc.ABC):
"""
- Base class for various types of tax scales: amount-based tax scales, rate-based
- tax scales...
+ Base class for various types of tax scales: amount-based tax scales,
+ rate-based tax scales...
"""
- name: str
- option: None
- unit: None
+ name: typing.Optional[str]
+ option: typing.Any
+ unit: typing.Any
thresholds: typing.List
@abc.abstractmethod
- def __init__(self, name: typing.Optional[str] = None, option = None, unit = None) -> None:
+ def __init__(
+ self,
+ name: typing.Optional[str] = None,
+ option: typing.Any = None,
+ unit: typing.Any = None,
+ ) -> None:
self.name = name or "Untitled TaxScale"
self.option = option
self.unit = unit
@@ -46,9 +54,9 @@ def __repr__(self) -> str:
@abc.abstractmethod
def calc(
self,
- tax_base: typing.Union[numpy.ndarray[int], numpy.ndarray[float]],
+ tax_base: NumericalArray,
right: bool,
- ) -> numpy.ndarray[float]:
+ ) -> numpy.float_:
...
@abc.abstractmethod
diff --git a/openfisca_core/tracers/computation_log.py b/openfisca_core/tracers/computation_log.py
index 7fe4d23e81..51baa4f439 100644
--- a/openfisca_core/tracers/computation_log.py
+++ b/openfisca_core/tracers/computation_log.py
@@ -1,61 +1,101 @@
+from __future__ import annotations
+
import typing
import numpy
+import numpy.typing
from openfisca_core.indexed_enums import EnumArray
+if typing.TYPE_CHECKING:
+ from openfisca_core.tracers import FullTracer, TraceNode
+
+ Array = typing.Union[EnumArray, numpy.typing.ArrayLike]
+
class ComputationLog:
- def __init__(self, full_tracer):
+ _full_tracer: FullTracer
+
+ def __init__(self, full_tracer: FullTracer) -> None:
self._full_tracer = full_tracer
- def display(self, value):
+ def display(
+ self,
+ value: typing.Optional[Array],
+ ) -> str:
if isinstance(value, EnumArray):
value = value.decode_to_str()
return numpy.array2string(value, max_line_width = float("inf"))
- def _get_node_log(self, node, depth, aggregate) -> typing.List[str]:
+ def _get_node_log(
+ self,
+ node: TraceNode,
+ depth: int,
+ aggregate: bool,
+ ) -> typing.List[str]:
- def print_line(depth, node) -> str:
+ def print_line(depth: int, node: TraceNode) -> str:
+ indent = ' ' * depth
value = node.value
- if aggregate:
+
+ if value is None:
+ formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}"
+
+ elif aggregate:
try:
- formatted_value = str({'avg': numpy.mean(value), 'max': numpy.max(value), 'min': numpy.min(value)})
+ formatted_value = str({
+ 'avg': numpy.mean(value),
+ 'max': numpy.max(value),
+ 'min': numpy.min(value),
+ })
+
except TypeError:
formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}"
+
else:
formatted_value = self.display(value)
- return "{}{}<{}> >> {}".format(' ' * depth, node.name, node.period, formatted_value)
-
- # if not self.trace.get(node):
- # return print_line(depth, node, "Calculation aborted due to a circular dependency")
+ return f"{indent}{node.name}<{node.period}> >> {formatted_value}"
node_log = [print_line(depth, node)]
- children_logs = self._flatten(
+
+ children_logs = [
self._get_node_log(child, depth + 1, aggregate)
- for child in node.children
- )
+ for child
+ in node.children
+ ]
- return node_log + children_logs
+ return node_log + self._flatten(children_logs)
- def _flatten(self, list_of_lists):
+ def _flatten(
+ self,
+ list_of_lists: typing.List[typing.List[str]],
+ ) -> typing.List[str]:
return [item for _list in list_of_lists for item in _list]
- def lines(self, aggregate = False) -> typing.List[str]:
+ def lines(self, aggregate: bool = False) -> typing.List[str]:
depth = 1
- lines_by_tree = [self._get_node_log(node, depth, aggregate) for node in self._full_tracer.trees]
+
+ lines_by_tree = [
+ self._get_node_log(node, depth, aggregate)
+ for node
+ in self._full_tracer.trees
+ ]
+
return self._flatten(lines_by_tree)
- def print_log(self, aggregate = False):
+ def print_log(self, aggregate = False) -> None:
"""
Print the computation log of a simulation.
- If ``aggregate`` is ``False`` (default), print the value of each computed vector.
+ If ``aggregate`` is ``False`` (default), print the value of each
+ computed vector.
+
+ If ``aggregate`` is ``True``, only print the minimum, maximum, and
+ average value of each computed vector.
- If ``aggregate`` is ``True``, only print the minimum, maximum, and average value of each computed vector.
This mode is more suited for simulations on a large population.
"""
for line in self.lines(aggregate):
diff --git a/openfisca_core/tracers/flat_trace.py b/openfisca_core/tracers/flat_trace.py
index bbd9b1fa7d..362c386137 100644
--- a/openfisca_core/tracers/flat_trace.py
+++ b/openfisca_core/tracers/flat_trace.py
@@ -3,16 +3,22 @@
import typing
import numpy
+import numpy.typing
from openfisca_core.indexed_enums import EnumArray
if typing.TYPE_CHECKING:
- from openfisca_core.tracers import TraceNode
+ from openfisca_core.tracers import TraceNode, FullTracer
+
+ Array = typing.Union[EnumArray, numpy.typing.ArrayLike]
+ Trace = typing.Dict[str, dict]
class FlatTrace:
- def __init__(self, full_tracer):
+ _full_tracer: FullTracer
+
+ def __init__(self, full_tracer: FullTracer) -> None:
self._full_tracer = full_tracer
def key(self, node: TraceNode) -> str:
@@ -20,17 +26,23 @@ def key(self, node: TraceNode) -> str:
period = node.period
return f"{name}<{period}>"
- def get_trace(self):
+ def get_trace(self) -> dict:
trace = {}
+
for node in self._full_tracer.browse_trace():
- trace.update({ # We don't want cache read to overwrite data about the initial calculation. We therefore use a non-overwriting update.
+ # We don't want cache read to overwrite data about the initial
+ # calculation.
+ #
+ # We therefore use a non-overwriting update.
+ trace.update({
key: node_trace
for key, node_trace in self._get_flat_trace(node).items()
if key not in trace
})
+
return trace
- def get_serialized_trace(self):
+ def get_serialized_trace(self) -> dict:
return {
key: {
**flat_trace,
@@ -39,16 +51,26 @@ def get_serialized_trace(self):
for key, flat_trace in self.get_trace().items()
}
- def serialize(self, value: numpy.ndarray) -> typing.List:
+ def serialize(
+ self,
+ value: typing.Optional[Array],
+ ) -> typing.Union[typing.Optional[Array], list]:
if isinstance(value, EnumArray):
value = value.decode_to_str()
- if isinstance(value, numpy.ndarray) and numpy.issubdtype(value.dtype, numpy.dtype(bytes)):
+
+ if isinstance(value, numpy.ndarray) and \
+ numpy.issubdtype(value.dtype, numpy.dtype(bytes)):
value = value.astype(numpy.dtype(str))
+
if isinstance(value, numpy.ndarray):
value = value.tolist()
+
return value
- def _get_flat_trace(self, node: TraceNode) -> typing.Dict[str, typing.Dict]:
+ def _get_flat_trace(
+ self,
+ node: TraceNode,
+ ) -> Trace:
key = self.key(node)
node_trace = {
@@ -57,11 +79,15 @@ def _get_flat_trace(self, node: TraceNode) -> typing.Dict[str, typing.Dict]:
self.key(child) for child in node.children
],
'parameters': {
- self.key(parameter): self.serialize(parameter.value) for parameter in node.parameters
+ self.key(parameter):
+ self.serialize(parameter.value)
+ for parameter
+ in node.parameters
},
'value': node.value,
'calculation_time': node.calculation_time(),
'formula_time': node.formula_time(),
},
}
+
return node_trace
diff --git a/openfisca_core/tracers/full_tracer.py b/openfisca_core/tracers/full_tracer.py
index ab109fb9c9..56e695b4e0 100644
--- a/openfisca_core/tracers/full_tracer.py
+++ b/openfisca_core/tracers/full_tracer.py
@@ -1,9 +1,8 @@
+from __future__ import annotations
+
import time
import typing
-import numpy
-
-# from openfisca_core import tracers
from openfisca_core.tracers import (
ComputationLog,
FlatTrace,
@@ -12,71 +11,117 @@
TraceNode,
)
+if typing.TYPE_CHECKING:
+ import numpy
+ import numpy.typing
+
+ from openfisca_core.periods import Period
+
+ Stack = typing.List[typing.Dict[str, typing.Union[str, Period]]]
+
class FullTracer:
- def __init__(self):
+ _simple_tracer: SimpleTracer
+ _trees: list
+ _current_node: typing.Optional[TraceNode]
+
+ def __init__(self) -> None:
self._simple_tracer = SimpleTracer()
self._trees = []
self._current_node = None
- def record_calculation_start(self, variable: str, period):
+ def record_calculation_start(
+ self,
+ variable: str,
+ period: Period,
+ ) -> None:
self._simple_tracer.record_calculation_start(variable, period)
self._enter_calculation(variable, period)
self._record_start_time()
- def _enter_calculation(self, variable: str, period):
- new_node = TraceNode(name = variable, period = period, parent = self._current_node)
+ def _enter_calculation(
+ self,
+ variable: str,
+ period: Period,
+ ) -> None:
+ new_node = TraceNode(
+ name = variable,
+ period = period,
+ parent = self._current_node,
+ )
+
if self._current_node is None:
self._trees.append(new_node)
+
else:
self._current_node.append_child(new_node)
- self._current_node = new_node
- def record_parameter_access(self, parameter: str, period, value):
- self._current_node.parameters.append(TraceNode(name = parameter, period = period, value = value))
+ self._current_node = new_node
- def _record_start_time(self, time_in_s: typing.Optional[float] = None):
+ def record_parameter_access(
+ self,
+ parameter: str,
+ period: Period,
+ value: numpy.typing.ArrayLike,
+ ) -> None:
+
+ if self._current_node is not None:
+ self._current_node.parameters.append(
+ TraceNode(name = parameter, period = period, value = value),
+ )
+
+ def _record_start_time(
+ self,
+ time_in_s: typing.Optional[float] = None,
+ ) -> None:
if time_in_s is None:
time_in_s = self._get_time_in_sec()
- self._current_node.start = time_in_s
+ if self._current_node is not None:
+ self._current_node.start = time_in_s
- def record_calculation_result(self, value: numpy.ndarray):
- self._current_node.value = value
+ def record_calculation_result(self, value: numpy.typing.ArrayLike) -> None:
+ if self._current_node is not None:
+ self._current_node.value = value
- def record_calculation_end(self):
+ def record_calculation_end(self) -> None:
self._simple_tracer.record_calculation_end()
self._record_end_time()
self._exit_calculation()
- def _record_end_time(self, time_in_s: typing.Optional[float] = None):
+ def _record_end_time(
+ self,
+ time_in_s: typing.Optional[float] = None,
+ ) -> None:
if time_in_s is None:
time_in_s = self._get_time_in_sec()
- self._current_node.end = time_in_s
+ if self._current_node is not None:
+ self._current_node.end = time_in_s
- def _exit_calculation(self):
- self._current_node = self._current_node.parent
+ def _exit_calculation(self) -> None:
+ if self._current_node is not None:
+ self._current_node = self._current_node.parent
@property
- def stack(self):
+ def stack(self) -> Stack:
return self._simple_tracer.stack
@property
- def trees(self):
+ def trees(self) -> typing.List[TraceNode]:
return self._trees
@property
- def computation_log(self):
+ def computation_log(self) -> ComputationLog:
return ComputationLog(self)
@property
- def performance_log(self):
+ def performance_log(self) -> PerformanceLog:
return PerformanceLog(self)
@property
- def flat_trace(self):
+ def flat_trace(self) -> FlatTrace:
return FlatTrace(self)
def _get_time_in_sec(self) -> float:
@@ -91,24 +136,34 @@ def generate_performance_graph(self, dir_path: str) -> None:
def generate_performance_tables(self, dir_path: str) -> None:
self.performance_log.generate_performance_tables(dir_path)
- def _get_nb_requests(self, tree, variable: str):
+ def _get_nb_requests(self, tree: TraceNode, variable: str) -> int:
tree_call = tree.name == variable
- children_calls = sum(self._get_nb_requests(child, variable) for child in tree.children)
+ children_calls = sum(
+ self._get_nb_requests(child, variable)
+ for child
+ in tree.children
+ )
return tree_call + children_calls
- def get_nb_requests(self, variable: str):
- return sum(self._get_nb_requests(tree, variable) for tree in self.trees)
+ def get_nb_requests(self, variable: str) -> int:
+ return sum(
+ self._get_nb_requests(tree, variable)
+ for tree
+ in self.trees
+ )
- def get_flat_trace(self):
+ def get_flat_trace(self) -> dict:
return self.flat_trace.get_trace()
- def get_serialized_flat_trace(self):
+ def get_serialized_flat_trace(self) -> dict:
return self.flat_trace.get_serialized_trace()
def browse_trace(self) -> typing.Iterator[TraceNode]:
+
def _browse_node(node):
yield node
+
for child in node.children:
yield from _browse_node(child)
diff --git a/openfisca_core/tracers/performance_log.py b/openfisca_core/tracers/performance_log.py
index 7ab8c09965..cdd66a286b 100644
--- a/openfisca_core/tracers/performance_log.py
+++ b/openfisca_core/tracers/performance_log.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import csv
import importlib.resources
import itertools
@@ -7,41 +9,82 @@
from openfisca_core.tracers import TraceNode
+if typing.TYPE_CHECKING:
+ from openfisca_core.tracers import FullTracer
+
+ Trace = typing.Dict[str, dict]
+ Calculation = typing.Tuple[str, dict]
+ SortedTrace = typing.List[Calculation]
+
class PerformanceLog:
- def __init__(self, full_tracer):
+ def __init__(self, full_tracer: FullTracer) -> None:
self._full_tracer = full_tracer
- def generate_graph(self, dir_path):
+ def generate_graph(self, dir_path: str) -> None:
with open(os.path.join(dir_path, 'performance_graph.html'), 'w') as f:
- template = importlib.resources.read_text('openfisca_core.scripts.assets', 'index.html')
- perf_graph_html = template.replace('{{data}}', json.dumps(self._json()))
+ template = importlib.resources.read_text(
+ 'openfisca_core.scripts.assets',
+ 'index.html',
+ )
+
+ perf_graph_html = template.replace(
+ '{{data}}',
+ json.dumps(self._json()),
+ )
+
f.write(perf_graph_html)
def generate_performance_tables(self, dir_path: str) -> None:
-
flat_trace = self._full_tracer.get_flat_trace()
csv_rows = [
- {'name': key, 'calculation_time': trace['calculation_time'], 'formula_time': trace['formula_time']}
- for key, trace in flat_trace.items()
+ {
+ 'name': key,
+ 'calculation_time': trace['calculation_time'],
+ 'formula_time': trace['formula_time'],
+ }
+ for key, trace
+ in flat_trace.items()
]
- self._write_csv(os.path.join(dir_path, 'performance_table.csv'), csv_rows)
+
+ self._write_csv(
+ os.path.join(dir_path, 'performance_table.csv'),
+ csv_rows,
+ )
aggregated_csv_rows = [
{'name': key, **aggregated_time}
- for key, aggregated_time in self.aggregate_calculation_times(flat_trace).items()
+ for key, aggregated_time
+ in self.aggregate_calculation_times(flat_trace).items()
]
- self._write_csv(os.path.join(dir_path, 'aggregated_performance_table.csv'), aggregated_csv_rows)
+ self._write_csv(
+ os.path.join(dir_path, 'aggregated_performance_table.csv'),
+ aggregated_csv_rows,
+ )
- def aggregate_calculation_times(self, flat_trace: typing.Dict) -> typing.Dict[str, typing.Dict]:
+ def aggregate_calculation_times(
+ self,
+ flat_trace: Trace,
+ ) -> typing.Dict[str, dict]:
- def _aggregate_calculations(calculations):
+ def _aggregate_calculations(calculations: list) -> dict:
calculation_count = len(calculations)
- calculation_time = sum(calculation[1]['calculation_time'] for calculation in calculations)
- formula_time = sum(calculation[1]['formula_time'] for calculation in calculations)
+
+ calculation_time = sum(
+ calculation[1]['calculation_time']
+ for calculation
+ in calculations
+ )
+
+ formula_time = sum(
+ calculation[1]['formula_time']
+ for calculation
+ in calculations
+ )
+
return {
'calculation_count': calculation_count,
'calculation_time': TraceNode.round(calculation_time),
@@ -50,26 +93,43 @@ def _aggregate_calculations(calculations):
'avg_formula_time': TraceNode.round(formula_time / calculation_count),
}
- all_calculations = sorted(flat_trace.items())
+ def _groupby(calculation: Calculation) -> str:
+ return calculation[0].split('<')[0]
+
+ all_calculations: SortedTrace = sorted(flat_trace.items())
+
return {
variable_name: _aggregate_calculations(list(calculations))
- for variable_name, calculations in itertools.groupby(all_calculations, lambda calculation: calculation[0].split('<')[0])
+ for variable_name, calculations
+ in itertools.groupby(all_calculations, _groupby)
}
- def _json(self):
+ def _json(self) -> dict:
children = [self._json_tree(tree) for tree in self._full_tracer.trees]
calculations_total_time = sum(child['value'] for child in children)
- return {'name': 'All calculations', 'value': calculations_total_time, 'children': children}
- def _json_tree(self, tree: TraceNode):
+ return {
+ 'name': 'All calculations',
+ 'value': calculations_total_time,
+ 'children': children,
+ }
+
+ def _json_tree(self, tree: TraceNode) -> dict:
calculation_total_time = tree.calculation_time()
children = [self._json_tree(child) for child in tree.children]
- return {'name': f"{tree.name}<{tree.period}>", 'value': calculation_total_time, 'children': children}
- def _write_csv(self, path: str, rows: typing.List[typing.Dict[str, typing.Any]]) -> None:
+ return {
+ 'name': f"{tree.name}<{tree.period}>",
+ 'value': calculation_total_time,
+ 'children': children,
+ }
+
+ def _write_csv(self, path: str, rows: typing.List[dict]) -> None:
fieldnames = list(rows[0].keys())
+
with open(path, 'w') as csv_file:
writer = csv.DictWriter(csv_file, fieldnames = fieldnames)
writer.writeheader()
+
for row in rows:
writer.writerow(row)
diff --git a/openfisca_core/tracers/simple_tracer.py b/openfisca_core/tracers/simple_tracer.py
index 3e6c602487..c54b313f60 100644
--- a/openfisca_core/tracers/simple_tracer.py
+++ b/openfisca_core/tracers/simple_tracer.py
@@ -1,23 +1,35 @@
-import numpy
+from __future__ import annotations
+
+import typing
+
+if typing.TYPE_CHECKING:
+ import numpy
+ import numpy.typing
+
+ from openfisca_core.periods import Period
+
+ Stack = typing.List[typing.Dict[str, typing.Union[str, Period]]]
class SimpleTracer:
- def __init__(self):
+ _stack: Stack
+
+ def __init__(self) -> None:
self._stack = []
- def record_calculation_start(self, variable: str, period):
+ def record_calculation_start(self, variable: str, period: Period) -> None:
self.stack.append({'name': variable, 'period': period})
- def record_calculation_result(self, value: numpy.ndarray):
+ def record_calculation_result(self, value: numpy.typing.ArrayLike) -> None:
pass # ignore calculation result
def record_parameter_access(self, parameter: str, period, value):
pass
- def record_calculation_end(self):
+ def record_calculation_end(self) -> None:
self.stack.pop()
@property
- def stack(self):
+ def stack(self) -> Stack:
return self._stack
diff --git a/openfisca_core/tracers/trace_node.py b/openfisca_core/tracers/trace_node.py
index 71f15cfb93..93b630886c 100644
--- a/openfisca_core/tracers/trace_node.py
+++ b/openfisca_core/tracers/trace_node.py
@@ -3,9 +3,14 @@
import dataclasses
import typing
-import numpy
+if typing.TYPE_CHECKING:
+ import numpy
-from openfisca_core.periods import Period
+ from openfisca_core.indexed_enums import EnumArray
+ from openfisca_core.periods import Period
+
+ Array = typing.Union[EnumArray, numpy.typing.ArrayLike]
+ Time = typing.Union[float, int]
@dataclasses.dataclass
@@ -15,23 +20,35 @@ class TraceNode:
parent: typing.Optional[TraceNode] = None
children: typing.List[TraceNode] = dataclasses.field(default_factory = list)
parameters: typing.List[TraceNode] = dataclasses.field(default_factory = list)
- value: numpy.ndarray = None
+ value: typing.Optional[Array] = None
start: float = 0
end: float = 0
- def calculation_time(self, round_ = True):
+ def calculation_time(self, round_: bool = True) -> Time:
result = self.end - self.start
+
if round_:
return self.round(result)
+
return result
- def formula_time(self):
- result = self.calculation_time(round_ = False) - sum(child.calculation_time(round_ = False) for child in self.children)
+ def formula_time(self) -> float:
+ children_calculation_time = sum(
+ child.calculation_time(round_ = False)
+ for child
+ in self.children
+ )
+
+ result = (
+ + self.calculation_time(round_ = False)
+ - children_calculation_time
+ )
+
return self.round(result)
- def append_child(self, node: TraceNode):
+ def append_child(self, node: TraceNode) -> None:
self.children.append(node)
@staticmethod
- def round(time):
+ def round(time: Time) -> float:
return float(f'{time:.4g}') # Keep only 4 significant figures
diff --git a/openfisca_core/tracers/tracing_parameter_node_at_instant.py b/openfisca_core/tracers/tracing_parameter_node_at_instant.py
index 2595989467..94abf6b2ca 100644
--- a/openfisca_core/tracers/tracing_parameter_node_at_instant.py
+++ b/openfisca_core/tracers/tracing_parameter_node_at_instant.py
@@ -1,3 +1,7 @@
+from __future__ import annotations
+
+import typing
+
import numpy
from openfisca_core.parameters import (
@@ -6,30 +10,70 @@
VectorialParameterNodeAtInstant,
)
+if typing.TYPE_CHECKING:
+ import numpy.typing
+
+ from openfisca_core.tracers import FullTracer
+
+ ParameterNode = typing.Union[
+ VectorialParameterNodeAtInstant,
+ ParameterNodeAtInstant,
+ ]
+
+ Child = typing.Union[ParameterNode, numpy.typing.ArrayLike]
+
class TracingParameterNodeAtInstant:
- def __init__(self, parameter_node_at_instant, tracer):
+ def __init__(
+ self,
+ parameter_node_at_instant: ParameterNode,
+ tracer: FullTracer,
+ ) -> None:
self.parameter_node_at_instant = parameter_node_at_instant
self.tracer = tracer
- def __getattr__(self, key):
+ def __getattr__(
+ self,
+ key: str,
+ ) -> typing.Union[TracingParameterNodeAtInstant, Child]:
child = getattr(self.parameter_node_at_instant, key)
return self.get_traced_child(child, key)
- def __getitem__(self, key):
+ def __getitem__(
+ self,
+ key: typing.Union[str, numpy.typing.ArrayLike],
+ ) -> typing.Union[TracingParameterNodeAtInstant, Child]:
child = self.parameter_node_at_instant[key]
return self.get_traced_child(child, key)
- def get_traced_child(self, child, key):
+ def get_traced_child(
+ self,
+ child: Child,
+ key: typing.Union[str, numpy.typing.ArrayLike],
+ ) -> typing.Union[TracingParameterNodeAtInstant, Child]:
period = self.parameter_node_at_instant._instant_str
- if isinstance(child, (ParameterNodeAtInstant, VectorialParameterNodeAtInstant)):
+
+ if isinstance(
+ child,
+ (ParameterNodeAtInstant, VectorialParameterNodeAtInstant),
+ ):
return TracingParameterNodeAtInstant(child, self.tracer)
- if not isinstance(key, str) or isinstance(self.parameter_node_at_instant, VectorialParameterNodeAtInstant):
- # In case of vectorization, we keep the parent node name as, for instance, rate[status].zone1 is best described as the value of "rate"
+
+ if not isinstance(key, str) or \
+ isinstance(
+ self.parameter_node_at_instant,
+ VectorialParameterNodeAtInstant,
+ ):
+ # In case of vectorization, we keep the parent node name as, for
+ # instance, rate[status].zone1 is best described as the value of
+ # "rate".
name = self.parameter_node_at_instant._name
+
else:
name = '.'.join([self.parameter_node_at_instant._name, key])
+
if isinstance(child, (numpy.ndarray,) + config.ALLOWED_PARAM_TYPES):
self.tracer.record_parameter_access(name, period, child)
+
return child
diff --git a/openfisca_core/variables/config.py b/openfisca_core/variables/config.py
index edda2664b6..b260bb3dd9 100644
--- a/openfisca_core/variables/config.py
+++ b/openfisca_core/variables/config.py
@@ -8,7 +8,7 @@
VALUE_TYPES = {
bool: {
- 'dtype': numpy.bool,
+ 'dtype': numpy.bool_,
'default': False,
'json_type': 'boolean',
'formatted_value_type': 'Boolean',
diff --git a/setup.py b/setup.py
index 3411528df7..cb28724032 100644
--- a/setup.py
+++ b/setup.py
@@ -8,7 +8,7 @@
general_requirements = [
'dpath >= 1.5.0, < 2.0.0',
'pytest >= 4.4.1, < 6.0.0', # For openfisca test
- 'numpy >= 1.11, < 1.19',
+ 'numpy >= 1.11, < 1.21',
'psutil >= 5.4.7, < 6.0.0',
'PyYAML >= 3.10',
'sortedcontainers == 2.2.2',
@@ -35,7 +35,7 @@
setup(
name = 'OpenFisca-Core',
- version = '35.3.6',
+ version = '35.3.7',
author = 'OpenFisca Team',
author_email = 'contact@openfisca.org',
classifiers = [