diff --git a/src/daft-core/src/array/mod.rs b/src/daft-core/src/array/mod.rs index fbcc4ef01f..d7e5fc7df4 100644 --- a/src/daft-core/src/array/mod.rs +++ b/src/daft-core/src/array/mod.rs @@ -5,7 +5,7 @@ pub mod pseudo_arrow; use std::{marker::PhantomData, sync::Arc}; -use crate::datatypes::{DaftPhysicalType, DataType, Field}; +use crate::datatypes::{DaftArrayType, DaftPhysicalType, DataType, Field}; use common_error::{DaftError, DaftResult}; @@ -22,6 +22,8 @@ impl Clone for DataArray { } } +impl DaftArrayType for DataArray {} + impl DataArray where T: DaftPhysicalType, diff --git a/src/daft-core/src/array/ops/arange.rs b/src/daft-core/src/array/ops/arange.rs index 4f5a04eb29..1145e8d9f0 100644 --- a/src/daft-core/src/array/ops/arange.rs +++ b/src/daft-core/src/array/ops/arange.rs @@ -19,7 +19,7 @@ where let arrow_array = Box::new(arrow2::array::PrimitiveArray::::from_vec(data)); let data_array = Int64Array::from((name.as_ref(), arrow_array)); let casted_array = data_array.cast(&T::get_dtype())?; - let downcasted = casted_array.downcast::()?; + let downcasted = casted_array.downcast::>()?; Ok(downcasted.clone()) } } diff --git a/src/daft-core/src/array/ops/arithmetic.rs b/src/daft-core/src/array/ops/arithmetic.rs index ce94bfb7b4..1066a7a5f4 100644 --- a/src/daft-core/src/array/ops/arithmetic.rs +++ b/src/daft-core/src/array/ops/arithmetic.rs @@ -10,7 +10,7 @@ use crate::{ use common_error::{DaftError, DaftResult}; -use super::as_arrow::AsArrow; +use super::{as_arrow::AsArrow, full::FullNull}; /// Helper function to perform arithmetic operations on a DataArray /// Takes both Kernel (array x array operation) and operation (scalar x scalar) functions /// The Kernel is used for when both arrays are non-unit length and the operation is used when broadcasting diff --git a/src/daft-core/src/array/ops/broadcast.rs b/src/daft-core/src/array/ops/broadcast.rs index 7a465bebde..d1d3dba764 100644 --- a/src/daft-core/src/array/ops/broadcast.rs +++ b/src/daft-core/src/array/ops/broadcast.rs @@ -8,7 +8,7 @@ use crate::{ use common_error::{DaftError, DaftResult}; -use super::as_arrow::AsArrow; +use super::{as_arrow::AsArrow, full::FullNull}; #[cfg(feature = "python")] use crate::array::pseudo_arrow::PseudoArrowArray; diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index 8689c1e744..676cbc6635 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -1,13 +1,18 @@ use super::as_arrow::AsArrow; use crate::{ - array::{ops::from_arrow::FromArrow, ops::image::ImageArraySidecarData, DataArray}, + array::{ + ops::image::ImageArraySidecarData, + ops::{from_arrow::FromArrow, full::FullNull}, + DataArray, + }, datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, LogicalArray, TensorArray, TimestampArray, + FixedShapeTensorArray, ImageArray, LogicalArray, LogicalArrayImpl, TensorArray, + TimestampArray, }, - DaftArrowBackedType, DaftLogicalType, DataType, Field, FixedShapeTensorType, - FixedSizeListArray, ImageMode, StructArray, TensorType, TimeUnit, Utf8Array, + DaftArrowBackedType, DaftLogicalType, DataType, Field, FixedSizeListArray, ImageMode, + StructArray, TimeUnit, Utf8Array, }, series::{IntoSeries, Series}, with_match_arrow_daft_types, with_match_daft_logical_primitive_types, @@ -39,9 +44,13 @@ use { std::ops::Deref, }; -fn arrow_logical_cast(to_cast: &LogicalArray, dtype: &DataType) -> DaftResult +fn arrow_logical_cast( + to_cast: &LogicalArrayImpl>, + dtype: &DataType, +) -> DaftResult where T: DaftLogicalType, + T::PhysicalType: DaftArrowBackedType, { // Cast from LogicalArray to the target DataType // using Arrow's casting mechanisms. @@ -49,7 +58,7 @@ where // Note that Arrow Logical->Logical direct casts (what this method exposes) // have different behaviour than Arrow Logical->Physical->Logical casts. - let source_dtype = to_cast.logical_type(); + let source_dtype = to_cast.data_type(); let source_arrow_type = source_dtype.to_arrow()?; let target_arrow_type = dtype.to_arrow()?; @@ -360,8 +369,8 @@ impl TimestampArray { match dtype { DataType::Timestamp(..) => arrow_logical_cast(self, dtype), DataType::Utf8 => { - let DataType::Timestamp(unit, timezone) = self.logical_type() else { - panic!("Wrong dtype for TimestampArray: {}", self.logical_type()) + let DataType::Timestamp(unit, timezone) = self.data_type() else { + panic!("Wrong dtype for TimestampArray: {}", self.data_type()) }; let str_array: arrow2::array::Utf8Array = timezone.as_ref().map_or_else( @@ -1079,7 +1088,7 @@ impl PythonArray { impl EmbeddingArray { pub fn cast(&self, dtype: &DataType) -> DaftResult { - match (dtype, self.logical_type()) { + match (dtype, self.data_type()) { #[cfg(feature = "python")] (DataType::Python, DataType::Embedding(_, size)) => Python::with_gil(|py| { let shape = (self.len(), *size); @@ -1111,7 +1120,7 @@ impl EmbeddingArray { DataType::FixedShapeTensor(Box::new(inner_dtype.clone().dtype), image_shape); let fixed_shape_tensor_array = self.cast(&fixed_shape_tensor_dtype)?; let fixed_shape_tensor_array = - fixed_shape_tensor_array.downcast_logical::()?; + fixed_shape_tensor_array.downcast::()?; fixed_shape_tensor_array.cast(dtype) } // NOTE(Clark): Casting to FixedShapeTensor is supported by the physical array cast. @@ -1174,7 +1183,7 @@ impl ImageArray { } })?; let fixed_shape_tensor_array = - fixed_shape_tensor_array.downcast_logical::()?; + fixed_shape_tensor_array.downcast::()?; fixed_shape_tensor_array.cast(dtype) } DataType::Tensor(_) => { @@ -1230,7 +1239,7 @@ impl ImageArray { DataType::FixedShapeTensor(inner_dtype, _) => { let tensor_dtype = DataType::Tensor(inner_dtype.clone()); let tensor_array = self.cast(&tensor_dtype)?; - let tensor_array = tensor_array.downcast_logical::()?; + let tensor_array = tensor_array.downcast::()?; tensor_array.cast(dtype) } _ => self.physical.cast(dtype), @@ -1240,7 +1249,7 @@ impl ImageArray { impl FixedShapeImageArray { pub fn cast(&self, dtype: &DataType) -> DaftResult { - match (dtype, self.logical_type()) { + match (dtype, self.data_type()) { #[cfg(feature = "python")] (DataType::Python, DataType::FixedShapeImage(mode, height, width)) => { pyo3::Python::with_gil(|py| { @@ -1299,13 +1308,13 @@ impl FixedShapeImageArray { } })?; let fixed_shape_tensor_array = - fixed_shape_tensor_array.downcast_logical::()?; + fixed_shape_tensor_array.downcast::()?; fixed_shape_tensor_array.cast(dtype) } (DataType::Image(_), DataType::FixedShapeImage(mode, _, _)) => { let tensor_dtype = DataType::Tensor(Box::new(mode.get_dtype())); let tensor_array = self.cast(&tensor_dtype)?; - let tensor_array = tensor_array.downcast_logical::()?; + let tensor_array = tensor_array.downcast::()?; tensor_array.cast(dtype) } // NOTE(Clark): Casting to FixedShapeTensor is supported by the physical array cast. @@ -1505,7 +1514,7 @@ impl TensorArray { } })?; let fixed_shape_tensor_array = - fixed_shape_tensor_array.downcast_logical::()?; + fixed_shape_tensor_array.downcast::()?; fixed_shape_tensor_array.cast(dtype) } _ => self.physical.cast(dtype), @@ -1515,7 +1524,7 @@ impl TensorArray { impl FixedShapeTensorArray { pub fn cast(&self, dtype: &DataType) -> DaftResult { - match (dtype, self.logical_type()) { + match (dtype, self.data_type()) { #[cfg(feature = "python")] (DataType::Python, DataType::FixedShapeTensor(_, shape)) => { pyo3::Python::with_gil(|py| { @@ -1616,11 +1625,12 @@ impl FixedShapeTensorArray { fn cast_logical_to_python_array(array: &LogicalArray, dtype: &DataType) -> DaftResult where T: DaftLogicalType, + T::PhysicalType: DaftArrowBackedType, LogicalArray: AsArrow, as AsArrow>::Output: arrow2::array::Array, { Python::with_gil(|py| { - let arrow_dtype = array.logical_type().to_arrow()?; + let arrow_dtype = array.data_type().to_arrow()?; let arrow_array = array.as_arrow().to_type(arrow_dtype).with_validity(None); let pyarrow = py.import("pyarrow")?; let py_array: Vec = ffi::to_py_array(arrow_array.to_boxed(), py, pyarrow)? diff --git a/src/daft-core/src/array/ops/compare_agg.rs b/src/daft-core/src/array/ops/compare_agg.rs index 3c7fdf52a3..ec1f829943 100644 --- a/src/daft-core/src/array/ops/compare_agg.rs +++ b/src/daft-core/src/array/ops/compare_agg.rs @@ -1,5 +1,5 @@ use super::{DaftCompareAggable, GroupIndices}; -use crate::{array::DataArray, datatypes::*}; +use crate::{array::ops::full::FullNull, array::DataArray, datatypes::*}; use arrow2::array::PrimitiveArray; use arrow2::{self, array::Array}; diff --git a/src/daft-core/src/array/ops/comparison.rs b/src/daft-core/src/array/ops/comparison.rs index 4d70d95004..9bf61bed3e 100644 --- a/src/daft-core/src/array/ops/comparison.rs +++ b/src/daft-core/src/array/ops/comparison.rs @@ -12,7 +12,7 @@ use common_error::{DaftError, DaftResult}; use std::ops::Not; -use super::{DaftCompare, DaftLogical}; +use super::{full::FullNull, DaftCompare, DaftLogical}; use super::as_arrow::AsArrow; use arrow2::{compute::comparison, scalar::PrimitiveScalar}; diff --git a/src/daft-core/src/array/ops/from_arrow.rs b/src/daft-core/src/array/ops/from_arrow.rs index 12e2eaf694..e4b3ce8ed7 100644 --- a/src/daft-core/src/array/ops/from_arrow.rs +++ b/src/daft-core/src/array/ops/from_arrow.rs @@ -2,7 +2,7 @@ use common_error::DaftResult; use crate::{ array::DataArray, - datatypes::{logical::LogicalArray, DaftLogicalType, DaftPhysicalType, Field}, + datatypes::{logical::LogicalArray, DaftDataType, DaftLogicalType, DaftPhysicalType, Field}, }; /// Arrays that implement [`FromArrow`] can be instantiated from a Box @@ -19,10 +19,17 @@ impl FromArrow for DataArray { } } -impl FromArrow for LogicalArray { +impl FromArrow for LogicalArray +where + ::ArrayType: FromArrow, +{ fn from_arrow(field: &Field, arrow_arr: Box) -> DaftResult { let data_array_field = Field::new(field.name.clone(), field.dtype.to_physical()); - let physical = DataArray::try_from((data_array_field, arrow_arr))?; + let physical_arrow_arr = arrow_arr.to_type(data_array_field.dtype.to_arrow()?); + let physical = ::ArrayType::from_arrow( + &data_array_field, + physical_arrow_arr, + )?; Ok(LogicalArray::::new(field.clone(), physical)) } } diff --git a/src/daft-core/src/array/ops/full.rs b/src/daft-core/src/array/ops/full.rs index ff0b59af16..6e1244f1fe 100644 --- a/src/daft-core/src/array/ops/full.rs +++ b/src/daft-core/src/array/ops/full.rs @@ -5,15 +5,22 @@ use pyo3::Python; use crate::{ array::{pseudo_arrow::PseudoArrowArray, DataArray}, - datatypes::{DaftPhysicalType, DataType, Field}, + datatypes::{ + logical::LogicalArray, DaftDataType, DaftLogicalType, DaftPhysicalType, DataType, Field, + }, }; -impl DataArray +pub trait FullNull { + fn full_null(name: &str, dtype: &DataType, length: usize) -> Self; + fn empty(name: &str, dtype: &DataType) -> Self; +} + +impl FullNull for DataArray where T: DaftPhysicalType, { /// Creates a DataArray of size `length` that is filled with all nulls. - pub fn full_null(name: &str, dtype: &DataType, length: usize) -> Self { + fn full_null(name: &str, dtype: &DataType, length: usize) -> Self { let field = Field::new(name, dtype.clone()); #[cfg(feature = "python")] if dtype.is_python() { @@ -37,7 +44,7 @@ where } } - pub fn empty(name: &str, dtype: &DataType) -> Self { + fn empty(name: &str, dtype: &DataType) -> Self { let field = Field::new(name, dtype.clone()); #[cfg(feature = "python")] if dtype.is_python() { @@ -59,3 +66,20 @@ where } } } + +impl FullNull for LogicalArray +where + ::ArrayType: FullNull, +{ + fn full_null(name: &str, dtype: &DataType, length: usize) -> Self { + let physical = ::ArrayType::full_null(name, dtype, length); + Self::new(Field::new(name, dtype.clone()), physical) + } + + fn empty(field_name: &str, dtype: &DataType) -> Self { + let physical = + ::ArrayType::empty(field_name, &dtype.to_physical()); + let field = Field::new(field_name, dtype.clone()); + Self::new(field, physical) + } +} diff --git a/src/daft-core/src/array/ops/if_else.rs b/src/daft-core/src/array/ops/if_else.rs index 07facbcdff..95da88eb92 100644 --- a/src/daft-core/src/array/ops/if_else.rs +++ b/src/daft-core/src/array/ops/if_else.rs @@ -1,3 +1,4 @@ +use crate::array::ops::full::FullNull; use crate::array::DataArray; use crate::datatypes::logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, diff --git a/src/daft-core/src/array/ops/image.rs b/src/daft-core/src/array/ops/image.rs index d6e7c05c8c..70a0a14f5e 100644 --- a/src/daft-core/src/array/ops/image.rs +++ b/src/daft-core/src/array/ops/image.rs @@ -4,15 +4,15 @@ use std::vec; use image::{ColorType, DynamicImage, ImageBuffer}; +use crate::datatypes::FixedSizeListArray; use crate::datatypes::{ logical::{DaftImageryType, FixedShapeImageArray, ImageArray, LogicalArray}, - BinaryArray, DaftLogicalType, DataType, Field, FixedSizeListArray, ImageFormat, ImageMode, - StructArray, + BinaryArray, DataType, Field, ImageFormat, ImageMode, StructArray, }; use common_error::{DaftError, DaftResult}; use image::{Luma, LumaA, Rgb, Rgba}; -use super::as_arrow::AsArrow; +use super::{as_arrow::AsArrow, from_arrow::FromArrow}; use num_traits::FromPrimitive; use std::ops::Deref; @@ -315,22 +315,24 @@ pub struct ImageArraySidecarData { } pub trait AsImageObj { + fn name(&self) -> &str; + fn len(&self) -> usize; fn as_image_obj(&self, idx: usize) -> Option>; } -pub struct ImageBufferIter<'a, T> +pub struct ImageBufferIter<'a, Arr> where - T: DaftLogicalType + DaftImageryType, + Arr: AsImageObj, { cursor: usize, - image_array: &'a LogicalArray, + image_array: &'a Arr, } -impl<'a, T> ImageBufferIter<'a, T> +impl<'a, Arr> ImageBufferIter<'a, Arr> where - T: DaftLogicalType + DaftImageryType, + Arr: AsImageObj, { - pub fn new(image_array: &'a LogicalArray) -> Self { + pub fn new(image_array: &'a Arr) -> Self { Self { cursor: 0usize, image_array, @@ -338,10 +340,9 @@ where } } -impl<'a, T> Iterator for ImageBufferIter<'a, T> +impl<'a, Arr> Iterator for ImageBufferIter<'a, Arr> where - T: DaftLogicalType + DaftImageryType, - LogicalArray: AsImageObj, + Arr: AsImageObj, { type Item = Option>; @@ -358,7 +359,7 @@ where impl ImageArray { pub fn image_mode(&self) -> &Option { - match self.logical_type() { + match self.data_type() { DataType::Image(mode) => mode, _ => panic!("Expected dtype to be Image"), } @@ -579,6 +580,14 @@ impl ImageArray { } impl AsImageObj for ImageArray { + fn len(&self) -> usize { + ImageArray::len(self) + } + + fn name(&self) -> &str { + ImageArray::name(self) + } + fn as_image_obj<'a>(&'a self, idx: usize) -> Option> { assert!(idx < self.len()); if !self.physical.is_valid(idx) { @@ -685,10 +694,8 @@ impl FixedShapeImageArray { Box::new(arrow2::array::PrimitiveArray::from_vec(data)), validity, )); - let physical_array = FixedSizeListArray::new( - Field::new(name, (&arrow_dtype).into()).into(), - arrow_array.boxed(), - )?; + let physical_array = + FixedSizeListArray::from_arrow(&Field::new(name, (&arrow_dtype).into()), arrow_array)?; let logical_dtype = DataType::FixedShapeImage(*image_mode, height, width); Ok(Self::new(Field::new(name, logical_dtype), physical_array)) } @@ -699,7 +706,7 @@ impl FixedShapeImageArray { pub fn resize(&self, w: u32, h: u32) -> DaftResult { let result = resize_images(self, w, h); - match self.logical_type() { + match self.data_type() { DataType::FixedShapeImage(mode, _, _) => Self::from_daft_image_buffers(self.name(), result.as_slice(), mode, h, w), dt => panic!("FixedShapeImageArray should always have DataType::FixedShapeImage() as it's dtype, but got {}", dt), } @@ -727,13 +734,21 @@ impl FixedShapeImageArray { } impl AsImageObj for FixedShapeImageArray { + fn len(&self) -> usize { + FixedShapeImageArray::len(self) + } + + fn name(&self) -> &str { + FixedShapeImageArray::name(self) + } + fn as_image_obj<'a>(&'a self, idx: usize) -> Option> { assert!(idx < self.len()); if !self.physical.is_valid(idx) { return None; } - match self.logical_type() { + match self.data_type() { DataType::FixedShapeImage(mode, height, width) => { let arrow_array = self.as_arrow().values().as_any().downcast_ref::().unwrap(); let num_channels = mode.num_channels(); @@ -772,7 +787,7 @@ where LogicalArray: AsImageObj, { type Item = Option>; - type IntoIter = ImageBufferIter<'a, T>; + type IntoIter = ImageBufferIter<'a, LogicalArray>; fn into_iter(self) -> Self::IntoIter { ImageBufferIter::new(self) @@ -815,15 +830,10 @@ impl BinaryArray { } } -fn encode_images<'a, T>( - images: &'a LogicalArray, - image_format: ImageFormat, -) -> DaftResult +fn encode_images<'a, Arr>(images: &'a Arr, image_format: ImageFormat) -> DaftResult where - T: DaftImageryType, - LogicalArray: AsImageObj, - &'a LogicalArray: - IntoIterator>, IntoIter = ImageBufferIter<'a, T>>, + Arr: AsImageObj, + &'a Arr: IntoIterator>, IntoIter = ImageBufferIter<'a, Arr>>, { let arrow_array = match image_format { ImageFormat::TIFF => { @@ -911,12 +921,10 @@ where ) } -fn resize_images<'a, T>(images: &'a LogicalArray, w: u32, h: u32) -> Vec> +fn resize_images<'a, Arr>(images: &'a Arr, w: u32, h: u32) -> Vec> where - T: DaftImageryType, - LogicalArray: AsImageObj, - &'a LogicalArray: - IntoIterator>, IntoIter = ImageBufferIter<'a, T>>, + Arr: AsImageObj, + &'a Arr: IntoIterator>, IntoIter = ImageBufferIter<'a, Arr>>, { images .into_iter() @@ -924,15 +932,13 @@ where .collect::>() } -fn crop_images<'a, T>( - images: &'a LogicalArray, +fn crop_images<'a, Arr>( + images: &'a Arr, bboxes: &mut dyn Iterator>, ) -> Vec>> where - T: DaftImageryType, - LogicalArray: AsImageObj, - &'a LogicalArray: - IntoIterator>, IntoIter = ImageBufferIter<'a, T>>, + Arr: AsImageObj, + &'a Arr: IntoIterator>, IntoIter = ImageBufferIter<'a, Arr>>, { images .into_iter() diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index 755aab5b0d..cc6af69f1d 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -15,7 +15,7 @@ mod date; mod filter; mod float; pub mod from_arrow; -mod full; +pub mod full; mod get; pub(crate) mod groups; mod hash; diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index 9c52f21d20..262559776f 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -3,7 +3,7 @@ use arrow2; use common_error::{DaftError, DaftResult}; -use super::as_arrow::AsArrow; +use super::{as_arrow::AsArrow, full::FullNull}; impl Utf8Array { pub fn endswith(&self, pattern: &Utf8Array) -> DaftResult { diff --git a/src/daft-core/src/datatypes/logical.rs b/src/daft-core/src/datatypes/logical.rs index 059ee47740..48c85c3bec 100644 --- a/src/daft-core/src/datatypes/logical.rs +++ b/src/daft-core/src/datatypes/logical.rs @@ -1,121 +1,134 @@ use std::{marker::PhantomData, sync::Arc}; -use crate::datatypes::{BooleanArray, DaftLogicalType, DateType, Field}; +use crate::{ + datatypes::{DaftLogicalType, DateType, Field}, + with_match_daft_logical_primitive_types, +}; use common_error::DaftResult; use super::{ - DataArray, DataType, Decimal128Type, DurationType, EmbeddingType, FixedShapeImageType, - FixedShapeTensorType, ImageType, TensorType, TimestampType, + DaftArrayType, DaftDataType, DataArray, DataType, Decimal128Type, DurationType, EmbeddingType, + FixedShapeImageType, FixedShapeTensorType, ImageType, TensorType, TimestampType, }; -pub struct LogicalArray { + +/// A LogicalArray is a wrapper on top of some underlying array, applying the semantic meaning of its +/// field.datatype() to the underlying array. +#[derive(Clone)] +pub struct LogicalArrayImpl { pub field: Arc, - pub physical: DataArray, + pub physical: PhysicalArray, marker_: PhantomData, } -impl Clone for LogicalArray { - fn clone(&self) -> Self { - LogicalArray::new(self.field.clone(), self.physical.clone()) - } -} +impl DaftArrayType for LogicalArrayImpl {} -impl LogicalArray { - pub fn new>>(field: F, physical: DataArray) -> Self { +impl LogicalArrayImpl { + pub fn new>>(field: F, physical: P) -> Self { let field = field.into(); assert!( field.dtype.is_logical(), "Can only construct Logical Arrays on Logical Types, got {}", field.dtype ); - assert_eq!( - physical.data_type(), - &field.dtype.to_physical(), - "Expected {} for Physical Array, got {}", - &field.dtype.to_physical(), - physical.data_type() - ); - - LogicalArray { + // TODO(FixedSizeList): How to do this assert on the physical datatype? + // assert_eq!( + // physical.data_type(), + // &field.dtype.to_physical(), + // "Expected {} for Physical Array, got {}", + // &field.dtype.to_physical(), + // physical.data_type() + // ); + LogicalArrayImpl { physical, field, marker_: PhantomData, } } - pub fn empty(name: &str, dtype: &DataType) -> Self { - let field = Field::new(name, dtype.clone()); - Self::new(field, DataArray::empty(name, &dtype.to_physical())) - } - pub fn name(&self) -> &str { self.field.name.as_ref() } - pub fn rename(&self, name: &str) -> Self { - let new_field = self.field.rename(name); - let new_array = self.physical.rename(name); - Self::new(new_field, new_array) - } - pub fn field(&self) -> &Field { &self.field } - pub fn logical_type(&self) -> &DataType { + pub fn data_type(&self) -> &DataType { &self.field.dtype } +} - pub fn physical_type(&self) -> &DataType { - self.physical.data_type() - } - - pub fn len(&self) -> usize { - self.physical.len() - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - pub fn size_bytes(&self) -> DaftResult { - self.physical.size_bytes() - } - - pub fn slice(&self, start: usize, end: usize) -> DaftResult { - let new_array = self.physical.slice(start, end)?; - Ok(Self::new(self.field.clone(), new_array)) - } +macro_rules! impl_logical_type { + ($physical_array_type:ident) => { + pub fn len(&self) -> usize { + self.physical.len() + } - pub fn head(&self, num: usize) -> DaftResult { - self.slice(0, num) - } + pub fn is_empty(&self) -> bool { + self.len() == 0 + } - pub fn concat(arrays: &[&Self]) -> DaftResult { - if arrays.is_empty() { - return Err(common_error::DaftError::ValueError( - "Need at least 1 logical array to concat".to_string(), - )); + pub fn concat(arrays: &[&Self]) -> DaftResult { + if arrays.is_empty() { + return Err(common_error::DaftError::ValueError( + "Need at least 1 logical array to concat".to_string(), + )); + } + let physicals: Vec<_> = arrays.iter().map(|a| &a.physical).collect(); + let concatd = $physical_array_type::concat(physicals.as_slice())?; + Ok(Self::new(arrays.first().unwrap().field.clone(), concatd)) } - let physicals: Vec<_> = arrays.iter().map(|a| &a.physical).collect(); - let concatd = DataArray::::concat(physicals.as_slice())?; - Ok(Self::new(arrays.first().unwrap().field.clone(), concatd)) - } + }; +} - pub fn filter(&self, mask: &BooleanArray) -> DaftResult { - let new_array = self.physical.filter(mask)?; - Ok(Self::new(self.field.clone(), new_array)) +/// Implementation for a LogicalArray that wraps a DataArray +impl LogicalArrayImpl> { + impl_logical_type!(DataArray); + + pub fn to_arrow(&self) -> Box { + let daft_type = self.data_type(); + let arrow_logical_type = daft_type.to_arrow().unwrap(); + let physical_arrow_array = self.physical.data(); + use crate::datatypes::DataType::*; + match daft_type { + // For wrapped primitive types, switch the datatype label on the arrow2 Array. + Decimal128(..) | Date | Timestamp(..) | Duration(..) => { + with_match_daft_logical_primitive_types!(daft_type, |$P| { + use arrow2::array::Array; + physical_arrow_array + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + .to(arrow_logical_type) + .to_boxed() + }) + } + // Otherwise, use arrow cast to make sure the result arrow2 array is of the correct type. + _ => arrow2::compute::cast::cast( + physical_arrow_array, + &arrow_logical_type, + arrow2::compute::cast::CastOptions { + wrapped: true, + partial: false, + }, + ) + .unwrap(), + } } } +pub type LogicalArray = + LogicalArrayImpl::PhysicalType as DaftDataType>::ArrayType>; pub type Decimal128Array = LogicalArray; pub type DateArray = LogicalArray; pub type DurationArray = LogicalArray; -pub type EmbeddingArray = LogicalArray; pub type ImageArray = LogicalArray; -pub type FixedShapeImageArray = LogicalArray; pub type TimestampArray = LogicalArray; pub type TensorArray = LogicalArray; +pub type EmbeddingArray = LogicalArray; pub type FixedShapeTensorArray = LogicalArray; +pub type FixedShapeImageArray = LogicalArray; pub trait DaftImageryType: DaftLogicalType {} diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index 98caecf218..d3f53f1d30 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -23,8 +23,14 @@ use num_traits::{Bounded, Float, FromPrimitive, Num, NumCast, ToPrimitive, Zero} pub use time_unit::TimeUnit; pub mod logical; +/// Trait that is implemented by all Array types +pub trait DaftArrayType {} + /// Trait to wrap DataType Enum -pub trait DaftDataType: Sync + Send { +pub trait DaftDataType: Sync + Send + Clone { + // Concrete ArrayType that backs data of this DataType + type ArrayType: DaftArrayType; + // returns Daft DataType Enum fn get_dtype() -> DataType where @@ -36,11 +42,12 @@ pub trait DaftPhysicalType: Send + Sync + DaftDataType {} pub trait DaftArrowBackedType: Send + Sync + DaftPhysicalType + 'static {} pub trait DaftLogicalType: Send + Sync + DaftDataType + 'static { - type PhysicalType: DaftArrowBackedType; + type PhysicalType: DaftPhysicalType; } macro_rules! impl_daft_arrow_datatype { ($ca:ident, $variant:ident) => { + #[derive(Clone)] pub struct $ca {} impl DaftDataType for $ca { @@ -48,6 +55,8 @@ macro_rules! impl_daft_arrow_datatype { fn get_dtype() -> DataType { DataType::$variant } + + type ArrayType = DataArray<$ca>; } impl DaftArrowBackedType for $ca {} @@ -57,6 +66,7 @@ macro_rules! impl_daft_arrow_datatype { macro_rules! impl_daft_non_arrow_datatype { ($ca:ident, $variant:ident) => { + #[derive(Clone)] pub struct $ca {} impl DaftDataType for $ca { @@ -64,6 +74,8 @@ macro_rules! impl_daft_non_arrow_datatype { fn get_dtype() -> DataType { DataType::$variant } + + type ArrayType = DataArray<$ca>; } impl DaftPhysicalType for $ca {} }; @@ -71,6 +83,7 @@ macro_rules! impl_daft_non_arrow_datatype { macro_rules! impl_daft_logical_datatype { ($ca:ident, $variant:ident, $physical_type:ident) => { + #[derive(Clone)] pub struct $ca {} impl DaftDataType for $ca { @@ -78,6 +91,8 @@ macro_rules! impl_daft_logical_datatype { fn get_dtype() -> DataType { DataType::$variant } + + type ArrayType = logical::LogicalArray<$ca>; } impl DaftLogicalType for $ca { diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 2dc1b91f74..ddc7d6fc45 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -5,7 +5,7 @@ use pyo3::{exceptions::PyValueError, prelude::*, pyclass::CompareOp, types::PyLi use crate::{ array::{ops::DaftLogical, pseudo_arrow::PseudoArrowArray, DataArray}, count_mode::CountMode, - datatypes::{DataType, Field, ImageFormat, ImageMode, PythonType, UInt64Type}, + datatypes::{DataType, Field, ImageFormat, ImageMode, PythonType}, ffi, series::{self, IntoSeries, Series}, utils::arrow::{cast_array_for_daft_if_needed, cast_array_from_daft_if_needed}, @@ -155,7 +155,7 @@ impl PySeries { ))); } seed_series = s.series; - seed_array = Some(seed_series.downcast::()?); + seed_array = Some(seed_series.u64()?); } Ok(self.series.hash(seed_array)?.into_series().into()) } diff --git a/src/daft-core/src/series/array_impl/binary_ops.rs b/src/daft-core/src/series/array_impl/binary_ops.rs index 0c3e11db96..1f3cfb5183 100644 --- a/src/daft-core/src/series/array_impl/binary_ops.rs +++ b/src/daft-core/src/series/array_impl/binary_ops.rs @@ -4,7 +4,7 @@ use common_error::DaftResult; use crate::{ array::ops::{DaftCompare, DaftLogical}, - datatypes::{logical::Decimal128Array, BooleanType, Float64Type, Int128Array, Utf8Type}, + datatypes::{logical::Decimal128Array, Int128Array}, series::series_like::SeriesLike, with_match_comparable_daft_types, with_match_numeric_daft_types, DataType, }; @@ -77,7 +77,13 @@ macro_rules! py_numeric_binary_op { Python => Ok(py_binary_op!(lhs, $rhs, $pyop)), output_type if output_type.is_numeric() => { with_match_numeric_daft_types!(output_type, |$T| { - cast_downcast_op_into_series!(lhs, $rhs, output_type, $T, $op) + cast_downcast_op_into_series!( + lhs, + $rhs, + output_type, + <$T as DaftDataType>::ArrayType, + $op + ) }) } _ => binary_op_unimplemented!(lhs, $pyop, $rhs, output_type), @@ -94,9 +100,9 @@ macro_rules! physical_logic_op { match (&lhs.data_type(), &$rhs.data_type()) { #[cfg(feature = "python")] (Python, _) | (_, Python) => py_binary_op_bool!(lhs, $rhs, $pyop) - .downcast::() + .downcast::() .cloned(), - _ => cast_downcast_op!(lhs, $rhs, &Boolean, BooleanType, $op), + _ => cast_downcast_op!(lhs, $rhs, &Boolean, BooleanArray, $op), } } else { unreachable!() @@ -113,10 +119,10 @@ macro_rules! physical_compare_op { match comp_type { #[cfg(feature = "python")] Python => py_binary_op_bool!(lhs, $rhs, $pyop) - .downcast::() + .downcast::() .cloned(), _ => with_match_comparable_daft_types!(comp_type, |$T| { - cast_downcast_op!(lhs, $rhs, &comp_type, $T, $op) + cast_downcast_op!(lhs, $rhs, &comp_type, <$T as DaftDataType>::ArrayType, $op) }), } } else { @@ -133,10 +139,10 @@ pub(crate) trait SeriesBinaryOps: SeriesLike { match &output_type { #[cfg(feature = "python")] Python => Ok(py_binary_op!(lhs, rhs, "add")), - Utf8 => cast_downcast_op_into_series!(lhs, rhs, &Utf8, Utf8Type, add), + Utf8 => cast_downcast_op_into_series!(lhs, rhs, &Utf8, Utf8Array, add), output_type if output_type.is_numeric() => { with_match_numeric_daft_types!(output_type, |$T| { - cast_downcast_op_into_series!(lhs, rhs, output_type, $T, add) + cast_downcast_op_into_series!(lhs, rhs, output_type, <$T as DaftDataType>::ArrayType, add) }) } _ => binary_op_unimplemented!(lhs, "+", rhs, output_type), @@ -155,7 +161,7 @@ pub(crate) trait SeriesBinaryOps: SeriesLike { match &output_type { #[cfg(feature = "python")] Python => Ok(py_binary_op!(lhs, rhs, "truediv")), - Float64 => cast_downcast_op_into_series!(lhs, rhs, &Float64, Float64Type, div), + Float64 => cast_downcast_op_into_series!(lhs, rhs, &Float64, Float64Array, div), _ => binary_op_unimplemented!(lhs, "/", rhs, output_type), } } diff --git a/src/daft-core/src/series/array_impl/data_array.rs b/src/daft-core/src/series/array_impl/data_array.rs index a8706d3d5e..b0c26d9ea2 100644 --- a/src/daft-core/src/series/array_impl/data_array.rs +++ b/src/daft-core/src/series/array_impl/data_array.rs @@ -244,7 +244,10 @@ macro_rules! impl_series_like_for_data_array { } fn take(&self, idx: &Series) -> DaftResult { with_match_integer_daft_types!(idx.data_type(), |$S| { - Ok(self.0.take(idx.downcast::<$S>()?)?.into_series()) + Ok(self + .0 + .take(idx.downcast::<<$S as DaftDataType>::ArrayType>()?)? + .into_series()) }) } diff --git a/src/daft-core/src/series/array_impl/logical_array.rs b/src/daft-core/src/series/array_impl/logical_array.rs index b851cd2e5b..d3eabe3dac 100644 --- a/src/daft-core/src/series/array_impl/logical_array.rs +++ b/src/daft-core/src/series/array_impl/logical_array.rs @@ -1,62 +1,38 @@ use crate::datatypes::logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, TensorArray, TimestampArray, + FixedShapeTensorArray, ImageArray, LogicalArray, TensorArray, TimestampArray, }; -use crate::datatypes::BooleanArray; +use crate::datatypes::{BooleanArray, DaftLogicalType, Field}; use super::{ArrayWrapper, IntoSeries, Series}; use crate::array::ops::GroupIndices; use crate::series::array_impl::binary_ops::SeriesBinaryOps; use crate::series::DaftResult; use crate::series::SeriesLike; -use crate::with_match_daft_logical_primitive_types; use crate::with_match_integer_daft_types; +use crate::DataType; use std::sync::Arc; -macro_rules! impl_series_like_for_logical_array { - ($da:ident) => { - impl IntoSeries for $da { - fn into_series(self) -> Series { - Series { - inner: Arc::new(ArrayWrapper(self)), - } - } +impl IntoSeries for LogicalArray +where + L: DaftLogicalType, + ArrayWrapper>: SeriesLike, +{ + fn into_series(self) -> Series { + Series { + inner: Arc::new(ArrayWrapper(self)), } + } +} +macro_rules! impl_series_like_for_logical_array { + ($da:ident) => { impl SeriesLike for ArrayWrapper<$da> { fn into_series(&self) -> Series { self.0.clone().into_series() } fn to_arrow(&self) -> Box { - let daft_type = self.0.logical_type(); - let arrow_logical_type = daft_type.to_arrow().unwrap(); - let physical_arrow_array = self.0.physical.data(); - use crate::datatypes::DataType::*; - match daft_type { - // For wrapped primitive types, switch the datatype label on the arrow2 Array. - Decimal128(..) | Date | Timestamp(..) | Duration(..) => { - with_match_daft_logical_primitive_types!(daft_type, |$P| { - use arrow2::array::Array; - physical_arrow_array - .as_any() - .downcast_ref::>() - .unwrap() - .clone() - .to(arrow_logical_type) - .to_boxed() - }) - } - // Otherwise, use arrow cast to make sure the result arrow2 array is of the correct type. - _ => arrow2::compute::cast::cast( - physical_arrow_array, - &arrow_logical_type, - arrow2::compute::cast::CastOptions { - wrapped: true, - partial: false, - }, - ) - .unwrap(), - } + self.0.to_arrow() } fn as_any(&self) -> &dyn std::any::Any { @@ -69,30 +45,31 @@ macro_rules! impl_series_like_for_logical_array { Ok($da::new(self.0.field.clone(), data_array).into_series()) } - fn cast(&self, datatype: &crate::datatypes::DataType) -> DaftResult { + fn cast(&self, datatype: &DataType) -> DaftResult { self.0.cast(datatype) } - fn data_type(&self) -> &crate::datatypes::DataType { - self.0.logical_type() + fn data_type(&self) -> &DataType { + self.0.data_type() } - fn field(&self) -> &crate::datatypes::Field { + fn field(&self) -> &Field { self.0.field() } fn filter(&self, mask: &crate::datatypes::BooleanArray) -> DaftResult { - Ok(self.0.filter(mask)?.into_series()) + let new_array = self.0.physical.filter(mask)?; + Ok($da::new(self.0.field.clone(), new_array).into_series()) } fn head(&self, num: usize) -> DaftResult { - Ok(self.0.head(num)?.into_series()) + self.slice(0, num) } fn if_else(&self, other: &Series, predicate: &Series) -> DaftResult { Ok(self .0 - .if_else(other.downcast_logical()?, predicate.downcast()?)? + .if_else(other.downcast()?, predicate.downcast()?)? .into_series()) } @@ -103,23 +80,26 @@ macro_rules! impl_series_like_for_logical_array { } fn len(&self) -> usize { - self.0.len() + self.0.physical.len() } fn size_bytes(&self) -> DaftResult { - self.0.size_bytes() + self.0.physical.size_bytes() } fn name(&self) -> &str { - self.0.name() + self.0.field.name.as_str() } fn rename(&self, name: &str) -> Series { - self.0.rename(name).into_series() + let new_array = self.0.physical.rename(name); + let new_field = self.0.field.rename(name); + $da::new(new_field, new_array).into_series() } fn slice(&self, start: usize, end: usize) -> DaftResult { - Ok(self.0.slice(start, end)?.into_series()) + let new_array = self.0.physical.slice(start, end)?; + Ok($da::new(self.0.field.clone(), new_array).into_series()) } fn sort(&self, descending: bool) -> DaftResult { @@ -136,7 +116,10 @@ macro_rules! impl_series_like_for_logical_array { fn take(&self, idx: &Series) -> DaftResult { with_match_integer_daft_types!(idx.data_type(), |$S| { - Ok(self.0.take(idx.downcast::<$S>()?)?.into_series()) + Ok(self + .0 + .take(idx.downcast::<<$S as DaftDataType>::ArrayType>()?)? + .into_series()) }) } @@ -220,9 +203,9 @@ macro_rules! impl_series_like_for_logical_array { impl_series_like_for_logical_array!(Decimal128Array); impl_series_like_for_logical_array!(DateArray); impl_series_like_for_logical_array!(DurationArray); -impl_series_like_for_logical_array!(EmbeddingArray); impl_series_like_for_logical_array!(ImageArray); -impl_series_like_for_logical_array!(FixedShapeImageArray); impl_series_like_for_logical_array!(TimestampArray); impl_series_like_for_logical_array!(TensorArray); +impl_series_like_for_logical_array!(EmbeddingArray); +impl_series_like_for_logical_array!(FixedShapeImageArray); impl_series_like_for_logical_array!(FixedShapeTensorArray); diff --git a/src/daft-core/src/series/mod.rs b/src/daft-core/src/series/mod.rs index 8619cac03b..00888f671d 100644 --- a/src/daft-core/src/series/mod.rs +++ b/src/daft-core/src/series/mod.rs @@ -13,7 +13,7 @@ use common_error::DaftResult; pub use array_impl::IntoSeries; -use self::series_like::SeriesLike; +pub(crate) use self::series_like::SeriesLike; #[derive(Clone)] pub struct Series { diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index 65b7d1bc61..37afde8304 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -11,8 +11,8 @@ impl Series { let s = self.as_physical()?; with_match_physical_daft_types!(s.data_type(), |$T| { match groups { - Some(groups) => Ok(DaftCountAggable::grouped_count(&s.downcast::<$T>()?, groups, mode)?.into_series()), - None => Ok(DaftCountAggable::count(&s.downcast::<$T>()?, mode)?.into_series()) + Some(groups) => Ok(DaftCountAggable::grouped_count(&s.downcast::<<$T as DaftDataType>::ArrayType>()?, groups, mode)?.into_series()), + None => Ok(DaftCountAggable::count(&s.downcast::<<$T as DaftDataType>::ArrayType>()?, mode)?.into_series()) } }) } @@ -45,19 +45,19 @@ impl Series { // floatX -> floatX (in line with numpy) Float32 => match groups { Some(groups) => Ok(DaftSumAggable::grouped_sum( - &self.downcast::()?, + &self.downcast::()?, groups, )? .into_series()), - None => Ok(DaftSumAggable::sum(&self.downcast::()?)?.into_series()), + None => Ok(DaftSumAggable::sum(&self.downcast::()?)?.into_series()), }, Float64 => match groups { Some(groups) => Ok(DaftSumAggable::grouped_sum( - &self.downcast::()?, + &self.downcast::()?, groups, )? .into_series()), - None => Ok(DaftSumAggable::sum(&self.downcast::()?)?.into_series()), + None => Ok(DaftSumAggable::sum(&self.downcast::()?)?.into_series()), }, other => Err(DaftError::TypeError(format!( "Numeric sum is not implemented for type {}", @@ -104,7 +104,7 @@ impl Series { use crate::array::ops::DaftConcatAggable; match self.data_type() { DataType::List(..) => { - let downcasted = self.downcast::()?; + let downcasted = self.downcast::()?; match groups { Some(groups) => { Ok(DaftConcatAggable::grouped_concat(downcasted, groups)?.into_series()) @@ -114,7 +114,7 @@ impl Series { } #[cfg(feature = "python")] DataType::Python => { - let downcasted = self.downcast::()?; + let downcasted = self.downcast::()?; match groups { Some(groups) => { Ok(DaftConcatAggable::grouped_concat(downcasted, groups)?.into_series()) diff --git a/src/daft-core/src/series/ops/arithmetic.rs b/src/daft-core/src/series/ops/arithmetic.rs index 7417f49403..4f0c5f21ef 100644 --- a/src/daft-core/src/series/ops/arithmetic.rs +++ b/src/daft-core/src/series/ops/arithmetic.rs @@ -30,6 +30,7 @@ impl_arithmetic_for_series!(Rem, rem); #[cfg(test)] mod tests { + use crate::array::ops::full::FullNull; use crate::datatypes::{DataType, Float64Array, Int64Array, Utf8Array}; use crate::series::IntoSeries; use common_error::DaftResult; diff --git a/src/daft-core/src/series/ops/broadcast.rs b/src/daft-core/src/series/ops/broadcast.rs index 87cc2e6cd4..7c3c0fbe90 100644 --- a/src/daft-core/src/series/ops/broadcast.rs +++ b/src/daft-core/src/series/ops/broadcast.rs @@ -9,6 +9,7 @@ impl Series { #[cfg(test)] mod tests { + use crate::array::ops::full::FullNull; use crate::datatypes::{DataType, Int64Array, Utf8Array}; use crate::series::array_impl::IntoSeries; use common_error::DaftResult; diff --git a/src/daft-core/src/series/ops/concat.rs b/src/daft-core/src/series/ops/concat.rs index 6faee693ba..5812342bdf 100644 --- a/src/daft-core/src/series/ops/concat.rs +++ b/src/daft-core/src/series/ops/concat.rs @@ -29,13 +29,13 @@ impl Series { } if first_dtype.is_logical() { return Ok(with_match_daft_logical_types!(first_dtype, |$T| { - let downcasted = series.into_iter().map(|s| s.downcast_logical::<$T>()).collect::>>()?; + let downcasted = series.into_iter().map(|s| s.downcast::<<$T as DaftDataType>::ArrayType>()).collect::>>()?; LogicalArray::<$T>::concat(downcasted.as_slice())?.into_series() })); } with_match_physical_daft_types!(first_dtype, |$T| { - let downcasted = series.into_iter().map(|s| s.downcast::<$T>()).collect::>>()?; + let downcasted = series.into_iter().map(|s| s.downcast::<<$T as DaftDataType>::ArrayType>()).collect::>>()?; Ok(DataArray::<$T>::concat(downcasted.as_slice())?.into_series()) }) } diff --git a/src/daft-core/src/series/ops/date.rs b/src/daft-core/src/series/ops/date.rs index 267303cf5b..c3147bc81c 100644 --- a/src/daft-core/src/series/ops/date.rs +++ b/src/daft-core/src/series/ops/date.rs @@ -1,6 +1,6 @@ use crate::series::array_impl::IntoSeries; use crate::{ - datatypes::{DataType, DateType}, + datatypes::{logical::DateArray, DataType}, series::Series, }; use common_error::{DaftError, DaftResult}; @@ -14,7 +14,7 @@ impl Series { ))); } - let downcasted = self.downcast_logical::()?; + let downcasted = self.downcast::()?; Ok(downcasted.day()?.into_series()) } @@ -26,7 +26,7 @@ impl Series { ))); } - let downcasted = self.downcast_logical::()?; + let downcasted = self.downcast::()?; Ok(downcasted.month()?.into_series()) } @@ -38,7 +38,7 @@ impl Series { ))); } - let downcasted = self.downcast_logical::()?; + let downcasted = self.downcast::()?; Ok(downcasted.year()?.into_series()) } @@ -50,7 +50,7 @@ impl Series { ))); } - let downcasted = self.downcast_logical::()?; + let downcasted = self.downcast::()?; Ok(downcasted.day_of_week()?.into_series()) } } diff --git a/src/daft-core/src/series/ops/downcast.rs b/src/daft-core/src/series/ops/downcast.rs index a4e500af26..6b4b5a7e88 100644 --- a/src/daft-core/src/series/ops/downcast.rs +++ b/src/daft-core/src/series/ops/downcast.rs @@ -1,33 +1,24 @@ +use std::marker::PhantomData; + use crate::datatypes::*; -use crate::datatypes::logical::{FixedShapeImageArray, ImageArray, LogicalArray}; +use crate::datatypes::logical::{FixedShapeImageArray, ImageArray}; use crate::series::array_impl::ArrayWrapper; use crate::series::Series; use common_error::DaftResult; impl Series { - pub fn downcast(&self) -> DaftResult<&DataArray> - where - T: DaftPhysicalType + 'static, - { - match self.inner.as_any().downcast_ref() { - Some(ArrayWrapper(arr)) => Ok(arr), - None => panic!( - "Attempting to downcast {:?} to {:?}", - self.data_type(), - T::get_dtype() - ), //Err(DaftError::SchemaMismatch(format!( - } - } - - pub fn downcast_logical(&self) -> DaftResult<&LogicalArray> { + pub fn downcast(&self) -> DaftResult<&Arr> { match self.inner.as_any().downcast_ref() { Some(ArrayWrapper(arr)) => Ok(arr), - None => panic!( - "Attempting to downcast {:?} to {:?}", - self.data_type(), - L::get_dtype() - ), //Err(DaftError::SchemaMismatch(format!( + None => { + let phantom: PhantomData = PhantomData {}; + panic!( + "Attempting to downcast {:?} to {:?}", + self.data_type(), + phantom + ) + } } } @@ -128,11 +119,11 @@ impl Series { } pub fn image(&self) -> DaftResult<&ImageArray> { - self.downcast_logical() + self.downcast() } pub fn fixed_size_image(&self) -> DaftResult<&FixedShapeImageArray> { - self.downcast_logical() + self.downcast() } #[cfg(feature = "python")] diff --git a/src/daft-core/src/series/ops/float.rs b/src/daft-core/src/series/ops/float.rs index c500c6f3d7..8a8795c277 100644 --- a/src/daft-core/src/series/ops/float.rs +++ b/src/daft-core/src/series/ops/float.rs @@ -8,7 +8,7 @@ impl Series { pub fn is_nan(&self) -> DaftResult { use crate::array::ops::DaftIsNan; with_match_float_and_null_daft_types!(self.data_type(), |$T| { - Ok(DaftIsNan::is_nan(self.downcast::<$T>()?)?.into_series()) + Ok(DaftIsNan::is_nan(self.downcast::<<$T as DaftDataType>::ArrayType>()?)?.into_series()) }) } } diff --git a/src/daft-core/src/series/ops/groups.rs b/src/daft-core/src/series/ops/groups.rs index b19876db2e..d15b411f3e 100644 --- a/src/daft-core/src/series/ops/groups.rs +++ b/src/daft-core/src/series/ops/groups.rs @@ -9,7 +9,7 @@ impl IntoGroups for Series { fn make_groups(&self) -> DaftResult { let s = self.as_physical()?; with_match_comparable_daft_types!(s.data_type(), |$T| { - let array = s.downcast::<$T>()?; + let array = s.downcast::<<$T as DaftDataType>::ArrayType>()?; array.make_groups() }) } diff --git a/src/daft-core/src/series/ops/hash.rs b/src/daft-core/src/series/ops/hash.rs index c3bb53a6f1..37aa74296a 100644 --- a/src/daft-core/src/series/ops/hash.rs +++ b/src/daft-core/src/series/ops/hash.rs @@ -5,7 +5,7 @@ impl Series { pub fn hash(&self, seed: Option<&UInt64Array>) -> DaftResult { let s = self.as_physical()?; with_match_comparable_daft_types!(s.data_type(), |$T| { - let downcasted = s.downcast::<$T>()?; + let downcasted = s.downcast::<<$T as DaftDataType>::ArrayType>()?; downcasted.hash(seed) }) } diff --git a/src/daft-core/src/series/ops/image.rs b/src/daft-core/src/series/ops/image.rs index e5644d19cf..8f8791a93e 100644 --- a/src/daft-core/src/series/ops/image.rs +++ b/src/daft-core/src/series/ops/image.rs @@ -1,4 +1,5 @@ -use crate::datatypes::{DataType, Field, FixedShapeImageType, ImageFormat, ImageType}; +use crate::datatypes::logical::{FixedShapeImageArray, ImageArray}; +use crate::datatypes::{DataType, Field, ImageFormat}; use crate::series::{IntoSeries, Series}; use common_error::{DaftError, DaftResult}; @@ -16,11 +17,11 @@ impl Series { pub fn image_encode(&self, image_format: ImageFormat) -> DaftResult { match self.data_type() { DataType::Image(..) => Ok(self - .downcast_logical::()? + .downcast::()? .encode(image_format)? .into_series()), DataType::FixedShapeImage(..) => Ok(self - .downcast_logical::()? + .downcast::()? .encode(image_format)? .into_series()), dtype => Err(DaftError::ValueError(format!( @@ -33,7 +34,7 @@ impl Series { pub fn image_resize(&self, w: u32, h: u32) -> DaftResult { match self.data_type() { DataType::Image(mode) => { - let array = self.downcast_logical::()?; + let array = self.downcast::()?; match mode { // If the image mode is specified at the type-level (and is therefore guaranteed to be consistent // across all images across all partitions), store the resized image in a fixed shape image array, @@ -45,7 +46,7 @@ impl Series { } } DataType::FixedShapeImage(..) => Ok(self - .downcast_logical::()? + .downcast::()? .resize(w, h)? .into_series()), _ => Err(DaftError::ValueError(format!( diff --git a/src/daft-core/src/series/ops/not.rs b/src/daft-core/src/series/ops/not.rs index fced540140..41638c5246 100644 --- a/src/daft-core/src/series/ops/not.rs +++ b/src/daft-core/src/series/ops/not.rs @@ -1,6 +1,6 @@ use std::ops::Not; -use crate::datatypes::BooleanType; +use crate::datatypes::BooleanArray; use crate::series::array_impl::IntoSeries; use crate::series::Series; use common_error::DaftResult; @@ -8,7 +8,7 @@ use common_error::DaftResult; impl Not for &Series { type Output = DaftResult; fn not(self) -> Self::Output { - let array = self.downcast::()?; + let array = self.downcast::()?; Ok((!array)?.into_series()) } } diff --git a/src/daft-core/src/series/ops/search_sorted.rs b/src/daft-core/src/series/ops/search_sorted.rs index 6a084ca1fc..a9773e503e 100644 --- a/src/daft-core/src/series/ops/search_sorted.rs +++ b/src/daft-core/src/series/ops/search_sorted.rs @@ -12,8 +12,8 @@ impl Series { let rhs = rhs.as_physical()?; with_match_comparable_daft_types!(lhs.data_type(), |$T| { - let lhs = lhs.downcast::<$T>().unwrap(); - let rhs = rhs.downcast::<$T>().unwrap(); + let lhs = lhs.downcast::<<$T as DaftDataType>::ArrayType>().unwrap(); + let rhs = rhs.downcast::<<$T as DaftDataType>::ArrayType>().unwrap(); lhs.search_sorted(rhs, descending) }) } diff --git a/src/daft-core/src/series/ops/sort.rs b/src/daft-core/src/series/ops/sort.rs index 9aea2b53b6..da66c40e0b 100644 --- a/src/daft-core/src/series/ops/sort.rs +++ b/src/daft-core/src/series/ops/sort.rs @@ -9,7 +9,7 @@ impl Series { pub fn argsort(&self, descending: bool) -> DaftResult { let series = self.as_physical()?; with_match_comparable_daft_types!(series.data_type(), |$T| { - let downcasted = series.downcast::<$T>()?; + let downcasted = series.downcast::<<$T as DaftDataType>::ArrayType>()?; Ok(downcasted.argsort::(descending)?.into_series()) }) } @@ -32,7 +32,7 @@ impl Series { let first = sort_keys.first().unwrap().as_physical()?; with_match_comparable_daft_types!(first.data_type(), |$T| { - let downcasted = first.downcast::<$T>()?; + let downcasted = first.downcast::<<$T as DaftDataType>::ArrayType>()?; let result = downcasted.argsort_multikey::(&sort_keys[1..], descending)?; Ok(result.into_series()) }) diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 61e17217a3..d6db43e364 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -4,8 +4,8 @@ use std::{ }; use crate::expr::Expr; -use daft_core::datatypes::DataType; use daft_core::series::Series; +use daft_core::{array::ops::full::FullNull, datatypes::DataType}; use serde::{Deserialize, Serialize}; #[cfg(feature = "python")] diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index be9c8adba3..752db9e070 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -3,13 +3,14 @@ use std::collections::HashSet; use std::fmt::{Display, Formatter, Result}; +use daft_core::array::ops::full::FullNull; use num_traits::ToPrimitive; use daft_core::array::ops::GroupIndices; use common_error::{DaftError, DaftResult}; use daft_core::datatypes::logical::LogicalArray; -use daft_core::datatypes::{BooleanType, DataType, Field, UInt64Array}; +use daft_core::datatypes::{BooleanArray, DataType, Field, UInt64Array}; use daft_core::schema::{Schema, SchemaRef}; use daft_core::series::{IntoSeries, Series}; use daft_core::{with_match_daft_logical_types, with_match_physical_daft_types}; @@ -202,7 +203,7 @@ impl Table { ))); } - let mask = mask.downcast::().unwrap(); + let mask = mask.downcast::().unwrap(); let new_series: DaftResult> = self.columns.iter().map(|s| s.filter(mask)).collect(); Ok(Table { schema: self.schema.clone(), diff --git a/src/daft-table/src/ops/groups.rs b/src/daft-table/src/ops/groups.rs index bc728246b8..76e6c04c33 100644 --- a/src/daft-table/src/ops/groups.rs +++ b/src/daft-table/src/ops/groups.rs @@ -4,7 +4,7 @@ use daft_core::{ arrow2::comparison::build_multi_array_is_equal, as_arrow::AsArrow, GroupIndicesPair, IntoGroups, }, - datatypes::{UInt64Array, UInt64Type}, + datatypes::UInt64Array, series::Series, }; @@ -56,7 +56,7 @@ impl Table { // Begin by doing the argsort. let argsort_series = Series::argsort_multikey(self.columns.as_slice(), &vec![false; self.columns.len()])?; - let argsort_array = argsort_series.downcast::()?; + let argsort_array = argsort_series.downcast::()?; // The result indices. let mut key_indices: Vec = vec![]; diff --git a/src/daft-table/src/ops/joins/hash_join.rs b/src/daft-table/src/ops/joins/hash_join.rs index b46ed549be..68c3b7dfaf 100644 --- a/src/daft-table/src/ops/joins/hash_join.rs +++ b/src/daft-table/src/ops/joins/hash_join.rs @@ -1,5 +1,5 @@ use daft_core::{ - array::ops::arrow2::comparison::build_multi_array_is_equal, + array::ops::{arrow2::comparison::build_multi_array_is_equal, full::FullNull}, datatypes::{DataType, UInt64Array}, series::{IntoSeries, Series}, };