From 9a3deb39d7abde405b614130398f078e30b534b8 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Fri, 10 May 2024 17:13:59 +0200 Subject: [PATCH 01/29] refactor(python): Return correct temporal type from Rust in `to_numpy` (#14353) --- py-polars/polars/series/series.py | 17 +----- py-polars/src/series/export.rs | 53 +++++++++++++++---- .../interop/numpy/test_to_numpy_series.py | 1 + 3 files changed, 46 insertions(+), 25 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 17b7ad7fd327..ebce4249dbeb 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4377,17 +4377,6 @@ def raise_on_copy() -> None: msg = "cannot return a zero-copy array" raise ValueError(msg) - def temporal_dtype_to_numpy(dtype: PolarsDataType) -> Any: - if dtype == Date: - return np.dtype("datetime64[D]") - elif dtype == Duration: - return np.dtype(f"timedelta64[{dtype.time_unit}]") # type: ignore[union-attr] - elif dtype == Datetime: - return np.dtype(f"datetime64[{dtype.time_unit}]") # type: ignore[union-attr] - else: - msg = f"invalid temporal type: {dtype}" - raise TypeError(msg) - if self.n_chunks() > 1: raise_on_copy() self = self.rechunk() @@ -4421,9 +4410,8 @@ def temporal_dtype_to_numpy(dtype: PolarsDataType) -> Any: np_array = s_u8._s.to_numpy_view().view(bool) elif dtype == Date: raise_on_copy() - np_dtype = temporal_dtype_to_numpy(dtype) s_i32 = self.to_physical() - np_array = s_i32._s.to_numpy_view().astype(np_dtype) + np_array = s_i32._s.to_numpy_view().astype(" Any: else: raise_on_copy() np_array = self._s.to_numpy() - if dtype in (Datetime, Duration, Date): - np_dtype = temporal_dtype_to_numpy(dtype) - np_array = np_array.view(np_dtype) if writable and not np_array.flags.writeable: raise_on_copy() diff --git a/py-polars/src/series/export.rs b/py-polars/src/series/export.rs index 38e4fae3ecb2..1097e823f0f0 100644 --- a/py-polars/src/series/export.rs +++ b/py-polars/src/series/export.rs @@ -185,7 +185,34 @@ impl PySeries { np_arr.into_py(py) }, Date => date_series_to_numpy(py, s), - Datetime(_, _) | Duration(_) => temporal_series_to_numpy(py, s), + Datetime(tu, _) => { + use numpy::datetime::{units, Datetime}; + match tu { + TimeUnit::Milliseconds => { + temporal_series_to_numpy::>(py, s) + }, + TimeUnit::Microseconds => { + temporal_series_to_numpy::>(py, s) + }, + TimeUnit::Nanoseconds => { + temporal_series_to_numpy::>(py, s) + }, + } + }, + Duration(tu) => { + use numpy::datetime::{units, Timedelta}; + match tu { + TimeUnit::Milliseconds => { + temporal_series_to_numpy::>(py, s) + }, + TimeUnit::Microseconds => { + temporal_series_to_numpy::>(py, s) + }, + TimeUnit::Nanoseconds => { + temporal_series_to_numpy::>(py, s) + }, + } + }, Time => { let ca = s.time().unwrap(); let iter = time_to_pyobject_iter(py, ca); @@ -254,19 +281,27 @@ where } /// Convert dates directly to i64 with i64::MIN representing a null value fn date_series_to_numpy(py: Python, s: &Series) -> PyObject { + use numpy::datetime::{units, Datetime}; + let s_phys = s.to_physical_repr(); let ca = s_phys.i32().unwrap(); - let mapper = |opt_v: Option| match opt_v { - Some(v) => v as i64, - None => i64::MIN, + let mapper = |opt_v: Option| { + let int = match opt_v { + Some(v) => v as i64, + None => i64::MIN, + }; + int.into() }; - let np_arr = PyArray1::from_iter_bound(py, ca.iter().map(mapper)); - np_arr.into_py(py) + let iter = ca.iter().map(mapper); + PyArray1::>::from_iter_bound(py, iter).into_py(py) } /// Convert datetimes and durations with i64::MIN representing a null value -fn temporal_series_to_numpy(py: Python, s: &Series) -> PyObject { +fn temporal_series_to_numpy(py: Python, s: &Series) -> PyObject +where + T: From + numpy::Element, +{ let s_phys = s.to_physical_repr(); let ca = s_phys.i64().unwrap(); - let np_arr = PyArray1::from_iter_bound(py, ca.iter().map(|v| v.unwrap_or(i64::MIN))); - np_arr.into_py(py) + let iter = ca.iter().map(|v| v.unwrap_or(i64::MIN).into()); + PyArray1::::from_iter_bound(py, iter).into_py(py) } diff --git a/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py b/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py index fe2909672fa8..5674d183e3f6 100644 --- a/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py +++ b/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py @@ -130,6 +130,7 @@ def test_series_to_numpy_date() -> None: assert s.to_list() == result.tolist() assert result.dtype == np.dtype("datetime64[D]") + assert result.flags.writeable is True assert_allow_copy_false_raises(s) From 6de63afbff9d544166dfaf72965df3a9896f46bd Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Fri, 10 May 2024 11:19:48 -0400 Subject: [PATCH 02/29] chore(python): Update plugin example to PyO3 0.21 (#16157) Co-authored-by: Itamar Turner-Trauring --- examples/python_rust_compiled_function/Cargo.toml | 2 +- examples/python_rust_compiled_function/src/ffi.rs | 14 +++++++------- examples/python_rust_compiled_function/src/lib.rs | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/python_rust_compiled_function/Cargo.toml b/examples/python_rust_compiled_function/Cargo.toml index da8b5f37096a..94982fe498ef 100644 --- a/examples/python_rust_compiled_function/Cargo.toml +++ b/examples/python_rust_compiled_function/Cargo.toml @@ -14,4 +14,4 @@ polars = { path = "../../crates/polars" } pyo3 = { workspace = true, features = ["extension-module"] } [build-dependencies] -pyo3-build-config = "0.20" +pyo3-build-config = "0.21" diff --git a/examples/python_rust_compiled_function/src/ffi.rs b/examples/python_rust_compiled_function/src/ffi.rs index 16e4f09a440c..22222e8e20f8 100644 --- a/examples/python_rust_compiled_function/src/ffi.rs +++ b/examples/python_rust_compiled_function/src/ffi.rs @@ -7,7 +7,7 @@ use pyo3::{PyAny, PyObject, PyResult}; /// Take an arrow array from python and convert it to a rust arrow array. /// This operation does not copy data. -fn array_to_rust(arrow_array: &PyAny) -> PyResult { +fn array_to_rust(arrow_array: &Bound) -> PyResult { // prepare a pointer to receive the Array struct let array = Box::new(ffi::ArrowArray::empty()); let schema = Box::new(ffi::ArrowSchema::empty()); @@ -30,7 +30,7 @@ fn array_to_rust(arrow_array: &PyAny) -> PyResult { } /// Arrow array to Python. -pub(crate) fn to_py_array(py: Python, pyarrow: &PyModule, array: ArrayRef) -> PyResult { +pub(crate) fn to_py_array(py: Python, pyarrow: &Bound, array: ArrayRef) -> PyResult { let schema = Box::new(ffi::export_field_to_c(&ArrowField::new( "", array.data_type().clone(), @@ -49,7 +49,7 @@ pub(crate) fn to_py_array(py: Python, pyarrow: &PyModule, array: ArrayRef) -> Py Ok(array.to_object(py)) } -pub fn py_series_to_rust_series(series: &PyAny) -> PyResult { +pub fn py_series_to_rust_series(series: &Bound) -> PyResult { // rechunk series so that they have a single arrow array let series = series.call_method0("rechunk")?; @@ -59,7 +59,7 @@ pub fn py_series_to_rust_series(series: &PyAny) -> PyResult { let array = series.call_method0("to_arrow")?; // retrieve rust arrow array - let array = array_to_rust(array)?; + let array = array_to_rust(&array)?; Series::try_from((name.as_str(), array)).map_err(|e| PyValueError::new_err(format!("{}", e))) } @@ -71,13 +71,13 @@ pub fn rust_series_to_py_series(series: &Series) -> PyResult { Python::with_gil(|py| { // import pyarrow - let pyarrow = py.import("pyarrow")?; + let pyarrow = py.import_bound("pyarrow")?; // pyarrow array - let pyarrow_array = to_py_array(py, pyarrow, array)?; + let pyarrow_array = to_py_array(py, &pyarrow, array)?; // import polars - let polars = py.import("polars")?; + let polars = py.import_bound("polars")?; let out = polars.call_method1("from_arrow", (pyarrow_array,))?; Ok(out.to_object(py)) }) diff --git a/examples/python_rust_compiled_function/src/lib.rs b/examples/python_rust_compiled_function/src/lib.rs index 71708aa90475..f8c2caec2123 100644 --- a/examples/python_rust_compiled_function/src/lib.rs +++ b/examples/python_rust_compiled_function/src/lib.rs @@ -5,7 +5,7 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; #[pyfunction] -fn hamming_distance(series_a: &PyAny, series_b: &PyAny) -> PyResult { +fn hamming_distance(series_a: &Bound, series_b: &Bound) -> PyResult { let series_a = ffi::py_series_to_rust_series(series_a)?; let series_b = ffi::py_series_to_rust_series(series_b)?; @@ -44,7 +44,7 @@ fn hamming_distance_strs(a: Option<&str>, b: Option<&str>) -> Option { } #[pymodule] -fn my_polars_functions(_py: Python, m: &PyModule) -> PyResult<()> { +fn my_polars_functions(_py: Python, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(hamming_distance)).unwrap(); Ok(()) } From be57336621c889800bca3c6fbb7c8424825967f5 Mon Sep 17 00:00:00 2001 From: thalassemia Date: Fri, 10 May 2024 10:28:45 -0700 Subject: [PATCH 03/29] feat(rust,python): Add run-length encoding to Parquet writer (#16125) Co-authored-by: ritchie --- .github/workflows/test-rust.yml | 2 + .../src/compute/cast/binary_to.rs | 1 + .../src/compute/cast/binview_to.rs | 2 + .../src/compute/cast/primitive_to.rs | 1 + .../polars-arrow/src/compute/cast/utf8_to.rs | 1 + crates/polars-io/src/parquet/write/writer.rs | 2 +- .../src/arrow/write/dictionary.rs | 41 +-- crates/polars-parquet/src/arrow/write/mod.rs | 2 +- .../src/arrow/write/nested/mod.rs | 10 +- .../polars-parquet/src/arrow/write/utils.rs | 6 +- .../parquet/encoding/hybrid_rle/encoder.rs | 305 +++++++++++++----- .../src/parquet/encoding/hybrid_rle/mod.rs | 4 +- .../tests/it/io/parquet/write/binary.rs | 4 +- .../tests/it/io/parquet/write/primitive.rs | 4 +- py-polars/tests/unit/io/test_parquet.py | 65 ++++ 15 files changed, 314 insertions(+), 136 deletions(-) diff --git a/.github/workflows/test-rust.yml b/.github/workflows/test-rust.yml index 6db364e210d1..4e54ca0cf8e9 100644 --- a/.github/workflows/test-rust.yml +++ b/.github/workflows/test-rust.yml @@ -52,6 +52,7 @@ jobs: -p polars-io -p polars-lazy -p polars-ops + -p polars-parquet -p polars-plan -p polars-row -p polars-sql @@ -68,6 +69,7 @@ jobs: -p polars-io -p polars-lazy -p polars-ops + -p polars-parquet -p polars-plan -p polars-row -p polars-sql diff --git a/crates/polars-arrow/src/compute/cast/binary_to.rs b/crates/polars-arrow/src/compute/cast/binary_to.rs index c7970fe6a051..d5e8bfb30852 100644 --- a/crates/polars-arrow/src/compute/cast/binary_to.rs +++ b/crates/polars-arrow/src/compute/cast/binary_to.rs @@ -139,6 +139,7 @@ pub fn binary_to_dictionary( from: &BinaryArray, ) -> PolarsResult> { let mut array = MutableDictionaryArray::>::new(); + array.reserve(from.len()); array.try_extend(from.iter())?; Ok(array.into()) diff --git a/crates/polars-arrow/src/compute/cast/binview_to.rs b/crates/polars-arrow/src/compute/cast/binview_to.rs index 8c7ef4c2453a..1c157110ec49 100644 --- a/crates/polars-arrow/src/compute/cast/binview_to.rs +++ b/crates/polars-arrow/src/compute/cast/binview_to.rs @@ -21,6 +21,7 @@ pub(super) fn binview_to_dictionary( from: &BinaryViewArray, ) -> PolarsResult> { let mut array = MutableDictionaryArray::>::new(); + array.reserve(from.len()); array.try_extend(from.iter())?; Ok(array.into()) @@ -30,6 +31,7 @@ pub(super) fn utf8view_to_dictionary( from: &Utf8ViewArray, ) -> PolarsResult> { let mut array = MutableDictionaryArray::>::new(); + array.reserve(from.len()); array.try_extend(from.iter())?; Ok(array.into()) diff --git a/crates/polars-arrow/src/compute/cast/primitive_to.rs b/crates/polars-arrow/src/compute/cast/primitive_to.rs index d0d2056b70de..583b6ab19a96 100644 --- a/crates/polars-arrow/src/compute/cast/primitive_to.rs +++ b/crates/polars-arrow/src/compute/cast/primitive_to.rs @@ -318,6 +318,7 @@ pub fn primitive_to_dictionary( let mut array = MutableDictionaryArray::::try_empty(MutablePrimitiveArray::::from( from.data_type().clone(), ))?; + array.reserve(from.len()); array.try_extend(iter)?; Ok(array.into()) diff --git a/crates/polars-arrow/src/compute/cast/utf8_to.rs b/crates/polars-arrow/src/compute/cast/utf8_to.rs index 4df2876d394e..85b478c43817 100644 --- a/crates/polars-arrow/src/compute/cast/utf8_to.rs +++ b/crates/polars-arrow/src/compute/cast/utf8_to.rs @@ -27,6 +27,7 @@ pub fn utf8_to_dictionary( from: &Utf8Array, ) -> PolarsResult> { let mut array = MutableDictionaryArray::>::new(); + array.reserve(from.len()); array.try_extend(from.iter())?; Ok(array.into()) diff --git a/crates/polars-io/src/parquet/write/writer.rs b/crates/polars-io/src/parquet/write/writer.rs index 2408d66e9ba2..620ac11c3351 100644 --- a/crates/polars-io/src/parquet/write/writer.rs +++ b/crates/polars-io/src/parquet/write/writer.rs @@ -102,7 +102,7 @@ where WriteOptions { write_statistics: self.statistics, compression: self.compression, - version: Version::V2, + version: Version::V1, data_pagesize_limit: self.data_page_size, } } diff --git a/crates/polars-parquet/src/arrow/write/dictionary.rs b/crates/polars-parquet/src/arrow/write/dictionary.rs index b3ea666865c9..0525578589eb 100644 --- a/crates/polars-parquet/src/arrow/write/dictionary.rs +++ b/crates/polars-parquet/src/arrow/write/dictionary.rs @@ -1,7 +1,6 @@ use arrow::array::{Array, BinaryViewArray, DictionaryArray, DictionaryKey, Utf8ViewArray}; use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::datatypes::{ArrowDataType, IntegerType}; -use num_traits::ToPrimitive; use polars_error::{polars_bail, PolarsResult}; use super::binary::{ @@ -16,23 +15,19 @@ use super::primitive::{ use super::{binview, nested, Nested, WriteOptions}; use crate::arrow::read::schema::is_nullable; use crate::arrow::write::{slice_nested_leaf, utils}; -use crate::parquet::encoding::hybrid_rle::encode_u32; +use crate::parquet::encoding::hybrid_rle::encode; use crate::parquet::encoding::Encoding; use crate::parquet::page::{DictPage, Page}; use crate::parquet::schema::types::PrimitiveType; use crate::parquet::statistics::{serialize_statistics, ParquetStatistics}; -use crate::write::{to_nested, DynIter, ParquetType}; +use crate::write::DynIter; pub(crate) fn encode_as_dictionary_optional( array: &dyn Array, + nested: &[Nested], type_: PrimitiveType, options: WriteOptions, ) -> Option>>> { - let nested = to_nested(array, &ParquetType::PrimitiveType(type_.clone())) - .ok()? - .pop() - .unwrap(); - let dtype = Box::new(array.data_type().clone()); let len_before = array.len(); @@ -52,35 +47,11 @@ pub(crate) fn encode_as_dictionary_optional( if (array.values().len() as f64) / (len_before as f64) > 0.75 { return None; } - if array.values().len().to_u16().is_some() { - let array = arrow::compute::cast::cast( - array, - &ArrowDataType::Dictionary( - IntegerType::UInt16, - Box::new(array.values().data_type().clone()), - false, - ), - Default::default(), - ) - .unwrap(); - - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - return Some(array_to_pages( - array, - type_, - &nested, - options, - Encoding::RleDictionary, - )); - } Some(array_to_pages( array, type_, - &nested, + nested, options, Encoding::RleDictionary, )) @@ -116,7 +87,7 @@ fn serialize_keys_values( buffer.push(num_bits as u8); // followed by the encoded indices. - Ok(encode_u32(buffer, keys, num_bits)?) + Ok(encode::(buffer, keys, num_bits)?) } else { let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64); @@ -124,7 +95,7 @@ fn serialize_keys_values( buffer.push(num_bits as u8); // followed by the encoded indices. - Ok(encode_u32(buffer, keys, num_bits)?) + Ok(encode::(buffer, keys, num_bits)?) } } diff --git a/crates/polars-parquet/src/arrow/write/mod.rs b/crates/polars-parquet/src/arrow/write/mod.rs index a980177c4835..65e03cecaae4 100644 --- a/crates/polars-parquet/src/arrow/write/mod.rs +++ b/crates/polars-parquet/src/arrow/write/mod.rs @@ -219,7 +219,7 @@ pub fn array_to_pages( // Only take this path for primitive columns if matches!(nested.first(), Some(Nested::Primitive(_, _, _))) { if let Some(result) = - encode_as_dictionary_optional(primitive_array, type_.clone(), options) + encode_as_dictionary_optional(primitive_array, nested, type_.clone(), options) { return result; } diff --git a/crates/polars-parquet/src/arrow/write/nested/mod.rs b/crates/polars-parquet/src/arrow/write/nested/mod.rs index 46e15eec6c72..9aed392a06ee 100644 --- a/crates/polars-parquet/src/arrow/write/nested/mod.rs +++ b/crates/polars-parquet/src/arrow/write/nested/mod.rs @@ -6,7 +6,7 @@ use polars_error::PolarsResult; pub use rep::num_values; use super::Nested; -use crate::parquet::encoding::hybrid_rle::encode_u32; +use crate::parquet::encoding::hybrid_rle::encode; use crate::parquet::read::levels::get_bit_width; use crate::parquet::write::Version; @@ -41,12 +41,12 @@ fn write_rep_levels(buffer: &mut Vec, nested: &[Nested], version: Version) - match version { Version::V1 => { write_levels_v1(buffer, |buffer: &mut Vec| { - encode_u32(buffer, levels, num_bits)?; + encode::(buffer, levels, num_bits)?; Ok(()) })?; }, Version::V2 => { - encode_u32(buffer, levels, num_bits)?; + encode::(buffer, levels, num_bits)?; }, } @@ -65,10 +65,10 @@ fn write_def_levels(buffer: &mut Vec, nested: &[Nested], version: Version) - match version { Version::V1 => write_levels_v1(buffer, move |buffer: &mut Vec| { - encode_u32(buffer, levels, num_bits)?; + encode::(buffer, levels, num_bits)?; Ok(()) }), - Version::V2 => Ok(encode_u32(buffer, levels, num_bits)?), + Version::V2 => Ok(encode::(buffer, levels, num_bits)?), } } diff --git a/crates/polars-parquet/src/arrow/write/utils.rs b/crates/polars-parquet/src/arrow/write/utils.rs index 2032029b2de4..0ba9f4289bab 100644 --- a/crates/polars-parquet/src/arrow/write/utils.rs +++ b/crates/polars-parquet/src/arrow/write/utils.rs @@ -4,7 +4,7 @@ use polars_error::*; use super::{Version, WriteOptions}; use crate::parquet::compression::CompressionOptions; -use crate::parquet::encoding::hybrid_rle::encode_bool; +use crate::parquet::encoding::hybrid_rle::encode; use crate::parquet::encoding::Encoding; use crate::parquet::metadata::Descriptor; use crate::parquet::page::{DataPage, DataPageHeader, DataPageHeaderV1, DataPageHeaderV2}; @@ -14,7 +14,7 @@ use crate::parquet::statistics::ParquetStatistics; fn encode_iter_v1>(buffer: &mut Vec, iter: I) -> PolarsResult<()> { buffer.extend_from_slice(&[0; 4]); let start = buffer.len(); - encode_bool(buffer, iter)?; + encode::(buffer, iter, 1)?; let end = buffer.len(); let length = end - start; @@ -25,7 +25,7 @@ fn encode_iter_v1>(buffer: &mut Vec, iter: I) -> Po } fn encode_iter_v2>(writer: &mut Vec, iter: I) -> PolarsResult<()> { - Ok(encode_bool(writer, iter)?) + Ok(encode::(writer, iter, 1)?) } fn encode_iter>( diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/encoder.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/encoder.rs index 1c4dd67ccec7..7e1858e44979 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/encoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/encoder.rs @@ -3,98 +3,232 @@ use std::io::Write; use super::bitpacked_encode; use crate::parquet::encoding::{bitpacked, ceil8, uleb128}; -/// RLE-hybrid encoding of `u32`. This currently only yields bitpacked values. -pub fn encode_u32>( - writer: &mut W, - iterator: I, - num_bits: u32, -) -> std::io::Result<()> { - let num_bits = num_bits as u8; - // the length of the iterator. - let length = iterator.size_hint().1.unwrap(); +// Arbitrary value that balances memory usage and storage overhead +const MAX_VALUES_PER_LITERAL_RUN: usize = (1 << 10) * 8; + +pub trait Encoder { + fn bitpacked_encode>( + writer: &mut W, + iterator: I, + num_bits: usize, + ) -> std::io::Result<()>; + + fn run_length_encode( + writer: &mut W, + run_length: usize, + value: T, + bit_width: u32, + ) -> std::io::Result<()>; +} - // write the length + indicator - let mut header = ceil8(length) as u64; - header <<= 1; - header |= 1; // it is bitpacked => first bit is set - let mut container = [0; 10]; - let used = uleb128::encode(header, &mut container); - writer.write_all(&container[..used])?; +const U32_BLOCK_LEN: usize = 32; - bitpacked_encode_u32(writer, iterator, num_bits as usize)?; +impl Encoder for u32 { + fn bitpacked_encode>( + writer: &mut W, + mut iterator: I, + num_bits: usize, + ) -> std::io::Result<()> { + // the length of the iterator. + let length = iterator.size_hint().1.unwrap(); + + let mut header = ceil8(length) as u64; + header <<= 1; + header |= 1; // it is bitpacked => first bit is set + let mut container = [0; 10]; + let used = uleb128::encode(header, &mut container); + writer.write_all(&container[..used])?; + + let chunks = length / U32_BLOCK_LEN; + let remainder = length - chunks * U32_BLOCK_LEN; + let mut buffer = [0u32; U32_BLOCK_LEN]; + + // simplified from ceil8(U32_BLOCK_LEN * num_bits) since U32_BLOCK_LEN = 32 + let compressed_chunk_size = 4 * num_bits; + + for _ in 0..chunks { + iterator + .by_ref() + .take(U32_BLOCK_LEN) + .zip(buffer.iter_mut()) + .for_each(|(item, buf)| *buf = item); + + let mut packed = [0u8; 4 * U32_BLOCK_LEN]; + bitpacked::encode_pack::(&buffer, num_bits, packed.as_mut()); + writer.write_all(&packed[..compressed_chunk_size])?; + } + + if remainder != 0 { + // Must be careful here to ensure we write a multiple of `num_bits` + // (the bit width) to align with the spec. Some readers also rely on + // this - see https://github.com/pola-rs/polars/pull/13883. + + // this is ceil8(remainder * num_bits), but we ensure the output is a + // multiple of num_bits by rewriting it as ceil8(remainder) * num_bits + let compressed_remainder_size = ceil8(remainder) * num_bits; + iterator + .by_ref() + .take(remainder) + .zip(buffer.iter_mut()) + .for_each(|(item, buf)| *buf = item); + + let mut packed = [0u8; 4 * U32_BLOCK_LEN]; + // No need to zero rest of buffer because remainder is either: + // * Multiple of 8: We pad non-terminal literal runs to have a + // multiple of 8 values. Once compressed, the data will end on + // clean byte boundaries and packed[..compressed_remainder_size] + // will include only the remainder values and nothing extra. + // * Final run: Extra values from buffer will be included in + // packed[..compressed_remainder_size] but ignored when decoding + // because they extend beyond known column length + bitpacked::encode_pack(&buffer, num_bits, packed.as_mut()); + writer.write_all(&packed[..compressed_remainder_size])?; + }; + Ok(()) + } - Ok(()) + fn run_length_encode( + writer: &mut W, + run_length: usize, + value: u32, + bit_width: u32, + ) -> std::io::Result<()> { + // write the length + indicator + let mut header = run_length as u64; + header <<= 1; + let mut container = [0; 10]; + let used = uleb128::encode(header, &mut container); + writer.write_all(&container[..used])?; + + let num_bytes = ceil8(bit_width as usize); + let bytes = value.to_le_bytes(); + writer.write_all(&bytes[..num_bytes])?; + Ok(()) + } } -const U32_BLOCK_LEN: usize = 32; - -fn bitpacked_encode_u32>( - writer: &mut W, - mut iterator: I, - num_bits: usize, -) -> std::io::Result<()> { - // the length of the iterator. - let length = iterator.size_hint().1.unwrap(); - - let chunks = length / U32_BLOCK_LEN; - let remainder = length - chunks * U32_BLOCK_LEN; - let mut buffer = [0u32; U32_BLOCK_LEN]; - - // simplified from ceil8(U32_BLOCK_LEN * num_bits) since U32_BLOCK_LEN = 32 - let compressed_chunk_size = 4 * num_bits; - - for _ in 0..chunks { - iterator - .by_ref() - .take(U32_BLOCK_LEN) - .zip(buffer.iter_mut()) - .for_each(|(item, buf)| *buf = item); - - let mut packed = [0u8; 4 * U32_BLOCK_LEN]; - bitpacked::encode_pack::(&buffer, num_bits, packed.as_mut()); - writer.write_all(&packed[..compressed_chunk_size])?; +impl Encoder for bool { + fn bitpacked_encode>( + writer: &mut W, + iterator: I, + _num_bits: usize, + ) -> std::io::Result<()> { + // the length of the iterator. + let length = iterator.size_hint().1.unwrap(); + + let mut header = ceil8(length) as u64; + header <<= 1; + header |= 1; // it is bitpacked => first bit is set + let mut container = [0; 10]; + let used = uleb128::encode(header, &mut container); + writer.write_all(&container[..used])?; + bitpacked_encode(writer, iterator)?; + Ok(()) } - if remainder != 0 { - // Must be careful here to ensure we write a multiple of `num_bits` - // (the bit width) to align with the spec. Some readers also rely on - // this - see https://github.com/pola-rs/polars/pull/13883. - - // this is ceil8(remainder * num_bits), but we ensure the output is a - // multiple of num_bits by rewriting it as ceil8(remainder) * num_bits - let compressed_remainder_size = ceil8(remainder) * num_bits; - iterator - .by_ref() - .take(remainder) - .zip(buffer.iter_mut()) - .for_each(|(item, buf)| *buf = item); - - let mut packed = [0u8; 4 * U32_BLOCK_LEN]; - bitpacked::encode_pack(&buffer, num_bits, packed.as_mut()); - writer.write_all(&packed[..compressed_remainder_size])?; - }; - Ok(()) + fn run_length_encode( + writer: &mut W, + run_length: usize, + value: bool, + _bit_width: u32, + ) -> std::io::Result<()> { + // write the length + indicator + let mut header = run_length as u64; + header <<= 1; + let mut container = [0; 10]; + let used = uleb128::encode(header, &mut container); + writer.write_all(&container[..used])?; + writer.write_all(&(value as u8).to_le_bytes())?; + Ok(()) + } } -/// the bitpacked part of the encoder. -pub fn encode_bool>( +#[allow(clippy::comparison_chain)] +pub fn encode, W: Write, I: Iterator>( writer: &mut W, iterator: I, + num_bits: u32, ) -> std::io::Result<()> { - // the length of the iterator. - let length = iterator.size_hint().1.unwrap(); - - // write the length + indicator - let mut header = ceil8(length) as u64; - header <<= 1; - header |= 1; // it is bitpacked => first bit is set - let mut container = [0; 10]; - let used = uleb128::encode(header, &mut container); - - writer.write_all(&container[..used])?; - - // encode the iterator - bitpacked_encode(writer, iterator) + let mut consecutive_repeats: usize = 0; + let mut previous_val = T::default(); + let mut buffered_bits = [previous_val; MAX_VALUES_PER_LITERAL_RUN]; + let mut buffer_idx = 0; + let mut literal_run_idx = 0; + for val in iterator { + if val == previous_val { + consecutive_repeats += 1; + if consecutive_repeats >= 8 { + // Run is long enough to RLE, no need to buffer values + if consecutive_repeats > 8 { + continue; + } else { + // When we encounter a run long enough to potentially RLE, + // we must first ensure that the buffered literal run has + // a multiple of 8 values for bit-packing. If not, we pad + // up by taking some of the consecutive repeats + let literal_padding = (8 - (literal_run_idx % 8)) % 8; + consecutive_repeats -= literal_padding; + literal_run_idx += literal_padding; + } + } + // Too short to RLE, continue to buffer values + } else if consecutive_repeats > 8 { + // Value changed so start a new run but the current run is long + // enough to RLE. First, bit-pack any buffered literal run. Then, + // RLE current run and reset consecutive repeat counter and buffer. + if literal_run_idx > 0 { + debug_assert!(literal_run_idx % 8 == 0); + T::bitpacked_encode( + writer, + buffered_bits.iter().take(literal_run_idx).copied(), + num_bits as usize, + )?; + literal_run_idx = 0; + } + T::run_length_encode(writer, consecutive_repeats, previous_val, num_bits)?; + consecutive_repeats = 1; + buffer_idx = 0; + } else { + // Value changed so start a new run but the current run is not long + // enough to RLE. Consolidate all consecutive repeats into buffered + // literal run. + literal_run_idx = buffer_idx; + consecutive_repeats = 1; + } + // If buffer is full, bit-pack as literal run and reset + if buffer_idx == MAX_VALUES_PER_LITERAL_RUN { + T::bitpacked_encode(writer, buffered_bits.iter().copied(), num_bits as usize)?; + // If buffer fills up in the middle of a run, all but the last + // repeat is consolidated into the literal run. + debug_assert!( + (consecutive_repeats < 8) + && (buffer_idx - literal_run_idx == consecutive_repeats - 1) + ); + consecutive_repeats = 1; + buffer_idx = 0; + literal_run_idx = 0; + } + buffered_bits[buffer_idx] = val; + previous_val = val; + buffer_idx += 1; + } + // Final run not long enough to RLE, extend literal run. + if consecutive_repeats <= 8 { + literal_run_idx = buffer_idx; + } + // Bit-pack final buffered literal run, if any + if literal_run_idx > 0 { + T::bitpacked_encode( + writer, + buffered_bits.iter().take(literal_run_idx).copied(), + num_bits as usize, + )?; + } + // RLE final consecutive run if long enough + if consecutive_repeats > 8 { + T::run_length_encode(writer, consecutive_repeats, previous_val, num_bits)?; + } + Ok(()) } #[cfg(test)] @@ -108,7 +242,7 @@ mod tests { let mut vec = vec![]; - encode_bool(&mut vec, iter)?; + encode::(&mut vec, iter, 1)?; assert_eq!(vec, vec![(2 << 1 | 1), 0b10011101u8, 0b00011101]); @@ -119,9 +253,10 @@ mod tests { fn bool_from_iter() -> std::io::Result<()> { let mut vec = vec![]; - encode_bool( + encode::( &mut vec, vec![true, true, true, true, true, true, true, true].into_iter(), + 1, )?; assert_eq!(vec, vec![(1 << 1 | 1), 0b11111111]); @@ -132,7 +267,7 @@ mod tests { fn test_encode_u32() -> std::io::Result<()> { let mut vec = vec![]; - encode_u32(&mut vec, vec![0, 1, 2, 1, 2, 1, 1, 0, 3].into_iter(), 2)?; + encode::(&mut vec, vec![0, 1, 2, 1, 2, 1, 1, 0, 3].into_iter(), 2)?; assert_eq!( vec, @@ -153,7 +288,7 @@ mod tests { let values = (0..128).map(|x| x % 4); - encode_u32(&mut vec, values, 2)?; + encode::(&mut vec, values, 2)?; let length = 128; let expected = 0b11_10_01_00u8; @@ -170,7 +305,7 @@ mod tests { let values = vec![3, 3, 0, 3, 2, 3, 3, 3, 3, 1, 3, 3, 3, 0, 3].into_iter(); let mut vec = vec![]; - encode_u32(&mut vec, values, 2)?; + encode::(&mut vec, values, 2)?; let expected = vec![5, 207, 254, 247, 51]; assert_eq!(expected, vec); diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs index 3dc072552524..89816f87fb54 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs @@ -4,7 +4,7 @@ mod decoder; mod encoder; pub use bitmap::{encode_bool as bitpacked_encode, BitmapIter}; pub use decoder::Decoder; -pub use encoder::{encode_bool, encode_u32}; +pub use encoder::encode; use polars_utils::iter::FallibleIterator; use super::bitpacked; @@ -137,7 +137,7 @@ mod tests { let data = (0..1000).collect::>(); - encode_u32(&mut buffer, data.iter().cloned(), num_bits).unwrap(); + encode::(&mut buffer, data.iter().cloned(), num_bits).unwrap(); let decoder = HybridRleDecoder::try_new(&buffer, num_bits, data.len())?; diff --git a/crates/polars/tests/it/io/parquet/write/binary.rs b/crates/polars/tests/it/io/parquet/write/binary.rs index 3112f115c3e7..dd4e3a942c46 100644 --- a/crates/polars/tests/it/io/parquet/write/binary.rs +++ b/crates/polars/tests/it/io/parquet/write/binary.rs @@ -1,4 +1,4 @@ -use polars_parquet::parquet::encoding::hybrid_rle::encode_bool; +use polars_parquet::parquet::encoding::hybrid_rle::encode; use polars_parquet::parquet::encoding::Encoding; use polars_parquet::parquet::error::Result; use polars_parquet::parquet::metadata::Descriptor; @@ -25,7 +25,7 @@ fn unzip_option(array: &[Option>]) -> Result<(Vec, Vec)> { false } }); - encode_bool(&mut validity, iter)?; + encode::(&mut validity, iter, 1)?; // write the length, now that it is known let mut validity = validity.into_inner(); diff --git a/crates/polars/tests/it/io/parquet/write/primitive.rs b/crates/polars/tests/it/io/parquet/write/primitive.rs index 3b5ae150896a..e5da32252e99 100644 --- a/crates/polars/tests/it/io/parquet/write/primitive.rs +++ b/crates/polars/tests/it/io/parquet/write/primitive.rs @@ -1,4 +1,4 @@ -use polars_parquet::parquet::encoding::hybrid_rle::encode_bool; +use polars_parquet::parquet::encoding::hybrid_rle::encode; use polars_parquet::parquet::encoding::Encoding; use polars_parquet::parquet::error::Result; use polars_parquet::parquet::metadata::Descriptor; @@ -24,7 +24,7 @@ fn unzip_option(array: &[Option]) -> Result<(Vec, Vec) false } }); - encode_bool(&mut validity, iter)?; + encode::(&mut validity, iter, 1)?; // write the length, now that it is known let mut validity = validity.into_inner(); diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 12ac1a835b40..6ef0c4201981 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -892,3 +892,68 @@ def test_no_glob_windows(tmp_path: Path) -> None: df.write_parquet(str(p2)) assert_frame_equal(pl.scan_parquet(str(p1), glob=False).collect(), df) + + +@pytest.mark.slow() +def test_hybrid_rle() -> None: + # 10_007 elements to test if not a nice multiple of 8 + n = 10_007 + literal_literal = [] + literal_rle = [] + for i in range(500): + literal_literal.append(np.repeat(i, 5)) + literal_literal.append(np.repeat(i + 2, 11)) + literal_rle.append(np.repeat(i, 5)) + literal_rle.append(np.repeat(i + 2, 15)) + literal_literal.append(np.random.randint(0, 10, size=2007)) + literal_rle.append(np.random.randint(0, 10, size=7)) + literal_literal = np.concatenate(literal_literal) + literal_rle = np.concatenate(literal_rle) + df = pl.DataFrame( + { + # Primitive types + "i64": pl.Series([1, 2], dtype=pl.Int64).sample(n, with_replacement=True), + "u64": pl.Series([1, 2], dtype=pl.UInt64).sample(n, with_replacement=True), + "i8": pl.Series([1, 2], dtype=pl.Int8).sample(n, with_replacement=True), + "u8": pl.Series([1, 2], dtype=pl.UInt8).sample(n, with_replacement=True), + "string": pl.Series(["abc", "def"], dtype=pl.String).sample( + n, with_replacement=True + ), + "categorical": pl.Series(["aaa", "bbb"], dtype=pl.Categorical).sample( + n, with_replacement=True + ), + # Fill up bit-packing buffer in middle of consecutive run + "large_bit_pack": np.concatenate( + [np.repeat(i, 5) for i in range(2000)] + + [np.random.randint(0, 10, size=7)] + ), + # Literal run that is not a multiple of 8 followed by consecutive + # run initially long enough to RLE but not after padding literal + "literal_literal": literal_literal, + # Literal run that is not a multiple of 8 followed by consecutive + # run long enough to RLE even after padding literal + "literal_rle": literal_rle, + # Final run not long enough to RLE + "final_literal": np.concatenate( + [np.random.randint(0, 100, 10_000), np.repeat(-1, 7)] + ), + # Final run long enough to RLE + "final_rle": np.concatenate( + [np.random.randint(0, 100, 9_998), np.repeat(-1, 9)] + ), + # Test filling up bit-packing buffer for encode_bool, + # which is only used to encode validities + "large_bit_pack_validity": [0, None] * 4092 + + [0] * 9 + + [1] * 9 + + [2] * 10 + + [0] * 1795, + } + ) + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + for column in pq.ParquetFile(f).metadata.to_dict()["row_groups"][0]["columns"]: + assert "RLE_DICTIONARY" in column["encodings"] + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) From c525d64fe087d65489faf35fb3ed86de8d2838fa Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Fri, 10 May 2024 18:29:53 +0100 Subject: [PATCH 04/29] feat(rust!): separate `rolling_*_by` from `rolling_*(..., by=...)` in Rust (#16102) --- crates/polars-core/Cargo.toml | 2 + .../src/chunked_array/ops/rolling_window.rs | 18 +- crates/polars-lazy/Cargo.toml | 6 +- crates/polars-lazy/src/prelude.rs | 4 +- crates/polars-lazy/src/tests/aggregations.rs | 12 +- crates/polars-ops/Cargo.toml | 1 + crates/polars-plan/Cargo.toml | 7 +- .../polars-plan/src/dsl/function_expr/mod.rs | 36 +- .../src/dsl/function_expr/rolling.rs | 150 ++------ .../src/dsl/function_expr/rolling_by.rs | 88 +++++ .../src/dsl/function_expr/schema.rs | 15 +- .../src/dsl/functions/correlation.rs | 16 +- crates/polars-plan/src/dsl/mod.rs | 130 +++++-- crates/polars-time/Cargo.toml | 3 +- crates/polars-time/src/chunkedarray/mod.rs | 4 +- .../chunkedarray/rolling_window/dispatch.rs | 320 ++++++++++++++---- .../src/chunkedarray/rolling_window/mod.rs | 158 +-------- crates/polars-time/src/windows/group_by.rs | 2 +- crates/polars/Cargo.toml | 4 +- crates/polars/tests/it/core/rolling_window.rs | 44 +-- py-polars/Cargo.toml | 1 + py-polars/polars/_utils/deprecation.py | 36 +- py-polars/polars/expr/expr.py | 210 ++++++++---- py-polars/src/expr/rolling.rs | 277 ++++++++++----- py-polars/src/lazyframe/visitor/expr_nodes.rs | 45 +-- .../unit/operations/rolling/test_rolling.py | 50 ++- 26 files changed, 1021 insertions(+), 618 deletions(-) create mode 100644 crates/polars-plan/src/dsl/function_expr/rolling_by.rs diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index 297caa7a71f3..6b87dbb177e4 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -88,6 +88,7 @@ take_opt_iter = [] group_by_list = [] # rolling window functions rolling_window = [] +rolling_window_by = [] diagonal_concat = [] dataframe_arithmetic = [] product = [] @@ -135,6 +136,7 @@ docs-selection = [ "dot_product", "row_hash", "rolling_window", + "rolling_window_by", "dtype-categorical", "dtype-decimal", "diagonal_concat", diff --git a/crates/polars-core/src/chunked_array/ops/rolling_window.rs b/crates/polars-core/src/chunked_array/ops/rolling_window.rs index 95679a4dafae..26ea0c4db61f 100644 --- a/crates/polars-core/src/chunked_array/ops/rolling_window.rs +++ b/crates/polars-core/src/chunked_array/ops/rolling_window.rs @@ -1,6 +1,9 @@ use arrow::legacy::prelude::DynArgs; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; -#[derive(Clone)] +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct RollingOptionsFixedWindow { /// The length of the window. pub window_size: usize, @@ -11,9 +14,22 @@ pub struct RollingOptionsFixedWindow { pub weights: Option>, /// Set the labels at the center of the window. pub center: bool, + #[cfg_attr(feature = "serde", serde(skip))] pub fn_params: DynArgs, } +#[cfg(feature = "rolling_window")] +impl PartialEq for RollingOptionsFixedWindow { + fn eq(&self, other: &Self) -> bool { + self.window_size == other.window_size + && self.min_periods == other.min_periods + && self.weights == other.weights + && self.center == other.center + && self.fn_params.is_none() + && other.fn_params.is_none() + } +} + impl Default for RollingOptionsFixedWindow { fn default() -> Self { RollingOptionsFixedWindow { diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 4c696453ed15..be6107c6b209 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -104,7 +104,10 @@ cum_agg = ["polars-plan/cum_agg"] interpolate = ["polars-plan/interpolate"] rolling_window = [ "polars-plan/rolling_window", - "polars-time/rolling_window", +] +rolling_window_by = [ + "polars-plan/rolling_window_by", + "polars-time/rolling_window_by", ] rank = ["polars-plan/rank"] diff = ["polars-plan/diff", "polars-plan/diff"] @@ -292,6 +295,7 @@ features = [ "replace", "rle", "rolling_window", + "rolling_window_by", "round_series", "row_hash", "search_sorted", diff --git a/crates/polars-lazy/src/prelude.rs b/crates/polars-lazy/src/prelude.rs index b986b5924d1b..0cdb926c886a 100644 --- a/crates/polars-lazy/src/prelude.rs +++ b/crates/polars-lazy/src/prelude.rs @@ -15,8 +15,8 @@ pub use polars_plan::logical_plan::{ }; pub use polars_plan::prelude::UnionArgs; pub(crate) use polars_plan::prelude::*; -#[cfg(feature = "rolling_window")] -pub use polars_time::{prelude::RollingOptions, Duration}; +#[cfg(feature = "rolling_window_by")] +pub use polars_time::Duration; #[cfg(feature = "dynamic_group_by")] pub use polars_time::{DynamicGroupOptions, PolarsTemporalGroupby, RollingGroupOptions}; pub(crate) use polars_utils::arena::{Arena, Node}; diff --git a/crates/polars-lazy/src/tests/aggregations.rs b/crates/polars-lazy/src/tests/aggregations.rs index 85a1177b4a63..0e67cba50566 100644 --- a/crates/polars-lazy/src/tests/aggregations.rs +++ b/crates/polars-lazy/src/tests/aggregations.rs @@ -173,14 +173,14 @@ fn test_power_in_agg_list1() -> PolarsResult<()> { .group_by([col("fruits")]) .agg([ col("A") - .rolling_min(RollingOptions { - window_size: Duration::new(1), + .rolling_min(RollingOptionsFixedWindow { + window_size: 1, ..Default::default() }) .alias("input"), col("A") - .rolling_min(RollingOptions { - window_size: Duration::new(1), + .rolling_min(RollingOptionsFixedWindow { + window_size: 1, ..Default::default() }) .pow(2.0) @@ -211,8 +211,8 @@ fn test_power_in_agg_list2() -> PolarsResult<()> { .lazy() .group_by([col("fruits")]) .agg([col("A") - .rolling_min(RollingOptions { - window_size: Duration::new(2), + .rolling_min(RollingOptionsFixedWindow { + window_size: 2, min_periods: 2, ..Default::default() }) diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 168998e8c330..6f7d410dab98 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -105,6 +105,7 @@ log = [] hash = [] reinterpret = ["polars-core/reinterpret"] rolling_window = ["polars-core/rolling_window"] +rolling_window_by = ["polars-core/rolling_window_by"] moment = [] mode = [] search_sorted = [] diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index ee6d0a2d43ee..bd3a1b9a3626 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -122,7 +122,11 @@ rolling_window = [ "polars-core/rolling_window", "polars-time/rolling_window", "polars-ops/rolling_window", - "polars-time/rolling_window", +] +rolling_window_by = [ + "polars-core/rolling_window_by", + "polars-time/rolling_window_by", + "polars-ops/rolling_window_by", ] rank = ["polars-ops/rank"] diff = ["polars-ops/diff"] @@ -180,6 +184,7 @@ features = [ "temporal", "serde", "rolling_window", + "rolling_window_by", "timezones", "dtype-date", "extract_groups", diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index e45cb5e86313..b9afd4c595d3 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -45,6 +45,8 @@ mod random; mod range; #[cfg(feature = "rolling_window")] pub mod rolling; +#[cfg(feature = "rolling_window_by")] +pub mod rolling_by; #[cfg(feature = "round_series")] mod round; #[cfg(feature = "row_hash")] @@ -96,6 +98,8 @@ pub use self::pow::PowFunction; pub(super) use self::range::RangeFunction; #[cfg(feature = "rolling_window")] pub(super) use self::rolling::RollingFunction; +#[cfg(feature = "rolling_window_by")] +pub(super) use self::rolling_by::RollingFunctionBy; #[cfg(feature = "strings")] pub(crate) use self::strings::StringFunction; #[cfg(feature = "dtype-struct")] @@ -156,6 +160,8 @@ pub enum FunctionExpr { FillNullWithStrategy(FillNullStrategy), #[cfg(feature = "rolling_window")] RollingExpr(RollingFunction), + #[cfg(feature = "rolling_window_by")] + RollingExprBy(RollingFunctionBy), ShiftAndFill, Shift, DropNans, @@ -420,6 +426,10 @@ impl Hash for FunctionExpr { RollingExpr(f) => { f.hash(state); }, + #[cfg(feature = "rolling_window_by")] + RollingExprBy(f) => { + f.hash(state); + }, #[cfg(feature = "moment")] Skew(a) => a.hash(state), #[cfg(feature = "moment")] @@ -609,6 +619,8 @@ impl Display for FunctionExpr { FillNull { .. } => "fill_null", #[cfg(feature = "rolling_window")] RollingExpr(func, ..) => return write!(f, "{func}"), + #[cfg(feature = "rolling_window_by")] + RollingExprBy(func, ..) => return write!(f, "{func}"), ShiftAndFill => "shift_and_fill", DropNans => "drop_nans", DropNulls => "drop_nulls", @@ -907,25 +919,31 @@ impl From for SpecialEq> { use RollingFunction::*; match f { Min(options) => map!(rolling::rolling_min, options.clone()), - MinBy(options) => map_as_slice!(rolling::rolling_min_by, options.clone()), Max(options) => map!(rolling::rolling_max, options.clone()), - MaxBy(options) => map_as_slice!(rolling::rolling_max_by, options.clone()), Mean(options) => map!(rolling::rolling_mean, options.clone()), - MeanBy(options) => map_as_slice!(rolling::rolling_mean_by, options.clone()), Sum(options) => map!(rolling::rolling_sum, options.clone()), - SumBy(options) => map_as_slice!(rolling::rolling_sum_by, options.clone()), Quantile(options) => map!(rolling::rolling_quantile, options.clone()), - QuantileBy(options) => { - map_as_slice!(rolling::rolling_quantile_by, options.clone()) - }, Var(options) => map!(rolling::rolling_var, options.clone()), - VarBy(options) => map_as_slice!(rolling::rolling_var_by, options.clone()), Std(options) => map!(rolling::rolling_std, options.clone()), - StdBy(options) => map_as_slice!(rolling::rolling_std_by, options.clone()), #[cfg(feature = "moment")] Skew(window_size, bias) => map!(rolling::rolling_skew, window_size, bias), } }, + #[cfg(feature = "rolling_window_by")] + RollingExprBy(f) => { + use RollingFunctionBy::*; + match f { + MinBy(options) => map_as_slice!(rolling_by::rolling_min_by, options.clone()), + MaxBy(options) => map_as_slice!(rolling_by::rolling_max_by, options.clone()), + MeanBy(options) => map_as_slice!(rolling_by::rolling_mean_by, options.clone()), + SumBy(options) => map_as_slice!(rolling_by::rolling_sum_by, options.clone()), + QuantileBy(options) => { + map_as_slice!(rolling_by::rolling_quantile_by, options.clone()) + }, + VarBy(options) => map_as_slice!(rolling_by::rolling_var_by, options.clone()), + StdBy(options) => map_as_slice!(rolling_by::rolling_std_by, options.clone()), + } + }, #[cfg(feature = "hist")] Hist { bin_count, diff --git a/crates/polars-plan/src/dsl/function_expr/rolling.rs b/crates/polars-plan/src/dsl/function_expr/rolling.rs index f1ae64c5f792..9302ab4a1ad7 100644 --- a/crates/polars-plan/src/dsl/function_expr/rolling.rs +++ b/crates/polars-plan/src/dsl/function_expr/rolling.rs @@ -5,20 +5,13 @@ use super::*; #[derive(Clone, PartialEq, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum RollingFunction { - Min(RollingOptions), - MinBy(RollingOptions), - Max(RollingOptions), - MaxBy(RollingOptions), - Mean(RollingOptions), - MeanBy(RollingOptions), - Sum(RollingOptions), - SumBy(RollingOptions), - Quantile(RollingOptions), - QuantileBy(RollingOptions), - Var(RollingOptions), - VarBy(RollingOptions), - Std(RollingOptions), - StdBy(RollingOptions), + Min(RollingOptionsFixedWindow), + Max(RollingOptionsFixedWindow), + Mean(RollingOptionsFixedWindow), + Sum(RollingOptionsFixedWindow), + Quantile(RollingOptionsFixedWindow), + Var(RollingOptionsFixedWindow), + Std(RollingOptionsFixedWindow), #[cfg(feature = "moment")] Skew(usize, bool), } @@ -29,19 +22,12 @@ impl Display for RollingFunction { let name = match self { Min(_) => "rolling_min", - MinBy(_) => "rolling_min_by", Max(_) => "rolling_max", - MaxBy(_) => "rolling_max_by", Mean(_) => "rolling_mean", - MeanBy(_) => "rolling_mean_by", Sum(_) => "rolling_sum", - SumBy(_) => "rolling_sum_by", Quantile(_) => "rolling_quantile", - QuantileBy(_) => "rolling_quantile_by", Var(_) => "rolling_var", - VarBy(_) => "rolling_var_by", Std(_) => "rolling_std", - StdBy(_) => "rolling_std_by", #[cfg(feature = "moment")] Skew(..) => "rolling_skew", }; @@ -66,123 +52,35 @@ impl Hash for RollingFunction { } } -fn convert<'a>( - f: impl Fn(RollingOptionsImpl) -> PolarsResult + 'a, - ss: &'a [Series], - expr_name: &'static str, -) -> impl Fn(RollingOptions) -> PolarsResult + 'a { - move |options| { - let mut by = ss[1].clone(); - by = by.rechunk(); - - let (by, tz) = match by.dtype() { - DataType::Datetime(tu, tz) => (by.cast(&DataType::Datetime(*tu, None))?, tz), - DataType::Date => ( - by.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?, - &None, - ), - dt => polars_bail!(InvalidOperation: - "in `{}` operation, `by` argument of dtype `{}` is not supported (expected `{}`)", - expr_name, - dt, - "date/datetime"), - }; - if by.is_sorted_flag() != IsSorted::Ascending && options.warn_if_unsorted { - polars_warn!(format!( - "Series is not known to be sorted by `by` column in {} operation.\n\ - \n\ - To silence this warning, you may want to try:\n\ - - sorting your data by your `by` column beforehand;\n\ - - setting `.set_sorted()` if you already know your data is sorted;\n\ - - passing `warn_if_unsorted=False` if this warning is a false-positive\n \ - (this is known to happen when combining rolling aggregations with `over`);\n\n\ - before passing calling the rolling aggregation function.\n", - expr_name - )); - } - let by = by.datetime().unwrap(); - let by_values = by.cont_slice().map_err(|_| { - polars_err!( - ComputeError: - "`by` column should not have null values in 'rolling by' expression" - ) - })?; - let tu = by.time_unit(); - - let options = RollingOptionsImpl { - window_size: options.window_size, - min_periods: options.min_periods, - weights: options.weights, - center: options.center, - by: Some(by_values), - tu: Some(tu), - tz: tz.as_ref(), - closed_window: options.closed_window, - fn_params: options.fn_params.clone(), - }; - - f(options) - } -} - -pub(super) fn rolling_min(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_min(options.into()) -} - -pub(super) fn rolling_min_by(s: &[Series], options: RollingOptions) -> PolarsResult { - convert(|options| s[0].rolling_min(options), s, "rolling_min")(options) -} - -pub(super) fn rolling_max(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_max(options.into()) -} - -pub(super) fn rolling_max_by(s: &[Series], options: RollingOptions) -> PolarsResult { - convert(|options| s[0].rolling_max(options), s, "rolling_max")(options) -} - -pub(super) fn rolling_mean(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_mean(options.into()) -} - -pub(super) fn rolling_mean_by(s: &[Series], options: RollingOptions) -> PolarsResult { - convert(|options| s[0].rolling_mean(options), s, "rolling_mean")(options) -} - -pub(super) fn rolling_sum(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_sum(options.into()) -} - -pub(super) fn rolling_sum_by(s: &[Series], options: RollingOptions) -> PolarsResult { - convert(|options| s[0].rolling_sum(options), s, "rolling_sum")(options) +pub(super) fn rolling_min(s: &Series, options: RollingOptionsFixedWindow) -> PolarsResult { + s.rolling_min(options) } -pub(super) fn rolling_quantile(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_quantile(options.into()) +pub(super) fn rolling_max(s: &Series, options: RollingOptionsFixedWindow) -> PolarsResult { + s.rolling_max(options) } -pub(super) fn rolling_quantile_by(s: &[Series], options: RollingOptions) -> PolarsResult { - convert( - |options| s[0].rolling_quantile(options), - s, - "rolling_quantile", - )(options) +pub(super) fn rolling_mean(s: &Series, options: RollingOptionsFixedWindow) -> PolarsResult { + s.rolling_mean(options) } -pub(super) fn rolling_var(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_var(options.into()) +pub(super) fn rolling_sum(s: &Series, options: RollingOptionsFixedWindow) -> PolarsResult { + s.rolling_sum(options) } -pub(super) fn rolling_var_by(s: &[Series], options: RollingOptions) -> PolarsResult { - convert(|options| s[0].rolling_var(options), s, "rolling_var")(options) +pub(super) fn rolling_quantile( + s: &Series, + options: RollingOptionsFixedWindow, +) -> PolarsResult { + s.rolling_quantile(options) } -pub(super) fn rolling_std(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_std(options.into()) +pub(super) fn rolling_var(s: &Series, options: RollingOptionsFixedWindow) -> PolarsResult { + s.rolling_var(options) } -pub(super) fn rolling_std_by(s: &[Series], options: RollingOptions) -> PolarsResult { - convert(|options| s[0].rolling_std(options), s, "rolling_std")(options) +pub(super) fn rolling_std(s: &Series, options: RollingOptionsFixedWindow) -> PolarsResult { + s.rolling_std(options) } #[cfg(feature = "moment")] diff --git a/crates/polars-plan/src/dsl/function_expr/rolling_by.rs b/crates/polars-plan/src/dsl/function_expr/rolling_by.rs new file mode 100644 index 000000000000..c2b3510281f2 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/rolling_by.rs @@ -0,0 +1,88 @@ +use polars_time::chunkedarray::*; + +use super::*; + +#[derive(Clone, PartialEq, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum RollingFunctionBy { + MinBy(RollingOptionsDynamicWindow), + MaxBy(RollingOptionsDynamicWindow), + MeanBy(RollingOptionsDynamicWindow), + SumBy(RollingOptionsDynamicWindow), + QuantileBy(RollingOptionsDynamicWindow), + VarBy(RollingOptionsDynamicWindow), + StdBy(RollingOptionsDynamicWindow), +} + +impl Display for RollingFunctionBy { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use RollingFunctionBy::*; + + let name = match self { + MinBy(_) => "rolling_min_by", + MaxBy(_) => "rolling_max_by", + MeanBy(_) => "rolling_mean_by", + SumBy(_) => "rolling_sum_by", + QuantileBy(_) => "rolling_quantile_by", + VarBy(_) => "rolling_var_by", + StdBy(_) => "rolling_std_by", + }; + + write!(f, "{name}") + } +} + +impl Hash for RollingFunctionBy { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + } +} + +pub(super) fn rolling_min_by( + s: &[Series], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + s[0].rolling_min_by(&s[1], options) +} + +pub(super) fn rolling_max_by( + s: &[Series], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + s[0].rolling_max_by(&s[1], options) +} + +pub(super) fn rolling_mean_by( + s: &[Series], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + s[0].rolling_mean_by(&s[1], options) +} + +pub(super) fn rolling_sum_by( + s: &[Series], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + s[0].rolling_sum_by(&s[1], options) +} + +pub(super) fn rolling_quantile_by( + s: &[Series], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + s[0].rolling_quantile_by(&s[1], options) +} + +pub(super) fn rolling_var_by( + s: &[Series], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + s[0].rolling_var_by(&s[1], options) +} + +pub(super) fn rolling_std_by( + s: &[Series], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + s[0].rolling_std_by(&s[1], options) +} diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 830891fea1cb..e301557a247a 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -64,15 +64,20 @@ impl FunctionExpr { RollingExpr(rolling_func, ..) => { use RollingFunction::*; match rolling_func { - Min(_) | MinBy(_) | Max(_) | MaxBy(_) | Sum(_) | SumBy(_) => { - mapper.with_same_dtype() - }, - Mean(_) | MeanBy(_) | Quantile(_) | QuantileBy(_) | Var(_) | VarBy(_) - | Std(_) | StdBy(_) => mapper.map_to_float_dtype(), + Min(_) | Max(_) | Sum(_) => mapper.with_same_dtype(), + Mean(_) | Quantile(_) | Var(_) | Std(_) => mapper.map_to_float_dtype(), #[cfg(feature = "moment")] Skew(..) => mapper.map_to_float_dtype(), } }, + #[cfg(feature = "rolling_window_by")] + RollingExprBy(rolling_func, ..) => { + use RollingFunctionBy::*; + match rolling_func { + MinBy(_) | MaxBy(_) | SumBy(_) => mapper.with_same_dtype(), + MeanBy(_) | QuantileBy(_) | VarBy(_) | StdBy(_) => mapper.map_to_float_dtype(), + } + }, ShiftAndFill => mapper.with_same_dtype(), DropNans => mapper.with_same_dtype(), DropNulls => mapper.with_same_dtype(), diff --git a/crates/polars-plan/src/dsl/functions/correlation.rs b/crates/polars-plan/src/dsl/functions/correlation.rs index a41a8c8621a2..651365091cbe 100644 --- a/crates/polars-plan/src/dsl/functions/correlation.rs +++ b/crates/polars-plan/src/dsl/functions/correlation.rs @@ -73,8 +73,8 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> E #[cfg(feature = "rolling_window")] pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { // see: https://github.com/pandas-dev/pandas/blob/v1.5.1/pandas/core/window/rolling.py#L1780-L1804 - let rolling_options = RollingOptions { - window_size: Duration::new(options.window_size as i64), + let rolling_options = RollingOptionsFixedWindow { + window_size: options.window_size as usize, min_periods: options.min_periods as usize, ..Default::default() }; @@ -85,8 +85,8 @@ pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { let var_x = x.clone().rolling_var(rolling_options.clone()); let var_y = y.clone().rolling_var(rolling_options); - let rolling_options_count = RollingOptions { - window_size: Duration::new(options.window_size as i64), + let rolling_options_count = RollingOptionsFixedWindow { + window_size: options.window_size as usize, min_periods: 0, ..Default::default() }; @@ -104,8 +104,8 @@ pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { #[cfg(feature = "rolling_window")] pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { // see: https://github.com/pandas-dev/pandas/blob/91111fd99898d9dcaa6bf6bedb662db4108da6e6/pandas/core/window/rolling.py#L1700 - let rolling_options = RollingOptions { - window_size: Duration::new(options.window_size as i64), + let rolling_options = RollingOptionsFixedWindow { + window_size: options.window_size as usize, min_periods: options.min_periods as usize, ..Default::default() }; @@ -113,8 +113,8 @@ pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone()); let mean_x = x.clone().rolling_mean(rolling_options.clone()); let mean_y = y.clone().rolling_mean(rolling_options); - let rolling_options_count = RollingOptions { - window_size: Duration::new(options.window_size as i64), + let rolling_options_count = RollingOptionsFixedWindow { + window_size: options.window_size as usize, min_periods: 0, ..Default::default() }; diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 015e999851df..3c4b96130583 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -3,12 +3,12 @@ #[cfg(feature = "dtype-categorical")] pub mod cat; -#[cfg(feature = "rolling_window")] +#[cfg(any(feature = "rolling_window", feature = "rolling_window_by"))] use std::any::Any; #[cfg(feature = "dtype-categorical")] pub use cat::*; -#[cfg(feature = "rolling_window")] +#[cfg(feature = "rolling_window_by")] pub(crate) use polars_time::prelude::*; mod arithmetic; @@ -1237,64 +1237,128 @@ impl Expr { self.apply_private(FunctionExpr::Interpolate(method)) } + #[cfg(feature = "rolling_window_by")] + #[allow(clippy::type_complexity)] + fn finish_rolling_by( + self, + by: Expr, + options: RollingOptionsDynamicWindow, + rolling_function_by: fn(RollingOptionsDynamicWindow) -> RollingFunctionBy, + ) -> Expr { + self.apply_many_private( + FunctionExpr::RollingExprBy(rolling_function_by(options)), + &[by], + false, + false, + ) + } + #[cfg(feature = "rolling_window")] #[allow(clippy::type_complexity)] fn finish_rolling( self, - options: RollingOptions, - rolling_function: fn(RollingOptions) -> RollingFunction, - rolling_function_by: fn(RollingOptions) -> RollingFunction, + options: RollingOptionsFixedWindow, + rolling_function: fn(RollingOptionsFixedWindow) -> RollingFunction, ) -> Expr { - if let Some(ref by) = options.by { - let name = by.clone(); - self.apply_many_private( - FunctionExpr::RollingExpr(rolling_function_by(options)), - &[col(&name)], - false, - false, - ) - } else { - self.apply_private(FunctionExpr::RollingExpr(rolling_function(options))) - } + self.apply_private(FunctionExpr::RollingExpr(rolling_function(options))) + } + + /// Apply a rolling minimum based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_min_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::MinBy) + } + + /// Apply a rolling maximum based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_max_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::MaxBy) + } + + /// Apply a rolling mean based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_mean_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::MeanBy) + } + + /// Apply a rolling sum based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_sum_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::SumBy) + } + + /// Apply a rolling quantile based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_quantile_by( + self, + by: Expr, + interpol: QuantileInterpolOptions, + quantile: f64, + mut options: RollingOptionsDynamicWindow, + ) -> Expr { + options.fn_params = Some(Arc::new(RollingQuantileParams { + prob: quantile, + interpol, + }) as Arc); + + self.finish_rolling_by(by, options, RollingFunctionBy::QuantileBy) + } + + /// Apply a rolling variance based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_var_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::VarBy) + } + + /// Apply a rolling std-dev based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_std_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::StdBy) + } + + /// Apply a rolling median based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_median_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.rolling_quantile_by(by, QuantileInterpolOptions::Linear, 0.5, options) } /// Apply a rolling minimum. /// /// See: [`RollingAgg::rolling_min`] #[cfg(feature = "rolling_window")] - pub fn rolling_min(self, options: RollingOptions) -> Expr { - self.finish_rolling(options, RollingFunction::Min, RollingFunction::MinBy) + pub fn rolling_min(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Min) } /// Apply a rolling maximum. /// /// See: [`RollingAgg::rolling_max`] #[cfg(feature = "rolling_window")] - pub fn rolling_max(self, options: RollingOptions) -> Expr { - self.finish_rolling(options, RollingFunction::Max, RollingFunction::MaxBy) + pub fn rolling_max(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Max) } /// Apply a rolling mean. /// /// See: [`RollingAgg::rolling_mean`] #[cfg(feature = "rolling_window")] - pub fn rolling_mean(self, options: RollingOptions) -> Expr { - self.finish_rolling(options, RollingFunction::Mean, RollingFunction::MeanBy) + pub fn rolling_mean(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Mean) } /// Apply a rolling sum. /// /// See: [`RollingAgg::rolling_sum`] #[cfg(feature = "rolling_window")] - pub fn rolling_sum(self, options: RollingOptions) -> Expr { - self.finish_rolling(options, RollingFunction::Sum, RollingFunction::SumBy) + pub fn rolling_sum(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Sum) } /// Apply a rolling median. /// /// See: [`RollingAgg::rolling_median`] #[cfg(feature = "rolling_window")] - pub fn rolling_median(self, options: RollingOptions) -> Expr { + pub fn rolling_median(self, options: RollingOptionsFixedWindow) -> Expr { self.rolling_quantile(QuantileInterpolOptions::Linear, 0.5, options) } @@ -1306,30 +1370,26 @@ impl Expr { self, interpol: QuantileInterpolOptions, quantile: f64, - mut options: RollingOptions, + mut options: RollingOptionsFixedWindow, ) -> Expr { options.fn_params = Some(Arc::new(RollingQuantileParams { prob: quantile, interpol, }) as Arc); - self.finish_rolling( - options, - RollingFunction::Quantile, - RollingFunction::QuantileBy, - ) + self.finish_rolling(options, RollingFunction::Quantile) } /// Apply a rolling variance. #[cfg(feature = "rolling_window")] - pub fn rolling_var(self, options: RollingOptions) -> Expr { - self.finish_rolling(options, RollingFunction::Var, RollingFunction::VarBy) + pub fn rolling_var(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Var) } /// Apply a rolling std-dev. #[cfg(feature = "rolling_window")] - pub fn rolling_std(self, options: RollingOptions) -> Expr { - self.finish_rolling(options, RollingFunction::Std, RollingFunction::StdBy) + pub fn rolling_std(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Std) } /// Apply a rolling skew. diff --git a/crates/polars-time/Cargo.toml b/crates/polars-time/Cargo.toml index 2925de13c869..9e0773ecd2c6 100644 --- a/crates/polars-time/Cargo.toml +++ b/crates/polars-time/Cargo.toml @@ -32,7 +32,8 @@ dtype-date = ["polars-core/dtype-date", "temporal"] dtype-datetime = ["polars-core/dtype-datetime", "temporal"] dtype-time = ["polars-core/dtype-time", "temporal"] dtype-duration = ["polars-core/dtype-duration", "temporal"] -rolling_window = ["polars-core/rolling_window", "dtype-duration"] +rolling_window = ["polars-core/rolling_window"] +rolling_window_by = ["polars-core/rolling_window_by", "dtype-duration"] fmt = ["polars-core/fmt"] serde = ["dep:serde", "smartstring/serde"] temporal = ["polars-core/temporal"] diff --git a/crates/polars-time/src/chunkedarray/mod.rs b/crates/polars-time/src/chunkedarray/mod.rs index 4c2fb9cbf505..e61031d46ed1 100644 --- a/crates/polars-time/src/chunkedarray/mod.rs +++ b/crates/polars-time/src/chunkedarray/mod.rs @@ -6,7 +6,7 @@ mod datetime; #[cfg(feature = "dtype-duration")] mod duration; mod kernels; -#[cfg(feature = "rolling_window")] +#[cfg(any(feature = "rolling_window", feature = "rolling_window_by"))] mod rolling_window; pub mod string; #[cfg(feature = "dtype-time")] @@ -22,7 +22,7 @@ pub use datetime::DatetimeMethods; pub use duration::DurationMethods; use kernels::*; use polars_core::prelude::*; -#[cfg(feature = "rolling_window")] +#[cfg(any(feature = "rolling_window", feature = "rolling_window_by"))] pub use rolling_window::*; pub use string::StringMethods; #[cfg(feature = "dtype-time")] diff --git a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs index 1e6eb024919d..5feb3f9f99cb 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs @@ -1,13 +1,15 @@ +use polars_core::series::IsSorted; use polars_core::{with_match_physical_float_polars_type, with_match_physical_numeric_polars_type}; use super::*; use crate::prelude::*; use crate::series::AsSeries; +#[cfg(feature = "rolling_window")] #[allow(clippy::type_complexity)] fn rolling_agg( ca: &ChunkedArray, - options: RollingOptionsImpl, + options: RollingOptionsFixedWindow, rolling_agg_fn: &dyn Fn( &[T::Native], usize, @@ -24,79 +26,140 @@ fn rolling_agg( Option<&[f64]>, DynArgs, ) -> ArrayRef, - rolling_agg_fn_dynamic: Option< - &dyn Fn( - &[T::Native], - Duration, - &[i64], - ClosedWindow, - usize, - TimeUnit, - Option<&TimeZone>, - DynArgs, - ) -> PolarsResult, - >, ) -> PolarsResult where T: PolarsNumericType, { + polars_ensure!(options.min_periods <= options.window_size, InvalidOperation: "`min_periods` should be <= `window_size`"); if ca.is_empty() { return Ok(Series::new_empty(ca.name(), ca.dtype())); } let ca = ca.rechunk(); let arr = ca.downcast_iter().next().unwrap(); - // "5i" is a window size of 5, e.g. fixed - let arr = if options.by.is_none() { - let options: RollingOptionsFixedWindow = options.try_into()?; - Ok(match ca.null_count() { - 0 => rolling_agg_fn( - arr.values().as_slice(), - options.window_size, - options.min_periods, - options.center, - options.weights.as_deref(), - options.fn_params, - )?, - _ => rolling_agg_fn_nulls( - arr, - options.window_size, - options.min_periods, - options.center, - options.weights.as_deref(), - options.fn_params, - ), - }) - } else { - let options: RollingOptionsDynamicWindow = options.try_into()?; - if arr.null_count() > 0 { - polars_bail!(InvalidOperation: "'Expr.rolling_*(..., by=...)' not yet supported for series with null values, consider using 'DataFrame.rolling' or 'Expr.rolling'") - } - let values = arr.values().as_slice(); - let tu = options.tu.expect("time_unit was set in `convert` function"); - let by = options.by; - let func = rolling_agg_fn_dynamic.expect("rolling_agg_fn_dynamic must have been passed"); - - func( - values, + let arr = match ca.null_count() { + 0 => rolling_agg_fn( + arr.values().as_slice(), + options.window_size, + options.min_periods, + options.center, + options.weights.as_deref(), + options.fn_params, + )?, + _ => rolling_agg_fn_nulls( + arr, options.window_size, - by, - options.closed_window, options.min_periods, - tu, - options.tz, + options.center, + options.weights.as_deref(), options.fn_params, + ), + }; + Series::try_from((ca.name(), arr)) +} + +#[cfg(feature = "rolling_window_by")] +#[allow(clippy::type_complexity)] +fn rolling_agg_by( + ca: &ChunkedArray, + by: &Series, + options: RollingOptionsDynamicWindow, + rolling_agg_fn_dynamic: &dyn Fn( + &[T::Native], + Duration, + &[i64], + ClosedWindow, + usize, + TimeUnit, + Option<&TimeZone>, + DynArgs, + ) -> PolarsResult, +) -> PolarsResult +where + T: PolarsNumericType, +{ + if ca.is_empty() { + return Ok(Series::new_empty(ca.name(), ca.dtype())); + } + let ca = ca.rechunk(); + ensure_duration_matches_data_type(options.window_size, by.dtype(), "window_size")?; + polars_ensure!(options.window_size.duration_ns()>0 && !options.window_size.negative, InvalidOperation: "`window_size` must be strictly positive"); + if by.is_sorted_flag() != IsSorted::Ascending && options.warn_if_unsorted { + polars_warn!(format!( + "Series is not known to be sorted by `by` column in `rolling_*_by` operation.\n\ + \n\ + To silence this warning, you may want to try:\n\ + - sorting your data by your `by` column beforehand;\n\ + - setting `.set_sorted()` if you already know your data is sorted;\n\ + - passing `warn_if_unsorted=False` if this warning is a false-positive\n \ + (this is known to happen when combining rolling aggregations with `over`);\n\n\ + before passing calling the rolling aggregation function.\n", + )); + } + let (by, tz) = match by.dtype() { + DataType::Datetime(tu, tz) => (by.cast(&DataType::Datetime(*tu, None))?, tz), + DataType::Date => ( + by.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?, + &None, + ), + dt => polars_bail!(InvalidOperation: + "in `rolling_*_by` operation, `by` argument of dtype `{}` is not supported (expected `{}`)", + dt, + "date/datetime"), + }; + let by = by.datetime().unwrap(); + let by_values = by.cont_slice().map_err(|_| { + polars_err!( + ComputeError: + "`by` column should not have null values in 'rolling by' expression" ) - }?; + })?; + let tu = by.time_unit(); + + let arr = ca.downcast_iter().next().unwrap(); + if arr.null_count() > 0 { + polars_bail!(InvalidOperation: "'Expr.rolling_*(..., by=...)' not yet supported for series with null values, consider using 'DataFrame.rolling' or 'Expr.rolling'") + } + let values = arr.values().as_slice(); + let func = rolling_agg_fn_dynamic; + + let arr = func( + values, + options.window_size, + by_values, + options.closed_window, + options.min_periods, + tu, + tz.as_ref(), + options.fn_params, + )?; Series::try_from((ca.name(), arr)) } pub trait SeriesOpsTime: AsSeries { + /// Apply a rolling mean to a Series based on another Series. + #[cfg(feature = "rolling_window_by")] + fn rolling_mean_by( + &self, + by: &Series, + options: RollingOptionsDynamicWindow, + ) -> PolarsResult { + let s = self.as_series().to_float()?; + with_match_physical_float_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg_by( + ca, + by, + options, + &super::rolling_kernels::no_nulls::rolling_mean, + ) + }) + } /// Apply a rolling mean to a Series. /// /// See: [`RollingAgg::rolling_mean`] #[cfg(feature = "rolling_window")] - fn rolling_mean(&self, options: RollingOptionsImpl) -> PolarsResult { + fn rolling_mean(&self, options: RollingOptionsFixedWindow) -> PolarsResult { let s = self.as_series().to_float()?; with_match_physical_float_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); @@ -105,13 +168,31 @@ pub trait SeriesOpsTime: AsSeries { options, &rolling::no_nulls::rolling_mean, &rolling::nulls::rolling_mean, - Some(&super::rolling_kernels::no_nulls::rolling_mean), ) }) } + /// Apply a rolling sum to a Series based on another Series. + #[cfg(feature = "rolling_window_by")] + fn rolling_sum_by( + &self, + by: &Series, + options: RollingOptionsDynamicWindow, + ) -> PolarsResult { + let s = self.as_series().clone(); + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg_by( + ca, + by, + options, + &super::rolling_kernels::no_nulls::rolling_sum, + ) + }) + } + /// Apply a rolling sum to a Series. #[cfg(feature = "rolling_window")] - fn rolling_sum(&self, options: RollingOptionsImpl) -> PolarsResult { + fn rolling_sum(&self, options: RollingOptionsFixedWindow) -> PolarsResult { let mut s = self.as_series().clone(); if options.weights.is_some() { s = s.to_float()?; @@ -124,14 +205,32 @@ pub trait SeriesOpsTime: AsSeries { options, &rolling::no_nulls::rolling_sum, &rolling::nulls::rolling_sum, - Some(&super::rolling_kernels::no_nulls::rolling_sum), ) }) } + /// Apply a rolling quantile to a Series based on another Series. + #[cfg(feature = "rolling_window_by")] + fn rolling_quantile_by( + &self, + by: &Series, + options: RollingOptionsDynamicWindow, + ) -> PolarsResult { + let s = self.as_series().to_float()?; + with_match_physical_float_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg_by( + ca, + by, + options, + &super::rolling_kernels::no_nulls::rolling_quantile, + ) + }) + } + /// Apply a rolling quantile to a Series. #[cfg(feature = "rolling_window")] - fn rolling_quantile(&self, options: RollingOptionsImpl) -> PolarsResult { + fn rolling_quantile(&self, options: RollingOptionsFixedWindow) -> PolarsResult { let s = self.as_series().to_float()?; with_match_physical_float_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); @@ -140,14 +239,32 @@ pub trait SeriesOpsTime: AsSeries { options, &rolling::no_nulls::rolling_quantile, &rolling::nulls::rolling_quantile, - Some(&super::rolling_kernels::no_nulls::rolling_quantile), ) }) } + /// Apply a rolling min to a Series based on another Series. + #[cfg(feature = "rolling_window_by")] + fn rolling_min_by( + &self, + by: &Series, + options: RollingOptionsDynamicWindow, + ) -> PolarsResult { + let s = self.as_series().clone(); + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg_by( + ca, + by, + options, + &super::rolling_kernels::no_nulls::rolling_min, + ) + }) + } + /// Apply a rolling min to a Series. #[cfg(feature = "rolling_window")] - fn rolling_min(&self, options: RollingOptionsImpl) -> PolarsResult { + fn rolling_min(&self, options: RollingOptionsFixedWindow) -> PolarsResult { let mut s = self.as_series().clone(); if options.weights.is_some() { s = s.to_float()?; @@ -160,13 +277,32 @@ pub trait SeriesOpsTime: AsSeries { options, &rolling::no_nulls::rolling_min, &rolling::nulls::rolling_min, - Some(&super::rolling_kernels::no_nulls::rolling_min), ) }) } + + /// Apply a rolling max to a Series based on another Series. + #[cfg(feature = "rolling_window_by")] + fn rolling_max_by( + &self, + by: &Series, + options: RollingOptionsDynamicWindow, + ) -> PolarsResult { + let s = self.as_series().clone(); + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg_by( + ca, + by, + options, + &super::rolling_kernels::no_nulls::rolling_max, + ) + }) + } + /// Apply a rolling max to a Series. #[cfg(feature = "rolling_window")] - fn rolling_max(&self, options: RollingOptionsImpl) -> PolarsResult { + fn rolling_max(&self, options: RollingOptionsFixedWindow) -> PolarsResult { let mut s = self.as_series().clone(); if options.weights.is_some() { s = s.to_float()?; @@ -179,14 +315,48 @@ pub trait SeriesOpsTime: AsSeries { options, &rolling::no_nulls::rolling_max, &rolling::nulls::rolling_max, - Some(&super::rolling_kernels::no_nulls::rolling_max), + ) + }) + } + + /// Apply a rolling variance to a Series based on another Series. + #[cfg(feature = "rolling_window_by")] + fn rolling_var_by( + &self, + by: &Series, + options: RollingOptionsDynamicWindow, + ) -> PolarsResult { + let s = self.as_series().to_float()?; + + with_match_physical_float_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + let mut ca = ca.clone(); + + if let Some(idx) = ca.first_non_null() { + let k = ca.get(idx).unwrap(); + // TODO! remove this! + // This is a temporary hack to improve numeric stability. + // var(X) = var(X - k) + // This is temporary as we will rework the rolling methods + // the 100.0 absolute boundary is arbitrarily chosen. + // the algorithm will square numbers, so it loses precision rapidly + if k.abs() > 100.0 { + ca = ca - k; + } + } + + rolling_agg_by( + &ca, + by, + options, + &super::rolling_kernels::no_nulls::rolling_var, ) }) } /// Apply a rolling variance to a Series. #[cfg(feature = "rolling_window")] - fn rolling_var(&self, options: RollingOptionsImpl) -> PolarsResult { + fn rolling_var(&self, options: RollingOptionsFixedWindow) -> PolarsResult { let s = self.as_series().to_float()?; with_match_physical_float_polars_type!(s.dtype(), |$T| { @@ -211,14 +381,36 @@ pub trait SeriesOpsTime: AsSeries { options, &rolling::no_nulls::rolling_var, &rolling::nulls::rolling_var, - Some(&super::rolling_kernels::no_nulls::rolling_var), ) }) } + /// Apply a rolling std_dev to a Series based on another Series. + #[cfg(feature = "rolling_window_by")] + fn rolling_std_by( + &self, + by: &Series, + options: RollingOptionsDynamicWindow, + ) -> PolarsResult { + self.rolling_var_by(by, options).map(|mut s| { + match s.dtype().clone() { + DataType::Float32 => { + let ca: &mut ChunkedArray = s._get_inner_mut().as_mut(); + ca.apply_mut(|v| v.powf(0.5)) + }, + DataType::Float64 => { + let ca: &mut ChunkedArray = s._get_inner_mut().as_mut(); + ca.apply_mut(|v| v.powf(0.5)) + }, + _ => unreachable!(), + } + s + }) + } + /// Apply a rolling std_dev to a Series. #[cfg(feature = "rolling_window")] - fn rolling_std(&self, options: RollingOptionsImpl) -> PolarsResult { + fn rolling_std(&self, options: RollingOptionsFixedWindow) -> PolarsResult { self.rolling_var(options).map(|mut s| { match s.dtype().clone() { DataType::Float32 => { diff --git a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs index d5ae53e1459f..0b2909b5dda4 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs @@ -1,4 +1,5 @@ mod dispatch; +#[cfg(feature = "rolling_window_by")] mod rolling_kernels; use arrow::array::{Array, ArrayRef, PrimitiveArray}; @@ -12,20 +13,13 @@ use crate::prelude::*; #[derive(Clone, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct RollingOptions { +pub struct RollingOptionsDynamicWindow { /// The length of the window. pub window_size: Duration, /// Amount of elements in the window that should be filled before computing a result. pub min_periods: usize, - /// An optional slice with the same length as the window that will be multiplied - /// elementwise with the values in the window. - pub weights: Option>, - /// Set the labels at the center of the window. - pub center: bool, - /// Compute the rolling aggregates with a window defined by a time column - pub by: Option, - /// The closed window of that time window if given - pub closed_window: Option, + /// Which side windows should be closed. + pub closed_window: ClosedWindow, /// Optional parameters for the rolling function #[cfg_attr(feature = "serde", serde(skip))] pub fn_params: DynArgs, @@ -33,152 +27,14 @@ pub struct RollingOptions { pub warn_if_unsorted: bool, } -impl Default for RollingOptions { - fn default() -> Self { - RollingOptions { - window_size: Duration::parse("3i"), - min_periods: 1, - weights: None, - center: false, - by: None, - closed_window: None, - fn_params: None, - warn_if_unsorted: true, - } - } -} - -#[cfg(feature = "rolling_window")] -impl PartialEq for RollingOptions { +#[cfg(feature = "rolling_window_by")] +impl PartialEq for RollingOptionsDynamicWindow { fn eq(&self, other: &Self) -> bool { self.window_size == other.window_size && self.min_periods == other.min_periods - && self.weights == other.weights - && self.center == other.center - && self.by == other.by && self.closed_window == other.closed_window + && self.warn_if_unsorted == other.warn_if_unsorted && self.fn_params.is_none() && other.fn_params.is_none() } } - -#[derive(Clone)] -pub struct RollingOptionsImpl<'a> { - /// The length of the window. - pub window_size: Duration, - /// Amount of elements in the window that should be filled before computing a result. - pub min_periods: usize, - /// An optional slice with the same length as the window that will be multiplied - /// elementwise with the values in the window. - pub weights: Option>, - /// Set the labels at the center of the window. - pub center: bool, - pub by: Option<&'a [i64]>, - pub tu: Option, - pub tz: Option<&'a TimeZone>, - pub closed_window: Option, - pub fn_params: DynArgs, -} - -impl From for RollingOptionsImpl<'static> { - fn from(options: RollingOptions) -> Self { - RollingOptionsImpl { - window_size: options.window_size, - min_periods: options.min_periods, - weights: options.weights, - center: options.center, - by: None, - tu: None, - tz: None, - closed_window: options.closed_window, - fn_params: options.fn_params, - } - } -} - -impl Default for RollingOptionsImpl<'static> { - fn default() -> Self { - RollingOptionsImpl { - window_size: Duration::parse("3i"), - min_periods: 1, - weights: None, - center: false, - by: None, - tu: None, - tz: None, - closed_window: None, - fn_params: None, - } - } -} - -impl<'a> TryFrom> for RollingOptionsFixedWindow { - type Error = PolarsError; - fn try_from(options: RollingOptionsImpl<'a>) -> PolarsResult { - polars_ensure!( - options.window_size.parsed_int, - InvalidOperation: "if `window_size` is a temporal window (e.g. '1d', '2h, ...), then the `by` argument must be passed" - ); - polars_ensure!( - options.closed_window.is_none(), - InvalidOperation: "`closed_window` is not supported for fixed window size rolling aggregations, \ - consider using DataFrame.rolling for greater flexibility", - ); - let window_size = options.window_size.nanoseconds() as usize; - check_input(window_size, options.min_periods)?; - Ok(RollingOptionsFixedWindow { - window_size, - min_periods: options.min_periods, - weights: options.weights, - center: options.center, - fn_params: options.fn_params, - }) - } -} - -/// utility -fn check_input(window_size: usize, min_periods: usize) -> PolarsResult<()> { - polars_ensure!( - min_periods <= window_size, - ComputeError: "`min_periods` should be <= `window_size`", - ); - Ok(()) -} - -#[derive(Clone)] -pub struct RollingOptionsDynamicWindow<'a> { - /// The length of the window. - pub window_size: Duration, - /// Amount of elements in the window that should be filled before computing a result. - pub min_periods: usize, - pub by: &'a [i64], - pub tu: Option, - pub tz: Option<&'a TimeZone>, - pub closed_window: ClosedWindow, - pub fn_params: DynArgs, -} - -impl<'a> TryFrom> for RollingOptionsDynamicWindow<'a> { - type Error = PolarsError; - fn try_from(options: RollingOptionsImpl<'a>) -> PolarsResult { - let duration = options.window_size; - polars_ensure!(duration.duration_ns() > 0 && !duration.negative, ComputeError:"window size should be strictly positive"); - polars_ensure!( - options.weights.is_none(), - InvalidOperation: "`weights` is not supported in 'rolling_*(..., by=...)' expression" - ); - polars_ensure!( - !options.window_size.parsed_int, - InvalidOperation: "if `by` argument is passed, then `window_size` must be a temporal window (e.g. '1d' or '2h', not '3i')" - ); - Ok(RollingOptionsDynamicWindow { - window_size: options.window_size, - min_periods: options.min_periods, - by: options.by.expect("by must have been set to get here"), - tu: options.tu, - tz: options.tz, - closed_window: options.closed_window.unwrap_or(ClosedWindow::Right), - fn_params: options.fn_params, - }) - } -} diff --git a/crates/polars-time/src/windows/group_by.rs b/crates/polars-time/src/windows/group_by.rs index c7cb2429fa22..7b48db38c8e6 100644 --- a/crates/polars-time/src/windows/group_by.rs +++ b/crates/polars-time/src/windows/group_by.rs @@ -443,7 +443,7 @@ pub(crate) fn group_by_values_iter_lookahead( }) } -#[cfg(feature = "rolling_window")] +#[cfg(feature = "rolling_window_by")] #[inline] pub(crate) fn group_by_values_iter( period: Duration, diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 9056f42abfaa..0c16fa597cfa 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -195,7 +195,8 @@ reinterpret = ["polars-core/reinterpret", "polars-lazy?/reinterpret", "polars-op repeat_by = ["polars-ops/repeat_by", "polars-lazy?/repeat_by"] replace = ["polars-ops/replace", "polars-lazy?/replace"] rle = ["polars-lazy?/rle"] -rolling_window = ["polars-core/rolling_window", "polars-lazy?/rolling_window", "polars-time/rolling_window"] +rolling_window = ["polars-core/rolling_window", "polars-lazy?/rolling_window"] +rolling_window_by = ["polars-core/rolling_window_by", "polars-lazy?/rolling_window_by", "polars-time/rolling_window_by"] round_series = ["polars-ops/round_series", "polars-lazy?/round_series"] row_hash = ["polars-core/row_hash", "polars-lazy?/row_hash"] search_sorted = ["polars-lazy?/search_sorted"] @@ -366,6 +367,7 @@ docs-selection = [ "take_opt_iter", "cum_agg", "rolling_window", + "rolling_window_by", "interpolate", "diff", "rank", diff --git a/crates/polars/tests/it/core/rolling_window.rs b/crates/polars/tests/it/core/rolling_window.rs index 17270932bca2..a374523e1454 100644 --- a/crates/polars/tests/it/core/rolling_window.rs +++ b/crates/polars/tests/it/core/rolling_window.rs @@ -4,8 +4,8 @@ use super::*; fn test_rolling() { let s = Int32Chunked::new("foo", &[1, 2, 3, 2, 1]).into_series(); let a = s - .rolling_sum(RollingOptionsImpl { - window_size: Duration::new(2), + .rolling_sum(RollingOptionsFixedWindow { + window_size: 2, min_periods: 1, ..Default::default() }) @@ -20,8 +20,8 @@ fn test_rolling() { .collect::>() ); let a = s - .rolling_min(RollingOptionsImpl { - window_size: Duration::new(2), + .rolling_min(RollingOptionsFixedWindow { + window_size: 2, min_periods: 1, ..Default::default() }) @@ -36,8 +36,8 @@ fn test_rolling() { .collect::>() ); let a = s - .rolling_max(RollingOptionsImpl { - window_size: Duration::new(2), + .rolling_max(RollingOptionsFixedWindow { + window_size: 2, weights: Some(vec![1., 1.]), min_periods: 1, ..Default::default() @@ -59,8 +59,8 @@ fn test_rolling() { fn test_rolling_min_periods() { let s = Int32Chunked::new("foo", &[1, 2, 3, 2, 1]).into_series(); let a = s - .rolling_max(RollingOptionsImpl { - window_size: Duration::new(2), + .rolling_max(RollingOptionsFixedWindow { + window_size: 2, min_periods: 2, ..Default::default() }) @@ -87,8 +87,8 @@ fn test_rolling_mean() { // check err on wrong input assert!(s - .rolling_mean(RollingOptionsImpl { - window_size: Duration::new(1), + .rolling_mean(RollingOptionsFixedWindow { + window_size: 1, min_periods: 2, ..Default::default() }) @@ -96,8 +96,8 @@ fn test_rolling_mean() { // validate that we divide by the proper window length. (same as pandas) let a = s - .rolling_mean(RollingOptionsImpl { - window_size: Duration::new(3), + .rolling_mean(RollingOptionsFixedWindow { + window_size: 3, min_periods: 1, center: false, ..Default::default() @@ -119,8 +119,8 @@ fn test_rolling_mean() { // check centered rolling window let a = s - .rolling_mean(RollingOptionsImpl { - window_size: Duration::new(3), + .rolling_mean(RollingOptionsFixedWindow { + window_size: 3, min_periods: 1, center: true, ..Default::default() @@ -144,8 +144,8 @@ fn test_rolling_mean() { let ca = Int32Chunked::from_slice("", &[1, 8, 6, 2, 16, 10]); let out = ca .into_series() - .rolling_mean(RollingOptionsImpl { - window_size: Duration::new(2), + .rolling_mean(RollingOptionsFixedWindow { + window_size: 2, weights: None, min_periods: 2, center: false, @@ -211,8 +211,8 @@ fn test_rolling_var() { .into_series(); // window larger than array assert_eq!( - s.rolling_var(RollingOptionsImpl { - window_size: Duration::new(10), + s.rolling_var(RollingOptionsFixedWindow { + window_size: 10, min_periods: 10, ..Default::default() }) @@ -221,8 +221,8 @@ fn test_rolling_var() { s.len() ); - let options = RollingOptionsImpl { - window_size: Duration::new(3), + let options = RollingOptionsFixedWindow { + window_size: 3, min_periods: 3, ..Default::default() }; @@ -252,8 +252,8 @@ fn test_rolling_var() { // check centered rolling window let out = s - .rolling_var(RollingOptionsImpl { - window_size: Duration::new(4), + .rolling_var(RollingOptionsFixedWindow { + window_size: 4, min_periods: 3, center: true, ..Default::default() diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index cb530b2d1445..d4452c23f913 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -77,6 +77,7 @@ features = [ "reinterpret", "replace", "rolling_window", + "rolling_window_by", "round_series", "row_hash", "rows", diff --git a/py-polars/polars/_utils/deprecation.py b/py-polars/polars/_utils/deprecation.py index b74c1a3a7c07..9c3382f4982d 100644 --- a/py-polars/polars/_utils/deprecation.py +++ b/py-polars/polars/_utils/deprecation.py @@ -6,13 +6,14 @@ from typing import TYPE_CHECKING, Callable, Sequence, TypeVar from polars._utils.various import find_stacklevel +from polars.exceptions import InvalidOperationError if TYPE_CHECKING: import sys from typing import Mapping from polars import Expr - from polars.type_aliases import Ambiguous + from polars.type_aliases import Ambiguous, ClosedInterval if sys.version_info >= (3, 10): from typing import ParamSpec @@ -275,3 +276,36 @@ def deprecate_saturating(duration: T) -> T: ) return duration[:-11] # type: ignore[return-value] return duration + + +def validate_rolling_by_aggs_arguments( + weights: list[float] | None, *, center: bool +) -> None: + if weights is not None: + msg = "`weights` is not supported in `rolling_*(..., by=...)` expression" + raise InvalidOperationError(msg) + if center: + msg = "`center=True` is not supported in `rolling_*(..., by=...)` expression" + raise InvalidOperationError(msg) + + +def validate_rolling_aggs_arguments( + window_size: int | str, closed: ClosedInterval | None +) -> int: + if isinstance(window_size, str): + issue_deprecation_warning( + "Passing a str to `rolling_*` is deprecated.\n\n" + "Please, either:\n" + "- pass an integer if you want a fixed window size (e.g. `rolling_mean(3)`)\n" + "- pass a string if you are computing the rolling operation based on another column (e.g. `rolling_mean_by('date', '3d'))\n", + version="0.20.26", + ) + try: + window_size = int(window_size.rstrip("i")) + except ValueError: + msg = f"Expected a string of the form 'ni', where `n` is a positive integer, got: {window_size}" + raise InvalidOperationError(msg) from None + if closed is not None: + msg = "`closed` is not supported in `rolling_*(...)` expression" + raise InvalidOperationError(msg) + return window_size diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index f662b348396a..799ba329509e 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -33,6 +33,8 @@ deprecate_renamed_parameter, deprecate_saturating, issue_deprecation_warning, + validate_rolling_aggs_arguments, + validate_rolling_by_aggs_arguments, ) from polars._utils.parse_expr_input import ( parse_as_expression, @@ -6165,7 +6167,7 @@ def interpolate(self, method: InterpolationMethod = "linear") -> Self: @unstable() def rolling_min_by( self, - by: str, + by: IntoExpr, window_size: timedelta | str, *, min_periods: int = 1, @@ -6285,12 +6287,11 @@ def rolling_min_by( └───────┴─────────────────────┴─────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_min( - window_size, None, min_periods, False, by, closed, warn_if_unsorted + self._pyexpr.rolling_min_by( + by, window_size, min_periods, closed, warn_if_unsorted ) ) @@ -6443,12 +6444,11 @@ def rolling_max_by( └───────┴─────────────────────┴─────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_max( - window_size, None, min_periods, False, by, closed, warn_if_unsorted + self._pyexpr.rolling_max_by( + by, window_size, min_periods, closed, warn_if_unsorted ) ) @@ -6603,16 +6603,13 @@ def rolling_mean_by( └───────┴─────────────────────┴──────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_mean( + self._pyexpr.rolling_mean_by( + by, window_size, - None, min_periods, - False, - by, closed, warn_if_unsorted, ) @@ -6767,12 +6764,11 @@ def rolling_sum_by( └───────┴─────────────────────┴─────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_sum( - window_size, None, min_periods, False, by, closed, warn_if_unsorted + self._pyexpr.rolling_sum_by( + by, window_size, min_periods, closed, warn_if_unsorted ) ) @@ -6929,16 +6925,13 @@ def rolling_std_by( └───────┴─────────────────────┴─────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_std( + self._pyexpr.rolling_std_by( + by, window_size, - None, min_periods, - False, - by, closed, ddof, warn_if_unsorted, @@ -7097,16 +7090,13 @@ def rolling_var_by( └───────┴─────────────────────┴─────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_var( + self._pyexpr.rolling_var_by( + by, window_size, - None, min_periods, - False, - by, closed, ddof, warn_if_unsorted, @@ -7238,12 +7228,11 @@ def rolling_median_by( └───────┴─────────────────────┴────────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_median( - window_size, None, min_periods, False, by, closed, warn_if_unsorted + self._pyexpr.rolling_median_by( + by, window_size, min_periods, closed, warn_if_unsorted ) ) @@ -7378,18 +7367,15 @@ def rolling_quantile_by( └───────┴─────────────────────┴──────────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_quantile( + self._pyexpr.rolling_quantile_by( + by, quantile, interpolation, window_size, - None, min_periods, - False, - by, closed, warn_if_unsorted, ) @@ -7612,9 +7598,22 @@ def rolling_min( "`rolling_min(..., by='foo')`, please use `rolling_min_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_min_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + warn_if_unsorted=warn_if_unsorted, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_min( - window_size, weights, min_periods, center, by, closed, warn_if_unsorted + window_size, + weights, + min_periods, + center, ) ) @@ -7861,9 +7860,22 @@ def rolling_max( "`rolling_max(..., by='foo')`, please use `rolling_max_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_max_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + warn_if_unsorted=warn_if_unsorted, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_max( - window_size, weights, min_periods, center, by, closed, warn_if_unsorted + window_size, + weights, + min_periods, + center, ) ) @@ -8112,15 +8124,22 @@ def rolling_mean( "`rolling_mean(..., by='foo')`, please use `rolling_mean_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_mean_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + warn_if_unsorted=warn_if_unsorted, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_mean( window_size, weights, min_periods, center, - by, - closed, - warn_if_unsorted, ) ) @@ -8367,9 +8386,22 @@ def rolling_sum( "`rolling_sum(..., by='foo')`, please use `rolling_sum_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_sum_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + warn_if_unsorted=warn_if_unsorted, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_sum( - window_size, weights, min_periods, center, by, closed, warn_if_unsorted + window_size, + weights, + min_periods, + center, ) ) @@ -8616,16 +8648,24 @@ def rolling_std( "`rolling_std(..., by='foo')`, please use `rolling_std_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_std_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + ddof=ddof, + warn_if_unsorted=warn_if_unsorted, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_std( window_size, weights, min_periods, center, - by, - closed, ddof, - warn_if_unsorted, ) ) @@ -8871,16 +8911,24 @@ def rolling_var( "`rolling_var(..., by='foo')`, please use `rolling_var_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_var_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + ddof=ddof, + warn_if_unsorted=warn_if_unsorted, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_var( window_size, weights, min_periods, center, - by, - closed, ddof, - warn_if_unsorted, ) ) @@ -9046,9 +9094,22 @@ def rolling_median( "`rolling_median(..., by='foo')`, please use `rolling_median_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_median_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + warn_if_unsorted=warn_if_unsorted, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_median( - window_size, weights, min_periods, center, by, closed, warn_if_unsorted + window_size, + weights, + min_periods, + center, ) ) @@ -9247,6 +9308,17 @@ def rolling_quantile( "`rolling_quantile(..., by='foo')`, please use `rolling_quantile_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_quantile_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + warn_if_unsorted=warn_if_unsorted, + quantile=quantile, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_quantile( quantile, @@ -9255,9 +9327,6 @@ def rolling_quantile( weights, min_periods, center, - by, - closed, - warn_if_unsorted, ) ) @@ -11940,7 +12009,7 @@ def _prepare_alpha( def _prepare_rolling_window_args( window_size: int | timedelta | str, min_periods: int | None = None, -) -> tuple[str, int]: +) -> tuple[int | str, int]: if isinstance(window_size, int): if window_size < 1: msg = "`window_size` must be positive" @@ -11948,9 +12017,16 @@ def _prepare_rolling_window_args( if min_periods is None: min_periods = window_size - window_size = f"{window_size}i" elif isinstance(window_size, timedelta): window_size = parse_as_duration_string(window_size) if min_periods is None: min_periods = 1 return window_size, min_periods + + +def _prepare_rolling_by_window_args( + window_size: timedelta | str, +) -> str: + if isinstance(window_size, timedelta): + window_size = parse_as_duration_string(window_size) + return window_size diff --git a/py-polars/src/expr/rolling.rs b/py-polars/src/expr/rolling.rs index 5c0c24e3a7a5..44af77b2f469 100644 --- a/py-polars/src/expr/rolling.rs +++ b/py-polars/src/expr/rolling.rs @@ -10,201 +10,293 @@ use crate::{PyExpr, PySeries}; #[pymethods] impl PyExpr { - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] + #[pyo3(signature = (window_size, weights, min_periods, center))] fn rolling_sum( &self, - window_size: &str, + window_size: usize, weights: Option>, min_periods: usize, center: bool, - by: Option, - closed: Option>, - warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { - window_size: Duration::parse(window_size), + let options = RollingOptionsFixedWindow { + window_size, weights, min_periods, center, - by, - closed_window: closed.map(|c| c.0), - warn_if_unsorted, ..Default::default() }; self.inner.clone().rolling_sum(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] - fn rolling_min( + #[pyo3(signature = (by, window_size, min_periods, closed, warn_if_unsorted))] + fn rolling_sum_by( &self, + by: PyExpr, window_size: &str, - weights: Option>, min_periods: usize, - center: bool, - by: Option, - closed: Option>, + closed: Wrap, warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { + let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), + min_periods, + closed_window: closed.0, + warn_if_unsorted, + fn_params: None, + }; + self.inner.clone().rolling_sum_by(by.inner, options).into() + } + + #[pyo3(signature = (window_size, weights, min_periods, center))] + fn rolling_min( + &self, + window_size: usize, + weights: Option>, + min_periods: usize, + center: bool, + ) -> Self { + let options = RollingOptionsFixedWindow { + window_size, weights, min_periods, center, - by, - closed_window: closed.map(|c| c.0), - warn_if_unsorted, ..Default::default() }; self.inner.clone().rolling_min(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] - fn rolling_max( + #[pyo3(signature = (by, window_size, min_periods, closed, warn_if_unsorted))] + fn rolling_min_by( &self, + by: PyExpr, window_size: &str, - weights: Option>, min_periods: usize, - center: bool, - by: Option, - closed: Option>, + closed: Wrap, warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { + let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), + min_periods, + closed_window: closed.0, + warn_if_unsorted, + fn_params: None, + }; + self.inner.clone().rolling_min_by(by.inner, options).into() + } + + #[pyo3(signature = (window_size, weights, min_periods, center))] + fn rolling_max( + &self, + window_size: usize, + weights: Option>, + min_periods: usize, + center: bool, + ) -> Self { + let options = RollingOptionsFixedWindow { + window_size, weights, min_periods, center, - by, - closed_window: closed.map(|c| c.0), - warn_if_unsorted, ..Default::default() }; self.inner.clone().rolling_max(options).into() } + #[pyo3(signature = (by, window_size, min_periods, closed, warn_if_unsorted))] + fn rolling_max_by( + &self, + by: PyExpr, + window_size: &str, + min_periods: usize, + closed: Wrap, + warn_if_unsorted: bool, + ) -> Self { + let options = RollingOptionsDynamicWindow { + window_size: Duration::parse(window_size), + min_periods, + closed_window: closed.0, + warn_if_unsorted, + fn_params: None, + }; + self.inner.clone().rolling_max_by(by.inner, options).into() + } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] + #[pyo3(signature = (window_size, weights, min_periods, center))] fn rolling_mean( &self, - window_size: &str, + window_size: usize, weights: Option>, min_periods: usize, center: bool, - by: Option, - closed: Option>, - warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { - window_size: Duration::parse(window_size), + let options = RollingOptionsFixedWindow { + window_size, weights, min_periods, center, - by, - closed_window: closed.map(|c| c.0), - warn_if_unsorted, ..Default::default() }; self.inner.clone().rolling_mean(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, ddof, warn_if_unsorted))] - fn rolling_std( + #[pyo3(signature = (by, window_size, min_periods, closed, warn_if_unsorted))] + fn rolling_mean_by( &self, + by: PyExpr, window_size: &str, + min_periods: usize, + closed: Wrap, + warn_if_unsorted: bool, + ) -> Self { + let options = RollingOptionsDynamicWindow { + window_size: Duration::parse(window_size), + min_periods, + closed_window: closed.0, + warn_if_unsorted, + fn_params: None, + }; + + self.inner.clone().rolling_mean_by(by.inner, options).into() + } + + #[pyo3(signature = (window_size, weights, min_periods, center, ddof))] + fn rolling_std( + &self, + window_size: usize, weights: Option>, min_periods: usize, center: bool, - by: Option, - closed: Option>, ddof: u8, - warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { - window_size: Duration::parse(window_size), + let options = RollingOptionsFixedWindow { + window_size, weights, min_periods, center, - by, - closed_window: closed.map(|c| c.0), fn_params: Some(Arc::new(RollingVarParams { ddof }) as Arc), - warn_if_unsorted, }; self.inner.clone().rolling_std(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, ddof, warn_if_unsorted))] - fn rolling_var( + #[pyo3(signature = (by, window_size, min_periods, closed, ddof, warn_if_unsorted))] + fn rolling_std_by( &self, + by: PyExpr, window_size: &str, - weights: Option>, min_periods: usize, - center: bool, - by: Option, - closed: Option>, + closed: Wrap, ddof: u8, warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { + let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), + min_periods, + closed_window: closed.0, + fn_params: Some(Arc::new(RollingVarParams { ddof }) as Arc), + warn_if_unsorted, + }; + + self.inner.clone().rolling_std_by(by.inner, options).into() + } + + #[pyo3(signature = (window_size, weights, min_periods, center, ddof))] + fn rolling_var( + &self, + window_size: usize, + weights: Option>, + min_periods: usize, + center: bool, + ddof: u8, + ) -> Self { + let options = RollingOptionsFixedWindow { + window_size, weights, min_periods, center, - by, - closed_window: closed.map(|c| c.0), fn_params: Some(Arc::new(RollingVarParams { ddof }) as Arc), - warn_if_unsorted, }; self.inner.clone().rolling_var(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] - fn rolling_median( + #[pyo3(signature = (by, window_size, min_periods, closed, ddof, warn_if_unsorted))] + fn rolling_var_by( &self, + by: PyExpr, window_size: &str, + min_periods: usize, + closed: Wrap, + ddof: u8, + warn_if_unsorted: bool, + ) -> Self { + let options = RollingOptionsDynamicWindow { + window_size: Duration::parse(window_size), + min_periods, + closed_window: closed.0, + fn_params: Some(Arc::new(RollingVarParams { ddof }) as Arc), + warn_if_unsorted, + }; + + self.inner.clone().rolling_var_by(by.inner, options).into() + } + + #[pyo3(signature = (window_size, weights, min_periods, center))] + fn rolling_median( + &self, + window_size: usize, weights: Option>, min_periods: usize, center: bool, - by: Option, - closed: Option>, + ) -> Self { + let options = RollingOptionsFixedWindow { + window_size, + min_periods, + weights, + center, + fn_params: None, + }; + self.inner.clone().rolling_median(options).into() + } + + #[pyo3(signature = (by, window_size, min_periods, closed, warn_if_unsorted))] + fn rolling_median_by( + &self, + by: PyExpr, + window_size: &str, + min_periods: usize, + closed: Wrap, warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { + let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), - weights, min_periods, - center, - by, - closed_window: closed.map(|c| c.0), + closed_window: closed.0, fn_params: None, warn_if_unsorted, }; - self.inner.clone().rolling_median(options).into() + self.inner + .clone() + .rolling_median_by(by.inner, options) + .into() } - #[pyo3(signature = (quantile, interpolation, window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] + #[pyo3(signature = (quantile, interpolation, window_size, weights, min_periods, center))] fn rolling_quantile( &self, quantile: f64, interpolation: Wrap, - window_size: &str, + window_size: usize, weights: Option>, min_periods: usize, center: bool, - by: Option, - closed: Option>, - warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { - window_size: Duration::parse(window_size), + let options = RollingOptionsFixedWindow { + window_size, weights, min_periods, center, - by, - closed_window: closed.map(|c| c.0), fn_params: None, - warn_if_unsorted, }; self.inner @@ -213,6 +305,31 @@ impl PyExpr { .into() } + #[pyo3(signature = (by, quantile, interpolation, window_size, min_periods, closed, warn_if_unsorted))] + fn rolling_quantile_by( + &self, + by: PyExpr, + quantile: f64, + interpolation: Wrap, + window_size: &str, + min_periods: usize, + closed: Wrap, + warn_if_unsorted: bool, + ) -> Self { + let options = RollingOptionsDynamicWindow { + window_size: Duration::parse(window_size), + min_periods, + closed_window: closed.0, + fn_params: None, + warn_if_unsorted, + }; + + self.inner + .clone() + .rolling_quantile_by(by.inner, interpolation.0, quantile, options) + .into() + } + fn rolling_skew(&self, window_size: usize, bias: bool) -> Self { self.inner.clone().rolling_skew(window_size, bias).into() } diff --git a/py-polars/src/lazyframe/visitor/expr_nodes.rs b/py-polars/src/lazyframe/visitor/expr_nodes.rs index 5fd75d05bbc0..c5db4007dd16 100644 --- a/py-polars/src/lazyframe/visitor/expr_nodes.rs +++ b/py-polars/src/lazyframe/visitor/expr_nodes.rs @@ -1,5 +1,6 @@ use polars_core::series::IsSorted; use polars_plan::dsl::function_expr::rolling::RollingFunction; +use polars_plan::dsl::function_expr::rolling_by::RollingFunctionBy; use polars_plan::dsl::function_expr::trigonometry::TrigonometricFunction; use polars_plan::dsl::BooleanFunction; use polars_plan::prelude::{ @@ -628,49 +629,51 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { RollingFunction::Min(_) => { return Err(PyNotImplementedError::new_err("rolling min")) }, - RollingFunction::MinBy(_) => { - return Err(PyNotImplementedError::new_err("rolling min by")) - }, RollingFunction::Max(_) => { return Err(PyNotImplementedError::new_err("rolling max")) }, - RollingFunction::MaxBy(_) => { - return Err(PyNotImplementedError::new_err("rolling max by")) - }, RollingFunction::Mean(_) => { return Err(PyNotImplementedError::new_err("rolling mean")) }, - RollingFunction::MeanBy(_) => { - return Err(PyNotImplementedError::new_err("rolling mean by")) - }, RollingFunction::Sum(_) => { return Err(PyNotImplementedError::new_err("rolling sum")) }, - RollingFunction::SumBy(_) => { - return Err(PyNotImplementedError::new_err("rolling sum by")) - }, RollingFunction::Quantile(_) => { return Err(PyNotImplementedError::new_err("rolling quantile")) }, - RollingFunction::QuantileBy(_) => { - return Err(PyNotImplementedError::new_err("rolling quantile by")) - }, RollingFunction::Var(_) => { return Err(PyNotImplementedError::new_err("rolling var")) }, - RollingFunction::VarBy(_) => { - return Err(PyNotImplementedError::new_err("rolling var by")) - }, RollingFunction::Std(_) => { return Err(PyNotImplementedError::new_err("rolling std")) }, - RollingFunction::StdBy(_) => { - return Err(PyNotImplementedError::new_err("rolling std by")) - }, RollingFunction::Skew(_, _) => { return Err(PyNotImplementedError::new_err("rolling skew")) }, }, + FunctionExpr::RollingExprBy(rolling) => match rolling { + RollingFunctionBy::MinBy(_) => { + return Err(PyNotImplementedError::new_err("rolling min by")) + }, + RollingFunctionBy::MaxBy(_) => { + return Err(PyNotImplementedError::new_err("rolling max by")) + }, + RollingFunctionBy::MeanBy(_) => { + return Err(PyNotImplementedError::new_err("rolling mean by")) + }, + RollingFunctionBy::SumBy(_) => { + return Err(PyNotImplementedError::new_err("rolling sum by")) + }, + RollingFunctionBy::QuantileBy(_) => { + return Err(PyNotImplementedError::new_err("rolling quantile by")) + }, + RollingFunctionBy::VarBy(_) => { + return Err(PyNotImplementedError::new_err("rolling var by")) + }, + RollingFunctionBy::StdBy(_) => { + return Err(PyNotImplementedError::new_err("rolling std by")) + }, + }, FunctionExpr::ShiftAndFill => { return Err(PyNotImplementedError::new_err("shift and fill")) }, diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index bc69f1d5ca0b..8898c8a29d31 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -52,6 +52,9 @@ def test_rolling_kernels_and_rolling( pl.col("values").rolling_var_by("dt", period, closed=closed).alias("var"), pl.col("values").rolling_mean_by("dt", period, closed=closed).alias("mean"), pl.col("values").rolling_std_by("dt", period, closed=closed).alias("std"), + pl.col("values") + .rolling_quantile_by("dt", period, quantile=0.2, closed=closed) + .alias("quantile"), ] ) out2 = ( @@ -63,6 +66,7 @@ def test_rolling_kernels_and_rolling( pl.col("values").var().alias("var"), pl.col("values").mean().alias("mean"), pl.col("values").std().alias("std"), + pl.col("values").quantile(quantile=0.2).alias("quantile"), ] ) ) @@ -220,13 +224,13 @@ def test_rolling_crossing_dst( def test_rolling_by_invalid() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).sort("a") - msg = "in `rolling_min` operation, `by` argument of dtype `i64` is not supported" + msg = r"in `rolling_\*_by` operation, `by` argument of dtype `i64` is not supported" with pytest.raises(InvalidOperationError, match=msg): - df.select(pl.col("b").rolling_min_by("a", 2)) # type: ignore[arg-type] + df.select(pl.col("b").rolling_min_by("a", "2i")) df = pl.DataFrame({"a": [1, 2, 3], "b": [date(2020, 1, 1)] * 3}).sort("b") - msg = "if `by` argument is passed, then `window_size` must be a temporal window" + msg = "`window_size` duration may not be a parsed integer" with pytest.raises(InvalidOperationError, match=msg): - df.select(pl.col("a").rolling_min_by("b", 2)) # type: ignore[arg-type] + df.select(pl.col("a").rolling_min_by("b", "2i")) def test_rolling_infinity() -> None: @@ -240,7 +244,10 @@ def test_rolling_invalid_closed_option() -> None: df = pl.DataFrame( {"a": [4, 5, 6], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} ).sort("a", "b") - with pytest.raises(InvalidOperationError, match="consider using DataFrame.rolling"): + with pytest.raises( + InvalidOperationError, + match=r"`closed` is not supported in `rolling_\*\(...\)` expression", + ): df.with_columns(pl.col("a").rolling_sum(2, closed="left")) @@ -248,21 +255,31 @@ def test_rolling_by_non_temporal_window_size() -> None: df = pl.DataFrame( {"a": [4, 5, 6], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} ).sort("a", "b") - msg = "if `by` argument is passed, then `window_size` must be a temporal window" + msg = "`window_size` duration may not be a parsed integer" with pytest.raises(InvalidOperationError, match=msg): - df.with_columns(pl.col("a").rolling_sum_by("b", 2, closed="left")) # type: ignore[arg-type] + df.with_columns(pl.col("a").rolling_sum_by("b", "2i", closed="left")) def test_rolling_by_weights() -> None: df = pl.DataFrame( {"a": [4, 5, 6], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} ).sort("b") - msg = r"`weights` is not supported in 'rolling_\*\(..., by=...\)' expression" + msg = r"`weights` is not supported in `rolling_\*\(..., by=...\)` expression" with pytest.raises(InvalidOperationError, match=msg): # noqa: SIM117 with pytest.deprecated_call(match="rolling_sum_by"): df.with_columns(pl.col("a").rolling_sum("2d", by="b", weights=[1, 2])) +def test_rolling_by_center() -> None: + df = pl.DataFrame( + {"a": [4, 5, 6], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} + ).sort("b") + msg = r"`center=True` is not supported in `rolling_\*\(..., by=...\)` expression" + with pytest.raises(InvalidOperationError, match=msg): # noqa: SIM117 + with pytest.deprecated_call(match="rolling_sum_by"): + df.with_columns(pl.col("a").rolling_sum("2d", by="b", center=True)) + + def test_rolling_extrema() -> None: # sorted data and nulls flags trigger different kernels df = ( @@ -566,11 +583,15 @@ def test_rolling_negative_period() -> None: df.lazy().rolling("ts", period="-1d", offset="-1d").agg( pl.col("value") ).collect() - with pytest.raises(ComputeError, match="window size should be strictly positive"): + with pytest.raises( + InvalidOperationError, match="`window_size` must be strictly positive" + ): df.select( pl.col("value").rolling_min_by("ts", window_size="-1d", closed="left") ) - with pytest.raises(ComputeError, match="window size should be strictly positive"): + with pytest.raises( + InvalidOperationError, match="`window_size` must be strictly positive" + ): df.lazy().select( pl.col("value").rolling_min_by("ts", window_size="-1d", closed="left") ).collect() @@ -984,7 +1005,10 @@ def test_temporal_windows_size_without_by_15977() -> None: df = pl.DataFrame( {"a": [1, 2, 3], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} ) - with pytest.raises( - pl.InvalidOperationError, match="the `by` argument must be passed" + with pytest.raises( # noqa: SIM117 + InvalidOperationError, match="Expected a string of the form 'ni'" ): - df.select(pl.col("a").rolling_mean("3d")) + with pytest.deprecated_call( + match=r"Passing a str to `rolling_\*` is deprecated" + ): + df.select(pl.col("a").rolling_mean("3d")) From eabca7adf88f646c947fcc679f8b6617070679a4 Mon Sep 17 00:00:00 2001 From: Jeremy Monat Date: Fri, 10 May 2024 13:37:15 -0400 Subject: [PATCH 05/29] docs(python): Explain how Polars floor division differs from Python (#16054) --- py-polars/polars/expr/expr.py | 52 +++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 799ba329509e..4ce79d1ab742 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -5519,6 +5519,58 @@ def floordiv(self, other: Any) -> Self: │ 4 ┆ 2.0 ┆ 2 │ │ 5 ┆ 2.5 ┆ 2 │ └─────┴─────┴──────┘ + + Note that Polars' `floordiv` is subtly different from Python's floor division. + For example, consider 6.0 floor-divided by 0.1. + Python gives: + + >>> 6.0 // 0.1 + 59.0 + + because `0.1` is not represented internally as that exact value, + but a slightly larger value. + So the result of the division is slightly less than 60, + meaning the flooring operation returns 59.0. + + Polars instead first does the floating-point division, + resulting in a floating-point value of 60.0, + and then performs the flooring operation using :any:`floor`: + + >>> df = pl.DataFrame({"x": [6.0, 6.03]}) + >>> df.with_columns( + ... pl.col("x").truediv(0.1).alias("x/0.1"), + ... ).with_columns( + ... pl.col("x/0.1").floor().alias("x/0.1 floor"), + ... ) + shape: (2, 3) + ┌──────┬───────┬─────────────┐ + │ x ┆ x/0.1 ┆ x/0.1 floor │ + │ --- ┆ --- ┆ --- │ + │ f64 ┆ f64 ┆ f64 │ + ╞══════╪═══════╪═════════════╡ + │ 6.0 ┆ 60.0 ┆ 60.0 │ + │ 6.03 ┆ 60.3 ┆ 60.0 │ + └──────┴───────┴─────────────┘ + + yielding the more intuitive result 60.0. + The row with x = 6.03 is included to demonstrate + the effect of the flooring operation. + + `floordiv` combines those two steps + to give the same result with one expression: + + >>> df.with_columns( + ... pl.col("x").floordiv(0.1).alias("x//0.1"), + ... ) + shape: (2, 2) + ┌──────┬────────┐ + │ x ┆ x//0.1 │ + │ --- ┆ --- │ + │ f64 ┆ f64 │ + ╞══════╪════════╡ + │ 6.0 ┆ 60.0 │ + │ 6.03 ┆ 60.0 │ + └──────┴────────┘ """ return self.__floordiv__(other) From 21b3d43f777c386e0a0c9cf5604ea6255812a64e Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Fri, 10 May 2024 22:02:31 +0400 Subject: [PATCH 06/29] =?UTF-8?q?feat:=20Allow=20implicit=20string=20?= =?UTF-8?q?=E2=86=92=20temporal=20conversion=20in=20SQL=20comparisons=20(#?= =?UTF-8?q?15958)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 1 + crates/polars-sql/Cargo.toml | 1 + crates/polars-sql/src/context.rs | 18 +-- crates/polars-sql/src/functions.rs | 22 ++-- crates/polars-sql/src/sql_expr.rs | 137 ++++++++++++++++++---- crates/polars-sql/tests/simple_exprs.rs | 33 +++++- py-polars/tests/unit/sql/test_temporal.py | 59 +++++++++- 7 files changed, 228 insertions(+), 43 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 661c396729e8..99c5dc7fbd3c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3122,6 +3122,7 @@ name = "polars-sql" version = "0.39.2" dependencies = [ "hex", + "once_cell", "polars-arrow", "polars-core", "polars-error", diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index 1f2d32413563..6eae9faa2227 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -16,6 +16,7 @@ polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_ polars-plan = { workspace = true } hex = { workspace = true } +once_cell = { workspace = true } rand = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 6fc6ac559968..e594ce5b9e0c 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -370,8 +370,9 @@ impl SQLContext { let mut contains_wildcard_exclude = false; // Filter expression. + let schema = Some(lf.schema()?); if let Some(expr) = select_stmt.selection.as_ref() { - let mut filter_expression = parse_sql_expr(expr, self)?; + let mut filter_expression = parse_sql_expr(expr, self, schema.as_deref())?; lf = self.process_subqueries(lf, vec![&mut filter_expression]); lf = lf.filter(filter_expression); } @@ -382,9 +383,9 @@ impl SQLContext { .iter() .map(|select_item| { Ok(match select_item { - SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, self)?, + SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, self, schema.as_deref())?, SelectItem::ExprWithAlias { expr, alias } => { - let expr = parse_sql_expr(expr, self)?; + let expr = parse_sql_expr(expr, self, schema.as_deref())?; expr.alias(&alias.value) }, SelectItem::QualifiedWildcard(oname, wildcard_options) => self @@ -427,7 +428,7 @@ impl SQLContext { ComputeError: "group_by error: a positive number or an expression expected", )), - _ => parse_sql_expr(e, self), + _ => parse_sql_expr(e, self, schema.as_deref()), }) .collect::>()? } else { @@ -506,8 +507,9 @@ impl SQLContext { lf = self.process_order_by(lf, &query.order_by)?; // Apply optional 'having' clause, post-aggregation. + let schema = Some(lf.schema()?); match select_stmt.having.as_ref() { - Some(expr) => lf.filter(parse_sql_expr(expr, self)?), + Some(expr) => lf.filter(parse_sql_expr(expr, self, schema.as_deref())?), None => lf, } }; @@ -517,10 +519,11 @@ impl SQLContext { Some(Distinct::Distinct) => lf.unique_stable(None, UniqueKeepStrategy::Any), Some(Distinct::On(exprs)) => { // TODO: support exprs in `unique` see https://github.com/pola-rs/polars/issues/5760 + let schema = Some(lf.schema()?); let cols = exprs .iter() .map(|e| { - let expr = parse_sql_expr(e, self)?; + let expr = parse_sql_expr(e, self, schema.as_deref())?; if let Expr::Column(name) = expr { Ok(name.to_string()) } else { @@ -664,8 +667,9 @@ impl SQLContext { let mut by = Vec::with_capacity(ob.len()); let mut descending = Vec::with_capacity(ob.len()); + let schema = Some(lf.schema()?); for ob in ob { - by.push(parse_sql_expr(&ob.expr, self)?); + by.push(parse_sql_expr(&ob.expr, self, schema.as_deref())?); descending.push(!ob.asc.unwrap_or(true)); polars_ensure!( ob.nulls_first.is_none(), diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 05cd4bc8959a..78eb5023d6a4 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -1074,7 +1074,7 @@ impl SQLFunctionVisitor<'_> { .into_iter() .map(|arg| { if let FunctionArgExpr::Expr(e) = arg { - parse_sql_expr(e, self.ctx) + parse_sql_expr(e, self.ctx, None) } else { polars_bail!(ComputeError: "Only expressions are supported in UDFs") } @@ -1130,7 +1130,7 @@ impl SQLFunctionVisitor<'_> { let (order_by, desc): (Vec, Vec) = order_by .iter() .map(|o| { - let expr = parse_sql_expr(&o.expr, self.ctx)?; + let expr = parse_sql_expr(&o.expr, self.ctx, None)?; Ok(match o.asc { Some(b) => (expr, !b), None => (expr, false), @@ -1157,7 +1157,7 @@ impl SQLFunctionVisitor<'_> { let args = extract_args(self.func); match args.as_slice() { [FunctionArgExpr::Expr(sql_expr)] => { - let expr = parse_sql_expr(sql_expr, self.ctx)?; + let expr = parse_sql_expr(sql_expr, self.ctx, None)?; // apply the function on the inner expr -- e.g. SUM(a) -> SUM Ok(f(expr)) }, @@ -1179,7 +1179,7 @@ impl SQLFunctionVisitor<'_> { let args = extract_args(self.func); match args.as_slice() { [FunctionArgExpr::Expr(sql_expr1), FunctionArgExpr::Expr(sql_expr2)] => { - let expr1 = parse_sql_expr(sql_expr1, self.ctx)?; + let expr1 = parse_sql_expr(sql_expr1, self.ctx, None)?; let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?; f(expr1, expr2) }, @@ -1199,7 +1199,7 @@ impl SQLFunctionVisitor<'_> { let mut expr_args = vec![]; for arg in args { if let FunctionArgExpr::Expr(sql_expr) = arg { - expr_args.push(parse_sql_expr(sql_expr, self.ctx)?); + expr_args.push(parse_sql_expr(sql_expr, self.ctx, None)?); } else { return self.not_supported_error(); }; @@ -1215,7 +1215,7 @@ impl SQLFunctionVisitor<'_> { match args.as_slice() { [FunctionArgExpr::Expr(sql_expr1), FunctionArgExpr::Expr(sql_expr2), FunctionArgExpr::Expr(sql_expr3)] => { - let expr1 = parse_sql_expr(sql_expr1, self.ctx)?; + let expr1 = parse_sql_expr(sql_expr1, self.ctx, None)?; let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?; let expr3 = Arg::from_sql_expr(sql_expr3, self.ctx)?; f(expr1, expr2, expr3) @@ -1239,7 +1239,7 @@ impl SQLFunctionVisitor<'_> { (false, []) => Ok(len()), // count(column_name) (false, [FunctionArgExpr::Expr(sql_expr)]) => { - let expr = parse_sql_expr(sql_expr, self.ctx)?; + let expr = parse_sql_expr(sql_expr, self.ctx, None)?; let expr = self.apply_window_spec(expr, &self.func.over)?; Ok(expr.count()) }, @@ -1247,7 +1247,7 @@ impl SQLFunctionVisitor<'_> { (false, [FunctionArgExpr::Wildcard]) => Ok(len()), // count(distinct column_name) (true, [FunctionArgExpr::Expr(sql_expr)]) => { - let expr = parse_sql_expr(sql_expr, self.ctx)?; + let expr = parse_sql_expr(sql_expr, self.ctx, None)?; let expr = self.apply_window_spec(expr, &self.func.over)?; Ok(expr.n_unique()) }, @@ -1267,7 +1267,7 @@ impl SQLFunctionVisitor<'_> { .order_by .iter() .map(|o| { - let e = parse_sql_expr(&o.expr, self.ctx)?; + let e = parse_sql_expr(&o.expr, self.ctx, None)?; Ok(o.asc.map_or(e.clone(), |b| { e.sort(SortOptions::default().with_order_descending(!b)) })) @@ -1279,7 +1279,7 @@ impl SQLFunctionVisitor<'_> { let partition_by = window_spec .partition_by .iter() - .map(|p| parse_sql_expr(p, self.ctx)) + .map(|p| parse_sql_expr(p, self.ctx, None)) .collect::>>()?; expr.over(partition_by) } @@ -1388,6 +1388,6 @@ impl FromSQLExpr for Expr { where Self: Sized, { - parse_sql_expr(expr, ctx) + parse_sql_expr(expr, ctx, None) } } diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index b20fde159b4f..eac1a56fbe05 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -8,6 +8,7 @@ use polars_plan::prelude::typed_lit; use polars_plan::prelude::LiteralValue::Null; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; +use regex::{Regex, RegexBuilder}; #[cfg(feature = "dtype-decimal")] use sqlparser::ast::ExactNumberInfo; use sqlparser::ast::{ @@ -22,6 +23,21 @@ use sqlparser::parser::{Parser, ParserOptions}; use crate::functions::SQLFunctionVisitor; use crate::SQLContext; +static DATE_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); +static TIME_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); + +fn timeunit_from_precision(prec: &Option) -> PolarsResult { + Ok(match prec { + None => TimeUnit::Microseconds, + Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds, + Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds, + Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds, + Some(n) => { + polars_bail!(ComputeError: "invalid temporal type precision; expected 1-9, found {}", n) + }, + }) +} + pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult { Ok(match data_type { // --------------------------------- @@ -106,22 +122,12 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult { - let tu = match prec { - None => TimeUnit::Microseconds, - Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds, - Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds, - Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds, - Some(n) => { - polars_bail!(ComputeError: "unsupported `timestamp` precision; expected a value between 1 and 9, found {}", n) - }, - }; - match tz { - TimezoneInfo::None => DataType::Datetime(tu, None), - _ => { - polars_bail!(ComputeError: "`timestamp` with timezone is not (yet) supported; found tz={}", tz) - }, - } + SQLDataType::Datetime(prec) => DataType::Datetime(timeunit_from_precision(prec)?, None), + SQLDataType::Timestamp(prec, tz) => match tz { + TimezoneInfo::None => DataType::Datetime(timeunit_from_precision(prec)?, None), + _ => { + polars_bail!(ComputeError: "`timestamp` with timezone is not (yet) supported; found tz={}", tz) + }, }, // --------------------------------- @@ -173,6 +179,7 @@ pub enum SubqueryRestriction { /// Recursively walks a SQL Expr to create a polars Expr pub(crate) struct SQLExprVisitor<'a> { ctx: &'a mut SQLContext, + active_schema: Option<&'a Schema>, } impl SQLExprVisitor<'_> { @@ -396,9 +403,70 @@ impl SQLExprVisitor<'_> { } } + /// Handle implicit temporal string comparisons. + /// + /// eg: "dt >= '2024-04-30'", or "dtm::date = '2077-10-10'" + fn convert_temporal_strings(&mut self, left: &Expr, right: &Expr) -> Expr { + if let (Some(name), Some(s), expr_dtype) = match (left, right) { + // identify "col string" expressions + (Expr::Column(name), Expr::Literal(LiteralValue::String(s))) => { + (Some(name.clone()), Some(s), None) + }, + // identify "CAST(expr AS type) string" and/or "expr::type string" expressions + ( + Expr::Cast { + expr, data_type, .. + }, + Expr::Literal(LiteralValue::String(s)), + ) => { + if let Expr::Column(name) = &**expr { + (Some(name.clone()), Some(s), Some(data_type)) + } else { + (None, Some(s), Some(data_type)) + } + }, + _ => (None, None, None), + } { + if expr_dtype.is_none() && self.active_schema.is_none() { + right.clone() + } else { + let left_dtype = expr_dtype + .unwrap_or_else(|| self.active_schema.as_ref().unwrap().get(&name).unwrap()); + + let dt_regex = DATE_LITERAL_RE + .get_or_init(|| RegexBuilder::new(r"^\d{4}-[01]\d-[0-3]\d").build().unwrap()); + let tm_regex = TIME_LITERAL_RE.get_or_init(|| { + RegexBuilder::new(r"^[012]\d:[0-5]\d:[0-5]\d") + .build() + .unwrap() + }); + + match left_dtype { + DataType::Time if tm_regex.is_match(s) => { + right.clone().strict_cast(left_dtype.clone()) + }, + DataType::Date if dt_regex.is_match(s) => { + right.clone().strict_cast(left_dtype.clone()) + }, + DataType::Datetime(_, _) if dt_regex.is_match(s) => { + if s.len() == 10 { + // handle upcast from ISO date string (10 chars) to datetime + lit(format!("{}T00:00:00", s)).strict_cast(left_dtype.clone()) + } else { + lit(s.replacen(' ', "T", 1)).strict_cast(left_dtype.clone()) + } + }, + _ => right.clone(), + } + } + } else { + right.clone() + } + } + /// Visit a SQL binary operator. /// - /// e.g. column + 1 or column1 / column2 + /// e.g. "column + 1", "column1 <= column2" fn visit_binary_op( &mut self, left: &SQLExpr, @@ -406,7 +474,9 @@ impl SQLExprVisitor<'_> { right: &SQLExpr, ) -> PolarsResult { let left = self.visit_expr(left)?; - let right = self.visit_expr(right)?; + let mut right = self.visit_expr(right)?; + right = self.convert_temporal_strings(&left, &right); + Ok(match op { SQLBinaryOperator::And => left.and(right), SQLBinaryOperator::Divide => left / right, @@ -747,8 +817,25 @@ impl SQLExprVisitor<'_> { } }) .collect::>>()?; - let s = Series::from_any_values("", &list, true)?; + let mut s = Series::from_any_values("", &list, true)?; + + // handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')". + // (not yet as versatile as the temporal string conversions in visit_binary_op) + if s.dtype() == &DataType::String { + // handle implicit temporal string comparisons, eg: "dt >= '2024-04-30'" + if let Expr::Column(name) = &expr { + if self.active_schema.is_some() { + let schema = self.active_schema.as_ref().unwrap(); + let left_dtype = schema.get(name); + if let Some(DataType::Date | DataType::Time | DataType::Datetime(_, _)) = + left_dtype + { + s = s.strict_cast(&left_dtype.unwrap().clone())?; + } + } + } + } if negated { Ok(expr.is_in(lit(s)).not()) } else { @@ -1011,16 +1098,20 @@ pub fn sql_expr>(s: S) -> PolarsResult { Ok(match &expr { SelectItem::ExprWithAlias { expr, alias } => { - let expr = parse_sql_expr(expr, &mut ctx)?; + let expr = parse_sql_expr(expr, &mut ctx, None)?; expr.alias(&alias.value) }, - SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx)?, + SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?, _ => polars_bail!(InvalidOperation: "Unable to parse '{}' as Expr", s.as_ref()), }) } -pub(crate) fn parse_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult { - let mut visitor = SQLExprVisitor { ctx }; +pub(crate) fn parse_sql_expr( + expr: &SQLExpr, + ctx: &mut SQLContext, + active_schema: Option<&Schema>, +) -> PolarsResult { + let mut visitor = SQLExprVisitor { ctx, active_schema }; visitor.visit_expr(expr) } diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index 92a69a03ea0c..9a1338adf5fe 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -144,6 +144,37 @@ fn test_literal_exprs() { assert!(df_sql.equals_missing(&df_pl)); } +#[test] +fn test_implicit_date_string() { + let df = df! { + "idx" => &[Some(0), Some(1), Some(2), Some(3)], + "dt" => &[Some("1955-10-01"), None, Some("2007-07-05"), Some("2077-06-11")], + } + .unwrap() + .lazy() + .select(vec![col("idx"), col("dt").cast(DataType::Date)]) + .collect() + .unwrap(); + + let mut context = SQLContext::new(); + context.register("frame", df.clone().lazy()); + for sql in [ + "SELECT idx, dt FROM frame WHERE dt >= '2007-07-05'", + "SELECT idx, dt FROM frame WHERE dt::date >= '2007-07-05'", + "SELECT idx, dt FROM frame WHERE dt::datetime >= '2007-07-05 00:00:00'", + "SELECT idx, dt FROM frame WHERE dt::timestamp >= '2007-07-05 00:00:00'", + ] { + let df_sql = context.execute(sql).unwrap().collect().unwrap(); + let df_pl = df + .clone() + .lazy() + .filter(col("idx").gt_eq(lit(2))) + .collect() + .unwrap(); + assert!(df_sql.equals(&df_pl)); + } +} + #[test] fn test_prefixed_column_names() { let df = create_sample_df().unwrap(); @@ -331,7 +362,7 @@ fn test_agg_functions() { } #[test] -fn create_table() { +fn test_create_table() { let df = create_sample_df().unwrap(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); diff --git a/py-polars/tests/unit/sql/test_temporal.py b/py-polars/tests/unit/sql/test_temporal.py index 4babd435374f..c9d95c84c5ad 100644 --- a/py-polars/tests/unit/sql/test_temporal.py +++ b/py-polars/tests/unit/sql/test_temporal.py @@ -147,6 +147,63 @@ def test_extract_century_millennium(dt: date, expected: list[int]) -> None: ) +@pytest.mark.parametrize( + ("constraint", "expected"), + [ + ("dtm >= '2020-12-30T10:30:45.987'", [0, 2]), + ("dtm::date > '2006-01-01'", [0, 2]), + ("dtm > '2006-01-01'", [0, 1, 2]), # << implies '2006-01-01 00:00:00' + ("dtm <= '2006-01-01'", []), # << implies '2006-01-01 00:00:00' + ("dt != '1960-01-07'", [0, 1]), + ("dt::datetime = '1960-01-07'", [2]), + ("dt::datetime = '1960-01-07 00:00:00'", [2]), + ("dt IN ('1960-01-07','2077-01-01','2222-02-22')", [1, 2]), + ( + "dtm = '2024-01-07 01:02:03.123456000' OR dtm = '2020-12-30 10:30:45.987654'", + [0, 2], + ), + ], +) +def test_implicit_temporal_strings(constraint: str, expected: list[int]) -> None: + df = pl.DataFrame( + { + "idx": [0, 1, 2], + "dtm": [ + datetime(2024, 1, 7, 1, 2, 3, 123456), + datetime(2006, 1, 1, 23, 59, 59, 555555), + datetime(2020, 12, 30, 10, 30, 45, 987654), + ], + "dt": [ + date(2020, 12, 30), + date(2077, 1, 1), + date(1960, 1, 7), + ], + } + ) + res = df.sql(f"SELECT idx FROM self WHERE {constraint}") + actual = sorted(res["idx"]) + assert actual == expected + + +@pytest.mark.parametrize( + "dtval", + [ + "2020-12-30T10:30:45", + "yyyy-mm-dd", + "2222-22-22", + "10:30:45", + ], +) +def test_implicit_temporal_string_errors(dtval: str) -> None: + df = pl.DataFrame({"dt": [date(2020, 12, 30)]}) + + with pytest.raises( + ComputeError, + match="(conversion.*failed)|(cannot compare.*string.*temporal)", + ): + df.sql(f"SELECT * FROM self WHERE dt = '{dtval}'") + + @pytest.mark.parametrize( ("unit", "expected"), [ @@ -182,6 +239,6 @@ def test_timestamp_time_unit_errors() -> None: for prec in (0, 15): with pytest.raises( ComputeError, - match=f"unsupported `timestamp` precision; expected a value between 1 and 9, found {prec}", + match=f"invalid temporal type precision; expected 1-9, found {prec}", ): ctx.execute(f"SELECT ts::timestamp({prec}) FROM frame_data") From 172299cf057972f702bdc5471dbc292a7e8e2bfb Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Sat, 11 May 2024 09:13:46 +0200 Subject: [PATCH 07/29] feat(python): Avoid an extra copy when converting Boolean Series to writable NumPy array (#16164) --- py-polars/polars/series/series.py | 26 +- py-polars/src/series/export.rs | 266 ++++++++++-------- .../interop/numpy/test_to_numpy_series.py | 10 +- 3 files changed, 164 insertions(+), 138 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index ebce4249dbeb..ff6053357f64 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -76,7 +76,6 @@ Object, String, Time, - UInt8, UInt16, UInt32, UInt64, @@ -4401,30 +4400,7 @@ def raise_on_copy() -> None: zero_copy_only=not allow_copy, writable=writable ) - if self.null_count() == 0: - if dtype.is_integer() or dtype.is_float() or dtype in (Datetime, Duration): - np_array = self._s.to_numpy_view() - elif dtype == Boolean: - raise_on_copy() - s_u8 = self.cast(UInt8) - np_array = s_u8._s.to_numpy_view().view(bool) - elif dtype == Date: - raise_on_copy() - s_i32 = self.to_physical() - np_array = s_i32._s.to_numpy_view().astype(" torch.Tensor: """ diff --git a/py-polars/src/series/export.rs b/py-polars/src/series/export.rs index 1097e823f0f0..de4fe0d4e0c8 100644 --- a/py-polars/src/series/export.rs +++ b/py-polars/src/series/export.rs @@ -1,6 +1,8 @@ use num_traits::{Float, NumCast}; use numpy::PyArray1; use polars_core::prelude::*; +use pyo3::exceptions::PyValueError; +use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyList; @@ -161,111 +163,133 @@ impl PySeries { /// Convert this Series to a NumPy ndarray. /// - /// This method will copy data - numeric types without null values should - /// be handled on the Python side in a zero-copy manner. - /// - /// This method will cast integers to floats so that `null = np.nan`. - fn to_numpy(&self, py: Python) -> PyResult { - use DataType::*; - let s = &self.series; - let out = match s.dtype() { - Int8 => numeric_series_to_numpy::(py, s), - Int16 => numeric_series_to_numpy::(py, s), - Int32 => numeric_series_to_numpy::(py, s), - Int64 => numeric_series_to_numpy::(py, s), - UInt8 => numeric_series_to_numpy::(py, s), - UInt16 => numeric_series_to_numpy::(py, s), - UInt32 => numeric_series_to_numpy::(py, s), - UInt64 => numeric_series_to_numpy::(py, s), - Float32 => numeric_series_to_numpy::(py, s), - Float64 => numeric_series_to_numpy::(py, s), - Boolean => { - let ca = s.bool().unwrap(); - let np_arr = PyArray1::from_iter_bound(py, ca.into_iter().map(|s| s.into_py(py))); - np_arr.into_py(py) - }, - Date => date_series_to_numpy(py, s), - Datetime(tu, _) => { - use numpy::datetime::{units, Datetime}; - match tu { - TimeUnit::Milliseconds => { - temporal_series_to_numpy::>(py, s) - }, - TimeUnit::Microseconds => { - temporal_series_to_numpy::>(py, s) - }, - TimeUnit::Nanoseconds => { - temporal_series_to_numpy::>(py, s) - }, - } - }, - Duration(tu) => { - use numpy::datetime::{units, Timedelta}; - match tu { - TimeUnit::Milliseconds => { - temporal_series_to_numpy::>(py, s) - }, - TimeUnit::Microseconds => { - temporal_series_to_numpy::>(py, s) - }, - TimeUnit::Nanoseconds => { - temporal_series_to_numpy::>(py, s) - }, + /// This method copies data only when necessary. Set `allow_copy` to raise an error if copy + /// is required. Set `writable` to make sure the resulting array is writable, possibly requiring + /// copying the data. + fn to_numpy(&self, py: Python, allow_copy: bool, writable: bool) -> PyResult { + let is_empty = self.series.is_empty(); + + if self.series.null_count() == 0 { + if let Some(mut arr) = self.to_numpy_view(py) { + if writable || is_empty { + if !allow_copy && !is_empty { + return Err(PyValueError::new_err( + "cannot return a zero-copy writable array", + )); + } + arr = arr.call_method0(py, intern!(py, "copy"))?; } - }, - Time => { - let ca = s.time().unwrap(); - let iter = time_to_pyobject_iter(py, ca); - let np_arr = PyArray1::from_iter_bound(py, iter.map(|v| v.into_py(py))); - np_arr.into_py(py) - }, - String => { - let ca = s.str().unwrap(); - let np_arr = PyArray1::from_iter_bound(py, ca.into_iter().map(|s| s.into_py(py))); - np_arr.into_py(py) - }, - Binary => { - let ca = s.binary().unwrap(); - let np_arr = PyArray1::from_iter_bound(py, ca.into_iter().map(|s| s.into_py(py))); - np_arr.into_py(py) - }, - Categorical(_, _) | Enum(_, _) => { - let ca = s.categorical().unwrap(); - let np_arr = PyArray1::from_iter_bound(py, ca.iter_str().map(|s| s.into_py(py))); - np_arr.into_py(py) - }, - Decimal(_, _) => { - let ca = s.decimal().unwrap(); - let iter = decimal_to_pyobject_iter(py, ca); - let np_arr = PyArray1::from_iter_bound(py, iter.map(|v| v.into_py(py))); - np_arr.into_py(py) - }, - #[cfg(feature = "object")] - Object(_, _) => { - let ca = s - .as_any() - .downcast_ref::>() - .unwrap(); - let np_arr = - PyArray1::from_iter_bound(py, ca.into_iter().map(|opt_v| opt_v.to_object(py))); - np_arr.into_py(py) - }, - Null => { - let n = s.len(); - let np_arr = PyArray1::from_iter_bound(py, std::iter::repeat(f32::NAN).take(n)); - np_arr.into_py(py) - }, - dt => { - raise_err!( - format!("`to_numpy` not supported for dtype {dt:?}"), - ComputeError - ); - }, - }; - Ok(out) + return Ok(arr); + } + } + + if !allow_copy & !is_empty { + return Err(PyValueError::new_err("cannot return a zero-copy array")); + } + + series_to_numpy_with_copy(py, &self.series) } } -/// Convert numeric types to f32 or f64 with NaN representing a null value + +/// Convert a Series to a NumPy ndarray, copying data in the process. +/// +/// This method will cast integers to floats so that `null = np.nan`. +fn series_to_numpy_with_copy(py: Python, s: &Series) -> PyResult { + use DataType::*; + let out = match s.dtype() { + Int8 => numeric_series_to_numpy::(py, s), + Int16 => numeric_series_to_numpy::(py, s), + Int32 => numeric_series_to_numpy::(py, s), + Int64 => numeric_series_to_numpy::(py, s), + UInt8 => numeric_series_to_numpy::(py, s), + UInt16 => numeric_series_to_numpy::(py, s), + UInt32 => numeric_series_to_numpy::(py, s), + UInt64 => numeric_series_to_numpy::(py, s), + Float32 => numeric_series_to_numpy::(py, s), + Float64 => numeric_series_to_numpy::(py, s), + Boolean => boolean_series_to_numpy(py, s), + Date => date_series_to_numpy(py, s), + Datetime(tu, _) => { + use numpy::datetime::{units, Datetime}; + match tu { + TimeUnit::Milliseconds => { + temporal_series_to_numpy::>(py, s) + }, + TimeUnit::Microseconds => { + temporal_series_to_numpy::>(py, s) + }, + TimeUnit::Nanoseconds => { + temporal_series_to_numpy::>(py, s) + }, + } + }, + Duration(tu) => { + use numpy::datetime::{units, Timedelta}; + match tu { + TimeUnit::Milliseconds => { + temporal_series_to_numpy::>(py, s) + }, + TimeUnit::Microseconds => { + temporal_series_to_numpy::>(py, s) + }, + TimeUnit::Nanoseconds => { + temporal_series_to_numpy::>(py, s) + }, + } + }, + Time => { + let ca = s.time().unwrap(); + let iter = time_to_pyobject_iter(py, ca); + let np_arr = PyArray1::from_iter_bound(py, iter.map(|v| v.into_py(py))); + np_arr.into_py(py) + }, + String => { + let ca = s.str().unwrap(); + let np_arr = PyArray1::from_iter_bound(py, ca.iter().map(|s| s.into_py(py))); + np_arr.into_py(py) + }, + Binary => { + let ca = s.binary().unwrap(); + let np_arr = PyArray1::from_iter_bound(py, ca.iter().map(|s| s.into_py(py))); + np_arr.into_py(py) + }, + Categorical(_, _) | Enum(_, _) => { + let ca = s.categorical().unwrap(); + let np_arr = PyArray1::from_iter_bound(py, ca.iter_str().map(|s| s.into_py(py))); + np_arr.into_py(py) + }, + Decimal(_, _) => { + let ca = s.decimal().unwrap(); + let iter = decimal_to_pyobject_iter(py, ca); + let np_arr = PyArray1::from_iter_bound(py, iter.map(|v| v.into_py(py))); + np_arr.into_py(py) + }, + #[cfg(feature = "object")] + Object(_, _) => { + let ca = s + .as_any() + .downcast_ref::>() + .unwrap(); + let np_arr = + PyArray1::from_iter_bound(py, ca.into_iter().map(|opt_v| opt_v.to_object(py))); + np_arr.into_py(py) + }, + Null => { + let n = s.len(); + let np_arr = PyArray1::from_iter_bound(py, std::iter::repeat(f32::NAN).take(n)); + np_arr.into_py(py) + }, + dt => { + raise_err!( + format!("`to_numpy` not supported for dtype {dt:?}"), + ComputeError + ); + }, + }; + Ok(out) +} + +/// Convert numeric types to f32 or f64 with NaN representing a null value. fn numeric_series_to_numpy(py: Python, s: &Series) -> PyObject where T: PolarsNumericType, @@ -279,23 +303,41 @@ where let np_arr = PyArray1::from_iter_bound(py, ca.iter().map(mapper)); np_arr.into_py(py) } -/// Convert dates directly to i64 with i64::MIN representing a null value +/// Convert booleans to u8 if no nulls are present, otherwise convert to objects. +fn boolean_series_to_numpy(py: Python, s: &Series) -> PyObject { + let ca = s.bool().unwrap(); + if s.null_count() == 0 { + let values = ca.into_no_null_iter(); + PyArray1::::from_iter_bound(py, values).into_py(py) + } else { + let values = ca.iter().map(|opt_v| opt_v.into_py(py)); + PyArray1::from_iter_bound(py, values).into_py(py) + } +} +/// Convert dates directly to i64 with i64::MIN representing a null value. fn date_series_to_numpy(py: Python, s: &Series) -> PyObject { use numpy::datetime::{units, Datetime}; let s_phys = s.to_physical_repr(); let ca = s_phys.i32().unwrap(); - let mapper = |opt_v: Option| { - let int = match opt_v { - Some(v) => v as i64, - None => i64::MIN, + + if s.null_count() == 0 { + let mapper = |v: i32| (v as i64).into(); + let values = ca.into_no_null_iter().map(mapper); + PyArray1::>::from_iter_bound(py, values).into_py(py) + } else { + let mapper = |opt_v: Option| { + match opt_v { + Some(v) => v as i64, + None => i64::MIN, + } + .into() }; - int.into() - }; - let iter = ca.iter().map(mapper); - PyArray1::>::from_iter_bound(py, iter).into_py(py) + let values = ca.iter().map(mapper); + PyArray1::>::from_iter_bound(py, values).into_py(py) + } } -/// Convert datetimes and durations with i64::MIN representing a null value +/// Convert datetimes and durations with i64::MIN representing a null value. fn temporal_series_to_numpy(py: Python, s: &Series) -> PyObject where T: From + numpy::Element, diff --git a/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py b/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py index 5674d183e3f6..31e845e1bbe3 100644 --- a/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py +++ b/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py @@ -209,6 +209,7 @@ def test_series_to_numpy_bool() -> None: assert s.to_list() == result.tolist() assert result.dtype == np.bool_ + assert result.flags.writeable is True assert_allow_copy_false_raises(s) @@ -267,7 +268,14 @@ def test_to_numpy_empty() -> None: result = s.to_numpy(use_pyarrow=False, allow_copy=False) assert result.dtype == np.object_ assert result.shape == (0,) - assert result.size == 0 + + +def test_to_numpy_empty_writable() -> None: + s = pl.Series(dtype=pl.Int64) + result = s.to_numpy(use_pyarrow=False, allow_copy=False, writable=True) + assert result.dtype == np.int64 + assert result.shape == (0,) + assert result.flags.writeable is True def test_to_numpy_chunked() -> None: From b7207b262e48df188fb7b19ef1afc71f2869943e Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sat, 11 May 2024 13:02:29 +0200 Subject: [PATCH 08/29] fix: Flush parquet at end of batches tick (#16073) --- .../src/parquet/write/batched_writer.rs | 123 +++++++++++------- crates/polars-parquet/src/arrow/write/mod.rs | 2 +- .../polars-parquet/src/arrow/write/pages.rs | 48 +++++++ .../src/parquet/write/row_group.rs | 1 + 4 files changed, 128 insertions(+), 46 deletions(-) diff --git a/crates/polars-io/src/parquet/write/batched_writer.rs b/crates/polars-io/src/parquet/write/batched_writer.rs index 818fc65404c6..72dc8b2253c9 100644 --- a/crates/polars-io/src/parquet/write/batched_writer.rs +++ b/crates/polars-io/src/parquet/write/batched_writer.rs @@ -7,8 +7,8 @@ use polars_core::prelude::*; use polars_core::POOL; use polars_parquet::read::ParquetError; use polars_parquet::write::{ - array_to_columns, compress, CompressedPage, Compressor, DynIter, DynStreamingIterator, - Encoding, FallibleStreamingIterator, FileWriter, ParquetType, RowGroupIterColumns, + array_to_columns, arrays_to_columns, CompressedPage, Compressor, DynIter, DynStreamingIterator, + Encoding, FallibleStreamingIterator, FileWriter, Page, ParquetType, RowGroupIterColumns, SchemaDescriptor, WriteOptions, }; use rayon::prelude::*; @@ -44,6 +44,13 @@ impl BatchedWriter { }) } + pub fn encode_and_compress_multiple<'a>( + &'a self, + // A DataFrame with multiple chunks + chunked_df: &'a DataFrame, + ) { + } + /// Write a batch to the parquet writer. /// /// # Panics @@ -108,6 +115,42 @@ fn prepare_rg_iter<'a>( }) } +fn pages_iter_to_compressor( + encoded_columns: Vec>>, + options: WriteOptions, +) -> Vec>> { + encoded_columns + .into_iter() + .map(|encoded_pages| { + // iterator over pages + let pages = DynStreamingIterator::new( + Compressor::new_from_vec( + encoded_pages.map(|result| { + result.map_err(|e| { + ParquetError::FeatureNotSupported(format!("reraised in polars: {e}",)) + }) + }), + options.compression, + vec![], + ) + .map_err(PolarsError::from), + ); + + Ok(pages) + }) + .collect::>() +} + +fn array_to_pages_iter( + array: &ArrayRef, + type_: &ParquetType, + encoding: &[Encoding], + options: WriteOptions, +) -> Vec>> { + let encoded_columns = array_to_columns(array, type_.clone(), options, encoding).unwrap(); + pages_iter_to_compressor(encoded_columns, options) +} + fn create_serializer( batch: RecordBatch, fields: &[ParquetType], @@ -116,30 +159,7 @@ fn create_serializer( parallel: bool, ) -> PolarsResult> { let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec)| { - let encoded_columns = array_to_columns(array, type_.clone(), options, encoding).unwrap(); - - encoded_columns - .into_iter() - .map(|encoded_pages| { - // iterator over pages - let pages = DynStreamingIterator::new( - Compressor::new_from_vec( - encoded_pages.map(|result| { - result.map_err(|e| { - ParquetError::FeatureNotSupported(format!( - "reraised in polars: {e}", - )) - }) - }), - options.compression, - vec![], - ) - .map_err(PolarsError::from), - ); - - Ok(pages) - }) - .collect::>() + array_to_pages_iter(array, type_, encoding, options) }; let columns = if parallel { @@ -204,25 +224,7 @@ fn create_eager_serializer( options: WriteOptions, ) -> PolarsResult> { let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec)| { - let encoded_columns = array_to_columns(array, type_.clone(), options, encoding).unwrap(); - - encoded_columns - .into_iter() - .map(|encoded_pages| { - let compressed_pages = encoded_pages - .into_iter() - .map(|page| { - let page = page?; - let page = compress(page, vec![], options.compression)?; - Ok(Ok(page)) - }) - .collect::>>()?; - - Ok(DynStreamingIterator::new(CompressedPages::new( - compressed_pages, - ))) - }) - .collect::>() + array_to_pages_iter(array, type_, encoding, options) }; let columns = batch @@ -237,3 +239,34 @@ fn create_eager_serializer( Ok(row_group) } + +fn create_eager_serializer_batches( + // DataFrame with multiple chunks + chunked_df: DataFrame, + fields: &[ParquetType], + encodings: &[Vec], + options: WriteOptions, +) -> PolarsResult> { + let func = move |((s, type_), encoding): ((&Series, &ParquetType), &Vec)| { + let n_chunks = s.chunks().len(); + let mut chunks = Vec::with_capacity(n_chunks); + for i in 0..n_chunks { + chunks.push(s.to_arrow(i, true)) + } + + let encoded_columns = arrays_to_columns(&chunks, type_.clone(), options, encoding).unwrap(); + pages_iter_to_compressor(encoded_columns, options) + }; + + let columns = chunked_df + .get_columns() + .iter() + .zip(fields) + .zip(encodings) + .flat_map(func) + .collect::>(); + + let row_group = DynIter::new(columns.into_iter()); + + Ok(row_group) +} diff --git a/crates/polars-parquet/src/arrow/write/mod.rs b/crates/polars-parquet/src/arrow/write/mod.rs index 65e03cecaae4..e5f46b39a476 100644 --- a/crates/polars-parquet/src/arrow/write/mod.rs +++ b/crates/polars-parquet/src/arrow/write/mod.rs @@ -66,7 +66,7 @@ pub struct WriteOptions { use arrow::compute::aggregate::estimated_bytes_size; use arrow::match_integer_type; pub use file::FileWriter; -pub use pages::{array_to_columns, Nested}; +pub use pages::{array_to_columns, arrays_to_columns, Nested}; use polars_error::{polars_bail, PolarsResult}; pub use row_group::{row_group_iter, RowGroupIterator}; pub use schema::to_parquet_type; diff --git a/crates/polars-parquet/src/arrow/write/pages.rs b/crates/polars-parquet/src/arrow/write/pages.rs index f62735258205..f11f749b0a37 100644 --- a/crates/polars-parquet/src/arrow/write/pages.rs +++ b/crates/polars-parquet/src/arrow/write/pages.rs @@ -295,6 +295,54 @@ pub fn array_to_columns + Send + Sync>( .collect() } +pub fn arrays_to_columns + Send + Sync>( + arrays: &[A], + type_: ParquetType, + options: WriteOptions, + encoding: &[Encoding], +) -> PolarsResult>>> { + let array = arrays[0].as_ref(); + let nested = to_nested(array, &type_)?; + + let types = to_parquet_leaves(type_); + + // leaves; index level is nesting depth. + // index i: has a vec because we have multiple chunks. + let mut leaves = vec![]; + + // Ensure we transpose the leaves. So that all the leaves from the same columns are at the same level vec. + let mut scratch = vec![]; + for arr in arrays { + scratch.clear(); + to_leaves_recursive(arr.as_ref(), &mut scratch); + for (i, leave) in scratch.iter().copied().enumerate() { + while i < leaves.len() { + leaves.push(vec![]); + } + leaves[i].push(leave); + } + } + + leaves + .into_iter() + .zip(nested) + .zip(types) + .zip(encoding.iter()) + .map(move |(((values, nested), type_), encoding)| { + let iter = values.into_iter().map(|leave_values| { + array_to_pages(leave_values, type_.clone(), &nested, options, *encoding) + }); + + // Need a scratch to bubble up the error :/ + let mut scratch = Vec::with_capacity(iter.size_hint().0); + for v in iter { + scratch.push(v?) + } + Ok(DynIter::new(scratch.into_iter().flatten())) + }) + .collect::>>() +} + #[cfg(test)] mod tests { use arrow::array::*; diff --git a/crates/polars-parquet/src/parquet/write/row_group.rs b/crates/polars-parquet/src/parquet/write/row_group.rs index 22dc42286bf7..0db63806a6a1 100644 --- a/crates/polars-parquet/src/parquet/write/row_group.rs +++ b/crates/polars-parquet/src/parquet/write/row_group.rs @@ -98,6 +98,7 @@ where let bytes_written = offset - initial; let num_rows = compute_num_rows(&columns)?; + dbg!(num_rows); // compute row group stats let file_offset = columns From cf1af8b32e5124e0ef8e7aa2633a93c2b297218f Mon Sep 17 00:00:00 2001 From: Jan Pipek Date: Sat, 11 May 2024 15:20:27 +0200 Subject: [PATCH 09/29] Fix: StringNameSpace.replace_all docstring (#16169) --- py-polars/polars/series/string.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index 9854f1fac33c..da6304e7331e 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -1216,7 +1216,7 @@ def replace( def replace_all(self, pattern: str, value: str, *, literal: bool = False) -> Series: r""" - Replace first matching regex/literal substring with a new string value. + Replace all matching regex/literal substrings with a new string value. Parameters ---------- @@ -1227,12 +1227,10 @@ def replace_all(self, pattern: str, value: str, *, literal: bool = False) -> Ser String that will replace the matched substring. literal Treat `pattern` as a literal string. - n - Number of matches to replace. See Also -------- - replace_all + replace Notes ----- From 9ea2504d3c1369e01dc5ab04c3460fe78d33c312 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Sun, 12 May 2024 14:46:33 +0800 Subject: [PATCH 10/29] docs(python): Add deprecated messages to `cumfold` and `cumreduce` (#16173) --- py-polars/polars/functions/lazy.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index 99a881cf6373..014b3b622fdf 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -2312,6 +2312,9 @@ def cumfold( Every cumulative result is added as a separate field in a Struct column. + .. deprecated:: 0.19.14 + This function has been renamed to :func:`cum_fold`. + Parameters ---------- acc @@ -2344,6 +2347,9 @@ def cumreduce( Every cumulative result is added as a separate field in a Struct column. + .. deprecated:: 0.19.14 + This function has been renamed to :func:`cum_reduce`. + Parameters ---------- function From 2f817420881f70f0d0f9fa6cf1573d03e01a326f Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Sun, 12 May 2024 07:54:15 +0100 Subject: [PATCH 11/29] perf: use zeroable_vec in ewm_mean_by (#16166) --- crates/polars-ops/src/series/ops/ewm_by.rs | 72 +++++++++++-------- .../tests/unit/operations/test_ewm_by.py | 17 ++++- 2 files changed, 58 insertions(+), 31 deletions(-) diff --git a/crates/polars-ops/src/series/ops/ewm_by.rs b/crates/polars-ops/src/series/ops/ewm_by.rs index a14947467f52..ce374f757507 100644 --- a/crates/polars-ops/src/series/ops/ewm_by.rs +++ b/crates/polars-ops/src/series/ops/ewm_by.rs @@ -1,5 +1,9 @@ +use arrow::compute::concatenate::concatenate_validities; +use arrow::compute::utils::combine_validities_and; +use bytemuck::allocation::zeroed_vec; use num_traits::{Float, FromPrimitive, One, Zero}; use polars_core::prelude::*; +use polars_core::utils::align_chunks_binary; pub fn ewm_mean_by( s: &Series, @@ -7,17 +11,19 @@ pub fn ewm_mean_by( half_life: i64, assume_sorted: bool, ) -> PolarsResult { - let func = match assume_sorted { - true => ewm_mean_by_impl_sorted, - false => ewm_mean_by_impl, - }; match (s.dtype(), times.dtype()) { - (DataType::Float64, DataType::Int64) => { - Ok(func(s.f64().unwrap(), times.i64().unwrap(), half_life).into_series()) - }, - (DataType::Float32, DataType::Int64) => { - Ok(ewm_mean_by_impl(s.f32().unwrap(), times.i64().unwrap(), half_life).into_series()) - }, + (DataType::Float64, DataType::Int64) => Ok((if assume_sorted { + ewm_mean_by_impl_sorted(s.f64().unwrap(), times.i64().unwrap(), half_life) + } else { + ewm_mean_by_impl(s.f64().unwrap(), times.i64().unwrap(), half_life) + }) + .into_series()), + (DataType::Float32, DataType::Int64) => Ok((if assume_sorted { + ewm_mean_by_impl_sorted(s.f32().unwrap(), times.i64().unwrap(), half_life) + } else { + ewm_mean_by_impl(s.f32().unwrap(), times.i64().unwrap(), half_life) + }) + .into_series()), #[cfg(feature = "dtype-datetime")] (_, DataType::Datetime(time_unit, _)) => { let half_life = adjust_half_life_to_time_unit(half_life, time_unit); @@ -61,50 +67,56 @@ where ChunkedArray: ChunkTakeUnchecked, { let sorting_indices = times.arg_sort(Default::default()); - let values = unsafe { values.take_unchecked(&sorting_indices) }; - let times = unsafe { times.take_unchecked(&sorting_indices) }; + let sorted_values = unsafe { values.take_unchecked(&sorting_indices) }; + let sorted_times = unsafe { times.take_unchecked(&sorting_indices) }; let sorting_indices = sorting_indices .cont_slice() .expect("`arg_sort` should have returned a single chunk"); - let mut out = vec![None; times.len()]; + let mut out: Vec<_> = zeroed_vec(sorted_times.len()); let mut skip_rows: usize = 0; let mut prev_time: i64 = 0; let mut prev_result = T::Native::zero(); - for (idx, (value, time)) in values.iter().zip(times.iter()).enumerate() { + for (idx, (value, time)) in sorted_values.iter().zip(sorted_times.iter()).enumerate() { if let (Some(time), Some(value)) = (time, value) { prev_time = time; prev_result = value; unsafe { let out_idx = sorting_indices.get_unchecked(idx); - *out.get_unchecked_mut(*out_idx as usize) = Some(prev_result); + *out.get_unchecked_mut(*out_idx as usize) = prev_result; } skip_rows = idx + 1; break; }; } - values + sorted_values .iter() - .zip(times.iter()) + .zip(sorted_times.iter()) .enumerate() .skip(skip_rows) .for_each(|(idx, (value, time))| { - let result_opt = match (time, value) { - (Some(time), Some(value)) => { - let result = update(value, prev_result, time, prev_time, half_life); - prev_time = time; - prev_result = result; - Some(result) - }, - _ => None, + if let (Some(time), Some(value)) = (time, value) { + let result = update(value, prev_result, time, prev_time, half_life); + prev_time = time; + prev_result = result; + unsafe { + let out_idx = sorting_indices.get_unchecked(idx); + *out.get_unchecked_mut(*out_idx as usize) = result; + } }; - unsafe { - let out_idx = sorting_indices.get_unchecked(idx); - *out.get_unchecked_mut(*out_idx as usize) = result_opt; - } }); - ChunkedArray::::from_iter_options(values.name(), out.into_iter()) + let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(true)); + if (times.null_count() > 0) || (values.null_count() > 0) { + let (times, values) = align_chunks_binary(times, values); + let times_chunk_refs: Vec<_> = times.chunks().iter().map(|c| &**c).collect(); + let times_validity = concatenate_validities(×_chunk_refs); + let values_chunk_refs: Vec<_> = values.chunks().iter().map(|c| &**c).collect(); + let values_validity = concatenate_validities(&values_chunk_refs); + let validity = combine_validities_and(times_validity.as_ref(), values_validity.as_ref()); + arr = arr.with_validity_typed(validity); + } + ChunkedArray::with_chunk(values.name(), arr) } /// Fastpath if `times` is known to already be sorted. diff --git a/py-polars/tests/unit/operations/test_ewm_by.py b/py-polars/tests/unit/operations/test_ewm_by.py index aaace3e67cf4..43884d7e0b82 100644 --- a/py-polars/tests/unit/operations/test_ewm_by.py +++ b/py-polars/tests/unit/operations/test_ewm_by.py @@ -6,7 +6,7 @@ import pytest import polars as pl -from polars.testing import assert_frame_equal +from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: from polars.type_aliases import PolarsIntegerType, TimeUnit @@ -223,3 +223,18 @@ def test_ewma_by_warn_two_chunks() -> None: pl.col("values").ewm_mean_by("by", half_life="2i"), ) assert_frame_equal(result, expected.sort("by")) + + +def test_ewma_by_multiple_chunks() -> None: + # times contains null + times = pl.Series([1, 2]).append(pl.Series([None], dtype=pl.Int64)) + values = pl.Series([1, 2]).append(pl.Series([3])) + result = values.ewm_mean_by(times, half_life="2i") + expected = pl.Series([1.0, 1.292893, None]) + assert_series_equal(result, expected) + + # values contains null + times = pl.Series([1, 2]).append(pl.Series([3])) + values = pl.Series([1, 2]).append(pl.Series([None], dtype=pl.Int64)) + result = values.ewm_mean_by(times, half_life="2i") + assert_series_equal(result, expected) From 492c9d942cd69210623712113119c9e71f2f463d Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sun, 12 May 2024 13:08:25 +0200 Subject: [PATCH 12/29] chore: Remove unused code (#16175) --- .../src/parquet/write/batched_writer.rs | 71 +------------------ .../src/parquet/write/row_group.rs | 1 - 2 files changed, 2 insertions(+), 70 deletions(-) diff --git a/crates/polars-io/src/parquet/write/batched_writer.rs b/crates/polars-io/src/parquet/write/batched_writer.rs index 72dc8b2253c9..f5b42b7ef690 100644 --- a/crates/polars-io/src/parquet/write/batched_writer.rs +++ b/crates/polars-io/src/parquet/write/batched_writer.rs @@ -1,4 +1,3 @@ -use std::collections::VecDeque; use std::io::Write; use std::sync::Mutex; @@ -7,8 +6,8 @@ use polars_core::prelude::*; use polars_core::POOL; use polars_parquet::read::ParquetError; use polars_parquet::write::{ - array_to_columns, arrays_to_columns, CompressedPage, Compressor, DynIter, DynStreamingIterator, - Encoding, FallibleStreamingIterator, FileWriter, Page, ParquetType, RowGroupIterColumns, + array_to_columns, CompressedPage, Compressor, DynIter, DynStreamingIterator, Encoding, + FallibleStreamingIterator, FileWriter, Page, ParquetType, RowGroupIterColumns, SchemaDescriptor, WriteOptions, }; use rayon::prelude::*; @@ -44,13 +43,6 @@ impl BatchedWriter { }) } - pub fn encode_and_compress_multiple<'a>( - &'a self, - // A DataFrame with multiple chunks - chunked_df: &'a DataFrame, - ) { - } - /// Write a batch to the parquet writer. /// /// # Panics @@ -187,34 +179,6 @@ fn create_serializer( Ok(row_group) } -struct CompressedPages { - pages: VecDeque>, - current: Option, -} - -impl CompressedPages { - fn new(pages: VecDeque>) -> Self { - Self { - pages, - current: None, - } - } -} - -impl FallibleStreamingIterator for CompressedPages { - type Item = CompressedPage; - type Error = PolarsError; - - fn advance(&mut self) -> Result<(), Self::Error> { - self.current = self.pages.pop_front().transpose()?; - Ok(()) - } - - fn get(&self) -> Option<&Self::Item> { - self.current.as_ref() - } -} - /// This serializer encodes and compresses all eagerly in memory. /// Used for separating compute from IO. fn create_eager_serializer( @@ -239,34 +203,3 @@ fn create_eager_serializer( Ok(row_group) } - -fn create_eager_serializer_batches( - // DataFrame with multiple chunks - chunked_df: DataFrame, - fields: &[ParquetType], - encodings: &[Vec], - options: WriteOptions, -) -> PolarsResult> { - let func = move |((s, type_), encoding): ((&Series, &ParquetType), &Vec)| { - let n_chunks = s.chunks().len(); - let mut chunks = Vec::with_capacity(n_chunks); - for i in 0..n_chunks { - chunks.push(s.to_arrow(i, true)) - } - - let encoded_columns = arrays_to_columns(&chunks, type_.clone(), options, encoding).unwrap(); - pages_iter_to_compressor(encoded_columns, options) - }; - - let columns = chunked_df - .get_columns() - .iter() - .zip(fields) - .zip(encodings) - .flat_map(func) - .collect::>(); - - let row_group = DynIter::new(columns.into_iter()); - - Ok(row_group) -} diff --git a/crates/polars-parquet/src/parquet/write/row_group.rs b/crates/polars-parquet/src/parquet/write/row_group.rs index 0db63806a6a1..22dc42286bf7 100644 --- a/crates/polars-parquet/src/parquet/write/row_group.rs +++ b/crates/polars-parquet/src/parquet/write/row_group.rs @@ -98,7 +98,6 @@ where let bytes_written = offset - initial; let num_rows = compute_num_rows(&columns)?; - dbg!(num_rows); // compute row group stats let file_offset = columns From 1b97b6da20a0fc855ed4856f7dece3b1fd265262 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Sun, 12 May 2024 21:38:06 +1000 Subject: [PATCH 13/29] fix: Fix CSV skip_rows_after_header for streaming (#16176) --- crates/polars-pipe/src/executors/sources/csv.rs | 6 ++++++ py-polars/tests/unit/io/test_csv.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/crates/polars-pipe/src/executors/sources/csv.rs b/crates/polars-pipe/src/executors/sources/csv.rs index d72be4ba752e..28fa62ec83f4 100644 --- a/crates/polars-pipe/src/executors/sources/csv.rs +++ b/crates/polars-pipe/src/executors/sources/csv.rs @@ -61,6 +61,12 @@ impl CsvSource { let low_memory = options.low_memory; let reader: CsvReader = options + .with_skip_rows_after_header( + // If we don't set it to 0 here, it will skip double the amount of rows. + // But if we set it to 0, it will still skip the requested amount of rows. + // TODO: Find out why. Maybe has something to do with schema inference. + 0, + ) .with_schema_overwrite(Some(self.schema.clone())) .with_n_rows(n_rows) .with_columns(with_columns) diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index b548a9f6f87d..e19d31ca2e29 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -2062,6 +2062,22 @@ def test_csv_escape_cf_15349() -> None: assert f.read() == b'test\nnormal\n"with\rcr"\n' +@pytest.mark.write_disk() +@pytest.mark.parametrize("streaming", [True, False]) +def test_skip_rows_after_header(tmp_path: Path, streaming: bool) -> None: + tmp_path.mkdir(exist_ok=True) + path = tmp_path / "data.csv" + + df = pl.Series("a", [1, 2, 3, 4, 5], dtype=pl.Int64).to_frame() + df.write_csv(path) + + skip = 2 + expect = df.slice(skip) + out = pl.scan_csv(path, skip_rows_after_header=skip).collect(streaming=streaming) + + assert_frame_equal(out, expect) + + @pytest.mark.parametrize("use_pyarrow", [True, False]) def test_skip_rows_after_header_pyarrow(use_pyarrow: bool) -> None: csv = textwrap.dedent( From 3892c750253e2beed38676038b0922760ce9a283 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sun, 12 May 2024 14:45:22 +0200 Subject: [PATCH 14/29] fix: Fix streaming glob slice (#16174) --- crates/polars-core/src/frame/mod.rs | 3 + .../src/csv/read/read_impl/batched_mmap.rs | 22 ++++--- .../src/csv/read/read_impl/batched_read.rs | 28 +++++---- .../streaming/construct_pipeline.rs | 15 ++--- .../physical_plan/streaming/convert_alp.rs | 62 +------------------ .../optimizer/slice_pushdown_lp.rs | 13 +++- crates/polars-utils/src/arena.rs | 7 +++ .../tests/unit/streaming/test_streaming_io.py | 5 ++ 8 files changed, 59 insertions(+), 96 deletions(-) diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index c5433e12174f..77dbb569cb10 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -2249,6 +2249,9 @@ impl DataFrame { if offset == 0 && length == self.height() { return self.clone(); } + if length == 0 { + return self.clear(); + } let col = self .columns .iter() diff --git a/crates/polars-io/src/csv/read/read_impl/batched_mmap.rs b/crates/polars-io/src/csv/read/read_impl/batched_mmap.rs index cb8ce04947d8..f30f3105de51 100644 --- a/crates/polars-io/src/csv/read/read_impl/batched_mmap.rs +++ b/crates/polars-io/src/csv/read/read_impl/batched_mmap.rs @@ -170,7 +170,7 @@ impl<'a> CoreReader<'a> { to_cast: self.to_cast, ignore_errors: self.ignore_errors, truncate_ragged_lines: self.truncate_ragged_lines, - n_rows: self.n_rows, + remaining: self.n_rows.unwrap_or(usize::MAX), encoding: self.encoding, separator: self.separator, schema: self.schema, @@ -197,7 +197,7 @@ pub struct BatchedCsvReaderMmap<'a> { truncate_ragged_lines: bool, to_cast: Vec, ignore_errors: bool, - n_rows: Option, + remaining: usize, encoding: CsvEncoding, separator: u8, schema: SchemaRef, @@ -211,14 +211,9 @@ pub struct BatchedCsvReaderMmap<'a> { impl<'a> BatchedCsvReaderMmap<'a> { pub fn next_batches(&mut self, n: usize) -> PolarsResult>> { - if n == 0 { + if n == 0 || self.remaining == 0 { return Ok(None); } - if let Some(n_rows) = self.n_rows { - if self.rows_read >= n_rows as IdxSize { - return Ok(None); - } - } // get next `n` offset positions. let file_chunks_iter = (&mut self.file_chunks_iter).take(n); @@ -274,8 +269,15 @@ impl<'a> BatchedCsvReaderMmap<'a> { if self.row_index.is_some() { update_row_counts2(&mut chunks, self.rows_read) } - for df in &chunks { - self.rows_read += df.height() as IdxSize; + for df in &mut chunks { + let h = df.height(); + + if self.remaining < h { + *df = df.slice(0, self.remaining) + }; + self.remaining = self.remaining.saturating_sub(h); + + self.rows_read += h as IdxSize; } Ok(Some(chunks)) } diff --git a/crates/polars-io/src/csv/read/read_impl/batched_read.rs b/crates/polars-io/src/csv/read/read_impl/batched_read.rs index b42be05f14b4..8c405f88e3fc 100644 --- a/crates/polars-io/src/csv/read/read_impl/batched_read.rs +++ b/crates/polars-io/src/csv/read/read_impl/batched_read.rs @@ -246,7 +246,6 @@ impl<'a> CoreReader<'a> { Ok(BatchedCsvReaderRead { chunk_size: self.chunk_size, - finished: false, file_chunk_reader: chunk_iter, file_chunks: vec![], projection, @@ -260,20 +259,20 @@ impl<'a> CoreReader<'a> { to_cast: self.to_cast, ignore_errors: self.ignore_errors, truncate_ragged_lines: self.truncate_ragged_lines, - n_rows: self.n_rows, + remaining: self.n_rows.unwrap_or(usize::MAX), encoding: self.encoding, separator: self.separator, schema: self.schema, rows_read: 0, _cat_lock, decimal_comma: self.decimal_comma, + finished: false, }) } } pub struct BatchedCsvReaderRead<'a> { chunk_size: usize, - finished: bool, file_chunk_reader: ChunkReader<'a>, file_chunks: Vec<(SyncPtr, usize)>, projection: Vec, @@ -287,7 +286,7 @@ pub struct BatchedCsvReaderRead<'a> { to_cast: Vec, ignore_errors: bool, truncate_ragged_lines: bool, - n_rows: Option, + remaining: usize, encoding: CsvEncoding, separator: u8, schema: SchemaRef, @@ -297,19 +296,15 @@ pub struct BatchedCsvReaderRead<'a> { #[cfg(not(feature = "dtype-categorical"))] _cat_lock: Option, decimal_comma: bool, + finished: bool, } // impl<'a> BatchedCsvReaderRead<'a> { /// `n` number of batches. pub fn next_batches(&mut self, n: usize) -> PolarsResult>> { - if n == 0 || self.finished { + if n == 0 || self.remaining == 0 || self.finished { return Ok(None); } - if let Some(n_rows) = self.n_rows { - if self.rows_read >= n_rows as IdxSize { - return Ok(None); - } - } // get next `n` offset positions. @@ -331,7 +326,7 @@ impl<'a> BatchedCsvReaderRead<'a> { // get the final slice self.file_chunks .push(self.file_chunk_reader.get_buf_remaining()); - self.finished = true + self.finished = true; } // depleted the offsets iterator, we are done as well. @@ -380,8 +375,15 @@ impl<'a> BatchedCsvReaderRead<'a> { if self.row_index.is_some() { update_row_counts2(&mut chunks, self.rows_read) } - for df in &chunks { - self.rows_read += df.height() as IdxSize; + for df in &mut chunks { + let h = df.height(); + + if self.remaining < h { + *df = df.slice(0, self.remaining) + }; + self.remaining = self.remaining.saturating_sub(h); + + self.rows_read += h as IdxSize; } Ok(Some(chunks)) } diff --git a/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs index 840429855f5b..93e6a55aa6f3 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs @@ -68,10 +68,12 @@ fn jit_insert_slice( sink_nodes: &mut Vec<(usize, Node, Rc>)>, operator_offset: usize, ) { - // if the join/union has a slice, we add a new slice node + // if the join has a slice, we add a new slice node // note that we take the offset + 1, because we want to // slice AFTER the join has happened and the join will be an // operator + // NOTE: Don't do this for union, that doesn't work. + // TODO! Deal with this in the optimizer. use IR::*; let (offset, len) = match lp_arena.get(node) { Join { options, .. } if options.args.slice.is_some() => { @@ -80,19 +82,11 @@ fn jit_insert_slice( }; (offset, len) }, - Union { - options: - UnionOptions { - slice: Some((offset, len)), - .. - }, - .. - } => (*offset, *len), _ => return, }; let slice_node = lp_arena.add(Slice { - input: Node::default(), + input: node, offset, len: len as IdxSize, }); @@ -178,7 +172,6 @@ pub(super) fn construct( }, PipelineNode::Union(node) => { operator_nodes.push(node); - jit_insert_slice(node, lp_arena, &mut sink_nodes, operator_offset); let op = get_operator(node, lp_arena, expr_arena, &to_physical_piped_expr)?; operators.push(op); }, diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index 7ffdbd7935af..54fe0b1a68f3 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -81,21 +81,6 @@ fn insert_file_sink(mut root: Node, lp_arena: &mut Arena) -> Node { root } -fn insert_slice( - root: Node, - offset: i64, - len: IdxSize, - lp_arena: &mut Arena, - state: &mut Branch, -) { - let node = lp_arena.add(IR::Slice { - input: root, - offset, - len: len as IdxSize, - }); - state.operators_sinks.push(PipelineNode::Sink(node)); -} - pub(crate) fn insert_streaming_nodes( root: Node, lp_arena: &mut Arena, @@ -244,20 +229,8 @@ pub(crate) fn insert_streaming_nodes( ) } }, - Scan { - file_options: options, - scan_type, - .. - } if scan_type.streamable() => { + Scan { scan_type, .. } if scan_type.streamable() => { if state.streamable { - #[cfg(feature = "csv")] - if matches!(scan_type, FileScan::Csv { .. }) { - // the batched csv reader doesn't stop exactly at n_rows - if let Some(n_rows) = options.n_rows { - insert_slice(root, 0, n_rows as IdxSize, lp_arena, &mut state); - } - } - state.sources.push(root); pipeline_trees[current_idx].push(state) } @@ -320,38 +293,7 @@ pub(crate) fn insert_streaming_nodes( state.sources.push(root); pipeline_trees[current_idx].push(state); }, - Union { - options: - UnionOptions { - slice: Some((offset, len)), - .. - }, - .. - } if *offset >= 0 => { - insert_slice(root, *offset, *len as IdxSize, lp_arena, &mut state); - state.streamable = true; - let Union { inputs, .. } = lp_arena.get(root) else { - unreachable!() - }; - for (i, input) in inputs.iter().enumerate() { - let mut state = if i == 0 { - // Note the clone! - let mut state = state.clone(); - state.join_count += inputs.len() as u32 - 1; - state - } else { - let mut state = state.split_from_sink(); - state.join_count = 0; - state - }; - state.operators_sinks.push(PipelineNode::Union(root)); - stack.push(StackFrame::new(*input, state, current_idx)); - } - }, - Union { - inputs, - options: UnionOptions { slice: None, .. }, - } => { + Union { inputs, .. } => { { state.streamable = true; for (i, input) in inputs.iter().enumerate() { diff --git a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs index 18b3c9d85631..5a0975f4a654 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs @@ -209,7 +209,6 @@ impl SlicePushDown { Ok(lp) } (Union {mut inputs, mut options }, Some(state)) => { - options.slice = Some((state.offset, state.len as usize)); if state.offset == 0 { for input in &mut inputs { let input_lp = lp_arena.take(*input); @@ -217,7 +216,17 @@ impl SlicePushDown { lp_arena.replace(*input, input_lp); } } - Ok(Union {inputs, options}) + // The in-memory union node is slice aware. + // We still set this information, but the streaming engine will ignore it. + options.slice = Some((state.offset, state.len as usize)); + let lp = Union {inputs, options}; + + if self.streaming { + // Ensure the slice node remains. + self.no_pushdown_finish_opt(lp, Some(state), lp_arena) + } else { + Ok(lp) + } }, (Join { input_left, diff --git a/crates/polars-utils/src/arena.rs b/crates/polars-utils/src/arena.rs index df367b733f1f..31818eb03d86 100644 --- a/crates/polars-utils/src/arena.rs +++ b/crates/polars-utils/src/arena.rs @@ -104,6 +104,13 @@ impl Arena { } } +impl Arena { + pub fn duplicate(&mut self, node: Node) -> Node { + let item = self.items[node.0].clone(); + self.add(item) + } +} + impl Arena { #[inline] pub fn take(&mut self, idx: Node) -> T { diff --git a/py-polars/tests/unit/streaming/test_streaming_io.py b/py-polars/tests/unit/streaming/test_streaming_io.py index d405fec1183c..982ba225e9d9 100644 --- a/py-polars/tests/unit/streaming/test_streaming_io.py +++ b/py-polars/tests/unit/streaming/test_streaming_io.py @@ -30,6 +30,11 @@ def test_scan_slice_streaming(io_files_path: Path) -> None: df = pl.scan_csv(foods_file_path).head(5).collect(streaming=True) assert df.shape == (5, 4) + # globbing + foods_file_path = io_files_path / "foods*.csv" + df = pl.scan_csv(foods_file_path).head(5).collect(streaming=True) + assert df.shape == (5, 4) + @pytest.mark.parametrize("dtype", [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16]) def test_scan_csv_overwrite_small_dtypes( From 0b66308a5f64ab35116d3a1f8a0dab96a94ff98b Mon Sep 17 00:00:00 2001 From: tharunsuresh-code <70054755+tharunsuresh-code@users.noreply.github.com> Date: Mon, 13 May 2024 00:12:03 +0530 Subject: [PATCH 15/29] docs(python): Add examples for multiple `Series` functions (#16172) --- py-polars/polars/series/series.py | 153 +++++++++++++++++++++++++++++- 1 file changed, 151 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index ff6053357f64..a7109c3c21d8 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -3314,6 +3314,28 @@ def limit(self, n: int = 10) -> Series: See Also -------- head + + Examples + -------- + >>> s = pl.Series("a", [1, 2, 3, 4, 5]) + >>> s.limit(3) + shape: (3,) + Series: 'a' [i64] + [ + 1 + 2 + 3 + ] + + Pass a negative value to get all rows `except` the last `abs(n)`. + + >>> s.limit(-3) + shape: (2,) + Series: 'a' [i64] + [ + 1 + 2 + ] """ return self.head(n) @@ -4064,6 +4086,28 @@ def explode(self) -> Series: -------- Series.list.explode : Explode a list column. Series.str.explode : Explode a string column. + + Examples + -------- + >>> s = pl.Series("a", [[1, 2, 3], [4, 5, 6]]) + >>> s + shape: (2,) + Series: 'a' [list[i64]] + [ + [1, 2, 3] + [4, 5, 6] + ] + >>> s.explode() + shape: (6,) + Series: 'a' [i64] + [ + 1 + 2 + 3 + 4 + 5 + 6 + ] """ def equals( @@ -4212,6 +4256,29 @@ def rechunk(self, *, in_place: bool = False) -> Self: ---------- in_place In place or not. + + Examples + -------- + >>> s1 = pl.Series("a", [1, 2, 3]) + >>> s1.n_chunks() + 1 + >>> s2 = pl.Series("a", [4, 5, 6]) + >>> s = pl.concat([s1, s2], rechunk=False) + >>> s.n_chunks() + 2 + >>> s.rechunk(in_place=True) + shape: (6,) + Series: 'a' [i64] + [ + 1 + 2 + 3 + 4 + 5 + 6 + ] + >>> s.n_chunks() + 1 """ opt_s = self._s.rechunk(in_place) return self if in_place else self._from_pyseries(opt_s) @@ -6236,6 +6303,26 @@ def reinterpret(self, *, signed: bool = True) -> Series: ---------- signed If True, reinterpret as `pl.Int64`. Otherwise, reinterpret as `pl.UInt64`. + + Examples + -------- + >>> s = pl.Series("a", [-(2**60), -2, 3]) + >>> s + shape: (3,) + Series: 'a' [i64] + [ + -1152921504606846976 + -2 + 3 + ] + >>> s.reinterpret(signed=False) + shape: (3,) + Series: 'a' [u64] + [ + 17293822569102704640 + 18446744073709551614 + 3 + ] """ def interpolate(self, method: InterpolationMethod = "linear") -> Series: @@ -7204,7 +7291,21 @@ def set_sorted(self, *, descending: bool = False) -> Self: return self._from_pyseries(self._s.set_sorted_flag(descending)) def new_from_index(self, index: int, length: int) -> Self: - """Create a new Series filled with values from the given index.""" + """ + Create a new Series filled with values from the given index. + + Examples + -------- + >>> s = pl.Series("a", [1, 2, 3, 4, 5]) + >>> s.new_from_index(1, 3) + shape: (3,) + Series: 'a' [i64] + [ + 2 + 2 + 2 + ] + """ return self._from_pyseries(self._s.new_from_index(index, length)) def shrink_dtype(self) -> Series: @@ -7213,10 +7314,58 @@ def shrink_dtype(self) -> Series: Shrink to the dtype needed to fit the extrema of this [`Series`]. This can be used to reduce memory pressure. + + Examples + -------- + >>> s = pl.Series("a", [1, 2, 3, 4, 5, 6]) + >>> s + shape: (6,) + Series: 'a' [i64] + [ + 1 + 2 + 3 + 4 + 5 + 6 + ] + >>> s.shrink_dtype() + shape: (6,) + Series: 'a' [i8] + [ + 1 + 2 + 3 + 4 + 5 + 6 + ] """ def get_chunks(self) -> list[Series]: - """Get the chunks of this Series as a list of Series.""" + """ + Get the chunks of this Series as a list of Series. + + Examples + -------- + >>> s1 = pl.Series("a", [1, 2, 3]) + >>> s2 = pl.Series("a", [4, 5, 6]) + >>> s = pl.concat([s1, s2], rechunk=False) + >>> s.get_chunks() + [shape: (3,) + Series: 'a' [i64] + [ + 1 + 2 + 3 + ], shape: (3,) + Series: 'a' [i64] + [ + 4 + 5 + 6 + ]] + """ return self._s.get_chunks() def implode(self) -> Self: From dbfc6b2e945028dfa027f0a5b688140fed1c4e10 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Mon, 13 May 2024 06:51:25 +0200 Subject: [PATCH 16/29] feat(python): Overhaul parametric test implementations and update Hypothesis to latest version (#16062) --- py-polars/docs/requirements-docs.txt | 2 +- py-polars/docs/source/reference/testing.rst | 49 +- py-polars/polars/_utils/constants.py | 26 + py-polars/polars/_utils/convert.py | 25 +- .../polars/testing/parametric/__init__.py | 43 +- .../polars/testing/parametric/primitives.py | 724 ------------------ .../polars/testing/parametric/profiles.py | 4 +- .../polars/testing/parametric/strategies.py | 494 ------------ .../testing/parametric/strategies/__init__.py | 22 + .../testing/parametric/strategies/_utils.py | 14 + .../testing/parametric/strategies/core.py | 451 +++++++++++ .../testing/parametric/strategies/data.py | 334 ++++++++ .../testing/parametric/strategies/dtype.py | 321 ++++++++ .../testing/parametric/strategies/legacy.py | 156 ++++ py-polars/requirements-dev.txt | 2 +- py-polars/tests/parametric/conftest.py | 2 +- py-polars/tests/parametric/test_dataframe.py | 10 +- .../tests/parametric/test_groupby_rolling.py | 19 +- py-polars/tests/parametric/test_lazyframe.py | 10 +- py-polars/tests/parametric/test_lit.py | 12 +- py-polars/tests/parametric/test_series.py | 14 +- py-polars/tests/parametric/test_testing.py | 153 ++-- .../parametric/time_series/test_ewm_by.py | 2 +- .../time_series/test_to_datetime.py | 38 +- py-polars/tests/unit/conftest.py | 2 +- py-polars/tests/unit/dataframe/test_df.py | 12 +- .../tests/unit/interchange/test_roundtrip.py | 15 +- .../tests/unit/interop/test_to_pandas.py | 6 +- py-polars/tests/unit/operations/test_cast.py | 6 +- py-polars/tests/unit/operations/test_clear.py | 4 +- py-polars/tests/unit/operations/test_ewm.py | 16 +- .../parametric/strategies/test_core.py | 24 + .../parametric/strategies/test_dtype.py | 54 ++ .../parametric/strategies/test_legacy.py | 20 + .../parametric/strategies/test_utils.py | 22 + 35 files changed, 1672 insertions(+), 1436 deletions(-) create mode 100644 py-polars/polars/_utils/constants.py delete mode 100644 py-polars/polars/testing/parametric/primitives.py delete mode 100644 py-polars/polars/testing/parametric/strategies.py create mode 100644 py-polars/polars/testing/parametric/strategies/__init__.py create mode 100644 py-polars/polars/testing/parametric/strategies/_utils.py create mode 100644 py-polars/polars/testing/parametric/strategies/core.py create mode 100644 py-polars/polars/testing/parametric/strategies/data.py create mode 100644 py-polars/polars/testing/parametric/strategies/dtype.py create mode 100644 py-polars/polars/testing/parametric/strategies/legacy.py create mode 100644 py-polars/tests/unit/testing/parametric/strategies/test_core.py create mode 100644 py-polars/tests/unit/testing/parametric/strategies/test_dtype.py create mode 100644 py-polars/tests/unit/testing/parametric/strategies/test_legacy.py create mode 100644 py-polars/tests/unit/testing/parametric/strategies/test_utils.py diff --git a/py-polars/docs/requirements-docs.txt b/py-polars/docs/requirements-docs.txt index 6b33ddc135b3..36f02988b0c7 100644 --- a/py-polars/docs/requirements-docs.txt +++ b/py-polars/docs/requirements-docs.txt @@ -2,7 +2,7 @@ numpy pandas pyarrow -hypothesis==6.97.4 +hypothesis==6.100.4 sphinx==7.2.4 diff --git a/py-polars/docs/source/reference/testing.rst b/py-polars/docs/source/reference/testing.rst index 78ce4c96a0bd..6cdc5ddbba78 100644 --- a/py-polars/docs/source/reference/testing.rst +++ b/py-polars/docs/source/reference/testing.rst @@ -40,17 +40,18 @@ and library integrations: * `Quick start guide `_ -Polars primitives +Polars strategies ~~~~~~~~~~~~~~~~~ Polars provides the following `hypothesis `_ -testing primitives and strategy generators/helpers to make it easy to generate -suitable test DataFrames and Series. +testing strategies: .. autosummary:: :toctree: api/ testing.parametric.dataframes + testing.parametric.dtypes + testing.parametric.lists testing.parametric.series @@ -112,20 +113,21 @@ of any generated value being ``null`` (this is distinct from ``NaN``). .. code-block:: python + import polars as pl from polars.testing.parametric import dataframes from polars import NUMERIC_DTYPES - from hypothesis import given + from hypothesis import given @given( dataframes( cols=5, - null_probabililty=0.1, + null_probability=0.1, allowed_dtypes=NUMERIC_DTYPES, ) ) - def test_numeric(df): - assert all(df[col].is_numeric() for col in df.columns) + def test_numeric(df: pl.DataFrame): + assert all(df[col].dtype.is_numeric() for col in df.columns) # Example frame: # ┌──────┬────────┬───────┬────────────┬────────────┐ @@ -145,27 +147,27 @@ conform to the given strategies: .. code-block:: python + import polars as pl from polars.testing.parametric import column, dataframes - from hypothesis.strategies import floats, sampled_from, text - from hypothesis import given + import hypothesis.strategies as st + from hypothesis import given from string import ascii_letters, digits id_chars = ascii_letters + digits - @given( dataframes( cols=[ - column("id", strategy=text(min_size=4, max_size=4, alphabet=id_chars)), - column("ccy", strategy=sampled_from(["GBP", "EUR", "JPY", "USD"])), - column("price", strategy=floats(min_value=0.0, max_value=1000.0)), + column("id", strategy=st.text(min_size=4, max_size=4, alphabet=id_chars)), + column("ccy", strategy=st.sampled_from(["GBP", "EUR", "JPY", "USD"])), + column("price", strategy=st.floats(min_value=0.0, max_value=1000.0)), ], min_size=5, lazy=True, ) ) - def test_price_calculations(lf): + def test_price_calculations(lf: pl.LazyFrame): ... print(lf.collect()) @@ -189,17 +191,18 @@ is always less than or equal to the second value: .. code-block:: python - from polars.testing.parametric import create_list_strategy, dataframes, column - from hypothesis.strategies import composite - from hypothesis import given + import polars as pl + from polars.testing.parametric import column, dataframes, lists + import hypothesis.strategies as st + from hypothesis import given - @composite - def uint8_pairs(draw, uints=create_list_strategy(pl.UInt8, size=2)): + @st.composite + def uint8_pairs(draw: st.DrawFn): + uints = lists(pl.UInt8, size=2) pairs = list(zip(draw(uints), draw(uints))) return [sorted(ints) for ints in pairs] - @given( dataframes( cols=[ @@ -207,11 +210,11 @@ is always less than or equal to the second value: column("coly", strategy=uint8_pairs()), column("colz", strategy=uint8_pairs()), ], - size=3, + min_size=3, + max_size=3, ) ) - def test_miscellaneous(df): - ... + def test_miscellaneous(df: pl.DataFrame): ... # Example frame: # ┌─────────────────────────┬─────────────────────────┬──────────────────────────┐ diff --git a/py-polars/polars/_utils/constants.py b/py-polars/polars/_utils/constants.py new file mode 100644 index 000000000000..5b13a0157d16 --- /dev/null +++ b/py-polars/polars/_utils/constants.py @@ -0,0 +1,26 @@ +from datetime import date, datetime, timezone + +# Integer ranges +I8_MIN = -(2**7) +I16_MIN = -(2**15) +I32_MIN = -(2**31) +I64_MIN = -(2**63) +I8_MAX = 2**7 - 1 +I16_MAX = 2**15 - 1 +I32_MAX = 2**31 - 1 +I64_MAX = 2**63 - 1 +U8_MAX = 2**8 - 1 +U16_MAX = 2**16 - 1 +U32_MAX = 2**32 - 1 +U64_MAX = 2**64 - 1 + +# Temporal +SECONDS_PER_DAY = 86_400 +SECONDS_PER_HOUR = 3_600 +NS_PER_SECOND = 1_000_000_000 +US_PER_SECOND = 1_000_000 +MS_PER_SECOND = 1_000 + +EPOCH_DATE = date(1970, 1, 1) +EPOCH = datetime(1970, 1, 1).replace(tzinfo=None) +EPOCH_UTC = datetime(1970, 1, 1, tzinfo=timezone.utc) diff --git a/py-polars/polars/_utils/convert.py b/py-polars/polars/_utils/convert.py index 92ae98feb67a..894f870ca91b 100644 --- a/py-polars/polars/_utils/convert.py +++ b/py-polars/polars/_utils/convert.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import date, datetime, time, timedelta, timezone +from datetime import datetime, time, timedelta, timezone from decimal import Context from functools import lru_cache from typing import ( @@ -13,26 +13,25 @@ overload, ) +from polars._utils.constants import ( + EPOCH, + EPOCH_DATE, + EPOCH_UTC, + MS_PER_SECOND, + NS_PER_SECOND, + SECONDS_PER_DAY, + SECONDS_PER_HOUR, + US_PER_SECOND, +) from polars.dependencies import _ZONEINFO_AVAILABLE, zoneinfo if TYPE_CHECKING: - from datetime import tzinfo + from datetime import date, tzinfo from decimal import Decimal from polars.type_aliases import TimeUnit -SECONDS_PER_DAY = 86_400 -SECONDS_PER_HOUR = 3_600 -NS_PER_SECOND = 1_000_000_000 -US_PER_SECOND = 1_000_000 -MS_PER_SECOND = 1_000 - -EPOCH_DATE = date(1970, 1, 1) -EPOCH = datetime(1970, 1, 1).replace(tzinfo=None) -EPOCH_UTC = datetime(1970, 1, 1, tzinfo=timezone.utc) - - @overload def parse_as_duration_string(td: None) -> None: ... diff --git a/py-polars/polars/testing/parametric/__init__.py b/py-polars/polars/testing/parametric/__init__.py index 862b0b0d923a..1ce6c77af71c 100644 --- a/py-polars/polars/testing/parametric/__init__.py +++ b/py-polars/polars/testing/parametric/__init__.py @@ -1,34 +1,33 @@ -from typing import Any - from polars.dependencies import _HYPOTHESIS_AVAILABLE -if _HYPOTHESIS_AVAILABLE: - from polars.testing.parametric.primitives import column, columns, dataframes, series - from polars.testing.parametric.profiles import load_profile, set_profile - from polars.testing.parametric.strategies import ( - all_strategies, - create_array_strategy, - create_list_strategy, - nested_strategies, - scalar_strategies, +if not _HYPOTHESIS_AVAILABLE: + msg = ( + "polars.testing.parametric requires the 'hypothesis' module\n" + "Please install it using the command: pip install hypothesis" ) -else: - - def __getattr__(*args: Any, **kwargs: Any) -> Any: - msg = f"polars.testing.parametric.{args[0]} requires the 'hypothesis' module" - raise ModuleNotFoundError(msg) from None + raise ModuleNotFoundError(msg) +from polars.testing.parametric.profiles import load_profile, set_profile +from polars.testing.parametric.strategies import ( + column, + columns, + create_list_strategy, + dataframes, + dtypes, + lists, + series, +) __all__ = [ - "all_strategies", + # strategies + "dataframes", + "series", "column", "columns", - "create_array_strategy", + "dtypes", + "lists", "create_list_strategy", - "dataframes", + # profiles "load_profile", - "nested_strategies", - "scalar_strategies", - "series", "set_profile", ] diff --git a/py-polars/polars/testing/parametric/primitives.py b/py-polars/polars/testing/parametric/primitives.py deleted file mode 100644 index fc41af3e8b7a..000000000000 --- a/py-polars/polars/testing/parametric/primitives.py +++ /dev/null @@ -1,724 +0,0 @@ -from __future__ import annotations - -import random -import warnings -from dataclasses import dataclass -from math import isfinite -from textwrap import dedent -from typing import TYPE_CHECKING, Any, Collection, Sequence, overload - -from hypothesis.errors import InvalidArgument, NonInteractiveExampleWarning -from hypothesis.strategies import booleans, composite, lists, sampled_from -from hypothesis.strategies._internal.utils import defines_strategy - -from polars.dataframe import DataFrame -from polars.datatypes import ( - DTYPE_TEMPORAL_UNITS, - Array, - Categorical, - DataType, - DataTypeClass, - Datetime, - Duration, - List, - is_polars_dtype, - py_type_to_dtype, -) -from polars.series import Series -from polars.string_cache import StringCache -from polars.testing.parametric.strategies import ( - _flexhash, - between, - create_array_strategy, - create_list_strategy, - dtype_strategies, - scalar_strategies, -) - -if TYPE_CHECKING: - from typing import Literal - - from hypothesis.strategies import DrawFn, SearchStrategy - - from polars import LazyFrame - from polars.type_aliases import OneOrMoreDataTypes, PolarsDataType - -_time_units = list(DTYPE_TEMPORAL_UNITS) - - -def empty_list(value: Any, *, nested: bool) -> bool: - """Check if value is an empty list, or a list that contains only empty lists.""" - if isinstance(value, list): - return ( - True - if value and not nested - else all(empty_list(v, nested=True) for v in value) - ) - return False - - -# ==================================================================== -# Polars 'hypothesis' primitives for Series, DataFrame, and LazyFrame -# See: https://hypothesis.readthedocs.io/ -# ==================================================================== -MAX_DATA_SIZE = 10 # max generated frame/series length -MAX_COLS = 8 # max number of generated cols - -# note: there is a rare 'list' dtype failure that needs to be tracked -# down before re-enabling selection from "all_strategies" ... -strategy_dtypes = list( - {dtype.base_type() for dtype in scalar_strategies} # all_strategies} -) - - -@dataclass -class column: - """ - Define a column for use with the @dataframes strategy. - - Parameters - ---------- - name : str - string column name. - dtype : PolarsDataType - a recognised polars dtype. - strategy : strategy, optional - supports overriding the default strategy for the given dtype. - null_probability : float, optional - percentage chance (expressed between 0.0 => 1.0) that a generated value is - None. this is applied independently of any None values generated by the - underlying strategy. - unique : bool, optional - flag indicating that all values generated for the column should be unique. - - Examples - -------- - >>> from hypothesis.strategies import sampled_from - >>> from polars.testing.parametric import column - >>> - >>> column(name="unique_small_ints", dtype=pl.UInt8, unique=True) - column(name='unique_small_ints', dtype=UInt8, strategy=None, null_probability=None, unique=True) - >>> column(name="ccy", strategy=sampled_from(["GBP", "EUR", "JPY"])) - column(name='ccy', dtype=String, strategy=sampled_from(['GBP', 'EUR', 'JPY']), null_probability=None, unique=False) - """ # noqa: W505 - - name: str - dtype: PolarsDataType | None = None - strategy: SearchStrategy[Any] | None = None - null_probability: float | None = None - unique: bool = False - - def __post_init__(self) -> None: - if (self.null_probability is not None) and ( - self.null_probability < 0 or self.null_probability > 1 - ): - msg = f"`null_probability` should be between 0.0 and 1.0, or None; found {self.null_probability!r}" - raise InvalidArgument(msg) - - if self.dtype is None: - tp = getattr(self.strategy, "_dtype", None) - if is_polars_dtype(tp): - self.dtype = tp - - if self.dtype is None and self.strategy is None: - self.dtype = random.choice(strategy_dtypes) - - elif self.dtype in (Array, List): - if self.strategy is not None: - self.dtype = getattr(self.strategy, "_dtype", self.dtype) - else: - if self.dtype == Array: - self.strategy = create_array_strategy( - getattr(self.dtype, "inner", None), - getattr(self.dtype, "width", None), - ) - else: - self.strategy = create_list_strategy( - getattr(self.dtype, "inner", None) - ) - self.dtype = self.strategy._dtype # type: ignore[attr-defined] - - # elif self.dtype == Struct: - # ... - - elif self.dtype not in scalar_strategies: - if self.dtype is not None: - msg = f"no strategy (currently) available for {self.dtype!r} type" - raise InvalidArgument(msg) - else: - # given a custom strategy, but no explicit dtype. infer one - # from the first non-None value that the strategy produces. - with warnings.catch_warnings(): - # note: usually you should not call "example()" outside of an - # interactive shell, hence the warning. however, here it is - # reasonable to do so, so we catch and ignore it - warnings.simplefilter("ignore", NonInteractiveExampleWarning) - sample_value_iter = ( - self.strategy.example() # type: ignore[union-attr] - for _ in range(100) - ) - try: - sample_value_type = type( - next( - e - for e in sample_value_iter - if e is not None and not empty_list(e, nested=True) - ) - ) - except StopIteration: - msg = "unable to determine dtype for strategy" - raise InvalidArgument(msg) from None - - if sample_value_type is not None: - value_dtype = py_type_to_dtype(sample_value_type) - if value_dtype is not Array and value_dtype is not List: - self.dtype = value_dtype - - -def columns( - cols: int | Sequence[str] | None = None, - *, - dtype: OneOrMoreDataTypes | None = None, - min_cols: int | None = 0, - max_cols: int | None = MAX_COLS, - unique: bool = False, -) -> list[column]: - """ - Define multiple columns for use with the @dataframes strategy. - - Generate a fixed sequence of `column` objects suitable for passing to the - @dataframes strategy, or using standalone (note that this function is not itself - a strategy). - - Notes - ----- - Additional control is available by creating a sequence of columns explicitly, - using the `column` class (an especially useful option is to override the default - data-generating strategy for a given col/dtype). - - Parameters - ---------- - cols : {int, [str]}, optional - integer number of cols to create, or explicit list of column names. if - omitted a random number of columns (between mincol and max_cols) are - created. - dtype : PolarsDataType, optional - a single dtype for all cols, or list of dtypes (the same length as `cols`). - if omitted, each generated column is assigned a random dtype. - min_cols : int, optional - if not passing an exact size, can set a minimum here (defaults to 0). - max_cols : int, optional - if not passing an exact size, can set a maximum value here (defaults to - MAX_COLS). - unique : bool, optional - indicate if the values generated for these columns should be unique - (per-column). - - Examples - -------- - >>> from polars.testing.parametric import columns, dataframes - >>> from hypothesis import given - >>> - >>> @given(dataframes(columns(["x", "y", "z"], unique=True))) - ... def test_unique_xyz(df: pl.DataFrame) -> None: - ... assert_something(df) - - Note, as 'columns' creates a list of native polars column definitions it can - also be used independently of parametric/hypothesis tests: - - >>> from string import punctuation - >>> - >>> def test_special_char_colname_init() -> None: - ... df = pl.DataFrame(schema=[(c.name, c.dtype) for c in columns(punctuation)]) - ... assert len(cols) == len(df.columns) - ... assert 0 == len(df.rows()) - """ - # create/assign named columns - if cols is None: - cols = random.randint( - a=min_cols or 0, - b=max_cols or MAX_COLS, - ) - if isinstance(cols, int): - names: list[str] = [f"col{n}" for n in range(cols)] - else: - names = list(cols) - - if isinstance(dtype, Sequence): - if len(dtype) != len(names): - msg = f"given {len(dtype)} dtypes for {len(names)} names" - raise InvalidArgument(msg) - dtypes = list(dtype) - elif dtype is None: - dtypes = [random.choice(strategy_dtypes) for _ in range(len(names))] - elif is_polars_dtype(dtype): - dtypes = [dtype] * len(names) - else: - msg = f"{dtype!r} is not a valid polars datatype" - raise InvalidArgument(msg) - - # init list of named/typed columns - return [column(name=nm, dtype=tp, unique=unique) for nm, tp in zip(names, dtypes)] - - -@defines_strategy() -def series( - *, - name: str | SearchStrategy[str] | None = None, - dtype: PolarsDataType | None = None, - size: int | None = None, - min_size: int | None = 0, - max_size: int | None = MAX_DATA_SIZE, - strategy: SearchStrategy[object] | None = None, - null_probability: float = 0.0, - allow_infinities: bool = True, - unique: bool = False, - chunked: bool | None = None, - allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, - excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, -) -> SearchStrategy[Series]: - """ - Hypothesis strategy for producing polars Series. - - Parameters - ---------- - name : {str, strategy}, optional - literal string or a strategy for strings (or None), passed to the Series - constructor name-param. - dtype : PolarsDataType, optional - a valid polars DataType for the resulting series. - size : int, optional - if set, creates a Series of exactly this size (ignoring min_size/max_size - params). - min_size : int, optional - if not passing an exact size, can set a minimum here (defaults to 0). - no-op if `size` is set. - max_size : int, optional - if not passing an exact size, can set a maximum value here (defaults to - MAX_DATA_SIZE). no-op if `size` is set. - strategy : strategy, optional - supports overriding the default strategy for the given dtype. - null_probability : float, optional - percentage chance (expressed between 0.0 => 1.0) that a generated value is - None. this is applied independently of any None values generated by the - underlying strategy. - allow_infinities : bool, optional - optionally disallow generation of +/-inf values for floating-point dtypes. - unique : bool, optional - indicate whether Series values should all be distinct. - chunked : bool, optional - ensure that Series with more than one element have `n_chunks` > 1. - if omitted, chunking is applied at random. - allowed_dtypes : {list,set}, optional - when automatically generating Series data, allow only these dtypes. - excluded_dtypes : {list,set}, optional - when automatically generating Series data, exclude these dtypes. - - Notes - ----- - In actual usage this is deployed as a unit test decorator, providing a strategy - that generates multiple Series with the given dtype/size characteristics for the - unit test. While developing a strategy/test, it can also be useful to call - `.example()` directly on a given strategy to see concrete instances of the - generated data. - - Examples - -------- - >>> from polars.testing.parametric import series - >>> from hypothesis import given - - In normal usage, as a simple unit test: - - >>> @given(s=series(null_probability=0.1)) - ... def test_repr_is_valid_string(s: pl.Series) -> None: - ... assert isinstance(repr(s), str) - - Experimenting locally with a custom List dtype strategy: - - >>> from polars.testing.parametric import create_list_strategy - >>> s = series( - ... strategy=create_list_strategy( - ... inner_dtype=pl.String, - ... select_from=["xx", "yy", "zz"], - ... ), - ... min_size=2, - ... max_size=4, - ... ) - >>> s.example() # doctest: +SKIP - shape: (4,) - Series: '' [list[str]] - [ - [] - ["yy", "yy", "zz"] - ["zz", "yy", "zz"] - ["xx"] - ] - """ - if isinstance(allowed_dtypes, (DataType, DataTypeClass)): - allowed_dtypes = [allowed_dtypes] - if isinstance(excluded_dtypes, (DataType, DataTypeClass)): - excluded_dtypes = [excluded_dtypes] - - selectable_dtypes = [ - dtype - for dtype in (allowed_dtypes or strategy_dtypes) - if dtype not in (excluded_dtypes or ()) - ] - if null_probability and not (0 <= null_probability <= 1): - msg = f"`null_probability` should be between 0.0 and 1.0, or None; found {null_probability}" - raise InvalidArgument(msg) - null_probability = float(null_probability or 0.0) - - @composite - def draw_series(draw: DrawFn) -> Series: - with StringCache(): - # create/assign series dtype and retrieve matching strategy - series_dtype: PolarsDataType = ( - draw(sampled_from(selectable_dtypes)) # type: ignore[assignment] - if dtype is None and strategy is None - else dtype - ) - if strategy is None: - if series_dtype is Datetime or series_dtype is Duration: - series_dtype = series_dtype(random.choice(_time_units)) # type: ignore[operator] - dtype_strategy = draw(dtype_strategies(series_dtype)) - else: - dtype_strategy = strategy - - if not allow_infinities and series_dtype.is_float(): - dtype_strategy = dtype_strategy.filter( - lambda x: not isinstance(x, float) or isfinite(x) - ) - - # create/assign series size - series_size = ( - between( - draw, int, min_=(min_size or 0), max_=(max_size or MAX_DATA_SIZE) - ) - if size is None - else size - ) - # assign series name - series_name = name if isinstance(name, str) or name is None else draw(name) - - # create series using dtype-specific strategy to generate values - if series_size == 0: - series_values = [] - elif null_probability == 1: - series_values = [None] * series_size - else: - series_values = draw( - lists( - dtype_strategy, - min_size=series_size, - max_size=series_size, - unique_by=(_flexhash if unique else None), - ) - ) - - # apply null values (custom frequency) - if null_probability and null_probability != 1: - for idx in range(series_size): - if random.random() < null_probability: - series_values[idx] = None - - # init series with strategy-generated data - s = Series( - name=series_name, - dtype=series_dtype, - values=series_values, - ) - if dtype == Categorical: - s = s.cast(Categorical) - if series_size and (chunked or (chunked is None and draw(booleans()))): - split_at = series_size // 2 - s = s[:split_at].append(s[split_at:]) - return s - - return draw_series() - - -_failed_frame_init_msgs_: set[str] = set() - - -@overload -def dataframes( - cols: int | column | Sequence[column] | None = None, - *, - lazy: Literal[False] = ..., - min_cols: int | None = 0, - max_cols: int | None = MAX_COLS, - size: int | None = None, - min_size: int | None = 0, - max_size: int | None = MAX_DATA_SIZE, - chunked: bool | None = None, - include_cols: Sequence[column] | column | None = None, - null_probability: float | dict[str, float] = 0.0, - allow_infinities: bool = True, - allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, - excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, -) -> SearchStrategy[DataFrame]: ... - - -@overload -def dataframes( - cols: int | column | Sequence[column] | None = None, - *, - lazy: Literal[True], - min_cols: int | None = 0, - max_cols: int | None = MAX_COLS, - size: int | None = None, - min_size: int | None = 0, - max_size: int | None = MAX_DATA_SIZE, - chunked: bool | None = None, - include_cols: Sequence[column] | column | None = None, - null_probability: float | dict[str, float] = 0.0, - allow_infinities: bool = True, - allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, - excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, -) -> SearchStrategy[LazyFrame]: ... - - -@defines_strategy() -def dataframes( - cols: int | column | Sequence[column] | None = None, - *, - lazy: bool = False, - min_cols: int | None = 0, - max_cols: int | None = MAX_COLS, - size: int | None = None, - min_size: int | None = 0, - max_size: int | None = MAX_DATA_SIZE, - chunked: bool | None = None, - include_cols: Sequence[column] | column | None = None, - null_probability: float | dict[str, float] = 0.0, - allow_infinities: bool = True, - allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, - excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, -) -> SearchStrategy[DataFrame | LazyFrame]: - """ - Hypothesis strategy for producing polars DataFrames or LazyFrames. - - Parameters - ---------- - cols : {int, columns}, optional - integer number of columns to create, or a sequence of `column` objects - that describe the desired DataFrame column data. - lazy : bool, optional - produce a LazyFrame instead of a DataFrame. - min_cols : int, optional - if not passing an exact size, can set a minimum here (defaults to 0). - max_cols : int, optional - if not passing an exact size, can set a maximum value here (defaults to - MAX_COLS). - size : int, optional - if set, will create a DataFrame of exactly this size (and ignore - the min_size/max_size len params). - min_size : int, optional - if not passing an exact size, set the minimum number of rows in the - DataFrame. - max_size : int, optional - if not passing an exact size, set the maximum number of rows in the - DataFrame. - chunked : bool, optional - ensure that DataFrames with more than row have `n_chunks` > 1. if - omitted, chunking will be randomised at the level of individual Series. - include_cols : [column], optional - a list of `column` objects to include in the generated DataFrame. note that - explicitly provided columns are appended onto the list of existing columns - (if any present). - null_probability : {float, dict[str,float]}, optional - percentage chance (expressed between 0.0 => 1.0) that a generated value is - None. this is applied independently of any None values generated by the - underlying strategy, and can be applied either on a per-column basis (if - given as a `{col:pct}` dict), or globally. if null_probability is defined - on a column, it takes precedence over the global value. - allow_infinities : bool, optional - optionally disallow generation of +/-inf values for floating-point dtypes. - allowed_dtypes : {list,set}, optional - when automatically generating data, allow only these dtypes. - excluded_dtypes : {list,set}, optional - when automatically generating data, exclude these dtypes. - - Notes - ----- - In actual usage this is deployed as a unit test decorator, providing a strategy - that generates DataFrames or LazyFrames with the given characteristics for - the unit test. While developing a strategy/test, it can also be useful to - call `.example()` directly on a given strategy to see concrete instances of - the generated data. - - Examples - -------- - Use `column` or `columns` to specify the schema of the types of DataFrame to - generate. Note: in actual use the strategy is applied as a test decorator, not - used standalone. - - >>> from polars.testing.parametric import column, columns, dataframes - >>> from hypothesis import given - - Generate arbitrary DataFrames (as part of a unit test): - - >>> @given(df=dataframes()) - ... def test_repr(df: pl.DataFrame) -> None: - ... assert isinstance(repr(df), str) - - Generate LazyFrames with at least 1 column, random dtypes, and specific size: - - >>> dfs = dataframes(min_cols=1, max_size=5, lazy=True) - >>> dfs.example() # doctest: +SKIP - - - Generate DataFrames with known colnames, random dtypes (per test, not per-frame): - - >>> dfs = dataframes(columns(["x", "y", "z"])) - >>> dfs.example() # doctest: +SKIP - shape: (3, 3) - ┌────────────┬───────┬────────────────────────────┐ - │ x ┆ y ┆ z │ - │ --- ┆ --- ┆ --- │ - │ date ┆ u16 ┆ datetime[μs] │ - ╞════════════╪═══════╪════════════════════════════╡ - │ 0565-08-12 ┆ 34715 ┆ 5844-09-20 00:33:31.076854 │ - │ 3382-10-17 ┆ 48662 ┆ 7540-01-29 11:20:14.836271 │ - │ 4063-06-17 ┆ 39092 ┆ 1889-05-05 13:25:41.874455 │ - └────────────┴───────┴────────────────────────────┘ - - Generate frames with explicitly named/typed columns and a fixed size: - - >>> dfs = dataframes( - ... [ - ... column("x", dtype=pl.Int32), - ... column("y", dtype=pl.Float64), - ... ], - ... size=2, - ... ) - >>> dfs.example() # doctest: +SKIP - shape: (2, 2) - ┌───────────┬────────────┐ - │ x ┆ y │ - │ --- ┆ --- │ - │ i32 ┆ f64 │ - ╞═══════════╪════════════╡ - │ -15836 ┆ 1.1755e-38 │ - │ 575050513 ┆ NaN │ - └───────────┴────────────┘ - """ - _failed_frame_init_msgs_.clear() - - if isinstance(min_size, int) and min_cols in (0, None): - min_cols = 1 - if isinstance(allowed_dtypes, (DataType, DataTypeClass)): - allowed_dtypes = [allowed_dtypes] - if isinstance(excluded_dtypes, (DataType, DataTypeClass)): - excluded_dtypes = [excluded_dtypes] - if isinstance(include_cols, column): - include_cols = [include_cols] - - selectable_dtypes = [ - dtype - for dtype in (allowed_dtypes or strategy_dtypes) - if dtype in strategy_dtypes and dtype not in (excluded_dtypes or ()) - ] - - @composite - def draw_frames(draw: DrawFn) -> DataFrame | LazyFrame: - """Reproducibly generate random DataFrames according to the given spec.""" - with StringCache(): - # if not given, create 'n' cols with random dtypes - if cols is None or isinstance(cols, int): - n = cols or between( - draw, int, min_=(min_cols or 0), max_=(max_cols or MAX_COLS) - ) - dtypes_ = [draw(sampled_from(selectable_dtypes)) for _ in range(n)] - coldefs = columns(cols=n, dtype=dtypes_) - elif isinstance(cols, column): - coldefs = [cols] - else: - coldefs = list(cols) - - # append any explicitly provided cols - coldefs.extend(include_cols or ()) - - # assign dataframe/series size - series_size = ( - between( - draw, int, min_=(min_size or 0), max_=(max_size or MAX_DATA_SIZE) - ) - if size is None - else size - ) - - # assign names, null probability - for idx, c in enumerate(coldefs): - if c.name is None: - c.name = f"col{idx}" - if c.null_probability is None: - if isinstance(null_probability, dict): - c.null_probability = null_probability.get(c.name, 0.0) - else: - c.null_probability = null_probability - - # init dataframe from generated series data; series data is - # given as a python-native sequence. - data = { - c.name: draw( - series( - name=c.name, - dtype=c.dtype, - size=series_size, - null_probability=(c.null_probability or 0.0), - allow_infinities=allow_infinities, - strategy=c.strategy, - unique=c.unique, - chunked=(chunked is None and draw(booleans())), - ) - ) - for c in coldefs - } - - # note: randomly change between column-wise and row-wise frame init - orient = "col" - if draw(booleans()) and not any(c.dtype in (Array, List) for c in coldefs): - data = list(zip(*data.values())) # type: ignore[assignment] - orient = "row" - - schema = [(c.name, c.dtype) for c in coldefs] - try: - df = DataFrame(data=data, schema=schema, orient=orient) # type: ignore[arg-type] - - # optionally generate chunked frames - if series_size > 1 and chunked is True: - split_at = series_size // 2 - df = df[:split_at].vstack(df[split_at:]) - - _failed_frame_init_msgs_.clear() - return df.lazy() if lazy else df - - except Exception: - # print code that will allow any init failure to be reproduced - if isinstance(data, dict): - frame_cols = ", ".join( - f"{col!r}: {s.to_init_repr()}" for col, s in data.items() - ) - frame_data = f"{{{frame_cols}}}" - else: - frame_data = repr(data) - - failed_frame_init = dedent( - f""" - # failed frame init: reproduce with... - pl.DataFrame( - data={frame_data}, - schema={repr(schema).replace("', ", "', pl.")}, - orient={orient!r}, - ) - """.replace("datetime.", "") - ) - # note: this avoids printing the repro twice - if failed_frame_init not in _failed_frame_init_msgs_: - _failed_frame_init_msgs_.add(failed_frame_init) - print(failed_frame_init) - raise - - return draw_frames() diff --git a/py-polars/polars/testing/parametric/profiles.py b/py-polars/polars/testing/parametric/profiles.py index 76682af6d7f2..9c31f7b5b9c5 100644 --- a/py-polars/polars/testing/parametric/profiles.py +++ b/py-polars/polars/testing/parametric/profiles.py @@ -31,7 +31,7 @@ def load_profile( Examples -------- >>> # load a custom profile that will run with 1500 iterations - >>> from polars.testing.parametric.profiles import load_profile + >>> from polars.testing.parametric import load_profile >>> load_profile(1500) """ common_settings = {"print_blob": True, "deadline": None} @@ -84,7 +84,7 @@ def set_profile(profile: ParametricProfileNames | int) -> None: Examples -------- >>> # prefer the 'balanced' profile for running parametric tests - >>> from polars.testing.parametric.profiles import set_profile + >>> from polars.testing.parametric import set_profile >>> set_profile("balanced") """ profile_name = str(profile).split(".")[-1] diff --git a/py-polars/polars/testing/parametric/strategies.py b/py-polars/polars/testing/parametric/strategies.py deleted file mode 100644 index 5d481811dedd..000000000000 --- a/py-polars/polars/testing/parametric/strategies.py +++ /dev/null @@ -1,494 +0,0 @@ -from __future__ import annotations - -import decimal -from datetime import datetime, timedelta -from itertools import chain -from random import choice, randint, shuffle -from string import ascii_uppercase -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterator, - Mapping, - MutableMapping, - Sequence, -) - -import hypothesis.strategies as st -from hypothesis.strategies import ( - SearchStrategy, - binary, - booleans, - characters, - composite, - dates, - datetimes, - floats, - from_type, - integers, - lists, - sampled_from, - sets, - text, - timedeltas, - times, -) - -from polars.datatypes import ( - Array, - Binary, - Boolean, - Categorical, - Date, - Datetime, - Decimal, - Duration, - Float32, - Float64, - Int8, - Int16, - Int32, - Int64, - List, - String, - Time, - UInt8, - UInt16, - UInt32, - UInt64, -) -from polars.type_aliases import PolarsDataType - -if TYPE_CHECKING: - import sys - - from hypothesis.strategies import DrawFn - - if sys.version_info >= (3, 11): - from typing import Self - else: - from typing_extensions import Self - - -@composite -def dtype_strategies(draw: DrawFn, dtype: PolarsDataType) -> SearchStrategy[Any]: - """Returns a strategy which generates valid values for the given data type.""" - if (strategy := all_strategies.get(dtype)) is not None: - return strategy - elif (strategy_base := all_strategies.get(dtype.base_type())) is not None: - return strategy_base - - if dtype == Decimal: - return draw( - decimal_strategies( - precision=getattr(dtype, "precision", None), - scale=getattr(dtype, "scale", None), - ) - ) - else: - msg = f"unsupported data type: {dtype}" - raise TypeError(msg) - - -def between(draw: DrawFn, type_: type, min_: Any, max_: Any) -> Any: - """Draw a value in a given range from a type-inferred strategy.""" - strategy_init = from_type(type_).function # type: ignore[attr-defined] - return draw(strategy_init(min_, max_)) - - -# scalar dtype strategies are largely straightforward, mapping directly -# onto the associated hypothesis strategy, with dtype-defined limits -strategy_bool = booleans() -strategy_f32 = floats(width=32) -strategy_f64 = floats(width=64) -strategy_i8 = integers(min_value=-(2**7), max_value=(2**7) - 1) -strategy_i16 = integers(min_value=-(2**15), max_value=(2**15) - 1) -strategy_i32 = integers(min_value=-(2**31), max_value=(2**31) - 1) -strategy_i64 = integers(min_value=-(2**63), max_value=(2**63) - 1) -strategy_u8 = integers(min_value=0, max_value=(2**8) - 1) -strategy_u16 = integers(min_value=0, max_value=(2**16) - 1) -strategy_u32 = integers(min_value=0, max_value=(2**32) - 1) -strategy_u64 = integers(min_value=0, max_value=(2**64) - 1) - -strategy_categorical = text(max_size=2, alphabet=ascii_uppercase) -strategy_string = text( - alphabet=characters(max_codepoint=1000, exclude_categories=["Cs", "Cc"]), - max_size=8, -) -strategy_binary = binary() -strategy_datetime_ns = datetimes( - min_value=datetime(1677, 9, 22, 0, 12, 43, 145225), - max_value=datetime(2262, 4, 11, 23, 47, 16, 854775), -) -strategy_datetime_us = strategy_datetime_ms = datetimes( - min_value=datetime(1, 1, 1), - max_value=datetime(9999, 12, 31, 23, 59, 59, 999000), -) -strategy_time = times() -strategy_date = dates() -strategy_duration = timedeltas( - min_value=timedelta(microseconds=-(2**46)), - max_value=timedelta(microseconds=(2**46) - 1), -) -strategy_closed = sampled_from(["left", "right", "both", "none"]) -strategy_time_unit = sampled_from(["ns", "us", "ms"]) - - -@composite -def decimal_strategies( - draw: DrawFn, precision: int | None = None, scale: int | None = None -) -> SearchStrategy[decimal.Decimal]: - """Returns a strategy which generates instances of Python `Decimal`.""" - if precision is None: - precision = draw(integers(min_value=scale or 1, max_value=38)) - if scale is None: - scale = draw(integers(min_value=0, max_value=precision)) - - c = decimal.Context(prec=precision) - exclusive_limit = c.create_decimal(f"1E+{precision - scale}") - inclusive_limit = c.next_minus(exclusive_limit) - - return st.decimals( - allow_nan=False, - allow_infinity=False, - min_value=-inclusive_limit, - max_value=inclusive_limit, - places=scale, - ) - - -@composite -def strategy_datetime_format(draw: DrawFn) -> str: - """Draw a random datetime format string.""" - fmt = draw( - sets( - sampled_from( - [ - "%m", - "%b", - "%B", - "%d", - "%j", - "%a", - "%A", - "%w", - "%H", - "%I", - "%p", - "%M", - "%S", - "%U", - "%W", - "%%", - ] - ), - ) - ) - - # Make sure year is always present - fmt.add("%Y") - - return " ".join(fmt) - - -class StrategyLookup(MutableMapping[PolarsDataType, SearchStrategy[Any]]): - """ - Mapping from polars DataTypes to hypothesis Strategies. - - We customise this so that retrieval of nested strategies respects the inner dtype - of List/Struct types; nested strategies are stored as callables that create the - given strategy on demand (there are infinitely many possible nested dtypes). - """ - - _items: dict[ - PolarsDataType, SearchStrategy[Any] | Callable[..., SearchStrategy[Any]] - ] - - def __init__( - self, - items: ( - Mapping[ - PolarsDataType, SearchStrategy[Any] | Callable[..., SearchStrategy[Any]] - ] - | None - ) = None, - ): - """ - Initialise lookup with the given dtype/strategy items. - - Parameters - ---------- - items - A dtype to strategy dict/mapping. - """ - self._items = {} - if items is not None: - self._items.update(items) - - def __setitem__( - self, - item: PolarsDataType, - value: SearchStrategy[Any] | Callable[..., SearchStrategy[Any]], - ) -> None: - """Add a dtype and its associated strategy to the lookup.""" - self._items[item] = value - - def __delitem__(self, item: PolarsDataType) -> None: - """Remove the given dtype from the lookup.""" - del self._items[item] - - def __getitem__(self, item: PolarsDataType) -> SearchStrategy[Any]: - """Retrieve a hypothesis strategy for the given dtype.""" - strat = self._items[item] - - # if the item is a scalar strategy, return it directly - if isinstance(strat, SearchStrategy): - return strat - - # instantiate nested strategies on demand, using the inner dtype. - # if no inner dtype, a randomly selected dtype is assigned. - return strat(inner_dtype=getattr(item, "inner", None)) - - def __len__(self) -> int: - """Return the number of items in the lookup.""" - return len(self._items) - - def __iter__(self) -> Iterator[PolarsDataType]: - """Iterate over the lookup's dtype keys.""" - yield from self._items - - def __or__(self, other: StrategyLookup) -> StrategyLookup: - """Create a new StrategyLookup from the union of this lookup and another.""" - return StrategyLookup().update(self).update(other) - - def update(self, items: StrategyLookup) -> Self: # type: ignore[override] - """Add new strategy items to the lookup.""" - self._items.update(items) - return self - - -scalar_strategies: StrategyLookup = StrategyLookup( - { - Boolean: strategy_bool, - Float32: strategy_f32, - Float64: strategy_f64, - Int8: strategy_i8, - Int16: strategy_i16, - Int32: strategy_i32, - Int64: strategy_i64, - UInt8: strategy_u8, - UInt16: strategy_u16, - UInt32: strategy_u32, - UInt64: strategy_u64, - Time: strategy_time, - Date: strategy_date, - Datetime("ns"): strategy_datetime_ns, - Datetime("us"): strategy_datetime_us, - Datetime("ms"): strategy_datetime_ms, - # Datetime("ns", "*"): strategy_datetime_ns_tz, - # Datetime("us", "*"): strategy_datetime_us_tz, - # Datetime("ms", "*"): strategy_datetime_ms_tz, - Datetime: strategy_datetime_us, - Duration("ns"): strategy_duration, - Duration("us"): strategy_duration, - Duration("ms"): strategy_duration, - Duration: strategy_duration, - Categorical: strategy_categorical, - String: strategy_string, - Binary: strategy_binary, - } -) -nested_strategies: StrategyLookup = StrategyLookup() - - -def _get_strategy_dtypes() -> list[PolarsDataType]: - """Get a list of all the dtypes for which we have a strategy.""" - strategy_dtypes = list(chain(scalar_strategies.keys(), nested_strategies.keys())) - return [tp.base_type() for tp in strategy_dtypes] - - -def _flexhash(elem: Any) -> int: - """Hashing that also handles lists/dicts (for 'unique' check).""" - if isinstance(elem, list): - return hash(tuple(_flexhash(e) for e in elem)) - elif isinstance(elem, dict): - return hash((_flexhash(k), _flexhash(v)) for k, v in elem.items()) - return hash(elem) - - -def create_array_strategy( - inner_dtype: PolarsDataType | None = None, - width: int | None = None, - *, - select_from: Sequence[Any] | None = None, - unique: bool = False, -) -> SearchStrategy[list[Any]]: - """ - Hypothesis strategy for producing polars Array data. - - Parameters - ---------- - inner_dtype : PolarsDataType - type of the inner array elements (can also be another Array). - width : int, optional - generated arrays will have this length. - select_from : list, optional - randomly select the innermost values from this list (otherwise - the default strategy associated with the innermost dtype is used). - unique : bool, optional - ensure that the generated lists contain unique values. - - Examples - -------- - Create a strategy that generates arrays of i32 values: - - >>> arr = create_array_strategy(inner_dtype=pl.Int32, width=3) - >>> arr.example() # doctest: +SKIP - [-11330, 24030, 116] - - Create a strategy that generates arrays of specific strings: - - >>> arr = create_array_strategy(inner_dtype=pl.String, width=2) - >>> arr.example() # doctest: +SKIP - ['xx', 'yy'] - """ - if width is None: - width = randint(a=1, b=8) - - if inner_dtype is None: - strats = list(_get_strategy_dtypes()) - shuffle(strats) - inner_dtype = choice(strats) - - strat = create_list_strategy( - inner_dtype=inner_dtype, - select_from=select_from, - size=width, - unique=unique, - ) - strat._dtype = Array(inner_dtype, width=width) # type: ignore[attr-defined] - return strat - - -def create_list_strategy( - inner_dtype: PolarsDataType | None = None, - *, - select_from: Sequence[Any] | None = None, - size: int | None = None, - min_size: int | None = None, - max_size: int | None = None, - unique: bool = False, -) -> SearchStrategy[list[Any]]: - """ - Hypothesis strategy for producing polars List data. - - Parameters - ---------- - inner_dtype : PolarsDataType - type of the inner list elements (can also be another List). - select_from : list, optional - randomly select the innermost values from this list (otherwise - the default strategy associated with the innermost dtype is used). - size : int, optional - if set, generated lists will be of exactly this size (and - ignore the min_size/max_size params). - min_size : int, optional - set the minimum size of the generated lists (default: 0 if unset). - max_size : int, optional - set the maximum size of the generated lists (default: 3 if - min_size is unset or zero, otherwise 2x min_size). - unique : bool, optional - ensure that the generated lists contain unique values. - - Examples - -------- - Create a strategy that generates a list of i32 values: - - >>> lst = create_list_strategy(inner_dtype=pl.Int32) - >>> lst.example() # doctest: +SKIP - [-11330, 24030, 116] - - Create a strategy that generates lists of lists of specific strings: - - >>> lst = create_list_strategy( - ... inner_dtype=pl.List(pl.String), - ... select_from=["xx", "yy", "zz"], - ... ) - >>> lst.example() # doctest: +SKIP - [['yy', 'xx'], [], ['zz']] - - Create a UInt8 dtype strategy as a hypothesis composite that generates - pairs of small int values where the first is always <= the second: - - >>> from hypothesis.strategies import composite - >>> - >>> @composite - ... def uint8_pairs(draw, uints=create_list_strategy(pl.UInt8, size=2)): - ... pairs = list(zip(draw(uints), draw(uints))) - ... return [sorted(ints) for ints in pairs] - >>> uint8_pairs().example() # doctest: +SKIP - [(12, 22), (15, 131)] - >>> uint8_pairs().example() # doctest: +SKIP - [(59, 176), (149, 149)] - """ - if select_from and inner_dtype is None: - msg = "if specifying `select_from`, must also specify `inner_dtype`" - raise ValueError(msg) - - if inner_dtype is None: - strats = list(_get_strategy_dtypes()) - shuffle(strats) - inner_dtype = choice(strats) - if size: - min_size = max_size = size - else: - min_size = min_size or 0 - if max_size is None: - max_size = 3 if not min_size else (min_size * 2) - - if inner_dtype in (Array, List): - if inner_dtype == Array: - if (width := getattr(inner_dtype, "width", None)) is None: - width = randint(a=1, b=8) - st = create_array_strategy( - inner_dtype=inner_dtype.inner, # type: ignore[union-attr] - select_from=select_from, - width=width, - ) - else: - st = create_list_strategy( - inner_dtype=inner_dtype.inner, # type: ignore[union-attr] - select_from=select_from, - min_size=min_size, - max_size=max_size, - ) - - if inner_dtype.inner is None and hasattr(st, "_dtype"): # type: ignore[union-attr] - inner_dtype = st._dtype - else: - st = ( - sampled_from(list(select_from)) - if select_from - else scalar_strategies[inner_dtype] - ) - - ls = lists( - elements=st, - min_size=min_size, - max_size=max_size, - unique_by=(_flexhash if unique else None), - ) - ls._dtype = List(inner_dtype) # type: ignore[attr-defined, arg-type] - return ls - - -# TODO: strategy for Struct dtype. -# def create_struct_strategy( - - -nested_strategies[Array] = create_array_strategy -nested_strategies[List] = create_list_strategy -# nested_strategies[Struct] = create_struct_strategy(inner_dtype=None) - -all_strategies = scalar_strategies | nested_strategies diff --git a/py-polars/polars/testing/parametric/strategies/__init__.py b/py-polars/polars/testing/parametric/strategies/__init__.py new file mode 100644 index 000000000000..2165db4dea52 --- /dev/null +++ b/py-polars/polars/testing/parametric/strategies/__init__.py @@ -0,0 +1,22 @@ +from polars.testing.parametric.strategies.core import ( + column, + dataframes, + series, +) +from polars.testing.parametric.strategies.data import lists +from polars.testing.parametric.strategies.dtype import dtypes +from polars.testing.parametric.strategies.legacy import columns, create_list_strategy + +__all__ = [ + # core + "dataframes", + "series", + "column", + # dtype + "dtypes", + # data + "lists", + # legacy + "columns", + "create_list_strategy", +] diff --git a/py-polars/polars/testing/parametric/strategies/_utils.py b/py-polars/polars/testing/parametric/strategies/_utils.py new file mode 100644 index 000000000000..8efdffbe60fd --- /dev/null +++ b/py-polars/polars/testing/parametric/strategies/_utils.py @@ -0,0 +1,14 @@ +from typing import Any + + +def flexhash(elem: Any) -> int: + """ + Hashing function that also handles lists and dictionaries. + + Used for `unique` check in nested strategies. + """ + if isinstance(elem, list): + return hash(tuple(flexhash(e) for e in elem)) + elif isinstance(elem, dict): + return hash(tuple((k, flexhash(v)) for k, v in elem.items())) + return hash(elem) diff --git a/py-polars/polars/testing/parametric/strategies/core.py b/py-polars/polars/testing/parametric/strategies/core.py new file mode 100644 index 000000000000..08e751a18d74 --- /dev/null +++ b/py-polars/polars/testing/parametric/strategies/core.py @@ -0,0 +1,451 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Collection, Sequence, overload + +import hypothesis.strategies as st +from hypothesis.errors import InvalidArgument + +import polars.functions as F +from polars.dataframe import DataFrame +from polars.datatypes import Boolean, DataType, DataTypeClass +from polars.series import Series +from polars.string_cache import StringCache +from polars.testing.parametric.strategies._utils import flexhash +from polars.testing.parametric.strategies.data import data +from polars.testing.parametric.strategies.dtype import _instantiate_dtype, dtypes + +if TYPE_CHECKING: + from typing import Literal + + from hypothesis.strategies import DrawFn, SearchStrategy + + from polars import LazyFrame + from polars.type_aliases import PolarsDataType + + +_ROW_LIMIT = 5 # max generated frame/series length +_COL_LIMIT = 5 # max number of generated cols + + +@st.composite +def series( # noqa: D417 + draw: DrawFn, + *, + name: str | SearchStrategy[str] | None = None, + dtype: PolarsDataType | None = None, + size: int | None = None, + min_size: int = 0, + max_size: int = _ROW_LIMIT, + strategy: SearchStrategy[Any] | None = None, + null_probability: float = 0.0, + allow_infinities: bool = True, + unique: bool = False, + chunked: bool | None = None, + allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, + excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, +) -> Series: + """ + Hypothesis strategy for producing polars Series. + + Parameters + ---------- + name : {str, strategy}, optional + literal string or a strategy for strings (or None), passed to the Series + constructor name-param. + dtype : PolarsDataType, optional + a valid polars DataType for the resulting series. + size : int, optional + if set, creates a Series of exactly this size (ignoring min_size/max_size + params). + min_size : int + if not passing an exact size, can set a minimum here (defaults to 0). + no-op if `size` is set. + max_size : int + if not passing an exact size, can set a maximum value here (defaults to + MAX_DATA_SIZE). no-op if `size` is set. + strategy : strategy, optional + supports overriding the default strategy for the given dtype. + null_probability : float + percentage chance (expressed between 0.0 => 1.0) that a generated value is + None. this is applied independently of any None values generated by the + underlying strategy. + allow_infinities : bool, optional + optionally disallow generation of +/-inf values for floating-point dtypes. + unique : bool, optional + indicate whether Series values should all be distinct. + chunked : bool, optional + ensure that Series with more than one element have `n_chunks` > 1. + if omitted, chunking is applied at random. + allowed_dtypes : {list,set}, optional + when automatically generating Series data, allow only these dtypes. + excluded_dtypes : {list,set}, optional + when automatically generating Series data, exclude these dtypes. + + Notes + ----- + In actual usage this is deployed as a unit test decorator, providing a strategy + that generates multiple Series with the given dtype/size characteristics for the + unit test. While developing a strategy/test, it can also be useful to call + `.example()` directly on a given strategy to see concrete instances of the + generated data. + + Examples + -------- + The strategy is generally used to generate series in a unit test: + + >>> from polars.testing.parametric import series + >>> from hypothesis import given + >>> @given(s=series(min_size=3, max_size=5)) + ... def test_series_len(s: pl.Series) -> None: + ... assert 3 <= s.len() <= 5 + + Drawing examples interactively is also possible with the `.example()` method. + This should be avoided while running tests. + + >>> from polars.testing.parametric import lists + >>> s = series(strategy=lists(pl.String, select_from=["xx", "yy", "zz"])) + >>> s.example() # doctest: +SKIP + shape: (4,) + Series: '' [list[str]] + [ + ["zz", "zz"] + ["zz", "xx", "yy"] + [] + ["xx"] + ] + """ + if not (0.0 <= null_probability <= 1.0): + msg = ( + f"`null_probability` should be between 0.0 and 1.0, got {null_probability}" + ) + raise InvalidArgument(msg) + + if isinstance(allowed_dtypes, (DataType, DataTypeClass)): + allowed_dtypes = [allowed_dtypes] + elif allowed_dtypes is not None and not isinstance(allowed_dtypes, Sequence): + allowed_dtypes = list(allowed_dtypes) + if isinstance(excluded_dtypes, (DataType, DataTypeClass)): + excluded_dtypes = [excluded_dtypes] + elif excluded_dtypes is not None and not isinstance(excluded_dtypes, Sequence): + excluded_dtypes = list(excluded_dtypes) + + if strategy is None: + if dtype is None: + dtype = draw( + dtypes(allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes) + ) + else: + dtype = draw( + _instantiate_dtype( + dtype, + allowed_dtypes=allowed_dtypes, + excluded_dtypes=excluded_dtypes, + ) + ) + + if size is None: + size = draw(st.integers(min_value=min_size, max_value=max_size)) + + if isinstance(name, st.SearchStrategy): + name = draw(name) + + if size == 0: + values = [] + elif null_probability == 1.0: + values = [None] * size + else: + # Create series using dtype-specific strategy to generate values + if strategy is None: + strategy = data(dtype, allow_infinity=allow_infinities) # type: ignore[arg-type] + + values = draw( + st.lists( + strategy, + min_size=size, + max_size=size, + unique_by=(flexhash if unique else None), + ) + ) + + s = Series(name=name, values=values, dtype=dtype) + + # Set null values + if 0.0 < null_probability < 1.0: + random = draw(st.randoms(use_true_random=True)) + validity = [random.random() > null_probability for _ in range(size)] + s = F.select(F.when(Series(validity, dtype=Boolean)).then(s)).to_series() + + # Apply chunking + if size > 1: + if chunked is None: + chunk_probability = 0.5 + chunked = draw(st.floats(0.0, 1.0)) < chunk_probability + if chunked: + split_at = size // 2 + s = s[:split_at].append(s[split_at:]) + + return s + + +@overload +def dataframes( + cols: int | column | Sequence[column] | None = None, + *, + lazy: Literal[False] = ..., + min_cols: int = 0, + max_cols: int = _COL_LIMIT, + size: int | None = None, + min_size: int = 0, + max_size: int = _ROW_LIMIT, + chunked: bool | None = None, + include_cols: Sequence[column] | column | None = None, + null_probability: float | dict[str, float] = 0.0, + allow_infinities: bool = True, + allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, + excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, +) -> SearchStrategy[DataFrame]: ... + + +@overload +def dataframes( + cols: int | column | Sequence[column] | None = None, + *, + lazy: Literal[True], + min_cols: int = 0, + max_cols: int = _COL_LIMIT, + size: int | None = None, + min_size: int = 0, + max_size: int = _ROW_LIMIT, + chunked: bool | None = None, + include_cols: Sequence[column] | column | None = None, + null_probability: float | dict[str, float] = 0.0, + allow_infinities: bool = True, + allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, + excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, +) -> SearchStrategy[LazyFrame]: ... + + +@st.composite +def dataframes( # noqa: D417 + draw: DrawFn, + cols: int | column | Sequence[column] | None = None, + *, + lazy: bool = False, + min_cols: int = 1, + max_cols: int = _COL_LIMIT, + size: int | None = None, + min_size: int = 0, + max_size: int = _ROW_LIMIT, + chunked: bool | None = None, + include_cols: Sequence[column] | column | None = None, + null_probability: float | dict[str, float] = 0.0, + allow_infinities: bool = True, + allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, + excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, +) -> DataFrame | LazyFrame: + """ + Hypothesis strategy for producing polars DataFrames or LazyFrames. + + Parameters + ---------- + cols : {int, columns}, optional + integer number of columns to create, or a sequence of `column` objects + that describe the desired DataFrame column data. + lazy : bool, optional + produce a LazyFrame instead of a DataFrame. + min_cols : int, optional + if not passing an exact size, can set a minimum here (defaults to 0). + max_cols : int, optional + if not passing an exact size, can set a maximum value here (defaults to + MAX_COLS). + size : int, optional + if set, will create a DataFrame of exactly this size (and ignore + the min_size/max_size len params). + min_size : int, optional + if not passing an exact size, set the minimum number of rows in the + DataFrame. + max_size : int, optional + if not passing an exact size, set the maximum number of rows in the + DataFrame. + chunked : bool, optional + ensure that DataFrames with more than row have `n_chunks` > 1. if + omitted, chunking will be randomised at the level of individual Series. + include_cols : [column], optional + a list of `column` objects to include in the generated DataFrame. note that + explicitly provided columns are appended onto the list of existing columns + (if any present). + null_probability : {float, dict[str,float]}, optional + percentage chance (expressed between 0.0 => 1.0) that a generated value is + None. this is applied independently of any None values generated by the + underlying strategy, and can be applied either on a per-column basis (if + given as a `{col:pct}` dict), or globally. if null_probability is defined + on a column, it takes precedence over the global value. + allow_infinities : bool, optional + optionally disallow generation of +/-inf values for floating-point dtypes. + allowed_dtypes : {list,set}, optional + when automatically generating data, allow only these dtypes. + excluded_dtypes : {list,set}, optional + when automatically generating data, exclude these dtypes. + + Notes + ----- + In actual usage this is deployed as a unit test decorator, providing a strategy + that generates DataFrames or LazyFrames with the given characteristics for + the unit test. While developing a strategy/test, it can also be useful to + call `.example()` directly on a given strategy to see concrete instances of + the generated data. + + Examples + -------- + The strategy is generally used to generate series in a unit test: + + >>> from polars.testing.parametric import dataframes + >>> from hypothesis import given + >>> @given(df=dataframes(min_size=3, max_size=5)) + ... def test_df_height(df: pl.DataFrame) -> None: + ... assert 3 <= df.height <= 5 + + Drawing examples interactively is also possible with the `.example()` method. + This should be avoided while running tests. + + >>> df = dataframes(allowed_dtypes=[pl.Datetime, pl.Float64], max_cols=3) + >>> df.example() # doctest: +SKIP + shape: (3, 3) + ┌─────────────┬────────────────────────────┬───────────┐ + │ col0 ┆ col1 ┆ col2 │ + │ --- ┆ --- ┆ --- │ + │ f64 ┆ datetime[ns] ┆ f64 │ + ╞═════════════╪════════════════════════════╪═══════════╡ + │ NaN ┆ 1844-07-05 06:19:48.848808 ┆ 3.1436e16 │ + │ -1.9914e218 ┆ 2068-12-01 23:05:11.412277 ┆ 2.7415e16 │ + │ 0.5 ┆ 2095-11-19 22:05:17.647961 ┆ -0.5 │ + └─────────────┴────────────────────────────┴───────────┘ + + Use :class:`column` for more control over which exactly which columns are generated. + + >>> from polars.testing.parametric import column + >>> dfs = dataframes( + ... [ + ... column("x", dtype=pl.Int32), + ... column("y", dtype=pl.Float64), + ... ], + ... size=2, + ... ) + >>> dfs.example() # doctest: +SKIP + shape: (2, 2) + ┌───────────┬────────────┐ + │ x ┆ y │ + │ --- ┆ --- │ + │ i32 ┆ f64 │ + ╞═══════════╪════════════╡ + │ -15836 ┆ 1.1755e-38 │ + │ 575050513 ┆ NaN │ + └───────────┴────────────┘ + """ + if isinstance(include_cols, column): + include_cols = [include_cols] + + if cols is None: + n_cols = draw(st.integers(min_value=min_cols, max_value=max_cols)) + cols = [column() for _ in range(n_cols)] + elif isinstance(cols, int): + cols = [column() for _ in range(cols)] + elif isinstance(cols, column): + cols = [cols] + else: + cols = list(cols) + + if include_cols: + cols.extend(list(include_cols)) + + if size is None: + size = draw(st.integers(min_value=min_size, max_value=max_size)) + + # assign names, null probability + for idx, c in enumerate(cols): + if c.name is None: + c.name = f"col{idx}" + if c.null_probability is None: + if isinstance(null_probability, dict): + c.null_probability = null_probability.get(c.name, 0.0) + else: + c.null_probability = null_probability + + # init dataframe from generated series data; series data is + # given as a python-native sequence. + with StringCache(): + data = { + c.name: draw( + series( + name=c.name, + dtype=c.dtype, + size=size, + null_probability=c.null_probability, # type: ignore[arg-type] + allow_infinities=allow_infinities, + strategy=c.strategy, + unique=c.unique, + chunked=None if chunked is None else False, + allowed_dtypes=allowed_dtypes, + excluded_dtypes=excluded_dtypes, + ) + ) + for c in cols + } + + df = DataFrame(data) + + # optionally generate chunked frames + if size > 1 and chunked: + split_at = size // 2 + df = df[:split_at].vstack(df[split_at:]) + + if lazy: + return df.lazy() + + return df + + +@dataclass +class column: + """ + Define a column for use with the @dataframes strategy. + + Parameters + ---------- + name : str + string column name. + dtype : PolarsDataType + a polars dtype. + strategy : strategy, optional + supports overriding the default strategy for the given dtype. + null_probability : float, optional + percentage chance (expressed between 0.0 => 1.0) that a generated value is + None. this is applied independently of any None values generated by the + underlying strategy. + unique : bool, optional + flag indicating that all values generated for the column should be unique. + + Examples + -------- + >>> from polars.testing.parametric import column + >>> column(name="unique_small_ints", dtype=pl.UInt8, unique=True) + column(name='unique_small_ints', dtype=UInt8, strategy=None, null_probability=None, unique=True) + + >>> from hypothesis.strategies import sampled_from + >>> column(name="ccy", strategy=sampled_from(["GBP", "EUR", "JPY"])) + column(name='ccy', dtype=None, strategy=sampled_from(['GBP', 'EUR', 'JPY']), null_probability=None, unique=False) + """ # noqa: W505 + + name: str | None = None + dtype: PolarsDataType | None = None + strategy: SearchStrategy[Any] | None = None + null_probability: float | None = None + unique: bool = False + + def __post_init__(self) -> None: + if (self.null_probability is not None) and ( + self.null_probability < 0 or self.null_probability > 1 + ): + msg = f"`null_probability` should be between 0.0 and 1.0, or None; found {self.null_probability!r}" + raise InvalidArgument(msg) diff --git a/py-polars/polars/testing/parametric/strategies/data.py b/py-polars/polars/testing/parametric/strategies/data.py new file mode 100644 index 000000000000..ad1373689c2c --- /dev/null +++ b/py-polars/polars/testing/parametric/strategies/data.py @@ -0,0 +1,334 @@ +"""Strategies for generating various forms of data.""" + +from __future__ import annotations + +import decimal +import string +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Literal, Sequence + +import hypothesis.strategies as st +from hypothesis.errors import InvalidArgument + +from polars._utils.constants import ( + EPOCH, + I8_MAX, + I8_MIN, + I16_MAX, + I16_MIN, + I32_MAX, + I32_MIN, + I64_MAX, + I64_MIN, + U8_MAX, + U16_MAX, + U32_MAX, + U64_MAX, +) +from polars.datatypes import ( + Array, + Binary, + Boolean, + Categorical, + Date, + Datetime, + Decimal, + Duration, + Float32, + Float64, + Int8, + Int16, + Int32, + Int64, + List, + Null, + String, + Time, + UInt8, + UInt16, + UInt32, + UInt64, +) +from polars.testing.parametric.strategies._utils import flexhash +from polars.testing.parametric.strategies.dtype import _DEFAULT_ARRAY_WIDTH_LIMIT + +if TYPE_CHECKING: + from datetime import date, datetime, time + + from hypothesis.strategies import SearchStrategy + + from polars.datatypes import DataType, DataTypeClass + from polars.type_aliases import PolarsDataType, TimeUnit + +_DEFAULT_LIST_LEN_LIMIT = 3 + +_INTEGER_STRATEGIES: dict[bool, dict[int, SearchStrategy[int]]] = { + True: { + 8: st.integers(I8_MIN, I8_MAX), + 16: st.integers(I16_MIN, I16_MAX), + 32: st.integers(I32_MIN, I32_MAX), + 64: st.integers(I64_MIN, I64_MAX), + }, + False: { + 8: st.integers(0, U8_MAX), + 16: st.integers(0, U16_MAX), + 32: st.integers(0, U32_MAX), + 64: st.integers(0, U64_MAX), + }, +} + + +def integers( + bit_width: Literal[8, 16, 32, 64] = 64, *, signed: bool = True +) -> SearchStrategy[int]: + """Create a strategy for generating integers.""" + return _INTEGER_STRATEGIES[signed][bit_width] + + +def floats( + bit_width: Literal[32, 64] = 64, *, allow_infinity: bool = True +) -> SearchStrategy[float]: + """Create a strategy for generating integers.""" + return st.floats(width=bit_width, allow_infinity=allow_infinity) + + +def booleans() -> SearchStrategy[bool]: + """Create a strategy for generating booleans.""" + return st.booleans() + + +def strings() -> SearchStrategy[str]: + """Create a strategy for generating string values.""" + alphabet = st.characters(max_codepoint=1000, exclude_categories=["Cs", "Cc"]) + return st.text(alphabet=alphabet, max_size=8) + + +def binary() -> SearchStrategy[bytes]: + """Create a strategy for generating bytes.""" + return st.binary() + + +def categories() -> SearchStrategy[str]: + """Create a strategy for generating category strings.""" + return st.text(alphabet=string.ascii_uppercase, min_size=1, max_size=2) + + +def times() -> SearchStrategy[time]: + """Create a strategy for generating `time` objects.""" + return st.times() + + +def dates() -> SearchStrategy[date]: + """Create a strategy for generating `date` objects.""" + return st.dates() + + +def datetimes(time_unit: TimeUnit = "us") -> SearchStrategy[datetime]: + """ + Create a strategy for generating `datetime` objects in the time unit's range. + + Parameters + ---------- + time_unit + Time unit for which the datetime objects are valid. + """ + if time_unit in ("us", "ms"): + # datetime.min/max fall within the range + return st.datetimes() + elif time_unit == "ns": + return st.datetimes( + min_value=EPOCH + timedelta(microseconds=I64_MIN // 1000), + max_value=EPOCH + timedelta(microseconds=I64_MAX // 1000), + ) + else: + msg = f"invalid time unit: {time_unit}" + raise InvalidArgument(msg) + + +def durations(time_unit: TimeUnit = "us") -> SearchStrategy[timedelta]: + """ + Create a strategy for generating `timedelta` objects in the time unit's range. + + Parameters + ---------- + time_unit + Time unit for which the timedelta objects are valid. + """ + if time_unit == "us": + return st.timedeltas( + min_value=timedelta(microseconds=I64_MIN), + max_value=timedelta(microseconds=I64_MAX), + ) + elif time_unit == "ns": + return st.timedeltas( + min_value=timedelta(microseconds=I64_MIN // 1000), + max_value=timedelta(microseconds=I64_MAX // 1000), + ) + elif time_unit == "ms": + # TODO: Enable full range of millisecond durations + # timedelta.min/max fall within the range + # return st.timedeltas() + return st.timedeltas( + min_value=timedelta(microseconds=I64_MIN), + max_value=timedelta(microseconds=I64_MAX), + ) + else: + msg = f"invalid time unit: {time_unit}" + raise InvalidArgument(msg) + + +def decimals( + precision: int | None = 38, scale: int = 0 +) -> SearchStrategy[decimal.Decimal]: + """ + Create a strategy for generating `Decimal` objects. + + Parameters + ---------- + precision + Maximum number of digits in each number. + If set to `None`, the precision is set to 38 (the maximum supported by Polars). + scale + Number of digits to the right of the decimal point in each number. + """ + if precision is None: + precision = 38 + + c = decimal.Context(prec=precision) + exclusive_limit = c.create_decimal(f"1E+{precision - scale}") + max_value = c.next_minus(exclusive_limit) + min_value = c.copy_negate(max_value) + + return st.decimals( + min_value=min_value, + max_value=max_value, + allow_nan=False, + allow_infinity=False, + places=scale, + ) + + +def lists( + inner_dtype: DataType, + *, + select_from: Sequence[Any] | None = None, + min_len: int = 0, + max_len: int | None = None, + unique: bool = False, + **kwargs: Any, +) -> SearchStrategy[list[Any]]: + """ + Create a strategy for generating lists of the given data type. + + Parameters + ---------- + inner_dtype + Data type of the list elements. If the data type is not fully instantiated, + defaults will be used, e.g. `Datetime` will become `Datetime('us')`. + select_from + The values to use for the innermost lists. If set to `None` (default), + the default strategy associated with the innermost data type is used. + min_len + The minimum length of the generated lists. + max_len + The maximum length of the generated lists. If set to `None` (default), the + maximum is set based on `min_size`: `3` if `min_len` is zero, + otherwise `2 * min_len`. + unique + Ensure that the generated lists contain unique values. + **kwargs + Additional arguments that are passed to nested data generation strategies. + + Examples + -------- + ... + """ + if max_len is None: + max_len = _DEFAULT_LIST_LEN_LIMIT if min_len == 0 else min_len * 2 + + if select_from is not None and not inner_dtype.is_nested(): + inner_strategy = st.sampled_from(select_from) + else: + inner_strategy = data( + inner_dtype, + select_from=select_from, + min_size=min_len, + max_size=max_len, + unique=unique, + **kwargs, + ) + + return st.lists( + elements=inner_strategy, + min_size=min_len, + max_size=max_len, + unique_by=(flexhash if unique else None), + ) + + +def nulls() -> SearchStrategy[None]: + """Create a strategy for generating null values.""" + return st.none() + + +# Strategies that are not customizable through parameters +_STATIC_STRATEGIES: dict[DataTypeClass, SearchStrategy[Any]] = { + Boolean: booleans(), + Int8: integers(8, signed=True), + Int16: integers(16, signed=True), + Int32: integers(32, signed=True), + Int64: integers(64, signed=True), + UInt8: integers(8, signed=False), + UInt16: integers(16, signed=False), + UInt32: integers(32, signed=False), + UInt64: integers(64, signed=False), + Time: times(), + Date: dates(), + Categorical: categories(), + String: strings(), + Binary: binary(), + Null: nulls(), +} + + +def data(dtype: PolarsDataType, **kwargs: Any) -> SearchStrategy[Any]: + """ + Create a strategy for generating data for the given data type. + + Parameters + ---------- + dtype + A Polars data type. If the data type is not fully instantiated, defaults will + be used, e.g. `Datetime` will become `Datetime('us')`. + **kwargs + Additional parameters for the strategy associated with the given `dtype`. + """ + if (strategy := _STATIC_STRATEGIES.get(dtype.base_type())) is not None: + return strategy + + if dtype == Float32: + return floats(32, allow_infinity=kwargs.pop("allow_infinity", True)) + elif dtype == Float64: + return floats(64, allow_infinity=kwargs.pop("allow_infinity", True)) + elif dtype == Datetime: + # TODO: Handle time zones + return datetimes(time_unit=getattr(dtype, "time_unit", None) or "us") + elif dtype == Duration: + return durations(time_unit=getattr(dtype, "time_unit", None) or "us") + elif dtype == Decimal: + return decimals(getattr(dtype, "precision", None), getattr(dtype, "scale", 0)) + elif dtype == List: + inner = getattr(dtype, "inner", None) or Null() + return lists(inner, **kwargs) + elif dtype == Array: + inner = getattr(dtype, "inner", None) or Null() + width = getattr(dtype, "width", _DEFAULT_ARRAY_WIDTH_LIMIT) + return lists( + inner, + min_len=width, + max_len=width, + **kwargs, + ) + else: + msg = f"unsupported data type: {dtype}" + raise InvalidArgument(msg) diff --git a/py-polars/polars/testing/parametric/strategies/dtype.py b/py-polars/polars/testing/parametric/strategies/dtype.py new file mode 100644 index 000000000000..ae54a5e405b4 --- /dev/null +++ b/py-polars/polars/testing/parametric/strategies/dtype.py @@ -0,0 +1,321 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Collection, Sequence + +import hypothesis.strategies as st +from hypothesis.errors import InvalidArgument + +from polars.datatypes import ( + Array, + Binary, + Boolean, + Categorical, + DataType, + Date, + Datetime, + Decimal, + Duration, + Float32, + Float64, + Int8, + Int16, + Int32, + Int64, + List, + Null, + String, + Struct, + Time, + UInt8, + UInt16, + UInt32, + UInt64, +) + +if TYPE_CHECKING: + from hypothesis.strategies import DrawFn, SearchStrategy + + from polars.datatypes import DataTypeClass + from polars.type_aliases import CategoricalOrdering, PolarsDataType, TimeUnit + + +# Supported data type classes which do not take any arguments +_SIMPLE_DTYPES: list[DataTypeClass] = [ + Int64, + Int32, + Int16, + Int8, + Float64, + Float32, + Boolean, + UInt8, + UInt16, + UInt32, + UInt64, + String, + Binary, + Date, + Time, +] +# Supported data type classes with arguments +_COMPLEX_DTYPES: list[DataTypeClass] = [ + Datetime, + Duration, + Categorical, + Decimal, +] +# Supported data type classes that contain other data types +_NESTED_DTYPES: list[DataTypeClass] = [ + # TODO: Enable nested types by default when various issues are solved. + # List, + # Array, + # Struct, +] +# Supported data type classes that do not contain other data types +_FLAT_DTYPES = _SIMPLE_DTYPES + _COMPLEX_DTYPES + +_DEFAULT_ARRAY_WIDTH_LIMIT = 3 +_DEFAULT_STRUCT_FIELDS_LIMIT = 3 + + +def dtypes( + *, + allowed_dtypes: Collection[PolarsDataType] | None = None, + excluded_dtypes: Sequence[PolarsDataType] | None = None, + nesting_level: int = 3, +) -> SearchStrategy[DataType]: + """ + Create a strategy for generating Polars :class:`DataType` objects. + + Parameters + ---------- + allowed_dtypes + Data types the strategy will pick from. If set to `None` (default), + all supported data types are included. + excluded_dtypes + Data types the strategy will *not* pick from. This takes priority over + data types specified in `allowed_dtypes`. + nesting_level + The complexity of nested data types. If set to 0, nested data types are + disabled. + """ + flat_dtypes, nested_dtypes = _parse_allowed_dtypes(allowed_dtypes) + + if nesting_level > 0 and nested_dtypes: + if not flat_dtypes: + return _nested_dtypes( + inner=st.just(Null()), + allowed_dtypes=nested_dtypes, + excluded_dtypes=excluded_dtypes, + ) + return st.recursive( + base=_flat_dtypes( + allowed_dtypes=flat_dtypes, excluded_dtypes=excluded_dtypes + ), + extend=lambda s: _nested_dtypes( + s, allowed_dtypes=nested_dtypes, excluded_dtypes=excluded_dtypes + ), + max_leaves=nesting_level, + ) + else: + return _flat_dtypes(allowed_dtypes=flat_dtypes, excluded_dtypes=excluded_dtypes) + + +def _parse_allowed_dtypes( + allowed_dtypes: Collection[PolarsDataType] | None = None, +) -> tuple[Sequence[PolarsDataType], Sequence[PolarsDataType]]: + """Split allowed dtypes into flat and nested data types.""" + if allowed_dtypes is None: + return _FLAT_DTYPES, _NESTED_DTYPES + + allowed_dtypes_flat = [] + allowed_dtypes_nested = [] + for dt in allowed_dtypes: + if dt.is_nested(): + allowed_dtypes_nested.append(dt) + else: + allowed_dtypes_flat.append(dt) + + return allowed_dtypes_flat, allowed_dtypes_nested + + +@st.composite +def _flat_dtypes( + draw: DrawFn, + allowed_dtypes: Sequence[PolarsDataType] | None = None, + excluded_dtypes: Sequence[PolarsDataType] | None = None, +) -> DataType: + """Create a strategy for generating non-nested Polars :class:`DataType` objects.""" + if allowed_dtypes is None: + allowed_dtypes = _FLAT_DTYPES + if excluded_dtypes is None: + excluded_dtypes = [] + + dtype = draw(st.sampled_from(allowed_dtypes)) + return draw( + _instantiate_flat_dtype(dtype).filter(lambda x: x not in excluded_dtypes) + ) + + +@st.composite +def _instantiate_flat_dtype(draw: DrawFn, dtype: PolarsDataType) -> DataType: + """Take a flat data type and instantiate it.""" + if isinstance(dtype, DataType): + return dtype + elif dtype in _SIMPLE_DTYPES: + return dtype() + elif dtype == Datetime: + # TODO: Add time zones + time_unit = draw(_time_units()) + return Datetime(time_unit) + elif dtype == Duration: + time_unit = draw(_time_units()) + return Duration(time_unit) + elif dtype == Categorical: + ordering = draw(_categorical_orderings()) + return Categorical(ordering) + elif dtype == Decimal: + precision = draw(st.integers(min_value=1, max_value=38) | st.none()) + scale = draw(st.integers(min_value=0, max_value=precision or 38)) + return Decimal(precision, scale) + else: + msg = f"unsupported data type: {dtype}" + raise InvalidArgument(msg) + + +@st.composite +def _nested_dtypes( + draw: DrawFn, + inner: SearchStrategy[DataType], + allowed_dtypes: Sequence[PolarsDataType] | None = None, + excluded_dtypes: Sequence[PolarsDataType] | None = None, +) -> DataType: + """Create a strategy for generating nested Polars :class:`DataType` objects.""" + if allowed_dtypes is None: + allowed_dtypes = _NESTED_DTYPES + if excluded_dtypes is None: + excluded_dtypes = [] + + dtype = draw(st.sampled_from(allowed_dtypes)) + return draw( + _instantiate_nested_dtype(dtype, inner).filter( + lambda x: x not in excluded_dtypes + ) + ) + + +@st.composite +def _instantiate_nested_dtype( + draw: DrawFn, + dtype: PolarsDataType, + inner: SearchStrategy[DataType], +) -> DataType: + """Take a nested data type and instantiate it.""" + + def instantiate_inner(dtype: PolarsDataType) -> DataType: + inner_dtype = getattr(dtype, "inner", None) + if inner_dtype is None: + return draw(inner) + elif inner_dtype.is_nested(): + return draw(_instantiate_nested_dtype(inner_dtype, inner)) + else: + return draw(_instantiate_flat_dtype(inner_dtype)) + + if dtype == List: + inner_dtype = instantiate_inner(dtype) + return List(inner_dtype) + elif dtype == Array: + inner_dtype = instantiate_inner(dtype) + width = getattr( + dtype, + "width", + draw(st.integers(min_value=1, max_value=_DEFAULT_ARRAY_WIDTH_LIMIT)), + ) + return Array(inner_dtype, width) + elif dtype == Struct: + # TODO: Recursively instantiate struct field dtypes + if isinstance(dtype, DataType): + return dtype + n_fields = draw( + st.integers(min_value=1, max_value=_DEFAULT_STRUCT_FIELDS_LIMIT) + ) + return Struct({f"f{i}": draw(inner) for i in range(n_fields)}) + else: + msg = f"unsupported data type: {dtype}" + raise InvalidArgument(msg) + + +def _time_units() -> SearchStrategy[TimeUnit]: + """Create a strategy for generating valid units of time.""" + return st.sampled_from(["us", "ns", "ms"]) + + +def _categorical_orderings() -> SearchStrategy[CategoricalOrdering]: + """Create a strategy for generating valid ordering types for categorical data.""" + return st.sampled_from(["physical", "lexical"]) + + +@st.composite +def _instantiate_dtype( + draw: DrawFn, + dtype: PolarsDataType, + *, + allowed_dtypes: Collection[PolarsDataType] | None = None, + excluded_dtypes: Sequence[PolarsDataType] | None = None, + nesting_level: int = 3, +) -> DataType: + """Take a data type and instantiate it.""" + if not dtype.is_nested(): + if allowed_dtypes is None: + allowed_dtypes = [dtype] + else: + allowed_dtypes = [dt for dt in allowed_dtypes if dt == dtype] + return draw( + _flat_dtypes(allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes) + ) + + def draw_inner(dtype: PolarsDataType) -> DataType: + if isinstance(dtype, DataType): + return draw( + _instantiate_dtype( + dtype.inner, # type: ignore[attr-defined] + allowed_dtypes=allowed_dtypes, + excluded_dtypes=excluded_dtypes, + nesting_level=nesting_level - 1, + ) + ) + else: + return draw( + dtypes( + allowed_dtypes=allowed_dtypes, + excluded_dtypes=excluded_dtypes, + nesting_level=nesting_level - 1, + ) + ) + + if dtype == List: + inner = draw_inner(dtype) + return List(inner) + elif dtype == Array: + inner = draw_inner(dtype) + width = getattr( + dtype, + "width", + draw(st.integers(min_value=1, max_value=_DEFAULT_ARRAY_WIDTH_LIMIT)), + ) + return Array(inner, width) + elif dtype == Struct: + if isinstance(dtype, DataType): + return dtype + n_fields = draw( + st.integers(min_value=1, max_value=_DEFAULT_STRUCT_FIELDS_LIMIT) + ) + inner_strategy = dtypes( + allowed_dtypes=allowed_dtypes, + excluded_dtypes=excluded_dtypes, + nesting_level=nesting_level - 1, + ) + return Struct({f"f{i}": draw(inner_strategy) for i in range(n_fields)}) + else: + msg = f"unsupported data type: {dtype}" + raise InvalidArgument(msg) diff --git a/py-polars/polars/testing/parametric/strategies/legacy.py b/py-polars/polars/testing/parametric/strategies/legacy.py new file mode 100644 index 000000000000..434d791c604b --- /dev/null +++ b/py-polars/polars/testing/parametric/strategies/legacy.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence + +import hypothesis.strategies as st +from hypothesis.errors import InvalidArgument + +from polars._utils.deprecation import deprecate_function +from polars.datatypes import is_polars_dtype +from polars.testing.parametric.strategies.core import _COL_LIMIT, column +from polars.testing.parametric.strategies.data import lists +from polars.testing.parametric.strategies.dtype import _instantiate_dtype, dtypes + +if TYPE_CHECKING: + from hypothesis.strategies import SearchStrategy + + from polars.type_aliases import OneOrMoreDataTypes, PolarsDataType + + +@deprecate_function( + "Use `column` instead in conjunction with a list comprehension.", version="0.20.26" +) +def columns( + cols: int | Sequence[str] | None = None, + *, + dtype: OneOrMoreDataTypes | None = None, + min_cols: int = 0, + max_cols: int = _COL_LIMIT, + unique: bool = False, +) -> list[column]: + """ + Define multiple columns for use with the @dataframes strategy. + + .. deprecated:: 0.20.26 + Use :class:`column` instead in conjunction with a list comprehension. + + Generate a fixed sequence of `column` objects suitable for passing to the + @dataframes strategy, or using standalone (note that this function is not itself + a strategy). + + Notes + ----- + Additional control is available by creating a sequence of columns explicitly, + using the `column` class (an especially useful option is to override the default + data-generating strategy for a given col/dtype). + + Parameters + ---------- + cols : {int, [str]}, optional + integer number of cols to create, or explicit list of column names. if + omitted a random number of columns (between mincol and max_cols) are + created. + dtype : PolarsDataType, optional + a single dtype for all cols, or list of dtypes (the same length as `cols`). + if omitted, each generated column is assigned a random dtype. + min_cols : int, optional + if not passing an exact size, can set a minimum here (defaults to 0). + max_cols : int, optional + if not passing an exact size, can set a maximum value here (defaults to + MAX_COLS). + unique : bool, optional + indicate if the values generated for these columns should be unique + (per-column). + + Examples + -------- + >>> from polars.testing.parametric import columns, dataframes + >>> from hypothesis import given + >>> @given(dataframes(columns(["x", "y", "z"], unique=True))) # doctest: +SKIP + ... def test_unique_xyz(df: pl.DataFrame) -> None: + ... assert_something(df) + """ + # create/assign named columns + if cols is None: + cols = st.integers(min_value=min_cols, max_value=max_cols).example() + if isinstance(cols, int): + names: Sequence[str] = [f"col{n}" for n in range(cols)] + else: + names = cols + n_cols = len(names) + + if dtype is None: + dtypes: Sequence[PolarsDataType | None] = [None] * n_cols + elif is_polars_dtype(dtype): + dtypes = [dtype] * n_cols + elif isinstance(dtype, Sequence): + if (n_dtypes := len(dtype)) != n_cols: + msg = f"given {n_dtypes} dtypes for {n_cols} names" + raise InvalidArgument(msg) + dtypes = dtype + else: + msg = f"{dtype!r} is not a valid polars datatype" + raise InvalidArgument(msg) + + # init list of named/typed columns + return [column(name=nm, dtype=tp, unique=unique) for nm, tp in zip(names, dtypes)] + + +@deprecate_function("Use `lists` instead.", version="0.20.26") +def create_list_strategy( + inner_dtype: PolarsDataType | None = None, + *, + select_from: Sequence[Any] | None = None, + size: int | None = None, + min_size: int = 0, + max_size: int | None = None, + unique: bool = False, +) -> SearchStrategy[list[Any]]: + """ + Create a strategy for generating Polars :class:`List` data. + + .. deprecated:: 0.20.26 + Use :func:`lists` instead. + + Parameters + ---------- + inner_dtype : PolarsDataType + type of the inner list elements (can also be another List). + select_from : list, optional + randomly select the innermost values from this list (otherwise + the default strategy associated with the innermost dtype is used). + size : int, optional + if set, generated lists will be of exactly this size (and + ignore the min_size/max_size params). + min_size : int, optional + set the minimum size of the generated lists (default: 0 if unset). + max_size : int, optional + set the maximum size of the generated lists (default: 3 if + min_size is unset or zero, otherwise 2x min_size). + unique : bool, optional + ensure that the generated lists contain unique values. + + Examples + -------- + Create a strategy that generates a list of i32 values: + + >>> from polars.testing.parametric import create_list_strategy + >>> lst = create_list_strategy(inner_dtype=pl.Int32) # doctest: +SKIP + >>> lst.example() # doctest: +SKIP + [-11330, 24030, 116] + """ + if size is not None: + min_size = max_size = size + + if inner_dtype is None: + inner_dtype = dtypes().example() + else: + inner_dtype = _instantiate_dtype(inner_dtype).example() + + return lists( + inner_dtype, + select_from=select_from, + min_len=min_size, + max_len=max_size, + unique=unique, + ) diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index 6afb0dfb85a2..70c366a73e51 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -55,7 +55,7 @@ nest_asyncio # TOOLING # ------- -hypothesis==6.97.4 +hypothesis==6.100.4 pytest==8.2.0 pytest-codspeed==2.2.1 pytest-cov==5.0.0 diff --git a/py-polars/tests/parametric/conftest.py b/py-polars/tests/parametric/conftest.py index beb1befe32a1..5f0d33ec5073 100644 --- a/py-polars/tests/parametric/conftest.py +++ b/py-polars/tests/parametric/conftest.py @@ -1,6 +1,6 @@ import os -from polars.testing.parametric.profiles import load_profile +from polars.testing.parametric import load_profile load_profile( profile=os.environ.get("POLARS_HYPOTHESIS_PROFILE", "fast"), # type: ignore[arg-type] diff --git a/py-polars/tests/parametric/test_dataframe.py b/py-polars/tests/parametric/test_dataframe.py index 1895e7146530..5a782d1042bc 100644 --- a/py-polars/tests/parametric/test_dataframe.py +++ b/py-polars/tests/parametric/test_dataframe.py @@ -3,8 +3,8 @@ # ---------------------------------------------------- from __future__ import annotations +import hypothesis.strategies as st from hypothesis import example, given, settings -from hypothesis.strategies import integers import polars as pl from polars.testing import assert_frame_equal @@ -65,19 +65,21 @@ def test_null_count(df: pl.DataFrame) -> None: "start", dtype=pl.Int8, null_probability=0.15, - strategy=integers(min_value=-8, max_value=8), + strategy=st.integers(min_value=-8, max_value=8), ), column( "stop", dtype=pl.Int8, null_probability=0.15, - strategy=integers(min_value=-6, max_value=6), + strategy=st.integers(min_value=-6, max_value=6), ), column( "step", dtype=pl.Int8, null_probability=0.15, - strategy=integers(min_value=-4, max_value=4).filter(lambda x: x != 0), + strategy=st.integers(min_value=-4, max_value=4).filter( + lambda x: x != 0 + ), ), column("misc", dtype=pl.Int32), ], diff --git a/py-polars/tests/parametric/test_groupby_rolling.py b/py-polars/tests/parametric/test_groupby_rolling.py index 9b1eb75e20c6..e048eba378e5 100644 --- a/py-polars/tests/parametric/test_groupby_rolling.py +++ b/py-polars/tests/parametric/test_groupby_rolling.py @@ -10,13 +10,20 @@ import polars as pl from polars._utils.convert import parse_as_duration_string from polars.testing import assert_frame_equal -from polars.testing.parametric.primitives import column, dataframes -from polars.testing.parametric.strategies import strategy_closed, strategy_time_unit +from polars.testing.parametric import column, dataframes +from polars.testing.parametric.strategies.dtype import _time_units if TYPE_CHECKING: + from hypothesis.strategies import SearchStrategy + from polars.type_aliases import ClosedInterval, TimeUnit +def interval_defs() -> SearchStrategy[ClosedInterval]: + closed: list[ClosedInterval] = ["left", "right", "both", "none"] + return st.sampled_from(closed) + + @given( period=st.timedeltas( min_value=timedelta(microseconds=0), max_value=timedelta(days=1000) @@ -24,9 +31,9 @@ offset=st.timedeltas( min_value=timedelta(microseconds=0), max_value=timedelta(days=1000) ).map(parse_as_duration_string), - closed=strategy_closed, + closed=interval_defs(), data=st.data(), - time_unit=strategy_time_unit, + time_unit=_time_units(), ) def test_rolling( period: str, @@ -86,9 +93,9 @@ def test_rolling( window_size=st.timedeltas( min_value=timedelta(microseconds=0), max_value=timedelta(days=2) ).map(parse_as_duration_string), - closed=strategy_closed, + closed=interval_defs(), data=st.data(), - time_unit=strategy_time_unit, + time_unit=_time_units(), aggregation=st.sampled_from( [ "min", diff --git a/py-polars/tests/parametric/test_lazyframe.py b/py-polars/tests/parametric/test_lazyframe.py index d9d3ff0eb4a9..24420ce7fb16 100644 --- a/py-polars/tests/parametric/test_lazyframe.py +++ b/py-polars/tests/parametric/test_lazyframe.py @@ -1,8 +1,8 @@ # ---------------------------------------------------- # Validate LazyFrame behaviour with parametric tests # ---------------------------------------------------- +import hypothesis.strategies as st from hypothesis import example, given -from hypothesis.strategies import integers import polars as pl from polars.testing.parametric import column, dataframes @@ -17,19 +17,21 @@ "start", dtype=pl.Int8, null_probability=0.3, - strategy=integers(min_value=-3, max_value=4), + strategy=st.integers(min_value=-3, max_value=4), ), column( "stop", dtype=pl.Int8, null_probability=0.3, - strategy=integers(min_value=-2, max_value=6), + strategy=st.integers(min_value=-2, max_value=6), ), column( "step", dtype=pl.Int8, null_probability=0.3, - strategy=integers(min_value=-3, max_value=3).filter(lambda x: x != 0), + strategy=st.integers(min_value=-3, max_value=3).filter( + lambda x: x != 0 + ), ), column("misc", dtype=pl.Int32), ], diff --git a/py-polars/tests/parametric/test_lit.py b/py-polars/tests/parametric/test_lit.py index e291f4e04dda..73df1aa98012 100644 --- a/py-polars/tests/parametric/test_lit.py +++ b/py-polars/tests/parametric/test_lit.py @@ -3,20 +3,16 @@ from hypothesis import given import polars as pl -from polars.testing.parametric.strategies import ( - strategy_datetime_ms, - strategy_datetime_ns, - strategy_datetime_us, -) +from polars.testing.parametric.strategies.data import datetimes -@given(value=strategy_datetime_ns) +@given(value=datetimes("ns")) def test_datetime_ns(value: datetime) -> None: result = pl.select(pl.lit(value, dtype=pl.Datetime("ns")))["literal"][0] assert result == value -@given(value=strategy_datetime_us) +@given(value=datetimes("us")) def test_datetime_us(value: datetime) -> None: result = pl.select(pl.lit(value, dtype=pl.Datetime("us")))["literal"][0] assert result == value @@ -24,7 +20,7 @@ def test_datetime_us(value: datetime) -> None: assert result == value -@given(value=strategy_datetime_ms) +@given(value=datetimes("ms")) def test_datetime_ms(value: datetime) -> None: result = pl.select(pl.lit(value, dtype=pl.Datetime("ms")))["literal"][0] expected_microsecond = value.microsecond // 1000 * 1000 diff --git a/py-polars/tests/parametric/test_series.py b/py-polars/tests/parametric/test_series.py index 3aba5240fe76..cf298eda401c 100644 --- a/py-polars/tests/parametric/test_series.py +++ b/py-polars/tests/parametric/test_series.py @@ -3,8 +3,9 @@ # ------------------------------------------------- from __future__ import annotations +import hypothesis.strategies as st +import pytest from hypothesis import given, settings -from hypothesis.strategies import sampled_from import polars as pl from polars.testing import assert_series_equal @@ -27,6 +28,10 @@ def test_series_datetime_timeunits( @given( s=series(min_size=1, max_size=10, dtype=pl.Duration), ) +@pytest.mark.skip( + "These functions are currently bugged for large values: " + "https://github.com/pola-rs/polars/issues/16057" +) def test_series_duration_timeunits( s: pl.Series, ) -> None: @@ -45,7 +50,6 @@ def test_series_duration_timeunits( # special handling for ns timeunit (as we may generate a microsecs-based # timedelta that results in 64bit overflow on conversion to nanosecs) - micros = s.dt.total_microseconds().to_list() lower_bound, upper_bound = -(2**63), (2**63) - 1 if all( (lower_bound <= (us * 1000) <= upper_bound) @@ -58,9 +62,9 @@ def test_series_duration_timeunits( @given( srs=series(max_size=10, dtype=pl.Int64), - start=sampled_from([-5, -4, -3, -2, -1, None, 0, 1, 2, 3, 4, 5]), - stop=sampled_from([-5, -4, -3, -2, -1, None, 0, 1, 2, 3, 4, 5]), - step=sampled_from([-5, -4, -3, -2, -1, None, 1, 2, 3, 4, 5]), + start=st.sampled_from([-5, -4, -3, -2, -1, None, 0, 1, 2, 3, 4, 5]), + stop=st.sampled_from([-5, -4, -3, -2, -1, None, 0, 1, 2, 3, 4, 5]), + step=st.sampled_from([-5, -4, -3, -2, -1, None, 1, 2, 3, 4, 5]), ) @settings(max_examples=500) def test_series_slice( diff --git a/py-polars/tests/parametric/test_testing.py b/py-polars/tests/parametric/test_testing.py index b6d8a79e9ce6..98c5f83252bb 100644 --- a/py-polars/tests/parametric/test_testing.py +++ b/py-polars/tests/parametric/test_testing.py @@ -7,25 +7,15 @@ from datetime import datetime from typing import Any +import hypothesis.strategies as st import pytest from hypothesis import given, settings from hypothesis.errors import InvalidArgument, NonInteractiveExampleWarning -from hypothesis.strategies import sampled_from import polars as pl -from polars.datatypes import TEMPORAL_DTYPES -from polars.testing.parametric import ( - column, - columns, - create_list_strategy, - dataframes, - series, -) +from polars.testing.parametric import column, dataframes, lists, series -# TODO: add parametric strategy generator that supports timezones -TEMPORAL_DTYPES_ = { - tp for tp in TEMPORAL_DTYPES if getattr(tp, "time_zone", None) != "*" -} +TEMPORAL_DTYPES = {pl.Date, pl.Time, pl.Datetime, pl.Duration} @given(df=dataframes(), lf=dataframes(lazy=True), srs=series()) @@ -59,10 +49,6 @@ def test_strategy_shape( assert s1.name == "" assert s2.name == "col" - from polars.testing.parametric.primitives import MAX_COLS - - assert 0 <= len(columns(None)) <= MAX_COLS - @given( lf=dataframes( @@ -70,10 +56,10 @@ def test_strategy_shape( lazy=True, min_size=1, # test mix & match of bulk-assigned cols with custom cols - cols=columns(["a", "b"], dtype=pl.UInt8, unique=True), + cols=[column(n, dtype=pl.UInt8, unique=True) for n in ["a", "b"]], include_cols=[ column("c", dtype=pl.Boolean), - column("d", strategy=sampled_from(["x", "y", "z"])), + column("d", strategy=st.sampled_from(["x", "y", "z"])), ], ) ) @@ -102,11 +88,11 @@ def test_strategy_frame_columns(lf: pl.LazyFrame) -> None: @given( - df=dataframes(allowed_dtypes=TEMPORAL_DTYPES_, max_size=1, max_cols=5), - lf=dataframes(excluded_dtypes=TEMPORAL_DTYPES_, max_size=1, max_cols=5, lazy=True), + df=dataframes(allowed_dtypes=TEMPORAL_DTYPES, max_size=1, max_cols=5), + lf=dataframes(excluded_dtypes=TEMPORAL_DTYPES, max_size=1, max_cols=5, lazy=True), s1=series(dtype=pl.Boolean, max_size=1), - s2=series(allowed_dtypes=TEMPORAL_DTYPES_, max_size=1), - s3=series(excluded_dtypes=TEMPORAL_DTYPES_, max_size=1), + s2=series(allowed_dtypes=TEMPORAL_DTYPES, max_size=1), + s3=series(excluded_dtypes=TEMPORAL_DTYPES, max_size=1), ) @settings(max_examples=50) def test_strategy_dtypes( @@ -126,39 +112,38 @@ def test_strategy_dtypes( assert not s3.dtype.is_temporal() +@given(s=series()) +def test_series_null_probability_default(s: pl.Series) -> None: + assert s.null_count() == 0 + + +@given(s=series(null_probability=0.1)) +def test_series_null_probability(s: pl.Series) -> None: + assert 0 <= s.null_count() <= s.len() + + +@given(df=dataframes(cols=1, null_probability=0.3)) +def test_dataframes_null_probability_global(df: pl.DataFrame) -> None: + null_count = sum(df.null_count().row(0)) + assert 0 <= null_count <= df.height * df.width + + +@given(df=dataframes(cols=2, null_probability={"col0": 0.7})) +def test_dataframes_null_probability_column(df: pl.DataFrame) -> None: + null_count = sum(df.null_count().row(0)) + assert 0 <= null_count <= df.height * df.width + + @given( - # set global, per-column, and overridden null-probabilities - s=series(size=50, null_probability=0.10), - df1=dataframes(cols=1, size=50, null_probability=0.30), - df2=dataframes(cols=2, size=50, null_probability={"col0": 0.70}), - df3=dataframes( + df=dataframes( cols=1, - size=50, null_probability=1.0, - include_cols=[column(name="colx", null_probability=0.20)], - ), + include_cols=[column(name="colx", null_probability=0.2)], + ) ) -@settings(max_examples=50) -def test_strategy_null_probability( - s: pl.Series, - df1: pl.DataFrame, - df2: pl.DataFrame, - df3: pl.DataFrame, -) -> None: - for obj in (s, df1, df2, df3): - assert len(obj) == 50 # type: ignore[arg-type] - - assert s.null_count() < df1.null_count().fold(sum).sum() - assert df1.null_count().fold(sum).sum() < df2.null_count().fold(sum).sum() - assert df2.null_count().fold(sum).sum() < df3.null_count().fold(sum).sum() - - nulls_col0, nulls_col1 = df2.null_count().rows()[0] - assert nulls_col0 > nulls_col1 - assert nulls_col0 < 50 - - nulls_col0, nulls_colx = df3.null_count().rows()[0] - assert nulls_col0 > nulls_colx - assert nulls_col0 == 50 +def test_dataframes_null_probability_override(df: pl.DataFrame) -> None: + assert df.get_column("col0").null_count() == df.height + assert 0 <= df.get_column("col0").null_count() <= df.height @given( @@ -195,12 +180,10 @@ def test_infinities( df: pl.DataFrame, s: pl.Series, ) -> None: - from math import isfinite + from math import isfinite, isnan def finite_float(value: Any) -> bool: - if isinstance(value, float): - return isfinite(value) - return False + return isfinite(value) or isnan(value) assert all(finite_float(val) for val in s.to_list()) for col in df.columns: @@ -214,10 +197,11 @@ def finite_float(value: Any) -> bool: column("coly", dtype=pl.List(pl.Datetime("ms"))), column( name="colz", - strategy=create_list_strategy( + dtype=pl.List(pl.List(pl.String)), + strategy=lists( inner_dtype=pl.List(pl.String), select_from=["aa", "bb", "cc"], - min_size=1, + min_len=1, ), ), ] @@ -240,42 +224,23 @@ def test_sequence_strategies(df: pl.DataFrame) -> None: @pytest.mark.hypothesis() -def test_invalid_arguments() -> None: - for invalid_probability in (-1.0, +2.0): - with pytest.raises(InvalidArgument, match="between 0.0 and 1.0"): - column("colx", dtype=pl.Boolean, null_probability=invalid_probability) - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=NonInteractiveExampleWarning) - with pytest.raises(InvalidArgument, match="between 0.0 and 1.0"): - series(name="colx", null_probability=invalid_probability).example() - with pytest.raises(InvalidArgument, match="between 0.0 and 1.0"): - dataframes( - cols=column(None), # type: ignore[arg-type] - null_probability=invalid_probability, - ).example() - - with pytest.raises(InvalidArgument): - # TODO: add support for remaining compound types - column("colx", dtype=pl.Struct) - - with pytest.raises(InvalidArgument, match="not a valid polars datatype"): - columns(["colx", "coly"], dtype=pl.DataFrame) # type: ignore[arg-type] - - with pytest.raises(InvalidArgument, match=r"\d dtypes for \d names"): - columns(["colx", "coly"], dtype=[pl.Date, pl.Date, pl.Datetime]) +@pytest.mark.parametrize("invalid_probability", [-1.0, +2.0]) +def test_invalid_argument_null_probability(invalid_probability: float) -> None: + with pytest.raises(InvalidArgument, match="between 0.0 and 1.0"): + column("colx", dtype=pl.Boolean, null_probability=invalid_probability) - with pytest.raises(InvalidArgument, match="unable to determine dtype"): - column("colx", strategy=sampled_from([None])) - - -@given(s=series(dtype=pl.Binary)) -@settings(max_examples=5) -def test_strategy_dtype_binary(s: pl.Series) -> None: - assert s.dtype == pl.Binary + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=NonInteractiveExampleWarning) + with pytest.raises(InvalidArgument, match="between 0.0 and 1.0"): + series(name="colx", null_probability=invalid_probability).example() + with pytest.raises(InvalidArgument, match="between 0.0 and 1.0"): + dataframes( + cols=column(None), + null_probability=invalid_probability, + ).example() -@given(s=series(dtype=pl.Decimal)) -@settings(max_examples=5) -def test_strategy_dtype_decimal(s: pl.Series) -> None: - assert s.dtype == pl.Decimal +@pytest.mark.hypothesis() +def test_column_invalid_probability() -> None: + with pytest.raises(InvalidArgument): + column("col", null_probability=2.0) diff --git a/py-polars/tests/parametric/time_series/test_ewm_by.py b/py-polars/tests/parametric/time_series/test_ewm_by.py index e283a1f14c0f..6c7f1da4cf9f 100644 --- a/py-polars/tests/parametric/time_series/test_ewm_by.py +++ b/py-polars/tests/parametric/time_series/test_ewm_by.py @@ -5,7 +5,7 @@ import polars as pl from polars.testing import assert_frame_equal -from polars.testing.parametric.primitives import column, dataframes +from polars.testing.parametric import column, dataframes @given( diff --git a/py-polars/tests/parametric/time_series/test_to_datetime.py b/py-polars/tests/parametric/time_series/test_to_datetime.py index 65785c60fe86..64a4fefd7f47 100644 --- a/py-polars/tests/parametric/time_series/test_to_datetime.py +++ b/py-polars/tests/parametric/time_series/test_to_datetime.py @@ -1,12 +1,44 @@ +from __future__ import annotations + from datetime import datetime +from typing import TYPE_CHECKING import hypothesis.strategies as st from hypothesis import given import polars as pl from polars.exceptions import ComputeError -from polars.testing.parametric.strategies import strategy_datetime_format -from polars.type_aliases import TimeUnit + +if TYPE_CHECKING: + from hypothesis.strategies import DrawFn + + from polars.type_aliases import TimeUnit + + +@st.composite +def datetime_formats(draw: DrawFn) -> str: + """Returns a strategy which generates datetime format strings.""" + parts = [ + "%m", + "%b", + "%B", + "%d", + "%j", + "%a", + "%A", + "%w", + "%H", + "%I", + "%p", + "%M", + "%S", + "%U", + "%W", + "%%", + ] + fmt = draw(st.sets(st.sampled_from(parts))) + fmt.add("%Y") # Make sure year is always present + return " ".join(fmt) @given( @@ -14,7 +46,7 @@ min_value=datetime(1699, 1, 1), max_value=datetime(9999, 12, 31), ), - fmt=strategy_datetime_format(), + fmt=datetime_formats(), ) def test_to_datetime(datetimes: datetime, fmt: str) -> None: input = datetimes.strftime(fmt) diff --git a/py-polars/tests/unit/conftest.py b/py-polars/tests/unit/conftest.py index cc4fe80fc1e1..1b27f23cb730 100644 --- a/py-polars/tests/unit/conftest.py +++ b/py-polars/tests/unit/conftest.py @@ -12,7 +12,7 @@ import pytest import polars as pl -from polars.testing.parametric.profiles import load_profile +from polars.testing.parametric import load_profile load_profile( profile=os.environ.get("POLARS_HYPOTHESIS_PROFILE", "fast"), # type: ignore[arg-type] diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index d9fd0f519e80..f122e71c530b 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -24,7 +24,6 @@ assert_frame_not_equal, assert_series_equal, ) -from polars.testing.parametric import columns if TYPE_CHECKING: from zoneinfo import ZoneInfo @@ -61,13 +60,12 @@ def test_init_empty() -> None: def test_special_char_colname_init() -> None: from string import punctuation - with pl.StringCache(): - cols = [(c.name, c.dtype) for c in columns(punctuation)] - df = pl.DataFrame(schema=cols) + cols = [(c, pl.Int8) for c in punctuation] + df = pl.DataFrame(schema=cols) - assert len(cols) == len(df.columns) - assert len(df.rows()) == 0 - assert df.is_empty() + assert len(cols) == len(df.columns) + assert len(df.rows()) == 0 + assert df.is_empty() def test_comparisons() -> None: diff --git a/py-polars/tests/unit/interchange/test_roundtrip.py b/py-polars/tests/unit/interchange/test_roundtrip.py index 582f15f061aa..31e639e3479c 100644 --- a/py-polars/tests/unit/interchange/test_roundtrip.py +++ b/py-polars/tests/unit/interchange/test_roundtrip.py @@ -13,7 +13,7 @@ from polars.testing import assert_frame_equal from polars.testing.parametric import dataframes -protocol_dtypes = [ +integer_dtypes: list[pl.PolarsDataType] = [ pl.Int8, pl.Int16, pl.Int32, @@ -22,12 +22,15 @@ pl.UInt16, pl.UInt32, pl.UInt64, +] +protocol_dtypes: list[pl.PolarsDataType] = integer_dtypes + [ pl.Float32, pl.Float64, pl.Boolean, pl.String, pl.Datetime, - pl.Categorical, + # TODO: Enable lexically ordered categoricals + pl.Categorical("physical"), # TODO: Add Enum # pl.Enum, ] @@ -153,7 +156,9 @@ def test_from_dataframe_pandas_parametric(df: pl.DataFrame) -> None: @given( dataframes( - allowed_dtypes=protocol_dtypes, + allowed_dtypes=( + integer_dtypes + [pl.Datetime] # Smaller selection to improve performance + ), excluded_dtypes=[ pl.String, # Polars String type does not match protocol spec pl.Categorical, # Categoricals come back as Enums @@ -198,7 +203,9 @@ def test_from_dataframe_pandas_native_parametric(df: pl.DataFrame) -> None: @given( dataframes( - allowed_dtypes=protocol_dtypes, + allowed_dtypes=( + integer_dtypes + [pl.Datetime] # Smaller selection to improve performance + ), excluded_dtypes=[ pl.String, # Polars String type does not match protocol spec pl.Categorical, # Categoricals come back as Enums diff --git a/py-polars/tests/unit/interop/test_to_pandas.py b/py-polars/tests/unit/interop/test_to_pandas.py index 061affd14954..44cb31052061 100644 --- a/py-polars/tests/unit/interop/test_to_pandas.py +++ b/py-polars/tests/unit/interop/test_to_pandas.py @@ -3,12 +3,12 @@ from datetime import date, datetime from typing import Literal +import hypothesis.strategies as st import numpy as np import pandas as pd import pyarrow as pa import pytest from hypothesis import given -from hypothesis.strategies import just, lists, one_of import polars as pl @@ -89,8 +89,8 @@ def test_cat_to_pandas(dtype: pl.DataType) -> None: @given( - column_type_names=lists( - one_of(just("Object"), just("Int32")), min_size=1, max_size=8 + column_type_names=st.lists( + st.one_of(st.just("Object"), st.just("Int32")), min_size=1, max_size=8 ) ) def test_object_to_pandas(column_type_names: list[Literal["Object", "Int32"]]) -> None: diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py index 7d5804473eb6..3ecdf33aae6a 100644 --- a/py-polars/tests/unit/operations/test_cast.py +++ b/py-polars/tests/unit/operations/test_cast.py @@ -7,11 +7,7 @@ import pytest import polars as pl -from polars._utils.convert import ( - MS_PER_SECOND, - NS_PER_SECOND, - US_PER_SECOND, -) +from polars._utils.constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND from polars.testing import assert_frame_equal from polars.testing.asserts.series import assert_series_equal diff --git a/py-polars/tests/unit/operations/test_clear.py b/py-polars/tests/unit/operations/test_clear.py index c9a2d29c1492..7799e7e05ce2 100644 --- a/py-polars/tests/unit/operations/test_clear.py +++ b/py-polars/tests/unit/operations/test_clear.py @@ -1,14 +1,14 @@ from __future__ import annotations +import hypothesis.strategies as st import pytest from hypothesis import given -from hypothesis.strategies import integers import polars as pl from polars.testing.parametric import series -@given(s=series(), n=integers(min_value=0, max_value=10)) +@given(s=series(), n=st.integers(min_value=0, max_value=10)) def test_clear_series_parametric(s: pl.Series, n: int) -> None: result = s.clear() diff --git a/py-polars/tests/unit/operations/test_ewm.py b/py-polars/tests/unit/operations/test_ewm.py index faf0750c689b..66e94c20aee7 100644 --- a/py-polars/tests/unit/operations/test_ewm.py +++ b/py-polars/tests/unit/operations/test_ewm.py @@ -2,10 +2,10 @@ from typing import Any +import hypothesis.strategies as st import numpy as np import pytest from hypothesis import given -from hypothesis.strategies import booleans, floats import polars as pl from polars.expr.expr import _prepare_alpha @@ -224,16 +224,16 @@ def alpha_guard(**decay_param: float) -> bool: min_size=4, dtype=pl.Float64, null_probability=0.05, - strategy=floats(min_value=-1e8, max_value=1e8), + strategy=st.floats(min_value=-1e8, max_value=1e8), ), - half_life=floats(min_value=0, max_value=4, exclude_min=True).filter( + half_life=st.floats(min_value=0, max_value=4, exclude_min=True).filter( lambda x: alpha_guard(half_life=x) ), - com=floats(min_value=0, max_value=99).filter(lambda x: alpha_guard(com=x)), - span=floats(min_value=1, max_value=10).filter(lambda x: alpha_guard(span=x)), - ignore_nulls=booleans(), - adjust=booleans(), - bias=booleans(), + com=st.floats(min_value=0, max_value=99).filter(lambda x: alpha_guard(com=x)), + span=st.floats(min_value=1, max_value=10).filter(lambda x: alpha_guard(span=x)), + ignore_nulls=st.booleans(), + adjust=st.booleans(), + bias=st.booleans(), ) def test_ewm_methods( s: pl.Series, diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_core.py b/py-polars/tests/unit/testing/parametric/strategies/test_core.py new file mode 100644 index 000000000000..3917197fe3e4 --- /dev/null +++ b/py-polars/tests/unit/testing/parametric/strategies/test_core.py @@ -0,0 +1,24 @@ +import hypothesis.strategies as st +from hypothesis import given, settings + +import polars as pl +from polars.testing.parametric.strategies import dtypes, series + + +@given(st.data()) +def test_dtype(data: st.DataObject) -> None: + dtype = data.draw(dtypes()) + s = data.draw(series(dtype=dtype)) + assert s.dtype == dtype + + +@given(s=series(dtype=pl.Binary)) +@settings(max_examples=5) +def test_strategy_dtype_binary(s: pl.Series) -> None: + assert s.dtype == pl.Binary + + +@given(s=series(dtype=pl.Decimal)) +@settings(max_examples=5) +def test_strategy_dtype_decimal(s: pl.Series) -> None: + assert s.dtype == pl.Decimal diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_dtype.py b/py-polars/tests/unit/testing/parametric/strategies/test_dtype.py new file mode 100644 index 000000000000..fe9ce390002f --- /dev/null +++ b/py-polars/tests/unit/testing/parametric/strategies/test_dtype.py @@ -0,0 +1,54 @@ +import hypothesis.strategies as st +from hypothesis import given + +import polars as pl +from polars.testing.parametric.strategies.dtype import dtypes + + +@given(dtype=dtypes()) +def test_dtypes(dtype: pl.DataType) -> None: + assert isinstance(dtype, pl.DataType) + + +@given(dtype=dtypes(nesting_level=0)) +def test_dtypes_nesting_level(dtype: pl.DataType) -> None: + assert not dtype.is_nested() + + +@given(st.data()) +def test_dtypes_allowed(data: st.DataObject) -> None: + allowed_dtype = data.draw(dtypes()) + result = data.draw(dtypes(allowed_dtypes=[allowed_dtype])) + assert result == allowed_dtype + + +@given(st.data()) +def test_dtypes_excluded(data: st.DataObject) -> None: + excluded_dtype = data.draw(dtypes()) + result = data.draw(dtypes(excluded_dtypes=[excluded_dtype])) + assert result != excluded_dtype + + +@given(dtype=dtypes(allowed_dtypes=[pl.Duration], excluded_dtypes=[pl.Duration("ms")])) +def test_dtypes_allowed_excluded_instance(dtype: pl.DataType) -> None: + assert isinstance(dtype, pl.Duration) + assert dtype.time_unit != "ms" + + +@given( + dtype=dtypes( + allowed_dtypes=[pl.Duration("ns"), pl.Date], excluded_dtypes=[pl.Duration] + ) +) +def test_dtypes_allowed_excluded_priority(dtype: pl.DataType) -> None: + assert dtype == pl.Date + + +@given(dtype=dtypes(allowed_dtypes=[pl.Int8(), pl.Duration("ms")])) +def test_dtypes_allowed_instantiated(dtype: pl.DataType) -> None: + assert dtype in (pl.Int8(), pl.Duration("ms")) + + +@given(dtype=dtypes(allowed_dtypes=[pl.List(pl.List), pl.Int64])) +def test_dtypes_allowed_uninstantiated_nested(dtype: pl.DataType) -> None: + assert dtype in (pl.List, pl.Int64) diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_legacy.py b/py-polars/tests/unit/testing/parametric/strategies/test_legacy.py new file mode 100644 index 000000000000..eab2574dc2d6 --- /dev/null +++ b/py-polars/tests/unit/testing/parametric/strategies/test_legacy.py @@ -0,0 +1,20 @@ +import pytest +from hypothesis.errors import NonInteractiveExampleWarning + +from polars.testing.parametric import columns, create_list_strategy +from polars.testing.parametric.strategies.core import _COL_LIMIT + + +@pytest.mark.hypothesis() +def test_columns_deprecated() -> None: + with pytest.deprecated_call(), pytest.warns(NonInteractiveExampleWarning): + result = columns() + assert 0 <= len(result) <= _COL_LIMIT + + +@pytest.mark.hypothesis() +def test_create_list_strategy_deprecated() -> None: + with pytest.deprecated_call(), pytest.warns(NonInteractiveExampleWarning): + result = create_list_strategy(size=5) + with pytest.warns(NonInteractiveExampleWarning): + assert len(result.example()) == 5 diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_utils.py b/py-polars/tests/unit/testing/parametric/strategies/test_utils.py new file mode 100644 index 000000000000..f192bc410e21 --- /dev/null +++ b/py-polars/tests/unit/testing/parametric/strategies/test_utils.py @@ -0,0 +1,22 @@ +from typing import Any + +import pytest + +from polars.testing.parametric.strategies._utils import flexhash + + +@pytest.mark.parametrize( + ("left", "right"), + [ + (1, 2), + (1.0, 2.0), + ("x", "y"), + ([1, 2], [3, 4]), + ({"a": 1, "b": 2}, {"a": 1, "b": 3}), + ({"a": 1, "b": [1.0]}, {"a": 1, "b": [1.5]}), + ], +) +def test_flexhash_flat(left: Any, right: Any) -> None: + assert flexhash(left) != flexhash(right) + assert flexhash(left) == flexhash(left) + assert flexhash(right) == flexhash(right) From 2adb030d22ae1b1c03c228f9dae5e7f1d3c029ef Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Mon, 13 May 2024 07:49:02 +0200 Subject: [PATCH 17/29] perf(python): Avoid needless copy when converting chunked Series to NumPy (#16178) --- py-polars/polars/series/series.py | 32 +++++------ py-polars/src/conversion/chunked_array.rs | 12 ++-- py-polars/src/series/export.rs | 67 ++++++++++++----------- 3 files changed, 55 insertions(+), 56 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index a7109c3c21d8..e45a9301c6f6 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4409,9 +4409,10 @@ def to_numpy( if the array was created without copy, as the underlying Arrow data is immutable. use_pyarrow - Use `pyarrow.Array.to_numpy + First convert to PyArrow, then call `pyarrow.Array.to_numpy `_ - for the conversion to NumPy. + to convert to NumPy. If set to `False`, Polars' own conversion logic is + used. zero_copy_only Raise an exception if the conversion to a NumPy would require copying the underlying data. Data copy occurs, for example, when the Series contains @@ -4438,18 +4439,20 @@ def to_numpy( ) allow_copy = not zero_copy_only - def raise_on_copy() -> None: - if not allow_copy and not self.is_empty(): + if ( + use_pyarrow + and _PYARROW_AVAILABLE + and self.dtype not in (Object, Datetime, Duration, Date, Array) + ): + if not allow_copy and self.n_chunks() > 1 and not self.is_empty(): msg = "cannot return a zero-copy array" raise ValueError(msg) - if self.n_chunks() > 1: - raise_on_copy() - self = self.rechunk() - - dtype = self.dtype + return self.to_arrow().to_numpy( + zero_copy_only=not allow_copy, writable=writable + ) - if dtype == Array: + if self.dtype == Array: np_array = self.explode().to_numpy( allow_copy=allow_copy, writable=writable, @@ -4458,15 +4461,6 @@ def raise_on_copy() -> None: np_array.shape = (self.len(), self.dtype.width) # type: ignore[attr-defined] return np_array - if ( - use_pyarrow - and _PYARROW_AVAILABLE - and dtype not in (Object, Datetime, Duration, Date) - ): - return self.to_arrow().to_numpy( - zero_copy_only=not allow_copy, writable=writable - ) - return self._s.to_numpy(allow_copy=allow_copy, writable=writable) def to_torch(self) -> torch.Tensor: diff --git a/py-polars/src/conversion/chunked_array.rs b/py-polars/src/conversion/chunked_array.rs index 35b5a4427a5f..455b3a0fd525 100644 --- a/py-polars/src/conversion/chunked_array.rs +++ b/py-polars/src/conversion/chunked_array.rs @@ -8,7 +8,7 @@ use crate::py_modules::UTILS; impl ToPyObject for Wrap<&StringChunked> { fn to_object(&self, py: Python) -> PyObject { - let iter = self.0.into_iter(); + let iter = self.0.iter(); PyList::new_bound(py, iter).into_py(py) } } @@ -17,7 +17,7 @@ impl ToPyObject for Wrap<&BinaryChunked> { fn to_object(&self, py: Python) -> PyObject { let iter = self .0 - .into_iter() + .iter() .map(|opt_bytes| opt_bytes.map(|bytes| PyBytes::new_bound(py, bytes))); PyList::new_bound(py, iter).into_py(py) } @@ -48,7 +48,7 @@ impl ToPyObject for Wrap<&DurationChunked> { let time_unit = self.0.time_unit().to_ascii(); let iter = self .0 - .into_iter() + .iter() .map(|opt_v| opt_v.map(|v| convert.call1((v, time_unit)).unwrap())); PyList::new_bound(py, iter).into_py(py) } @@ -62,7 +62,7 @@ impl ToPyObject for Wrap<&DatetimeChunked> { let time_zone = self.0.time_zone().to_object(py); let iter = self .0 - .into_iter() + .iter() .map(|opt_v| opt_v.map(|v| convert.call1((v, time_unit, &time_zone)).unwrap())); PyList::new_bound(py, iter).into_py(py) } @@ -81,7 +81,7 @@ pub(crate) fn time_to_pyobject_iter<'a>( ) -> impl ExactSizeIterator>> { let utils = UTILS.bind(py); let convert = utils.getattr(intern!(py, "to_py_time")).unwrap().clone(); - ca.0.into_iter() + ca.0.iter() .map(move |opt_v| opt_v.map(|v| convert.call1((v,)).unwrap())) } @@ -113,7 +113,7 @@ pub(crate) fn decimal_to_pyobject_iter<'a>( let py_scale = (-(ca.scale() as i32)).to_object(py); // if we don't know precision, the only safe bet is to set it to 39 let py_precision = ca.precision().unwrap_or(39).to_object(py); - ca.into_iter().map(move |opt_v| { + ca.iter().map(move |opt_v| { opt_v.map(|v| { // TODO! use AnyValue so that we have a single impl. const N: usize = 3; diff --git a/py-polars/src/series/export.rs b/py-polars/src/series/export.rs index de4fe0d4e0c8..c59c6129d02b 100644 --- a/py-polars/src/series/export.rs +++ b/py-polars/src/series/export.rs @@ -167,12 +167,14 @@ impl PySeries { /// is required. Set `writable` to make sure the resulting array is writable, possibly requiring /// copying the data. fn to_numpy(&self, py: Python, allow_copy: bool, writable: bool) -> PyResult { - let is_empty = self.series.is_empty(); - - if self.series.null_count() == 0 { + if self.series.is_empty() { + // Take this path to ensure a writable array. + // This does not actually copy for empty Series. + return series_to_numpy_with_copy(py, &self.series); + } else if self.series.null_count() == 0 { if let Some(mut arr) = self.to_numpy_view(py) { - if writable || is_empty { - if !allow_copy && !is_empty { + if writable { + if !allow_copy { return Err(PyValueError::new_err( "cannot return a zero-copy writable array", )); @@ -183,7 +185,7 @@ impl PySeries { } } - if !allow_copy & !is_empty { + if !allow_copy { return Err(PyValueError::new_err("cannot return a zero-copy array")); } @@ -239,30 +241,28 @@ fn series_to_numpy_with_copy(py: Python, s: &Series) -> PyResult { }, Time => { let ca = s.time().unwrap(); - let iter = time_to_pyobject_iter(py, ca); - let np_arr = PyArray1::from_iter_bound(py, iter.map(|v| v.into_py(py))); - np_arr.into_py(py) + let values = time_to_pyobject_iter(py, ca).map(|v| v.into_py(py)); + PyArray1::from_iter_bound(py, values).into_py(py) }, String => { let ca = s.str().unwrap(); - let np_arr = PyArray1::from_iter_bound(py, ca.iter().map(|s| s.into_py(py))); - np_arr.into_py(py) + let values = ca.iter().map(|s| s.into_py(py)); + PyArray1::from_iter_bound(py, values).into_py(py) }, Binary => { let ca = s.binary().unwrap(); - let np_arr = PyArray1::from_iter_bound(py, ca.iter().map(|s| s.into_py(py))); - np_arr.into_py(py) + let values = ca.iter().map(|s| s.into_py(py)); + PyArray1::from_iter_bound(py, values).into_py(py) }, Categorical(_, _) | Enum(_, _) => { let ca = s.categorical().unwrap(); - let np_arr = PyArray1::from_iter_bound(py, ca.iter_str().map(|s| s.into_py(py))); - np_arr.into_py(py) + let values = ca.iter_str().map(|s| s.into_py(py)); + PyArray1::from_iter_bound(py, values).into_py(py) }, Decimal(_, _) => { let ca = s.decimal().unwrap(); - let iter = decimal_to_pyobject_iter(py, ca); - let np_arr = PyArray1::from_iter_bound(py, iter.map(|v| v.into_py(py))); - np_arr.into_py(py) + let values = decimal_to_pyobject_iter(py, ca).map(|v| v.into_py(py)); + PyArray1::from_iter_bound(py, values).into_py(py) }, #[cfg(feature = "object")] Object(_, _) => { @@ -270,14 +270,13 @@ fn series_to_numpy_with_copy(py: Python, s: &Series) -> PyResult { .as_any() .downcast_ref::>() .unwrap(); - let np_arr = - PyArray1::from_iter_bound(py, ca.into_iter().map(|opt_v| opt_v.to_object(py))); - np_arr.into_py(py) + let values = ca.iter().map(|v| v.to_object(py)); + PyArray1::from_iter_bound(py, values).into_py(py) }, Null => { let n = s.len(); - let np_arr = PyArray1::from_iter_bound(py, std::iter::repeat(f32::NAN).take(n)); - np_arr.into_py(py) + let values = std::iter::repeat(f32::NAN).take(n); + PyArray1::from_iter_bound(py, values).into_py(py) }, dt => { raise_err!( @@ -293,15 +292,21 @@ fn series_to_numpy_with_copy(py: Python, s: &Series) -> PyResult { fn numeric_series_to_numpy(py: Python, s: &Series) -> PyObject where T: PolarsNumericType, + T::Native: numpy::Element, U: Float + numpy::Element, { let ca: &ChunkedArray = s.as_ref().as_ref(); - let mapper = |opt_v: Option| match opt_v { - Some(v) => NumCast::from(v).unwrap(), - None => U::nan(), - }; - let np_arr = PyArray1::from_iter_bound(py, ca.iter().map(mapper)); - np_arr.into_py(py) + if s.null_count() == 0 { + let values = ca.into_no_null_iter(); + PyArray1::::from_iter_bound(py, values).into_py(py) + } else { + let mapper = |opt_v: Option| match opt_v { + Some(v) => NumCast::from(v).unwrap(), + None => U::nan(), + }; + let values = ca.iter().map(mapper); + PyArray1::from_iter_bound(py, values).into_py(py) + } } /// Convert booleans to u8 if no nulls are present, otherwise convert to objects. fn boolean_series_to_numpy(py: Python, s: &Series) -> PyObject { @@ -344,6 +349,6 @@ where { let s_phys = s.to_physical_repr(); let ca = s_phys.i64().unwrap(); - let iter = ca.iter().map(|v| v.unwrap_or(i64::MIN).into()); - PyArray1::::from_iter_bound(py, iter).into_py(py) + let values = ca.iter().map(|v| v.unwrap_or(i64::MIN).into()); + PyArray1::::from_iter_bound(py, values).into_py(py) } From db18aa9059db3cc4e74e53b02363ddec86a4ac61 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Mon, 13 May 2024 10:48:17 +0400 Subject: [PATCH 18/29] feat: Add SQL support for `GROUP BY ALL` syntax and fix several issues with aliased group keys (#16179) Co-authored-by: ritchie --- crates/polars-plan/src/dsl/expr.rs | 8 +- .../src/dsl/functions/selectors.rs | 4 + .../src/logical_plan/expr_expansion.rs | 19 ++- crates/polars-plan/src/utils.rs | 2 +- crates/polars-sql/src/context.rs | 140 ++++++++++++------ crates/polars-sql/tests/functions_string.rs | 48 ++---- crates/polars-sql/tests/simple_exprs.rs | 55 ++++++- py-polars/tests/unit/sql/test_group_by.py | 129 ++++++++++++++++ 8 files changed, 304 insertions(+), 101 deletions(-) diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index fa447f51914a..778a4efeaa0f 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -67,9 +67,9 @@ impl AsRef for AggExpr { #[must_use] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum Expr { - Alias(Arc, Arc), - Column(Arc), - Columns(Vec), + Alias(Arc, ColumnName), + Column(ColumnName), + Columns(Arc<[ColumnName]>), DtypeColumn(Vec), Literal(LiteralValue), BinaryExpr { @@ -291,7 +291,7 @@ impl Default for Expr { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum Excluded { - Name(Arc), + Name(ColumnName), Dtype(DataType), } diff --git a/crates/polars-plan/src/dsl/functions/selectors.rs b/crates/polars-plan/src/dsl/functions/selectors.rs index 554c1fb37341..617b6497435e 100644 --- a/crates/polars-plan/src/dsl/functions/selectors.rs +++ b/crates/polars-plan/src/dsl/functions/selectors.rs @@ -39,6 +39,10 @@ pub fn all() -> Expr { /// Select multiple columns by name. pub fn cols>(names: I) -> Expr { let names = names.into_vec(); + let names = names + .into_iter() + .map(|v| ColumnName::from(v.as_str())) + .collect(); Expr::Columns(names) } diff --git a/crates/polars-plan/src/logical_plan/expr_expansion.rs b/crates/polars-plan/src/logical_plan/expr_expansion.rs index ed1f7924d87b..f221ab19d7ad 100644 --- a/crates/polars-plan/src/logical_plan/expr_expansion.rs +++ b/crates/polars-plan/src/logical_plan/expr_expansion.rs @@ -157,16 +157,15 @@ fn replace_regex( fn expand_columns( expr: &Expr, result: &mut Vec, - names: &[String], + names: &[ColumnName], schema: &Schema, - exclude: &PlHashSet>, + exclude: &PlHashSet, ) -> PolarsResult<()> { let mut is_valid = true; for name in names { - if !exclude.contains(name.as_str()) { + if !exclude.contains(name) { let new_expr = expr.clone(); - let (new_expr, new_expr_valid) = - replace_columns_with_column(new_expr, names, name.as_str()); + let (new_expr, new_expr_valid) = replace_columns_with_column(new_expr, names, name); is_valid &= new_expr_valid; // we may have regex col in columns. #[allow(clippy::collapsible_else_if)] @@ -233,15 +232,15 @@ fn struct_index_to_field(expr: Expr, schema: &Schema) -> PolarsResult { /// expression chain. pub(super) fn replace_columns_with_column( mut expr: Expr, - names: &[String], - column_name: &str, + names: &[ColumnName], + column_name: &ColumnName, ) -> (Expr, bool) { let mut is_valid = true; expr = expr.map_expr(|e| match e { Expr::Columns(members) => { // `col([a, b]) + col([c, d])` - if members == names { - Expr::Column(ColumnName::from(column_name)) + if members.as_ref() == names { + Expr::Column(column_name.clone()) } else { is_valid = false; Expr::Columns(members) @@ -586,7 +585,7 @@ fn replace_selector(expr: Expr, schema: &Schema, keys: &[Expr]) -> PolarsResult< let Expr::Column(name) = e else { unreachable!() }; - name.to_string() + name }) .collect(), )) diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index 056e0de10b3d..24764f1910f4 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -106,7 +106,7 @@ pub fn has_aexpr_literal(current_node: Node, arena: &Arena) -> bool { /// Can check if an expression tree has a matching_expr. This /// requires a dummy expression to be created that will be used to pattern match against. -pub(crate) fn has_expr(current_expr: &Expr, matches: F) -> bool +pub fn has_expr(current_expr: &Expr, matches: F) -> bool where F: Fn(&Expr) -> bool, { diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index e594ce5b9e0c..cc90dd1e10fe 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -1,5 +1,4 @@ use std::cell::RefCell; -use std::collections::BTreeSet; use polars_core::prelude::*; use polars_error::to_compute_err; @@ -8,8 +7,8 @@ use polars_plan::prelude::*; use sqlparser::ast::{ Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, SetOperator, - SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, Value as SQLValue, - WildcardAdditionalOptions, + SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator, + Value as SQLValue, WildcardAdditionalOptions, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; @@ -407,32 +406,64 @@ impl SQLContext { }) .collect::>()?; - // Check for group by (after projections as there may be ordinal/position ints). - let group_by_keys: Vec; - if let GroupByExpr::Expressions(group_by_exprs) = &select_stmt.group_by { - group_by_keys = group_by_exprs.iter() - .map(|e| match e { - SQLExpr::Value(SQLValue::Number(idx, _)) => { - let idx = match idx.parse::() { - Ok(0) | Err(_) => Err(polars_err!( + // Check for "GROUP BY ..." (after projections, as there may be ordinal/position ints). + let mut group_by_keys: Vec = Vec::new(); + match &select_stmt.group_by { + // Standard "GROUP BY x, y, z" syntax + GroupByExpr::Expressions(group_by_exprs) => { + group_by_keys = group_by_exprs + .iter() + .map(|e| match e { + SQLExpr::UnaryOp { + op: UnaryOperator::Minus, + expr, + } if matches!(**expr, SQLExpr::Value(SQLValue::Number(_, _))) => { + if let SQLExpr::Value(SQLValue::Number(ref idx, _)) = **expr { + Err(polars_err!( ComputeError: - "group_by error: a positive number or an expression expected, got {}", + "group_by error: expected a positive integer or valid expression; got -{}", idx - )), - Ok(idx) => Ok(idx), - }?; - // note: sql queries are 1-indexed - Ok(projections[idx - 1].clone()) + )) + } else { + unreachable!() + } + }, + SQLExpr::Value(SQLValue::Number(idx, _)) => { + // note: sql queries are 1-indexed + let idx = idx.parse::().unwrap(); + Ok(projections[idx - 1].clone()) + }, + SQLExpr::Value(v) => Err(polars_err!( + ComputeError: + "group_by error: expected a positive integer or valid expression; got {}", v, + )), + _ => parse_sql_expr(e, self, schema.as_deref()), + }) + .collect::>()? + }, + // "GROUP BY ALL" syntax; automatically adds expressions that do not contain + // nested agg/window funcs to the group key (also ignores literals). + GroupByExpr::All => { + projections.iter().for_each(|expr| match expr { + // immediately match the most common cases (col|agg|lit, optionally aliased). + Expr::Agg(_) | Expr::Literal(_) => (), + Expr::Column(_) => group_by_keys.push(expr.clone()), + Expr::Alias(e, _) if matches!(&**e, Expr::Agg(_) | Expr::Literal(_)) => (), + Expr::Alias(e, _) if matches!(&**e, Expr::Column(_)) => { + if let Expr::Column(name) = &**e { + group_by_keys.push(col(name)); + } }, - SQLExpr::Value(_) => Err(polars_err!( - ComputeError: - "group_by error: a positive number or an expression expected", - )), - _ => parse_sql_expr(e, self, schema.as_deref()), - }) - .collect::>()? - } else { - polars_bail!(ComputeError: "not implemented"); + _ => { + // If not quick-matched, add if no nested agg/window expressions + if !has_expr(expr, |e| { + matches!(e, Expr::Agg(_)) || matches!(e, Expr::Window { .. }) + }) { + group_by_keys.push(expr.clone()) + } + }, + }); + }, }; lf = if group_by_keys.is_empty() { @@ -441,28 +472,27 @@ impl SQLContext { } else if !contains_wildcard { let schema = lf.schema()?; let mut column_names = schema.get_names(); - let mut retained_names: BTreeSet = BTreeSet::new(); + let mut retained_names = PlHashSet::new(); projections.iter().for_each(|expr| match expr { Expr::Alias(_, name) => { - retained_names.insert((name).to_string()); + retained_names.insert(name.clone()); }, Expr::Column(name) => { - retained_names.insert((name).to_string()); + retained_names.insert(name.clone()); }, Expr::Columns(names) => names.iter().for_each(|name| { - retained_names.insert((name).to_string()); + retained_names.insert(name.clone()); }), Expr::Exclude(inner_expr, excludes) => { if let Expr::Columns(names) = (*inner_expr).as_ref() { names.iter().for_each(|name| { - retained_names.insert((name).to_string()); + retained_names.insert(name.clone()); }) } - excludes.iter().for_each(|excluded| { if let Excluded::Name(name) = excluded { - retained_names.remove(&(name.to_string())); + retained_names.remove(name); } }); }, @@ -476,7 +506,6 @@ impl SQLContext { lf.drop(column_names) } else if contains_wildcard_exclude { let mut dropped_names = Vec::with_capacity(projections.len()); - let exclude_expr = projections.iter().find(|expr| { if let Expr::Exclude(_, excludes) = expr { for excluded in excludes.iter() { @@ -489,7 +518,6 @@ impl SQLContext { false } }); - if exclude_expr.is_some() { lf = lf.with_columns(projections); lf = self.process_order_by(lf, &query.order_by)?; @@ -692,8 +720,6 @@ impl SQLContext { group_by_keys: &[Expr], projections: &[Expr], ) -> PolarsResult { - // Check group_by and projection due to difference between SQL and polars. - // Return error on wild card, shouldn't process this. polars_ensure!( !contains_wildcard, ComputeError: "group_by error: can't process wildcard in group_by" @@ -704,30 +730,51 @@ impl SQLContext { // Remove the group_by keys as polars adds those implicitly. let mut aggregation_projection = Vec::with_capacity(projections.len()); - let mut aliases: BTreeSet<&str> = BTreeSet::new(); + let mut projection_aliases = PlHashSet::new(); + let mut group_key_aliases = PlHashSet::new(); for mut e in projections { // If simple aliased expression we defer aliasing until after the group_by. - if e.clone().meta().is_simple_projection() { - if let Expr::Alias(expr, name) = e { - aliases.insert(name); + let is_agg_or_window = has_expr(e, |e| matches!(e, Expr::Agg(_) | Expr::Window { .. })); + if let Expr::Alias(expr, alias) = e { + if e.clone().meta().is_simple_projection() { + group_key_aliases.insert(alias.as_ref()); e = expr + } else if !is_agg_or_window && !group_by_keys_schema.contains(alias) { + projection_aliases.insert(alias.as_ref()); } } let field = e.to_field(&schema_before, Context::Default)?; - if group_by_keys_schema.get(&field.name).is_none() { - aggregation_projection.push(e.clone()) + if group_by_keys_schema.get(&field.name).is_none() && is_agg_or_window { + let mut e = e.clone(); + if let Expr::Agg(AggExpr::Implode(expr)) = &e { + e = (**expr).clone(); + } else if let Expr::Alias(expr, name) = &e { + if let Expr::Agg(AggExpr::Implode(expr)) = expr.as_ref() { + e = (**expr).clone().alias(name.as_ref()); + } + } + aggregation_projection.push(e); + } else if let Expr::Column(_) = e { + // Non-aggregated columns must be part of the GROUP BY clause + if !group_by_keys_schema.contains(&field.name) { + polars_bail!(ComputeError: "'{}' should participate in the GROUP BY clause or an aggregate function", &field.name); + } } } + let aggregated = lf.group_by(group_by_keys).agg(&aggregation_projection); let projection_schema = expressions_to_schema(projections, &schema_before, Context::Default)?; - // A final projection to get the proper order. + // A final projection to get the proper order and any deferred transforms/aliases. let final_projection = projection_schema .iter_names() .zip(projections) .map(|(name, projection_expr)| { - if group_by_keys_schema.get(name).is_some() || aliases.contains(name.as_str()) { + if group_by_keys_schema.get(name).is_some() + || projection_aliases.contains(name.as_str()) + || group_key_aliases.contains(name.as_str()) + { projection_expr.clone() } else { col(name) @@ -735,7 +782,6 @@ impl SQLContext { }) .collect::>(); - let aggregated = lf.group_by(group_by_keys).agg(&aggregation_projection); Ok(aggregated.select(&final_projection)) } @@ -817,7 +863,7 @@ impl SQLContext { contains_wildcard_exclude: &mut bool, ) -> PolarsResult { if options.opt_except.is_some() { - polars_bail!(InvalidOperation: "EXCEPT not supported. Use EXCLUDE instead") + polars_bail!(InvalidOperation: "EXCEPT not supported; use EXCLUDE instead") } Ok(match &options.opt_exclude { Some(ExcludeSelectItem::Single(ident)) => { diff --git a/crates/polars-sql/tests/functions_string.rs b/crates/polars-sql/tests/functions_string.rs index a1e56ea55134..952d3a6484f0 100644 --- a/crates/polars-sql/tests/functions_string.rs +++ b/crates/polars-sql/tests/functions_string.rs @@ -89,41 +89,23 @@ fn test_array_to_string() { "b" => &[1, 1, 42], } .unwrap(); + let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); - let sql = context - .execute( - r#" - SELECT - b, - a - FROM df - GROUP BY - b"#, - ) - .unwrap(); - context.register("df_1", sql.clone()); + let sql = r#" - SELECT - b, - array_to_string(a, ', ') as as, - FROM df_1 - ORDER BY - b, - as"#; + SELECT b, ARRAY_TO_STRING(a,', ') AS a2s, + FROM ( + SELECT b, ARRAY_AGG(a) + FROM df + GROUP BY b + ) tbl + ORDER BY a2s"#; let df_sql = context.execute(sql).unwrap().collect().unwrap(); - - let df_pl = df - .lazy() - .group_by([col("b")]) - .agg([col("a")]) - .select(&[col("b"), col("a").list().join(lit(", "), true).alias("as")]) - .sort_by_exprs( - vec![col("b"), col("as")], - SortMultipleOptions::default().with_maintain_order(true), - ) - .collect() - .unwrap(); - - assert!(df_sql.equals_missing(&df_pl)); + let df_expected = df! { + "b" => &[1, 42], + "a2s" => &["first, first", "third"], + } + .unwrap(); + assert!(df_sql.equals(&df_expected)); } diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index 9a1338adf5fe..e24f6351bd34 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -62,30 +62,73 @@ fn test_group_by_simple() -> PolarsResult<()> { let df_sql = context .execute( r#" - SELECT a, sum(b) as b , sum(a + b) as c, count(a) as total_count + SELECT + a AS "aa", + SUM(b) AS "bb", + SUM(a + b) AS "cc", + COUNT(a) AS "total_count" FROM df GROUP BY a LIMIT 100 "#, )? - .sort(["a"], Default::default()) + .sort(["aa"], Default::default()) .collect()?; let df_pl = df .lazy() - .group_by(&[col("a")]) + .group_by(&[col("a").alias("aa")]) .agg(&[ - col("b").sum().alias("b"), - (col("a") + col("b")).sum().alias("c"), + col("b").sum().alias("bb"), + (col("a") + col("b")).sum().alias("cc"), col("a").count().alias("total_count"), ]) .limit(100) - .sort(["a"], Default::default()) + .sort(["aa"], Default::default()) .collect()?; assert_eq!(df_sql, df_pl); Ok(()) } +#[test] +fn test_group_by_expression_key() -> PolarsResult<()> { + let df = df! { + "a" => &["xx", "yy", "xx", "yy", "xx", "zz"], + "b" => &[1, 2, 3, 4, 5, 6], + "c" => &[99, 99, 66, 66, 66, 66], + } + .unwrap(); + + let mut context = SQLContext::new(); + context.register("df", df.clone().lazy()); + + // check how we handle grouping by a key that gets used in select transform + let df_sql = context + .execute( + r#" + SELECT + CASE WHEN a = 'zz' THEN 'xx' ELSE a END AS grp, + SUM(b) AS sum_b, + SUM(c) AS sum_c, + FROM df + GROUP BY a + ORDER BY sum_c + "#, + )? + .sort(["sum_c"], Default::default()) + .collect()?; + + let df_expected = df! { + "grp" => ["xx", "yy", "xx"], + "sum_b" => [6, 6, 9], + "sum_c" => [66, 165, 231], + } + .unwrap(); + + assert_eq!(df_sql, df_expected); + Ok(()) +} + #[test] fn test_cast_exprs() { let df = create_sample_df().unwrap(); diff --git a/py-polars/tests/unit/sql/test_group_by.py b/py-polars/tests/unit/sql/test_group_by.py index 03f15778e605..37f242e86a95 100644 --- a/py-polars/tests/unit/sql/test_group_by.py +++ b/py-polars/tests/unit/sql/test_group_by.py @@ -1,10 +1,12 @@ from __future__ import annotations +from datetime import date from pathlib import Path import pytest import polars as pl +from polars.exceptions import ComputeError from polars.testing import assert_frame_equal @@ -65,6 +67,105 @@ def test_group_by(foods_ipc_path: Path) -> None: assert out.to_dict(as_series=False) == {"grp": ["c"], "n_dist_attr": [2]} +def test_group_by_all() -> None: + df = pl.DataFrame( + { + "a": ["xx", "yy", "xx", "yy", "xx", "zz"], + "b": [1, 2, 3, 4, 5, 6], + "c": [99, 99, 66, 66, 66, 66], + } + ) + + # basic group/agg + res = df.sql( + """ + SELECT + a, + SUM(b), + SUM(c) + FROM self + GROUP BY ALL + ORDER BY a + """ + ) + expected = pl.DataFrame( + { + "a": ["xx", "yy", "zz"], + "b": [9, 6, 6], + "c": [231, 165, 66], + } + ) + assert_frame_equal(expected, res) + + # more involved determination of agg/group columns + res = df.sql( + """ + SELECT + SUM(b) AS sum_b, + SUM(c) AS sum_c, + (SUM(b) + SUM(c)) / 2.0 AS sum_bc_over_2, -- nested agg + a as grp, --aliased group key + FROM self + GROUP BY ALL + ORDER BY grp + """ + ) + expected = pl.DataFrame( + { + "sum_b": [9, 6, 6], + "sum_c": [231, 165, 66], + "sum_bc_over_2": [120.0, 85.5, 36.0], + "grp": ["xx", "yy", "zz"], + } + ) + assert_frame_equal(expected, res.sort(by="grp")) + + +def test_group_by_all_multi() -> None: + dt1 = date(1999, 12, 31) + dt2 = date(2028, 7, 5) + + df = pl.DataFrame( + { + "key": ["xx", "yy", "xx", "yy", "xx", "xx"], + "dt": [dt1, dt1, dt1, dt2, dt2, dt2], + "value": [10.5, -5.5, 20.5, 8.0, -3.0, 5.0], + } + ) + expected = pl.DataFrame( + { + "dt": [dt1, dt1, dt2, dt2], + "key": ["xx", "yy", "xx", "yy"], + "sum_value": [31.0, -5.5, 2.0, 8.0], + "ninety_nine": [99, 99, 99, 99], + }, + schema_overrides={"ninety_nine": pl.Int16}, + ) + + # the following groupings should all be equivalent + for group in ( + "ALL", + "1, 2", + "dt, key", + ): + res = df.sql( + f""" + SELECT dt, key, sum_value, ninety_nine::int2 FROM + ( + SELECT + dt, + key, + SUM(value) AS sum_value, + 99 AS ninety_nine + FROM self + GROUP BY {group} + ORDER BY dt, key + ) AS grp + """ + ) + assert_frame_equal(expected, res) + + def test_group_by_ordinal_position() -> None: df = pl.DataFrame( { @@ -96,3 +197,31 @@ def test_group_by_ordinal_position() -> None: SELECT c, total_b FROM grp ORDER BY c""" ) assert_frame_equal(res2, expected) + + +def test_group_by_errors() -> None: + df = pl.DataFrame( + { + "a": ["xx", "yy", "xx"], + "b": [10, 20, 30], + "c": [99, 99, 66], + } + ) + + with pytest.raises( + ComputeError, + match=r"expected a positive integer or valid expression; got -99", + ): + df.sql("SELECT a, SUM(b) FROM self GROUP BY -99, a") + + with pytest.raises( + ComputeError, + match=r"expected a positive integer or valid expression; got '!!!'", + ): + df.sql("SELECT a, SUM(b) FROM self GROUP BY a, '!!!'") + + with pytest.raises( + ComputeError, + match=r"'a' should participate in the GROUP BY clause or an aggregate function", + ): + df.sql("SELECT a, SUM(b) FROM self GROUP BY b") From 3fef569fca9d57ce6a599acbb66f94f2cde95378 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 13 May 2024 10:45:38 +0200 Subject: [PATCH 19/29] fix: Fix panic on empty frame joins (#16181) --- crates/polars-core/src/series/implementations/null.rs | 6 +++++- py-polars/tests/unit/operations/test_join.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/crates/polars-core/src/series/implementations/null.rs b/crates/polars-core/src/series/implementations/null.rs index 8869855ba453..bf4afebdc52b 100644 --- a/crates/polars-core/src/series/implementations/null.rs +++ b/crates/polars-core/src/series/implementations/null.rs @@ -35,7 +35,11 @@ impl NullChunked { } } } -impl PrivateSeriesNumeric for NullChunked {} +impl PrivateSeriesNumeric for NullChunked { + fn bit_repr_small(&self) -> UInt32Chunked { + UInt32Chunked::full_null(self.name.as_ref(), self.len()) + } +} impl PrivateSeries for NullChunked { fn compute_len(&mut self) { diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index f06a40f8d103..e6a2faa316eb 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -990,3 +990,13 @@ def test_join_coalesce(how: str) -> None: out = q.collect() assert q.schema == out.schema assert out.columns == ["a", "b", "c"] + + +@pytest.mark.parametrize("how", ["left", "inner", "outer"]) +@typing.no_type_check +def test_join_empties(how: str) -> None: + df1 = pl.DataFrame({"col1": [], "col2": [], "col3": []}) + df2 = pl.DataFrame({"col2": [], "col4": [], "col5": []}) + + df = df1.join(df2, on="col2", how=how) + assert df.height == 0 From 77aaec3272ce6e46d7e658b131edfb5a688c8be1 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Mon, 13 May 2024 10:46:00 +0200 Subject: [PATCH 20/29] depr(python): Deprecate `allow_infinities` and `null_probability` args to parametric test strategies (#16183) --- py-polars/docs/source/reference/testing.rst | 2 +- .../testing/parametric/strategies/core.py | 189 ++++++++++++------ .../testing/parametric/strategies/data.py | 53 +++-- py-polars/tests/parametric/test_dataframe.py | 8 +- py-polars/tests/parametric/test_lazyframe.py | 6 +- py-polars/tests/parametric/test_testing.py | 67 +++---- .../tests/unit/operations/test_drop_nulls.py | 2 +- py-polars/tests/unit/operations/test_ewm.py | 2 +- .../tests/unit/operations/test_is_null.py | 2 +- .../unit/series/buffers/test_from_buffers.py | 2 +- .../parametric/strategies/test_data.py | 21 ++ 11 files changed, 234 insertions(+), 120 deletions(-) create mode 100644 py-polars/tests/unit/testing/parametric/strategies/test_data.py diff --git a/py-polars/docs/source/reference/testing.rst b/py-polars/docs/source/reference/testing.rst index 6cdc5ddbba78..1a84a5228263 100644 --- a/py-polars/docs/source/reference/testing.rst +++ b/py-polars/docs/source/reference/testing.rst @@ -122,7 +122,7 @@ of any generated value being ``null`` (this is distinct from ``NaN``). @given( dataframes( cols=5, - null_probability=0.1, + allow_null=True, allowed_dtypes=NUMERIC_DTYPES, ) ) diff --git a/py-polars/polars/testing/parametric/strategies/core.py b/py-polars/polars/testing/parametric/strategies/core.py index 08e751a18d74..3657d465e802 100644 --- a/py-polars/polars/testing/parametric/strategies/core.py +++ b/py-polars/polars/testing/parametric/strategies/core.py @@ -1,14 +1,14 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Collection, Sequence, overload +from typing import TYPE_CHECKING, Any, Collection, Mapping, Sequence, overload import hypothesis.strategies as st from hypothesis.errors import InvalidArgument -import polars.functions as F +from polars._utils.deprecation import issue_deprecation_warning from polars.dataframe import DataFrame -from polars.datatypes import Boolean, DataType, DataTypeClass +from polars.datatypes import DataType, DataTypeClass from polars.series import Series from polars.string_cache import StringCache from polars.testing.parametric.strategies._utils import flexhash @@ -31,6 +31,7 @@ @st.composite def series( # noqa: D417 draw: DrawFn, + /, *, name: str | SearchStrategy[str] | None = None, dtype: PolarsDataType | None = None, @@ -38,12 +39,12 @@ def series( # noqa: D417 min_size: int = 0, max_size: int = _ROW_LIMIT, strategy: SearchStrategy[Any] | None = None, - null_probability: float = 0.0, - allow_infinities: bool = True, + allow_null: bool = False, unique: bool = False, chunked: bool | None = None, allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, + **kwargs: Any, ) -> Series: """ Hypothesis strategy for producing polars Series. @@ -66,12 +67,8 @@ def series( # noqa: D417 MAX_DATA_SIZE). no-op if `size` is set. strategy : strategy, optional supports overriding the default strategy for the given dtype. - null_probability : float - percentage chance (expressed between 0.0 => 1.0) that a generated value is - None. this is applied independently of any None values generated by the - underlying strategy. - allow_infinities : bool, optional - optionally disallow generation of +/-inf values for floating-point dtypes. + allow_null : bool + Allow nulls as possible values. unique : bool, optional indicate whether Series values should all be distinct. chunked : bool, optional @@ -81,6 +78,23 @@ def series( # noqa: D417 when automatically generating Series data, allow only these dtypes. excluded_dtypes : {list,set}, optional when automatically generating Series data, exclude these dtypes. + **kwargs + Additional keyword arguments that are passed to the underlying data generation + strategies. + + null_probability : float + Percentage chance (expressed between 0.0 => 1.0) that any Series value is null. + This is applied independently of any None values generated by the underlying + strategy. + + .. deprecated:: 0.20.26 + Use `allow_null` instead. + + allow_infinities : bool, optional + Allow generation of +/-inf values for floating-point dtypes. + + .. deprecated:: 0.20.26 + Use `allow_infinity` instead. Notes ----- @@ -115,11 +129,14 @@ def series( # noqa: D417 ["xx"] ] """ - if not (0.0 <= null_probability <= 1.0): - msg = ( - f"`null_probability` should be between 0.0 and 1.0, got {null_probability}" + if (null_prob := kwargs.pop("null_probability", None)) is not None: + allow_null = _handle_null_probability_deprecation(null_prob) # type: ignore[assignment] + if (allow_inf := kwargs.pop("allow_infinities", None)) is not None: + issue_deprecation_warning( + "`allow_infinities` is deprecated. Use `allow_infinity` instead.", + version="0.20.26", ) - raise InvalidArgument(msg) + kwargs["allow_infinity"] = allow_inf if isinstance(allowed_dtypes, (DataType, DataTypeClass)): allowed_dtypes = [allowed_dtypes] @@ -152,12 +169,14 @@ def series( # noqa: D417 if size == 0: values = [] - elif null_probability == 1.0: - values = [None] * size else: # Create series using dtype-specific strategy to generate values if strategy is None: - strategy = data(dtype, allow_infinity=allow_infinities) # type: ignore[arg-type] + strategy = data( + dtype, # type: ignore[arg-type] + allow_null=allow_null, + **kwargs, + ) values = draw( st.lists( @@ -170,12 +189,6 @@ def series( # noqa: D417 s = Series(name=name, values=values, dtype=dtype) - # Set null values - if 0.0 < null_probability < 1.0: - random = draw(st.randoms(use_true_random=True)) - validity = [random.random() > null_probability for _ in range(size)] - s = F.select(F.when(Series(validity, dtype=Boolean)).then(s)).to_series() - # Apply chunking if size > 1: if chunked is None: @@ -200,10 +213,10 @@ def dataframes( max_size: int = _ROW_LIMIT, chunked: bool | None = None, include_cols: Sequence[column] | column | None = None, - null_probability: float | dict[str, float] = 0.0, - allow_infinities: bool = True, + allow_null: bool | Mapping[str, bool] = False, allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, + **kwargs: Any, ) -> SearchStrategy[DataFrame]: ... @@ -219,16 +232,17 @@ def dataframes( max_size: int = _ROW_LIMIT, chunked: bool | None = None, include_cols: Sequence[column] | column | None = None, - null_probability: float | dict[str, float] = 0.0, - allow_infinities: bool = True, + allow_null: bool | Mapping[str, bool] = False, allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, + **kwargs: Any, ) -> SearchStrategy[LazyFrame]: ... @st.composite def dataframes( # noqa: D417 draw: DrawFn, + /, cols: int | column | Sequence[column] | None = None, *, lazy: bool = False, @@ -239,10 +253,10 @@ def dataframes( # noqa: D417 max_size: int = _ROW_LIMIT, chunked: bool | None = None, include_cols: Sequence[column] | column | None = None, - null_probability: float | dict[str, float] = 0.0, - allow_infinities: bool = True, + allow_null: bool | Mapping[str, bool] = False, allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, + **kwargs: Any, ) -> DataFrame | LazyFrame: """ Hypothesis strategy for producing polars DataFrames or LazyFrames. @@ -275,18 +289,31 @@ def dataframes( # noqa: D417 a list of `column` objects to include in the generated DataFrame. note that explicitly provided columns are appended onto the list of existing columns (if any present). + allow_null : bool or Mapping[str, bool] + Allow nulls as possible values. + allowed_dtypes : {list,set}, optional + when automatically generating data, allow only these dtypes. + excluded_dtypes : {list,set}, optional + when automatically generating data, exclude these dtypes. + **kwargs + Additional keyword arguments that are passed to the underlying data generation + strategies. + null_probability : {float, dict[str,float]}, optional percentage chance (expressed between 0.0 => 1.0) that a generated value is None. this is applied independently of any None values generated by the underlying strategy, and can be applied either on a per-column basis (if given as a `{col:pct}` dict), or globally. if null_probability is defined on a column, it takes precedence over the global value. + + .. deprecated:: 0.20.26 + Use `allow_null` instead. + allow_infinities : bool, optional optionally disallow generation of +/-inf values for floating-point dtypes. - allowed_dtypes : {list,set}, optional - when automatically generating data, allow only these dtypes. - excluded_dtypes : {list,set}, optional - when automatically generating data, exclude these dtypes. + + .. deprecated:: 0.20.26 + Use `allow_infinity` instead. Notes ----- @@ -343,6 +370,15 @@ def dataframes( # noqa: D417 │ 575050513 ┆ NaN │ └───────────┴────────────┘ """ + if (null_prob := kwargs.pop("null_probability", None)) is not None: + allow_null = _handle_null_probability_deprecation(null_prob) + if (allow_inf := kwargs.pop("allow_infinities", None)) is not None: + issue_deprecation_warning( + "`allow_infinities` is deprecated. Use `allow_infinity` instead.", + version="0.20.26", + ) + kwargs["allow_infinity"] = allow_inf + if isinstance(include_cols, column): include_cols = [include_cols] @@ -362,15 +398,15 @@ def dataframes( # noqa: D417 if size is None: size = draw(st.integers(min_value=min_size, max_value=max_size)) - # assign names, null probability + # Process columns for idx, c in enumerate(cols): if c.name is None: c.name = f"col{idx}" - if c.null_probability is None: - if isinstance(null_probability, dict): - c.null_probability = null_probability.get(c.name, 0.0) + if c.allow_null is None: + if isinstance(allow_null, Mapping): + c.allow_null = allow_null.get(c.name, False) else: - c.null_probability = null_probability + c.allow_null = allow_null # init dataframe from generated series data; series data is # given as a python-native sequence. @@ -381,13 +417,13 @@ def dataframes( # noqa: D417 name=c.name, dtype=c.dtype, size=size, - null_probability=c.null_probability, # type: ignore[arg-type] - allow_infinities=allow_infinities, + allow_null=c.allow_null, # type: ignore[arg-type] strategy=c.strategy, unique=c.unique, chunked=None if chunked is None else False, allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes, + **kwargs, ) ) for c in cols @@ -395,7 +431,7 @@ def dataframes( # noqa: D417 df = DataFrame(data) - # optionally generate chunked frames + # Optionally generate chunked frames if size > 1 and chunked: split_at = size // 2 df = df[:split_at].vstack(df[split_at:]) @@ -409,7 +445,7 @@ def dataframes( # noqa: D417 @dataclass class column: """ - Define a column for use with the @dataframes strategy. + Define a column for use with the `dataframes` strategy. Parameters ---------- @@ -419,33 +455,72 @@ class column: a polars dtype. strategy : strategy, optional supports overriding the default strategy for the given dtype. + allow_null : bool, optional + Allow nulls as possible values. + unique : bool, optional + flag indicating that all values generated for the column should be unique. + null_probability : float, optional percentage chance (expressed between 0.0 => 1.0) that a generated value is None. this is applied independently of any None values generated by the underlying strategy. - unique : bool, optional - flag indicating that all values generated for the column should be unique. + + .. deprecated:: 0.20.26 + Use `allow_null` instead. Examples -------- >>> from polars.testing.parametric import column - >>> column(name="unique_small_ints", dtype=pl.UInt8, unique=True) - column(name='unique_small_ints', dtype=UInt8, strategy=None, null_probability=None, unique=True) - - >>> from hypothesis.strategies import sampled_from - >>> column(name="ccy", strategy=sampled_from(["GBP", "EUR", "JPY"])) - column(name='ccy', dtype=None, strategy=sampled_from(['GBP', 'EUR', 'JPY']), null_probability=None, unique=False) - """ # noqa: W505 + >>> dfs = dataframes( + ... [ + ... column("x", dtype=pl.Int32, allow_null=True), + ... column("y", dtype=pl.Float64), + ... ], + ... size=2, + ... ) + >>> dfs.example() # doctest: +SKIP + shape: (2, 2) + ┌───────────┬────────────┐ + │ x ┆ y │ + │ --- ┆ --- │ + │ i32 ┆ f64 │ + ╞═══════════╪════════════╡ + │ null ┆ 1.1755e-38 │ + │ 575050513 ┆ inf │ + └───────────┴────────────┘ + """ name: str | None = None dtype: PolarsDataType | None = None strategy: SearchStrategy[Any] | None = None - null_probability: float | None = None + allow_null: bool | None = None unique: bool = False + null_probability: float | None = None + def __post_init__(self) -> None: - if (self.null_probability is not None) and ( - self.null_probability < 0 or self.null_probability > 1 - ): - msg = f"`null_probability` should be between 0.0 and 1.0, or None; found {self.null_probability!r}" + if self.null_probability is not None: + self.allow_null = _handle_null_probability_deprecation( # type: ignore[assignment] + self.null_probability + ) + + +def _handle_null_probability_deprecation( + null_probability: float | Mapping[str, float], +) -> bool | dict[str, bool]: + issue_deprecation_warning( + "`null_probability` is deprecated. Use `include_nulls` instead.", + version="0.20.26", + ) + + def prob_to_bool(prob: float) -> bool: + if not (0.0 <= prob <= 1.0): + msg = f"`null_probability` should be between 0.0 and 1.0, got {prob!r}" raise InvalidArgument(msg) + + return bool(prob) + + if isinstance(null_probability, Mapping): + return {col: prob_to_bool(prob) for col, prob in null_probability.items()} + else: + return prob_to_bool(null_probability) diff --git a/py-polars/polars/testing/parametric/strategies/data.py b/py-polars/polars/testing/parametric/strategies/data.py index ad1373689c2c..78dc119c8929 100644 --- a/py-polars/polars/testing/parametric/strategies/data.py +++ b/py-polars/polars/testing/parametric/strategies/data.py @@ -3,7 +3,6 @@ from __future__ import annotations import decimal -import string from datetime import timedelta from typing import TYPE_CHECKING, Any, Literal, Sequence @@ -61,6 +60,7 @@ from polars.type_aliases import PolarsDataType, TimeUnit _DEFAULT_LIST_LEN_LIMIT = 3 +_DEFAULT_N_CATEGORIES = 10 _INTEGER_STRATEGIES: dict[bool, dict[int, SearchStrategy[int]]] = { True: { @@ -108,9 +108,17 @@ def binary() -> SearchStrategy[bytes]: return st.binary() -def categories() -> SearchStrategy[str]: - """Create a strategy for generating category strings.""" - return st.text(alphabet=string.ascii_uppercase, min_size=1, max_size=2) +def categories(n_categories: int = _DEFAULT_N_CATEGORIES) -> SearchStrategy[str]: + """ + Create a strategy for generating category strings. + + Parameters + ---------- + n_categories + The number of categories. + """ + categories = [f"c{i}" for i in range(n_categories)] + return st.sampled_from(categories) def times() -> SearchStrategy[time]: @@ -284,14 +292,15 @@ def nulls() -> SearchStrategy[None]: UInt64: integers(64, signed=False), Time: times(), Date: dates(), - Categorical: categories(), String: strings(), Binary: binary(), Null: nulls(), } -def data(dtype: PolarsDataType, **kwargs: Any) -> SearchStrategy[Any]: +def data( + dtype: PolarsDataType, *, allow_null: bool = False, **kwargs: Any +) -> SearchStrategy[Any]: """ Create a strategy for generating data for the given data type. @@ -300,30 +309,37 @@ def data(dtype: PolarsDataType, **kwargs: Any) -> SearchStrategy[Any]: dtype A Polars data type. If the data type is not fully instantiated, defaults will be used, e.g. `Datetime` will become `Datetime('us')`. + allow_null + Allow nulls as possible values. **kwargs Additional parameters for the strategy associated with the given `dtype`. """ if (strategy := _STATIC_STRATEGIES.get(dtype.base_type())) is not None: - return strategy - - if dtype == Float32: - return floats(32, allow_infinity=kwargs.pop("allow_infinity", True)) + strategy = strategy + elif dtype == Float32: + strategy = floats(32, allow_infinity=kwargs.pop("allow_infinity", True)) elif dtype == Float64: - return floats(64, allow_infinity=kwargs.pop("allow_infinity", True)) + strategy = floats(64, allow_infinity=kwargs.pop("allow_infinity", True)) elif dtype == Datetime: # TODO: Handle time zones - return datetimes(time_unit=getattr(dtype, "time_unit", None) or "us") + strategy = datetimes(time_unit=getattr(dtype, "time_unit", None) or "us") elif dtype == Duration: - return durations(time_unit=getattr(dtype, "time_unit", None) or "us") + strategy = durations(time_unit=getattr(dtype, "time_unit", None) or "us") + elif dtype == Categorical: + strategy = categories( + n_categories=kwargs.pop("n_categories", _DEFAULT_N_CATEGORIES) + ) elif dtype == Decimal: - return decimals(getattr(dtype, "precision", None), getattr(dtype, "scale", 0)) + strategy = decimals( + getattr(dtype, "precision", None), getattr(dtype, "scale", 0) + ) elif dtype == List: inner = getattr(dtype, "inner", None) or Null() - return lists(inner, **kwargs) + strategy = lists(inner, **kwargs) elif dtype == Array: inner = getattr(dtype, "inner", None) or Null() width = getattr(dtype, "width", _DEFAULT_ARRAY_WIDTH_LIMIT) - return lists( + strategy = lists( inner, min_len=width, max_len=width, @@ -332,3 +348,8 @@ def data(dtype: PolarsDataType, **kwargs: Any) -> SearchStrategy[Any]: else: msg = f"unsupported data type: {dtype}" raise InvalidArgument(msg) + + if allow_null: + strategy = nulls() | strategy + + return strategy diff --git a/py-polars/tests/parametric/test_dataframe.py b/py-polars/tests/parametric/test_dataframe.py index 5a782d1042bc..a69c5880407e 100644 --- a/py-polars/tests/parametric/test_dataframe.py +++ b/py-polars/tests/parametric/test_dataframe.py @@ -40,7 +40,7 @@ def test_dtype_integer_cols(df: pl.DataFrame) -> None: df=dataframes( min_size=1, min_cols=1, - null_probability=0.25, + allow_null=True, excluded_dtypes=[pl.String, pl.List], ) ) @@ -64,19 +64,19 @@ def test_null_count(df: pl.DataFrame) -> None: column( "start", dtype=pl.Int8, - null_probability=0.15, + allow_null=True, strategy=st.integers(min_value=-8, max_value=8), ), column( "stop", dtype=pl.Int8, - null_probability=0.15, + allow_null=True, strategy=st.integers(min_value=-6, max_value=6), ), column( "step", dtype=pl.Int8, - null_probability=0.15, + allow_null=True, strategy=st.integers(min_value=-4, max_value=4).filter( lambda x: x != 0 ), diff --git a/py-polars/tests/parametric/test_lazyframe.py b/py-polars/tests/parametric/test_lazyframe.py index 24420ce7fb16..a40a5fbc3d22 100644 --- a/py-polars/tests/parametric/test_lazyframe.py +++ b/py-polars/tests/parametric/test_lazyframe.py @@ -16,19 +16,19 @@ column( "start", dtype=pl.Int8, - null_probability=0.3, + allow_null=True, strategy=st.integers(min_value=-3, max_value=4), ), column( "stop", dtype=pl.Int8, - null_probability=0.3, + allow_null=True, strategy=st.integers(min_value=-2, max_value=6), ), column( "step", dtype=pl.Int8, - null_probability=0.3, + allow_null=True, strategy=st.integers(min_value=-3, max_value=3).filter( lambda x: x != 0 ), diff --git a/py-polars/tests/parametric/test_testing.py b/py-polars/tests/parametric/test_testing.py index 98c5f83252bb..053cbf262e92 100644 --- a/py-polars/tests/parametric/test_testing.py +++ b/py-polars/tests/parametric/test_testing.py @@ -3,14 +3,13 @@ # ------------------------------------------------ from __future__ import annotations -import warnings from datetime import datetime from typing import Any import hypothesis.strategies as st import pytest from hypothesis import given, settings -from hypothesis.errors import InvalidArgument, NonInteractiveExampleWarning +from hypothesis.errors import InvalidArgument import polars as pl from polars.testing.parametric import column, dataframes, lists, series @@ -113,23 +112,23 @@ def test_strategy_dtypes( @given(s=series()) -def test_series_null_probability_default(s: pl.Series) -> None: +def test_series_allow_null_default(s: pl.Series) -> None: assert s.null_count() == 0 -@given(s=series(null_probability=0.1)) -def test_series_null_probability(s: pl.Series) -> None: +@given(s=series(allow_null=True)) +def test_series_allow_null(s: pl.Series) -> None: assert 0 <= s.null_count() <= s.len() -@given(df=dataframes(cols=1, null_probability=0.3)) -def test_dataframes_null_probability_global(df: pl.DataFrame) -> None: +@given(df=dataframes(cols=1, allow_null=True)) +def test_dataframes_allow_null_global(df: pl.DataFrame) -> None: null_count = sum(df.null_count().row(0)) assert 0 <= null_count <= df.height * df.width -@given(df=dataframes(cols=2, null_probability={"col0": 0.7})) -def test_dataframes_null_probability_column(df: pl.DataFrame) -> None: +@given(df=dataframes(cols=2, allow_null={"col0": True})) +def test_dataframes_allow_null_column(df: pl.DataFrame) -> None: null_count = sum(df.null_count().row(0)) assert 0 <= null_count <= df.height * df.width @@ -137,13 +136,13 @@ def test_dataframes_null_probability_column(df: pl.DataFrame) -> None: @given( df=dataframes( cols=1, - null_probability=1.0, - include_cols=[column(name="colx", null_probability=0.2)], + allow_null=False, + include_cols=[column(name="colx", allow_null=True)], ) ) -def test_dataframes_null_probability_override(df: pl.DataFrame) -> None: - assert df.get_column("col0").null_count() == df.height - assert 0 <= df.get_column("col0").null_count() <= df.height +def test_dataframes_allow_null_override(df: pl.DataFrame) -> None: + assert df.get_column("col0").null_count() == 0 + assert 0 <= df.get_column("colx").null_count() <= df.height @given( @@ -170,11 +169,9 @@ def test_chunking( @given( df=dataframes( - allowed_dtypes=[pl.Float32, pl.Float64], - allow_infinities=False, - max_cols=4, + allowed_dtypes=[pl.Float32, pl.Float64], max_cols=4, allow_infinity=False ), - s=series(dtype=pl.Float64, allow_infinities=False), + s=series(dtype=pl.Float64, allow_infinity=False), ) def test_infinities( df: pl.DataFrame, @@ -224,23 +221,23 @@ def test_sequence_strategies(df: pl.DataFrame) -> None: @pytest.mark.hypothesis() -@pytest.mark.parametrize("invalid_probability", [-1.0, +2.0]) -def test_invalid_argument_null_probability(invalid_probability: float) -> None: - with pytest.raises(InvalidArgument, match="between 0.0 and 1.0"): - column("colx", dtype=pl.Boolean, null_probability=invalid_probability) - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=NonInteractiveExampleWarning) - with pytest.raises(InvalidArgument, match="between 0.0 and 1.0"): - series(name="colx", null_probability=invalid_probability).example() - with pytest.raises(InvalidArgument, match="between 0.0 and 1.0"): - dataframes( - cols=column(None), - null_probability=invalid_probability, - ).example() +def test_column_invalid_probability() -> None: + with pytest.deprecated_call(), pytest.raises(InvalidArgument): + column("col", null_probability=2.0) @pytest.mark.hypothesis() -def test_column_invalid_probability() -> None: - with pytest.raises(InvalidArgument): - column("col", null_probability=2.0) +def test_column_null_probability_deprecated() -> None: + with pytest.deprecated_call(): + col = column("col", allow_null=False, null_probability=0.5) + assert col.null_probability == 0.5 + assert col.allow_null is True # null_probability takes precedence + + +@given(st.data()) +def test_allow_infinities_deprecated(data: st.DataObject) -> None: + with pytest.deprecated_call(): + strategy = series(dtype=pl.Float64, allow_infinities=False) + s = data.draw(strategy) + + assert all(v not in (float("inf"), float("-inf")) for v in s) diff --git a/py-polars/tests/unit/operations/test_drop_nulls.py b/py-polars/tests/unit/operations/test_drop_nulls.py index 1ca966f8314a..4250ecad154e 100644 --- a/py-polars/tests/unit/operations/test_drop_nulls.py +++ b/py-polars/tests/unit/operations/test_drop_nulls.py @@ -7,7 +7,7 @@ from polars.testing.parametric import series -@given(s=series(null_probability=0.5)) +@given(s=series(allow_null=True)) def test_drop_nulls_parametric(s: pl.Series) -> None: result = s.drop_nulls() assert result.len() == s.len() - s.null_count() diff --git a/py-polars/tests/unit/operations/test_ewm.py b/py-polars/tests/unit/operations/test_ewm.py index 66e94c20aee7..57b2e32b10d3 100644 --- a/py-polars/tests/unit/operations/test_ewm.py +++ b/py-polars/tests/unit/operations/test_ewm.py @@ -223,7 +223,7 @@ def alpha_guard(**decay_param: float) -> bool: s=series( min_size=4, dtype=pl.Float64, - null_probability=0.05, + allow_null=True, strategy=st.floats(min_value=-1e8, max_value=1e8), ), half_life=st.floats(min_value=0, max_value=4, exclude_min=True).filter( diff --git a/py-polars/tests/unit/operations/test_is_null.py b/py-polars/tests/unit/operations/test_is_null.py index 7e1a53fa04c9..ec58ca68629e 100644 --- a/py-polars/tests/unit/operations/test_is_null.py +++ b/py-polars/tests/unit/operations/test_is_null.py @@ -7,7 +7,7 @@ from polars.testing.parametric import series -@given(s=series(null_probability=0.5)) +@given(s=series(allow_null=True)) def test_is_null_parametric(s: pl.Series) -> None: is_null = s.is_null() is_not_null = s.is_not_null() diff --git a/py-polars/tests/unit/series/buffers/test_from_buffers.py b/py-polars/tests/unit/series/buffers/test_from_buffers.py index 83ee12086b36..497591e94a5c 100644 --- a/py-polars/tests/unit/series/buffers/test_from_buffers.py +++ b/py-polars/tests/unit/series/buffers/test_from_buffers.py @@ -37,7 +37,7 @@ def test_series_from_buffers_numeric_with_validity(s: pl.Series) -> None: s=series( allowed_dtypes=(pl.INTEGER_DTYPES | pl.FLOAT_DTYPES | {pl.Boolean}), chunked=False, - null_probability=0.0, + allow_null=False, ) ) def test_series_from_buffers_numeric(s: pl.Series) -> None: diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_data.py b/py-polars/tests/unit/testing/parametric/strategies/test_data.py new file mode 100644 index 000000000000..0820015158dc --- /dev/null +++ b/py-polars/tests/unit/testing/parametric/strategies/test_data.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from hypothesis import given + +import polars as pl +from polars.testing.parametric.strategies.data import categories, data + + +@given(cat=categories(3)) +def test_categories(cat: str) -> None: + assert cat in ("c0", "c1", "c2") + + +@given(cat=data(pl.Categorical, n_categories=3)) +def test_data_kwargs(cat: str) -> None: + assert cat in ("c0", "c1", "c2") + + +@given(categories=data(pl.List(pl.Categorical), n_categories=3)) +def test_data_nested_kwargs(categories: list[str]) -> None: + assert all(c in ("c0", "c1", "c2") for c in categories) From 5a990ffaba4916d4027791b93048b399238d73fa Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 13 May 2024 10:46:09 +0200 Subject: [PATCH 21/29] =?UTF-8?q?feat:=20Raise=20when=20encountering=20inv?= =?UTF-8?q?alid=20supertype=20in=20functions=20during=20c=E2=80=A6=20(#161?= =?UTF-8?q?82)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/dsl/function_expr/fill_null.rs | 14 +---------- .../src/logical_plan/conversion/mod.rs | 1 + .../type_coercion/binary.rs | 0 .../type_coercion/mod.rs | 18 +++++++------ .../src/logical_plan/optimizer/mod.rs | 3 +-- crates/polars-utils/src/macros.rs | 25 +++++++++++++++++++ py-polars/tests/unit/test_errors.py | 8 ++++++ 7 files changed, 47 insertions(+), 22 deletions(-) rename crates/polars-plan/src/logical_plan/{optimizer => conversion}/type_coercion/binary.rs (100%) rename crates/polars-plan/src/logical_plan/{optimizer => conversion}/type_coercion/mod.rs (97%) diff --git a/crates/polars-plan/src/dsl/function_expr/fill_null.rs b/crates/polars-plan/src/dsl/function_expr/fill_null.rs index 96629e40c994..d5e408c0082d 100644 --- a/crates/polars-plan/src/dsl/function_expr/fill_null.rs +++ b/crates/polars-plan/src/dsl/function_expr/fill_null.rs @@ -4,19 +4,7 @@ pub(super) fn fill_null(s: &[Series]) -> PolarsResult { let series = s[0].clone(); let fill_value = s[1].clone(); - // let (series, fill_value) = if matches!(super_type, DataType::Unknown(_)) { - // let fill_value = fill_value.cast(series.dtype()).map_err(|_| { - // polars_err!( - // SchemaMismatch: - // "`fill_null` supertype could not be determined; set correct literal value or \ - // ensure the type of the expression is known" - // ) - // })?; - // (series.clone(), fill_value) - // } else { - // (series.cast(super_type)?, fill_value.cast(super_type)?) - // }; - // nothing to fill, so return early + // Nothing to fill, so return early // this is done after casting as the output type must be correct if series.null_count() == 0 { return Ok(series); diff --git a/crates/polars-plan/src/logical_plan/conversion/mod.rs b/crates/polars-plan/src/logical_plan/conversion/mod.rs index 62f9c06ae66f..230c2fd2b4e3 100644 --- a/crates/polars-plan/src/logical_plan/conversion/mod.rs +++ b/crates/polars-plan/src/logical_plan/conversion/mod.rs @@ -14,6 +14,7 @@ pub use ir_to_dsl::*; use polars_core::prelude::*; use polars_utils::vec::ConvertVec; use recursive::recursive; +pub(crate) mod type_coercion; use crate::constants::get_len_name; use crate::prelude::*; diff --git a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/binary.rs b/crates/polars-plan/src/logical_plan/conversion/type_coercion/binary.rs similarity index 100% rename from crates/polars-plan/src/logical_plan/optimizer/type_coercion/binary.rs rename to crates/polars-plan/src/logical_plan/conversion/type_coercion/binary.rs diff --git a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs b/crates/polars-plan/src/logical_plan/conversion/type_coercion/mod.rs similarity index 97% rename from crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs rename to crates/polars-plan/src/logical_plan/conversion/type_coercion/mod.rs index d38d58b027ef..b86fb13f2254 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs +++ b/crates/polars-plan/src/logical_plan/conversion/type_coercion/mod.rs @@ -3,13 +3,13 @@ mod binary; use std::borrow::Cow; use arrow::legacy::utils::CustomIterTools; +use binary::process_binary; use polars_core::prelude::*; use polars_core::utils::{get_supertype, materialize_dyn_int}; use polars_utils::idx_vec::UnitVec; -use polars_utils::unitvec; +use polars_utils::{format_list, unitvec}; use super::*; -use crate::logical_plan::optimizer::type_coercion::binary::process_binary; pub struct TypeCoercionRule {} @@ -345,6 +345,8 @@ impl OptimizationRule for TypeCoercionRule { for e in input { let (_, dtype) = unpack!(get_aexpr_and_type(expr_arena, e.node(), &input_schema)); + // Ignore Unknown in the inputs. + // We will raise if we cannot find the supertype later. match dtype { DataType::Unknown(UnknownKind::Any) => { options.cast_to_supertypes = false; @@ -369,11 +371,9 @@ impl OptimizationRule for TypeCoercionRule { let (other, type_other) = unpack!(get_aexpr_and_type(expr_arena, other.node(), &input_schema)); - // early return until Unknown is set - if matches!(type_other, DataType::Unknown(UnknownKind::Any)) { - return Ok(None); - } - let new_st = unpack!(get_supertype(&super_type, &type_other)); + let Some(new_st) = get_supertype(&super_type, &type_other) else { + polars_bail!(InvalidOperation: "could not determine supertype of: {}", format_list!(dtypes)); + }; if input.len() == 2 { // modify_supertype is a bit more conservative of casting columns // to literals @@ -385,6 +385,10 @@ impl OptimizationRule for TypeCoercionRule { } } + if matches!(super_type, DataType::Unknown(UnknownKind::Any)) { + polars_bail!(InvalidOperation: "could not determine supertype of: {}", format_list!(dtypes)); + } + let function = function.clone(); let input = input.clone(); diff --git a/crates/polars-plan/src/logical_plan/optimizer/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/mod.rs index e0c9e6dd2118..f980781166b2 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/mod.rs @@ -20,7 +20,6 @@ mod simplify_functions; mod slice_pushdown_expr; mod slice_pushdown_lp; mod stack_opt; -mod type_coercion; use collapse_and_project::SimpleProjectionAndCollapse; use delay_rechunk::DelayRechunk; @@ -31,10 +30,10 @@ pub use projection_pushdown::ProjectionPushDown; pub use simplify_expr::{SimplifyBooleanRule, SimplifyExprRule}; use slice_pushdown_lp::SlicePushDown; pub use stack_opt::{OptimizationRule, StackOptimizer}; -pub use type_coercion::TypeCoercionRule; use self::flatten_union::FlattenUnionRule; pub use crate::frame::{AllowedOptimizations, OptState}; +pub use crate::logical_plan::conversion::type_coercion::TypeCoercionRule; use crate::logical_plan::optimizer::count_star::CountStar; #[cfg(feature = "cse")] use crate::logical_plan::optimizer::cse::prune_unused_caches; diff --git a/crates/polars-utils/src/macros.rs b/crates/polars-utils/src/macros.rs index 264a7f5a148e..00d16315e0f8 100644 --- a/crates/polars-utils/src/macros.rs +++ b/crates/polars-utils/src/macros.rs @@ -16,3 +16,28 @@ macro_rules! unreachable_unchecked_release { } }; } + +#[macro_export] +macro_rules! format_list { + ($e:expr) => {{ + use std::fmt::Write; + let mut out = String::new(); + out.push('['); + let mut iter = $e.into_iter(); + let mut next = iter.next(); + + loop { + if let Some(val) = next { + write!(out, "{val}").unwrap(); + }; + next = iter.next(); + if next.is_some() { + out.push_str(", ") + } else { + break; + } + } + out.push_str("]\n"); + out + };}; +} diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index dc8a6e7b98de..b06cf74d2abc 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -704,3 +704,11 @@ def test_invalid_product_type() -> None: match="`product` operation not supported for dtype", ): pl.Series([[1, 2, 3]]).product() + + +def test_fill_null_invalid_supertype() -> None: + df = pl.DataFrame({"date": [date(2022, 1, 1), None]}) + with pytest.raises( + pl.InvalidOperationError, match="could not determine supertype of" + ): + df.select(pl.col("date").fill_null(1.0)) From 3a630eca62fbb40fcfb7f0581701c08680d1a3b2 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Mon, 13 May 2024 10:25:22 +0100 Subject: [PATCH 22/29] fix: offset=-0i was being treated differently to offset=0i in rolling (#16184) --- crates/polars-time/src/windows/group_by.rs | 4 ++-- py-polars/tests/parametric/test_groupby_rolling.py | 2 +- py-polars/tests/unit/operations/test_rolling.py | 12 ++++++++++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/crates/polars-time/src/windows/group_by.rs b/crates/polars-time/src/windows/group_by.rs index 7b48db38c8e6..1754cde91865 100644 --- a/crates/polars-time/src/windows/group_by.rs +++ b/crates/polars-time/src/windows/group_by.rs @@ -576,7 +576,7 @@ pub fn group_by_values( let run_parallel = !POOL.current_thread_has_pending_tasks().unwrap_or(false); // we have a (partial) lookbehind window - if offset.negative { + if offset.negative && offset.duration_ns() > 0 { // lookbehind if offset.duration_ns() == period.duration_ns() { // t is right at the end of the window @@ -647,7 +647,7 @@ pub fn group_by_values( iter.map(|result| result.map(|(offset, len)| [offset, len])) .collect::>() } - } else if offset != Duration::parse("0ns") + } else if offset.duration_ns() != 0 || closed_window == ClosedWindow::Right || closed_window == ClosedWindow::None { diff --git a/py-polars/tests/parametric/test_groupby_rolling.py b/py-polars/tests/parametric/test_groupby_rolling.py index e048eba378e5..8597634008f4 100644 --- a/py-polars/tests/parametric/test_groupby_rolling.py +++ b/py-polars/tests/parametric/test_groupby_rolling.py @@ -29,7 +29,7 @@ def interval_defs() -> SearchStrategy[ClosedInterval]: min_value=timedelta(microseconds=0), max_value=timedelta(days=1000) ).map(parse_as_duration_string), offset=st.timedeltas( - min_value=timedelta(microseconds=0), max_value=timedelta(days=1000) + min_value=timedelta(days=-1000), max_value=timedelta(days=1000) ).map(parse_as_duration_string), closed=interval_defs(), data=st.data(), diff --git a/py-polars/tests/unit/operations/test_rolling.py b/py-polars/tests/unit/operations/test_rolling.py index 89899d20ad46..6f6e2b24e7a4 100644 --- a/py-polars/tests/unit/operations/test_rolling.py +++ b/py-polars/tests/unit/operations/test_rolling.py @@ -321,3 +321,15 @@ def test_multiple_rolling_in_single_expression() -> None: front_count.alias("front"), (back_count - front_count).alias("back - front"), )["back - front"].to_list() == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5] + + +def test_negative_zero_offset_16168() -> None: + df = pl.DataFrame({"foo": [1] * 3}).sort("foo").with_row_index() + result = df.rolling(index_column="foo", period="1i", offset="0i").agg("index") + expected = pl.DataFrame( + {"foo": [1, 1, 1], "index": [[], [], []]}, + schema_overrides={"index": pl.List(pl.UInt32)}, + ) + assert_frame_equal(result, expected) + result = df.rolling(index_column="foo", period="1i", offset="-0i").agg("index") + assert_frame_equal(result, expected) From 0654b7d109a92c690e8973e281f819790506a4b6 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Mon, 13 May 2024 12:39:52 +0200 Subject: [PATCH 23/29] test(python): Move hypothesis tests into unit test module (#16185) --- py-polars/tests/parametric/conftest.py | 7 - .../tests/parametric/test_groupby_rolling.py | 170 ----------- py-polars/tests/parametric/test_lit.py | 27 -- py-polars/tests/parametric/test_series.py | 83 ------ py-polars/tests/parametric/test_testing.py | 243 ---------------- .../time_series/test_add_business_days.py | 51 ---- .../time_series/test_business_day_count.py | 44 --- .../dataframe/test_getitem.py} | 52 +--- .../tests/unit/dataframe/test_null_count.py | 27 ++ py-polars/tests/unit/dataframe/test_repr.py | 15 + .../{business => }/test_business_day_count.py | 41 +++ .../functions}/test_ewm_by.py | 0 py-polars/tests/unit/functions/test_lit.py | 23 ++ .../lazyframe/test_getitem.py} | 2 +- .../{ => operations}/namespaces/__init__.py | 0 .../operations/namespaces/array}/__init__.py | 0 .../namespaces/array/test_array.py | 0 .../namespaces/array/test_contains.py | 0 .../namespaces/array/test_to_list.py | 0 .../{ => operations}/namespaces/conftest.py | 0 .../namespaces/files/test_tree_fmt.txt | 0 .../namespaces/list}/__init__.py | 0 .../namespaces/list/test_list.py | 0 .../namespaces/list/test_set_operations.py | 0 .../namespaces/string}/__init__.py | 0 .../namespaces/string/test_concat.py | 0 .../namespaces/string/test_pad.py | 0 .../namespaces/string/test_string.py | 0 .../namespaces/temporal}/__init__.py | 0 .../temporal}/test_add_business_days.py | 44 ++- .../namespaces/temporal}/test_datetime.py | 50 ++++ .../namespaces/temporal}/test_to_datetime.py | 0 .../namespaces/temporal}/test_truncate.py | 0 .../namespaces/test_binary.py | 0 .../namespaces/test_categorical.py | 0 .../{ => operations}/namespaces/test_meta.py | 0 .../{ => operations}/namespaces/test_name.py | 0 .../{ => operations}/namespaces/test_plot.py | 0 .../namespaces/test_strptime.py | 0 .../namespaces/test_struct.py | 0 .../unit/operations/rolling/test_rolling.py | 158 ++++++++++ py-polars/tests/unit/series/test_getitem.py | 30 ++ .../parametric/strategies/test_core.py | 273 +++++++++++++++++- .../unit/testing/test_assert_frame_equal.py | 7 + 44 files changed, 661 insertions(+), 686 deletions(-) delete mode 100644 py-polars/tests/parametric/conftest.py delete mode 100644 py-polars/tests/parametric/test_groupby_rolling.py delete mode 100644 py-polars/tests/parametric/test_lit.py delete mode 100644 py-polars/tests/parametric/test_series.py delete mode 100644 py-polars/tests/parametric/test_testing.py delete mode 100644 py-polars/tests/parametric/time_series/test_add_business_days.py delete mode 100644 py-polars/tests/parametric/time_series/test_business_day_count.py rename py-polars/tests/{parametric/test_dataframe.py => unit/dataframe/test_getitem.py} (60%) create mode 100644 py-polars/tests/unit/dataframe/test_null_count.py create mode 100644 py-polars/tests/unit/dataframe/test_repr.py rename py-polars/tests/unit/functions/{business => }/test_business_day_count.py (77%) rename py-polars/tests/{parametric/time_series => unit/functions}/test_ewm_by.py (100%) rename py-polars/tests/{parametric/test_lazyframe.py => unit/lazyframe/test_getitem.py} (97%) rename py-polars/tests/unit/{ => operations}/namespaces/__init__.py (100%) rename py-polars/tests/{parametric => unit/operations/namespaces/array}/__init__.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/array/test_array.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/array/test_contains.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/array/test_to_list.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/conftest.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/files/test_tree_fmt.txt (100%) rename py-polars/tests/unit/{namespaces/array => operations/namespaces/list}/__init__.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/list/test_list.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/list/test_set_operations.py (100%) rename py-polars/tests/unit/{namespaces/list => operations/namespaces/string}/__init__.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/string/test_concat.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/string/test_pad.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/string/test_string.py (100%) rename py-polars/tests/unit/{namespaces/string => operations/namespaces/temporal}/__init__.py (100%) rename py-polars/tests/unit/{functions/business => operations/namespaces/temporal}/test_add_business_days.py (86%) rename py-polars/tests/unit/{namespaces => operations/namespaces/temporal}/test_datetime.py (96%) rename py-polars/tests/{parametric/time_series => unit/operations/namespaces/temporal}/test_to_datetime.py (100%) rename py-polars/tests/{parametric/time_series => unit/operations/namespaces/temporal}/test_truncate.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/test_binary.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/test_categorical.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/test_meta.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/test_name.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/test_plot.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/test_strptime.py (100%) rename py-polars/tests/unit/{ => operations}/namespaces/test_struct.py (100%) create mode 100644 py-polars/tests/unit/series/test_getitem.py diff --git a/py-polars/tests/parametric/conftest.py b/py-polars/tests/parametric/conftest.py deleted file mode 100644 index 5f0d33ec5073..000000000000 --- a/py-polars/tests/parametric/conftest.py +++ /dev/null @@ -1,7 +0,0 @@ -import os - -from polars.testing.parametric import load_profile - -load_profile( - profile=os.environ.get("POLARS_HYPOTHESIS_PROFILE", "fast"), # type: ignore[arg-type] -) diff --git a/py-polars/tests/parametric/test_groupby_rolling.py b/py-polars/tests/parametric/test_groupby_rolling.py deleted file mode 100644 index 8597634008f4..000000000000 --- a/py-polars/tests/parametric/test_groupby_rolling.py +++ /dev/null @@ -1,170 +0,0 @@ -from __future__ import annotations - -import datetime as dt -from datetime import timedelta -from typing import TYPE_CHECKING - -import hypothesis.strategies as st -from hypothesis import assume, given - -import polars as pl -from polars._utils.convert import parse_as_duration_string -from polars.testing import assert_frame_equal -from polars.testing.parametric import column, dataframes -from polars.testing.parametric.strategies.dtype import _time_units - -if TYPE_CHECKING: - from hypothesis.strategies import SearchStrategy - - from polars.type_aliases import ClosedInterval, TimeUnit - - -def interval_defs() -> SearchStrategy[ClosedInterval]: - closed: list[ClosedInterval] = ["left", "right", "both", "none"] - return st.sampled_from(closed) - - -@given( - period=st.timedeltas( - min_value=timedelta(microseconds=0), max_value=timedelta(days=1000) - ).map(parse_as_duration_string), - offset=st.timedeltas( - min_value=timedelta(days=-1000), max_value=timedelta(days=1000) - ).map(parse_as_duration_string), - closed=interval_defs(), - data=st.data(), - time_unit=_time_units(), -) -def test_rolling( - period: str, - offset: str, - closed: ClosedInterval, - data: st.DataObject, - time_unit: TimeUnit, -) -> None: - assume(period != "") - dataframe = data.draw( - dataframes( - [ - column( - "ts", - strategy=st.datetimes( - min_value=dt.datetime(2000, 1, 1), - max_value=dt.datetime(2001, 1, 1), - ), - dtype=pl.Datetime(time_unit), - ), - column( - "value", - strategy=st.integers(min_value=-100, max_value=100), - dtype=pl.Int64, - ), - ], - min_size=1, - ) - ) - df = dataframe.sort("ts") - result = df.rolling("ts", period=period, offset=offset, closed=closed).agg( - pl.col("value") - ) - - expected_dict: dict[str, list[object]] = {"ts": [], "value": []} - for ts, _ in df.iter_rows(): - window = df.filter( - pl.col("ts").is_between( - pl.lit(ts, dtype=pl.Datetime(time_unit)).dt.offset_by(offset), - pl.lit(ts, dtype=pl.Datetime(time_unit)) - .dt.offset_by(offset) - .dt.offset_by(period), - closed=closed, - ) - ) - value = window["value"].to_list() - expected_dict["ts"].append(ts) - expected_dict["value"].append(value) - expected = pl.DataFrame(expected_dict).select( - pl.col("ts").cast(pl.Datetime(time_unit)), - pl.col("value").cast(pl.List(pl.Int64)), - ) - assert_frame_equal(result, expected) - - -@given( - window_size=st.timedeltas( - min_value=timedelta(microseconds=0), max_value=timedelta(days=2) - ).map(parse_as_duration_string), - closed=interval_defs(), - data=st.data(), - time_unit=_time_units(), - aggregation=st.sampled_from( - [ - "min", - "max", - "mean", - "sum", - "std", - "var", - "median", - ] - ), -) -def test_rolling_aggs( - window_size: str, - closed: ClosedInterval, - data: st.DataObject, - time_unit: TimeUnit, - aggregation: str, -) -> None: - assume(window_size != "") - - # Testing logic can be faulty when window is more precise than time unit - # https://github.com/pola-rs/polars/issues/11754 - assume(not (time_unit == "ms" and "us" in window_size)) - - dataframe = data.draw( - dataframes( - [ - column( - "ts", - strategy=st.datetimes( - min_value=dt.datetime(2000, 1, 1), - max_value=dt.datetime(2001, 1, 1), - ), - dtype=pl.Datetime(time_unit), - ), - column( - "value", - strategy=st.integers(min_value=-100, max_value=100), - dtype=pl.Int64, - ), - ], - ) - ) - df = dataframe.sort("ts") - func = f"rolling_{aggregation}_by" - result = df.with_columns( - getattr(pl.col("value"), func)("ts", window_size=window_size, closed=closed) - ) - - expected_dict: dict[str, list[object]] = {"ts": [], "value": []} - for ts, _ in df.iter_rows(): - window = df.filter( - pl.col("ts").is_between( - pl.lit(ts, dtype=pl.Datetime(time_unit)).dt.offset_by( - f"-{window_size}" - ), - pl.lit(ts, dtype=pl.Datetime(time_unit)), - closed=closed, - ) - ) - expected_dict["ts"].append(ts) - if window.is_empty(): - expected_dict["value"].append(None) - else: - value = getattr(window["value"], aggregation)() - expected_dict["value"].append(value) - expected = pl.DataFrame(expected_dict).select( - pl.col("ts").cast(pl.Datetime(time_unit)), - pl.col("value").cast(result["value"].dtype), - ) - assert_frame_equal(result, expected) diff --git a/py-polars/tests/parametric/test_lit.py b/py-polars/tests/parametric/test_lit.py deleted file mode 100644 index 73df1aa98012..000000000000 --- a/py-polars/tests/parametric/test_lit.py +++ /dev/null @@ -1,27 +0,0 @@ -from datetime import datetime - -from hypothesis import given - -import polars as pl -from polars.testing.parametric.strategies.data import datetimes - - -@given(value=datetimes("ns")) -def test_datetime_ns(value: datetime) -> None: - result = pl.select(pl.lit(value, dtype=pl.Datetime("ns")))["literal"][0] - assert result == value - - -@given(value=datetimes("us")) -def test_datetime_us(value: datetime) -> None: - result = pl.select(pl.lit(value, dtype=pl.Datetime("us")))["literal"][0] - assert result == value - result = pl.select(pl.lit(value, dtype=pl.Datetime))["literal"][0] - assert result == value - - -@given(value=datetimes("ms")) -def test_datetime_ms(value: datetime) -> None: - result = pl.select(pl.lit(value, dtype=pl.Datetime("ms")))["literal"][0] - expected_microsecond = value.microsecond // 1000 * 1000 - assert result == value.replace(microsecond=expected_microsecond) diff --git a/py-polars/tests/parametric/test_series.py b/py-polars/tests/parametric/test_series.py deleted file mode 100644 index cf298eda401c..000000000000 --- a/py-polars/tests/parametric/test_series.py +++ /dev/null @@ -1,83 +0,0 @@ -# ------------------------------------------------- -# Validate Series behaviour with parametric tests -# ------------------------------------------------- -from __future__ import annotations - -import hypothesis.strategies as st -import pytest -from hypothesis import given, settings - -import polars as pl -from polars.testing import assert_series_equal -from polars.testing.parametric import series - - -@given( - s=series(min_size=1, max_size=10, dtype=pl.Datetime), -) -def test_series_datetime_timeunits( - s: pl.Series, -) -> None: - # datetime - assert s.to_list() == list(s) - assert list(s.dt.millisecond()) == [v.microsecond // 1000 for v in s] - assert list(s.dt.nanosecond()) == [v.microsecond * 1000 for v in s] - assert list(s.dt.microsecond()) == [v.microsecond for v in s] - - -@given( - s=series(min_size=1, max_size=10, dtype=pl.Duration), -) -@pytest.mark.skip( - "These functions are currently bugged for large values: " - "https://github.com/pola-rs/polars/issues/16057" -) -def test_series_duration_timeunits( - s: pl.Series, -) -> None: - nanos = s.dt.total_nanoseconds().to_list() - micros = s.dt.total_microseconds().to_list() - millis = s.dt.total_milliseconds().to_list() - - scale = { - "ns": 1, - "us": 1_000, - "ms": 1_000_000, - } - assert nanos == [v * scale[s.dtype.time_unit] for v in s.to_physical()] # type: ignore[attr-defined] - assert micros == [int(v / 1_000) for v in nanos] - assert millis == [int(v / 1_000) for v in micros] - - # special handling for ns timeunit (as we may generate a microsecs-based - # timedelta that results in 64bit overflow on conversion to nanosecs) - lower_bound, upper_bound = -(2**63), (2**63) - 1 - if all( - (lower_bound <= (us * 1000) <= upper_bound) - for us in micros - if isinstance(us, int) - ): - for ns, us in zip(s.dt.total_nanoseconds(), micros): - assert ns == (us * 1000) - - -@given( - srs=series(max_size=10, dtype=pl.Int64), - start=st.sampled_from([-5, -4, -3, -2, -1, None, 0, 1, 2, 3, 4, 5]), - stop=st.sampled_from([-5, -4, -3, -2, -1, None, 0, 1, 2, 3, 4, 5]), - step=st.sampled_from([-5, -4, -3, -2, -1, None, 1, 2, 3, 4, 5]), -) -@settings(max_examples=500) -def test_series_slice( - srs: pl.Series, - start: int | None, - stop: int | None, - step: int | None, -) -> None: - py_data = srs.to_list() - - s = slice(start, stop, step) - sliced_py_data = py_data[s] - sliced_pl_data = srs[s].to_list() - - assert sliced_py_data == sliced_pl_data, f"slice [{start}:{stop}:{step}] failed" - assert_series_equal(srs, srs, check_exact=True) diff --git a/py-polars/tests/parametric/test_testing.py b/py-polars/tests/parametric/test_testing.py deleted file mode 100644 index 053cbf262e92..000000000000 --- a/py-polars/tests/parametric/test_testing.py +++ /dev/null @@ -1,243 +0,0 @@ -# ------------------------------------------------ -# Test/validate Polars' hypothesis strategy units -# ------------------------------------------------ -from __future__ import annotations - -from datetime import datetime -from typing import Any - -import hypothesis.strategies as st -import pytest -from hypothesis import given, settings -from hypothesis.errors import InvalidArgument - -import polars as pl -from polars.testing.parametric import column, dataframes, lists, series - -TEMPORAL_DTYPES = {pl.Date, pl.Time, pl.Datetime, pl.Duration} - - -@given(df=dataframes(), lf=dataframes(lazy=True), srs=series()) -@settings(max_examples=5) -def test_strategy_classes(df: pl.DataFrame, lf: pl.LazyFrame, srs: pl.Series) -> None: - assert isinstance(df, pl.DataFrame) - assert isinstance(lf, pl.LazyFrame) - assert isinstance(srs, pl.Series) - - -@given( - s1=series(dtype=pl.Boolean, size=5), - s2=series(dtype=pl.Boolean, min_size=3, max_size=8, name="col"), - df1=dataframes(allowed_dtypes=[pl.Boolean], cols=5, size=5), - df2=dataframes( - allowed_dtypes=[pl.Boolean], min_cols=2, max_cols=5, min_size=3, max_size=8 - ), -) -@settings(max_examples=50) -def test_strategy_shape( - s1: pl.Series, s2: pl.Series, df1: pl.DataFrame, df2: pl.DataFrame -) -> None: - assert df1.shape == (5, 5) - assert df1.columns == ["col0", "col1", "col2", "col3", "col4"] - - assert 2 <= len(df2.columns) <= 5 - assert 3 <= len(df2) <= 8 - - assert s1.len() == 5 - assert 3 <= s2.len() <= 8 - assert s1.name == "" - assert s2.name == "col" - - -@given( - lf=dataframes( - # generate lazyframes with at least one row - lazy=True, - min_size=1, - # test mix & match of bulk-assigned cols with custom cols - cols=[column(n, dtype=pl.UInt8, unique=True) for n in ["a", "b"]], - include_cols=[ - column("c", dtype=pl.Boolean), - column("d", strategy=st.sampled_from(["x", "y", "z"])), - ], - ) -) -def test_strategy_frame_columns(lf: pl.LazyFrame) -> None: - assert lf.schema == {"a": pl.UInt8, "b": pl.UInt8, "c": pl.Boolean, "d": pl.String} - assert lf.columns == ["a", "b", "c", "d"] - df = lf.collect() - - # confirm uint cols bounds - uint8_max = (2**8) - 1 - assert df["a"].min() >= 0 # type: ignore[operator] - assert df["b"].min() >= 0 # type: ignore[operator] - assert df["a"].max() <= uint8_max # type: ignore[operator] - assert df["b"].max() <= uint8_max # type: ignore[operator] - - # confirm uint cols uniqueness - assert df["a"].is_unique().all() - assert df["b"].is_unique().all() - - # boolean col - assert all(isinstance(v, bool) for v in df["c"].to_list()) - - # string col, entries selected from custom values - xyz = {"x", "y", "z"} - assert all(v in xyz for v in df["d"].to_list()) - - -@given( - df=dataframes(allowed_dtypes=TEMPORAL_DTYPES, max_size=1, max_cols=5), - lf=dataframes(excluded_dtypes=TEMPORAL_DTYPES, max_size=1, max_cols=5, lazy=True), - s1=series(dtype=pl.Boolean, max_size=1), - s2=series(allowed_dtypes=TEMPORAL_DTYPES, max_size=1), - s3=series(excluded_dtypes=TEMPORAL_DTYPES, max_size=1), -) -@settings(max_examples=50) -def test_strategy_dtypes( - df: pl.DataFrame, - lf: pl.LazyFrame, - s1: pl.Series, - s2: pl.Series, - s3: pl.Series, -) -> None: - # dataframe, lazyframe - assert all(tp.is_temporal() for tp in df.dtypes) - assert all(not tp.is_temporal() for tp in lf.dtypes) - - # series - assert s1.dtype == pl.Boolean - assert s2.dtype.is_temporal() - assert not s3.dtype.is_temporal() - - -@given(s=series()) -def test_series_allow_null_default(s: pl.Series) -> None: - assert s.null_count() == 0 - - -@given(s=series(allow_null=True)) -def test_series_allow_null(s: pl.Series) -> None: - assert 0 <= s.null_count() <= s.len() - - -@given(df=dataframes(cols=1, allow_null=True)) -def test_dataframes_allow_null_global(df: pl.DataFrame) -> None: - null_count = sum(df.null_count().row(0)) - assert 0 <= null_count <= df.height * df.width - - -@given(df=dataframes(cols=2, allow_null={"col0": True})) -def test_dataframes_allow_null_column(df: pl.DataFrame) -> None: - null_count = sum(df.null_count().row(0)) - assert 0 <= null_count <= df.height * df.width - - -@given( - df=dataframes( - cols=1, - allow_null=False, - include_cols=[column(name="colx", allow_null=True)], - ) -) -def test_dataframes_allow_null_override(df: pl.DataFrame) -> None: - assert df.get_column("col0").null_count() == 0 - assert 0 <= df.get_column("colx").null_count() <= df.height - - -@given( - df1=dataframes(chunked=False, min_size=1), - df2=dataframes(chunked=True, min_size=1), - s1=series(chunked=False, min_size=1), - s2=series(chunked=True, min_size=1), -) -@settings(max_examples=10) -def test_chunking( - df1: pl.DataFrame, - df2: pl.DataFrame, - s1: pl.Series, - s2: pl.Series, -) -> None: - assert df1.n_chunks() == 1 - if len(df2) > 1: - assert df2.n_chunks("all") == [2] * len(df2.columns) - - assert s1.n_chunks() == 1 - if len(s2) > 1: - assert s2.n_chunks() > 1 - - -@given( - df=dataframes( - allowed_dtypes=[pl.Float32, pl.Float64], max_cols=4, allow_infinity=False - ), - s=series(dtype=pl.Float64, allow_infinity=False), -) -def test_infinities( - df: pl.DataFrame, - s: pl.Series, -) -> None: - from math import isfinite, isnan - - def finite_float(value: Any) -> bool: - return isfinite(value) or isnan(value) - - assert all(finite_float(val) for val in s.to_list()) - for col in df.columns: - assert all(finite_float(val) for val in df[col].to_list()) - - -@given( - df=dataframes( - cols=[ - column("colx", dtype=pl.Array(pl.UInt8, width=3)), - column("coly", dtype=pl.List(pl.Datetime("ms"))), - column( - name="colz", - dtype=pl.List(pl.List(pl.String)), - strategy=lists( - inner_dtype=pl.List(pl.String), - select_from=["aa", "bb", "cc"], - min_len=1, - ), - ), - ] - ), -) -def test_sequence_strategies(df: pl.DataFrame) -> None: - assert df.schema == { - "colx": pl.Array(pl.UInt8, width=3), - "coly": pl.List(pl.Datetime("ms")), - "colz": pl.List(pl.List(pl.String)), - } - uint8_max = (2**8) - 1 - - for colx, coly, colz in df.iter_rows(): - assert len(colx) == 3 - assert all(i <= uint8_max for i in colx) - assert all(isinstance(d, datetime) for d in coly) - for inner_list in colz: - assert all(s in ("aa", "bb", "cc") for s in inner_list) - - -@pytest.mark.hypothesis() -def test_column_invalid_probability() -> None: - with pytest.deprecated_call(), pytest.raises(InvalidArgument): - column("col", null_probability=2.0) - - -@pytest.mark.hypothesis() -def test_column_null_probability_deprecated() -> None: - with pytest.deprecated_call(): - col = column("col", allow_null=False, null_probability=0.5) - assert col.null_probability == 0.5 - assert col.allow_null is True # null_probability takes precedence - - -@given(st.data()) -def test_allow_infinities_deprecated(data: st.DataObject) -> None: - with pytest.deprecated_call(): - strategy = series(dtype=pl.Float64, allow_infinities=False) - s = data.draw(strategy) - - assert all(v not in (float("inf"), float("-inf")) for v in s) diff --git a/py-polars/tests/parametric/time_series/test_add_business_days.py b/py-polars/tests/parametric/time_series/test_add_business_days.py deleted file mode 100644 index a4328c4efdd1..000000000000 --- a/py-polars/tests/parametric/time_series/test_add_business_days.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import annotations - -import datetime as dt -from typing import TYPE_CHECKING - -import hypothesis.strategies as st -import numpy as np -from hypothesis import assume, given - -import polars as pl - -if TYPE_CHECKING: - from polars.type_aliases import Roll - - -@given( - start=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), - n=st.integers(min_value=-100, max_value=100), - week_mask=st.lists( - st.sampled_from([True, False]), - min_size=7, - max_size=7, - ), - holidays=st.lists( - st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), - min_size=0, - max_size=100, - ), - roll=st.sampled_from(["forward", "backward"]), -) -def test_against_np_busday_offset( - start: dt.date, - n: int, - week_mask: tuple[bool, ...], - holidays: list[dt.date], - roll: Roll, -) -> None: - assume(any(week_mask)) - result = ( - pl.DataFrame({"start": [start]}) - .select( - res=pl.col("start").dt.add_business_days( - n, week_mask=week_mask, holidays=holidays, roll=roll - ) - )["res"] - .item() - ) - expected = np.busday_offset( - start, n, weekmask=week_mask, holidays=holidays, roll=roll - ) - assert result == expected diff --git a/py-polars/tests/parametric/time_series/test_business_day_count.py b/py-polars/tests/parametric/time_series/test_business_day_count.py deleted file mode 100644 index 437e8a7208a8..000000000000 --- a/py-polars/tests/parametric/time_series/test_business_day_count.py +++ /dev/null @@ -1,44 +0,0 @@ -from __future__ import annotations - -import datetime as dt - -import hypothesis.strategies as st -import numpy as np -from hypothesis import assume, given, reject - -import polars as pl -from polars._utils.various import parse_version - - -@given( - start=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), - end=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), - week_mask=st.lists( - st.sampled_from([True, False]), - min_size=7, - max_size=7, - ), - holidays=st.lists( - st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), - min_size=0, - max_size=100, - ), -) -def test_against_np_busday_count( - start: dt.date, end: dt.date, week_mask: tuple[bool, ...], holidays: list[dt.date] -) -> None: - assume(any(week_mask)) - result = ( - pl.DataFrame({"start": [start], "end": [end]}) - .select( - n=pl.business_day_count( - "start", "end", week_mask=week_mask, holidays=holidays - ) - )["n"] - .item() - ) - expected = np.busday_count(start, end, weekmask=week_mask, holidays=holidays) - if start > end and parse_version(np.__version__) < parse_version("1.25"): - # Bug in old versions of numpy - reject() - assert result == expected diff --git a/py-polars/tests/parametric/test_dataframe.py b/py-polars/tests/unit/dataframe/test_getitem.py similarity index 60% rename from py-polars/tests/parametric/test_dataframe.py rename to py-polars/tests/unit/dataframe/test_getitem.py index a69c5880407e..1e4dd95fed3e 100644 --- a/py-polars/tests/parametric/test_dataframe.py +++ b/py-polars/tests/unit/dataframe/test_getitem.py @@ -1,62 +1,12 @@ -# ---------------------------------------------------- -# Validate DataFrame behaviour with parametric tests -# ---------------------------------------------------- from __future__ import annotations import hypothesis.strategies as st -from hypothesis import example, given, settings +from hypothesis import given import polars as pl -from polars.testing import assert_frame_equal from polars.testing.parametric import column, dataframes -@given(df=dataframes()) -def test_repr(df: pl.DataFrame) -> None: - assert isinstance(repr(df), str) - - -@given(df=dataframes()) -def test_equal(df: pl.DataFrame) -> None: - assert_frame_equal(df, df.clone(), check_exact=True) - - -@given( - df=dataframes( - cols=10, - max_size=1, - allowed_dtypes=[pl.Int8, pl.UInt16, pl.List(pl.Int32)], - ) -) -@settings(max_examples=3) -def test_dtype_integer_cols(df: pl.DataFrame) -> None: - # ensure dtype constraint works in conjunction with 'n' cols - assert all( - tp in (pl.Int8, pl.UInt16, pl.List(pl.Int32)) for tp in df.schema.values() - ) - - -@given( - df=dataframes( - min_size=1, - min_cols=1, - allow_null=True, - excluded_dtypes=[pl.String, pl.List], - ) -) -@example(df=pl.DataFrame(schema=["x", "y", "z"])) -@example(df=pl.DataFrame()) -def test_null_count(df: pl.DataFrame) -> None: - # note: the zero-row and zero-col cases are always passed as explicit examples - null_count, ncols = df.null_count(), len(df.columns) - if ncols == 0: - assert null_count.shape == (0, 0) - else: - assert null_count.shape == (1, ncols) - for idx, count in enumerate(null_count.rows()[0]): - assert count == sum(v is None for v in df.to_series(idx).to_list()) - - @given( df=dataframes( max_size=10, diff --git a/py-polars/tests/unit/dataframe/test_null_count.py b/py-polars/tests/unit/dataframe/test_null_count.py new file mode 100644 index 000000000000..11755bbdcb9b --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_null_count.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from hypothesis import example, given + +import polars as pl +from polars.testing.parametric import dataframes + + +@given( + df=dataframes( + min_size=1, + min_cols=1, + allow_null=True, + excluded_dtypes=[pl.String, pl.List], + ) +) +@example(df=pl.DataFrame(schema=["x", "y", "z"])) +@example(df=pl.DataFrame()) +def test_null_count(df: pl.DataFrame) -> None: + # note: the zero-row and zero-col cases are always passed as explicit examples + null_count, ncols = df.null_count(), len(df.columns) + if ncols == 0: + assert null_count.shape == (0, 0) + else: + assert null_count.shape == (1, ncols) + for idx, count in enumerate(null_count.rows()[0]): + assert count == sum(v is None for v in df.to_series(idx).to_list()) diff --git a/py-polars/tests/unit/dataframe/test_repr.py b/py-polars/tests/unit/dataframe/test_repr.py new file mode 100644 index 000000000000..e0a137c718ef --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_repr.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from hypothesis import given + +from polars.testing.parametric import dataframes + +if TYPE_CHECKING: + import polars as pl + + +@given(df=dataframes()) +def test_repr(df: pl.DataFrame) -> None: + assert isinstance(repr(df), str) diff --git a/py-polars/tests/unit/functions/business/test_business_day_count.py b/py-polars/tests/unit/functions/test_business_day_count.py similarity index 77% rename from py-polars/tests/unit/functions/business/test_business_day_count.py rename to py-polars/tests/unit/functions/test_business_day_count.py index f1ce93268810..7101d525016a 100644 --- a/py-polars/tests/unit/functions/business/test_business_day_count.py +++ b/py-polars/tests/unit/functions/test_business_day_count.py @@ -1,8 +1,15 @@ +from __future__ import annotations + +import datetime as dt from datetime import date +import hypothesis.strategies as st +import numpy as np import pytest +from hypothesis import assume, given, reject import polars as pl +from polars._utils.various import parse_version from polars.testing import assert_series_equal @@ -120,3 +127,37 @@ def test_business_day_count_w_holidays() -> None: )["business_day_count"] expected = pl.Series("business_day_count", [0, 5, 5], pl.Int32) assert_series_equal(result, expected) + + +@given( + start=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), + end=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), + week_mask=st.lists( + st.sampled_from([True, False]), + min_size=7, + max_size=7, + ), + holidays=st.lists( + st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), + min_size=0, + max_size=100, + ), +) +def test_against_np_busday_count( + start: dt.date, end: dt.date, week_mask: tuple[bool, ...], holidays: list[dt.date] +) -> None: + assume(any(week_mask)) + result = ( + pl.DataFrame({"start": [start], "end": [end]}) + .select( + n=pl.business_day_count( + "start", "end", week_mask=week_mask, holidays=holidays + ) + )["n"] + .item() + ) + expected = np.busday_count(start, end, weekmask=week_mask, holidays=holidays) + if start > end and parse_version(np.__version__) < parse_version("1.25"): + # Bug in old versions of numpy + reject() + assert result == expected diff --git a/py-polars/tests/parametric/time_series/test_ewm_by.py b/py-polars/tests/unit/functions/test_ewm_by.py similarity index 100% rename from py-polars/tests/parametric/time_series/test_ewm_by.py rename to py-polars/tests/unit/functions/test_ewm_by.py diff --git a/py-polars/tests/unit/functions/test_lit.py b/py-polars/tests/unit/functions/test_lit.py index 351e7fc9715c..5c6fbe1d6ab3 100644 --- a/py-polars/tests/unit/functions/test_lit.py +++ b/py-polars/tests/unit/functions/test_lit.py @@ -5,9 +5,11 @@ import numpy as np import pytest +from hypothesis import given import polars as pl from polars.testing import assert_frame_equal +from polars.testing.parametric.strategies.data import datetimes @pytest.mark.parametrize( @@ -95,3 +97,24 @@ def test_lit_unsupported_type() -> None: match="cannot create expression literal for value of type LazyFrame: ", ): pl.lit(pl.LazyFrame({"a": [1, 2, 3]})) + + +@given(value=datetimes("ns")) +def test_datetime_ns(value: datetime) -> None: + result = pl.select(pl.lit(value, dtype=pl.Datetime("ns")))["literal"][0] + assert result == value + + +@given(value=datetimes("us")) +def test_datetime_us(value: datetime) -> None: + result = pl.select(pl.lit(value, dtype=pl.Datetime("us")))["literal"][0] + assert result == value + result = pl.select(pl.lit(value, dtype=pl.Datetime))["literal"][0] + assert result == value + + +@given(value=datetimes("ms")) +def test_datetime_ms(value: datetime) -> None: + result = pl.select(pl.lit(value, dtype=pl.Datetime("ms")))["literal"][0] + expected_microsecond = value.microsecond // 1000 * 1000 + assert result == value.replace(microsecond=expected_microsecond) diff --git a/py-polars/tests/parametric/test_lazyframe.py b/py-polars/tests/unit/lazyframe/test_getitem.py similarity index 97% rename from py-polars/tests/parametric/test_lazyframe.py rename to py-polars/tests/unit/lazyframe/test_getitem.py index a40a5fbc3d22..b3f27bd05d58 100644 --- a/py-polars/tests/parametric/test_lazyframe.py +++ b/py-polars/tests/unit/lazyframe/test_getitem.py @@ -47,7 +47,7 @@ } ) ) -def test_lazyframe_slice(ldf: pl.LazyFrame) -> None: +def test_lazyframe_getitem(ldf: pl.LazyFrame) -> None: py_data = ldf.collect().rows() for start, stop, step, _ in py_data: diff --git a/py-polars/tests/unit/namespaces/__init__.py b/py-polars/tests/unit/operations/namespaces/__init__.py similarity index 100% rename from py-polars/tests/unit/namespaces/__init__.py rename to py-polars/tests/unit/operations/namespaces/__init__.py diff --git a/py-polars/tests/parametric/__init__.py b/py-polars/tests/unit/operations/namespaces/array/__init__.py similarity index 100% rename from py-polars/tests/parametric/__init__.py rename to py-polars/tests/unit/operations/namespaces/array/__init__.py diff --git a/py-polars/tests/unit/namespaces/array/test_array.py b/py-polars/tests/unit/operations/namespaces/array/test_array.py similarity index 100% rename from py-polars/tests/unit/namespaces/array/test_array.py rename to py-polars/tests/unit/operations/namespaces/array/test_array.py diff --git a/py-polars/tests/unit/namespaces/array/test_contains.py b/py-polars/tests/unit/operations/namespaces/array/test_contains.py similarity index 100% rename from py-polars/tests/unit/namespaces/array/test_contains.py rename to py-polars/tests/unit/operations/namespaces/array/test_contains.py diff --git a/py-polars/tests/unit/namespaces/array/test_to_list.py b/py-polars/tests/unit/operations/namespaces/array/test_to_list.py similarity index 100% rename from py-polars/tests/unit/namespaces/array/test_to_list.py rename to py-polars/tests/unit/operations/namespaces/array/test_to_list.py diff --git a/py-polars/tests/unit/namespaces/conftest.py b/py-polars/tests/unit/operations/namespaces/conftest.py similarity index 100% rename from py-polars/tests/unit/namespaces/conftest.py rename to py-polars/tests/unit/operations/namespaces/conftest.py diff --git a/py-polars/tests/unit/namespaces/files/test_tree_fmt.txt b/py-polars/tests/unit/operations/namespaces/files/test_tree_fmt.txt similarity index 100% rename from py-polars/tests/unit/namespaces/files/test_tree_fmt.txt rename to py-polars/tests/unit/operations/namespaces/files/test_tree_fmt.txt diff --git a/py-polars/tests/unit/namespaces/array/__init__.py b/py-polars/tests/unit/operations/namespaces/list/__init__.py similarity index 100% rename from py-polars/tests/unit/namespaces/array/__init__.py rename to py-polars/tests/unit/operations/namespaces/list/__init__.py diff --git a/py-polars/tests/unit/namespaces/list/test_list.py b/py-polars/tests/unit/operations/namespaces/list/test_list.py similarity index 100% rename from py-polars/tests/unit/namespaces/list/test_list.py rename to py-polars/tests/unit/operations/namespaces/list/test_list.py diff --git a/py-polars/tests/unit/namespaces/list/test_set_operations.py b/py-polars/tests/unit/operations/namespaces/list/test_set_operations.py similarity index 100% rename from py-polars/tests/unit/namespaces/list/test_set_operations.py rename to py-polars/tests/unit/operations/namespaces/list/test_set_operations.py diff --git a/py-polars/tests/unit/namespaces/list/__init__.py b/py-polars/tests/unit/operations/namespaces/string/__init__.py similarity index 100% rename from py-polars/tests/unit/namespaces/list/__init__.py rename to py-polars/tests/unit/operations/namespaces/string/__init__.py diff --git a/py-polars/tests/unit/namespaces/string/test_concat.py b/py-polars/tests/unit/operations/namespaces/string/test_concat.py similarity index 100% rename from py-polars/tests/unit/namespaces/string/test_concat.py rename to py-polars/tests/unit/operations/namespaces/string/test_concat.py diff --git a/py-polars/tests/unit/namespaces/string/test_pad.py b/py-polars/tests/unit/operations/namespaces/string/test_pad.py similarity index 100% rename from py-polars/tests/unit/namespaces/string/test_pad.py rename to py-polars/tests/unit/operations/namespaces/string/test_pad.py diff --git a/py-polars/tests/unit/namespaces/string/test_string.py b/py-polars/tests/unit/operations/namespaces/string/test_string.py similarity index 100% rename from py-polars/tests/unit/namespaces/string/test_string.py rename to py-polars/tests/unit/operations/namespaces/string/test_string.py diff --git a/py-polars/tests/unit/namespaces/string/__init__.py b/py-polars/tests/unit/operations/namespaces/temporal/__init__.py similarity index 100% rename from py-polars/tests/unit/namespaces/string/__init__.py rename to py-polars/tests/unit/operations/namespaces/temporal/__init__.py diff --git a/py-polars/tests/unit/functions/business/test_add_business_days.py b/py-polars/tests/unit/operations/namespaces/temporal/test_add_business_days.py similarity index 86% rename from py-polars/tests/unit/functions/business/test_add_business_days.py rename to py-polars/tests/unit/operations/namespaces/temporal/test_add_business_days.py index 9e082a9f5109..fb86302fc2a4 100644 --- a/py-polars/tests/unit/functions/business/test_add_business_days.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_add_business_days.py @@ -1,15 +1,19 @@ from __future__ import annotations +import datetime as dt from datetime import date, datetime, timedelta from typing import TYPE_CHECKING +import hypothesis.strategies as st +import numpy as np import pytest +from hypothesis import assume, given import polars as pl from polars.testing import assert_series_equal if TYPE_CHECKING: - from polars.type_aliases import TimeUnit + from polars.type_aliases import Roll, TimeUnit def test_add_business_days() -> None: @@ -234,3 +238,41 @@ def test_add_business_days_w_nulls() -> None: )["result"] expected = pl.Series("result", [None], dtype=pl.Date) assert_series_equal(result, expected) + + +@given( + start=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), + n=st.integers(min_value=-100, max_value=100), + week_mask=st.lists( + st.sampled_from([True, False]), + min_size=7, + max_size=7, + ), + holidays=st.lists( + st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), + min_size=0, + max_size=100, + ), + roll=st.sampled_from(["forward", "backward"]), +) +def test_against_np_busday_offset( + start: dt.date, + n: int, + week_mask: tuple[bool, ...], + holidays: list[dt.date], + roll: Roll, +) -> None: + assume(any(week_mask)) + result = ( + pl.DataFrame({"start": [start]}) + .select( + res=pl.col("start").dt.add_business_days( + n, week_mask=week_mask, holidays=holidays, roll=roll + ) + )["res"] + .item() + ) + expected = np.busday_offset( + start, n, weekmask=week_mask, holidays=holidays, roll=roll + ) + assert result == expected diff --git a/py-polars/tests/unit/namespaces/test_datetime.py b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py similarity index 96% rename from py-polars/tests/unit/namespaces/test_datetime.py rename to py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py index 5307a2e7e129..bdc31aa00d59 100644 --- a/py-polars/tests/unit/namespaces/test_datetime.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py @@ -4,11 +4,13 @@ from typing import TYPE_CHECKING import pytest +from hypothesis import given import polars as pl from polars.datatypes import DTYPE_TEMPORAL_UNITS from polars.exceptions import ComputeError, InvalidOperationError from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import series if TYPE_CHECKING: from zoneinfo import ZoneInfo @@ -1310,3 +1312,51 @@ def test_agg_median_expr() -> None: ) assert_frame_equal(df.select(pl.all().median()), expected) + + +@given( + s=series(min_size=1, max_size=10, dtype=pl.Duration), +) +@pytest.mark.skip( + "These functions are currently bugged for large values: " + "https://github.com/pola-rs/polars/issues/16057" +) +def test_series_duration_timeunits( + s: pl.Series, +) -> None: + nanos = s.dt.total_nanoseconds().to_list() + micros = s.dt.total_microseconds().to_list() + millis = s.dt.total_milliseconds().to_list() + + scale = { + "ns": 1, + "us": 1_000, + "ms": 1_000_000, + } + assert nanos == [v * scale[s.dtype.time_unit] for v in s.to_physical()] # type: ignore[attr-defined] + assert micros == [int(v / 1_000) for v in nanos] + assert millis == [int(v / 1_000) for v in micros] + + # special handling for ns timeunit (as we may generate a microsecs-based + # timedelta that results in 64bit overflow on conversion to nanosecs) + lower_bound, upper_bound = -(2**63), (2**63) - 1 + if all( + (lower_bound <= (us * 1000) <= upper_bound) + for us in micros + if isinstance(us, int) + ): + for ns, us in zip(s.dt.total_nanoseconds(), micros): + assert ns == (us * 1000) + + +@given( + s=series(min_size=1, max_size=10, dtype=pl.Datetime), +) +def test_series_datetime_timeunits( + s: pl.Series, +) -> None: + # datetime + assert s.to_list() == list(s) + assert list(s.dt.millisecond()) == [v.microsecond // 1000 for v in s] + assert list(s.dt.nanosecond()) == [v.microsecond * 1000 for v in s] + assert list(s.dt.microsecond()) == [v.microsecond for v in s] diff --git a/py-polars/tests/parametric/time_series/test_to_datetime.py b/py-polars/tests/unit/operations/namespaces/temporal/test_to_datetime.py similarity index 100% rename from py-polars/tests/parametric/time_series/test_to_datetime.py rename to py-polars/tests/unit/operations/namespaces/temporal/test_to_datetime.py diff --git a/py-polars/tests/parametric/time_series/test_truncate.py b/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py similarity index 100% rename from py-polars/tests/parametric/time_series/test_truncate.py rename to py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py diff --git a/py-polars/tests/unit/namespaces/test_binary.py b/py-polars/tests/unit/operations/namespaces/test_binary.py similarity index 100% rename from py-polars/tests/unit/namespaces/test_binary.py rename to py-polars/tests/unit/operations/namespaces/test_binary.py diff --git a/py-polars/tests/unit/namespaces/test_categorical.py b/py-polars/tests/unit/operations/namespaces/test_categorical.py similarity index 100% rename from py-polars/tests/unit/namespaces/test_categorical.py rename to py-polars/tests/unit/operations/namespaces/test_categorical.py diff --git a/py-polars/tests/unit/namespaces/test_meta.py b/py-polars/tests/unit/operations/namespaces/test_meta.py similarity index 100% rename from py-polars/tests/unit/namespaces/test_meta.py rename to py-polars/tests/unit/operations/namespaces/test_meta.py diff --git a/py-polars/tests/unit/namespaces/test_name.py b/py-polars/tests/unit/operations/namespaces/test_name.py similarity index 100% rename from py-polars/tests/unit/namespaces/test_name.py rename to py-polars/tests/unit/operations/namespaces/test_name.py diff --git a/py-polars/tests/unit/namespaces/test_plot.py b/py-polars/tests/unit/operations/namespaces/test_plot.py similarity index 100% rename from py-polars/tests/unit/namespaces/test_plot.py rename to py-polars/tests/unit/operations/namespaces/test_plot.py diff --git a/py-polars/tests/unit/namespaces/test_strptime.py b/py-polars/tests/unit/operations/namespaces/test_strptime.py similarity index 100% rename from py-polars/tests/unit/namespaces/test_strptime.py rename to py-polars/tests/unit/operations/namespaces/test_strptime.py diff --git a/py-polars/tests/unit/namespaces/test_struct.py b/py-polars/tests/unit/operations/namespaces/test_struct.py similarity index 100% rename from py-polars/tests/unit/namespaces/test_struct.py rename to py-polars/tests/unit/operations/namespaces/test_struct.py diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index 8898c8a29d31..8fa1b87bbe86 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -3,15 +3,22 @@ from datetime import date, datetime, timedelta from typing import TYPE_CHECKING +import hypothesis.strategies as st import numpy as np import pytest +from hypothesis import assume, given from numpy import nan import polars as pl +from polars._utils.convert import parse_as_duration_string from polars.exceptions import ComputeError, InvalidOperationError from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import column, dataframes +from polars.testing.parametric.strategies.dtype import _time_units if TYPE_CHECKING: + from hypothesis.strategies import SearchStrategy + from polars.type_aliases import ClosedInterval, PolarsDataType, TimeUnit @@ -1012,3 +1019,154 @@ def test_temporal_windows_size_without_by_15977() -> None: match=r"Passing a str to `rolling_\*` is deprecated" ): df.select(pl.col("a").rolling_mean("3d")) + + +def interval_defs() -> SearchStrategy[ClosedInterval]: + closed: list[ClosedInterval] = ["left", "right", "both", "none"] + return st.sampled_from(closed) + + +@given( + period=st.timedeltas( + min_value=timedelta(microseconds=0), max_value=timedelta(days=1000) + ).map(parse_as_duration_string), + offset=st.timedeltas( + min_value=timedelta(days=-1000), max_value=timedelta(days=1000) + ).map(parse_as_duration_string), + closed=interval_defs(), + data=st.data(), + time_unit=_time_units(), +) +def test_rolling_parametric( + period: str, + offset: str, + closed: ClosedInterval, + data: st.DataObject, + time_unit: TimeUnit, +) -> None: + assume(period != "") + dataframe = data.draw( + dataframes( + [ + column( + "ts", + strategy=st.datetimes( + min_value=datetime(2000, 1, 1), + max_value=datetime(2001, 1, 1), + ), + dtype=pl.Datetime(time_unit), + ), + column( + "value", + strategy=st.integers(min_value=-100, max_value=100), + dtype=pl.Int64, + ), + ], + min_size=1, + ) + ) + df = dataframe.sort("ts") + result = df.rolling("ts", period=period, offset=offset, closed=closed).agg( + pl.col("value") + ) + + expected_dict: dict[str, list[object]] = {"ts": [], "value": []} + for ts, _ in df.iter_rows(): + window = df.filter( + pl.col("ts").is_between( + pl.lit(ts, dtype=pl.Datetime(time_unit)).dt.offset_by(offset), + pl.lit(ts, dtype=pl.Datetime(time_unit)) + .dt.offset_by(offset) + .dt.offset_by(period), + closed=closed, + ) + ) + value = window["value"].to_list() + expected_dict["ts"].append(ts) + expected_dict["value"].append(value) + expected = pl.DataFrame(expected_dict).select( + pl.col("ts").cast(pl.Datetime(time_unit)), + pl.col("value").cast(pl.List(pl.Int64)), + ) + assert_frame_equal(result, expected) + + +@given( + window_size=st.timedeltas( + min_value=timedelta(microseconds=0), max_value=timedelta(days=2) + ).map(parse_as_duration_string), + closed=interval_defs(), + data=st.data(), + time_unit=_time_units(), + aggregation=st.sampled_from( + [ + "min", + "max", + "mean", + "sum", + "std", + "var", + "median", + ] + ), +) +def test_rolling_aggs( + window_size: str, + closed: ClosedInterval, + data: st.DataObject, + time_unit: TimeUnit, + aggregation: str, +) -> None: + assume(window_size != "") + + # Testing logic can be faulty when window is more precise than time unit + # https://github.com/pola-rs/polars/issues/11754 + assume(not (time_unit == "ms" and "us" in window_size)) + + dataframe = data.draw( + dataframes( + [ + column( + "ts", + strategy=st.datetimes( + min_value=datetime(2000, 1, 1), + max_value=datetime(2001, 1, 1), + ), + dtype=pl.Datetime(time_unit), + ), + column( + "value", + strategy=st.integers(min_value=-100, max_value=100), + dtype=pl.Int64, + ), + ], + ) + ) + df = dataframe.sort("ts") + func = f"rolling_{aggregation}_by" + result = df.with_columns( + getattr(pl.col("value"), func)("ts", window_size=window_size, closed=closed) + ) + + expected_dict: dict[str, list[object]] = {"ts": [], "value": []} + for ts, _ in df.iter_rows(): + window = df.filter( + pl.col("ts").is_between( + pl.lit(ts, dtype=pl.Datetime(time_unit)).dt.offset_by( + f"-{window_size}" + ), + pl.lit(ts, dtype=pl.Datetime(time_unit)), + closed=closed, + ) + ) + expected_dict["ts"].append(ts) + if window.is_empty(): + expected_dict["value"].append(None) + else: + value = getattr(window["value"], aggregation)() + expected_dict["value"].append(value) + expected = pl.DataFrame(expected_dict).select( + pl.col("ts").cast(pl.Datetime(time_unit)), + pl.col("value").cast(result["value"].dtype), + ) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/series/test_getitem.py b/py-polars/tests/unit/series/test_getitem.py new file mode 100644 index 000000000000..07fb8979f211 --- /dev/null +++ b/py-polars/tests/unit/series/test_getitem.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import hypothesis.strategies as st +from hypothesis import given + +import polars as pl +from polars.testing import assert_series_equal +from polars.testing.parametric import series + + +@given( + srs=series(max_size=10, dtype=pl.Int64), + start=st.sampled_from([-5, -4, -3, -2, -1, None, 0, 1, 2, 3, 4, 5]), + stop=st.sampled_from([-5, -4, -3, -2, -1, None, 0, 1, 2, 3, 4, 5]), + step=st.sampled_from([-5, -4, -3, -2, -1, None, 1, 2, 3, 4, 5]), +) +def test_series_getitem( + srs: pl.Series, + start: int | None, + stop: int | None, + step: int | None, +) -> None: + py_data = srs.to_list() + + s = slice(start, stop, step) + sliced_py_data = py_data[s] + sliced_pl_data = srs[s].to_list() + + assert sliced_py_data == sliced_pl_data, f"slice [{start}:{stop}:{step}] failed" + assert_series_equal(srs, srs, check_exact=True) diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_core.py b/py-polars/tests/unit/testing/parametric/strategies/test_core.py index 3917197fe3e4..f32807f52452 100644 --- a/py-polars/tests/unit/testing/parametric/strategies/test_core.py +++ b/py-polars/tests/unit/testing/parametric/strategies/test_core.py @@ -1,24 +1,281 @@ +from datetime import datetime +from typing import Any + import hypothesis.strategies as st +import pytest from hypothesis import given, settings +from hypothesis.errors import InvalidArgument import polars as pl -from polars.testing.parametric.strategies import dtypes, series +from polars.testing.parametric import ( + column, + dataframes, + dtypes, + lists, + series, +) + +TEMPORAL_DTYPES = {pl.Date, pl.Time, pl.Datetime, pl.Duration} + + +@given(s=series()) +@settings(max_examples=5) +def test_series_defaults(s: pl.Series) -> None: + assert isinstance(s, pl.Series) + assert s.name == "" + assert s.null_count() == 0 + + +@given(s=series(name="hello")) +@settings(max_examples=5) +def test_series_name(s: pl.Series) -> None: + assert s.name == "hello" @given(st.data()) -def test_dtype(data: st.DataObject) -> None: +def test_series_dtype(data: st.DataObject) -> None: dtype = data.draw(dtypes()) s = data.draw(series(dtype=dtype)) assert s.dtype == dtype -@given(s=series(dtype=pl.Binary)) +@given(s=series(dtype=pl.Boolean, size=5)) +@settings(max_examples=5) +def test_series_size(s: pl.Series) -> None: + assert s.len() == 5 + + +@given(s=series(min_size=3, max_size=8)) +@settings(max_examples=5) +def test_series_size_range(s: pl.Series) -> None: + assert 3 <= s.len() <= 8 + + +@given(s=series(allow_null=True)) +def test_series_allow_null(s: pl.Series) -> None: + assert 0 <= s.null_count() <= s.len() + + +@given(df=dataframes()) +@settings(max_examples=5) +def test_dataframes_defaults(df: pl.DataFrame) -> None: + assert isinstance(df, pl.DataFrame) + assert df.columns == [f"col{i}" for i in range(df.width)] + + +@given(lf=dataframes(lazy=True)) @settings(max_examples=5) -def test_strategy_dtype_binary(s: pl.Series) -> None: - assert s.dtype == pl.Binary +def test_dataframes_lazy(lf: pl.LazyFrame) -> None: + assert isinstance(lf, pl.LazyFrame) -@given(s=series(dtype=pl.Decimal)) +@given(df=dataframes(cols=3, size=5)) @settings(max_examples=5) -def test_strategy_dtype_decimal(s: pl.Series) -> None: - assert s.dtype == pl.Decimal +def test_dataframes_size(df: pl.DataFrame) -> None: + assert df.height == 5 + assert df.width == 3 + + +@given(df=dataframes(min_cols=2, max_cols=5, min_size=3, max_size=8)) +@settings(max_examples=5) +def test_dataframes_size_range(df: pl.DataFrame) -> None: + assert 3 <= df.height <= 8 + assert 2 <= df.width <= 5 + + +@given(df=dataframes(cols=1, allow_null=True)) +@settings(max_examples=5) +def test_dataframes_allow_null_global(df: pl.DataFrame) -> None: + null_count = sum(df.null_count().row(0)) + assert 0 <= null_count <= df.height * df.width + + +@given(df=dataframes(cols=2, allow_null={"col0": True})) +@settings(max_examples=5) +def test_dataframes_allow_null_column(df: pl.DataFrame) -> None: + null_count = sum(df.null_count().row(0)) + assert 0 <= null_count <= df.height * df.width + + +@given( + df=dataframes( + cols=1, + allow_null=False, + include_cols=[column(name="colx", allow_null=True)], + ) +) +def test_dataframes_allow_null_override(df: pl.DataFrame) -> None: + assert df.get_column("col0").null_count() == 0 + assert 0 <= df.get_column("colx").null_count() <= df.height + + +@given( + lf=dataframes( + # generate lazyframes with at least one row + lazy=True, + min_size=1, + # test mix & match of bulk-assigned cols with custom cols + cols=[column(n, dtype=pl.UInt8, unique=True) for n in ["a", "b"]], + include_cols=[ + column("c", dtype=pl.Boolean), + column("d", strategy=st.sampled_from(["x", "y", "z"])), + ], + ) +) +def test_dataframes_columns(lf: pl.LazyFrame) -> None: + assert lf.schema == {"a": pl.UInt8, "b": pl.UInt8, "c": pl.Boolean, "d": pl.String} + assert lf.columns == ["a", "b", "c", "d"] + df = lf.collect() + + # confirm uint cols bounds + uint8_max = (2**8) - 1 + assert df["a"].min() >= 0 # type: ignore[operator] + assert df["b"].min() >= 0 # type: ignore[operator] + assert df["a"].max() <= uint8_max # type: ignore[operator] + assert df["b"].max() <= uint8_max # type: ignore[operator] + + # confirm uint cols uniqueness + assert df["a"].is_unique().all() + assert df["b"].is_unique().all() + + # boolean col + assert all(isinstance(v, bool) for v in df["c"].to_list()) + + # string col, entries selected from custom values + xyz = {"x", "y", "z"} + assert all(v in xyz for v in df["d"].to_list()) + + +@pytest.mark.hypothesis() +def test_column_invalid_probability() -> None: + with pytest.deprecated_call(), pytest.raises(InvalidArgument): + column("col", null_probability=2.0) + + +@pytest.mark.hypothesis() +def test_column_null_probability_deprecated() -> None: + with pytest.deprecated_call(): + col = column("col", allow_null=False, null_probability=0.5) + assert col.null_probability == 0.5 + assert col.allow_null is True # null_probability takes precedence + + +@given(st.data()) +def test_allow_infinities_deprecated(data: st.DataObject) -> None: + with pytest.deprecated_call(): + strategy = series(dtype=pl.Float64, allow_infinities=False) + s = data.draw(strategy) + + assert all(v not in (float("inf"), float("-inf")) for v in s) + + +@given( + df=dataframes( + cols=[ + column("colx", dtype=pl.Array(pl.UInt8, width=3)), + column("coly", dtype=pl.List(pl.Datetime("ms"))), + column( + name="colz", + dtype=pl.List(pl.List(pl.String)), + strategy=lists( + inner_dtype=pl.List(pl.String), + select_from=["aa", "bb", "cc"], + min_len=1, + ), + ), + ] + ), +) +def test_dataframes_nested_strategies(df: pl.DataFrame) -> None: + assert df.schema == { + "colx": pl.Array(pl.UInt8, width=3), + "coly": pl.List(pl.Datetime("ms")), + "colz": pl.List(pl.List(pl.String)), + } + uint8_max = (2**8) - 1 + + for colx, coly, colz in df.iter_rows(): + assert len(colx) == 3 + assert all(i <= uint8_max for i in colx) + assert all(isinstance(d, datetime) for d in coly) + for inner_list in colz: + assert all(s in ("aa", "bb", "cc") for s in inner_list) + + +@given( + df=dataframes(allowed_dtypes=TEMPORAL_DTYPES, max_size=1, max_cols=5), + lf=dataframes(excluded_dtypes=TEMPORAL_DTYPES, max_size=1, max_cols=5, lazy=True), + s1=series(allowed_dtypes=TEMPORAL_DTYPES, max_size=1), + s2=series(excluded_dtypes=TEMPORAL_DTYPES, max_size=1), +) +@settings(max_examples=50) +def test_strategy_dtypes( + df: pl.DataFrame, + lf: pl.LazyFrame, + s1: pl.Series, + s2: pl.Series, +) -> None: + # dataframe, lazyframe + assert all(tp.is_temporal() for tp in df.dtypes) + assert all(not tp.is_temporal() for tp in lf.dtypes) + + # series + assert s1.dtype.is_temporal() + assert not s2.dtype.is_temporal() + + +@given( + df1=dataframes(chunked=False, min_size=1), + df2=dataframes(chunked=True, min_size=1), + s1=series(chunked=False, min_size=1), + s2=series(chunked=True, min_size=1), +) +@settings(max_examples=10) +def test_chunking( + df1: pl.DataFrame, + df2: pl.DataFrame, + s1: pl.Series, + s2: pl.Series, +) -> None: + assert df1.n_chunks() == 1 + if len(df2) > 1: + assert df2.n_chunks("all") == [2] * len(df2.columns) + + assert s1.n_chunks() == 1 + if len(s2) > 1: + assert s2.n_chunks() > 1 + + +@given( + df=dataframes( + allowed_dtypes=[pl.Float32, pl.Float64], max_cols=4, allow_infinity=False + ), + s=series(dtype=pl.Float64, allow_infinity=False), +) +def test_infinities( + df: pl.DataFrame, + s: pl.Series, +) -> None: + from math import isfinite, isnan + + def finite_float(value: Any) -> bool: + return isfinite(value) or isnan(value) + + assert all(finite_float(val) for val in s.to_list()) + for col in df.columns: + assert all(finite_float(val) for val in df[col].to_list()) + + +@given( + df=dataframes( + cols=10, + max_size=1, + allowed_dtypes=[pl.Int8, pl.UInt16, pl.List(pl.Int32)], + ) +) +@settings(max_examples=3) +def test_dataframes_allowed_dtypes_integer_cols(df: pl.DataFrame) -> None: + # ensure dtype constraint works in conjunction with 'n' cols + assert all( + tp in (pl.Int8, pl.UInt16, pl.List(pl.Int32)) for tp in df.schema.values() + ) diff --git a/py-polars/tests/unit/testing/test_assert_frame_equal.py b/py-polars/tests/unit/testing/test_assert_frame_equal.py index 5b2e9f92092a..69021a978ffc 100644 --- a/py-polars/tests/unit/testing/test_assert_frame_equal.py +++ b/py-polars/tests/unit/testing/test_assert_frame_equal.py @@ -4,15 +4,22 @@ from typing import Any import pytest +from hypothesis import given import polars as pl from polars.exceptions import InvalidAssert from polars.testing import assert_frame_equal, assert_frame_not_equal +from polars.testing.parametric import dataframes nan = float("nan") pytest_plugins = ["pytester"] +@given(df=dataframes()) +def test_equal(df: pl.DataFrame) -> None: + assert_frame_equal(df, df.clone(), check_exact=True) + + @pytest.mark.parametrize( ("df1", "df2", "kwargs"), [ From 54ddfa14a6f95fb553205e1b227294dbee10db9c Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 13 May 2024 12:48:14 +0200 Subject: [PATCH 24/29] fix: Fix rolling empty group OOB (#16186) --- crates/polars-time/src/group_by/dynamic.rs | 8 +- .../tests/unit/datatypes/test_temporal.py | 177 --------------- .../tests/unit/operations/test_rolling.py | 202 ++++++++++++++++++ 3 files changed, 209 insertions(+), 178 deletions(-) diff --git a/crates/polars-time/src/group_by/dynamic.rs b/crates/polars-time/src/group_by/dynamic.rs index 4a22d21f8a0c..589119440ec4 100644 --- a/crates/polars-time/src/group_by/dynamic.rs +++ b/crates/polars-time/src/group_by/dynamic.rs @@ -631,7 +631,13 @@ fn update_subgroups_slice(sub_groups: &[[IdxSize; 2]], base_g: [IdxSize; 2]) -> sub_groups .iter() .map(|&[first, len]| { - let new_first = base_g[0] + first; + let new_first = if len == 0 { + // In case the group is empty, keep the original first so that the + // group_by keys still point to the original group. + base_g[0] + } else { + base_g[0] + first + }; [new_first, len] }) .collect_trusted::>() diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index 79ef41e9a8cf..a34ffcb1ea4a 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -531,36 +531,6 @@ def test_explode_date() -> None: ] -def test_rolling() -> None: - dates = [ - "2020-01-01 13:45:48", - "2020-01-01 16:42:13", - "2020-01-01 16:45:09", - "2020-01-02 18:12:48", - "2020-01-03 19:45:32", - "2020-01-08 23:16:43", - ] - - df = ( - pl.DataFrame({"dt": dates, "a": [3, 7, 5, 9, 2, 1]}) - .with_columns(pl.col("dt").str.strptime(pl.Datetime)) - .set_sorted("dt") - ) - - period: str | timedelta - for period in ("2d", timedelta(days=2)): # type: ignore[assignment] - out = df.rolling(index_column="dt", period=period).agg( - [ - pl.sum("a").alias("sum_a"), - pl.min("a").alias("min_a"), - pl.max("a").alias("max_a"), - ] - ) - assert out["sum_a"].to_list() == [3, 10, 15, 24, 11, 1] - assert out["max_a"].to_list() == [3, 7, 7, 9, 9, 1] - assert out["min_a"].to_list() == [3, 3, 3, 3, 2, 1] - - @pytest.mark.parametrize( ("time_zone", "tzinfo"), [ @@ -926,35 +896,6 @@ def test_asof_join_tolerance_grouper() -> None: assert_frame_equal(out, expected) -def test_rolling_group_by_by_argument() -> None: - df = pl.DataFrame({"times": range(10), "groups": [1] * 4 + [2] * 6}) - - out = df.rolling("times", period="5i", group_by=["groups"]).agg( - pl.col("times").alias("agg_list") - ) - - expected = pl.DataFrame( - { - "groups": [1, 1, 1, 1, 2, 2, 2, 2, 2, 2], - "times": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], - "agg_list": [ - [0], - [0, 1], - [0, 1, 2], - [0, 1, 2, 3], - [4], - [4, 5], - [4, 5, 6], - [4, 5, 6, 7], - [4, 5, 6, 7, 8], - [5, 6, 7, 8, 9], - ], - } - ) - - assert_frame_equal(out, expected) - - def test_rolling_mean_3020() -> None: df = pl.DataFrame( { @@ -1376,96 +1317,6 @@ def test_datetime_instance_selection() -> None: assert [] == list(df.select(pl.exclude(DATETIME_DTYPES))) -def test_rolling_by_ordering() -> None: - # we must check that the keys still match the time labels after the rolling window - # with a `by` argument. - df = pl.DataFrame( - { - "dt": [ - datetime(2022, 1, 1, 0, 1), - datetime(2022, 1, 1, 0, 2), - datetime(2022, 1, 1, 0, 3), - datetime(2022, 1, 1, 0, 4), - datetime(2022, 1, 1, 0, 5), - datetime(2022, 1, 1, 0, 6), - datetime(2022, 1, 1, 0, 7), - ], - "key": ["A", "A", "B", "B", "A", "B", "A"], - "val": [1, 1, 1, 1, 1, 1, 1], - } - ).set_sorted("dt") - - assert df.rolling( - index_column="dt", - period="2m", - closed="both", - offset="-1m", - group_by="key", - ).agg( - [ - pl.col("val").sum().alias("sum val"), - ] - ).to_dict(as_series=False) == { - "key": ["A", "A", "A", "A", "B", "B", "B"], - "dt": [ - datetime(2022, 1, 1, 0, 1), - datetime(2022, 1, 1, 0, 2), - datetime(2022, 1, 1, 0, 5), - datetime(2022, 1, 1, 0, 7), - datetime(2022, 1, 1, 0, 3), - datetime(2022, 1, 1, 0, 4), - datetime(2022, 1, 1, 0, 6), - ], - "sum val": [2, 2, 1, 1, 2, 2, 1], - } - - -def test_rolling_by_() -> None: - df = pl.DataFrame({"group": pl.arange(0, 3, eager=True)}).join( - pl.DataFrame( - { - "datetime": pl.datetime_range( - datetime(2020, 1, 1), datetime(2020, 1, 5), "1d", eager=True - ), - } - ), - how="cross", - ) - out = ( - df.sort("datetime") - .rolling(index_column="datetime", group_by="group", period=timedelta(days=3)) - .agg([pl.len().alias("count")]) - ) - - expected = ( - df.sort(["group", "datetime"]) - .rolling(index_column="datetime", group_by="group", period="3d") - .agg([pl.len().alias("count")]) - ) - assert_frame_equal(out.sort(["group", "datetime"]), expected) - assert out.to_dict(as_series=False) == { - "group": [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2], - "datetime": [ - datetime(2020, 1, 1, 0, 0), - datetime(2020, 1, 2, 0, 0), - datetime(2020, 1, 3, 0, 0), - datetime(2020, 1, 4, 0, 0), - datetime(2020, 1, 5, 0, 0), - datetime(2020, 1, 1, 0, 0), - datetime(2020, 1, 2, 0, 0), - datetime(2020, 1, 3, 0, 0), - datetime(2020, 1, 4, 0, 0), - datetime(2020, 1, 5, 0, 0), - datetime(2020, 1, 1, 0, 0), - datetime(2020, 1, 2, 0, 0), - datetime(2020, 1, 3, 0, 0), - datetime(2020, 1, 4, 0, 0), - datetime(2020, 1, 5, 0, 0), - ], - "count": [1, 2, 3, 3, 3, 1, 2, 3, 3, 3, 1, 2, 3, 3, 3], - } - - def test_sum_duration() -> None: assert pl.DataFrame( [ @@ -2785,22 +2636,6 @@ def test_datetime_cum_agg_schema() -> None: } -def test_rolling_group_by_empty_groups_by_take_6330() -> None: - df1 = pl.DataFrame({"Event": ["Rain", "Sun"]}) - df2 = pl.DataFrame({"Date": [1, 2, 3, 4]}) - df = df1.join(df2, how="cross").set_sorted("Date") - - result = df.rolling( - index_column="Date", period="2i", offset="-2i", group_by="Event", closed="left" - ).agg(pl.len()) - - assert result.to_dict(as_series=False) == { - "Event": ["Rain", "Rain", "Rain", "Rain", "Sun", "Sun", "Sun", "Sun"], - "Date": [1, 2, 3, 4, 1, 2, 3, 4], - "len": [0, 1, 2, 2, 0, 1, 2, 2], - } - - def test_infer_iso8601_datetime(iso8601_format_datetime: str) -> None: # construct an example time string time_string = ( @@ -2958,18 +2793,6 @@ def test_pytime_conversion(tm: time) -> None: assert s.to_list() == [tm] -def test_rolling_duplicates() -> None: - df = pl.DataFrame( - { - "ts": [datetime(2000, 1, 1, 0, 0), datetime(2000, 1, 1, 0, 0)], - "value": [0, 1], - } - ) - assert df.sort("ts").with_columns(pl.col("value").rolling_max_by("ts", "1d"))[ - "value" - ].to_list() == [1, 1] - - def test_datetime_time_unit_none_deprecated() -> None: with pytest.deprecated_call(): dtype = pl.Datetime(time_unit=None) # type: ignore[arg-type] diff --git a/py-polars/tests/unit/operations/test_rolling.py b/py-polars/tests/unit/operations/test_rolling.py index 6f6e2b24e7a4..e08bdafeb69b 100644 --- a/py-polars/tests/unit/operations/test_rolling.py +++ b/py-polars/tests/unit/operations/test_rolling.py @@ -12,6 +12,36 @@ from polars.type_aliases import ClosedInterval, PolarsIntegerType +def test_rolling() -> None: + dates = [ + "2020-01-01 13:45:48", + "2020-01-01 16:42:13", + "2020-01-01 16:45:09", + "2020-01-02 18:12:48", + "2020-01-03 19:45:32", + "2020-01-08 23:16:43", + ] + + df = ( + pl.DataFrame({"dt": dates, "a": [3, 7, 5, 9, 2, 1]}) + .with_columns(pl.col("dt").str.strptime(pl.Datetime)) + .set_sorted("dt") + ) + + period: str | timedelta + for period in ("2d", timedelta(days=2)): # type: ignore[assignment] + out = df.rolling(index_column="dt", period=period).agg( + [ + pl.sum("a").alias("sum_a"), + pl.min("a").alias("min_a"), + pl.max("a").alias("max_a"), + ] + ) + assert out["sum_a"].to_list() == [3, 10, 15, 24, 11, 1] + assert out["max_a"].to_list() == [3, 7, 7, 9, 9, 1] + assert out["min_a"].to_list() == [3, 3, 3, 3, 2, 1] + + @pytest.mark.parametrize("dtype", [pl.UInt32, pl.UInt64, pl.Int32, pl.Int64]) def test_rolling_group_by_overlapping_groups(dtype: PolarsIntegerType) -> None: # this first aggregates overlapping groups so they cannot be naively flattened @@ -333,3 +363,175 @@ def test_negative_zero_offset_16168() -> None: assert_frame_equal(result, expected) result = df.rolling(index_column="foo", period="1i", offset="-0i").agg("index") assert_frame_equal(result, expected) + + +def test_rolling_sorted_empty_groups_16145() -> None: + df = pl.DataFrame( + { + "id": [1, 2], + "time": [ + datetime(year=1989, month=12, day=1, hour=12, minute=3), + datetime(year=1989, month=12, day=1, hour=13, minute=14), + ], + } + ) + + assert ( + df.sort("id") + .rolling( + index_column="time", + group_by="id", + period="1d", + offset="0d", + closed="right", + ) + .agg() + .select("id") + )["id"].to_list() == [1, 2] + + +def test_rolling_by_() -> None: + df = pl.DataFrame({"group": pl.arange(0, 3, eager=True)}).join( + pl.DataFrame( + { + "datetime": pl.datetime_range( + datetime(2020, 1, 1), datetime(2020, 1, 5), "1d", eager=True + ), + } + ), + how="cross", + ) + out = ( + df.sort("datetime") + .rolling(index_column="datetime", group_by="group", period=timedelta(days=3)) + .agg([pl.len().alias("count")]) + ) + + expected = ( + df.sort(["group", "datetime"]) + .rolling(index_column="datetime", group_by="group", period="3d") + .agg([pl.len().alias("count")]) + ) + assert_frame_equal(out.sort(["group", "datetime"]), expected) + assert out.to_dict(as_series=False) == { + "group": [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2], + "datetime": [ + datetime(2020, 1, 1, 0, 0), + datetime(2020, 1, 2, 0, 0), + datetime(2020, 1, 3, 0, 0), + datetime(2020, 1, 4, 0, 0), + datetime(2020, 1, 5, 0, 0), + datetime(2020, 1, 1, 0, 0), + datetime(2020, 1, 2, 0, 0), + datetime(2020, 1, 3, 0, 0), + datetime(2020, 1, 4, 0, 0), + datetime(2020, 1, 5, 0, 0), + datetime(2020, 1, 1, 0, 0), + datetime(2020, 1, 2, 0, 0), + datetime(2020, 1, 3, 0, 0), + datetime(2020, 1, 4, 0, 0), + datetime(2020, 1, 5, 0, 0), + ], + "count": [1, 2, 3, 3, 3, 1, 2, 3, 3, 3, 1, 2, 3, 3, 3], + } + + +def test_rolling_group_by_empty_groups_by_take_6330() -> None: + df1 = pl.DataFrame({"Event": ["Rain", "Sun"]}) + df2 = pl.DataFrame({"Date": [1, 2, 3, 4]}) + df = df1.join(df2, how="cross").set_sorted("Date") + + result = df.rolling( + index_column="Date", period="2i", offset="-2i", group_by="Event", closed="left" + ).agg(pl.len()) + + assert result.to_dict(as_series=False) == { + "Event": ["Rain", "Rain", "Rain", "Rain", "Sun", "Sun", "Sun", "Sun"], + "Date": [1, 2, 3, 4, 1, 2, 3, 4], + "len": [0, 1, 2, 2, 0, 1, 2, 2], + } + + +def test_rolling_duplicates() -> None: + df = pl.DataFrame( + { + "ts": [datetime(2000, 1, 1, 0, 0), datetime(2000, 1, 1, 0, 0)], + "value": [0, 1], + } + ) + assert df.sort("ts").with_columns(pl.col("value").rolling_max_by("ts", "1d"))[ + "value" + ].to_list() == [1, 1] + + +def test_rolling_group_by_by_argument() -> None: + df = pl.DataFrame({"times": range(10), "groups": [1] * 4 + [2] * 6}) + + out = df.rolling("times", period="5i", group_by=["groups"]).agg( + pl.col("times").alias("agg_list") + ) + + expected = pl.DataFrame( + { + "groups": [1, 1, 1, 1, 2, 2, 2, 2, 2, 2], + "times": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "agg_list": [ + [0], + [0, 1], + [0, 1, 2], + [0, 1, 2, 3], + [4], + [4, 5], + [4, 5, 6], + [4, 5, 6, 7], + [4, 5, 6, 7, 8], + [5, 6, 7, 8, 9], + ], + } + ) + + assert_frame_equal(out, expected) + + +def test_rolling_by_ordering() -> None: + # we must check that the keys still match the time labels after the rolling window + # with a `by` argument. + df = pl.DataFrame( + { + "dt": [ + datetime(2022, 1, 1, 0, 1), + datetime(2022, 1, 1, 0, 2), + datetime(2022, 1, 1, 0, 3), + datetime(2022, 1, 1, 0, 4), + datetime(2022, 1, 1, 0, 5), + datetime(2022, 1, 1, 0, 6), + datetime(2022, 1, 1, 0, 7), + ], + "key": ["A", "A", "B", "B", "A", "B", "A"], + "val": [1, 1, 1, 1, 1, 1, 1], + } + ).set_sorted("dt") + + assert df.rolling( + index_column="dt", + period="2m", + closed="both", + offset="-1m", + group_by="key", + ).agg( + [ + pl.col("val").sum().alias("sum val"), + ] + ).to_dict(as_series=False) == { + "key": ["A", "A", "A", "A", "B", "B", "B"], + "dt": [ + datetime(2022, 1, 1, 0, 1), + datetime(2022, 1, 1, 0, 2), + datetime(2022, 1, 1, 0, 5), + datetime(2022, 1, 1, 0, 7), + datetime(2022, 1, 1, 0, 3), + datetime(2022, 1, 1, 0, 4), + datetime(2022, 1, 1, 0, 6), + ], + "sum val": [2, 2, 1, 1, 2, 2, 1], + } From 2e0064721a299ed8bad9e89ba2adc4912c792a65 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Mon, 13 May 2024 13:10:30 +0200 Subject: [PATCH 25/29] feat(python): Support `Enum` types in parametric testing (#16188) --- .../polars/testing/parametric/strategies/data.py | 16 +++++++++++++++- .../testing/parametric/strategies/dtype.py | 9 +++++++++ .../testing/parametric/strategies/test_core.py | 7 +++++++ .../testing/parametric/strategies/test_data.py | 10 ++++++++++ 4 files changed, 41 insertions(+), 1 deletion(-) diff --git a/py-polars/polars/testing/parametric/strategies/data.py b/py-polars/polars/testing/parametric/strategies/data.py index 78dc119c8929..7cfda456c390 100644 --- a/py-polars/polars/testing/parametric/strategies/data.py +++ b/py-polars/polars/testing/parametric/strategies/data.py @@ -33,6 +33,7 @@ Datetime, Decimal, Duration, + Enum, Float32, Float64, Int8, @@ -49,7 +50,10 @@ UInt64, ) from polars.testing.parametric.strategies._utils import flexhash -from polars.testing.parametric.strategies.dtype import _DEFAULT_ARRAY_WIDTH_LIMIT +from polars.testing.parametric.strategies.dtype import ( + _DEFAULT_ARRAY_WIDTH_LIMIT, + _DEFAULT_ENUM_CATEGORIES_LIMIT, +) if TYPE_CHECKING: from datetime import date, datetime, time @@ -329,6 +333,16 @@ def data( strategy = categories( n_categories=kwargs.pop("n_categories", _DEFAULT_N_CATEGORIES) ) + elif dtype == Enum: + if isinstance(dtype, Enum): + if (cats := dtype.categories).is_empty(): + strategy = nulls() + else: + strategy = st.sampled_from(cats.to_list()) + else: + strategy = categories( + n_categories=kwargs.pop("n_categories", _DEFAULT_ENUM_CATEGORIES_LIMIT) + ) elif dtype == Decimal: strategy = decimals( getattr(dtype, "precision", None), getattr(dtype, "scale", 0) diff --git a/py-polars/polars/testing/parametric/strategies/dtype.py b/py-polars/polars/testing/parametric/strategies/dtype.py index ae54a5e405b4..91bae2317b93 100644 --- a/py-polars/polars/testing/parametric/strategies/dtype.py +++ b/py-polars/polars/testing/parametric/strategies/dtype.py @@ -15,6 +15,7 @@ Datetime, Decimal, Duration, + Enum, Float32, Float64, Int8, @@ -63,6 +64,7 @@ Duration, Categorical, Decimal, + Enum, ] # Supported data type classes that contain other data types _NESTED_DTYPES: list[DataTypeClass] = [ @@ -76,6 +78,7 @@ _DEFAULT_ARRAY_WIDTH_LIMIT = 3 _DEFAULT_STRUCT_FIELDS_LIMIT = 3 +_DEFAULT_ENUM_CATEGORIES_LIMIT = 3 def dtypes( @@ -174,6 +177,12 @@ def _instantiate_flat_dtype(draw: DrawFn, dtype: PolarsDataType) -> DataType: elif dtype == Categorical: ordering = draw(_categorical_orderings()) return Categorical(ordering) + elif dtype == Enum: + n_categories = draw( + st.integers(min_value=1, max_value=_DEFAULT_ENUM_CATEGORIES_LIMIT) + ) + categories = [f"c{i}" for i in range(n_categories)] + return Enum(categories) elif dtype == Decimal: precision = draw(st.integers(min_value=1, max_value=38) | st.none()) scale = draw(st.integers(min_value=0, max_value=precision or 38)) diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_core.py b/py-polars/tests/unit/testing/parametric/strategies/test_core.py index f32807f52452..13a7c31cb291 100644 --- a/py-polars/tests/unit/testing/parametric/strategies/test_core.py +++ b/py-polars/tests/unit/testing/parametric/strategies/test_core.py @@ -39,6 +39,13 @@ def test_series_dtype(data: st.DataObject) -> None: assert s.dtype == dtype +@given(s=series(dtype=pl.Enum)) +@settings(max_examples=5) +def test_series_dtype_enum(s: pl.Series) -> None: + assert isinstance(s.dtype, pl.Enum) + assert all(v in s.dtype.categories for v in s) + + @given(s=series(dtype=pl.Boolean, size=5)) @settings(max_examples=5) def test_series_size(s: pl.Series) -> None: diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_data.py b/py-polars/tests/unit/testing/parametric/strategies/test_data.py index 0820015158dc..a7316b0a7a7b 100644 --- a/py-polars/tests/unit/testing/parametric/strategies/test_data.py +++ b/py-polars/tests/unit/testing/parametric/strategies/test_data.py @@ -19,3 +19,13 @@ def test_data_kwargs(cat: str) -> None: @given(categories=data(pl.List(pl.Categorical), n_categories=3)) def test_data_nested_kwargs(categories: list[str]) -> None: assert all(c in ("c0", "c1", "c2") for c in categories) + + +@given(cat=data(pl.Enum)) +def test_data_enum(cat: str) -> None: + assert cat in ("c0", "c1", "c2") + + +@given(cat=data(pl.Enum(["hello", "world"]))) +def test_data_enum_instantiated(cat: str) -> None: + assert cat in ("hello", "world") From 8066952478d3c6ea1865f07221efe6f2de9d9c28 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 13 May 2024 14:31:48 +0200 Subject: [PATCH 26/29] fix: Fix get expression group-by state (#16189) --- .../src/physical_plan/expressions/mod.rs | 39 ------ .../src/physical_plan/expressions/take.rs | 125 ++++++++++-------- .../src/physical_plan/planner/expr.rs | 2 +- .../unit/operations/test_group_by_dynamic.py | 26 ++++ 4 files changed, 94 insertions(+), 98 deletions(-) diff --git a/crates/polars-lazy/src/physical_plan/expressions/mod.rs b/crates/polars-lazy/src/physical_plan/expressions/mod.rs index 4642654a9fb6..6d496e82b716 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/mod.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/mod.rs @@ -60,21 +60,6 @@ pub(crate) enum AggState { } impl AggState { - // Literal series are not safe to aggregate - fn safe_to_agg(&self, groups: &GroupsProxy) -> bool { - match self { - AggState::NotAggregated(s) => { - !(s.len() == 1 - // or more then one group - && (groups.len() > 1 - // or single groups with more than one index - || !groups.is_empty() - && groups.get(0).len() > 1)) - }, - _ => true, - } - } - fn try_map(&self, func: F) -> PolarsResult where F: FnOnce(&Series) -> PolarsResult, @@ -331,30 +316,6 @@ impl<'a> AggregationContext<'a> { self.update_groups = UpdateGroups::No; } - /// In a binary expression one state can be aggregated and the other not. - /// If both would be flattened naively one would be sorted and the other not. - /// Calling this function will ensure both are sorted. This will be a no-op - /// if already aggregated. - pub(crate) fn sort_by_groups(&mut self) { - // make sure that the groups are updated before we use them to sort. - self.groups(); - match &self.state { - AggState::NotAggregated(s) => { - // We should not aggregate literals!! - if self.state.safe_to_agg(&self.groups) { - // SAFETY: - // groups are in bounds - let agg = unsafe { s.agg_list(&self.groups) }; - self.update_groups = UpdateGroups::WithGroupsLen; - self.state = AggState::AggregatedList(agg); - } - }, - AggState::AggregatedScalar(_) => {}, - AggState::AggregatedList(_) => {}, - AggState::Literal(_) => {}, - } - } - /// # Arguments /// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its /// the columns dtype) diff --git a/crates/polars-lazy/src/physical_plan/expressions/take.rs b/crates/polars-lazy/src/physical_plan/expressions/take.rs index 9408635de332..5f0e40ab6a29 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/take.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/take.rs @@ -3,18 +3,19 @@ use polars_core::chunked_array::builder::get_list_builder; use polars_core::prelude::*; use polars_core::utils::NoNull; use polars_ops::prelude::{convert_to_unsigned_index, is_positive_idx_uncertain}; +use polars_utils::slice::GetSaferUnchecked; use crate::physical_plan::state::ExecutionState; use crate::prelude::*; -pub struct TakeExpr { +pub struct GatherExpr { pub(crate) phys_expr: Arc, pub(crate) idx: Arc, pub(crate) expr: Expr, pub(crate) returns_scalar: bool, } -impl PhysicalExpr for TakeExpr { +impl PhysicalExpr for GatherExpr { fn as_expression(&self) -> Option<&Expr> { Some(&self.expr) } @@ -93,7 +94,7 @@ impl PhysicalExpr for TakeExpr { } } -impl TakeExpr { +impl GatherExpr { fn finish( &self, df: &DataFrame, @@ -114,54 +115,75 @@ impl TakeExpr { mut ac: AggregationContext<'b>, idx: &IdxCa, ) -> PolarsResult> { - // The indexes are AggregatedScalar, meaning they are a single values pointing into - // a group. If we zip this with the first of each group -> `idx + first` then we can - // simply use a take operation on the whole array instead of per group. + if ac.is_not_aggregated() { + // A previous aggregation may have updated the groups. + let groups = ac.groups(); - // The groups maybe scattered all over the place, so we sort by group. - ac.sort_by_groups(); - - // A previous aggregation may have updated the groups. - let groups = ac.groups(); + // Determine the gather indices. + let idx: IdxCa = match groups.as_ref() { + GroupsProxy::Idx(groups) => { + if groups.all().iter().zip(idx).any(|(g, idx)| match idx { + None => true, + Some(idx) => idx >= g.len() as IdxSize, + }) { + self.oob_err()?; + } - // Determine the gather indices. - let idx: IdxCa = match groups.as_ref() { - GroupsProxy::Idx(groups) => { - if groups.all().iter().zip(idx).any(|(g, idx)| match idx { - None => true, - Some(idx) => idx >= g.len() as IdxSize, - }) { - self.oob_err()?; - } + idx.into_iter() + .zip(groups.iter()) + .map(|(idx, (_first, groups))| { + idx.map(|idx| { + // SAFETY: + // we checked bounds + unsafe { + *groups.get_unchecked_release(usize::try_from(idx).unwrap()) + } + }) + }) + .collect_trusted() + }, + GroupsProxy::Slice { groups, .. } => { + if groups.iter().zip(idx).any(|(g, idx)| match idx { + None => true, + Some(idx) => idx >= g[1], + }) { + self.oob_err()?; + } - idx.into_iter() - .zip(groups.first().iter()) - .map(|(idx, first)| idx.map(|idx| idx + first)) - .collect_trusted() - }, - GroupsProxy::Slice { groups, .. } => { - if groups.iter().zip(idx).any(|(g, idx)| match idx { - None => true, - Some(idx) => idx >= g[1], - }) { - self.oob_err()?; - } + idx.into_iter() + .zip(groups.iter()) + .map(|(idx, g)| idx.map(|idx| idx + g[0])) + .collect_trusted() + }, + }; - idx.into_iter() - .zip(groups.iter()) - .map(|(idx, g)| idx.map(|idx| idx + g[0])) - .collect_trusted() - }, - }; + let taken = ac.flat_naive().take(&idx)?; + let taken = if self.returns_scalar { + taken + } else { + taken.as_list().into_series() + }; - let taken = ac.flat_naive().take(&idx)?; - let taken = if self.returns_scalar { - taken + ac.with_series(taken, true, Some(&self.expr))?; + Ok(ac) } else { - taken.as_list().into_series() - }; + self.gather_aggregated_expensive(ac, idx) + } + } + + fn gather_aggregated_expensive<'b>( + &self, + mut ac: AggregationContext<'b>, + idx: &IdxCa, + ) -> PolarsResult> { + let out = ac + .aggregated() + .list() + .unwrap() + .try_apply_amortized(|s| s.as_ref().take(idx))?; - ac.with_series(taken, true, Some(&self.expr))?; + ac.with_series(out.into_series(), true, Some(&self.expr))?; + ac.with_update_groups(UpdateGroups::WithGroupsLen); Ok(ac) } @@ -174,11 +196,6 @@ impl TakeExpr { match idx.get(0) { None => polars_bail!(ComputeError: "cannot take by a null"), Some(idx) => { - if idx != 0 { - // We must make sure that the column we take from is sorted by - // groups otherwise we might point into the wrong group. - ac.sort_by_groups() - } // Make sure that we look at the updated groups. let groups = ac.groups(); @@ -213,15 +230,7 @@ impl TakeExpr { }, } } else { - let out = ac - .aggregated() - .list() - .unwrap() - .try_apply_amortized(|s| s.as_ref().take(idx))?; - - ac.with_series(out.into_series(), true, Some(&self.expr))?; - ac.with_update_groups(UpdateGroups::WithGroupsLen); - Ok(ac) + self.gather_aggregated_expensive(ac, idx) } } diff --git a/crates/polars-lazy/src/physical_plan/planner/expr.rs b/crates/polars-lazy/src/physical_plan/planner/expr.rs index fd7a6aebc653..3aaeadfa6546 100644 --- a/crates/polars-lazy/src/physical_plan/planner/expr.rs +++ b/crates/polars-lazy/src/physical_plan/planner/expr.rs @@ -262,7 +262,7 @@ fn create_physical_expr_inner( } => { let phys_expr = create_physical_expr_inner(expr, ctxt, expr_arena, schema, state)?; let phys_idx = create_physical_expr_inner(idx, ctxt, expr_arena, schema, state)?; - Ok(Arc::new(TakeExpr { + Ok(Arc::new(GatherExpr { phys_expr, idx: phys_idx, expr: node_to_expr(expression, expr_arena), diff --git a/py-polars/tests/unit/operations/test_group_by_dynamic.py b/py-polars/tests/unit/operations/test_group_by_dynamic.py index 1def99e885d6..72c9dfcbb124 100644 --- a/py-polars/tests/unit/operations/test_group_by_dynamic.py +++ b/py-polars/tests/unit/operations/test_group_by_dynamic.py @@ -1038,3 +1038,29 @@ def test_group_by_dynamic_invalid() -> None: .group_by_dynamic("index", every="3000d") .agg(pl.col("values").sum().alias("sum")) ) + + +def test_group_by_dynamic_get() -> None: + df = pl.DataFrame( + { + "time": pl.date_range(pl.date(2021, 1, 1), pl.date(2021, 1, 8), eager=True), + "data": pl.arange(8, eager=True), + } + ) + + assert df.group_by_dynamic( + index_column="time", + every="2d", + period="3d", + start_by="datapoint", + ).agg( + get=pl.col("data").get(1), + ).to_dict(as_series=False) == { + "time": [ + date(2021, 1, 1), + date(2021, 1, 3), + date(2021, 1, 5), + date(2021, 1, 7), + ], + "get": [1, 3, 5, 7], + } From f6b4f48f3f7a981f6d022b5ca8a50b40a467d74c Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Mon, 13 May 2024 15:54:35 +0200 Subject: [PATCH 27/29] feat(python): Enable Null datatype and null values by default in parametric testing (#16192) --- _typos.toml | 6 +-- py-polars/polars/series/utils.py | 2 +- .../testing/parametric/strategies/core.py | 47 ++++++++++--------- .../testing/parametric/strategies/data.py | 3 +- .../testing/parametric/strategies/dtype.py | 1 + .../unit/functions/range/test_date_range.py | 2 +- .../tests/unit/interchange/test_roundtrip.py | 17 ++++++- .../interop/numpy/test_to_numpy_series.py | 17 +++++-- .../namespaces/temporal/test_datetime.py | 2 +- .../unit/series/buffers/test_from_buffer.py | 1 + .../parametric/strategies/test_core.py | 32 +++++++++---- 11 files changed, 87 insertions(+), 43 deletions(-) diff --git a/_typos.toml b/_typos.toml index e2c2490664d5..43ba08246dbe 100644 --- a/_typos.toml +++ b/_typos.toml @@ -4,7 +4,6 @@ extend-ignore-identifiers-re = [ ] [default.extend-identifiers] -arange = "arange" bck = "bck" Fo = "Fo" ND = "ND" @@ -12,11 +11,10 @@ ba = "ba" nd = "nd" opt_nd = "opt_nd" ser = "ser" -strat = "strat" -width_strat = "width_strat" [default.extend-words] -iif = "iif" +arange = "arange" +strat = "strat" '"r0ot"' = "r0ot" wee = "wee" diff --git a/py-polars/polars/series/utils.py b/py-polars/polars/series/utils.py index 237b55a396da..34b09fb3d0da 100644 --- a/py-polars/polars/series/utils.py +++ b/py-polars/polars/series/utils.py @@ -32,7 +32,7 @@ def expr_dispatch(cls: type[T]) -> type[T]: * Applied to the Series class, and/or any Series 'NameSpace' classes. * Walks the class attributes, looking for methods that have empty function bodies, with signatures compatible with an existing Expr function. - * IIF both conditions are met, the empty method is decorated with @call_expr. + * IFF both conditions are met, the empty method is decorated with @call_expr. """ # create lookup of expression functions in this namespace namespace = getattr(cls, "_accessor", None) diff --git a/py-polars/polars/testing/parametric/strategies/core.py b/py-polars/polars/testing/parametric/strategies/core.py index 3657d465e802..2e9e4a13b0e3 100644 --- a/py-polars/polars/testing/parametric/strategies/core.py +++ b/py-polars/polars/testing/parametric/strategies/core.py @@ -8,7 +8,7 @@ from polars._utils.deprecation import issue_deprecation_warning from polars.dataframe import DataFrame -from polars.datatypes import DataType, DataTypeClass +from polars.datatypes import DataType, DataTypeClass, Null from polars.series import Series from polars.string_cache import StringCache from polars.testing.parametric.strategies._utils import flexhash @@ -39,7 +39,7 @@ def series( # noqa: D417 min_size: int = 0, max_size: int = _ROW_LIMIT, strategy: SearchStrategy[Any] | None = None, - allow_null: bool = False, + allow_null: bool = True, unique: bool = False, chunked: bool | None = None, allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, @@ -68,7 +68,7 @@ def series( # noqa: D417 strategy : strategy, optional supports overriding the default strategy for the given dtype. allow_null : bool - Allow nulls as possible values. + Allow nulls as possible values and allow the `Null` data type by default. unique : bool, optional indicate whether Series values should all be distinct. chunked : bool, optional @@ -144,22 +144,28 @@ def series( # noqa: D417 allowed_dtypes = list(allowed_dtypes) if isinstance(excluded_dtypes, (DataType, DataTypeClass)): excluded_dtypes = [excluded_dtypes] - elif excluded_dtypes is not None and not isinstance(excluded_dtypes, Sequence): - excluded_dtypes = list(excluded_dtypes) + elif excluded_dtypes is not None: + if not isinstance(excluded_dtypes, list): + excluded_dtypes = list(excluded_dtypes) + + if not allow_null and not (allowed_dtypes is not None and Null in allowed_dtypes): + if excluded_dtypes is None: + excluded_dtypes = [Null] + else: + excluded_dtypes.append(Null) if strategy is None: if dtype is None: - dtype = draw( - dtypes(allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes) + dtype_strat = dtypes( + allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes ) else: - dtype = draw( - _instantiate_dtype( - dtype, - allowed_dtypes=allowed_dtypes, - excluded_dtypes=excluded_dtypes, - ) + dtype_strat = _instantiate_dtype( + dtype, + allowed_dtypes=allowed_dtypes, + excluded_dtypes=excluded_dtypes, ) + dtype = draw(dtype_strat) if size is None: size = draw(st.integers(min_value=min_size, max_value=max_size)) @@ -213,7 +219,7 @@ def dataframes( max_size: int = _ROW_LIMIT, chunked: bool | None = None, include_cols: Sequence[column] | column | None = None, - allow_null: bool | Mapping[str, bool] = False, + allow_null: bool | Mapping[str, bool] = True, allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, **kwargs: Any, @@ -232,7 +238,7 @@ def dataframes( max_size: int = _ROW_LIMIT, chunked: bool | None = None, include_cols: Sequence[column] | column | None = None, - allow_null: bool | Mapping[str, bool] = False, + allow_null: bool | Mapping[str, bool] = True, allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, **kwargs: Any, @@ -253,7 +259,7 @@ def dataframes( # noqa: D417 max_size: int = _ROW_LIMIT, chunked: bool | None = None, include_cols: Sequence[column] | column | None = None, - allow_null: bool | Mapping[str, bool] = False, + allow_null: bool | Mapping[str, bool] = True, allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, **kwargs: Any, @@ -290,7 +296,8 @@ def dataframes( # noqa: D417 explicitly provided columns are appended onto the list of existing columns (if any present). allow_null : bool or Mapping[str, bool] - Allow nulls as possible values. + Allow nulls as possible values and allow the `Null` data type by default. + Accepts either a boolean or a mapping of column names to booleans. allowed_dtypes : {list,set}, optional when automatically generating data, allow only these dtypes. excluded_dtypes : {list,set}, optional @@ -404,12 +411,10 @@ def dataframes( # noqa: D417 c.name = f"col{idx}" if c.allow_null is None: if isinstance(allow_null, Mapping): - c.allow_null = allow_null.get(c.name, False) + c.allow_null = allow_null.get(c.name, True) else: c.allow_null = allow_null - # init dataframe from generated series data; series data is - # given as a python-native sequence. with StringCache(): data = { c.name: draw( @@ -456,7 +461,7 @@ class column: strategy : strategy, optional supports overriding the default strategy for the given dtype. allow_null : bool, optional - Allow nulls as possible values. + Allow nulls as possible values and allow the `Null` data type by default. unique : bool, optional flag indicating that all values generated for the column should be unique. diff --git a/py-polars/polars/testing/parametric/strategies/data.py b/py-polars/polars/testing/parametric/strategies/data.py index 7cfda456c390..2bb3345c3ce3 100644 --- a/py-polars/polars/testing/parametric/strategies/data.py +++ b/py-polars/polars/testing/parametric/strategies/data.py @@ -349,7 +349,7 @@ def data( ) elif dtype == List: inner = getattr(dtype, "inner", None) or Null() - strategy = lists(inner, **kwargs) + strategy = lists(inner, allow_null=allow_null, **kwargs) elif dtype == Array: inner = getattr(dtype, "inner", None) or Null() width = getattr(dtype, "width", _DEFAULT_ARRAY_WIDTH_LIMIT) @@ -357,6 +357,7 @@ def data( inner, min_len=width, max_len=width, + allow_null=allow_null, **kwargs, ) else: diff --git a/py-polars/polars/testing/parametric/strategies/dtype.py b/py-polars/polars/testing/parametric/strategies/dtype.py index 91bae2317b93..d3a192e462a1 100644 --- a/py-polars/polars/testing/parametric/strategies/dtype.py +++ b/py-polars/polars/testing/parametric/strategies/dtype.py @@ -57,6 +57,7 @@ Binary, Date, Time, + Null, ] # Supported data type classes with arguments _COMPLEX_DTYPES: list[DataTypeClass] = [ diff --git a/py-polars/tests/unit/functions/range/test_date_range.py b/py-polars/tests/unit/functions/range/test_date_range.py index 8651ee0f4b95..92753e53f2eb 100644 --- a/py-polars/tests/unit/functions/range/test_date_range.py +++ b/py-polars/tests/unit/functions/range/test_date_range.py @@ -14,7 +14,7 @@ def test_date_range() -> None: - # if low/high are both date, range is also be date _iif_ the granularity is >= 1d + # if low/high are both date, range is also be date _iff_ the granularity is >= 1d result = pl.date_range(date(2022, 1, 1), date(2022, 3, 1), "1mo", eager=True) assert result.to_list() == [date(2022, 1, 1), date(2022, 2, 1), date(2022, 3, 1)] diff --git a/py-polars/tests/unit/interchange/test_roundtrip.py b/py-polars/tests/unit/interchange/test_roundtrip.py index 31e639e3479c..e0bb7f0d65cc 100644 --- a/py-polars/tests/unit/interchange/test_roundtrip.py +++ b/py-polars/tests/unit/interchange/test_roundtrip.py @@ -36,7 +36,12 @@ ] -@given(dataframes(allowed_dtypes=protocol_dtypes)) +@given( + dataframes( + allowed_dtypes=protocol_dtypes, + allow_null=False, # Bug: https://github.com/pola-rs/polars/issues/16190 + ) +) def test_to_dataframe_pyarrow_parametric(df: pl.DataFrame) -> None: dfi = df.__dataframe__() df_pa = pa.interchange.from_dataframe(dfi) @@ -71,7 +76,12 @@ def test_to_dataframe_pyarrow_zero_copy_parametric(df: pl.DataFrame) -> None: @pytest.mark.filterwarnings( "ignore:.*PEP3118 format string that does not match its itemsize:RuntimeWarning" ) -@given(dataframes(allowed_dtypes=protocol_dtypes)) +@given( + dataframes( + allowed_dtypes=protocol_dtypes, + allow_null=False, # Bug: https://github.com/pola-rs/polars/issues/16190 + ) +) def test_to_dataframe_pandas_parametric(df: pl.DataFrame) -> None: dfi = df.__dataframe__() df_pd = pd.api.interchange.from_dataframe(dfi) @@ -94,6 +104,7 @@ def test_to_dataframe_pandas_parametric(df: pl.DataFrame) -> None: pl.Categorical, ], chunked=False, + allow_null=False, # Bug: https://github.com/pola-rs/polars/issues/16190 ) ) def test_to_dataframe_pandas_zero_copy_parametric(df: pl.DataFrame) -> None: @@ -193,6 +204,7 @@ def test_from_dataframe_pandas_zero_copy_parametric(df: pl.DataFrame) -> None: # Empty string columns cause an error due to a bug in pandas. # https://github.com/pandas-dev/pandas/issues/56703 min_size=1, + allow_null=False, # Bug: https://github.com/pola-rs/polars/issues/16190 ) ) def test_from_dataframe_pandas_native_parametric(df: pl.DataFrame) -> None: @@ -217,6 +229,7 @@ def test_from_dataframe_pandas_native_parametric(df: pl.DataFrame) -> None: # https://github.com/pandas-dev/pandas/issues/56700 min_size=1, chunked=False, + allow_null=False, # Bug: https://github.com/pola-rs/polars/issues/16190 ) ) def test_from_dataframe_pandas_native_zero_copy_parametric(df: pl.DataFrame) -> None: diff --git a/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py b/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py index 31e845e1bbe3..b559918ae7d0 100644 --- a/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py +++ b/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py @@ -326,11 +326,19 @@ def test_series_to_numpy_temporal() -> None: @given( s=series( - min_size=1, max_size=10, excluded_dtypes=[pl.Categorical, pl.List, pl.Struct] + min_size=1, + max_size=10, + excluded_dtypes=[ + pl.Categorical, + pl.List, + pl.Struct, + pl.Datetime("ms"), + pl.Duration("ms"), + ], + allow_null=False, ).filter( lambda s: ( - getattr(s.dtype, "time_unit", None) != "ms" - and not (s.dtype == pl.String and s.str.contains("\x00").any()) + not (s.dtype == pl.String and s.str.contains("\x00").any()) and not (s.dtype == pl.Binary and s.bin.contains(b"\x00").any()) ) ), @@ -345,8 +353,9 @@ def test_series_to_numpy(s: pl.Series) -> None: pl.Datetime("us"): "datetime64[us]", pl.Duration("ns"): "timedelta64[ns]", pl.Duration("us"): "timedelta64[us]", + pl.Null(): "float32", } - np_dtype = dtype_map.get(s.dtype) # type: ignore[call-overload] + np_dtype = dtype_map.get(s.dtype) expected = np.array(values, dtype=np_dtype) assert_array_equal(result, expected) diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py index bdc31aa00d59..96b27b814d65 100644 --- a/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py @@ -1350,7 +1350,7 @@ def test_series_duration_timeunits( @given( - s=series(min_size=1, max_size=10, dtype=pl.Datetime), + s=series(min_size=1, max_size=10, dtype=pl.Datetime, allow_null=False), ) def test_series_datetime_timeunits( s: pl.Series, diff --git a/py-polars/tests/unit/series/buffers/test_from_buffer.py b/py-polars/tests/unit/series/buffers/test_from_buffer.py index 34250f30ecf0..5eeecb3adf35 100644 --- a/py-polars/tests/unit/series/buffers/test_from_buffer.py +++ b/py-polars/tests/unit/series/buffers/test_from_buffer.py @@ -14,6 +14,7 @@ s=series( allowed_dtypes=(pl.INTEGER_DTYPES | pl.FLOAT_DTYPES | {pl.Boolean}), chunked=False, + allow_null=False, ) ) def test_series_from_buffer(s: pl.Series) -> None: diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_core.py b/py-polars/tests/unit/testing/parametric/strategies/test_core.py index 13a7c31cb291..cec48b8acbd9 100644 --- a/py-polars/tests/unit/testing/parametric/strategies/test_core.py +++ b/py-polars/tests/unit/testing/parametric/strategies/test_core.py @@ -23,7 +23,6 @@ def test_series_defaults(s: pl.Series) -> None: assert isinstance(s, pl.Series) assert s.name == "" - assert s.null_count() == 0 @given(s=series(name="hello")) @@ -39,7 +38,7 @@ def test_series_dtype(data: st.DataObject) -> None: assert s.dtype == dtype -@given(s=series(dtype=pl.Enum)) +@given(s=series(dtype=pl.Enum, allow_null=False)) @settings(max_examples=5) def test_series_dtype_enum(s: pl.Series) -> None: assert isinstance(s.dtype, pl.Enum) @@ -58,9 +57,21 @@ def test_series_size_range(s: pl.Series) -> None: assert 3 <= s.len() <= 8 -@given(s=series(allow_null=True)) -def test_series_allow_null(s: pl.Series) -> None: - assert 0 <= s.null_count() <= s.len() +@given(s=series(allow_null=False)) +def test_series_allow_null_false(s: pl.Series) -> None: + assert s.null_count() == 0 + assert s.dtype != pl.Null + + +@given(s=series(allowed_dtypes=[pl.Null], allow_null=False)) +def test_series_allow_null_allowed_dtypes(s: pl.Series) -> None: + assert s.dtype == pl.Null + + +@given(s=series(allowed_dtypes=[pl.List(pl.Int8)], allow_null=False)) +def test_series_allow_null_nested(s: pl.Series) -> None: + for v in s: + assert v.null_count() == 0 @given(df=dataframes()) @@ -121,6 +132,7 @@ def test_dataframes_allow_null_override(df: pl.DataFrame) -> None: # generate lazyframes with at least one row lazy=True, min_size=1, + allow_null=False, # test mix & match of bulk-assigned cols with custom cols cols=[column(n, dtype=pl.UInt8, unique=True) for n in ["a", "b"]], include_cols=[ @@ -190,7 +202,8 @@ def test_allow_infinities_deprecated(data: st.DataObject) -> None: min_len=1, ), ), - ] + ], + allow_null=False, ), ) def test_dataframes_nested_strategies(df: pl.DataFrame) -> None: @@ -255,9 +268,12 @@ def test_chunking( @given( df=dataframes( - allowed_dtypes=[pl.Float32, pl.Float64], max_cols=4, allow_infinity=False + allowed_dtypes=[pl.Float32, pl.Float64], + max_cols=4, + allow_null=False, + allow_infinity=False, ), - s=series(dtype=pl.Float64, allow_infinity=False), + s=series(dtype=pl.Float64, allow_null=False, allow_infinity=False), ) def test_infinities( df: pl.DataFrame, From 9bfa30c4c68dd16a5482b4de3f9b3419d9be3aa0 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Mon, 13 May 2024 15:42:55 +0100 Subject: [PATCH 28/29] chore(rust): Use `Duration.is_zero` instead of comparing Duration.duration_ns to 0 (#16195) --- .../polars-time/src/chunkedarray/rolling_window/dispatch.rs | 2 +- crates/polars-time/src/group_by/dynamic.rs | 2 +- crates/polars-time/src/windows/group_by.rs | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs index 5feb3f9f99cb..eaf503db5a68 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs @@ -83,7 +83,7 @@ where } let ca = ca.rechunk(); ensure_duration_matches_data_type(options.window_size, by.dtype(), "window_size")?; - polars_ensure!(options.window_size.duration_ns()>0 && !options.window_size.negative, InvalidOperation: "`window_size` must be strictly positive"); + polars_ensure!(!options.window_size.is_zero() && !options.window_size.negative, InvalidOperation: "`window_size` must be strictly positive"); if by.is_sorted_flag() != IsSorted::Ascending && options.warn_if_unsorted { polars_warn!(format!( "Series is not known to be sorted by `by` column in `rolling_*_by` operation.\n\ diff --git a/crates/polars-time/src/group_by/dynamic.rs b/crates/polars-time/src/group_by/dynamic.rs index 589119440ec4..30c352b62bb0 100644 --- a/crates/polars-time/src/group_by/dynamic.rs +++ b/crates/polars-time/src/group_by/dynamic.rs @@ -128,7 +128,7 @@ impl Wrap<&DataFrame> { options: &RollingGroupOptions, ) -> PolarsResult<(Series, Vec, GroupsProxy)> { polars_ensure!( - options.period.duration_ns() > 0 && !options.period.negative, + !options.period.is_zero() && !options.period.negative, ComputeError: "rolling window period should be strictly positive", ); diff --git a/crates/polars-time/src/windows/group_by.rs b/crates/polars-time/src/windows/group_by.rs index 1754cde91865..ce4d9c62767b 100644 --- a/crates/polars-time/src/windows/group_by.rs +++ b/crates/polars-time/src/windows/group_by.rs @@ -576,7 +576,7 @@ pub fn group_by_values( let run_parallel = !POOL.current_thread_has_pending_tasks().unwrap_or(false); // we have a (partial) lookbehind window - if offset.negative && offset.duration_ns() > 0 { + if offset.negative && !offset.is_zero() { // lookbehind if offset.duration_ns() == period.duration_ns() { // t is right at the end of the window @@ -647,7 +647,7 @@ pub fn group_by_values( iter.map(|result| result.map(|(offset, len)| [offset, len])) .collect::>() } - } else if offset.duration_ns() != 0 + } else if !offset.is_zero() || closed_window == ClosedWindow::Right || closed_window == ClosedWindow::None { From 81cc802c62279f8f9c92ae4c7635461680dad037 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Mon, 13 May 2024 21:28:15 +0200 Subject: [PATCH 29/29] feat(python): Implement support for Struct types in parametric tests (#16197) --- .../testing/parametric/strategies/data.py | 33 +++++++++++++++++-- .../testing/parametric/strategies/dtype.py | 2 +- .../tests/unit/dataframe/test_null_count.py | 6 +++- .../tests/unit/dataframe/test_to_dict.py | 8 ++++- py-polars/tests/unit/operations/test_clear.py | 14 ++++++-- .../tests/unit/operations/test_drop_nulls.py | 9 ++++- .../parametric/strategies/test_data.py | 6 ++++ 7 files changed, 69 insertions(+), 9 deletions(-) diff --git a/py-polars/polars/testing/parametric/strategies/data.py b/py-polars/polars/testing/parametric/strategies/data.py index 2bb3345c3ce3..71313436dc36 100644 --- a/py-polars/polars/testing/parametric/strategies/data.py +++ b/py-polars/polars/testing/parametric/strategies/data.py @@ -4,7 +4,7 @@ import decimal from datetime import timedelta -from typing import TYPE_CHECKING, Any, Literal, Sequence +from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence import hypothesis.strategies as st from hypothesis.errors import InvalidArgument @@ -34,6 +34,7 @@ Decimal, Duration, Enum, + Field, Float32, Float64, Int8, @@ -43,6 +44,7 @@ List, Null, String, + Struct, Time, UInt8, UInt16, @@ -58,10 +60,10 @@ if TYPE_CHECKING: from datetime import date, datetime, time - from hypothesis.strategies import SearchStrategy + from hypothesis.strategies import DrawFn, SearchStrategy from polars.datatypes import DataType, DataTypeClass - from polars.type_aliases import PolarsDataType, TimeUnit + from polars.type_aliases import PolarsDataType, SchemaDict, TimeUnit _DEFAULT_LIST_LEN_LIMIT = 3 _DEFAULT_N_CATEGORIES = 10 @@ -278,6 +280,28 @@ def lists( ) +@st.composite +def structs( # noqa: D417 + draw: DrawFn, /, fields: Sequence[Field] | SchemaDict, **kwargs: Any +) -> dict[str, Any]: + """ + Create a strategy for generating structs with the given fields. + + Parameters + ---------- + fields + The fields that make up the struct. Can be either a sequence of Field + objects or a mapping of column names to data types. + **kwargs + Additional arguments that are passed to nested data generation strategies. + """ + if isinstance(fields, Mapping): + fields = [Field(name, dtype) for name, dtype in fields.items()] + + strats = {f.name: data(f.dtype, **kwargs) for f in fields} + return {col: draw(strat) for col, strat in strats.items()} + + def nulls() -> SearchStrategy[None]: """Create a strategy for generating null values.""" return st.none() @@ -360,6 +384,9 @@ def data( allow_null=allow_null, **kwargs, ) + elif dtype == Struct: + fields = getattr(dtype, "fields", None) or [Field("f0", Null())] + strategy = structs(fields, **kwargs) else: msg = f"unsupported data type: {dtype}" raise InvalidArgument(msg) diff --git a/py-polars/polars/testing/parametric/strategies/dtype.py b/py-polars/polars/testing/parametric/strategies/dtype.py index d3a192e462a1..dac7049def65 100644 --- a/py-polars/polars/testing/parametric/strategies/dtype.py +++ b/py-polars/polars/testing/parametric/strategies/dtype.py @@ -72,7 +72,7 @@ # TODO: Enable nested types by default when various issues are solved. # List, # Array, - # Struct, + Struct, ] # Supported data type classes that do not contain other data types _FLAT_DTYPES = _SIMPLE_DTYPES + _COMPLEX_DTYPES diff --git a/py-polars/tests/unit/dataframe/test_null_count.py b/py-polars/tests/unit/dataframe/test_null_count.py index 11755bbdcb9b..a9b1141a2a67 100644 --- a/py-polars/tests/unit/dataframe/test_null_count.py +++ b/py-polars/tests/unit/dataframe/test_null_count.py @@ -11,7 +11,11 @@ min_size=1, min_cols=1, allow_null=True, - excluded_dtypes=[pl.String, pl.List], + excluded_dtypes=[ + pl.String, + pl.List, + pl.Struct, # See: https://github.com/pola-rs/polars/issues/3462 + ], ) ) @example(df=pl.DataFrame(schema=["x", "y", "z"])) diff --git a/py-polars/tests/unit/dataframe/test_to_dict.py b/py-polars/tests/unit/dataframe/test_to_dict.py index 30414f7c4a23..e95fc014caf5 100644 --- a/py-polars/tests/unit/dataframe/test_to_dict.py +++ b/py-polars/tests/unit/dataframe/test_to_dict.py @@ -10,7 +10,13 @@ from polars.testing.parametric import dataframes -@given(df=dataframes()) +@given( + df=dataframes( + excluded_dtypes=[ + pl.Categorical, # Bug: https://github.com/pola-rs/polars/issues/16196 + ] + ) +) def test_to_dict(df: pl.DataFrame) -> None: d = df.to_dict(as_series=False) result = pl.from_dict(d, schema=df.schema) diff --git a/py-polars/tests/unit/operations/test_clear.py b/py-polars/tests/unit/operations/test_clear.py index 7799e7e05ce2..ff2e94fecb94 100644 --- a/py-polars/tests/unit/operations/test_clear.py +++ b/py-polars/tests/unit/operations/test_clear.py @@ -8,14 +8,24 @@ from polars.testing.parametric import series -@given(s=series(), n=st.integers(min_value=0, max_value=10)) -def test_clear_series_parametric(s: pl.Series, n: int) -> None: +@given(s=series()) +def test_clear_series_parametric(s: pl.Series) -> None: result = s.clear() assert result.dtype == s.dtype assert result.name == s.name assert result.is_empty() + +@given( + s=series( + excluded_dtypes=[ + pl.Struct, # See: https://github.com/pola-rs/polars/issues/3462 + ] + ), + n=st.integers(min_value=0, max_value=10), +) +def test_clear_series_n_parametric(s: pl.Series, n: int) -> None: result = s.clear(n) assert result.dtype == s.dtype diff --git a/py-polars/tests/unit/operations/test_drop_nulls.py b/py-polars/tests/unit/operations/test_drop_nulls.py index 4250ecad154e..287a7ec2b7b0 100644 --- a/py-polars/tests/unit/operations/test_drop_nulls.py +++ b/py-polars/tests/unit/operations/test_drop_nulls.py @@ -7,7 +7,14 @@ from polars.testing.parametric import series -@given(s=series(allow_null=True)) +@given( + s=series( + allow_null=True, + excluded_dtypes=[ + pl.Struct, # See: https://github.com/pola-rs/polars/issues/3462 + ], + ) +) def test_drop_nulls_parametric(s: pl.Series) -> None: result = s.drop_nulls() assert result.len() == s.len() - s.null_count() diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_data.py b/py-polars/tests/unit/testing/parametric/strategies/test_data.py index a7316b0a7a7b..d3c5282ac959 100644 --- a/py-polars/tests/unit/testing/parametric/strategies/test_data.py +++ b/py-polars/tests/unit/testing/parametric/strategies/test_data.py @@ -29,3 +29,9 @@ def test_data_enum(cat: str) -> None: @given(cat=data(pl.Enum(["hello", "world"]))) def test_data_enum_instantiated(cat: str) -> None: assert cat in ("hello", "world") + + +@given(struct=data(pl.Struct({"a": pl.Int8, "b": pl.String}))) +def test_data_struct(struct: dict[str, int | str]) -> None: + assert isinstance(struct["a"], int) + assert isinstance(struct["b"], str)