From b983578c6681658137cfb323bd1f3eff0323c001 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Tue, 24 Sep 2024 18:51:47 -0700 Subject: [PATCH] tests pass yay --- .vscode/launch.json | 36 +++ .vscode/settings.json | 7 +- daft/daft/__init__.pyi | 1 + daft/expressions/expressions.py | 11 + src/arrow2/src/array/map/mod.rs | 26 +- src/arrow2/src/offset.rs | 33 ++- .../src/array/fixed_size_list_array.rs | 1 + src/daft-core/src/array/list_array.rs | 2 + .../src/array/ops/arrow2/comparison.rs | 2 +- src/daft-core/src/array/ops/cast.rs | 2 +- src/daft-core/src/array/ops/from_arrow.rs | 4 +- src/daft-core/src/array/ops/list.rs | 278 ++++++++++++++++-- src/daft-core/src/array/ops/map.rs | 21 +- src/daft-core/src/array/struct_array.rs | 2 + src/daft-core/src/datatypes/matching.rs | 2 +- src/daft-core/src/lib.rs | 1 + src/daft-core/src/series/from.rs | 3 +- src/daft-core/src/series/ops/list.rs | 18 +- src/daft-core/src/series/ops/map.rs | 2 +- src/daft-core/src/series/serdes.rs | 2 +- src/daft-dsl/src/functions/map/get.rs | 16 +- src/daft-dsl/src/functions/utf8/split.rs | 1 + src/daft-functions/src/list/mod.rs | 2 + src/daft-functions/src/list/value_counts.rs | 70 +++++ src/daft-schema/src/dtype.rs | 119 +++++--- src/daft-schema/src/python/datatype.rs | 8 +- src/daft-stats/src/column_stats/mod.rs | 2 +- src/daft-table/src/repr_html.rs | 2 +- tests/expressions/test_expressions.py | 58 ++++ 29 files changed, 600 insertions(+), 132 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 src/daft-functions/src/list/value_counts.rs diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000..84a59357e2 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,36 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Debug Daft Python/Rust", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/tests/expressions/test_expressions.py", + "args": [], + "console": "integratedTerminal", + "justMyCode": false, + "env": { + "PYTHONPATH": "${workspaceFolder}" + }, + "serverReadyAction": { + "pattern": "pID = ([0-9]+)", + "action": "startDebugging", + "name": "Daft Rust LLDB" + } + }, + { + "name": "Daft Rust LLDB", + "pid": "0", + "type": "lldb", + "request": "attach", + "program": "${workspaceFolder}/.venv/bin/python", + "stopOnEntry": false, + "sourceLanguages": [ + "rust" + ], + "presentation": { + "hidden": true + } + } + ] + } \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 2f8da01d92..1038ded445 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,5 +3,10 @@ "CARGO_TARGET_DIR": "target/analyzer" }, "rust-analyzer.check.features": "all", - "rust-analyzer.cargo.features": "all" + "rust-analyzer.cargo.features": "all", + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 1a5dc99f0f..b3486f5d65 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1267,6 +1267,7 @@ def dt_truncate(expr: PyExpr, interval: str, relative_to: PyExpr) -> PyExpr: ... # --- def explode(expr: PyExpr) -> PyExpr: ... def list_sort(expr: PyExpr, desc: PyExpr) -> PyExpr: ... +def list_value_counts(expr: PyExpr) -> PyExpr: ... def list_join(expr: PyExpr, delimiter: PyExpr) -> PyExpr: ... def list_count(expr: PyExpr, mode: CountMode) -> PyExpr: ... def list_get(expr: PyExpr, idx: PyExpr, default: PyExpr) -> PyExpr: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 1ae7e90dac..a97ac924e7 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -2922,6 +2922,17 @@ def join(self, delimiter: str | Expression) -> Expression: delimiter_expr = Expression._to_expression(delimiter) return Expression._from_pyexpr(native.list_join(self._expr, delimiter_expr._expr)) + def value_counts(self) -> Expression: + """Counts the occurrences of each unique value in the list. + + Returns: + Expression: A list of structs, where each struct contains a 'value' field + representing a unique element from the original list, and a 'count' field + representing the number of times that value appears in the list. + """ + return Expression._from_pyexpr(native.list_value_counts(self._expr)) + + def count(self, mode: CountMode = CountMode.Valid) -> Expression: """Counts the number of elements in each list diff --git a/src/arrow2/src/array/map/mod.rs b/src/arrow2/src/array/map/mod.rs index d0dcb46efb..a870de834f 100644 --- a/src/arrow2/src/array/map/mod.rs +++ b/src/arrow2/src/array/map/mod.rs @@ -41,20 +41,24 @@ impl MapArray { try_check_offsets_bounds(&offsets, field.len())?; let inner_field = Self::try_get_field(&data_type)?; - if let DataType::Struct(inner) = inner_field.data_type() { - if inner.len() != 2 { - return Err(Error::InvalidArgumentError( - "MapArray's inner `Struct` must have 2 fields (keys and maps)".to_string(), - )); - } - } else { + + let inner_data_type = inner_field.data_type(); + let DataType::Struct(inner) = inner_data_type else { return Err(Error::InvalidArgumentError( - "MapArray expects `DataType::Struct` as its inner logical type".to_string(), + format!("MapArray expects `DataType::Struct` as its inner logical type, but found {inner_data_type:?}"), )); + }; + + let inner_len = inner.len(); + if inner_len != 2 { + let msg = format!("MapArray's inner `Struct` must have 2 fields (keys and maps), but found {} fields", inner_len); + return Err(Error::InvalidArgumentError(msg)); } - if field.data_type() != inner_field.data_type() { + + let field_data_type = field.data_type(); + if field_data_type != inner_field.data_type() { return Err(Error::InvalidArgumentError( - "MapArray expects `field.data_type` to match its inner DataType".to_string(), + format!("MapArray expects `field.data_type` to match its inner DataType, but found \n{field_data_type:?}\nvs\n\n\n{inner_field:?}"), )); } @@ -66,7 +70,7 @@ impl MapArray { "validity mask length must match the number of values", )); } - + Ok(Self { data_type, field, diff --git a/src/arrow2/src/offset.rs b/src/arrow2/src/offset.rs index 80b45d6680..1ab6f15105 100644 --- a/src/arrow2/src/offset.rs +++ b/src/arrow2/src/offset.rs @@ -71,7 +71,7 @@ impl Offsets { /// Creates a new [`Offsets`] from an iterator of lengths #[inline] - pub fn try_from_iter>(iter: I) -> Result { + pub fn try_from_iter>(iter: I) -> Result { let iterator = iter.into_iter(); let (lower, _) = iterator.size_hint(); let mut offsets = Self::with_capacity(lower); @@ -144,10 +144,7 @@ impl Offsets { /// Returns the last offset of this container. #[inline] pub fn last(&self) -> &O { - match self.0.last() { - Some(element) => element, - None => unsafe { unreachable_unchecked() }, - } + self.0.last().unwrap_or_else(|| unsafe { unreachable_unchecked() }) } /// Returns a range (start, end) corresponding to the position `index` @@ -215,7 +212,7 @@ impl Offsets { /// # Errors /// This function errors iff this operation overflows for the maximum value of `O`. #[inline] - pub fn try_from_lengths>(lengths: I) -> Result { + pub fn try_from_lengths>(lengths: I) -> Result { let mut self_ = Self::with_capacity(lengths.size_hint().0); self_.try_extend_from_lengths(lengths)?; Ok(self_) @@ -225,7 +222,7 @@ impl Offsets { /// # Errors /// This function errors iff this operation overflows for the maximum value of `O`. #[inline] - pub fn try_extend_from_lengths>( + pub fn try_extend_from_lengths>( &mut self, lengths: I, ) -> Result<(), Error> { @@ -401,22 +398,26 @@ impl OffsetsBuffer { *self.last() - *self.first() } + pub fn ranges(&self) -> impl Iterator> + '_ { + self.0.windows(2).map(|w| { + let from = w[0]; + let to = w[1]; + debug_assert!(from <= to, "offsets must be monotonically increasing"); + from..to + }) + } + + /// Returns the first offset. #[inline] pub fn first(&self) -> &O { - match self.0.first() { - Some(element) => element, - None => unsafe { unreachable_unchecked() }, - } + self.0.first().unwrap_or_else(|| unsafe { unreachable_unchecked() }) } /// Returns the last offset. #[inline] pub fn last(&self) -> &O { - match self.0.last() { - Some(element) => element, - None => unsafe { unreachable_unchecked() }, - } + self.0.last().unwrap_or_else(|| unsafe { unreachable_unchecked() }) } /// Returns a range (start, end) corresponding to the position `index` @@ -460,7 +461,7 @@ impl OffsetsBuffer { /// Returns an iterator with the lengths of the offsets #[inline] - pub fn lengths(&self) -> impl Iterator + '_ { + pub fn lengths(&self) -> impl Iterator + '_ { self.0.windows(2).map(|w| (w[1] - w[0]).to_usize()) } diff --git a/src/daft-core/src/array/fixed_size_list_array.rs b/src/daft-core/src/array/fixed_size_list_array.rs index d265f42929..f577d990cf 100644 --- a/src/daft-core/src/array/fixed_size_list_array.rs +++ b/src/daft-core/src/array/fixed_size_list_array.rs @@ -11,6 +11,7 @@ use crate::{ #[derive(Clone, Debug)] pub struct FixedSizeListArray { pub field: Arc, + /// contains all the elements of the nested lists flattened into a single contiguous array. pub flat_child: Series, validity: Option, } diff --git a/src/daft-core/src/array/list_array.rs b/src/daft-core/src/array/list_array.rs index 964503b271..f81c398f3f 100644 --- a/src/daft-core/src/array/list_array.rs +++ b/src/daft-core/src/array/list_array.rs @@ -12,6 +12,8 @@ use crate::{ pub struct ListArray { pub field: Arc, pub flat_child: Series, + + /// Where does each row start offsets: arrow2::offset::OffsetsBuffer, validity: Option, } diff --git a/src/daft-core/src/array/ops/arrow2/comparison.rs b/src/daft-core/src/array/ops/arrow2/comparison.rs index 37f7b2a37b..700ab4f8d0 100644 --- a/src/daft-core/src/array/ops/arrow2/comparison.rs +++ b/src/daft-core/src/array/ops/arrow2/comparison.rs @@ -49,7 +49,7 @@ fn build_is_equal_with_nan( } } -fn build_is_equal( +pub fn build_is_equal( left: &dyn Array, right: &dyn Array, nulls_equal: bool, diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index 3f14b8f2f5..ec5f27a72e 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -2091,7 +2091,7 @@ impl ListArray { } } } - DataType::Map(..) => Ok(MapArray::new( + DataType::Map { .. } => Ok(MapArray::new( Field::new(self.name(), dtype.clone()), self.clone(), ) diff --git a/src/daft-core/src/array/ops/from_arrow.rs b/src/daft-core/src/array/ops/from_arrow.rs index a635fe6e21..3c0763e523 100644 --- a/src/daft-core/src/array/ops/from_arrow.rs +++ b/src/daft-core/src/array/ops/from_arrow.rs @@ -35,7 +35,7 @@ where // TODO: Consolidate Map to use the same .to_type conversion as other logical types // Currently, .to_type does not work for Map in Arrow2 because it requires physical types to be equivalent, // but the physical type of MapArray in Arrow2 is a MapArray, not a ListArray - DataType::Map(..) => arrow_arr, + DataType::Map { .. } => arrow_arr, _ => arrow_arr.to_type(data_array_field.dtype.to_arrow()?), }; let physical = ::ArrayType::from_arrow( @@ -98,7 +98,7 @@ impl FromArrow for ListArray { arrow_arr.validity().cloned(), )) } - (DataType::List(daft_child_dtype), arrow2::datatypes::DataType::Map(..)) => { + (DataType::List(daft_child_dtype), arrow2::datatypes::DataType::Map { .. }) => { let map_arr = arrow_arr .as_any() .downcast_ref::() diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index 4dd8cee2a8..c94cc7fb74 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -1,5 +1,7 @@ use std::{iter::repeat, sync::Arc}; - +use indexmap::IndexMap; +use indexmap::map::raw_entry_v1::RawEntryMut; +use indexmap::map::RawEntryApiV1; use arrow2::offset::OffsetsBuffer; use common_error::DaftResult; @@ -13,6 +15,11 @@ use crate::{ datatypes::{BooleanArray, DataType, Field, Int64Array, UInt64Array, Utf8Array}, series::{IntoSeries, Series}, }; +use crate::array::ops::arrow2::comparison::build_is_equal; +use crate::array::StructArray; +use crate::kernels::search_sorted::build_is_valid; +use crate::prelude::MapArray; +use crate::utils::identity_hash_set::IdentityBuildHasher; fn join_arrow_list_of_utf8s( list_element: Option<&dyn arrow2::array::Array>, @@ -43,7 +50,7 @@ fn join_arrow_list_of_utf8s( // Given an i64 array that may have either 1 or `self.len()` elements, create an iterator with // `self.len()` elements. If there was originally 1 element, we repeat this element `self.len()` // times, otherwise we simply take the original array. -fn create_iter<'a>(arr: &'a Int64Array, len: usize) -> Box + '_> { +fn create_iter<'a>(arr: &'a Int64Array, len: usize) -> Box + '_> { match arr.len() { 1 => Box::new(repeat(arr.get(0).unwrap()).take(len)), arr_len => { @@ -54,13 +61,13 @@ fn create_iter<'a>(arr: &'a Int64Array, len: usize) -> Box, + mut parent_offsets: impl Iterator, field: Arc, child_data_type: &DataType, flat_child: &Series, validity: Option<&arrow2::bitmap::Bitmap>, - start_iter: impl Iterator, - end_iter: impl Iterator, + start_iter: impl Iterator, + end_iter: impl Iterator, ) -> DaftResult { let mut slicing_indexes = Vec::with_capacity(flat_child.len()); let mut new_offsets = Vec::with_capacity(flat_child.len() + 1); @@ -112,7 +119,7 @@ fn get_slices_helper( arrow2::offset::OffsetsBuffer::try_from(new_offsets)?, validity.cloned(), ) - .into_series()) + .into_series()) } /// Helper function that gets chunks of a given `size` from each list in the Series. Discards excess @@ -146,7 +153,7 @@ fn get_chunks_helper( validity: Option<&arrow2::bitmap::Bitmap>, size: usize, total_elements_to_skip: usize, - to_skip: Option>, + to_skip: Option>, new_offsets: Vec, ) -> DaftResult { if total_elements_to_skip == 0 { @@ -155,9 +162,9 @@ fn get_chunks_helper( inner_list_field.clone(), flat_child.clone(), None, // Since we're creating an extra layer of lists, this layer doesn't have any - // validity information. The topmost list takes the parent's validity, and the - // child list is unaffected by the chunking operation and maintains its validity. - // This reasoning applies to the places that follow where validity is set. + // validity information. The topmost list takes the parent's validity, and the + // child list is unaffected by the chunking operation and maintains its validity. + // This reasoning applies to the places that follow where validity is set. ); Ok(ListArray::new( inner_list_field.to_list_field()?, @@ -165,7 +172,7 @@ fn get_chunks_helper( arrow2::offset::OffsetsBuffer::try_from(new_offsets)?, validity.cloned(), // Copy the parent's validity. ) - .into_series()) + .into_series()) } else { let mut growable: Box = make_growable( &field.name, @@ -189,15 +196,15 @@ fn get_chunks_helper( arrow2::offset::OffsetsBuffer::try_from(new_offsets)?, validity.cloned(), // Copy the parent's validity. ) - .into_series()) + .into_series()) } } fn list_sort_helper( flat_child: &Series, offsets: &OffsetsBuffer, - desc_iter: impl Iterator, - validity: impl Iterator, + desc_iter: impl Iterator, + validity: impl Iterator, ) -> DaftResult> { desc_iter .zip(validity) @@ -221,8 +228,8 @@ fn list_sort_helper( fn list_sort_helper_fixed_size( flat_child: &Series, fixed_size: usize, - desc_iter: impl Iterator, - validity: impl Iterator, + desc_iter: impl Iterator, + validity: impl Iterator, ) -> DaftResult> { desc_iter .zip(validity) @@ -244,6 +251,132 @@ fn list_sort_helper_fixed_size( } impl ListArray { + pub fn value_counts(&self) -> DaftResult { + struct IndexRef { + index: usize, + hash: u64, + } + + impl std::hash::Hash for IndexRef { + fn hash(&self, state: &mut H) { + self.hash.hash(state); + } + } + + let original_name = self.name(); + + let hashes = self.flat_child.hash(None)?; + + let flat_child = self.flat_child.to_arrow(); + let flat_child = &*flat_child; + + let is_equal = build_is_equal( + flat_child, + flat_child, + true, // todo: should nulls and nans be considered equal? + true, + )?; + + let is_valid = build_is_valid(flat_child); + + let key_type = self.flat_child.data_type().clone(); + let count_type = DataType::UInt64; + + + let mut include_mask = Vec::with_capacity(self.flat_child.len()); + let mut count_array = Vec::new(); + + let mut offsets = Vec::with_capacity(self.len()); + + offsets.push(0_i64); + + let mut map: IndexMap = IndexMap::default(); + for range in self.offsets().ranges() { + map.clear(); + + for index in range { + let index = index as usize; + if !is_valid(index) { + include_mask.push(false); + // skip nulls + continue; + } + + let hash = hashes.get(index).unwrap(); + + let entry = map.raw_entry_mut_v1().from_hash(hash, |other| { + is_equal(other.index, index) + }); + + match entry { + RawEntryMut::Occupied(mut entry) => { + include_mask.push(false); + *entry.get_mut() += 1; + } + RawEntryMut::Vacant(vacant) => { + include_mask.push(true); + vacant.insert(IndexRef { hash, index }, 1); + } + } + } + + + // indexmap so ordered + for v in map.values() { + count_array.push(*v); + } + + offsets.push(count_array.len() as i64); + } + + let values = UInt64Array::from(("count", count_array)).into_series(); + let boolean_array = BooleanArray::from(("boolean", include_mask.as_slice())); + + let keys = self.flat_child.filter(&boolean_array)?; + + // todo: probably inefficient + let keys = Series::try_from_field_and_arrow_array( + Field::new("key", key_type.clone()), + keys.to_arrow(), + )?; + + // todo: probably inefficient + let values = Series::try_from_field_and_arrow_array( + Field::new("value", count_type.clone()), + values.to_arrow(), + )?; + + + let struct_type = DataType::Struct(vec![ + Field::new("key", key_type.clone()), + Field::new("value", count_type.clone()), + ]); + + let struct_array = StructArray::new( + Arc::new(Field::new("entries", struct_type.clone())), + vec![keys, values], + None, + ); + + let list_type = DataType::List(Box::new(struct_type)); + + let offsets = OffsetsBuffer::try_from(offsets)?; + + let list_array = ListArray::new( + Arc::new(Field::new("entries", list_type.clone())), + struct_array.into_series(), + offsets, + None, + ); + + let map_type = DataType::Map { key: Box::new(key_type), value: Box::new(count_type) }; + + Ok(MapArray::new( + Field::new(original_name, map_type.clone()), + list_array, + )) + } + pub fn count(&self, mode: CountMode) -> DaftResult { let counts = match (mode, self.flat_child.validity()) { (CountMode::All, _) | (CountMode::Valid, None) => { @@ -315,7 +448,7 @@ impl ListArray { pub fn join(&self, delimiter: &Utf8Array) -> DaftResult { assert_eq!(self.child_data_type(), &DataType::Utf8,); - let delimiter_iter: Box>> = if delimiter.len() == 1 { + let delimiter_iter: Box>> = if delimiter.len() == 1 { Box::new(repeat(delimiter.get(0)).take(self.len())) } else { assert_eq!(delimiter.len(), self.len()); @@ -340,7 +473,7 @@ impl ListArray { fn get_children_helper( &self, - idx_iter: impl Iterator, + idx_iter: impl Iterator, default: &Series, ) -> DaftResult { assert!( @@ -472,6 +605,111 @@ impl ListArray { } impl FixedSizeListArray { + // this DaftResult? or something or Series or what + + pub fn value_counts(&self) -> DaftResult { + struct IndexRef { + index: usize, + hash: u64, + } + + impl std::hash::Hash for IndexRef { + fn hash(&self, state: &mut H) { + self.hash.hash(state); + } + } + + let hashes = self.flat_child.hash(None)?; + let is_equal = build_is_equal( + self.flat_child.to_arrow().as_ref(), + self.flat_child.to_arrow().as_ref(), + true, // todo: should nulls and nans be considered equal? + true, + )?; + + let key_type = self.flat_child.data_type().clone(); + let count_type = DataType::UInt64; + + let fixed_size = self.fixed_element_len(); + + + let mut map: IndexMap = IndexMap::default(); + + let mut booleans = Vec::with_capacity(self.flat_child.len()); + let mut count_array = Vec::new(); + let mut offsets = Vec::with_capacity(self.len()); + + offsets.push(0_i64); + + for i in 0..self.len() { + map.clear(); + let start_index = i * fixed_size; + for j in 0..fixed_size { + let index = start_index + j; + + let hash = hashes.get(index).unwrap(); + + let entry = map.raw_entry_mut_v1().from_hash(hash, |other| { + is_equal(other.index, index) + }); + + match entry { + RawEntryMut::Occupied(mut entry) => { + booleans.push(false); + *entry.get_mut() += 1; + } + RawEntryMut::Vacant(vacant) => { + booleans.push(true); + vacant.insert(IndexRef { hash, index }, 1); + } + } + } + + + // indexmap so ordered + for v in map.values() { + count_array.push(*v); + } + + offsets.push(count_array.len() as i64); + } + + let values = UInt64Array::from(("count", count_array)).into_series(); + let boolean_array = BooleanArray::from(("boolean", booleans.as_slice())); + + let keys = self.flat_child.filter(&boolean_array)?; + + + let struct_type = DataType::Struct(vec![ + Field::new("key", key_type.clone()), + Field::new("value", count_type.clone()), + ]); + + let struct_array = StructArray::new( + Arc::new(Field::new("entries", struct_type.clone())), + vec![keys, values], + None, + ); + + let list_type = DataType::List(Box::new(struct_type)); + + let offsets = OffsetsBuffer::try_from(offsets)?; + + let list_array = ListArray::new( + Arc::new(Field::new("entries", list_type.clone())), + struct_array.into_series(), + offsets, + None, + ); + + let map_type = DataType::Map { key: Box::new(key_type), value: Box::new(count_type) }; + + Ok(MapArray::new( + Field::new("entries", map_type.clone()), + list_array, + )) + } + pub fn count(&self, mode: CountMode) -> DaftResult { let size = self.fixed_element_len(); let counts = match (mode, self.flat_child.validity()) { @@ -531,7 +769,7 @@ impl FixedSizeListArray { pub fn join(&self, delimiter: &Utf8Array) -> DaftResult { assert_eq!(self.child_data_type(), &DataType::Utf8,); - let delimiter_iter: Box>> = if delimiter.len() == 1 { + let delimiter_iter: Box>> = if delimiter.len() == 1 { Box::new(repeat(delimiter.get(0)).take(self.len())) } else { assert_eq!(delimiter.len(), self.len()); @@ -556,7 +794,7 @@ impl FixedSizeListArray { fn get_children_helper( &self, - idx_iter: impl Iterator, + idx_iter: impl Iterator, default: &Series, ) -> DaftResult { assert!( diff --git a/src/daft-core/src/array/ops/map.rs b/src/daft-core/src/array/ops/map.rs index 3b2f6ffd8c..a1613ce19c 100644 --- a/src/daft-core/src/array/ops/map.rs +++ b/src/daft-core/src/array/ops/map.rs @@ -24,19 +24,10 @@ fn single_map_get(structs: &Series, key_to_get: &Series) -> DaftResult { impl MapArray { pub fn map_get(&self, key_to_get: &Series) -> DaftResult { - let value_type = if let DataType::Map(inner_dtype) = self.data_type() { - match *inner_dtype.clone() { - DataType::Struct(fields) if fields.len() == 2 => { - fields[1].dtype.clone() - } - _ => { - return Err(DaftError::TypeError(format!( - "Expected inner type to be a struct type with two fields: key and value, got {:?}", - inner_dtype - ))) - } - } - } else { + let DataType::Map { + value: value_type, .. + } = self.data_type() + else { return Err(DaftError::TypeError(format!( "Expected input to be a map type, got {:?}", self.data_type() @@ -49,7 +40,7 @@ impl MapArray { for series in self.physical.into_iter() { match series { Some(s) if !s.is_empty() => result.push(single_map_get(&s, key_to_get)?), - _ => result.push(Series::full_null("value", &value_type, 1)), + _ => result.push(Series::full_null("value", value_type, 1)), } } Series::concat(&result.iter().collect::>()) @@ -59,7 +50,7 @@ impl MapArray { for (i, series) in self.physical.into_iter().enumerate() { match (series, key_to_get.slice(i, i + 1)?) { (Some(s), k) if !s.is_empty() => result.push(single_map_get(&s, &k)?), - _ => result.push(Series::full_null("value", &value_type, 1)), + _ => result.push(Series::full_null("value", value_type, 1)), } } Series::concat(&result.iter().collect::>()) diff --git a/src/daft-core/src/array/struct_array.rs b/src/daft-core/src/array/struct_array.rs index fb0c50fb25..b9d909bcbc 100644 --- a/src/daft-core/src/array/struct_array.rs +++ b/src/daft-core/src/array/struct_array.rs @@ -11,6 +11,8 @@ use crate::{ #[derive(Clone, Debug)] pub struct StructArray { pub field: Arc, + + /// Column representations pub children: Vec, validity: Option, len: usize, diff --git a/src/daft-core/src/datatypes/matching.rs b/src/daft-core/src/datatypes/matching.rs index b8b8e1660f..bae597393c 100644 --- a/src/daft-core/src/datatypes/matching.rs +++ b/src/daft-core/src/datatypes/matching.rs @@ -31,7 +31,7 @@ macro_rules! with_match_daft_types {( FixedSizeList(_, _) => __with_ty__! { FixedSizeListType }, List(_) => __with_ty__! { ListType }, Struct(_) => __with_ty__! { StructType }, - Map(_) => __with_ty__! { MapType }, + Map{..} => __with_ty__! { MapType }, Extension(_, _, _) => __with_ty__! { ExtensionType }, #[cfg(feature = "python")] Python => __with_ty__! { PythonType }, diff --git a/src/daft-core/src/lib.rs b/src/daft-core/src/lib.rs index 322a0db3ec..5892f75ffb 100644 --- a/src/daft-core/src/lib.rs +++ b/src/daft-core/src/lib.rs @@ -2,6 +2,7 @@ #![feature(int_roundings)] #![feature(iterator_try_reduce)] #![feature(if_let_guard)] +#![feature(hash_raw_entry)] pub mod array; pub mod count_mode; diff --git a/src/daft-core/src/series/from.rs b/src/daft-core/src/series/from.rs index 99776edf64..746ba136eb 100644 --- a/src/daft-core/src/series/from.rs +++ b/src/daft-core/src/series/from.rs @@ -12,9 +12,10 @@ use crate::{ impl Series { pub fn try_from_field_and_arrow_array( - field: Arc, + field: impl Into>, array: Box, ) -> DaftResult { + let field = field.into(); // TODO(Nested): Refactor this out with nested logical types in StructArray and ListArray // Corner-case nested logical types that have not yet been migrated to new Array formats // to hold only casted physical arrow arrays. diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index d00f5440c2..b7a2626e73 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -7,6 +7,22 @@ use crate::{ }; impl Series { + pub fn list_value_counts(&self) -> DaftResult { + let series = match self.data_type() { + DataType::List(_) => self.list()?.value_counts(), + DataType::FixedSizeList(..) => self.fixed_size_list()?.value_counts(), + dt => { + return Err(DaftError::TypeError(format!( + "List contains not implemented for {}", + dt + ))) + } + }? + .into_series(); + + Ok(series) + } + pub fn explode(&self) -> DaftResult { match self.data_type() { DataType::List(_) => self.list()?.explode(), @@ -33,7 +49,7 @@ impl Series { arrow2::array::PrimitiveArray::from_vec( offsets.lengths().map(|l| l as u64).collect(), ) - .with_validity(data_array.validity().cloned()), + .with_validity(data_array.validity().cloned()), ); Ok(UInt64Array::from((self.name(), array))) } diff --git a/src/daft-core/src/series/ops/map.rs b/src/daft-core/src/series/ops/map.rs index 85461b1fe0..58f9a5b046 100644 --- a/src/daft-core/src/series/ops/map.rs +++ b/src/daft-core/src/series/ops/map.rs @@ -5,7 +5,7 @@ use crate::{datatypes::DataType, series::Series}; impl Series { pub fn map_get(&self, key: &Series) -> DaftResult { match self.data_type() { - DataType::Map(_) => self.map()?.map_get(key), + DataType::Map { .. } => self.map()?.map_get(key), dt => Err(DaftError::TypeError(format!( "map.get not implemented for {}", dt diff --git a/src/daft-core/src/series/serdes.rs b/src/daft-core/src/series/serdes.rs index bf7e42a1e0..3ce3b6f881 100644 --- a/src/daft-core/src/series/serdes.rs +++ b/src/daft-core/src/series/serdes.rs @@ -163,7 +163,7 @@ impl<'d> serde::Deserialize<'d> for Series { .unwrap() .into_series()) } - DataType::Map(..) => { + DataType::Map { .. } => { let physical = map.next_value::()?; Ok(MapArray::new( Arc::new(field), diff --git a/src/daft-dsl/src/functions/map/get.rs b/src/daft-dsl/src/functions/map/get.rs index ab6eb148f8..bf5f9efdf0 100644 --- a/src/daft-dsl/src/functions/map/get.rs +++ b/src/daft-dsl/src/functions/map/get.rs @@ -13,18 +13,14 @@ impl FunctionEvaluator for GetEvaluator { fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { match inputs { + // what is input and what is key + // input is a map field [input, key] => match (input.to_field(schema), key.to_field(schema)) { (Ok(input_field), Ok(_)) => match input_field.dtype { - DataType::Map(inner) => match inner.as_ref() { - DataType::Struct(fields) if fields.len() == 2 => { - let value_dtype = &fields[1].dtype; - Ok(Field::new("value", value_dtype.clone())) - } - _ => Err(DaftError::TypeError(format!( - "Expected input map to have struct values with 2 fields, got {}", - inner - ))), - }, + DataType::Map { value, .. } => { + // todo: perhaps better naming + Ok(Field::new("value", *value)) + } _ => Err(DaftError::TypeError(format!( "Expected input to be a map, got {}", input_field.dtype diff --git a/src/daft-dsl/src/functions/utf8/split.rs b/src/daft-dsl/src/functions/utf8/split.rs index 0518786055..c0e121b393 100644 --- a/src/daft-dsl/src/functions/utf8/split.rs +++ b/src/daft-dsl/src/functions/utf8/split.rs @@ -8,6 +8,7 @@ pub(super) struct SplitEvaluator {} impl FunctionEvaluator for SplitEvaluator { fn fn_name(&self) -> &'static str { + println!("hi"); "split" } diff --git a/src/daft-functions/src/list/mod.rs b/src/daft-functions/src/list/mod.rs index 2ba3f197be..0d9fdf0a9c 100644 --- a/src/daft-functions/src/list/mod.rs +++ b/src/daft-functions/src/list/mod.rs @@ -9,6 +9,7 @@ mod min; mod slice; mod sort; mod sum; +mod value_counts; pub use chunk::{list_chunk as chunk, ListChunk}; pub use count::{list_count as count, ListCount}; @@ -31,6 +32,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction_bound!(count::py_list_count, parent)?)?; parent.add_function(wrap_pyfunction_bound!(get::py_list_get, parent)?)?; parent.add_function(wrap_pyfunction_bound!(join::py_list_join, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(value_counts::py_list_value_counts, parent)?)?; parent.add_function(wrap_pyfunction_bound!(max::py_list_max, parent)?)?; parent.add_function(wrap_pyfunction_bound!(min::py_list_min, parent)?)?; diff --git a/src/daft-functions/src/list/value_counts.rs b/src/daft-functions/src/list/value_counts.rs new file mode 100644 index 0000000000..83607fc9d7 --- /dev/null +++ b/src/daft-functions/src/list/value_counts.rs @@ -0,0 +1,70 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::prelude::{DataType, Field, Schema, Series}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + python::PyExpr, + ExprRef, +}; +use pyo3::{pyfunction, PyResult}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +struct ListValueCountsFunction; + +#[typetag::serde] +impl ScalarUDF for ListValueCountsFunction { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "list_value_counts" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + let [data] = inputs else { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + }; + + let data_field = data.to_field(schema)?; + + let DataType::List(inner_type) = &data_field.dtype else { + return Err(DaftError::TypeError(format!( + "Expected list, got {}", + data_field.dtype + ))); + }; + + let map_type = DataType::Map { + key: inner_type.clone(), + value: Box::new(DataType::UInt64), + }; + + Ok(Field::new(data_field.name, map_type)) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + let [data] = inputs else { + return Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + }; + + data.list_value_counts() + } +} + +pub fn list_value_counts(expr: ExprRef) -> ExprRef { + ScalarFunction::new(ListValueCountsFunction, vec![expr]).into() +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "list_value_counts")] +pub fn py_list_value_counts(expr: PyExpr) -> PyResult { + Ok(list_value_counts(expr.into()).into()) +} diff --git a/src/daft-schema/src/dtype.rs b/src/daft-schema/src/dtype.rs index 48d9414aab..43c67ec6f4 100644 --- a/src/daft-schema/src/dtype.rs +++ b/src/daft-schema/src/dtype.rs @@ -107,8 +107,11 @@ pub enum DataType { Struct(Vec), /// A nested [`DataType`] that is represented as List>. - #[display("Map[{_0}]")] - Map(Box), + #[display("Map[{key}: {value}]")] + Map { + key: Box, + value: Box, + }, /// Extension type. #[display("{_1}")] @@ -233,14 +236,31 @@ impl DataType { DataType::List(field) => Ok(ArrowType::LargeList(Box::new( arrow2::datatypes::Field::new("item", field.to_arrow()?, true), ))), - DataType::Map(field) => Ok(ArrowType::Map( - Box::new(arrow2::datatypes::Field::new( - "item", - field.to_arrow()?, - true, - )), - false, - )), + DataType::Map { key, value } => { + let struct_type = ArrowType::Struct(vec![ + // We never allow null keys in maps for several reasons: + // 1. Null typically represents the absence of a value, which doesn't make sense for a key. + // 2. Null comparisons can be problematic (similar to how f64::NAN != f64::NAN). + // 3. It maintains consistency with common map implementations in arrow (no null keys). + // 4. It simplifies map operations + // + // This decision aligns with the thoughts of team members like Jay and Sammy, who argue that: + // - Nulls in keys could lead to unintuitive behavior + // - If users need to count or group by null values, they can use other constructs like + // group_by operations on non-map types, which offer more explicit control. + // + // By disallowing null keys, we encourage more robust data modeling practices and + // provide a clearer semantic meaning for map types in our system. + arrow2::datatypes::Field::new("key", key.to_arrow()?, true), + arrow2::datatypes::Field::new("value", value.to_arrow()?, true), + ]); + + // entries + let struct_field = + arrow2::datatypes::Field::new("entries", struct_type.clone(), true); + + Ok(ArrowType::Map(Box::new(struct_field), false)) + } DataType::Struct(fields) => Ok({ let fields = fields .iter() @@ -288,7 +308,10 @@ impl DataType { FixedSizeList(child_dtype, size) => { FixedSizeList(Box::new(child_dtype.to_physical()), *size) } - Map(child_dtype) => List(Box::new(child_dtype.to_physical())), + Map { key, value } => List(Box::new(Struct(vec![ + Field::new("key", key.to_physical()), + Field::new("value", value.to_physical()), + ]))), Embedding(dtype, size) => FixedSizeList(Box::new(dtype.to_physical()), *size), Image(mode) => Struct(vec![ Field::new( @@ -328,20 +351,6 @@ impl DataType { } } - #[inline] - pub fn nested_dtype(&self) -> Option<&DataType> { - match self { - DataType::Map(dtype) - | DataType::List(dtype) - | DataType::FixedSizeList(dtype, _) - | DataType::FixedShapeTensor(dtype, _) - | DataType::SparseTensor(dtype) - | DataType::FixedShapeSparseTensor(dtype, _) - | DataType::Tensor(dtype) => Some(dtype), - _ => None, - } - } - #[inline] pub fn is_arrow(&self) -> bool { self.to_arrow().is_ok() @@ -350,21 +359,21 @@ impl DataType { #[inline] pub fn is_numeric(&self) -> bool { match self { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Int128 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - // DataType::Float16 - | DataType::Float32 - | DataType::Float64 => true, - DataType::Extension(_, inner, _) => inner.is_numeric(), - _ => false - } + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Int128 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + // DataType::Float16 + | DataType::Float32 + | DataType::Float64 => true, + DataType::Extension(_, inner, _) => inner.is_numeric(), + _ => false + } } #[inline] @@ -453,7 +462,7 @@ impl DataType { #[inline] pub fn is_map(&self) -> bool { - matches!(self, DataType::Map(..)) + matches!(self, DataType::Map { .. }) } #[inline] @@ -576,7 +585,7 @@ impl DataType { | DataType::FixedShapeTensor(..) | DataType::SparseTensor(..) | DataType::FixedShapeSparseTensor(..) - | DataType::Map(..) + | DataType::Map { .. } ) } @@ -593,7 +602,7 @@ impl DataType { DataType::List(..) | DataType::FixedSizeList(..) | DataType::Struct(..) - | DataType::Map(..) + | DataType::Map { .. } ) } @@ -643,7 +652,29 @@ impl From<&ArrowType> for DataType { ArrowType::FixedSizeList(field, size) => { DataType::FixedSizeList(Box::new(field.as_ref().data_type().into()), *size) } - ArrowType::Map(field, ..) => DataType::Map(Box::new(field.as_ref().data_type().into())), + ArrowType::Map(field, ..) => { + // todo: TryFrom in future? want in second pass maybe + + // field should be a struct + let ArrowType::Struct(fields) = &field.data_type else { + panic!("Map should have a struct as its key") + }; + + let [key, value] = fields.as_slice() else { + panic!("Map should have two fields") + }; + + let key = &key.data_type; + let value = &value.data_type; + + let key = DataType::from(key); + let value = DataType::from(value); + + let key = Box::new(key); + let value = Box::new(value); + + DataType::Map { key, value } + } ArrowType::Struct(fields) => { let fields: Vec = fields.iter().map(|fld| fld.into()).collect(); DataType::Struct(fields) diff --git a/src/daft-schema/src/python/datatype.rs b/src/daft-schema/src/python/datatype.rs index 9128eceec2..32642fae58 100644 --- a/src/daft-schema/src/python/datatype.rs +++ b/src/daft-schema/src/python/datatype.rs @@ -209,10 +209,10 @@ impl PyDataType { #[staticmethod] pub fn map(key_type: Self, value_type: Self) -> PyResult { - Ok(DataType::Map(Box::new(DataType::Struct(vec![ - Field::new("key", key_type.dtype), - Field::new("value", value_type.dtype), - ]))) + Ok(DataType::Map { + key: Box::new(key_type.dtype), + value: Box::new(value_type.dtype), + } .into()) } diff --git a/src/daft-stats/src/column_stats/mod.rs b/src/daft-stats/src/column_stats/mod.rs index e8dc82f2f8..81f4dd5488 100644 --- a/src/daft-stats/src/column_stats/mod.rs +++ b/src/daft-stats/src/column_stats/mod.rs @@ -71,7 +71,7 @@ impl ColumnRangeStatistics { // UNSUPPORTED TYPES: // Types that don't support comparisons and can't be used as ColumnRangeStatistics - DataType::List(..) | DataType::FixedSizeList(..) | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::SparseTensor(..) | DataType::FixedShapeSparseTensor(..) | DataType::FixedShapeTensor(..) | DataType::Struct(..) | DataType::Map(..) | DataType::Extension(..) | DataType::Embedding(..) | DataType::Unknown => false, + DataType::List(..) | DataType::FixedSizeList(..) | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::SparseTensor(..) | DataType::FixedShapeSparseTensor(..) | DataType::FixedShapeTensor(..) | DataType::Struct(..) | DataType::Map { .. } | DataType::Extension(..) | DataType::Embedding(..) | DataType::Unknown => false, #[cfg(feature = "python")] DataType::Python => false, } diff --git a/src/daft-table/src/repr_html.rs b/src/daft-table/src/repr_html.rs index 79ecaf063a..0e46bb80b2 100644 --- a/src/daft-table/src/repr_html.rs +++ b/src/daft-table/src/repr_html.rs @@ -102,7 +102,7 @@ pub fn html_value(s: &Series, idx: usize) -> String { let arr = s.struct_().unwrap(); arr.html_value(idx) } - DataType::Map(_) => { + DataType::Map { .. } => { let arr = s.map().unwrap(); arr.html_value(idx) } diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index d3727c2ac3..23e4674ba3 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -508,3 +508,61 @@ def test_repr_series_lit() -> None: s = lit(Series.from_pylist([1, 2, 3])) output = repr(s) assert output == "lit([1, 2, 3])" + + +def test_list_value_counts(): + # Create a MicroPartition with a list column + mp = MicroPartition.from_pydict({ + "list_col": [ + ["a", "b", "a", "c"], + ["b", "b", "c"], + ["a", "a", "a"], + [], + ["d", None, "d"] + ] + }) + + print("mp is ", mp) + + # mp = MicroPartition.from_pydict({ + # "list_col": [ + # ["a", "b", "a", "c"], + # ["b", "b", "c"], + # ["a", "a", "a"], + # [], + # ["d", "d"] + # ] + # }) + + # # Apply list_value_counts operation + result = mp.eval_expression_list([col("list_col").list.value_counts().alias("value_counts")]) + value_counts = result.to_pydict()["value_counts"] + + # Expected output + expected = [ + [("a", 2), ("b", 1), ("c", 1)], + [("b", 2), ("c", 1)], + [("a", 3)], + [], + [("d", 2)] + ] + + assert value_counts == expected + + # # Check the result + # value_counts = result.to_pydict()["value_counts"] + # print(value_counts) + # assert value_counts == expected + + # # Test with empty input + # empty_mp = MicroPartition.from_pydict({"list_col": []}) + # empty_result = empty_mp.eval_expression_list([col("list_col").list.value_counts().alias("value_counts")]) + # assert empty_result.to_pydict()["value_counts"] == [] + + # # Test with all None input + # none_mp = MicroPartition.from_pydict({"list_col": [None, None, None]}) + # none_result = none_mp.eval_expression_list([col("list_col").list.value_counts().alias("value_counts")]) + # assert none_result.to_pydict()["value_counts"] == [None, None, None] + + +