From 82cb292090673d7331bd7b21969017c0b3ba3ead Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Sun, 17 Mar 2024 18:04:06 -0400 Subject: [PATCH] Cast bool to f64 during mean_horizontal --- crates/polars-core/src/frame/mod.rs | 7 ++- .../unit/operations/test_aggregations.py | 59 +++++++++++++++++++ py-polars/tests/unit/test_schema.py | 32 +++++++++- 3 files changed, 96 insertions(+), 2 deletions(-) diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index f138dedffe1a..6fa2202fc2fa 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -2561,7 +2561,12 @@ impl DataFrame { pub fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult> { match self.columns.len() { 0 => Ok(None), - 1 => Ok(Some(self.columns[0].clone())), + 1 => Ok(Some(match self.columns[0].dtype() { + dt if dt != &DataType::Float32 && (dt.is_numeric() || dt == &DataType::Boolean) => { + self.columns[0].cast(&DataType::Float64)? + }, + _ => self.columns[0].clone(), + })), _ => { let columns = self .columns diff --git a/py-polars/tests/unit/operations/test_aggregations.py b/py-polars/tests/unit/operations/test_aggregations.py index 9bcfbb3bb4d8..1ac644fe221d 100644 --- a/py-polars/tests/unit/operations/test_aggregations.py +++ b/py-polars/tests/unit/operations/test_aggregations.py @@ -12,6 +12,8 @@ if TYPE_CHECKING: import numpy.typing as npt + from polars.type_aliases import PolarsDataType + def test_quantile_expr_input() -> None: df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [0, 0, 0.3, 0.2, 0]}) @@ -471,3 +473,60 @@ def test_grouping_hash_14749() -> None: .select(pl.col("x").max().over("grp"))["x"] .value_counts() ).to_dict(as_series=False) == {"x": [3], "count": [1004]} + + +@pytest.mark.parametrize( + ("in_dtype", "out_dtype"), + [ + (pl.Boolean, pl.Float64), + (pl.UInt8, pl.Float64), + (pl.UInt16, pl.Float64), + (pl.UInt32, pl.Float64), + (pl.UInt64, pl.Float64), + (pl.Int8, pl.Float64), + (pl.Int16, pl.Float64), + (pl.Int32, pl.Float64), + (pl.Int64, pl.Float64), + (pl.Float32, pl.Float32), + (pl.Float64, pl.Float64), + ], +) +def test_horizontal_mean_single_column( + in_dtype: PolarsDataType, + out_dtype: PolarsDataType, +) -> None: + out = ( + pl.LazyFrame({"a": pl.Series([1, 0], dtype=in_dtype)}) + .select(pl.mean_horizontal(pl.all())) + .collect() + ) + + assert_frame_equal(out, pl.DataFrame({"a": pl.Series([1.0, 0.0], dtype=out_dtype)})) + + +def test_horizontal_mean_in_groupby_15115() -> None: + nbr_records = 1000 + out = ( + pl.LazyFrame( + { + "w": [None, "one", "two", "three"] * nbr_records, + "x": [None, None, "two", "three"] * nbr_records, + "y": [None, None, None, "three"] * nbr_records, + "z": [None, None, None, None] * nbr_records, + } + ) + .select(pl.mean_horizontal(pl.all().is_null()).alias("mean_null")) + .group_by("mean_null") + .len() + .sort(by="mean_null") + .collect() + ) + assert_frame_equal( + out, + pl.DataFrame( + { + "mean_null": pl.Series([0.25, 0.5, 0.75, 1.0], dtype=pl.Float64), + "len": pl.Series([nbr_records] * 4, dtype=pl.UInt32), + } + ), + ) diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index cb04db196150..5a5b3093d678 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -2,13 +2,16 @@ from collections import OrderedDict from datetime import date, timedelta -from typing import Any, Iterator, Mapping +from typing import TYPE_CHECKING, Any, Iterator, Mapping import pytest import polars as pl from polars.testing import assert_frame_equal, assert_series_equal +if TYPE_CHECKING: + from polars.type_aliases import PolarsDataType + class CustomSchema(Mapping[str, Any]): """Dummy schema object for testing compatibility with Mapping.""" @@ -637,3 +640,30 @@ def test_literal_subtract_schema_13284() -> None: def test_schema_boolean_sum_horizontal() -> None: lf = pl.LazyFrame({"a": [True, False]}).select(pl.sum_horizontal("a")) assert lf.schema == OrderedDict([("a", pl.UInt32)]) + + +@pytest.mark.parametrize( + ("in_dtype", "out_dtype"), + [ + (pl.Boolean, pl.Float64), + (pl.UInt8, pl.Float64), + (pl.UInt16, pl.Float64), + (pl.UInt32, pl.Float64), + (pl.UInt64, pl.Float64), + (pl.Int8, pl.Float64), + (pl.Int16, pl.Float64), + (pl.Int32, pl.Float64), + (pl.Int64, pl.Float64), + (pl.Float32, pl.Float32), + (pl.Float64, pl.Float64), + ], +) +def test_schema_mean_horizontal_single_column( + in_dtype: PolarsDataType, + out_dtype: PolarsDataType, +) -> None: + lf = pl.LazyFrame({"a": pl.Series([1, 0], dtype=in_dtype)}).select( + pl.mean_horizontal(pl.all()) + ) + + assert lf.schema == OrderedDict([("a", out_dtype)])