Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rust, python): Add null_on_oob parameter to expr.list.get #15395

Merged
merged 6 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions crates/polars-arrow/src/legacy/kernels/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ pub fn sublist_get(arr: &ListArray<i64>, index: i64) -> ArrayRef {
unsafe { take_unchecked(&**values, &take_by) }
}

/// Check if an index is out of bounds for at least one sublist.
pub fn index_is_oob(arr: &ListArray<i64>, index: i64) -> bool {
arr.offsets()
.lengths()
.any(|len| index.negative_to_usize(len).is_none())
}

/// Convert a list `[1, 2, 3]` to a list type of `[[1], [2], [3]]`
pub fn array_to_unit_list(array: ArrayRef) -> ListArray<i64> {
let len = array.len();
Expand Down
8 changes: 6 additions & 2 deletions crates/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt::Write;

use arrow::array::ValueSize;
use arrow::legacy::kernels::list::sublist_get;
use arrow::legacy::kernels::list::{index_is_oob, sublist_get};
use polars_core::chunked_array::builder::get_list_builder;
#[cfg(feature = "list_gather")]
use polars_core::export::num::ToPrimitive;
Expand Down Expand Up @@ -342,8 +342,12 @@ pub trait ListNameSpaceImpl: AsList {
/// So index `0` would return the first item of every sublist
/// and index `-1` would return the last item of every sublist
/// if an index is out of bounds, it will return a `None`.
fn lst_get(&self, idx: i64) -> PolarsResult<Series> {
fn lst_get(&self, idx: i64, null_on_oob: bool) -> PolarsResult<Series> {
let ca = self.as_list();
if !null_on_oob && ca.downcast_iter().any(|arr| index_is_oob(arr, idx)) {
polars_bail!(ComputeError: "get index is out of bounds");
}
JamesCE2001 marked this conversation as resolved.
Show resolved Hide resolved

let chunks = ca
.downcast_iter()
.map(|arr| sublist_get(arr, idx))
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-ops/src/chunked_array/list/to_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub trait ToStruct: AsList {
(0..n_fields)
.into_par_iter()
.map(|i| {
ca.lst_get(i as i64).map(|mut s| {
ca.lst_get(i as i64, true).map(|mut s| {
s.rename(&name_generator(i));
s
})
Expand Down
33 changes: 19 additions & 14 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub enum ListFunction {
},
Slice,
Shift,
Get,
Get(bool),
#[cfg(feature = "list_gather")]
Gather(bool),
#[cfg(feature = "list_gather")]
Expand Down Expand Up @@ -71,7 +71,7 @@ impl ListFunction {
Sample { .. } => mapper.with_same_dtype(),
Slice => mapper.with_same_dtype(),
Shift => mapper.with_same_dtype(),
Get => mapper.map_to_list_and_array_inner_dtype(),
Get(_) => mapper.map_to_list_and_array_inner_dtype(),
#[cfg(feature = "list_gather")]
Gather(_) => mapper.with_same_dtype(),
#[cfg(feature = "list_gather")]
Expand Down Expand Up @@ -136,7 +136,7 @@ impl Display for ListFunction {
},
Slice => "slice",
Shift => "shift",
Get => "get",
Get(_) => "get",
#[cfg(feature = "list_gather")]
Gather(_) => "gather",
#[cfg(feature = "list_gather")]
Expand Down Expand Up @@ -203,9 +203,9 @@ impl From<ListFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
},
Slice => wrap!(slice),
Shift => map_as_slice!(shift),
Get => wrap!(get),
Get(null_on_oob) => wrap!(get, null_on_oob),
#[cfg(feature = "list_gather")]
Gather(null_ob_oob) => map_as_slice!(gather, null_ob_oob),
Gather(null_on_oob) => map_as_slice!(gather, null_on_oob),
#[cfg(feature = "list_gather")]
GatherEvery => map_as_slice!(gather_every),
#[cfg(feature = "list_count")]
Expand Down Expand Up @@ -414,7 +414,7 @@ pub(super) fn concat(s: &mut [Series]) -> PolarsResult<Option<Series>> {
first_ca.lst_concat(other).map(|ca| Some(ca.into_series()))
}

pub(super) fn get(s: &mut [Series]) -> PolarsResult<Option<Series>> {
pub(super) fn get(s: &mut [Series], null_on_oob: bool) -> PolarsResult<Option<Series>> {
let ca = s[0].list()?;
let index = s[1].cast(&DataType::Int64)?;
let index = index.i64().unwrap();
Expand All @@ -423,7 +423,7 @@ pub(super) fn get(s: &mut [Series]) -> PolarsResult<Option<Series>> {
1 => {
let index = index.get(0);
if let Some(index) = index {
ca.lst_get(index).map(Some)
ca.lst_get(index, null_on_oob).map(Some)
} else {
Ok(Some(Series::full_null(
ca.name(),
Expand All @@ -440,19 +440,24 @@ pub(super) fn get(s: &mut [Series]) -> PolarsResult<Option<Series>> {
let take_by = index
.into_iter()
.enumerate()
.map(|(i, opt_idx)| {
opt_idx.and_then(|idx| {
.map(|(i, opt_idx)| match opt_idx {
Some(idx) => {
let (start, end) =
unsafe { (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1)) };
let offset = if idx >= 0 { start + idx } else { end + idx };
if offset >= end || offset < start || start == end {
None
if null_on_oob {
Ok(None)
} else {
polars_bail!(ComputeError: "get index is out of bounds");
}
} else {
Some(offset as IdxSize)
Ok(Some(offset as IdxSize))
}
})
},
None => Ok(None),
})
.collect::<IdxCa>();
.collect::<Result<IdxCa, _>>()?;
let s = Series::try_from((ca.name(), arr.values().clone())).unwrap();
unsafe { s.take_unchecked(&take_by) }
.cast(&ca.inner_dtype())
Expand All @@ -475,7 +480,7 @@ pub(super) fn gather(args: &[Series], null_on_oob: bool) -> PolarsResult<Series>
if idx.len() == 1 && null_on_oob {
// fast path
let idx = idx.get(0)?.try_extract::<i64>()?;
let out = ca.lst_get(idx)?;
let out = ca.lst_get(idx, null_on_oob)?;
// make sure we return a list
out.reshape(&[-1, 1])
} else {
Expand Down
8 changes: 8 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,14 @@ macro_rules! wrap {
($e:expr) => {
SpecialEq::new(Arc::new($e))
};

($e:expr, $($args:expr),*) => {{
let f = move |s: &mut [Series]| {
$e(s, $($args),*)
};

SpecialEq::new(Arc::new(f))
}};
}

// Fn(&[Series], args)
Expand Down
8 changes: 4 additions & 4 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ impl ListNameSpace {
}

/// Get items in every sublist by index.
pub fn get(self, index: Expr) -> Expr {
pub fn get(self, index: Expr, null_on_oob: bool) -> Expr {
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::Get),
FunctionExpr::ListExpr(ListFunction::Get(null_on_oob)),
&[index],
false,
false,
Expand Down Expand Up @@ -187,12 +187,12 @@ impl ListNameSpace {

/// Get first item of every sublist.
pub fn first(self) -> Expr {
self.get(lit(0i64))
self.get(lit(0i64), true)
}

/// Get last item of every sublist.
pub fn last(self) -> Expr {
self.get(lit(-1i64))
self.get(lit(-1i64), true)
}

/// Join all string items in a sublist and place a separator between them.
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,7 @@ impl SQLFunctionVisitor<'_> {
// Array functions
// ----
ArrayContains => self.visit_binary::<Expr>(|e, s| e.list().contains(s)),
ArrayGet => self.visit_binary(|e, i| e.list().get(i)),
ArrayGet => self.visit_binary(|e, i| e.list().get(i, false)),
JamesCE2001 marked this conversation as resolved.
Show resolved Hide resolved
ArrayLength => self.visit_unary(|e| e.list().len()),
ArrayMax => self.visit_unary(|e| e.list().max()),
ArrayMean => self.visit_unary(|e| e.list().mean()),
Expand Down
19 changes: 14 additions & 5 deletions py-polars/polars/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,12 @@ def concat(self, other: list[Expr | str] | Expr | str | Series | list[Any]) -> E
other_list.insert(0, wrap_expr(self._pyexpr))
return F.concat_list(other_list)

def get(self, index: int | Expr | str) -> Expr:
def get(
self,
index: int | Expr | str,
*,
null_on_oob: bool = False,
) -> Expr:
"""
Get the value by index in the sublists.

Expand All @@ -517,11 +522,15 @@ def get(self, index: int | Expr | str) -> Expr:
----------
index
Index to return per sublist
null_on_oob
Behavior if an index is out of bounds:
True -> set as null
False -> raise an error

Examples
--------
>>> df = pl.DataFrame({"a": [[3, 2, 1], [], [1, 2]]})
>>> df.with_columns(get=pl.col("a").list.get(0))
>>> df.with_columns(get=pl.col("a").list.get(0, null_on_oob=True))
shape: (3, 2)
┌───────────┬──────┐
│ a ┆ get │
Expand All @@ -534,7 +543,7 @@ def get(self, index: int | Expr | str) -> Expr:
└───────────┴──────┘
"""
index = parse_as_expression(index)
return wrap_expr(self._pyexpr.list_get(index))
return wrap_expr(self._pyexpr.list_get(index, null_on_oob))

def gather(
self,
Expand Down Expand Up @@ -641,7 +650,7 @@ def first(self) -> Expr:
│ [1, 2] ┆ 1 │
└───────────┴───────┘
"""
return self.get(0)
return self.get(0, null_on_oob=True)

def last(self) -> Expr:
"""
Expand All @@ -662,7 +671,7 @@ def last(self) -> Expr:
│ [1, 2] ┆ 2 │
└───────────┴──────┘
"""
return self.get(-1)
return self.get(-1, null_on_oob=True)

def contains(
self, item: float | str | bool | int | date | datetime | time | IntoExprColumn
Expand Down
13 changes: 11 additions & 2 deletions py-polars/polars/series/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,12 @@ def concat(self, other: list[Series] | Series | list[Any]) -> Series:
]
"""

def get(self, index: int | Series | list[int]) -> Series:
def get(
self,
index: int | Series | list[int],
*,
null_on_oob: bool = False,
) -> Series:
"""
Get the value by index in the sublists.

Expand All @@ -359,11 +364,15 @@ def get(self, index: int | Series | list[int]) -> Series:
----------
index
Index to return per sublist
null_on_oob
Behavior if an index is out of bounds:
True -> set as null
False -> raise an error

Examples
--------
>>> s = pl.Series("a", [[3, 2, 1], [], [1, 2]])
>>> s.list.get(0)
>>> s.list.get(0, null_on_oob=True)
shape: (3,)
Series: 'a' [i64]
[
Expand Down
8 changes: 6 additions & 2 deletions py-polars/src/expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ impl PyExpr {
self.inner.clone().list().eval(expr.inner, parallel).into()
}

fn list_get(&self, index: PyExpr) -> Self {
self.inner.clone().list().get(index.inner).into()
fn list_get(&self, index: PyExpr, null_on_oob: bool) -> Self {
self.inner
.clone()
.list()
.get(index.inner, null_on_oob)
.into()
}

fn list_join(&self, separator: PyExpr, ignore_nulls: bool) -> Self {
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ def test_list_gather_null_struct_14927() -> None:
{"index": [1], "col_0": [None], "field_0": [None]},
schema={**df.schema, "field_0": pl.Float64},
)
expr = pl.col("col_0").list.get(0).struct.field("field_0")
expr = pl.col("col_0").list.get(0, null_on_oob=True).struct.field("field_0")
out = df.filter(pl.col("index") > 0).with_columns(expr)
assert_frame_equal(out, expected)

Expand Down
Loading
Loading