Skip to content

Commit

Permalink
[CHORE] Fix List/FixedSizeList DataType to hold a dtype instead of Fi…
Browse files Browse the repository at this point in the history
…eld (#1351)

Closes: #1264 
Closes: #994 

Previously, our `DataType::List/FixedSizeList` would hold a child Field
instead of DataType. This was causing a few problems:

1. When comparing equality of two schemas, we often hit weird cases
where the naming of the subfield was wrong
2. When creating a new ListArray, sometimes we could forget to correctly
name the child array (e.g. when using a growable) which would result in
pretty nasty failures

This PR:
1. Hardcodes the child Series' name to `"list"`
2. Changes List/FixedSizeList DataType to hold a dtype instead of Field,
which essentially removes the name from the type and makes it
"anonymous"

Note that this is consistent with PyArrow behavior as well - in more
recent versions of PyArrow, they have started to ignore equality of the
subfield's name.

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Sep 7, 2023
1 parent 3f3de3e commit 4d6f716
Show file tree
Hide file tree
Showing 32 changed files with 175 additions and 235 deletions.
16 changes: 8 additions & 8 deletions daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,16 +204,16 @@ def duration(cls, timeunit: TimeUnit) -> DataType:
return cls._from_pydatatype(PyDataType.duration(timeunit._timeunit))

@classmethod
def list(cls, name: str, dtype: DataType) -> DataType:
def list(cls, dtype: DataType) -> DataType:
"""Create a List DataType: Variable-length list, where each element in the list has type ``dtype``
Args:
dtype: DataType of each element in the list
"""
return cls._from_pydatatype(PyDataType.list(name, dtype._dtype))
return cls._from_pydatatype(PyDataType.list(dtype._dtype))

@classmethod
def fixed_size_list(cls, name: str, dtype: DataType, size: int) -> DataType:
def fixed_size_list(cls, dtype: DataType, size: int) -> DataType:
"""Create a FixedSizeList DataType: Fixed-size list, where each element in the list has type ``dtype``
and each list has length ``size``.
Expand All @@ -223,7 +223,7 @@ def fixed_size_list(cls, name: str, dtype: DataType, size: int) -> DataType:
"""
if not isinstance(size, int) or size <= 0:
raise ValueError("The size for a fixed-size list must be a positive integer, but got: ", size)
return cls._from_pydatatype(PyDataType.fixed_size_list(name, dtype._dtype, size))
return cls._from_pydatatype(PyDataType.fixed_size_list(dtype._dtype, size))

@classmethod
def struct(cls, fields: dict[str, DataType]) -> DataType:
Expand All @@ -239,7 +239,7 @@ def extension(cls, name: str, storage_dtype: DataType, metadata: str | None = No
return cls._from_pydatatype(PyDataType.extension(name, storage_dtype._dtype, metadata))

@classmethod
def embedding(cls, name: str, dtype: DataType, size: int) -> DataType:
def embedding(cls, dtype: DataType, size: int) -> DataType:
"""Create an Embedding DataType: embeddings are fixed size arrays, where each element
in the array has a **numeric** ``dtype`` and each array has a fixed length of ``size``.
Expand All @@ -249,7 +249,7 @@ def embedding(cls, name: str, dtype: DataType, size: int) -> DataType:
"""
if not isinstance(size, int) or size <= 0:
raise ValueError("The size for a embedding must be a positive integer, but got: ", size)
return cls._from_pydatatype(PyDataType.embedding(name, dtype._dtype, size))
return cls._from_pydatatype(PyDataType.embedding(dtype._dtype, size))

@classmethod
def image(
Expand Down Expand Up @@ -360,11 +360,11 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType:
elif pa.types.is_list(arrow_type) or pa.types.is_large_list(arrow_type):
assert isinstance(arrow_type, (pa.ListType, pa.LargeListType))
field = arrow_type.value_field
return cls.list(field.name, cls.from_arrow_type(field.type))
return cls.list(cls.from_arrow_type(field.type))
elif pa.types.is_fixed_size_list(arrow_type):
assert isinstance(arrow_type, pa.FixedSizeListType)
field = arrow_type.value_field
return cls.fixed_size_list(field.name, cls.from_arrow_type(field.type), arrow_type.list_size)
return cls.fixed_size_list(cls.from_arrow_type(field.type), arrow_type.list_size)
elif pa.types.is_struct(arrow_type):
assert isinstance(arrow_type, pa.StructType)
fields = [arrow_type[i] for i in range(arrow_type.num_fields)]
Expand Down
2 changes: 1 addition & 1 deletion daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,6 @@ def crop(self, bbox: tuple[int, int, int, int] | Expression) -> Expression:
raise ValueError(
f"Expected `bbox` to be either a tuple of 4 ints or an Expression but received: {bbox}"
)
bbox = Expression._to_expression(bbox).cast(DataType.fixed_size_list("", DataType.uint64(), 4))
bbox = Expression._to_expression(bbox).cast(DataType.fixed_size_list(DataType.uint64(), 4))
assert isinstance(bbox, Expression)
return Expression._from_pyexpr(self._expr.image_crop(bbox._expr))
2 changes: 1 addition & 1 deletion daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def from_arrow(array: pa.Array | pa.ChunkedArray, name: str = "arrow_series") ->
storage_series = Series.from_arrow(array.storage, name=name)
series = storage_series.cast(
DataType.fixed_size_list(
"item", DataType.from_arrow_type(array.type.scalar_type), int(np.prod(array.type.shape))
DataType.from_arrow_type(array.type.scalar_type), int(np.prod(array.type.shape))
)
)
return series.cast(DataType.from_arrow_type(array.type))
Expand Down
17 changes: 7 additions & 10 deletions src/daft-core/src/array/fixed_size_list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ impl FixedSizeListArray {
) -> Self {
let field: Arc<Field> = field.into();
match &field.as_ref().dtype {
DataType::FixedSizeList(child_field, size) => {
DataType::FixedSizeList(child_dtype, size) => {
if let Some(validity) = validity.as_ref() && (validity.len() * size) != flat_child.len() {
panic!(
"FixedSizeListArray::new received values with len {} but expected it to match len of validity * size: {}",
flat_child.len(),
(validity.len() * size),
)
}
if child_field.as_ref() != flat_child.field() {
if child_dtype.as_ref() != flat_child.data_type() {
panic!(
"FixedSizeListArray::new expects the child series to have field {}, but received: {}",
child_field,
flat_child.field(),
"FixedSizeListArray::new expects the child series to have dtype {}, but received: {}",
child_dtype,
flat_child.data_type(),
)
}
}
Expand Down Expand Up @@ -103,7 +103,7 @@ impl FixedSizeListArray {

pub fn child_data_type(&self) -> &DataType {
match &self.field.dtype {
DataType::FixedSizeList(child, _) => &child.dtype,
DataType::FixedSizeList(child, _) => child.as_ref(),
_ => unreachable!("FixedSizeListArray must have DataType::FixedSizeList(..)"),
}
}
Expand Down Expand Up @@ -163,10 +163,7 @@ mod tests {

/// Helper that returns a FixedSizeListArray, with each list element at len=3
fn get_i32_fixed_size_list_array(validity: &[bool]) -> FixedSizeListArray {
let field = Field::new(
"foo",
DataType::FixedSizeList(Box::new(Field::new("foo", DataType::Int32)), 3),
);
let field = Field::new("foo", DataType::FixedSizeList(Box::new(DataType::Int32), 3));
let flat_child = Int32Array::from((
"foo",
(0i32..(validity.len() * 3) as i32).collect::<Vec<i32>>(),
Expand Down
6 changes: 3 additions & 3 deletions src/daft-core/src/array/growable/fixed_size_list_growable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ impl<'a> FixedSizeListGrowable<'a> {
capacity: usize,
) -> Self {
match dtype {
DataType::FixedSizeList(child_field, element_fixed_len) => {
DataType::FixedSizeList(child_dtype, element_fixed_len) => {
let child_growable = make_growable(
child_field.name.as_str(),
&child_field.dtype,
"item",
child_dtype.as_ref(),
arrays.iter().map(|a| &a.flat_child).collect::<Vec<_>>(),
use_validity,
capacity * element_fixed_len,
Expand Down
6 changes: 3 additions & 3 deletions src/daft-core/src/array/growable/list_growable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ impl<'a> ListGrowable<'a> {
child_capacity: usize,
) -> Self {
match dtype {
DataType::List(child_field) => {
DataType::List(child_dtype) => {
let child_growable = make_growable(
child_field.name.as_str(),
&child_field.dtype,
"list",
child_dtype.as_ref(),
arrays.iter().map(|a| &a.flat_child).collect::<Vec<_>>(),
use_validity,
child_capacity,
Expand Down
10 changes: 5 additions & 5 deletions src/daft-core/src/array/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ impl ListArray {
) -> Self {
let field: Arc<Field> = field.into();
match &field.as_ref().dtype {
DataType::List(child_field) => {
DataType::List(child_dtype) => {
if let Some(validity) = validity.as_ref() && validity.len() != offsets.len_proxy() {
panic!("ListArray::new validity length does not match computed length from offsets")
}
if child_field.as_ref() != flat_child.field() {
if child_dtype.as_ref() != flat_child.data_type() {
panic!(
"ListArray::new expects the child series to have field {}, but received: {}",
child_field,
flat_child.field(),
child_dtype,
flat_child.data_type(),
)
}
if *offsets.last() > flat_child.len() as i64 {
Expand Down Expand Up @@ -116,7 +116,7 @@ impl ListArray {

pub fn child_data_type(&self) -> &DataType {
match &self.field.dtype {
DataType::List(child) => &child.dtype,
DataType::List(child_dtype) => child_dtype.as_ref(),
_ => unreachable!("ListArray must have DataType::List(..)"),
}
}
Expand Down
Loading

0 comments on commit 4d6f716

Please sign in to comment.