Skip to content

Commit

Permalink
Don't allow passing missing data to generalized ufuncs.
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonspeed committed May 13, 2024
1 parent 9bfa30c commit 8f870a4
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 12 deletions.
36 changes: 24 additions & 12 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
from polars.dependencies import numpy as np
from polars.dependencies import pandas as pd
from polars.dependencies import pyarrow as pa
from polars.exceptions import ModuleUpgradeRequired, ShapeError
from polars.exceptions import ComputeError, ModuleUpgradeRequired, ShapeError
from polars.meta import get_index_type
from polars.series.array import ArrayNameSpace
from polars.series.binary import BinaryNameSpace
Expand Down Expand Up @@ -1295,7 +1295,7 @@ def __getitem__(

def __getitem__(
self,
item: (int | Series | range | slice | np.ndarray[Any, Any] | list[int]),
item: int | Series | range | slice | np.ndarray[Any, Any] | list[int],
) -> Any:
if isinstance(item, Series) and item.dtype.is_integer():
return self._take_with_series(item._pos_idxs(self.len()))
Expand Down Expand Up @@ -1404,13 +1404,10 @@ def __array_ufunc__(
raise NotImplementedError(msg)

args: list[int | float | np.ndarray[Any, Any]] = []

validity_mask = self.is_not_null()
for arg in inputs:
if isinstance(arg, (int, float, np.ndarray)):
args.append(arg)
elif isinstance(arg, Series):
validity_mask &= arg.is_not_null()
args.append(arg.to_physical()._s.to_numpy_view())
else:
msg = f"unsupported type {type(arg).__name__!r} for {arg!r}"
Expand Down Expand Up @@ -1443,6 +1440,13 @@ def __array_ufunc__(
else dtype_char_minimum
)

if ufunc.signature:
# Only generalized ufuncs have a signature set, and they're the
# ones that have problems with missing data.
if self.null_count() > 0:
msg = "Can't pass a Series with missing data to a generalized ufunc, as it might give unexpected results. See https://docs.pola.rs/user-guide/expressions/missing-data/ for suggestions on how to remove or fill in missing data."
raise ComputeError(msg)

f = get_ffi_func("apply_ufunc_<>", numpy_char_code_to_dtype(dtype_char), s)

if f is None:
Expand All @@ -1453,12 +1457,20 @@ def __array_ufunc__(
raise NotImplementedError(msg)

series = f(lambda out: ufunc(*args, out=out, dtype=dtype_char, **kwargs))
return (
self._from_pyseries(series)
.to_frame()
.select(F.when(validity_mask).then(F.col(self.name)))
.to_series(0)
)
result = self._from_pyseries(series)
if not ufunc.signature:
# Missing data is allowed, so filter it out:
validity_mask = self.is_not_null()
for arg in inputs:
if isinstance(arg, Series):
validity_mask &= arg.is_not_null()

result = (
result.to_frame()
.select(F.when(validity_mask).then(F.col(self.name)))
.to_series(0)
)
return result
else:
msg = (
"only `__call__` is implemented for numpy ufuncs on a Series, got "
Expand Down Expand Up @@ -4143,7 +4155,7 @@ def equals(

def cast(
self,
dtype: (PolarsDataType | type[int] | type[float] | type[str] | type[bool]),
dtype: PolarsDataType | type[int] | type[float] | type[str] | type[bool],
*,
strict: bool = True,
) -> Self:
Expand Down
1 change: 1 addition & 0 deletions py-polars/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pip

# Interoperability
numpy
numba; python_version < '3.13' # Numba can lag Python releases
pandas
pyarrow
pydantic>=2.0.0
Expand Down
28 changes: 28 additions & 0 deletions py-polars/tests/unit/interop/numpy/_numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Infrastructure for testing Numba.
Numba releases often lag for a few months after Python releases, so we don't
want Numba to be a blocker for Python 3.X support. So this minimally emulates
the Numba module, while allowing for the fact that Numba may not be installed.
"""

import pytest

try:
from numba import float64, guvectorize # type: ignore[import-untyped]
except ImportError:
float64 = []

def guvectorize(_a, _b): # type: ignore[no-untyped-def]
"""When Numba is unavailable, skip tests using the decorated function."""

def decorator(_): # type: ignore[no-untyped-def]
def skip(*_args, **_kwargs): # type: ignore[no-untyped-def]
pytest.skip("Numba not available")

return skip

return decorator


__all__ = ["guvectorize", "float64"]
22 changes: 22 additions & 0 deletions py-polars/tests/unit/interop/numpy/test_ufunc_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import polars as pl
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.interop.numpy._numba import float64, guvectorize


def test_ufunc() -> None:
Expand Down Expand Up @@ -130,3 +131,24 @@ def test_ufunc_multiple_expressions() -> None:
def test_grouped_ufunc() -> None:
df = pl.DataFrame({"id": ["a", "a", "b", "b"], "values": [0.1, 0.1, -0.1, -0.1]})
df.group_by("id").agg(pl.col("values").log1p().sum().pipe(np.expm1))


@guvectorize([(float64[:], float64[:])], "(n)->(n)")
def gufunc_mean(arr, result): # type: ignore[no-untyped-def]
mean = arr.mean()
for i in range(len(arr)):
result[i] = mean + i


def test_generalized_ufunc() -> None:
df = pl.DataFrame({"s": [1.0, 2.0, 3.0]})
result = df.select([pl.col("s").map_batches(gufunc_mean).alias("result")])
expected = pl.DataFrame({"result": [2.0, 3.0, 4.0]})
assert_frame_equal(result, expected)


def test_grouped_generalized_ufunc() -> None:
df = pl.DataFrame({"id": ["a", "a", "b", "b"], "values": [1.0, 2.0, 3.0, 4.0]})
result = df.group_by("id").agg(pl.col("values").map_batches(gufunc_mean)).sort("id")
expected = pl.DataFrame({"id": ["a", "b"], "values": [[1.5, 2.5], [3.5, 4.5]]})
assert_frame_equal(result, expected)
30 changes: 30 additions & 0 deletions py-polars/tests/unit/interop/numpy/test_ufunc_series.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import cast

import numpy as np
import pytest
from numpy.testing import assert_array_equal

import polars as pl
from polars.testing import assert_series_equal
from tests.unit.interop.numpy._numba import float64, guvectorize


def test_ufunc() -> None:
Expand Down Expand Up @@ -119,3 +121,31 @@ def test_numpy_string_array() -> None:
np.char.capitalize(s_str),
np.array(["Aa", "Bb", "Cc", "Dd"], dtype="<U2"),
)


@guvectorize([(float64[:], float64[:])], "(n)->(n)")
def add_one(arr, result): # type: ignore[no-untyped-def]
for i in range(len(arr)):
result[i] = arr[i] + 1.0


def test_generalized_ufunc() -> None:
"""A generalized ufunc can be called on a pl.Series."""
s_float = pl.Series("f", [1.0, 2.0, 3.0])
result = add_one(s_float)
assert_series_equal(result, pl.Series("f", [2.0, 3.0, 4.0]))


def test_generalized_ufunc_missing_data() -> None:
"""
If a pl.Series is missing data, using a generalized ufunc is not allowed.
While this particular example isn't necessarily a semantic issue, consider
a mean() function running on integers: it will give wrong results if the
input is missing data, since NumPy has no way to model missing slots. In
the general case, we can't assume the function will handle missing data
correctly.
"""
s_float = pl.Series("f", [1.0, 2.0, 3.0, None], dtype=pl.Float64)
with pytest.raises(pl.ComputeError):
add_one(s_float)

0 comments on commit 8f870a4

Please sign in to comment.