Skip to content

Commit

Permalink
Ensure f32 -> f32, expand tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Mar 18, 2024
1 parent 2b0e43b commit 7e0627f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 36 deletions.
12 changes: 5 additions & 7 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2557,14 +2557,12 @@ impl DataFrame {
pub fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Series>> {
match self.columns.len() {
0 => Ok(None),
1 => {
let dtype = self.columns[0].dtype();
Ok(Some(if dtype.is_numeric() || dtype == &DataType::Boolean {
1 => Ok(Some(match self.columns[0].dtype() {
dt if dt != &DataType::Float32 && (dt.is_numeric() || dt == &DataType::Boolean) => {
self.columns[0].cast(&DataType::Float64)?
} else {
self.columns[0].clone()
}))
},
},
_ => self.columns[0].clone(),
})),
_ => {
let columns = self
.columns
Expand Down
35 changes: 20 additions & 15 deletions py-polars/tests/unit/operations/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
if TYPE_CHECKING:
import numpy.typing as npt

from polars.type_aliases import PolarsDataType


def test_quantile_expr_input() -> None:
df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [0, 0, 0.3, 0.2, 0]})
Expand Down Expand Up @@ -442,29 +444,32 @@ def test_grouping_hash_14749() -> None:


@pytest.mark.parametrize(
("dtype"),
("in_dtype", "out_dtype"),
[
pl.Boolean,
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
(pl.Boolean, pl.Float64),
(pl.UInt8, pl.Float64),
(pl.UInt16, pl.Float64),
(pl.UInt32, pl.Float64),
(pl.UInt64, pl.Float64),
(pl.Int8, pl.Float64),
(pl.Int16, pl.Float64),
(pl.Int32, pl.Float64),
(pl.Int64, pl.Float64),
(pl.Float32, pl.Float32),
(pl.Float64, pl.Float64),
],
)
def test_horizontal_mean_single_column(dtype) -> None:
def test_horizontal_mean_single_column(
in_dtype: PolarsDataType,
out_dtype: PolarsDataType,
) -> None:
out = (
pl.LazyFrame({"a": pl.Series([1, 0], dtype=dtype)})
pl.LazyFrame({"a": pl.Series([1, 0], dtype=in_dtype)})
.select(pl.mean_horizontal(pl.all()))
.collect()
)

assert_frame_equal(
out, pl.DataFrame({"a": pl.Series([1.0, 0.0], dtype=pl.Float64)})
)
assert_frame_equal(out, pl.DataFrame({"a": pl.Series([1.0, 0.0], dtype=out_dtype)}))


def test_horizontal_mean_in_groupby_15115() -> None:
Expand Down
36 changes: 22 additions & 14 deletions py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

from collections import OrderedDict
from datetime import date, timedelta
from typing import Any, Iterator, Mapping
from typing import TYPE_CHECKING, Any, Iterator, Mapping

import pytest

import polars as pl
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
from polars.type_aliases import PolarsDataType


class CustomSchema(Mapping[str, Any]):
"""Dummy schema object for testing compatibility with Mapping."""
Expand Down Expand Up @@ -635,22 +638,27 @@ def test_literal_subtract_schema_13284() -> None:


@pytest.mark.parametrize(
("dtype"),
("in_dtype", "out_dtype"),
[
pl.Boolean,
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
(pl.Boolean, pl.Float64),
(pl.UInt8, pl.Float64),
(pl.UInt16, pl.Float64),
(pl.UInt32, pl.Float64),
(pl.UInt64, pl.Float64),
(pl.Int8, pl.Float64),
(pl.Int16, pl.Float64),
(pl.Int32, pl.Float64),
(pl.Int64, pl.Float64),
(pl.Float32, pl.Float32),
(pl.Float64, pl.Float64),
],
)
def test_schema_mean_horizontal_single_column(dtype) -> None:
lf = pl.LazyFrame({"a": pl.Series([1, 0], dtype=dtype)}).select(
def test_schema_mean_horizontal_single_column(
in_dtype: PolarsDataType,
out_dtype: PolarsDataType,
) -> None:
lf = pl.LazyFrame({"a": pl.Series([1, 0], dtype=in_dtype)}).select(
pl.mean_horizontal(pl.all())
)

assert lf.schema == OrderedDict([("a", pl.Float64)])
assert lf.schema == OrderedDict([("a", out_dtype)])

0 comments on commit 7e0627f

Please sign in to comment.