Skip to content

Commit

Permalink
[CHORE] Refactor Series downcast and LogicalArrayImpl (#1289)
Browse files Browse the repository at this point in the history
# Summary

1. Refactors `Series::downcast` to use the concrete array type instead
of a DaftDataType
2. Refactors `LogicalArray<L> = LogicalArrayImpl<L,
L::PhysicalType::ArrayType>`. The underlying `LogicalArrayImpl` is now
generic over both its LogicalType, as well as its underlying container
type.

Necessary drive-bys:

1. Refactors to `image.rs` which used to make a lot of assumptions about
LogicalArrays holding DataArrays
2. Refactor to `::full_null()` and `::empty()`: we now need a generic
implementation over all LogicalArray container types and this requires a
`FullNull` trait to be used.

This paves the way for adding more LogicalArray container types that
aren't `DataArray` (Coming soon: FixedSizeListArray, StructArray and
more)

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia committed Aug 23, 2023
1 parent 2bff0cb commit 671e340
Show file tree
Hide file tree
Showing 37 changed files with 335 additions and 270 deletions.
4 changes: 3 additions & 1 deletion src/daft-core/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub mod pseudo_arrow;

use std::{marker::PhantomData, sync::Arc};

use crate::datatypes::{DaftPhysicalType, DataType, Field};
use crate::datatypes::{DaftArrayType, DaftPhysicalType, DataType, Field};

use common_error::{DaftError, DaftResult};

Expand All @@ -22,6 +22,8 @@ impl<T: DaftPhysicalType> Clone for DataArray<T> {
}
}

impl<T: DaftPhysicalType> DaftArrayType for DataArray<T> {}

