Skip to content

Commit

Permalink
Extend numpy support to v1.20
Browse files Browse the repository at this point in the history
Add support for M1 processors
  • Loading branch information
MattiSG committed Apr 16, 2021
2 parents 3123d27 + 15dd1ff commit 6c05349
Show file tree
Hide file tree
Showing 22 changed files with 615 additions and 224 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

### 35.3.7 [#990](https://github.com/openfisca/openfisca-core/pull/990)

#### Technical changes

- Update dependencies.
- Extend NumPy compatibility to v1.20 to support M1 processors.

### 35.3.6 [#984](https://github.com/openfisca/openfisca-core/pull/984)

#### Technical changes
Expand Down
54 changes: 34 additions & 20 deletions openfisca_core/indexed_enums/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,47 @@

from openfisca_core.indexed_enums import config, EnumArray

if typing.TYPE_CHECKING:
IndexedEnumArray = numpy.object_


class Enum(enum.Enum):
"""
Enum based on `enum34 <https://pypi.python.org/pypi/enum34/>`_, whose items have an
index.
Enum based on `enum34 <https://pypi.python.org/pypi/enum34/>`_, whose items
have an index.
"""

# Tweak enums to add an index attribute to each enum item
def __init__(self, name: str) -> None:
# When the enum item is initialized, self._member_names_ contains the names of
# the previously initialized items, so its length is the index of this item.
# When the enum item is initialized, self._member_names_ contains the
# names of the previously initialized items, so its length is the index
# of this item.
self.index = len(self._member_names_)

# Bypass the slow Enum.__eq__
__eq__ = object.__eq__

# In Python 3, __hash__ must be defined if __eq__ is defined to stay hashable.
# In Python 3, __hash__ must be defined if __eq__ is defined to stay
# hashable.
__hash__ = object.__hash__

@classmethod
def encode(
cls,
array: typing.Union[
EnumArray,
numpy.ndarray[int],
numpy.ndarray[str],
numpy.ndarray[Enum],
numpy.int_,
numpy.float_,
IndexedEnumArray,
],
) -> EnumArray:
"""
Encode a string numpy array, an enum item numpy array, or an int numpy array
into an :any:`EnumArray`. See :any:`EnumArray.decode` for decoding.
Encode a string numpy array, an enum item numpy array, or an int numpy
array into an :any:`EnumArray`. See :any:`EnumArray.decode` for
decoding.
:param ndarray array: Array of string identifiers, or of enum items, to encode.
:param ndarray array: Array of string identifiers, or of enum items, to
encode.
:returns: An :any:`EnumArray` encoding the input array values.
:rtype: :any:`EnumArray`
Expand All @@ -59,24 +66,31 @@ def encode(
>>> encoded_array[0]
2 # Encoded value
"""
if type(array) is EnumArray:
if isinstance(array, EnumArray):
return array

if array.dtype.kind in {'U', 'S'}: # String array
# String array
if isinstance(array, numpy.ndarray) and \
array.dtype.kind in {'U', 'S'}:
array = numpy.select(
[array == item.name for item in cls],
[item.index for item in cls],
).astype(config.ENUM_ARRAY_DTYPE)

elif array.dtype.kind == 'O': # Enum items arrays
# Enum items arrays
elif isinstance(array, numpy.ndarray) and \
array.dtype.kind == 'O':
# Ensure we are comparing the comparable. The problem this fixes:
# On entering this method "cls" will generally come from
# variable.possible_values, while the array values may come from directly
# importing a module containing an Enum class. However, variables (and
# hence their possible_values) are loaded by a call to load_module, which
# gives them a different identity from the ones imported in the usual way.
# So, instead of relying on the "cls" passed in, we use only its name to
# check that the values in the array, if non-empty, are of the right type.
# variable.possible_values, while the array values may come from
# directly importing a module containing an Enum class. However,
# variables (and hence their possible_values) are loaded by a call
# to load_module, which gives them a different identity from the
# ones imported in the usual way.
#
# So, instead of relying on the "cls" passed in, we use only its
# name to check that the values in the array, if non-empty, are of
# the right type.
if len(array) > 0 and cls.__name__ is array[0].__class__.__name__:
cls = array[0].__class__

Expand Down
22 changes: 13 additions & 9 deletions openfisca_core/indexed_enums/enum_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
if typing.TYPE_CHECKING:
from openfisca_core.indexed_enums import Enum

