From d190e02693ead521cf38a603b31722db258ef491 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Thu, 30 May 2024 10:42:22 +0200 Subject: [PATCH] depr(python): Rename `Series.equals` parameter `strict` to `check_dtypes` and rename assertion utils parameter `check_dtype` to `check_dtypes` (#16573) --- crates/polars-core/src/testing.rs | 2 +- py-polars/polars/dataframe/frame.py | 2 +- py-polars/polars/series/series.py | 18 +++++--- py-polars/polars/testing/asserts/frame.py | 19 ++++---- py-polars/polars/testing/asserts/series.py | 15 +++--- py-polars/src/series/mod.rs | 4 +- py-polars/tests/unit/dataframe/test_df.py | 4 +- py-polars/tests/unit/io/test_delta.py | 16 +++---- .../map/test_inefficient_map_warning.py | 2 +- py-polars/tests/unit/series/test_equals.py | 11 ++++- py-polars/tests/unit/series/test_series.py | 2 +- py-polars/tests/unit/sql/test_group_by.py | 4 +- .../unit/streaming/test_streaming_join.py | 4 +- py-polars/tests/unit/test_lazy.py | 2 +- .../unit/testing/test_assert_frame_equal.py | 20 ++++++-- .../unit/testing/test_assert_series_equal.py | 46 ++++++++++++------- 16 files changed, 107 insertions(+), 64 deletions(-) diff --git a/crates/polars-core/src/testing.rs b/crates/polars-core/src/testing.rs index 99c28a617b2b..6c630b02af96 100644 --- a/crates/polars-core/src/testing.rs +++ b/crates/polars-core/src/testing.rs @@ -34,7 +34,7 @@ impl Series { && { let eq = self.equal_missing(other); match eq { - Ok(b) => b.sum().map(|s| s as usize).unwrap_or(0) == self.len(), + Ok(b) => b.all(), Err(_) => false, } } diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 893c04a06076..03cac5e3df61 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -4954,7 +4954,7 @@ def equals(self, other: DataFrame, *, null_equal: bool = True) -> bool: >>> df1.equals(df2) False """ - return self._df.equals(other._df, null_equal) + return self._df.equals(other._df, null_equal=null_equal) @deprecate_function( "DataFrame.replace is deprecated and will be removed in a future version. " diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 7c62d1765045..102fbf7f1191 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4111,8 +4111,13 @@ def explode(self) -> Series: ] """ + @deprecate_renamed_parameter("strict", "check_dtypes", version="0.20.31") def equals( - self, other: Series, *, null_equal: bool = True, strict: bool = False + self, + other: Series, + *, + check_dtypes: bool = False, + null_equal: bool = True, ) -> bool: """ Check whether the Series is equal to another Series. @@ -4121,11 +4126,10 @@ def equals( ---------- other Series to compare with. + check_dtypes + Require data types to match. null_equal Consider null values as equal. - strict - Don't allow different numerical dtypes, e.g. comparing `pl.UInt32` with a - `pl.Int64` will return `False`. See Also -------- @@ -4140,7 +4144,9 @@ def equals( >>> s1.equals(s2) False """ - return self._s.equals(other._s, null_equal, strict) + return self._s.equals( + other._s, check_dtypes=check_dtypes, null_equal=null_equal + ) def cast( self, @@ -8112,7 +8118,7 @@ def series_equal( Don't allow different numerical dtypes, e.g. comparing `pl.UInt32` with a `pl.Int64` will return `False`. """ - return self.equals(other, null_equal=null_equal, strict=strict) + return self.equals(other, check_dtypes=strict, null_equal=null_equal) # Keep the `list` and `str` properties below at the end of the definition of Series, # as to not confuse mypy with the type annotation `str` and `list` diff --git a/py-polars/polars/testing/asserts/frame.py b/py-polars/polars/testing/asserts/frame.py index ff2f8fc04c39..c3c939fb41ea 100644 --- a/py-polars/polars/testing/asserts/frame.py +++ b/py-polars/polars/testing/asserts/frame.py @@ -2,6 +2,7 @@ from typing import cast +from polars._utils.deprecation import deprecate_renamed_parameter from polars.dataframe import DataFrame from polars.exceptions import ComputeError, InvalidAssert from polars.lazyframe import LazyFrame @@ -9,13 +10,14 @@ from polars.testing.asserts.utils import raise_assertion_error +@deprecate_renamed_parameter("check_dtype", "check_dtypes", version="0.20.31") def assert_frame_equal( left: DataFrame | LazyFrame, right: DataFrame | LazyFrame, *, check_row_order: bool = True, check_column_order: bool = True, - check_dtype: bool = True, + check_dtypes: bool = True, check_exact: bool = False, rtol: float = 1e-5, atol: float = 1e-8, @@ -41,7 +43,7 @@ def assert_frame_equal( frames that contain unsortable columns. check_column_order Require column order to match. - check_dtype + check_dtypes Require data types to match. check_exact Require float values to match exactly. If set to `False`, values are considered @@ -94,7 +96,7 @@ def assert_frame_equal( left, right, check_column_order=check_column_order, - check_dtype=check_dtype, + check_dtypes=check_dtypes, objects=objects, ) @@ -153,7 +155,7 @@ def _assert_frame_schema_equal( left: DataFrame | LazyFrame, right: DataFrame | LazyFrame, *, - check_dtype: bool, + check_dtypes: bool, check_column_order: bool, objects: str, ) -> None: @@ -181,7 +183,7 @@ def _assert_frame_schema_equal( detail = "columns are not in the same order" raise_assertion_error(objects, detail, left_columns, right_columns) - if check_dtype: + if check_dtypes: left_schema_dict, right_schema_dict = dict(left_schema), dict(right_schema) if check_column_order or left_schema_dict != right_schema_dict: detail = "dtypes do not match" @@ -199,13 +201,14 @@ def _sort_dataframes(left: DataFrame, right: DataFrame) -> tuple[DataFrame, Data return left, right +@deprecate_renamed_parameter("check_dtype", "check_dtypes", version="0.20.31") def assert_frame_not_equal( left: DataFrame | LazyFrame, right: DataFrame | LazyFrame, *, check_row_order: bool = True, check_column_order: bool = True, - check_dtype: bool = True, + check_dtypes: bool = True, check_exact: bool = False, rtol: float = 1e-5, atol: float = 1e-8, @@ -230,7 +233,7 @@ def assert_frame_not_equal( frames that contain unsortable columns. check_column_order Require column order to match. - check_dtype + check_dtypes Require data types to match. check_exact Require float values to match exactly. If set to `False`, values are considered @@ -267,7 +270,7 @@ def assert_frame_not_equal( right=right, check_column_order=check_column_order, check_row_order=check_row_order, - check_dtype=check_dtype, + check_dtypes=check_dtypes, check_exact=check_exact, rtol=rtol, atol=atol, diff --git a/py-polars/polars/testing/asserts/series.py b/py-polars/polars/testing/asserts/series.py index 5bf691037ea9..da9c0e183e6d 100644 --- a/py-polars/polars/testing/asserts/series.py +++ b/py-polars/polars/testing/asserts/series.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING +from polars._utils.deprecation import deprecate_renamed_parameter from polars.datatypes import ( FLOAT_DTYPES, Array, @@ -19,11 +20,12 @@ from polars import DataType +@deprecate_renamed_parameter("check_dtype", "check_dtypes", version="0.20.31") def assert_series_equal( left: Series, right: Series, *, - check_dtype: bool = True, + check_dtypes: bool = True, check_names: bool = True, check_exact: bool = False, rtol: float = 1e-5, @@ -42,7 +44,7 @@ def assert_series_equal( The first Series to compare. right The second Series to compare. - check_dtype + check_dtypes Require data types to match. check_names Require names to match. @@ -99,7 +101,7 @@ def assert_series_equal( if check_names and left.name != right.name: raise_assertion_error("Series", "name mismatch", left.name, right.name) - if check_dtype and left.dtype != right.dtype: + if check_dtypes and left.dtype != right.dtype: raise_assertion_error("Series", "dtype mismatch", left.dtype, right.dtype) _assert_series_values_equal( @@ -295,11 +297,12 @@ def _assert_series_values_within_tolerance( ) +@deprecate_renamed_parameter("check_dtype", "check_dtypes", version="0.20.31") def assert_series_not_equal( left: Series, right: Series, *, - check_dtype: bool = True, + check_dtypes: bool = True, check_names: bool = True, check_exact: bool = False, rtol: float = 1e-5, @@ -317,7 +320,7 @@ def assert_series_not_equal( The first Series to compare. right The second Series to compare. - check_dtype + check_dtypes Require data types to match. check_names Require names to match. @@ -355,7 +358,7 @@ def assert_series_not_equal( assert_series_equal( left=left, right=right, - check_dtype=check_dtype, + check_dtypes=check_dtypes, check_names=check_names, check_exact=check_exact, rtol=rtol, diff --git a/py-polars/src/series/mod.rs b/py-polars/src/series/mod.rs index cf1b7dc03a9b..bda26f533eb5 100644 --- a/py-polars/src/series/mod.rs +++ b/py-polars/src/series/mod.rs @@ -324,8 +324,8 @@ impl PySeries { self.series.has_validity() } - fn equals(&self, other: &PySeries, null_equal: bool, strict: bool) -> bool { - if strict && (self.series.dtype() != other.series.dtype()) { + fn equals(&self, other: &PySeries, check_dtypes: bool, null_equal: bool) -> bool { + if check_dtypes && (self.series.dtype() != other.series.dtype()) { return false; } if null_equal { diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index f3aaf8737506..43c1cae543de 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -1953,8 +1953,8 @@ def test_product() -> None: expected = pl.DataFrame( {"int": [6], "flt": [-108.0], "bool_0": [0], "bool_1": [1], "str": [None]} ) - assert_frame_not_equal(out, expected, check_dtype=True) - assert_frame_equal(out, expected, check_dtype=False) + assert_frame_not_equal(out, expected, check_dtypes=True) + assert_frame_equal(out, expected, check_dtypes=False) def test_first_last_nth_expressions(fruits_cars: pl.DataFrame) -> None: diff --git a/py-polars/tests/unit/io/test_delta.py b/py-polars/tests/unit/io/test_delta.py index b94997106ec5..ca939b2dc153 100644 --- a/py-polars/tests/unit/io/test_delta.py +++ b/py-polars/tests/unit/io/test_delta.py @@ -24,7 +24,7 @@ def test_scan_delta(delta_table_path: Path) -> None: ldf = pl.scan_delta(str(delta_table_path), version=0) expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]}) - assert_frame_equal(expected, ldf.collect(), check_dtype=False) + assert_frame_equal(expected, ldf.collect(), check_dtypes=False) def test_scan_delta_version(delta_table_path: Path) -> None: @@ -66,7 +66,7 @@ def test_scan_delta_columns(delta_table_path: Path) -> None: ldf = pl.scan_delta(str(delta_table_path), version=0).select("name") expected = pl.DataFrame({"name": ["Joey", "Ivan"]}) - assert_frame_equal(expected, ldf.collect(), check_dtype=False) + assert_frame_equal(expected, ldf.collect(), check_dtypes=False) def test_scan_delta_filesystem(delta_table_path: Path) -> None: @@ -78,7 +78,7 @@ def test_scan_delta_filesystem(delta_table_path: Path) -> None: ) expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]}) - assert_frame_equal(expected, ldf.collect(), check_dtype=False) + assert_frame_equal(expected, ldf.collect(), check_dtypes=False) def test_scan_delta_relative(delta_table_path: Path) -> None: @@ -87,7 +87,7 @@ def test_scan_delta_relative(delta_table_path: Path) -> None: ldf = pl.scan_delta(rel_delta_table_path, version=0) expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]}) - assert_frame_equal(expected, ldf.collect(), check_dtype=False) + assert_frame_equal(expected, ldf.collect(), check_dtypes=False) ldf = pl.scan_delta(rel_delta_table_path, version=1) assert_frame_not_equal(expected, ldf.collect()) @@ -97,7 +97,7 @@ def test_read_delta(delta_table_path: Path) -> None: df = pl.read_delta(str(delta_table_path), version=0) expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]}) - assert_frame_equal(expected, df, check_dtype=False) + assert_frame_equal(expected, df, check_dtypes=False) def test_read_delta_version(delta_table_path: Path) -> None: @@ -139,7 +139,7 @@ def test_read_delta_columns(delta_table_path: Path) -> None: df = pl.read_delta(str(delta_table_path), version=0, columns=["name"]) expected = pl.DataFrame({"name": ["Joey", "Ivan"]}) - assert_frame_equal(expected, df, check_dtype=False) + assert_frame_equal(expected, df, check_dtypes=False) def test_read_delta_filesystem(delta_table_path: Path) -> None: @@ -151,7 +151,7 @@ def test_read_delta_filesystem(delta_table_path: Path) -> None: ) expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]}) - assert_frame_equal(expected, df, check_dtype=False) + assert_frame_equal(expected, df, check_dtypes=False) def test_read_delta_relative(delta_table_path: Path) -> None: @@ -160,7 +160,7 @@ def test_read_delta_relative(delta_table_path: Path) -> None: df = pl.read_delta(rel_delta_table_path, version=0) expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]}) - assert_frame_equal(expected, df, check_dtype=False) + assert_frame_equal(expected, df, check_dtypes=False) @pytest.mark.write_disk() diff --git a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py index 7214e73e4931..adf75e8fbe85 100644 --- a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py +++ b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py @@ -317,7 +317,7 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: assert_frame_equal( result_frame, expected_frame, - check_dtype=(".dt." not in suggested_expression), + check_dtypes=(".dt." not in suggested_expression), ) diff --git a/py-polars/tests/unit/series/test_equals.py b/py-polars/tests/unit/series/test_equals.py index 509a8f6b3072..da607b934936 100644 --- a/py-polars/tests/unit/series/test_equals.py +++ b/py-polars/tests/unit/series/test_equals.py @@ -11,7 +11,7 @@ def test_equals() -> None: s2 = pl.Series("a", [1, 2, None], pl.Int64) assert s1.equals(s2) is True - assert s1.equals(s2, strict=True) is False + assert s1.equals(s2, check_dtypes=True) is False assert s1.equals(s2, null_equal=False) is False df = pl.DataFrame( @@ -25,7 +25,7 @@ def test_equals() -> None: s4 = df["s4"].rename("b") assert s3.equals(s4) is False - assert s3.equals(s4, strict=True) is False + assert s3.equals(s4, check_dtypes=True) is False assert s3.equals(s4, null_equal=False) is False assert s3.dt.convert_time_zone("Asia/Tokyo").equals(s4) is True @@ -91,3 +91,10 @@ def test_ne_missing_expr() -> None: result_evaluated = pl.select(result).to_series() expected = pl.Series([False, True]) assert_series_equal(result_evaluated, expected) + + +def test_series_equals_strict_deprecated() -> None: + s1 = pl.Series("a", [1.0, 2.0, None], pl.Float64) + s2 = pl.Series("a", [1, 2, None], pl.Int64) + with pytest.deprecated_call(): + assert not s1.equals(s2, strict=True) # type: ignore[call-arg] diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 2df1f1b9083a..73cd54db7505 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -1105,7 +1105,7 @@ def test_empty() -> None: assert_series_equal(pl.Series(), pl.Series()) assert_series_equal( - pl.Series(dtype=pl.Int32), pl.Series(dtype=pl.Int64), check_dtype=False + pl.Series(dtype=pl.Int32), pl.Series(dtype=pl.Int64), check_dtypes=False ) with pytest.raises(TypeError, match="ambiguous"): diff --git a/py-polars/tests/unit/sql/test_group_by.py b/py-polars/tests/unit/sql/test_group_by.py index 1e15f0a36365..9f5b2d17c6cd 100644 --- a/py-polars/tests/unit/sql/test_group_by.py +++ b/py-polars/tests/unit/sql/test_group_by.py @@ -97,7 +97,7 @@ def test_group_by_all() -> None: "n": [3, 2, 1], } ) - assert_frame_equal(expected, res, check_dtype=False) + assert_frame_equal(expected, res, check_dtypes=False) # more involved determination of agg/group columns res = df.sql( @@ -198,7 +198,7 @@ def test_group_by_ordinal_position() -> None: ORDER BY c """ ) - assert_frame_equal(res1, expected, check_dtype=False) + assert_frame_equal(res1, expected, check_dtypes=False) res2 = ctx.execute( """ diff --git a/py-polars/tests/unit/streaming/test_streaming_join.py b/py-polars/tests/unit/streaming/test_streaming_join.py index cc783fa69ed3..2470c93555fa 100644 --- a/py-polars/tests/unit/streaming/test_streaming_join.py +++ b/py-polars/tests/unit/streaming/test_streaming_join.py @@ -83,7 +83,7 @@ def test_streaming_joins() -> None: .with_columns(pl.all().cast(int)) .sort(["a", "b"], maintain_order=True) ) - assert_frame_equal(a, pl_result, check_dtype=False) + assert_frame_equal(a, pl_result, check_dtypes=False) pd_result = dfa.merge(dfb, on=["a", "b"], how=how) @@ -96,7 +96,7 @@ def test_streaming_joins() -> None: # we cast to integer because pandas joins creates floats a = pl.from_pandas(pd_result).with_columns(pl.all().cast(int)).sort(["a", "b"]) - assert_frame_equal(a, pl_result, check_dtype=False) + assert_frame_equal(a, pl_result, check_dtypes=False) def test_sorted_flag_after_streaming_join() -> None: diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 168ec05f7cfd..72ea8b7a387b 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -137,7 +137,7 @@ def test_count_suffix_10783() -> None: .name.suffix("_suffix") ) df_expect = df.with_columns(pl.Series("len_suffix", [3, 3, 1, 3])) - assert_frame_equal(df_with_cnt, df_expect, check_dtype=False) + assert_frame_equal(df_with_cnt, df_expect, check_dtypes=False) def test_or() -> None: diff --git a/py-polars/tests/unit/testing/test_assert_frame_equal.py b/py-polars/tests/unit/testing/test_assert_frame_equal.py index 69021a978ffc..417c612340ce 100644 --- a/py-polars/tests/unit/testing/test_assert_frame_equal.py +++ b/py-polars/tests/unit/testing/test_assert_frame_equal.py @@ -50,13 +50,13 @@ def test_equal(df: pl.DataFrame) -> None: pytest.param( pl.DataFrame({"a": [0.0, 1.0, 2.0]}, schema={"a": pl.Float64}), pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), - {"check_dtype": False}, + {"check_dtypes": False}, id="equal_int_float_integer_no_check_dtype", ), pytest.param( pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float64}), pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float32}), - {"check_dtype": False}, + {"check_dtypes": False}, id="equal_int_float_integer_no_check_dtype", ), pytest.param( @@ -161,7 +161,7 @@ def test_assert_frame_equal_passes_assertion( pytest.param( pl.DataFrame({"a": [[2.0, 3.0]]}), pl.DataFrame({"a": [[2, 3]]}), - {"check_exact": False, "check_dtype": True}, + {"check_exact": False, "check_dtypes": True}, id="list_of_float_list_of_int_check_dtype_true", ), pytest.param( @@ -270,7 +270,7 @@ def test_compare_frame_equal_nested_nans() -> None: assert_frame_not_equal(df3, df4) for check_dtype in (True, False): with pytest.raises(AssertionError, match="mismatch|different"): - assert_frame_equal(df3, df4, check_dtype=check_dtype) + assert_frame_equal(df3, df4, check_dtypes=check_dtype) def test_assert_frame_equal_pass() -> None: @@ -380,6 +380,18 @@ def test_assert_frame_not_equal() -> None: assert_frame_not_equal(df, df) +def test_assert_frame_equal_check_dtype_deprecated() -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [1.0, 2.0]}) + df3 = pl.DataFrame({"a": [2, 1]}) + + with pytest.deprecated_call(): + assert_frame_equal(df1, df2, check_dtype=False) # type: ignore[call-arg] + + with pytest.deprecated_call(): + assert_frame_not_equal(df1, df3, check_dtype=False) # type: ignore[call-arg] + + def test_tracebackhide(testdir: pytest.Testdir) -> None: testdir.makefile( ".py", diff --git a/py-polars/tests/unit/testing/test_assert_series_equal.py b/py-polars/tests/unit/testing/test_assert_series_equal.py index e676be77b1ac..beb3ff43674c 100644 --- a/py-polars/tests/unit/testing/test_assert_series_equal.py +++ b/py-polars/tests/unit/testing/test_assert_series_equal.py @@ -58,10 +58,10 @@ def test_compare_series_nans_assert_equal() -> None: srs5 = pl.Series([1.0, 2.0, 3.0, 4.0, nan, 6.0]) srs6 = pl.Series([1, 2, 3, 4, None, 6]) - assert_series_equal(srs4, srs6, check_dtype=False) + assert_series_equal(srs4, srs6, check_dtypes=False) with pytest.raises(AssertionError): - assert_series_equal(srs5, srs6, check_dtype=False) - assert_series_not_equal(srs5, srs6, check_dtype=True) + assert_series_equal(srs5, srs6, check_dtypes=False) + assert_series_not_equal(srs5, srs6, check_dtypes=True) # nested for float_type in (pl.Float32, pl.Float64): @@ -218,13 +218,13 @@ def test_assert_series_equal_temporal(data1: Any, data2: Any) -> None: pytest.param( pl.Series([0.0, 1.0, 2.0], dtype=pl.Float64), pl.Series([0, 1, 2], dtype=pl.Int64), - {"check_dtype": False}, + {"check_dtypes": False}, id="equal_int_float_integer_no_check_dtype", ), pytest.param( pl.Series([0, 1, 2], dtype=pl.Float64), pl.Series([0, 1, 2], dtype=pl.Float32), - {"check_dtype": False}, + {"check_dtypes": False}, id="equal_int_float_integer_no_check_dtype", ), pytest.param( @@ -290,7 +290,7 @@ def test_assert_series_equal_temporal(data1: Any, data2: Any) -> None: pytest.param( pl.Series([[2.0, 3.0]]), pl.Series([[2, 3]]), - {"check_exact": False, "check_dtype": False}, + {"check_exact": False, "check_dtypes": False}, id="list_of_float_list_of_int_check_dtype_false", ), pytest.param( @@ -383,13 +383,13 @@ def test_assert_series_equal_passes_assertion( pytest.param( pl.Series([0, 1, 2], dtype=pl.Float64), pl.Series([0, 1, 2], dtype=pl.Int64), - {"check_dtype": True}, + {"check_dtypes": True}, id="equal_int_float_integer_check_dtype", ), pytest.param( pl.Series([0, 1, 2], dtype=pl.Float64), pl.Series([0, 1, 2], dtype=pl.Float32), - {"check_dtype": True}, + {"check_dtypes": True}, id="equal_int_float_integer_check_dtype", ), pytest.param( @@ -443,19 +443,19 @@ def test_assert_series_equal_passes_assertion( pytest.param( pl.Series([[2.0, 3.0]]), pl.Series([[2, 3]]), - {"check_exact": False, "check_dtype": True}, + {"check_exact": False, "check_dtypes": True}, id="list_of_float_list_of_int_check_dtype_true", ), pytest.param( pl.struct(a=0, b=1.1, eager=True), pl.struct(a=0, b=1, eager=True), - {"atol": 0.1, "rtol": 0, "check_dtype": True}, + {"atol": 0.1, "rtol": 0, "check_dtypes": True}, id="struct_approx_equal_different_type", ), pytest.param( pl.struct(a=0, b=1.09, eager=True), pl.struct(a=0, b=1, eager=True), - {"atol": 0.1, "rtol": 0, "check_dtype": False}, + {"atol": 0.1, "rtol": 0, "check_dtypes": False}, id="struct_approx_equal_different_type", ), ], @@ -477,8 +477,8 @@ def test_assert_series_equal_categorical_vs_str() -> None: with pytest.raises(AssertionError, match="dtype mismatch"): assert_series_equal(s1, s2, categorical_as_str=True) - assert_series_equal(s1, s2, check_dtype=False, categorical_as_str=True) - assert_series_equal(s2, s1, check_dtype=False, categorical_as_str=True) + assert_series_equal(s1, s2, check_dtypes=False, categorical_as_str=True) + assert_series_equal(s2, s1, check_dtypes=False, categorical_as_str=True) def test_assert_series_equal_incompatible_data_types() -> None: @@ -486,7 +486,7 @@ def test_assert_series_equal_incompatible_data_types() -> None: s2 = pl.Series([0, 1, 0], dtype=pl.Int8) with pytest.raises(AssertionError, match="incompatible data types"): - assert_series_equal(s1, s2, check_dtype=False) + assert_series_equal(s1, s2, check_dtypes=False) def test_assert_series_equal_full_series() -> None: @@ -537,7 +537,7 @@ def test_assert_series_equal_full_null_incompatible_dtypes_raises() -> None: # You could argue this should pass, but it's rare enough not to warrant the # additional check with pytest.raises(AssertionError, match="incompatible data types"): - assert_series_equal(s1, s2, check_dtype=False) + assert_series_equal(s1, s2, check_dtypes=False) def test_assert_series_equal_full_null_nested_list() -> None: @@ -549,7 +549,7 @@ def test_assert_series_equal_full_null_nested_not_nested() -> None: s1 = pl.Series([None, None], dtype=pl.List(pl.Float64)) s2 = pl.Series([None, None], dtype=pl.Float64) - assert_series_equal(s1, s2, check_dtype=False) + assert_series_equal(s1, s2, check_dtypes=False) def test_assert_series_equal_nested_list_nan() -> None: @@ -590,7 +590,7 @@ def test_assert_series_equal_uint_always_checked_exactly() -> None: s2 = pl.Series([2, 4], dtype=pl.Int64) with pytest.raises(AssertionError): - assert_series_equal(s1, s2, atol=1, check_dtype=False) + assert_series_equal(s1, s2, atol=1, check_dtypes=False) def test_assert_series_equal_nested_int_always_checked_exactly() -> None: @@ -639,6 +639,18 @@ def test_assert_series_equal_w_large_integers_12328() -> None: assert_series_equal(left, right) +def test_assert_series_equal_check_dtype_deprecated() -> None: + s1 = pl.Series("a", [1, 2]) + s2 = pl.Series("a", [1.0, 2.0]) + s3 = pl.Series("a", [2, 1]) + + with pytest.deprecated_call(): + assert_series_equal(s1, s2, check_dtype=False) # type: ignore[call-arg] + + with pytest.deprecated_call(): + assert_series_not_equal(s1, s3, check_dtype=False) # type: ignore[call-arg] + + def test_tracebackhide(testdir: pytest.Testdir) -> None: testdir.makefile( ".py",