Skip to content

Commit

Permalink
Add basic support for Tensor and FixedShapeTensor types.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed Jun 21, 2023
1 parent 64ab8ce commit bbc9eca
Show file tree
Hide file tree
Showing 16 changed files with 558 additions and 37 deletions.
24 changes: 24 additions & 0 deletions daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,30 @@ def image(
)
return cls._from_pydatatype(PyDataType.image(mode, height, width))

@classmethod
def tensor(
cls,
dtype: DataType,
shape: tuple[int, ...] | None = None,
) -> DataType:
"""Create a tensor DataType: tensor arrays contain n-dimensional arrays of data of the provided ``dtype`` as elements, each of the provided
``shape``.
If a ``shape`` is given, each ndarray in the column will have this shape.
If ``shape`` is not given, the ndarrays in the column can have different shapes. This is much more flexible,
but will result in a less compact representation and may be make some operations less efficient.
Args:
dtype: The type of the data contained within the tensor elements.
shape: The shape of each tensor in the column. This is ``None`` by default, which allows the shapes of
each tensor element to vary.
"""
if shape is not None:
if not isinstance(shape, tuple) or not shape or any(not isinstance(n, int) for n in shape):
raise ValueError("Tensor shape must be a non-empty tuple of ints, but got: ", shape)
return cls._from_pydatatype(PyDataType.tensor(dtype._dtype, shape))

@classmethod
def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType:
"""Maps a PyArrow DataType to a Daft DataType"""
Expand Down
23 changes: 22 additions & 1 deletion src/array/ops/as_arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use arrow2::array;
use crate::{
array::DataArray,
datatypes::{
logical::{DateArray, EmbeddingArray, FixedShapeImageArray, ImageArray, TimestampArray},
logical::{
DateArray, EmbeddingArray, FixedShapeImageArray, FixedShapeTensorArray, ImageArray,
TensorArray, TimestampArray,
},
BinaryArray, BooleanArray, DaftNumericType, FixedSizeListArray, ListArray, StructArray,
Utf8Array,
},
Expand Down Expand Up @@ -142,3 +145,21 @@ impl AsArrow for FixedShapeImageArray {
self.physical.data().as_any().downcast_ref().unwrap()
}
}

impl AsArrow for TensorArray {
type Output = array::StructArray;

// For LogicalArray<TensorType>, retrieve the underlying Arrow2 StructArray.
fn as_arrow(&self) -> &Self::Output {
self.physical.data().as_any().downcast_ref().unwrap()
}
}

impl AsArrow for FixedShapeTensorArray {
type Output = array::FixedSizeListArray;

// For LogicalArray<FixedShapeTensorType>, retrieve the underlying Arrow2 FixedSizeListArray.
fn as_arrow(&self) -> &Self::Output {
self.physical.data().as_any().downcast_ref().unwrap()
}
}
Loading

0 comments on commit bbc9eca

Please sign in to comment.