diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index 8d9d03903b..8689c1e744 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -1,6 +1,6 @@ use super::as_arrow::AsArrow; use crate::{ - array::{ops::image::ImageArraySidecarData, DataArray}, + array::{ops::from_arrow::FromArrow, ops::image::ImageArraySidecarData, DataArray}, datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, @@ -132,12 +132,11 @@ where if dtype.is_logical() { with_match_daft_logical_types!(dtype, |$T| { - let physical = DataArray::try_from((Field::new(to_cast.name(), dtype.to_physical()), result_arrow_physical_array))?; - return Ok(LogicalArray::<$T>::new(new_field.clone(), physical).into_series()); + return Ok(LogicalArray::<$T>::from_arrow(new_field.as_ref(), result_arrow_physical_array)?.into_series()) }) } with_match_arrow_daft_types!(dtype, |$T| { - Ok(DataArray::<$T>::try_from((new_field.clone(), result_arrow_physical_array))?.into_series()) + Ok(DataArray::<$T>::from_arrow(new_field.as_ref(), result_arrow_physical_array)?.into_series()) }) } @@ -224,12 +223,11 @@ where if dtype.is_logical() { with_match_daft_logical_types!(dtype, |$T| { - let physical = DataArray::try_from((Field::new(to_cast.name(), target_physical_type), result_array))?; - return Ok(LogicalArray::<$T>::new(new_field.clone(), physical).into_series()); + return Ok(LogicalArray::<$T>::from_arrow(new_field.as_ref(), result_array)?.into_series()); }) } with_match_arrow_daft_types!(dtype, |$T| { - Ok(DataArray::<$T>::try_from((new_field.clone(), result_array))?.into_series()) + return Ok(DataArray::<$T>::from_arrow(new_field.as_ref(), result_array)?.into_series()); }) } diff --git a/src/daft-core/src/array/ops/from_arrow.rs b/src/daft-core/src/array/ops/from_arrow.rs new file mode 100644 index 0000000000..12e2eaf694 --- /dev/null +++ b/src/daft-core/src/array/ops/from_arrow.rs @@ -0,0 +1,28 @@ +use common_error::DaftResult; + +use crate::{ + array::DataArray, + datatypes::{logical::LogicalArray, DaftLogicalType, DaftPhysicalType, Field}, +}; + +/// Arrays that implement [`FromArrow`] can be instantiated from a Box +pub trait FromArrow +where + Self: Sized, +{ + fn from_arrow(field: &Field, arrow_arr: Box) -> DaftResult; +} + +impl FromArrow for DataArray { + fn from_arrow(field: &Field, arrow_arr: Box) -> DaftResult { + DataArray::::try_from((field.clone(), arrow_arr)) + } +} + +impl FromArrow for LogicalArray { + fn from_arrow(field: &Field, arrow_arr: Box) -> DaftResult { + let data_array_field = Field::new(field.name.clone(), field.dtype.to_physical()); + let physical = DataArray::try_from((data_array_field, arrow_arr))?; + Ok(LogicalArray::::new(field.clone(), physical)) + } +} diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index 61d68f02db..755aab5b0d 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -14,6 +14,7 @@ mod count; mod date; mod filter; mod float; +pub mod from_arrow; mod full; mod get; pub(crate) mod groups; diff --git a/src/daft-core/src/series/from.rs b/src/daft-core/src/series/from.rs index a81bd09df7..602d4a10c8 100644 --- a/src/daft-core/src/series/from.rs +++ b/src/daft-core/src/series/from.rs @@ -9,6 +9,7 @@ use common_error::{DaftError, DaftResult}; use super::Series; +use crate::array::ops::from_arrow::FromArrow; use crate::series::array_impl::IntoSeries; impl TryFrom<(&str, Box)> for Series { @@ -52,8 +53,7 @@ impl TryFrom<(&str, Box)> for Series { }; let res = with_match_daft_logical_types!(dtype, |$T| { - let physical = DataArray::try_from((Field::new(name, physical_type), physical_arrow_array))?; - LogicalArray::<$T>::new(field, physical).into_series() + LogicalArray::<$T>::from_arrow(field.as_ref(), physical_arrow_array)?.into_series() }); return Ok(res); } @@ -70,12 +70,12 @@ impl TryFrom<(&str, Box)> for Series { }, )?; return Ok( - with_match_physical_daft_types!(physical_type, |$T| DataArray::<$T>::new(field, casted_array)?.into_series()), + with_match_physical_daft_types!(physical_type, |$T| DataArray::<$T>::from_arrow(field.as_ref(), casted_array)?.into_series()), ); } Ok( - with_match_physical_daft_types!(dtype, |$T| DataArray::<$T>::new(field, array.into())?.into_series()), + with_match_physical_daft_types!(dtype, |$T| DataArray::<$T>::from_arrow(field.as_ref(), array.into())?.into_series()), ) } }