Skip to content

Commit

Permalink
feat!: Add check_names parameter to Series.equals and default to …
Browse files Browse the repository at this point in the history
…`False` (#16610)
  • Loading branch information
stinodego authored Jun 4, 2024
1 parent 889b00a commit 9ea3bb6
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 19 deletions.
32 changes: 15 additions & 17 deletions crates/polars-core/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ use crate::prelude::*;
impl Series {
/// Check if series are equal. Note that `None == None` evaluates to `false`
pub fn equals(&self, other: &Series) -> bool {
if self.null_count() > 0 || other.null_count() > 0 || self.dtype() != other.dtype() {
if self.null_count() > 0 || other.null_count() > 0 {
false
} else {
self.equals_missing(other)
}
}

/// Check if all values in series are equal where `None == None` evaluates to `true`.
/// Two [`Datetime`](DataType::Datetime) series are *not* equal if their timezones are different, regardless
/// if they represent the same UTC time or not.
pub fn equals_missing(&self, other: &Series) -> bool {
match (self.dtype(), other.dtype()) {
// Two [`Datetime`](DataType::Datetime) series are *not* equal if their timezones
// are different, regardless if they represent the same UTC time or not.
#[cfg(feature = "timezones")]
(DataType::Datetime(_, tz_lhs), DataType::Datetime(_, tz_rhs)) => {
if tz_lhs != tz_rhs {
Expand All @@ -27,17 +27,14 @@ impl Series {
_ => {},
}

// differences from Partial::eq in that numerical dtype may be different
self.len() == other.len()
&& self.name() == other.name()
&& self.null_count() == other.null_count()
&& {
let eq = self.equal_missing(other);
match eq {
Ok(b) => b.all(),
Err(_) => false,
}
// Differs from Partial::eq in that numerical dtype may be different
self.len() == other.len() && self.null_count() == other.null_count() && {
let eq = self.equal_missing(other);
match eq {
Ok(b) => b.all(),
Err(_) => false,
}
}
}

/// Get a pointer to the underlying data of this [`Series`].
Expand Down Expand Up @@ -99,7 +96,7 @@ impl DataFrame {
return false;
}
for (left, right) in self.get_columns().iter().zip(other.get_columns()) {
if !left.equals(right) {
if left.name() != right.name() || !left.equals(right) {
return false;
}
}
Expand All @@ -125,7 +122,7 @@ impl DataFrame {
return false;
}
for (left, right) in self.get_columns().iter().zip(other.get_columns()) {
if !left.equals_missing(right) {
if left.name() != right.name() || !left.equals_missing(right) {
return false;
}
}
Expand Down Expand Up @@ -191,10 +188,11 @@ mod test {
}

#[test]
fn test_series_dtype_noteq() {
fn test_series_dtype_not_equal() {
let s_i32 = Series::new("a", &[1_i32, 2_i32]);
let s_i64 = Series::new("a", &[1_i64, 2_i64]);
assert!(!s_i32.equals(&s_i64));
assert!(s_i32.dtype() != s_i64.dtype());
assert!(s_i32.equals(&s_i64));
}

#[test]
Expand Down
8 changes: 7 additions & 1 deletion py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4120,6 +4120,7 @@ def equals(
other: Series,
*,
check_dtypes: bool = False,
check_names: bool = False,
null_equal: bool = True,
) -> bool:
"""
Expand All @@ -4131,6 +4132,8 @@ def equals(
Series to compare with.
check_dtypes
Require data types to match.
check_names
Require names to match.
null_equal
Consider null values as equal.
Expand All @@ -4148,7 +4151,10 @@ def equals(
False
"""
return self._s.equals(
other._s, check_dtypes=check_dtypes, null_equal=null_equal
other._s,
check_dtypes=check_dtypes,
check_names=check_names,
null_equal=null_equal,
)

def cast(
Expand Down
11 changes: 10 additions & 1 deletion py-polars/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,19 @@ impl PySeries {
self.series.has_validity()
}

fn equals(&self, other: &PySeries, check_dtypes: bool, null_equal: bool) -> bool {
fn equals(
&self,
other: &PySeries,
check_dtypes: bool,
check_names: bool,
null_equal: bool,
) -> bool {
if check_dtypes && (self.series.dtype() != other.series.dtype()) {
return false;
}
if check_names && (self.series.name() != other.series.name()) {
return false;
}
if null_equal {
self.series.equals_missing(&other.series)
} else {
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/unit/series/test_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ def test_equals() -> None:
assert s3.dt.convert_time_zone("Asia/Tokyo").equals(s4) is True


def test_series_equals_check_names() -> None:
s1 = pl.Series("foo", [1, 2, 3])
s2 = pl.Series("bar", [1, 2, 3])
assert s1.equals(s2) is True
assert s1.equals(s2, check_names=True) is False


def test_eq_list_cmp_list() -> None:
s = pl.Series([[1], [1, 2]])
result = s == [1, 2]
Expand Down

0 comments on commit 9ea3bb6

Please sign in to comment.