IndexedEnumArray = numpy.object_


class EnumArray(numpy.ndarray):
"""
Expand All @@ -20,23 +22,23 @@ class EnumArray(numpy.ndarray):
# https://docs.scipy.org/doc/numpy-1.13.0/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array.
def __new__(
cls,
input_array: numpy.ndarray[int],
input_array: numpy.int_,
possible_values: typing.Optional[typing.Type[Enum]] = None,
) -> EnumArray:
obj = numpy.asarray(input_array).view(cls)
obj.possible_values = possible_values
return obj

# See previous comment
def __array_finalize__(self, obj: typing.Optional[numpy.ndarray[int]]) -> None:
def __array_finalize__(self, obj: typing.Optional[numpy.int_]) -> None:
if obj is None:
return

self.possible_values = getattr(obj, "possible_values", None)

def __eq__(self, other: typing.Any) -> bool:
# When comparing to an item of self.possible_values, use the item index to
# speed up the comparison.
# When comparing to an item of self.possible_values, use the item index
# to speed up the comparison.
if other.__class__.__name__ is self.possible_values.__name__:
# Use view(ndarray) so that the result is a classic ndarray, not an
# EnumArray.
Expand All @@ -49,8 +51,8 @@ def __ne__(self, other: typing.Any) -> bool:

def _forbidden_operation(self, other: typing.Any) -> typing.NoReturn:
raise TypeError(
"Forbidden operation. The only operations allowed on EnumArrays are "
"'==' and '!='.",
"Forbidden operation. The only operations allowed on EnumArrays "
"are '==' and '!='.",
)

__add__ = _forbidden_operation
Expand All @@ -62,7 +64,7 @@ def _forbidden_operation(self, other: typing.Any) -> typing.NoReturn:
__and__ = _forbidden_operation
__or__ = _forbidden_operation

def decode(self) -> numpy.ndarray[Enum]:
def decode(self) -> IndexedEnumArray:
"""
Return the array of enum items corresponding to self.
Expand All @@ -72,14 +74,16 @@ def decode(self) -> numpy.ndarray[Enum]:
>>> enum_array[0]
>>> 2 # Encoded value
>>> enum_array.decode()[0]
<HousingOccupancyStatus.free_lodger: 'Free lodger'> # Decoded value : enum item
<HousingOccupancyStatus.free_lodger: 'Free lodger'>
Decoded value: enum item
"""
return numpy.select(
[self == item.index for item in self.possible_values],
list(self.possible_values),
)

def decode_to_str(self) -> numpy.ndarray[str]:
def decode_to_str(self) -> numpy.str_:
"""
Return the array of string identifiers corresponding to self.
Expand Down
26 changes: 17 additions & 9 deletions openfisca_core/taxscales/abstract_rate_tax_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,36 @@
import typing
import warnings

import numpy

from openfisca_core.taxscales import RateTaxScaleLike

if typing.TYPE_CHECKING:
import numpy

NumericalArray = typing.Union[numpy.int_, numpy.float_]


class AbstractRateTaxScale(RateTaxScaleLike):
"""
Base class for various types of rate-based tax scales: marginal rate, linear
average rate...
Base class for various types of rate-based tax scales: marginal rate,
linear average rate...
"""

def __init__(self, name: typing.Optional[str] = None, option = None, unit = None) -> None:
def __init__(
self, name: typing.Optional[str] = None,
option: typing.Any = None,
unit: typing.Any = None,
) -> None:
message = [
"The 'AbstractRateTaxScale' class has been deprecated since version",
"34.7.0, and will be removed in the future.",
"The 'AbstractRateTaxScale' class has been deprecated since",
"version 34.7.0, and will be removed in the future.",
]

warnings.warn(" ".join(message), DeprecationWarning)
super(AbstractRateTaxScale, self).__init__(name, option, unit)
super().__init__(name, option, unit)