impl<T> DataArray<T>
where
T: DaftPhysicalType,
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/array/ops/arange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ where
let arrow_array = Box::new(arrow2::array::PrimitiveArray::<i64>::from_vec(data));
let data_array = Int64Array::from((name.as_ref(), arrow_array));
let casted_array = data_array.cast(&T::get_dtype())?;
let downcasted = casted_array.downcast::<T>()?;
let downcasted = casted_array.downcast::<DataArray<T>>()?;
Ok(downcasted.clone())
}
}
2 changes: 1 addition & 1 deletion src/daft-core/src/array/ops/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{

use common_error::{DaftError, DaftResult};

use super::as_arrow::AsArrow;
use super::{as_arrow::AsArrow, full::FullNull};
/// Helper function to perform arithmetic operations on a DataArray
/// Takes both Kernel (array x array operation) and operation (scalar x scalar) functions
/// The Kernel is used for when both arrays are non-unit length and the operation is used when broadcasting
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/array/ops/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{

use common_error::{DaftError, DaftResult};

use super::as_arrow::AsArrow;
use super::{as_arrow::AsArrow, full::FullNull};

#[cfg(feature = "python")]
use crate::array::pseudo_arrow::PseudoArrowArray;
Expand Down
46 changes: 28 additions & 18 deletions src/daft-core/src/array/ops/cast.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
use super::as_arrow::AsArrow;
use crate::{
array::{ops::from_arrow::FromArrow, ops::image::ImageArraySidecarData, DataArray},
array::{
ops::image::ImageArraySidecarData,
ops::{from_arrow::FromArrow, full::FullNull},
DataArray,
},
datatypes::{
logical::{
DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray,
FixedShapeTensorArray, ImageArray, LogicalArray, TensorArray, TimestampArray,
FixedShapeTensorArray, ImageArray, LogicalArray, LogicalArrayImpl, TensorArray,
TimestampArray,
},
DaftArrowBackedType, DaftLogicalType, DataType, Field, FixedShapeTensorType,
FixedSizeListArray, ImageMode, StructArray, TensorType, TimeUnit, Utf8Array,
DaftArrowBackedType, DaftLogicalType, DataType, Field, FixedSizeListArray, ImageMode,
StructArray, TimeUnit, Utf8Array,
},
series::{IntoSeries, Series},
with_match_arrow_daft_types, with_match_daft_logical_primitive_types,
Expand Down Expand Up @@ -39,17 +44,21 @@ use {
std::ops::Deref,
};

fn arrow_logical_cast<T>(to_cast: &LogicalArray<T>, dtype: &DataType) -> DaftResult<Series>
fn arrow_logical_cast<T>(
to_cast: &LogicalArrayImpl<T, DataArray<T::PhysicalType>>,
dtype: &DataType,
) -> DaftResult<Series>
where
T: DaftLogicalType,
T::PhysicalType: DaftArrowBackedType,
{
// Cast from LogicalArray to the target DataType
// using Arrow's casting mechanisms.

// Note that Arrow Logical->Logical direct casts (what this method exposes)
// have different behaviour than Arrow Logical->Physical->Logical casts.

let source_dtype = to_cast.logical_type();
let source_dtype = to_cast.data_type();
let source_arrow_type = source_dtype.to_arrow()?;
let target_arrow_type = dtype.to_arrow()?;

Expand Down Expand Up @@ -360,8 +369,8 @@ impl TimestampArray {
match dtype {
DataType::Timestamp(..) => arrow_logical_cast(self, dtype),
DataType::Utf8 => {
let DataType::Timestamp(unit, timezone) = self.logical_type() else {
panic!("Wrong dtype for TimestampArray: {}", self.logical_type())
let DataType::Timestamp(unit, timezone) = self.data_type() else {
panic!("Wrong dtype for TimestampArray: {}", self.data_type())
};

let str_array: arrow2::array::Utf8Array<i64> = timezone.as_ref().map_or_else(
Expand Down Expand Up @@ -1079,7 +1088,7 @@ impl PythonArray {

impl EmbeddingArray {
pub fn cast(&self, dtype: &DataType) -> DaftResult<Series> {
match (dtype, self.logical_type()) {
match (dtype, self.data_type()) {
#[cfg(feature = "python")]
(DataType::Python, DataType::Embedding(_, size)) => Python::with_gil(|py| {
let shape = (self.len(), *size);
Expand Down Expand Up @@ -1111,7 +1120,7 @@ impl EmbeddingArray {
DataType::FixedShapeTensor(Box::new(inner_dtype.clone().dtype), image_shape);
let fixed_shape_tensor_array = self.cast(&fixed_shape_tensor_dtype)?;
let fixed_shape_tensor_array =
fixed_shape_tensor_array.downcast_logical::<FixedShapeTensorType>()?;
fixed_shape_tensor_array.downcast::<FixedShapeTensorArray>()?;
fixed_shape_tensor_array.cast(dtype)
}
// NOTE(Clark): Casting to FixedShapeTensor is supported by the physical array cast.
Expand Down Expand Up @@ -1174,7 +1183,7 @@ impl ImageArray {
}
})?;
let fixed_shape_tensor_array =
fixed_shape_tensor_array.downcast_logical::<FixedShapeTensorType>()?;
fixed_shape_tensor_array.downcast::<FixedShapeTensorArray>()?;
fixed_shape_tensor_array.cast(dtype)
}
DataType::Tensor(_) => {
Expand Down Expand Up @@ -1230,7 +1239,7 @@ impl ImageArray {
DataType::FixedShapeTensor(inner_dtype, _) => {
let tensor_dtype = DataType::Tensor(inner_dtype.clone());
let tensor_array = self.cast(&tensor_dtype)?;
let tensor_array = tensor_array.downcast_logical::<TensorType>()?;
let tensor_array = tensor_array.downcast::<TensorArray>()?;
tensor_array.cast(dtype)
}
_ => self.physical.cast(dtype),
Expand All @@ -1240,7 +1249,7 @@ impl ImageArray {

impl FixedShapeImageArray {
pub fn cast(&self, dtype: &DataType) -> DaftResult<Series> {
match (dtype, self.logical_type()) {
match (dtype, self.data_type()) {
#[cfg(feature = "python")]
(DataType::Python, DataType::FixedShapeImage(mode, height, width)) => {
pyo3::Python::with_gil(|py| {
Expand Down Expand Up @@ -1299,13 +1308,13 @@ impl FixedShapeImageArray {
}
})?;
let fixed_shape_tensor_array =
fixed_shape_tensor_array.downcast_logical::<FixedShapeTensorType>()?;
fixed_shape_tensor_array.downcast::<FixedShapeTensorArray>()?;
fixed_shape_tensor_array.cast(dtype)
}
(DataType::Image(_), DataType::FixedShapeImage(mode, _, _)) => {
let tensor_dtype = DataType::Tensor(Box::new(mode.get_dtype()));
let tensor_array = self.cast(&tensor_dtype)?;
let tensor_array = tensor_array.downcast_logical::<TensorType>()?;
let tensor_array = tensor_array.downcast::<TensorArray>()?;
tensor_array.cast(dtype)
}
// NOTE(Clark): Casting to FixedShapeTensor is supported by the physical array cast.
Expand Down Expand Up @@ -1505,7 +1514,7 @@ impl TensorArray {
}
})?;
let fixed_shape_tensor_array =
fixed_shape_tensor_array.downcast_logical::<FixedShapeTensorType>()?;
fixed_shape_tensor_array.downcast::<FixedShapeTensorArray>()?;
fixed_shape_tensor_array.cast(dtype)
}
_ => self.physical.cast(dtype),
Expand All @@ -1515,7 +1524,7 @@ impl TensorArray {

impl FixedShapeTensorArray {
pub fn cast(&self, dtype: &DataType) -> DaftResult<Series> {
match (dtype, self.logical_type()) {
match (dtype, self.data_type()) {
#[cfg(feature = "python")]
(DataType::Python, DataType::FixedShapeTensor(_, shape)) => {
pyo3::Python::with_gil(|py| {
Expand Down Expand Up @@ -1616,11 +1625,12 @@ impl FixedShapeTensorArray {
fn cast_logical_to_python_array<T>(array: &LogicalArray<T>, dtype: &DataType) -> DaftResult<Series>
where
T: DaftLogicalType,
T::PhysicalType: DaftArrowBackedType,
LogicalArray<T>: AsArrow,
<LogicalArray<T> as AsArrow>::Output: arrow2::array::Array,
{
Python::with_gil(|py| {
let arrow_dtype = array.logical_type().to_arrow()?;
let arrow_dtype = array.data_type().to_arrow()?;
let arrow_array = array.as_arrow().to_type(arrow_dtype).with_validity(None);
let pyarrow = py.import("pyarrow")?;
let py_array: Vec<PyObject> = ffi::to_py_array(arrow_array.to_boxed(), py, pyarrow)?
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/array/ops/compare_agg.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{DaftCompareAggable, GroupIndices};
use crate::{array::DataArray, datatypes::*};
use crate::{array::ops::full::FullNull, array::DataArray, datatypes::*};
use arrow2::array::PrimitiveArray;
use arrow2::{self, array::Array};

Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/array/ops/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use common_error::{DaftError, DaftResult};

use std::ops::Not;

use super::{DaftCompare, DaftLogical};
use super::{full::FullNull, DaftCompare, DaftLogical};

use super::as_arrow::AsArrow;
use arrow2::{compute::comparison, scalar::PrimitiveScalar};
Expand Down
13 changes: 10 additions & 3 deletions src/daft-core/src/array/ops/from_arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use common_error::DaftResult;

use crate::{
array::DataArray,
datatypes::{logical::LogicalArray, DaftLogicalType, DaftPhysicalType, Field},
datatypes::{logical::LogicalArray, DaftDataType, DaftLogicalType, DaftPhysicalType, Field},
};

/// Arrays that implement [`FromArrow`] can be instantiated from a Box<dyn arrow2::array::Array>
Expand All @@ -19,10 +19,17 @@ impl<T: DaftPhysicalType> FromArrow for DataArray<T> {
}
}

impl<L: DaftLogicalType> FromArrow for LogicalArray<L> {
impl<L: DaftLogicalType> FromArrow for LogicalArray<L>
where
<L::PhysicalType as DaftDataType>::ArrayType: FromArrow,
{
fn from_arrow(field: &Field, arrow_arr: Box<dyn arrow2::array::Array>) -> DaftResult<Self> {
let data_array_field = Field::new(field.name.clone(), field.dtype.to_physical());
let physical = DataArray::try_from((data_array_field, arrow_arr))?;
let physical_arrow_arr = arrow_arr.to_type(data_array_field.dtype.to_arrow()?);
let physical = <L::PhysicalType as DaftDataType>::ArrayType::from_arrow(
&data_array_field,
physical_arrow_arr,
)?;
Ok(LogicalArray::<L>::new(field.clone(), physical))
}
}
32 changes: 28 additions & 4 deletions src/daft-core/src/array/ops/full.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,22 @@ use pyo3::Python;

use crate::{
array::{pseudo_arrow::PseudoArrowArray, DataArray},
datatypes::{DaftPhysicalType, DataType, Field},
datatypes::{
logical::LogicalArray, DaftDataType, DaftLogicalType, DaftPhysicalType, DataType, Field,
},
};

impl<T> DataArray<T>
pub trait FullNull {
fn full_null(name: &str, dtype: &DataType, length: usize) -> Self;
fn empty(name: &str, dtype: &DataType) -> Self;
}

impl<T> FullNull for DataArray<T>
where
T: DaftPhysicalType,
{
/// Creates a DataArray<T> of size `length` that is filled with all nulls.
pub fn full_null(name: &str, dtype: &DataType, length: usize) -> Self {
fn full_null(name: &str, dtype: &DataType, length: usize) -> Self {
let field = Field::new(name, dtype.clone());
#[cfg(feature = "python")]
if dtype.is_python() {
Expand All @@ -37,7 +44,7 @@ where
}
}

pub fn empty(name: &str, dtype: &DataType) -> Self {
fn empty(name: &str, dtype: &DataType) -> Self {
let field = Field::new(name, dtype.clone());
#[cfg(feature = "python")]
if dtype.is_python() {
Expand All @@ -59,3 +66,20 @@ where
}
}
}

impl<L: DaftLogicalType> FullNull for LogicalArray<L>
where
<L::PhysicalType as DaftDataType>::ArrayType: FullNull,
{
fn full_null(name: &str, dtype: &DataType, length: usize) -> Self {
let physical = <L::PhysicalType as DaftDataType>::ArrayType::full_null(name, dtype, length);
Self::new(Field::new(name, dtype.clone()), physical)
}

fn empty(field_name: &str, dtype: &DataType) -> Self {
let physical =
<L::PhysicalType as DaftDataType>::ArrayType::empty(field_name, &dtype.to_physical());
let field = Field::new(field_name, dtype.clone());
Self::new(field, physical)
}
}
1 change: 1 addition & 0 deletions src/daft-core/src/array/ops/if_else.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::array::ops::full::FullNull;
use crate::array::DataArray;
use crate::datatypes::logical::{
DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray,
Expand Down
Loading

0 comments on commit 671e340

Please sign in to comment.