diff --git a/daft/daft.pyi b/daft/daft.pyi index aba9693083..18feaea513 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1146,6 +1146,7 @@ class PyExpr: def list_mean(self) -> PyExpr: ... def list_min(self) -> PyExpr: ... def list_max(self) -> PyExpr: ... + def list_slice(self, start: PyExpr, end: PyExpr) -> PyExpr: ... def struct_get(self, name: str) -> PyExpr: ... def map_get(self, key: PyExpr) -> PyExpr: ... def url_download( @@ -1298,6 +1299,7 @@ class PySeries: def partitioning_iceberg_truncate(self, w: int) -> PySeries: ... def list_count(self, mode: CountMode) -> PySeries: ... def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ... + def list_slice(self, start: PySeries, end: PySeries) -> PySeries: ... def map_get(self, key: PySeries) -> PySeries: ... def image_decode(self, raise_error_on_failure: bool) -> PySeries: ... def image_encode(self, image_format: ImageFormat) -> PySeries: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 100fec0881..239e33dbc6 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1807,6 +1807,20 @@ def get(self, idx: int | Expression, default: object = None) -> Expression: default_expr = lit(default) return Expression._from_pyexpr(self._expr.list_get(idx_expr._expr, default_expr._expr)) + def slice(self, start: int | Expression, end: int | Expression) -> Expression: + """Gets a subset of each list + + Args: + start: index or column of indices. The slice will include elements starting from this index. If `start` is negative, it represents an offset from the end of the list + end: index or column of indices. The slice will not include elements from this index onwards. If `end` is negative, it represents an offset from the end of the list + + Returns: + Expression: an expression with a list of the type of the list values + """ + start_expr = Expression._to_expression(start) + end_expr = Expression._to_expression(end) + return Expression._from_pyexpr(self._expr.list_slice(start_expr._expr, end_expr._expr)) + def sum(self) -> Expression: """Sums each list. Empty lists and lists with all nulls yield null. diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index 13a8e80de6..f7849d8dd5 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -1,6 +1,7 @@ use std::iter::repeat; +use std::sync::Arc; -use crate::datatypes::{Int64Array, Utf8Array}; +use crate::datatypes::{Field, Int64Array, Utf8Array}; use crate::{ array::{ growable::{make_growable, Growable}, @@ -10,7 +11,7 @@ use crate::{ }; use crate::{CountMode, DataType}; -use crate::series::Series; +use crate::series::{IntoSeries, Series}; use common_error::DaftResult; @@ -42,6 +43,81 @@ fn join_arrow_list_of_utf8s( }) } +// Given an i64 array that may have either 1 or `self.len()` elements, create an iterator with +// `self.len()` elements. If there was originally 1 element, we repeat this element `self.len()` +// times, otherwise we simply take the original array. +fn create_iter<'a>(arr: &'a Int64Array, len: usize) -> Box + '_> { + match arr.len() { + 1 => Box::new(repeat(arr.get(0).unwrap()).take(len)), + arr_len => { + assert_eq!(arr_len, len); + Box::new(arr.as_arrow().iter().map(|x| *x.unwrap())) + } + } +} + +pub fn get_slices_helper( + mut parent_offsets: impl Iterator, + field: Arc, + child_data_type: &DataType, + flat_child: &Series, + validity: Option<&arrow2::bitmap::Bitmap>, + start_iter: impl Iterator, + end_iter: impl Iterator, +) -> DaftResult { + let mut slicing_indexes = Vec::with_capacity(flat_child.len()); + let mut new_offsets = Vec::with_capacity(flat_child.len() + 1); + new_offsets.push(0); + let mut starting_idx = parent_offsets.next().unwrap(); + for (i, ((start, end), ending_idx)) in start_iter.zip(end_iter).zip(parent_offsets).enumerate() + { + let is_valid = match validity { + None => true, + Some(v) => v.get(i).unwrap(), + }; + let slice_start = if start >= 0 { + starting_idx + start + } else { + (ending_idx + start).max(starting_idx) + }; + let slice_end = if end >= 0 { + (starting_idx + end).min(ending_idx) + } else { + ending_idx + end + }; + let slice_length = slice_end - slice_start; + if is_valid && slice_start >= starting_idx && slice_length > 0 { + slicing_indexes.push(slice_start); + new_offsets.push(new_offsets.last().unwrap() + slice_length); + } else { + slicing_indexes.push(-1); + new_offsets.push(*new_offsets.last().unwrap()); + } + starting_idx = ending_idx; + } + let total_capacity = *new_offsets.last().unwrap(); + let mut growable: Box = make_growable( + &field.name, + child_data_type, + vec![flat_child], + false, // We don't set validity because we can simply copy the parent's validity. + total_capacity as usize, + ); + for (i, start) in slicing_indexes.iter().enumerate() { + if *start >= 0 { + let slice_len = new_offsets.get(i + 1).unwrap() - new_offsets.get(i).unwrap(); + growable.extend(0, *start as usize, slice_len as usize); + } + } + Ok(ListArray::new( + field, + growable.build()?, + arrow2::offset::OffsetsBuffer::try_from(new_offsets)?, + validity.cloned(), + ) + .into_series()) +} + impl ListArray { pub fn count(&self, mode: CountMode) -> DaftResult { let counts = match (mode, self.flat_child.validity()) { @@ -181,17 +257,22 @@ impl ListArray { } pub fn get_children(&self, idx: &Int64Array, default: &Series) -> DaftResult { - match idx.len() { - 1 => { - let idx_iter = repeat(idx.get(0).unwrap()).take(self.len()); - self.get_children_helper(idx_iter, default) - } - len => { - assert_eq!(len, self.len()); - let idx_iter = idx.as_arrow().iter().map(|x| *x.unwrap()); - self.get_children_helper(idx_iter, default) - } - } + let idx_iter = create_iter(idx, self.len()); + self.get_children_helper(idx_iter, default) + } + + pub fn get_slices(&self, start: &Int64Array, end: &Int64Array) -> DaftResult { + let start_iter = create_iter(start, self.len()); + let end_iter = create_iter(end, self.len()); + get_slices_helper( + self.offsets().iter().copied(), + self.field.clone(), + self.child_data_type(), + &self.flat_child, + self.validity(), + start_iter, + end_iter, + ) } } @@ -320,17 +401,24 @@ impl FixedSizeListArray { } pub fn get_children(&self, idx: &Int64Array, default: &Series) -> DaftResult { - match idx.len() { - 1 => { - let idx_iter = repeat(idx.get(0).unwrap()).take(self.len()); - self.get_children_helper(idx_iter, default) - } - len => { - assert_eq!(len, self.len()); - let idx_iter = idx.as_arrow().iter().map(|x| *x.unwrap()); - self.get_children_helper(idx_iter, default) - } - } + let idx_iter = create_iter(idx, self.len()); + self.get_children_helper(idx_iter, default) + } + + pub fn get_slices(&self, start: &Int64Array, end: &Int64Array) -> DaftResult { + let start_iter = create_iter(start, self.len()); + let end_iter = create_iter(end, self.len()); + let new_field = Arc::new(self.field.to_exploded_field()?.to_list_field()?); + let list_size = self.fixed_element_len(); + get_slices_helper( + (0..=((self.len() * list_size) as i64)).step_by(list_size), + new_field, + self.child_data_type(), + &self.flat_child, + self.validity(), + start_iter, + end_iter, + ) } } diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 56a9ae010e..ec171c6425 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -641,6 +641,10 @@ impl PySeries { Ok(self.series.list_get(&idx.series, &default.series)?.into()) } + pub fn list_slice(&self, start: &Self, end: &Self) -> PyResult { + Ok(self.series.list_slice(&start.series, &end.series)?.into()) + } + pub fn map_get(&self, key: &Self) -> PyResult { Ok(self.series.map_get(&key.series)?.into()) } diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index 961e160d6c..9ba23e4d68 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -69,6 +69,20 @@ impl Series { } } + pub fn list_slice(&self, start: &Series, end: &Series) -> DaftResult { + let start = start.cast(&DataType::Int64)?; + let start_arr = start.i64().unwrap(); + let end = end.cast(&DataType::Int64)?; + let end_arr = end.i64().unwrap(); + match self.data_type() { + DataType::List(_) => self.list()?.get_slices(start_arr, end_arr), + DataType::FixedSizeList(..) => self.fixed_size_list()?.get_slices(start_arr, end_arr), + dt => Err(DaftError::TypeError(format!( + "list slice not implemented for {dt}" + ))), + } + } + pub fn list_sum(&self) -> DaftResult { match self.data_type() { DataType::List(_) => self.list()?.sum(), diff --git a/src/daft-dsl/src/functions/list/mod.rs b/src/daft-dsl/src/functions/list/mod.rs index 764f5cdbbf..08994d66f4 100644 --- a/src/daft-dsl/src/functions/list/mod.rs +++ b/src/daft-dsl/src/functions/list/mod.rs @@ -5,6 +5,7 @@ mod join; mod max; mod mean; mod min; +mod slice; mod sum; use count::CountEvaluator; @@ -16,6 +17,7 @@ use max::MaxEvaluator; use mean::MeanEvaluator; use min::MinEvaluator; use serde::{Deserialize, Serialize}; +use slice::SliceEvaluator; use sum::SumEvaluator; use crate::{Expr, ExprRef}; @@ -32,6 +34,7 @@ pub enum ListExpr { Mean, Min, Max, + Slice, } impl ListExpr { @@ -47,6 +50,7 @@ impl ListExpr { Mean => &MeanEvaluator {}, Min => &MinEvaluator {}, Max => &MaxEvaluator {}, + Slice => &SliceEvaluator {}, } } } @@ -114,3 +118,11 @@ pub fn max(input: ExprRef) -> ExprRef { } .into() } + +pub fn slice(input: ExprRef, start: ExprRef, end: ExprRef) -> ExprRef { + Expr::Function { + func: super::FunctionExpr::List(ListExpr::Slice), + inputs: vec![input, start, end], + } + .into() +} diff --git a/src/daft-dsl/src/functions/list/slice.rs b/src/daft-dsl/src/functions/list/slice.rs new file mode 100644 index 0000000000..643c7a63ce --- /dev/null +++ b/src/daft-dsl/src/functions/list/slice.rs @@ -0,0 +1,53 @@ +use crate::ExprRef; +use daft_core::{datatypes::Field, schema::Schema, series::Series}; + +use super::super::FunctionEvaluator; +use crate::functions::FunctionExpr; +use common_error::{DaftError, DaftResult}; + +pub(super) struct SliceEvaluator {} + +impl FunctionEvaluator for SliceEvaluator { + fn fn_name(&self) -> &'static str { + "slice" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + match inputs { + [input, start, end] => { + let input_field = input.to_field(schema)?; + let start_field = start.to_field(schema)?; + let end_field = end.to_field(schema)?; + + if !start_field.dtype.is_integer() { + return Err(DaftError::TypeError(format!( + "Expected start index to be integer, received: {}", + start_field.dtype + ))); + } + + if !end_field.dtype.is_integer() { + return Err(DaftError::TypeError(format!( + "Expected end index to be integer, received: {}", + end_field.dtype + ))); + } + Ok(input_field.to_exploded_field()?.to_list_field()?) + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 3 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + match inputs { + [input, start, end] => input.list_slice(start, end), + _ => Err(DaftError::ValueError(format!( + "Expected 3 input args, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 73252fa4e1..aff6c64a31 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -794,6 +794,11 @@ impl PyExpr { Ok(max(self.into()).into()) } + pub fn list_slice(&self, start: &Self, end: &Self) -> PyResult { + use crate::functions::list::slice; + Ok(slice(self.into(), start.into(), end.into()).into()) + } + pub fn struct_get(&self, name: &str) -> PyResult { use crate::functions::struct_::get; Ok(get(self.into(), name).into()) diff --git a/tests/table/list/test_list_slice.py b/tests/table/list/test_list_slice.py new file mode 100644 index 0000000000..0946deff84 --- /dev/null +++ b/tests/table/list/test_list_slice.py @@ -0,0 +1,211 @@ +import pyarrow as pa +import pytest + +from daft.datatype import DataType +from daft.expressions import col +from daft.table import MicroPartition + + +def test_list_slice_empty_series(): + table = MicroPartition.from_pydict( + { + "col": pa.array([], type=pa.list_(pa.int64())), + "start": pa.array([], type=pa.int64()), + "end": pa.array([], type=pa.int64()), + } + ) + + result = table.eval_expression_list( + [ + col("col").list.slice(0, 1).alias("col"), + col("col").list.slice(col("start"), 1).alias("col-start"), + col("col").list.slice(col("start"), col("end")).alias("col-start-end"), + col("col").list.slice(0, col("end")).alias("col-end"), + ] + ) + + assert result.to_pydict() == { + "col": [], + "col-start": [], + "col-start-end": [], + "col-end": [], + } + + +def test_list_slice(): + table = MicroPartition.from_pydict( + { + # Test list of an atomic type. + "col1": [["a"], ["ab", "a"], [None, "a", "", "b", "c"], None, ["a", ""]], + # Test lists of a nested type. + "col2": [ + [[1]], + [[3, 3], [4], [5, 5]], + [], + [[], []], + None, + ], + "start": [-1, 1, 0, 2, -2], + "end": [1, 2, 0, 4, 3], + "edge_start": [-1, -2, -5, 0, -2], + "edge_end": [-1, -2, -1, -2, -1], + } + ) + + result = table.eval_expression_list( + [ + col("col1").list.slice(0, 1).alias("col1"), + col("col1").list.slice(col("start"), 1).alias("col1-start"), + col("col1").list.slice(col("start"), col("end")).alias("col1-start-end"), + col("col1").list.slice(1, col("end")).alias("col1-end"), + col("col1").list.slice(20, 25).alias("col1-invalid-start"), + col("col2").list.slice(0, 1).alias("col2"), + col("col2").list.slice(col("start"), 1).alias("col2-start"), + col("col2").list.slice(col("start"), col("end")).alias("col2-start-end"), + col("col2").list.slice(0, col("end")).alias("col2-end"), + col("col2").list.slice(20, 25).alias("col2-invalid-start"), + # Test edge cases. + col("col1").list.slice(-10, -20).alias("col1-edge1"), + col("col1").list.slice(-20, -10).alias("col1-edge2"), + col("col1").list.slice(-20, 10).alias("col1-edge3"), + col("col1").list.slice(-20, -1).alias("col1-edge4"), + col("col1").list.slice(col("edge_start"), col("edge_end")).alias("col1-edge5"), + col("col1").list.slice(10, 1).alias("col1-edge6"), + col("col1").list.slice(1, -1).alias("col1-edge7"), + ] + ) + + assert result.to_pydict() == { + "col1": [["a"], ["ab"], [None], None, ["a"]], + "col1-start": [["a"], [], [None], None, ["a"]], + "col1-start-end": [["a"], ["a"], [], None, ["a", ""]], + "col1-end": [[], ["a"], [], None, [""]], + "col1-invalid-start": [[], [], [], None, []], + "col2": [[[1]], [[3, 3]], [], [[]], None], + "col2-start": [[[1]], [], [], [], None], + "col2-start-end": [[[1]], [[4]], [], [], None], + "col2-end": [[[1]], [[3, 3], [4]], [], [[], []], None], + "col2-invalid-start": [[], [], [], [], None], + "col1-edge1": [[], [], [], None, []], + "col1-edge2": [[], [], [], None, []], + "col1-edge3": [["a"], ["ab", "a"], [None, "a", "", "b", "c"], None, ["a", ""]], + "col1-edge4": [[], ["ab"], [None, "a", "", "b"], None, ["a"]], + "col1-edge5": [[], [], [None, "a", "", "b"], None, ["a"]], + "col1-edge6": [[], [], [], None, []], + "col1-edge7": [[], [], ["a", "", "b"], None, []], + } + + +def test_fixed_size_list_slice(): + table = MicroPartition.from_pydict( + { + # Test list of an atomic type. + "col1": [["a", "b"], ["aa", "bb"], None, [None, "bbbb"], ["aaaaa", None]], + # Test lists of a nested type. + "col2": [ + [[1], [2]], + [[11, 111], [22, 222]], + None, + [None, [3333]], + [[], []], + ], + "start": [-1, 1, 0, 2, -2], + "end": [1, 1, 0, 2, 3], + "edge_start": [-1, -2, -5, 0, -2], + "edge_end": [-1, -2, -1, -2, -1], + } + ) + + dtype1 = DataType.fixed_size_list(DataType.string(), 2) + dtype2 = DataType.fixed_size_list(DataType.list(DataType.int32()), 2) + + table = table.eval_expression_list( + [ + col("col1").cast(dtype1), + col("col2").cast(dtype2), + col("start"), + col("end"), + col("edge_start"), + col("edge_end"), + ] + ) + + result = table.eval_expression_list( + [ + col("col1").list.slice(0, 1).alias("col1"), + col("col1").list.slice(col("start"), 1).alias("col1-start"), + col("col1").list.slice(col("start"), col("end")).alias("col1-start-end"), + col("col1").list.slice(1, col("end")).alias("col1-end"), + col("col1").list.slice(20, 25).alias("col1-invalid-start"), + col("col2").list.slice(0, 1).alias("col2"), + col("col2").list.slice(col("start"), 2).alias("col2-start"), + col("col2").list.slice(col("start"), col("end")).alias("col2-start-end"), + col("col2").list.slice(0, col("end")).alias("col2-end"), + col("col2").list.slice(20, 25).alias("col2-invalid-start"), + # Test edge cases. + col("col1").list.slice(-10, -20).alias("col1-edge1"), + col("col1").list.slice(-20, -10).alias("col1-edge2"), + col("col1").list.slice(-20, 10).alias("col1-edge3"), + col("col1").list.slice(-20, -1).alias("col1-edge4"), + col("col1").list.slice(col("edge_start"), col("edge_end")).alias("col1-edge5"), + col("col1").list.slice(10, 1).alias("col1-edge6"), + col("col1").list.slice(0, -1).alias("col1-edge7"), + ] + ) + + assert result.to_pydict() == { + "col1": [["a"], ["aa"], None, [None], ["aaaaa"]], + "col1-start": [[], [], None, [], ["aaaaa"]], + "col1-start-end": [[], [], None, [], ["aaaaa", None]], + "col1-end": [[], [], None, ["bbbb"], [None]], + "col1-invalid-start": [[], [], None, [], []], + "col2": [[[1]], [[11, 111]], None, [None], [[]]], + "col2-start": [[[2]], [[22, 222]], None, [], [[], []]], + "col2-start-end": [[], [], None, [], [[], []]], + "col2-end": [[[1]], [[11, 111]], None, [None, [3333]], [[], []]], + "col2-invalid-start": [[], [], None, [], []], + "col1-edge1": [[], [], None, [], []], + "col1-edge2": [[], [], None, [], []], + "col1-edge3": [["a", "b"], ["aa", "bb"], None, [None, "bbbb"], ["aaaaa", None]], + "col1-edge4": [["a"], ["aa"], None, [None], ["aaaaa"]], + "col1-edge5": [[], [], None, [], ["aaaaa"]], + "col1-edge6": [[], [], None, [], []], + "col1-edge7": [["a"], ["aa"], None, [None], ["aaaaa"]], + } + + +def test_list_slice_invalid_parameters(): + table = MicroPartition.from_pydict( + { + "col": [["a", "b", "c"], ["aa", "bb", "cc"], None, [None, "bbbb"], ["aaaaa", None]], + "start": [0, -1, 1, 3, -4], + "end": [1, 2, 3, -1, 0], + } + ) + with pytest.raises(ValueError, match="Expected start index to be integer"): + table.eval_expression_list([col("col").list.slice(1.0, 0)]) + with pytest.raises(ValueError, match="Expected end index to be integer"): + table.eval_expression_list([col("col").list.slice(0, 1.0)]) + with pytest.raises(TypeError, match="missing 1 required positional argument: 'end'"): + table.eval_expression_list([col("col").list.slice(0)]) + with pytest.raises(TypeError, match="missing 2 required positional arguments"): + table.eval_expression_list([col("col").list.slice()]) + with pytest.raises(TypeError, match="takes 3 positional arguments but 4 were given"): + table.eval_expression_list([col("col").list.slice(0, 0, 0)]) + + +def test_list_slice_non_list_type(): + table = MicroPartition.from_pydict( + { + "structcol": [{"a": 1}, {"b": 1}, {"c": 1}], + "stringcol": ["a", "b", "c"], + "intcol": [1, 2, 3], + }, + ) + + with pytest.raises(ValueError): + table.eval_expression_list([col("structcol").list.slice(0, 2)]) + with pytest.raises(ValueError): + table.eval_expression_list([col("stringcol").list.slice(0, 2)]) + with pytest.raises(ValueError): + table.eval_expression_list([col("intcol").list.slice(0, 2)])