Skip to content

Commit

Permalink
depr(python): Rename Series.equals parameter strict to `check_dty…
Browse files Browse the repository at this point in the history
…pes` and rename assertion utils parameter `check_dtype` to `check_dtypes` (#16573)
  • Loading branch information
stinodego authored May 30, 2024
1 parent f1be8d9 commit d190e02
Show file tree
Hide file tree
Showing 16 changed files with 107 additions and 64 deletions.
2 changes: 1 addition & 1 deletion crates/polars-core/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl Series {
&& {
let eq = self.equal_missing(other);
match eq {
Ok(b) => b.sum().map(|s| s as usize).unwrap_or(0) == self.len(),
Ok(b) => b.all(),
Err(_) => false,
}
}
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4954,7 +4954,7 @@ def equals(self, other: DataFrame, *, null_equal: bool = True) -> bool:
>>> df1.equals(df2)
False
"""
return self._df.equals(other._df, null_equal)
return self._df.equals(other._df, null_equal=null_equal)

@deprecate_function(
"DataFrame.replace is deprecated and will be removed in a future version. "
Expand Down
18 changes: 12 additions & 6 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4111,8 +4111,13 @@ def explode(self) -> Series:
]
"""

@deprecate_renamed_parameter("strict", "check_dtypes", version="0.20.31")
def equals(
self, other: Series, *, null_equal: bool = True, strict: bool = False
self,
other: Series,
*,
check_dtypes: bool = False,
null_equal: bool = True,
) -> bool:
"""
Check whether the Series is equal to another Series.
Expand All @@ -4121,11 +4126,10 @@ def equals(
----------
other
Series to compare with.
check_dtypes
Require data types to match.
null_equal
Consider null values as equal.
strict
Don't allow different numerical dtypes, e.g. comparing `pl.UInt32` with a
`pl.Int64` will return `False`.
See Also
--------
Expand All @@ -4140,7 +4144,9 @@ def equals(
>>> s1.equals(s2)
False
"""
return self._s.equals(other._s, null_equal, strict)
return self._s.equals(
other._s, check_dtypes=check_dtypes, null_equal=null_equal
)

def cast(
self,
Expand Down Expand Up @@ -8112,7 +8118,7 @@ def series_equal(
Don't allow different numerical dtypes, e.g. comparing `pl.UInt32` with a
`pl.Int64` will return `False`.
"""
return self.equals(other, null_equal=null_equal, strict=strict)
return self.equals(other, check_dtypes=strict, null_equal=null_equal)

# Keep the `list` and `str` properties below at the end of the definition of Series,
# as to not confuse mypy with the type annotation `str` and `list`
Expand Down
19 changes: 11 additions & 8 deletions py-polars/polars/testing/asserts/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@

from typing import cast

from polars._utils.deprecation import deprecate_renamed_parameter
from polars.dataframe import DataFrame
from polars.exceptions import ComputeError, InvalidAssert
from polars.lazyframe import LazyFrame
from polars.testing.asserts.series import _assert_series_values_equal
from polars.testing.asserts.utils import raise_assertion_error


@deprecate_renamed_parameter("check_dtype", "check_dtypes", version="0.20.31")
def assert_frame_equal(
left: DataFrame | LazyFrame,
right: DataFrame | LazyFrame,
*,
check_row_order: bool = True,
check_column_order: bool = True,
check_dtype: bool = True,
check_dtypes: bool = True,
check_exact: bool = False,
rtol: float = 1e-5,
atol: float = 1e-8,
Expand All @@ -41,7 +43,7 @@ def assert_frame_equal(
frames that contain unsortable columns.
check_column_order
Require column order to match.
check_dtype
check_dtypes
Require data types to match.
check_exact
Require float values to match exactly. If set to `False`, values are considered
Expand Down Expand Up @@ -94,7 +96,7 @@ def assert_frame_equal(
left,
right,
check_column_order=check_column_order,
check_dtype=check_dtype,
check_dtypes=check_dtypes,
objects=objects,
)

Expand Down Expand Up @@ -153,7 +155,7 @@ def _assert_frame_schema_equal(
left: DataFrame | LazyFrame,
right: DataFrame | LazyFrame,
*,
check_dtype: bool,
check_dtypes: bool,
check_column_order: bool,
objects: str,
) -> None:
Expand Down Expand Up @@ -181,7 +183,7 @@ def _assert_frame_schema_equal(
detail = "columns are not in the same order"
raise_assertion_error(objects, detail, left_columns, right_columns)

if check_dtype:
if check_dtypes:
left_schema_dict, right_schema_dict = dict(left_schema), dict(right_schema)
if check_column_order or left_schema_dict != right_schema_dict:
detail = "dtypes do not match"
Expand All @@ -199,13 +201,14 @@ def _sort_dataframes(left: DataFrame, right: DataFrame) -> tuple[DataFrame, Data
return left, right


@deprecate_renamed_parameter("check_dtype", "check_dtypes", version="0.20.31")
def assert_frame_not_equal(
left: DataFrame | LazyFrame,
right: DataFrame | LazyFrame,
*,
check_row_order: bool = True,
check_column_order: bool = True,
check_dtype: bool = True,
check_dtypes: bool = True,
check_exact: bool = False,
rtol: float = 1e-5,
atol: float = 1e-8,
Expand All @@ -230,7 +233,7 @@ def assert_frame_not_equal(
frames that contain unsortable columns.
check_column_order
Require column order to match.
check_dtype
check_dtypes
Require data types to match.
check_exact
Require float values to match exactly. If set to `False`, values are considered
Expand Down Expand Up @@ -267,7 +270,7 @@ def assert_frame_not_equal(
right=right,
check_column_order=check_column_order,
check_row_order=check_row_order,
check_dtype=check_dtype,
check_dtypes=check_dtypes,
check_exact=check_exact,
rtol=rtol,
atol=atol,
Expand Down
15 changes: 9 additions & 6 deletions py-polars/polars/testing/asserts/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING

from polars._utils.deprecation import deprecate_renamed_parameter
from polars.datatypes import (
FLOAT_DTYPES,
Array,
Expand All @@ -19,11 +20,12 @@
from polars import DataType


@deprecate_renamed_parameter("check_dtype", "check_dtypes", version="0.20.31")
def assert_series_equal(
left: Series,
right: Series,
*,
check_dtype: bool = True,
check_dtypes: bool = True,
check_names: bool = True,
check_exact: bool = False,
rtol: float = 1e-5,
Expand All @@ -42,7 +44,7 @@ def assert_series_equal(
The first Series to compare.
right
The second Series to compare.
check_dtype
check_dtypes
Require data types to match.
check_names
Require names to match.
Expand Down Expand Up @@ -99,7 +101,7 @@ def assert_series_equal(
if check_names and left.name != right.name:
raise_assertion_error("Series", "name mismatch", left.name, right.name)

if check_dtype and left.dtype != right.dtype:
if check_dtypes and left.dtype != right.dtype:
raise_assertion_error("Series", "dtype mismatch", left.dtype, right.dtype)

_assert_series_values_equal(
Expand Down Expand Up @@ -295,11 +297,12 @@ def _assert_series_values_within_tolerance(
)


@deprecate_renamed_parameter("check_dtype", "check_dtypes", version="0.20.31")
def assert_series_not_equal(
left: Series,
right: Series,
*,
check_dtype: bool = True,
check_dtypes: bool = True,
check_names: bool = True,
check_exact: bool = False,
rtol: float = 1e-5,
Expand All @@ -317,7 +320,7 @@ def assert_series_not_equal(
The first Series to compare.
right
The second Series to compare.
check_dtype
check_dtypes
Require data types to match.
check_names
Require names to match.
Expand Down Expand Up @@ -355,7 +358,7 @@ def assert_series_not_equal(
assert_series_equal(
left=left,
right=right,
check_dtype=check_dtype,
check_dtypes=check_dtypes,
check_names=check_names,
check_exact=check_exact,
rtol=rtol,
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,8 @@ impl PySeries {
self.series.has_validity()
}

fn equals(&self, other: &PySeries, null_equal: bool, strict: bool) -> bool {
if strict && (self.series.dtype() != other.series.dtype()) {
fn equals(&self, other: &PySeries, check_dtypes: bool, null_equal: bool) -> bool {
if check_dtypes && (self.series.dtype() != other.series.dtype()) {
return false;
}
if null_equal {
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1953,8 +1953,8 @@ def test_product() -> None:
expected = pl.DataFrame(
{"int": [6], "flt": [-108.0], "bool_0": [0], "bool_1": [1], "str": [None]}
)
assert_frame_not_equal(out, expected, check_dtype=True)
assert_frame_equal(out, expected, check_dtype=False)
assert_frame_not_equal(out, expected, check_dtypes=True)
assert_frame_equal(out, expected, check_dtypes=False)


def test_first_last_nth_expressions(fruits_cars: pl.DataFrame) -> None:
Expand Down
16 changes: 8 additions & 8 deletions py-polars/tests/unit/io/test_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_scan_delta(delta_table_path: Path) -> None:
ldf = pl.scan_delta(str(delta_table_path), version=0)

expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]})
assert_frame_equal(expected, ldf.collect(), check_dtype=False)
assert_frame_equal(expected, ldf.collect(), check_dtypes=False)


def test_scan_delta_version(delta_table_path: Path) -> None:
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_scan_delta_columns(delta_table_path: Path) -> None:
ldf = pl.scan_delta(str(delta_table_path), version=0).select("name")

expected = pl.DataFrame({"name": ["Joey", "Ivan"]})
assert_frame_equal(expected, ldf.collect(), check_dtype=False)
assert_frame_equal(expected, ldf.collect(), check_dtypes=False)


def test_scan_delta_filesystem(delta_table_path: Path) -> None:
Expand All @@ -78,7 +78,7 @@ def test_scan_delta_filesystem(delta_table_path: Path) -> None:
)

expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]})
assert_frame_equal(expected, ldf.collect(), check_dtype=False)
assert_frame_equal(expected, ldf.collect(), check_dtypes=False)


def test_scan_delta_relative(delta_table_path: Path) -> None:
Expand All @@ -87,7 +87,7 @@ def test_scan_delta_relative(delta_table_path: Path) -> None:
ldf = pl.scan_delta(rel_delta_table_path, version=0)

expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]})
assert_frame_equal(expected, ldf.collect(), check_dtype=False)
assert_frame_equal(expected, ldf.collect(), check_dtypes=False)

ldf = pl.scan_delta(rel_delta_table_path, version=1)
assert_frame_not_equal(expected, ldf.collect())
Expand All @@ -97,7 +97,7 @@ def test_read_delta(delta_table_path: Path) -> None:
df = pl.read_delta(str(delta_table_path), version=0)

expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]})
assert_frame_equal(expected, df, check_dtype=False)
assert_frame_equal(expected, df, check_dtypes=False)


def test_read_delta_version(delta_table_path: Path) -> None:
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_read_delta_columns(delta_table_path: Path) -> None:
df = pl.read_delta(str(delta_table_path), version=0, columns=["name"])

expected = pl.DataFrame({"name": ["Joey", "Ivan"]})
assert_frame_equal(expected, df, check_dtype=False)
assert_frame_equal(expected, df, check_dtypes=False)


def test_read_delta_filesystem(delta_table_path: Path) -> None:
Expand All @@ -151,7 +151,7 @@ def test_read_delta_filesystem(delta_table_path: Path) -> None:
)

expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]})
assert_frame_equal(expected, df, check_dtype=False)
assert_frame_equal(expected, df, check_dtypes=False)


def test_read_delta_relative(delta_table_path: Path) -> None:
Expand All @@ -160,7 +160,7 @@ def test_read_delta_relative(delta_table_path: Path) -> None:
df = pl.read_delta(rel_delta_table_path, version=0)

expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]})
assert_frame_equal(expected, df, check_dtype=False)
assert_frame_equal(expected, df, check_dtypes=False)


@pytest.mark.write_disk()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None:
assert_frame_equal(
result_frame,
expected_frame,
check_dtype=(".dt." not in suggested_expression),
check_dtypes=(".dt." not in suggested_expression),
)


Expand Down
11 changes: 9 additions & 2 deletions py-polars/tests/unit/series/test_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_equals() -> None:
s2 = pl.Series("a", [1, 2, None], pl.Int64)

assert s1.equals(s2) is True
assert s1.equals(s2, strict=True) is False
assert s1.equals(s2, check_dtypes=True) is False
assert s1.equals(s2, null_equal=False) is False

df = pl.DataFrame(
Expand All @@ -25,7 +25,7 @@ def test_equals() -> None:
s4 = df["s4"].rename("b")

assert s3.equals(s4) is False
assert s3.equals(s4, strict=True) is False
assert s3.equals(s4, check_dtypes=True) is False
assert s3.equals(s4, null_equal=False) is False
assert s3.dt.convert_time_zone("Asia/Tokyo").equals(s4) is True

Expand Down Expand Up @@ -91,3 +91,10 @@ def test_ne_missing_expr() -> None:
result_evaluated = pl.select(result).to_series()
expected = pl.Series([False, True])
assert_series_equal(result_evaluated, expected)


def test_series_equals_strict_deprecated() -> None:
s1 = pl.Series("a", [1.0, 2.0, None], pl.Float64)
s2 = pl.Series("a", [1, 2, None], pl.Int64)
with pytest.deprecated_call():
assert not s1.equals(s2, strict=True) # type: ignore[call-arg]
2 changes: 1 addition & 1 deletion py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ def test_empty() -> None:

assert_series_equal(pl.Series(), pl.Series())
assert_series_equal(
pl.Series(dtype=pl.Int32), pl.Series(dtype=pl.Int64), check_dtype=False
pl.Series(dtype=pl.Int32), pl.Series(dtype=pl.Int64), check_dtypes=False
)

with pytest.raises(TypeError, match="ambiguous"):
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/sql/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_group_by_all() -> None:
"n": [3, 2, 1],
}
)
assert_frame_equal(expected, res, check_dtype=False)
assert_frame_equal(expected, res, check_dtypes=False)

# more involved determination of agg/group columns
res = df.sql(
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_group_by_ordinal_position() -> None:
ORDER BY c
"""
)
assert_frame_equal(res1, expected, check_dtype=False)
assert_frame_equal(res1, expected, check_dtypes=False)

res2 = ctx.execute(
"""
Expand Down
Loading

0 comments on commit d190e02

Please sign in to comment.