def calc(
self,
tax_base: typing.Union[numpy.ndarray[int], numpy.ndarray[float]],
tax_base: NumericalArray,
right: bool,
) -> typing.NoReturn:
raise NotImplementedError(
Expand Down
28 changes: 19 additions & 9 deletions openfisca_core/taxscales/abstract_tax_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,34 @@
import typing
import warnings

import numpy

from openfisca_core.taxscales import TaxScaleLike

if typing.TYPE_CHECKING:
import numpy

NumericalArray = typing.Union[numpy.int_, numpy.float_]


class AbstractTaxScale(TaxScaleLike):
"""
Base class for various types of tax scales: amount-based tax scales, rate-based
tax scales...
Base class for various types of tax scales: amount-based tax scales,
rate-based tax scales...
"""

def __init__(self, name: typing.Optional[str] = None, option = None, unit = None) -> None:
def __init__(
self,
name: typing.Optional[str] = None,
option: typing.Any = None,
unit: numpy.int_ = None,
) -> None:

message = [
"The 'AbstractTaxScale' class has been deprecated since version 34.7.0,",
"and will be removed in the future.",
"The 'AbstractTaxScale' class has been deprecated since",
"version 34.7.0, and will be removed in the future.",
]

warnings.warn(" ".join(message), DeprecationWarning)
super(AbstractTaxScale, self).__init__(name, option, unit)
super().__init__(name, option, unit)

def __repr__(self) -> typing.NoReturn:
raise NotImplementedError(
Expand All @@ -30,7 +40,7 @@ def __repr__(self) -> typing.NoReturn:

def calc(
self,
tax_base: typing.Union[numpy.ndarray[int], numpy.ndarray[float]],
tax_base: NumericalArray,
right: bool,
) -> typing.NoReturn:
raise NotImplementedError(
Expand Down
20 changes: 14 additions & 6 deletions openfisca_core/taxscales/amount_tax_scale_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,28 @@

class AmountTaxScaleLike(TaxScaleLike, abc.ABC):
"""
Base class for various types of amount-based tax scales: single amount, marginal
amount...
Base class for various types of amount-based tax scales: single amount,
marginal amount...
"""

amounts: typing.List

def __init__(self, name: typing.Optional[str] = None, option = None, unit = None) -> None:
super(AmountTaxScaleLike, self).__init__(name, option, unit)
def __init__(
self,
name: typing.Optional[str] = None,
option: typing.Any = None,
unit: typing.Any = None,
) -> None:
super().__init__(name, option, unit)
self.amounts = []

def __repr__(self) -> str:
return tools.indent(
os.linesep.join(
[
f"- threshold: {threshold}{os.linesep} amount: {amount}"
for (threshold, amount) in zip(self.thresholds, self.amounts)
for (threshold, amount)
in zip(self.thresholds, self.amounts)
]
)
)
Expand All @@ -37,6 +43,7 @@ def add_bracket(
if threshold in self.thresholds:
i = self.thresholds.index(threshold)
self.amounts[i] += amount

else:
i = bisect.bisect_left(self.thresholds, threshold)
self.thresholds.insert(i, threshold)
Expand All @@ -45,5 +52,6 @@ def add_bracket(
def to_dict(self) -> dict:
return {
str(threshold): self.amounts[index]
for index, threshold in enumerate(self.thresholds)
for index, threshold
in enumerate(self.thresholds)
}
6 changes: 4 additions & 2 deletions openfisca_core/taxscales/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
if typing.TYPE_CHECKING:
from openfisca_core.parameters import ParameterNodeAtInstant

TaxScales = typing.Optional[taxscales.MarginalRateTaxScale]


def combine_tax_scales(
node: ParameterNodeAtInstant,
combined_tax_scales: typing.Optional[taxscales.MarginalRateTaxScale] = None,
) -> typing.Optional[taxscales.MarginalRateTaxScale]:
combined_tax_scales: TaxScales = None,
) -> TaxScales:
"""
Combine all the MarginalRateTaxScales in the node into a single
MarginalRateTaxScale.
Expand Down
7 changes: 5 additions & 2 deletions openfisca_core/taxscales/linear_average_rate_tax_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@

log = logging.getLogger(__name__)

if typing.TYPE_CHECKING:
NumericalArray = typing.Union[numpy.int_, numpy.float_]


class LinearAverageRateTaxScale(RateTaxScaleLike):

def calc(
self,
tax_base: typing.Union[numpy.ndarray[int], numpy.ndarray[float]],
tax_base: NumericalArray,
right: bool = False,
) -> numpy.ndarray[float]:
) -> numpy.float_:
if len(self.rates) == 1:
return tax_base * self.rates[0]

Expand Down
Loading

0 comments on commit 6c05349

Please sign in to comment.