Skip to content

Commit

Permalink
fix: Don't access out-of-bounds for null indices in bitmap gather (po…
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Mar 9, 2024
1 parent 6b23f79 commit dc95ac8
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 10 deletions.
27 changes: 26 additions & 1 deletion crates/polars-arrow/src/compute/take/bitmap.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,38 @@
use polars_utils::IdxSize;

use crate::array::Array;
use crate::bitmap::Bitmap;
use crate::datatypes::IdxArr;

/// # Safety
/// doesn't do any bound checks
/// Doesn't do any bound checks.
pub unsafe fn take_bitmap_unchecked(values: &Bitmap, indices: &[IdxSize]) -> Bitmap {
let values = indices.iter().map(|&index| {
debug_assert!((index as usize) < values.len());
values.get_bit_unchecked(index as usize)
});
Bitmap::from_trusted_len_iter(values)
}

/// # Safety
/// Doesn't check bounds for non-null elements.
pub unsafe fn take_bitmap_nulls_unchecked(values: &Bitmap, indices: &IdxArr) -> Bitmap {
// Fast-path: no need to bother with null indices.
if indices.null_count() == 0 {
return take_bitmap_unchecked(values, indices.values());
}

if values.is_empty() {
// Nothing can be in-bounds, assume indices is full-null.
debug_assert!(indices.null_count() == indices.len());
return Bitmap::new_zeroed(indices.len());
}

let values = indices.iter().map(|opt_index| {
// We checked that values.len() > 0 so we can use index 0 for nulls.
let index = opt_index.copied().unwrap_or(0) as usize;
debug_assert!(index < values.len());
values.get_bit_unchecked(index)
});
Bitmap::from_trusted_len_iter(values)
}
14 changes: 6 additions & 8 deletions crates/polars-arrow/src/compute/take/boolean.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use polars_utils::IdxSize;

use super::bitmap::take_bitmap_unchecked;
use super::bitmap::{take_bitmap_nulls_unchecked, take_bitmap_unchecked};
use crate::array::{Array, BooleanArray, PrimitiveArray};
use crate::bitmap::{Bitmap, MutableBitmap};

// take implementation when neither values nor indices contain nulls
// Take implementation when neither values nor indices contain nulls.
unsafe fn take_no_validity(values: &Bitmap, indices: &[IdxSize]) -> (Bitmap, Option<Bitmap>) {
(take_bitmap_unchecked(values, indices), None)
}

// take implementation when only values contain nulls
// Take implementation when only values contain nulls.
unsafe fn take_values_validity(
values: &BooleanArray,
indices: &[IdxSize],
Expand All @@ -23,18 +23,16 @@ unsafe fn take_values_validity(
(buffer, validity.into())
}

// take implementation when only indices contain nulls
// Take implementation when only indices contain nulls.
unsafe fn take_indices_validity(
values: &Bitmap,
indices: &PrimitiveArray<IdxSize>,
) -> (Bitmap, Option<Bitmap>) {
// simply take all and copy the bitmap
let buffer = take_bitmap_unchecked(values, indices.values());

let buffer = take_bitmap_nulls_unchecked(values, indices);
(buffer, indices.validity().cloned())
}

// take implementation when both values and indices contain nulls
// Take implementation when both values and indices contain nulls.
unsafe fn take_values_indices_validity(
values: &BooleanArray,
indices: &PrimitiveArray<IdxSize>,
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/compute/take/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub(super) unsafe fn take_unchecked(array: &StructArray, indices: &IdxArr) -> St

let validity = array
.validity()
.map(|b| super::bitmap::take_bitmap_unchecked(b, indices.values()));
.map(|b| super::bitmap::take_bitmap_nulls_unchecked(b, indices));
let validity = combine_validities_and(validity.as_ref(), indices.validity());
StructArray::new(array.data_type().clone(), values, validity)
}
23 changes: 23 additions & 0 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,3 +752,26 @@ def test_list_median(data_dispersion: pl.DataFrame) -> None:
)

assert_frame_equal(result, expected)


def test_list_gather_null_struct_14927() -> None:
df = pl.DataFrame(
[
{
"index": 0,
"col_0": [{"field_0": 1.0}],
},
{
"index": 1,
"col_0": None,
},
]
)

expected = pl.DataFrame(
{"index": [1], "col_0": [None], "field_0": [None]},
schema={**df.schema, "field_0": pl.Float64},
)
expr = pl.col("col_0").list.get(0).struct.field("field_0")
out = df.filter(pl.col("index") > 0).with_columns(expr)
assert_frame_equal(out, expected)

0 comments on commit dc95ac8

Please sign in to comment.