diff --git a/crates/polars-core/src/chunked_array/ops/zip.rs b/crates/polars-core/src/chunked_array/ops/zip.rs index 61f530280324..71c23d5de2bf 100644 --- a/crates/polars-core/src/chunked_array/ops/zip.rs +++ b/crates/polars-core/src/chunked_array/ops/zip.rs @@ -271,7 +271,7 @@ impl ChunkZip for StructChunked { let if_true = if_true.as_ref(); let if_false = if_false.as_ref(); - let (l, r, mask) = align_chunks_ternary(if_true, if_false, mask); + let (if_true, if_false, mask) = align_chunks_ternary(if_true, if_false, mask); // Prepare the boolean arrays such that Null maps to false. // This prevents every field doing that. @@ -287,10 +287,10 @@ impl ChunkZip for StructChunked { } // Zip all the fields. - let fields = l + let fields = if_true .fields_as_series() .iter() - .zip(r.fields_as_series()) + .zip(if_false.fields_as_series()) .map(|(lhs, rhs)| lhs.zip_with_same_type(&mask, &rhs)) .collect::>>()?; @@ -330,138 +330,145 @@ impl ChunkZip for StructChunked { // We need to take two things into account: // 1. The chunk lengths of `out` might not necessarily match `l`, `r` and `mask`. // 2. `l` and `r` might still need to be broadcasted. - if (l.null_count + r.null_count) > 0 { + if (if_true.null_count + if_false.null_count) > 0 { // Create one validity mask that spans the entirety of out. - let rechunked_validity = match (l.len(), r.len()) { - (1, 1) if length != 1 => match (l.null_count() == 0, r.null_count() == 0) { - (true, true) => None, - (false, true) => { - if mask.chunks().len() == 1 { - let m = mask.chunks()[0] - .as_any() - .downcast_ref::() - .unwrap() - .values(); - Some(!m) - } else { - rechunk_bitmaps( - length, - mask.downcast_iter().map(|m| (m.len(), Some(!m.values()))), - ) - } - }, - (true, false) => { - if mask.chunks().len() == 1 { - let m = mask.chunks()[0] - .as_any() - .downcast_ref::() - .unwrap() - .values(); - Some(m.clone()) - } else { - rechunk_bitmaps( - length, - mask.downcast_iter() - .map(|m| (m.len(), Some(m.values().clone()))), - ) - } - }, - (false, false) => Some(Bitmap::new_zeroed(length)), + let rechunked_validity = match (if_true.len(), if_false.len()) { + (1, 1) if length != 1 => { + match (if_true.null_count() == 0, if_false.null_count() == 0) { + (true, true) => None, + (false, true) => { + if mask.chunks().len() == 1 { + let m = mask.chunks()[0] + .as_any() + .downcast_ref::() + .unwrap() + .values(); + Some(!m) + } else { + rechunk_bitmaps( + length, + mask.downcast_iter() + .map(|m| (m.len(), Some(m.values().clone()))), + ) + } + }, + (true, false) => { + if mask.chunks().len() == 1 { + let m = mask.chunks()[0] + .as_any() + .downcast_ref::() + .unwrap() + .values(); + Some(m.clone()) + } else { + rechunk_bitmaps( + length, + mask.downcast_iter().map(|m| (m.len(), Some(!m.values()))), + ) + } + }, + (false, false) => Some(Bitmap::new_zeroed(length)), + } }, (1, _) if length != 1 => { - debug_assert!(r + debug_assert!(if_false .chunk_lengths() .zip(mask.chunk_lengths()) .all(|(r, m)| r == m)); - let combine = if l.null_count() == 0 { - |r: Option<&Bitmap>, m: &Bitmap| r.map(|r| arrow::bitmap::or(r, m)) + let combine = if if_true.null_count() == 0 { + |if_false: Option<&Bitmap>, m: &Bitmap| { + if_false.map(|v| arrow::bitmap::or(v, m)) + } } else { - |r: Option<&Bitmap>, m: &Bitmap| { - Some(r.map_or_else(|| m.clone(), |r| arrow::bitmap::and_not(r, m))) + |if_false: Option<&Bitmap>, m: &Bitmap| { + Some(if_false.map_or_else(|| !m, |v| arrow::bitmap::and_not(v, m))) } }; - if r.chunks().len() == 1 { - let r = r.chunks()[0].validity(); + if if_false.chunks().len() == 1 { + let if_false = if_false.chunks()[0].validity(); let m = mask.chunks()[0] .as_any() .downcast_ref::() .unwrap() .values(); - let validity = combine(r, m); - validity.and_then(|v| (v.unset_bits() > 0).then_some(v)) + let validity = combine(if_false, m); + validity.filter(|v| v.unset_bits() > 0) } else { rechunk_bitmaps( length, - r.chunks() - .iter() - .zip(mask.downcast_iter()) - .map(|(chunk, mask)| { + if_false.chunks().iter().zip(mask.downcast_iter()).map( + |(chunk, mask)| { (mask.len(), combine(chunk.validity(), mask.values())) - }), + }, + ), ) } }, (_, 1) if length != 1 => { - debug_assert!(l + debug_assert!(if_true .chunk_lengths() .zip(mask.chunk_lengths()) .all(|(l, m)| l == m)); - let combine = if r.null_count() == 0 { - |l: Option<&Bitmap>, m: &Bitmap| l.map(|l| arrow::bitmap::or_not(l, m)) + let combine = if if_false.null_count() == 0 { + |if_true: Option<&Bitmap>, m: &Bitmap| { + if_true.map(|v| arrow::bitmap::or_not(v, m)) + } } else { - |l: Option<&Bitmap>, m: &Bitmap| { - Some(l.map_or_else(|| m.clone(), |l| arrow::bitmap::and(l, m))) + |if_true: Option<&Bitmap>, m: &Bitmap| { + Some(if_true.map_or_else(|| m.clone(), |v| arrow::bitmap::and(v, m))) } }; - if l.chunks().len() == 1 { - let l = l.chunks()[0].validity(); + if if_true.chunks().len() == 1 { + let if_true = if_true.chunks()[0].validity(); let m = mask.chunks()[0] .as_any() .downcast_ref::() .unwrap() .values(); - let validity = combine(l, m); - validity.and_then(|v| (v.unset_bits() > 0).then_some(v)) + let validity = combine(if_true, m); + validity.filter(|v| v.unset_bits() > 0) } else { rechunk_bitmaps( length, - l.chunks() - .iter() - .zip(mask.downcast_iter()) - .map(|(chunk, mask)| { + if_true.chunks().iter().zip(mask.downcast_iter()).map( + |(chunk, mask)| { (mask.len(), combine(chunk.validity(), mask.values())) - }), + }, + ), ) } }, (_, _) => { - debug_assert!(l + debug_assert!(if_true .chunk_lengths() - .zip(r.chunk_lengths()) + .zip(if_false.chunk_lengths()) .all(|(l, r)| l == r)); - debug_assert!(l + debug_assert!(if_true .chunk_lengths() .zip(mask.chunk_lengths()) .all(|(l, r)| l == r)); - let validities = l + let validities = if_true .chunks() .iter() - .zip(r.chunks()) + .zip(if_false.chunks()) .map(|(l, r)| (l.validity(), r.validity())); rechunk_bitmaps( length, validities .zip(mask.downcast_iter()) - .map(|((lv, rv), mask)| { - (mask.len(), if_then_else_validity(mask.values(), lv, rv)) + .map(|((if_true, if_false), mask)| { + ( + mask.len(), + if_then_else_validity(mask.values(), if_true, if_false), + ) }), ) }, diff --git a/py-polars/tests/unit/functions/test_when_then.py b/py-polars/tests/unit/functions/test_when_then.py index dc458086c943..6e79c874a01b 100644 --- a/py-polars/tests/unit/functions/test_when_then.py +++ b/py-polars/tests/unit/functions/test_when_then.py @@ -690,3 +690,49 @@ def test_when_then_chunked_structs_18673() -> None: df.select(pl.when(pl.col.b).then(pl.first("x")).otherwise(pl.first("x"))), pl.DataFrame({"x": [{"a": 1}, {"a": 1}]}), ) + + +some_scalar = pl.Series("a", [{"x": 2}], pl.Struct) +none_scalar = pl.Series("a", [None], pl.Struct({"x": pl.Int64})) +column = pl.Series("a", [{"x": 2}, {"x": 2}], pl.Struct) + + +@pytest.mark.parametrize( + "values", + [ + (some_scalar, some_scalar), + (some_scalar, pl.col.a), + (some_scalar, none_scalar), + (some_scalar, column), + (none_scalar, pl.col.a), + (none_scalar, none_scalar), + (none_scalar, column), + (pl.col.a, pl.col.a), + (pl.col.a, column), + (column, column), + ], +) +def test_struct_when_then_broadcasting_combinations_19122( + values: tuple[Any, Any], +) -> None: + lv, rv = values + + df = pl.Series("a", [{"x": 1}, {"x": 1}], pl.Struct).to_frame() + + assert_frame_equal( + df.select( + pl.when(pl.col.a.struct.field("x") == 0).then(lv).otherwise(rv).alias("a") + ), + df.select( + pl.when(pl.col.a.struct.field("x") == 0).then(None).otherwise(rv).alias("a") + ), + ) + + assert_frame_equal( + df.select( + pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(lv).alias("a") + ), + df.select( + pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(None).alias("a") + ), + )