Skip to content

Commit

Permalink
[FEAT] List slice expression (#2479)
Browse files Browse the repository at this point in the history
Adds an expression that returns a specified subset of elements from a
list.

The API to use this expression is `daft.Expression.list.slice(start: int
| Expression, end: int | Expression)`.

- `start` is the 0-indexed position of the list to start retrieving
elements. If this value is negative, then it's the position from the
_end_ of the array. i.e. `-1` points to the last element of the array
- `end` is the 0-indexed position of the list to stop retrieving
elements. If this value is negative, then it's the position from the
_end_ of the array. i.e. `-1` points to the last element of the array

---------

Co-authored-by: Desmond Cheong <[email protected]>
  • Loading branch information
desmondcheongzx and Desmond Cheong authored Jul 6, 2024
1 parent 6655d30 commit 7b81de6
Show file tree
Hide file tree
Showing 9 changed files with 427 additions and 24 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,7 @@ class PyExpr:
def list_mean(self) -> PyExpr: ...
def list_min(self) -> PyExpr: ...
def list_max(self) -> PyExpr: ...
def list_slice(self, start: PyExpr, end: PyExpr) -> PyExpr: ...
def struct_get(self, name: str) -> PyExpr: ...
def map_get(self, key: PyExpr) -> PyExpr: ...
def url_download(
Expand Down Expand Up @@ -1298,6 +1299,7 @@ class PySeries:
def partitioning_iceberg_truncate(self, w: int) -> PySeries: ...
def list_count(self, mode: CountMode) -> PySeries: ...
def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ...
def list_slice(self, start: PySeries, end: PySeries) -> PySeries: ...
def map_get(self, key: PySeries) -> PySeries: ...
def image_decode(self, raise_error_on_failure: bool) -> PySeries: ...
def image_encode(self, image_format: ImageFormat) -> PySeries: ...
Expand Down
14 changes: 14 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,6 +1807,20 @@ def get(self, idx: int | Expression, default: object = None) -> Expression:
default_expr = lit(default)
return Expression._from_pyexpr(self._expr.list_get(idx_expr._expr, default_expr._expr))

def slice(self, start: int | Expression, end: int | Expression) -> Expression:
"""Gets a subset of each list
Args:
start: index or column of indices. The slice will include elements starting from this index. If `start` is negative, it represents an offset from the end of the list
end: index or column of indices. The slice will not include elements from this index onwards. If `end` is negative, it represents an offset from the end of the list
Returns:
Expression: an expression with a list of the type of the list values
"""
start_expr = Expression._to_expression(start)
end_expr = Expression._to_expression(end)
return Expression._from_pyexpr(self._expr.list_slice(start_expr._expr, end_expr._expr))

def sum(self) -> Expression:
"""Sums each list. Empty lists and lists with all nulls yield null.
Expand Down
136 changes: 112 additions & 24 deletions src/daft-core/src/array/ops/list.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::iter::repeat;
use std::sync::Arc;

use crate::datatypes::{Int64Array, Utf8Array};
use crate::datatypes::{Field, Int64Array, Utf8Array};
use crate::{
array::{
growable::{make_growable, Growable},
Expand All @@ -10,7 +11,7 @@ use crate::{
};
use crate::{CountMode, DataType};

use crate::series::Series;
use crate::series::{IntoSeries, Series};

use common_error::DaftResult;

Expand Down Expand Up @@ -42,6 +43,81 @@ fn join_arrow_list_of_utf8s(
})
}

// Given an i64 array that may have either 1 or `self.len()` elements, create an iterator with
// `self.len()` elements. If there was originally 1 element, we repeat this element `self.len()`
// times, otherwise we simply take the original array.
fn create_iter<'a>(arr: &'a Int64Array, len: usize) -> Box<dyn Iterator<Item = i64> + '_> {
match arr.len() {
1 => Box::new(repeat(arr.get(0).unwrap()).take(len)),
arr_len => {
assert_eq!(arr_len, len);
Box::new(arr.as_arrow().iter().map(|x| *x.unwrap()))
}
}
}

pub fn get_slices_helper(
mut parent_offsets: impl Iterator<Item = i64>,
field: Arc<Field>,
child_data_type: &DataType,
flat_child: &Series,
validity: Option<&arrow2::bitmap::Bitmap>,
start_iter: impl Iterator<Item = i64>,
end_iter: impl Iterator<Item = i64>,
) -> DaftResult<Series> {
let mut slicing_indexes = Vec::with_capacity(flat_child.len());
let mut new_offsets = Vec::with_capacity(flat_child.len() + 1);
new_offsets.push(0);
let mut starting_idx = parent_offsets.next().unwrap();
for (i, ((start, end), ending_idx)) in start_iter.zip(end_iter).zip(parent_offsets).enumerate()
{
let is_valid = match validity {
None => true,
Some(v) => v.get(i).unwrap(),
};
let slice_start = if start >= 0 {
starting_idx + start
} else {
(ending_idx + start).max(starting_idx)
};
let slice_end = if end >= 0 {
(starting_idx + end).min(ending_idx)
} else {
ending_idx + end
};
let slice_length = slice_end - slice_start;
if is_valid && slice_start >= starting_idx && slice_length > 0 {
slicing_indexes.push(slice_start);
new_offsets.push(new_offsets.last().unwrap() + slice_length);
} else {
slicing_indexes.push(-1);
new_offsets.push(*new_offsets.last().unwrap());
}
starting_idx = ending_idx;
}
let total_capacity = *new_offsets.last().unwrap();
let mut growable: Box<dyn Growable> = make_growable(
&field.name,
child_data_type,
vec![flat_child],
false, // We don't set validity because we can simply copy the parent's validity.
total_capacity as usize,
);
for (i, start) in slicing_indexes.iter().enumerate() {
if *start >= 0 {
let slice_len = new_offsets.get(i + 1).unwrap() - new_offsets.get(i).unwrap();
growable.extend(0, *start as usize, slice_len as usize);
}
}
Ok(ListArray::new(
field,
growable.build()?,
arrow2::offset::OffsetsBuffer::try_from(new_offsets)?,
validity.cloned(),
)
.into_series())
}

impl ListArray {
pub fn count(&self, mode: CountMode) -> DaftResult<UInt64Array> {
let counts = match (mode, self.flat_child.validity()) {
Expand Down Expand Up @@ -181,17 +257,22 @@ impl ListArray {
}

pub fn get_children(&self, idx: &Int64Array, default: &Series) -> DaftResult<Series> {
match idx.len() {
1 => {
let idx_iter = repeat(idx.get(0).unwrap()).take(self.len());
self.get_children_helper(idx_iter, default)
}
len => {
assert_eq!(len, self.len());
let idx_iter = idx.as_arrow().iter().map(|x| *x.unwrap());
self.get_children_helper(idx_iter, default)
}
}
let idx_iter = create_iter(idx, self.len());
self.get_children_helper(idx_iter, default)
}

pub fn get_slices(&self, start: &Int64Array, end: &Int64Array) -> DaftResult<Series> {
let start_iter = create_iter(start, self.len());
let end_iter = create_iter(end, self.len());
get_slices_helper(
self.offsets().iter().copied(),
self.field.clone(),
self.child_data_type(),
&self.flat_child,
self.validity(),
start_iter,
end_iter,
)
}
}

Expand Down Expand Up @@ -320,17 +401,24 @@ impl FixedSizeListArray {
}

pub fn get_children(&self, idx: &Int64Array, default: &Series) -> DaftResult<Series> {
match idx.len() {
1 => {
let idx_iter = repeat(idx.get(0).unwrap()).take(self.len());
self.get_children_helper(idx_iter, default)
}
len => {
assert_eq!(len, self.len());
let idx_iter = idx.as_arrow().iter().map(|x| *x.unwrap());
self.get_children_helper(idx_iter, default)
}
}
let idx_iter = create_iter(idx, self.len());
self.get_children_helper(idx_iter, default)
}

pub fn get_slices(&self, start: &Int64Array, end: &Int64Array) -> DaftResult<Series> {
let start_iter = create_iter(start, self.len());
let end_iter = create_iter(end, self.len());
let new_field = Arc::new(self.field.to_exploded_field()?.to_list_field()?);
let list_size = self.fixed_element_len();
get_slices_helper(
(0..=((self.len() * list_size) as i64)).step_by(list_size),
new_field,
self.child_data_type(),
&self.flat_child,
self.validity(),
start_iter,
end_iter,
)
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,10 @@ impl PySeries {
Ok(self.series.list_get(&idx.series, &default.series)?.into())
}

pub fn list_slice(&self, start: &Self, end: &Self) -> PyResult<Self> {
Ok(self.series.list_slice(&start.series, &end.series)?.into())
}

pub fn map_get(&self, key: &Self) -> PyResult<Self> {
Ok(self.series.map_get(&key.series)?.into())
}
Expand Down
14 changes: 14 additions & 0 deletions src/daft-core/src/series/ops/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ impl Series {
}
}

pub fn list_slice(&self, start: &Series, end: &Series) -> DaftResult<Series> {
let start = start.cast(&DataType::Int64)?;
let start_arr = start.i64().unwrap();
let end = end.cast(&DataType::Int64)?;
let end_arr = end.i64().unwrap();
match self.data_type() {
DataType::List(_) => self.list()?.get_slices(start_arr, end_arr),
DataType::FixedSizeList(..) => self.fixed_size_list()?.get_slices(start_arr, end_arr),
dt => Err(DaftError::TypeError(format!(
"list slice not implemented for {dt}"
))),
}
}

pub fn list_sum(&self) -> DaftResult<Series> {
match self.data_type() {
DataType::List(_) => self.list()?.sum(),
Expand Down
12 changes: 12 additions & 0 deletions src/daft-dsl/src/functions/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod join;
mod max;
mod mean;
mod min;
mod slice;
mod sum;

use count::CountEvaluator;
Expand All @@ -16,6 +17,7 @@ use max::MaxEvaluator;
use mean::MeanEvaluator;
use min::MinEvaluator;
use serde::{Deserialize, Serialize};
use slice::SliceEvaluator;
use sum::SumEvaluator;

use crate::{Expr, ExprRef};
Expand All @@ -32,6 +34,7 @@ pub enum ListExpr {
Mean,
Min,
Max,
Slice,
}

impl ListExpr {
Expand All @@ -47,6 +50,7 @@ impl ListExpr {
Mean => &MeanEvaluator {},
Min => &MinEvaluator {},
Max => &MaxEvaluator {},
Slice => &SliceEvaluator {},
}
}
}
Expand Down Expand Up @@ -114,3 +118,11 @@ pub fn max(input: ExprRef) -> ExprRef {
}
.into()
}

pub fn slice(input: ExprRef, start: ExprRef, end: ExprRef) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::List(ListExpr::Slice),
inputs: vec![input, start, end],
}
.into()
}
53 changes: 53 additions & 0 deletions src/daft-dsl/src/functions/list/slice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use crate::ExprRef;
use daft_core::{datatypes::Field, schema::Schema, series::Series};

use super::super::FunctionEvaluator;
use crate::functions::FunctionExpr;
use common_error::{DaftError, DaftResult};

pub(super) struct SliceEvaluator {}

impl FunctionEvaluator for SliceEvaluator {
fn fn_name(&self) -> &'static str {
"slice"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
match inputs {
[input, start, end] => {
let input_field = input.to_field(schema)?;
let start_field = start.to_field(schema)?;
let end_field = end.to_field(schema)?;

if !start_field.dtype.is_integer() {
return Err(DaftError::TypeError(format!(
"Expected start index to be integer, received: {}",
start_field.dtype
)));
}

if !end_field.dtype.is_integer() {
return Err(DaftError::TypeError(format!(
"Expected end index to be integer, received: {}",
end_field.dtype
)));
}
Ok(input_field.to_exploded_field()?.to_list_field()?)
}
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 3 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult<Series> {
match inputs {
[input, start, end] => input.list_slice(start, end),
_ => Err(DaftError::ValueError(format!(
"Expected 3 input args, got {}",
inputs.len()
))),
}
}
}
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,11 @@ impl PyExpr {
Ok(max(self.into()).into())
}

pub fn list_slice(&self, start: &Self, end: &Self) -> PyResult<Self> {
use crate::functions::list::slice;
Ok(slice(self.into(), start.into(), end.into()).into())
}

pub fn struct_get(&self, name: &str) -> PyResult<Self> {
use crate::functions::struct_::get;
Ok(get(self.into(), name).into())
Expand Down
Loading

0 comments on commit 7b81de6

Please sign in to comment.