Skip to content

Commit

Permalink
test: Re-enable struct related tests (pola-rs#17597)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 12, 2024
1 parent 6d72fb9 commit 693181c
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 42 deletions.
18 changes: 14 additions & 4 deletions crates/polars-core/src/chunked_array/struct_/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl StructChunked2 {

let mut out = Self::from_series(self.name(), &new_fields)?;
if self.null_count > 0 {
out.merge_validities(self.chunks());
out.zip_outer_validity(self);
}
Ok(out.into_series())
},
Expand Down Expand Up @@ -228,7 +228,7 @@ impl StructChunked2 {
.collect::<PolarsResult<Vec<_>>>()?;
let mut out = Self::from_series(self.name(), &fields)?;
if self.null_count > 0 {
out.merge_validities(self.chunks());
out.zip_outer_validity(self);
}
Ok(out.into_series())
},
Expand Down Expand Up @@ -311,9 +311,19 @@ impl StructChunked2 {
}

/// Combine the validities of two structs.
/// # Panics
/// Panics if the chunks don't align.
pub fn zip_outer_validity(&mut self, other: &StructChunked2) {
if self.chunks.len() != other.chunks.len()
|| !self
.chunks
.iter()
.zip(other.chunks.iter())
.map(|(a, b)| a.len() == b.len())
.all_equal()
{
*self = self.rechunk();
let other = other.rechunk();
return self.zip_outer_validity(&other);
}
if other.null_count > 0 {
// SAFETY:
// We keep length and dtypes the same.
Expand Down
41 changes: 21 additions & 20 deletions crates/polars-core/src/series/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,18 @@ fn any_values_to_list(
inner_type: &DataType,
strict: bool,
) -> PolarsResult<ListChunked> {
let target_dtype = DataType::List(Box::new(inner_type.clone()));
let it = match inner_type {
// Structs don't support empty fields yet.
// We must ensure the data-types match what we do physical
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) if fields.is_empty() => {
DataType::Struct(vec![Field::new("", DataType::Null)])
},
_ => inner_type.clone(),
};
let target_dtype = DataType::List(Box::new(it));

// This is handled downstream. The builder will choose the first non null type.
// This is handled downstream. The builder will choose the first non-null type.
let mut valid = true;
#[allow(unused_mut)]
let mut out: ListChunked = if inner_type == &DataType::Null {
Expand Down Expand Up @@ -657,8 +666,9 @@ fn any_values_to_struct(
// The physical series fields of the struct.
let mut series_fields = Vec::with_capacity(fields.len());
let mut has_outer_validity = false;
let mut field_avs = Vec::with_capacity(values.len());
for (i, field) in fields.iter().enumerate() {
let mut field_avs = Vec::with_capacity(values.len());
field_avs.clear();

for av in values.iter() {
match av {
Expand All @@ -669,29 +679,20 @@ fn any_values_to_struct(

let mut append_by_search = || {
// Search for the name.
let mut pushed = false;
for (av_fld, av_val) in av_fields.iter().zip(av_values) {
if av_fld.name == field.name {
field_avs.push(av_val.clone());
pushed = true;
break;
}
}
if !pushed {
field_avs.push(AnyValue::Null)
if let Some(i) = av_fields
.iter()
.position(|av_fld| av_fld.name == field.name)
{
field_avs.push(av_values[i].clone());
return;
}
field_avs.push(AnyValue::Null)
};

// All fields are available in this single value.
// We can use the index to get value.
if fields.len() == av_fields.len() {
let mut search = false;
for (l, r) in fields.iter().zip(av_fields.iter()) {
if l.name() != r.name() {
search = true;
}
}
if search {
if fields.iter().zip(av_fields.iter()).any(|(l, r)| l != r) {
append_by_search()
} else {
let av_val = av_values.get(i).cloned().unwrap_or(AnyValue::Null);
Expand Down
6 changes: 1 addition & 5 deletions crates/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,11 +607,7 @@ impl Series {
let mut ca = StructChunked2::from_series(self.name(), &fields).unwrap();

if arr.null_count() > 0 {
unsafe {
ca.downcast_iter_mut()
.zip(arr.downcast_iter().map(|arr| arr.validity()))
.for_each(|(arr, validity)| arr.set_validity(validity.cloned()))
}
ca.zip_outer_validity(arr);
}
Cow::Owned(ca.into_series())
},
Expand Down
9 changes: 6 additions & 3 deletions py-polars/tests/unit/dataframe/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,23 @@
from polars._typing import SerializationFormat


@pytest.mark.skip(reason="struct-refactor")
@given(df=dataframes())
@given(
df=dataframes(
excluded_dtypes=[pl.Struct], # Outer nullability not supported
)
)
def test_df_serde_roundtrip_binary(df: pl.DataFrame) -> None:
serialized = df.serialize()
result = pl.DataFrame.deserialize(io.BytesIO(serialized), format="binary")
assert_frame_equal(result, df, categorical_as_str=True)


@pytest.mark.skip(reason="struct-refactor")
@given(
df=dataframes(
excluded_dtypes=[
pl.Float32, # Bug, see: https://github.com/pola-rs/polars/issues/17211
pl.Float64, # Bug, see: https://github.com/pola-rs/polars/issues/17211
pl.Struct, # Outer nullability not supported
],
)
)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/dataframe/test_to_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
df=dataframes(
excluded_dtypes=[
pl.Categorical, # Bug: https://github.com/pola-rs/polars/issues/16196
pl.Struct, # @pytest.mark.skip(reason="struct-refactor")
pl.Struct,
],
# Roundtrip doesn't work with time zones:
# https://github.com/pola-rs/polars/issues/16297
Expand Down
6 changes: 4 additions & 2 deletions py-polars/tests/unit/io/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,16 @@ def test_read_ndjson_empty_array() -> None:
) == {"foo": [{"bar": []}]}


@pytest.mark.skip(reason="struct-refactor")
def test_ndjson_nested_null() -> None:
json_payload = """{"foo":{"bar":[{}]}}"""
df = pl.read_ndjson(io.StringIO(json_payload))

# 'bar' represents an empty list of structs; check the schema is correct (eg: picks
# up that it IS a list of structs), but confirm that list is empty (ref: #11301)
assert df.schema == {"foo": pl.Struct([pl.Field("bar", pl.List(pl.Struct([])))])}
# We don't support empty structs yet. So Null is closest.
assert df.schema == {
"foo": pl.Struct([pl.Field("bar", pl.List(pl.Struct({"": pl.Null})))])
}
assert df.to_dict(as_series=False) == {"foo": [{"bar": []}]}


Expand Down
10 changes: 7 additions & 3 deletions py-polars/tests/unit/lazyframe/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,26 @@
from polars._typing import SerializationFormat


@pytest.mark.skip(reason="struct-refactor")
@given(lf=dataframes(lazy=True))
@given(
lf=dataframes(
lazy=True,
excluded_dtypes=[pl.Struct],
)
)
@example(lf=pl.LazyFrame({"foo": ["a", "b", "a"]}, schema={"foo": pl.Enum(["b", "a"])}))
def test_lf_serde_roundtrip_binary(lf: pl.LazyFrame) -> None:
serialized = lf.serialize(format="binary")
result = pl.LazyFrame.deserialize(io.BytesIO(serialized), format="binary")
assert_frame_equal(result, lf, categorical_as_str=True)


@pytest.mark.skip(reason="struct-refactor")
@given(
lf=dataframes(
lazy=True,
excluded_dtypes=[
pl.Float32, # Bug, see: https://github.com/pola-rs/polars/issues/17211
pl.Float64, # Bug, see: https://github.com/pola-rs/polars/issues/17211
pl.Struct, # Outer nullability not supported
],
)
)
Expand Down
2 changes: 0 additions & 2 deletions py-polars/tests/unit/series/test_to_list.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import pytest
from hypothesis import example, given

import polars as pl
Expand All @@ -16,7 +15,6 @@
)
)
@example(s=pl.Series(dtype=pl.Array(pl.Date, 1)))
@pytest.mark.skip(reason="struct-refactor")
def test_to_list(s: pl.Series) -> None:
values = s.to_list()
result = pl.Series(values, dtype=s.dtype)
Expand Down
1 change: 0 additions & 1 deletion py-polars/tests/unit/testing/test_assert_frame_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


@given(df=dataframes())
@pytest.mark.skip(reason="struct-refactor")
def test_equal(df: pl.DataFrame) -> None:
assert_frame_equal(df, df.clone(), check_exact=True)

Expand Down
1 change: 0 additions & 1 deletion py-polars/tests/unit/testing/test_assert_series_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@


@given(s=series())
@pytest.mark.skip(reason="struct-refactor")
def test_assert_series_equal_parametric(s: pl.Series) -> None:
assert_series_equal(s, s)

Expand Down

0 comments on commit 693181c

Please sign in to comment.