Skip to content

Commit

Permalink
test: test median ignores null values
Browse files Browse the repository at this point in the history
  • Loading branch information
AlessandroMiola committed Oct 24, 2024
1 parent c49127e commit c762c6d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 17 deletions.
9 changes: 0 additions & 9 deletions narwhals/_polars/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,6 @@ def mean_horizontal(self, *exprs: IntoPolarsExpr) -> PolarsExpr:
dtypes=self._dtypes,
)

def median(self, *column_names: str) -> PolarsExpr:
import polars as pl # ignore-banned-import()

from narwhals._polars.expr import PolarsExpr

if self._backend_version < (0, 20, 4): # pragma: no cover
return PolarsExpr(pl.median([*column_names]), dtypes=self._dtypes) # type: ignore[arg-type]
return PolarsExpr(pl.median(*column_names), dtypes=self._dtypes)

def concat_str(
self,
exprs: Iterable[IntoPolarsExpr],
Expand Down
22 changes: 14 additions & 8 deletions tests/expr_and_series/median_test.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
from __future__ import annotations

from typing import Any

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import compare_dicts
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

data = {"a": [3, 8, 2], "b": [5, 5, 7], "z": [7.0, 8, 9]}
data = {
"a": [3, 8, 2, None],
"b": [5, 5, None, 7],
"z": [7.0, 8, 9, None],
"s": ["f", "a", "x", "x"],
}


@pytest.mark.parametrize(
"expr", [nw.col("a", "b", "z").median(), nw.median("a", "b", "z")]
)
def test_expr_median_expr(
def test_median_expr(
constructor: Constructor, expr: nw.Expr, request: pytest.FixtureRequest
) -> None:
if "dask_lazy_p2" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(expr)
expected = {"a": [3.0], "b": [5.0], "z": [8.0]}
compare_dicts(result, expected)
assert_equal_data(result, expected)


@pytest.mark.parametrize(("col", "expected"), [("a", 3.0), ("b", 5.0), ("z", 8.0)])
def test_expr_median_series(constructor_eager: Any, col: str, expected: float) -> None:
def test_median_series(
constructor_eager: ConstructorEager, col: str, expected: float
) -> None:
series = nw.from_native(constructor_eager(data), eager_only=True)[col]
result = series.median()
compare_dicts({col: [result]}, {col: [expected]})
assert_equal_data({col: [result]}, {col: [expected]})

0 comments on commit c762c6d

Please sign in to comment.