From 4d6f71669338590b6a39bd713bb0167e0060399b Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Wed, 6 Sep 2023 21:44:37 -0700 Subject: [PATCH] [CHORE] Fix List/FixedSizeList DataType to hold a dtype instead of Field (#1351) Closes: #1264 Closes: #994 Previously, our `DataType::List/FixedSizeList` would hold a child Field instead of DataType. This was causing a few problems: 1. When comparing equality of two schemas, we often hit weird cases where the naming of the subfield was wrong 2. When creating a new ListArray, sometimes we could forget to correctly name the child array (e.g. when using a growable) which would result in pretty nasty failures This PR: 1. Hardcodes the child Series' name to `"list"` 2. Changes List/FixedSizeList DataType to hold a dtype instead of Field, which essentially removes the name from the type and makes it "anonymous" Note that this is consistent with PyArrow behavior as well - in more recent versions of PyArrow, they have started to ignore equality of the subfield's name. --------- Co-authored-by: Jay Chia --- daft/datatype.py | 16 ++-- daft/expressions/expressions.py | 2 +- daft/series.py | 2 +- .../src/array/fixed_size_list_array.rs | 17 ++-- .../growable/fixed_size_list_growable.rs | 6 +- .../src/array/growable/list_growable.rs | 6 +- src/daft-core/src/array/list_array.rs | 10 +- src/daft-core/src/array/ops/cast.rs | 96 ++++++++----------- src/daft-core/src/array/ops/concat_agg.rs | 15 +-- src/daft-core/src/array/ops/from_arrow.rs | 10 +- src/daft-core/src/array/ops/full.rs | 23 +++-- src/daft-core/src/array/ops/get.rs | 10 +- src/daft-core/src/array/ops/image.rs | 5 +- src/daft-core/src/datatypes/dtype.rs | 71 ++++++-------- src/daft-core/src/datatypes/field.rs | 2 +- src/daft-core/src/python/datatype.rs | 20 ++-- src/daft-core/src/python/series.rs | 4 +- src/daft-core/src/series/ops/image.rs | 4 +- src/daft-core/src/utils/supertype.rs | 7 +- src/daft-dsl/src/functions/image/crop.rs | 4 +- tests/dataframe/test_creation.py | 8 +- tests/dataframe/test_logical_type.py | 2 +- tests/series/test_cast.py | 10 +- tests/series/test_concat.py | 4 +- tests/series/test_embedding.py | 2 +- tests/series/test_if_else.py | 4 +- tests/table/table_io/test_json.py | 2 +- tests/table/table_io/test_parquet.py | 14 +-- tests/table/test_broadcasts.py | 2 +- tests/table/test_from_py.py | 10 +- tests/table/test_table_aggs.py | 20 ++-- tests/test_schema.py | 2 +- 32 files changed, 175 insertions(+), 235 deletions(-) diff --git a/daft/datatype.py b/daft/datatype.py index 52b3e68c19..bdfa65390b 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -204,16 +204,16 @@ def duration(cls, timeunit: TimeUnit) -> DataType: return cls._from_pydatatype(PyDataType.duration(timeunit._timeunit)) @classmethod - def list(cls, name: str, dtype: DataType) -> DataType: + def list(cls, dtype: DataType) -> DataType: """Create a List DataType: Variable-length list, where each element in the list has type ``dtype`` Args: dtype: DataType of each element in the list """ - return cls._from_pydatatype(PyDataType.list(name, dtype._dtype)) + return cls._from_pydatatype(PyDataType.list(dtype._dtype)) @classmethod - def fixed_size_list(cls, name: str, dtype: DataType, size: int) -> DataType: + def fixed_size_list(cls, dtype: DataType, size: int) -> DataType: """Create a FixedSizeList DataType: Fixed-size list, where each element in the list has type ``dtype`` and each list has length ``size``. @@ -223,7 +223,7 @@ def fixed_size_list(cls, name: str, dtype: DataType, size: int) -> DataType: """ if not isinstance(size, int) or size <= 0: raise ValueError("The size for a fixed-size list must be a positive integer, but got: ", size) - return cls._from_pydatatype(PyDataType.fixed_size_list(name, dtype._dtype, size)) + return cls._from_pydatatype(PyDataType.fixed_size_list(dtype._dtype, size)) @classmethod def struct(cls, fields: dict[str, DataType]) -> DataType: @@ -239,7 +239,7 @@ def extension(cls, name: str, storage_dtype: DataType, metadata: str | None = No return cls._from_pydatatype(PyDataType.extension(name, storage_dtype._dtype, metadata)) @classmethod - def embedding(cls, name: str, dtype: DataType, size: int) -> DataType: + def embedding(cls, dtype: DataType, size: int) -> DataType: """Create an Embedding DataType: embeddings are fixed size arrays, where each element in the array has a **numeric** ``dtype`` and each array has a fixed length of ``size``. @@ -249,7 +249,7 @@ def embedding(cls, name: str, dtype: DataType, size: int) -> DataType: """ if not isinstance(size, int) or size <= 0: raise ValueError("The size for a embedding must be a positive integer, but got: ", size) - return cls._from_pydatatype(PyDataType.embedding(name, dtype._dtype, size)) + return cls._from_pydatatype(PyDataType.embedding(dtype._dtype, size)) @classmethod def image( @@ -360,11 +360,11 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType: elif pa.types.is_list(arrow_type) or pa.types.is_large_list(arrow_type): assert isinstance(arrow_type, (pa.ListType, pa.LargeListType)) field = arrow_type.value_field - return cls.list(field.name, cls.from_arrow_type(field.type)) + return cls.list(cls.from_arrow_type(field.type)) elif pa.types.is_fixed_size_list(arrow_type): assert isinstance(arrow_type, pa.FixedSizeListType) field = arrow_type.value_field - return cls.fixed_size_list(field.name, cls.from_arrow_type(field.type), arrow_type.list_size) + return cls.fixed_size_list(cls.from_arrow_type(field.type), arrow_type.list_size) elif pa.types.is_struct(arrow_type): assert isinstance(arrow_type, pa.StructType) fields = [arrow_type[i] for i in range(arrow_type.num_fields)] diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 9266450c54..be70827895 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -800,6 +800,6 @@ def crop(self, bbox: tuple[int, int, int, int] | Expression) -> Expression: raise ValueError( f"Expected `bbox` to be either a tuple of 4 ints or an Expression but received: {bbox}" ) - bbox = Expression._to_expression(bbox).cast(DataType.fixed_size_list("", DataType.uint64(), 4)) + bbox = Expression._to_expression(bbox).cast(DataType.fixed_size_list(DataType.uint64(), 4)) assert isinstance(bbox, Expression) return Expression._from_pyexpr(self._expr.image_crop(bbox._expr)) diff --git a/daft/series.py b/daft/series.py index 45b18a2d88..84b8af6da8 100644 --- a/daft/series.py +++ b/daft/series.py @@ -68,7 +68,7 @@ def from_arrow(array: pa.Array | pa.ChunkedArray, name: str = "arrow_series") -> storage_series = Series.from_arrow(array.storage, name=name) series = storage_series.cast( DataType.fixed_size_list( - "item", DataType.from_arrow_type(array.type.scalar_type), int(np.prod(array.type.shape)) + DataType.from_arrow_type(array.type.scalar_type), int(np.prod(array.type.shape)) ) ) return series.cast(DataType.from_arrow_type(array.type)) diff --git a/src/daft-core/src/array/fixed_size_list_array.rs b/src/daft-core/src/array/fixed_size_list_array.rs index d8bd03f08d..ff48317aee 100644 --- a/src/daft-core/src/array/fixed_size_list_array.rs +++ b/src/daft-core/src/array/fixed_size_list_array.rs @@ -24,7 +24,7 @@ impl FixedSizeListArray { ) -> Self { let field: Arc = field.into(); match &field.as_ref().dtype { - DataType::FixedSizeList(child_field, size) => { + DataType::FixedSizeList(child_dtype, size) => { if let Some(validity) = validity.as_ref() && (validity.len() * size) != flat_child.len() { panic!( "FixedSizeListArray::new received values with len {} but expected it to match len of validity * size: {}", @@ -32,11 +32,11 @@ impl FixedSizeListArray { (validity.len() * size), ) } - if child_field.as_ref() != flat_child.field() { + if child_dtype.as_ref() != flat_child.data_type() { panic!( - "FixedSizeListArray::new expects the child series to have field {}, but received: {}", - child_field, - flat_child.field(), + "FixedSizeListArray::new expects the child series to have dtype {}, but received: {}", + child_dtype, + flat_child.data_type(), ) } } @@ -103,7 +103,7 @@ impl FixedSizeListArray { pub fn child_data_type(&self) -> &DataType { match &self.field.dtype { - DataType::FixedSizeList(child, _) => &child.dtype, + DataType::FixedSizeList(child, _) => child.as_ref(), _ => unreachable!("FixedSizeListArray must have DataType::FixedSizeList(..)"), } } @@ -163,10 +163,7 @@ mod tests { /// Helper that returns a FixedSizeListArray, with each list element at len=3 fn get_i32_fixed_size_list_array(validity: &[bool]) -> FixedSizeListArray { - let field = Field::new( - "foo", - DataType::FixedSizeList(Box::new(Field::new("foo", DataType::Int32)), 3), - ); + let field = Field::new("foo", DataType::FixedSizeList(Box::new(DataType::Int32), 3)); let flat_child = Int32Array::from(( "foo", (0i32..(validity.len() * 3) as i32).collect::>(), diff --git a/src/daft-core/src/array/growable/fixed_size_list_growable.rs b/src/daft-core/src/array/growable/fixed_size_list_growable.rs index cb2a2fea99..2a7cd51635 100644 --- a/src/daft-core/src/array/growable/fixed_size_list_growable.rs +++ b/src/daft-core/src/array/growable/fixed_size_list_growable.rs @@ -25,10 +25,10 @@ impl<'a> FixedSizeListGrowable<'a> { capacity: usize, ) -> Self { match dtype { - DataType::FixedSizeList(child_field, element_fixed_len) => { + DataType::FixedSizeList(child_dtype, element_fixed_len) => { let child_growable = make_growable( - child_field.name.as_str(), - &child_field.dtype, + "item", + child_dtype.as_ref(), arrays.iter().map(|a| &a.flat_child).collect::>(), use_validity, capacity * element_fixed_len, diff --git a/src/daft-core/src/array/growable/list_growable.rs b/src/daft-core/src/array/growable/list_growable.rs index a7f2f2b5e1..3f9ffd4356 100644 --- a/src/daft-core/src/array/growable/list_growable.rs +++ b/src/daft-core/src/array/growable/list_growable.rs @@ -28,10 +28,10 @@ impl<'a> ListGrowable<'a> { child_capacity: usize, ) -> Self { match dtype { - DataType::List(child_field) => { + DataType::List(child_dtype) => { let child_growable = make_growable( - child_field.name.as_str(), - &child_field.dtype, + "list", + child_dtype.as_ref(), arrays.iter().map(|a| &a.flat_child).collect::>(), use_validity, child_capacity, diff --git a/src/daft-core/src/array/list_array.rs b/src/daft-core/src/array/list_array.rs index 1f6b2d0118..1f8775801d 100644 --- a/src/daft-core/src/array/list_array.rs +++ b/src/daft-core/src/array/list_array.rs @@ -26,15 +26,15 @@ impl ListArray { ) -> Self { let field: Arc = field.into(); match &field.as_ref().dtype { - DataType::List(child_field) => { + DataType::List(child_dtype) => { if let Some(validity) = validity.as_ref() && validity.len() != offsets.len_proxy() { panic!("ListArray::new validity length does not match computed length from offsets") } - if child_field.as_ref() != flat_child.field() { + if child_dtype.as_ref() != flat_child.data_type() { panic!( "ListArray::new expects the child series to have field {}, but received: {}", - child_field, - flat_child.field(), + child_dtype, + flat_child.data_type(), ) } if *offsets.last() > flat_child.len() as i64 { @@ -116,7 +116,7 @@ impl ListArray { pub fn child_data_type(&self) -> &DataType { match &self.field.dtype { - DataType::List(child) => &child.dtype, + DataType::List(child_dtype) => child_dtype.as_ref(), _ => unreachable!("ListArray must have DataType::List(..)"), } } diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index 0e01d90d6a..139f305caa 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -721,23 +721,20 @@ fn extract_python_like_to_fixed_size_list< >( py: Python<'_>, python_objects: &PythonArray, - child_field: &Field, + child_dtype: &DataType, list_size: usize, ) -> DaftResult { - let (values_vec, _, _, _) = extract_python_to_vec::( - py, - python_objects, - &child_field.dtype, - None, - Some(list_size), - None, - )?; + let (values_vec, _, _, _) = + extract_python_to_vec::(py, python_objects, child_dtype, None, Some(list_size), None)?; let values_array: Box = Box::new(arrow2::array::PrimitiveArray::from_vec(values_vec)); - let inner_field = child_field.to_arrow()?; - let list_dtype = arrow2::datatypes::DataType::FixedSizeList(Box::new(inner_field), list_size); + let inner_dtype = child_dtype.to_arrow()?; + let list_dtype = arrow2::datatypes::DataType::FixedSizeList( + Box::new(arrow2::datatypes::Field::new("item", inner_dtype, true)), + list_size, + ); let daft_type = (&list_dtype).into(); let list_array = arrow2::array::FixedSizeListArray::new( @@ -758,19 +755,21 @@ fn extract_python_like_to_list< >( py: Python<'_>, python_objects: &PythonArray, - child_field: &Field, + child_dtype: &DataType, ) -> DaftResult { let (values_vec, offsets, _, _) = - extract_python_to_vec::(py, python_objects, &child_field.dtype, None, None, None)?; + extract_python_to_vec::(py, python_objects, child_dtype, None, None, None)?; let offsets = offsets.expect("Offsets should but non-None for dynamic list"); let values_array: Box = Box::new(arrow2::array::PrimitiveArray::from_vec(values_vec)); - let inner_field = child_field.to_arrow()?; + let inner_dtype = child_dtype.to_arrow()?; - let list_dtype = arrow2::datatypes::DataType::LargeList(Box::new(inner_field)); + let list_dtype = arrow2::datatypes::DataType::LargeList(Box::new( + arrow2::datatypes::Field::new("item", inner_dtype, true), + )); let daft_type = (&list_dtype).into(); @@ -995,32 +994,32 @@ impl PythonArray { dt @ DataType::Float32 | dt @ DataType::Float64 => { pycast_then_arrowcast!(self, dt, "float") } - DataType::List(field) => { - if !field.dtype.is_numeric() { + DataType::List(child_dtype) => { + if !child_dtype.is_numeric() { return Err(DaftError::ValueError(format!( "We can only convert numeric python types to List, got {}", - field.dtype + child_dtype ))); } - with_match_numeric_daft_types!(field.dtype, |$T| { + with_match_numeric_daft_types!(child_dtype.as_ref(), |$T| { type Tgt = <$T as DaftNumericType>::Native; pyo3::Python::with_gil(|py| { - let result = extract_python_like_to_list::(py, self, field)?; + let result = extract_python_like_to_list::(py, self, child_dtype.as_ref())?; Ok(result.into_series()) }) }) } - DataType::FixedSizeList(field, size) => { - if !field.dtype.is_numeric() { + DataType::FixedSizeList(child_dtype, size) => { + if !child_dtype.is_numeric() { return Err(DaftError::ValueError(format!( "We can only convert numeric python types to FixedSizeList, got {}", - field.dtype + child_dtype, ))); } - with_match_numeric_daft_types!(field.dtype, |$T| { + with_match_numeric_daft_types!(child_dtype.as_ref(), |$T| { type Tgt = <$T as DaftNumericType>::Native; pyo3::Python::with_gil(|py| { - let result = extract_python_like_to_fixed_size_list::(py, self, field, *size)?; + let result = extract_python_like_to_fixed_size_list::(py, self, child_dtype.as_ref(), *size)?; Ok(result.into_series()) }) }) @@ -1107,7 +1106,7 @@ impl EmbeddingArray { (DataType::Tensor(_), DataType::Embedding(inner_dtype, size)) => { let image_shape = vec![*size as u64]; let fixed_shape_tensor_dtype = - DataType::FixedShapeTensor(Box::new(inner_dtype.clone().dtype), image_shape); + DataType::FixedShapeTensor(Box::new(inner_dtype.as_ref().clone()), image_shape); let fixed_shape_tensor_array = self.cast(&fixed_shape_tensor_dtype)?; let fixed_shape_tensor_array = fixed_shape_tensor_array.downcast::()?; @@ -1194,7 +1193,7 @@ impl ImageArray { shapes.push(wa.value(i) as u64); shapes.push(ca.value(i) as u64); } - let shapes_dtype = DataType::List(Box::new(Field::new("shape", DataType::UInt64))); + let shapes_dtype = DataType::List(Box::new(DataType::UInt64)); let shape_offsets = arrow2::offset::OffsetsBuffer::try_from(shape_offsets)?; let shapes_array = ListArray::new( Field::new("shape", shapes_dtype), @@ -1350,7 +1349,7 @@ impl TensorArray { let size = shape.iter().product::() as usize; let result = da.cast(&DataType::FixedSizeList( - Box::new(Field::new("data", inner_dtype.as_ref().clone())), + Box::new(inner_dtype.as_ref().clone()), size, ))?; let tensor_array = FixedShapeTensorArray::new( @@ -1528,19 +1527,13 @@ impl FixedShapeTensorArray { // FixedSizeList -> List let list_arr = physical_arr - .cast(&DataType::List(Box::new(Field::new( - "data", - inner_dtype.as_ref().clone(), - ))))? + .cast(&DataType::List(Box::new(inner_dtype.as_ref().clone())))? .rename("data"); // List -> Struct let shape_offsets = arrow2::offset::OffsetsBuffer::try_from(shape_offsets)?; let shapes_array = ListArray::new( - Field::new( - "shape", - DataType::List(Box::new(Field::new("shape", DataType::UInt64))), - ), + Field::new("shape", DataType::List(Box::new(DataType::UInt64))), Series::try_from(( "shape", Box::new(arrow2::array::PrimitiveArray::from_vec(shapes)) @@ -1569,7 +1562,7 @@ impl FixedShapeTensorArray { impl FixedSizeListArray { pub fn cast(&self, dtype: &DataType) -> DaftResult { match dtype { - DataType::FixedSizeList(child, size) => { + DataType::FixedSizeList(child_dtype, size) => { if size != &self.fixed_element_len() { return Err(DaftError::ValueError(format!( "Cannot cast from FixedSizeListSeries with size {} to size: {}", @@ -1577,10 +1570,7 @@ impl FixedSizeListArray { size ))); } - let casted_child = self - .flat_child - .cast(&child.dtype)? - .rename(child.name.as_str()); + let casted_child = self.flat_child.cast(child_dtype.as_ref())?; Ok(FixedSizeListArray::new( Field::new(self.name().to_string(), dtype.clone()), casted_child, @@ -1588,12 +1578,9 @@ impl FixedSizeListArray { ) .into_series()) } - DataType::List(child) => { + DataType::List(child_dtype) => { let element_size = self.fixed_element_len(); - let casted_child = self - .flat_child - .cast(&child.dtype)? - .rename(child.name.as_str()); + let casted_child = self.flat_child.cast(child_dtype.as_ref())?; let offsets: Offsets = match self.validity() { None => Offsets::try_from_iter(repeat(element_size).take(self.len()))?, Some(validity) => Offsets::try_from_iter(validity.iter().map(|v| { @@ -1655,16 +1642,14 @@ impl FixedSizeListArray { impl ListArray { pub fn cast(&self, dtype: &DataType) -> DaftResult { match dtype { - DataType::List(child_field) => Ok(ListArray::new( + DataType::List(child_dtype) => Ok(ListArray::new( Field::new(self.name(), dtype.clone()), - self.flat_child - .cast(&child_field.dtype)? - .rename(child_field.name.as_str()), + self.flat_child.cast(child_dtype.as_ref())?, self.offsets().clone(), self.validity().cloned(), ) .into_series()), - DataType::FixedSizeList(child_field, size) => { + DataType::FixedSizeList(child_dtype, size) => { // Validate lengths of elements are equal to `size` let lengths_ok = match self.validity() { None => self.offsets().lengths().all(|l| l == *size), @@ -1682,10 +1667,7 @@ impl ListArray { } // Cast child - let casted_child = self - .flat_child - .cast(&child_field.dtype)? - .rename(child_field.name.as_str()); + let casted_child = self.flat_child.cast(child_dtype.as_ref())?; // Build a FixedSizeListArray match self.validity() { @@ -1699,8 +1681,8 @@ impl ListArray { // Some invalids, we need to insert nulls into the child Some(validity) => { let mut child_growable = make_growable( - child_field.name.as_str(), - &child_field.dtype, + "item", + child_dtype.as_ref(), vec![&casted_child], true, self.validity() diff --git a/src/daft-core/src/array/ops/concat_agg.rs b/src/daft-core/src/array/ops/concat_agg.rs index 60f91eb9d9..97b588fd0b 100644 --- a/src/daft-core/src/array/ops/concat_agg.rs +++ b/src/daft-core/src/array/ops/concat_agg.rs @@ -162,10 +162,7 @@ mod test { fn test_list_concat_agg_all_null() -> DaftResult<()> { // [None, None, None] let list_array = ListArray::new( - Field::new( - "foo", - DataType::List(Box::new(Field::new("item", DataType::Int64))), - ), + Field::new("foo", DataType::List(Box::new(DataType::Int64))), Int64Array::from(( "item", Box::new(arrow2::array::Int64Array::from_iter(vec![].iter())), @@ -189,10 +186,7 @@ mod test { fn test_list_concat_agg_with_nulls() -> DaftResult<()> { // [[0], [1, 1], [2, None], [None], [], None, None] let list_array = ListArray::new( - Field::new( - "foo", - DataType::List(Box::new(Field::new("item", DataType::Int64))), - ), + Field::new("foo", DataType::List(Box::new(DataType::Int64))), Int64Array::from(( "item", Box::new(arrow2::array::Int64Array::from_iter( @@ -227,10 +221,7 @@ mod test { // [[0], [0, 0], [1, None], [None], [2, None], None, None, None] // | group0 | | group1 | | group 2 | group 3 | let list_array = ListArray::new( - Field::new( - "foo", - DataType::List(Box::new(Field::new("item", DataType::Int64))), - ), + Field::new("foo", DataType::List(Box::new(DataType::Int64))), Int64Array::from(( "item", Box::new(arrow2::array::Int64Array::from_iter( diff --git a/src/daft-core/src/array/ops/from_arrow.rs b/src/daft-core/src/array/ops/from_arrow.rs index 30c734420f..d48f339c82 100644 --- a/src/daft-core/src/array/ops/from_arrow.rs +++ b/src/daft-core/src/array/ops/from_arrow.rs @@ -38,14 +38,14 @@ where impl FromArrow for FixedSizeListArray { fn from_arrow(field: &Field, arrow_arr: Box) -> DaftResult { match (&field.dtype, arrow_arr.data_type()) { - (DataType::FixedSizeList(daft_child_field, daft_size), arrow2::datatypes::DataType::FixedSizeList(_arrow_child_field, arrow_size)) => { + (DataType::FixedSizeList(daft_child_dtype, daft_size), arrow2::datatypes::DataType::FixedSizeList(_arrow_child_field, arrow_size)) => { if daft_size != arrow_size { return Err(DaftError::TypeError(format!("Attempting to create Daft FixedSizeListArray with element length {} from Arrow FixedSizeList array with element length {}", daft_size, arrow_size))); } let arrow_arr = arrow_arr.as_ref().as_any().downcast_ref::().unwrap(); let arrow_child_array = arrow_arr.values(); - let child_series = Series::from_arrow(daft_child_field.as_ref(), arrow_child_array.clone())?; + let child_series = Series::from_arrow(&Field::new("item", daft_child_dtype.as_ref().clone()), arrow_child_array.clone())?; Ok(FixedSizeListArray::new( field.clone(), child_series, @@ -60,13 +60,13 @@ impl FromArrow for FixedSizeListArray { impl FromArrow for ListArray { fn from_arrow(field: &Field, arrow_arr: Box) -> DaftResult { match (&field.dtype, arrow_arr.data_type()) { - (DataType::List(daft_child_field), arrow2::datatypes::DataType::List(arrow_child_field)) | - (DataType::List(daft_child_field), arrow2::datatypes::DataType::LargeList(arrow_child_field)) + (DataType::List(daft_child_dtype), arrow2::datatypes::DataType::List(arrow_child_field)) | + (DataType::List(daft_child_dtype), arrow2::datatypes::DataType::LargeList(arrow_child_field)) => { let arrow_arr = arrow_arr.to_type(arrow2::datatypes::DataType::LargeList(arrow_child_field.clone())); let arrow_arr = arrow_arr.as_any().downcast_ref::>().unwrap(); let arrow_child_array = arrow_arr.values(); - let child_series = Series::from_arrow(daft_child_field.as_ref(), arrow_child_array.clone())?; + let child_series = Series::from_arrow(&Field::new("list", daft_child_dtype.as_ref().clone()), arrow_child_array.clone())?; Ok(ListArray::new( field.clone(), child_series, diff --git a/src/daft-core/src/array/ops/full.rs b/src/daft-core/src/array/ops/full.rs index 2b0e817ccf..f39c904709 100644 --- a/src/daft-core/src/array/ops/full.rs +++ b/src/daft-core/src/array/ops/full.rs @@ -93,9 +93,8 @@ impl FullNull for FixedSizeListArray { let validity = arrow2::bitmap::Bitmap::from_iter(repeat(false).take(length)); match dtype { - DataType::FixedSizeList(child, size) => { - let flat_child = - Series::full_null(child.name.as_str(), &child.dtype, length * size); + DataType::FixedSizeList(child_dtype, size) => { + let flat_child = Series::full_null("item", child_dtype, length * size); Self::new(Field::new(name, dtype.clone()), flat_child, Some(validity)) } _ => panic!( @@ -107,9 +106,9 @@ impl FullNull for FixedSizeListArray { fn empty(name: &str, dtype: &DataType) -> Self { match dtype { - DataType::FixedSizeList(child, _) => { + DataType::FixedSizeList(child_dtype, _) => { let field = Field::new(name, dtype.clone()); - let empty_child = Series::empty(child.name.as_str(), &child.dtype); + let empty_child = Series::empty("item", child_dtype.as_ref()); Self::new(field, empty_child, None) } _ => panic!( @@ -125,8 +124,8 @@ impl FullNull for ListArray { let validity = arrow2::bitmap::Bitmap::from_iter(repeat(false).take(length)); match dtype { - DataType::List(child) => { - let empty_flat_child = Series::empty(child.name.as_str(), &child.dtype); + DataType::List(child_dtype) => { + let empty_flat_child = Series::empty("list", child_dtype.as_ref()); Self::new( Field::new(name, dtype.clone()), empty_flat_child, @@ -143,9 +142,9 @@ impl FullNull for ListArray { fn empty(name: &str, dtype: &DataType) -> Self { match dtype { - DataType::List(child) => { + DataType::List(child_dtype) => { let field = Field::new(name, dtype.clone()); - let empty_child = Series::empty(child.name.as_str(), &child.dtype); + let empty_child = Series::empty("list", child_dtype.as_ref()); Self::new(field, empty_child, OffsetsBuffer::default(), None) } _ => panic!("Cannot create empty ListArray with dtype: {}", dtype), @@ -198,7 +197,7 @@ mod tests { fn create_fixed_size_list_full_null() -> DaftResult<()> { let arr = FixedSizeListArray::full_null( "foo", - &DataType::FixedSizeList(Box::new(Field::new("bar", DataType::Int64)), 3), + &DataType::FixedSizeList(Box::new(DataType::Int64), 3), 3, ); assert_eq!(arr.len(), 3); @@ -226,7 +225,7 @@ mod tests { fn create_fixed_size_list_full_null_empty() -> DaftResult<()> { let arr = FixedSizeListArray::full_null( "foo", - &DataType::FixedSizeList(Box::new(Field::new("bar", DataType::Int64)), 3), + &DataType::FixedSizeList(Box::new(DataType::Int64), 3), 0, ); assert_eq!(arr.len(), 0); @@ -248,7 +247,7 @@ mod tests { fn create_fixed_size_list_empty() -> DaftResult<()> { let arr = FixedSizeListArray::empty( "foo", - &DataType::FixedSizeList(Box::new(Field::new("bar", DataType::Int64)), 3), + &DataType::FixedSizeList(Box::new(DataType::Int64), 3), ); assert_eq!(arr.len(), 0); Ok(()) diff --git a/src/daft-core/src/array/ops/get.rs b/src/daft-core/src/array/ops/get.rs index 340809a015..a0f534da07 100644 --- a/src/daft-core/src/array/ops/get.rs +++ b/src/daft-core/src/array/ops/get.rs @@ -168,10 +168,7 @@ mod tests { #[test] fn test_fixed_size_list_get_all_valid() -> DaftResult<()> { - let field = Field::new( - "foo", - DataType::FixedSizeList(Box::new(Field::new("foo", DataType::Int32)), 3), - ); + let field = Field::new("foo", DataType::FixedSizeList(Box::new(DataType::Int32), 3)); let flat_child = Int32Array::from(("foo", (0..9).collect::>())); let validity = None; let arr = FixedSizeListArray::new(field, flat_child.into_series(), validity); @@ -201,10 +198,7 @@ mod tests { #[test] fn test_fixed_size_list_get_some_valid() -> DaftResult<()> { - let field = Field::new( - "foo", - DataType::FixedSizeList(Box::new(Field::new("foo", DataType::Int32)), 3), - ); + let field = Field::new("foo", DataType::FixedSizeList(Box::new(DataType::Int32), 3)); let flat_child = Int32Array::from(("foo", (0..9).collect::>())); let raw_validity = vec![true, false, true]; let validity = Some(arrow2::bitmap::Bitmap::from(raw_validity.as_slice())); diff --git a/src/daft-core/src/array/ops/image.rs b/src/daft-core/src/array/ops/image.rs index 09662a19de..7685c1e92f 100644 --- a/src/daft-core/src/array/ops/image.rs +++ b/src/daft-core/src/array/ops/image.rs @@ -420,10 +420,7 @@ impl ImageArray { } } let data_array = ListArray::new( - Field::new( - "data", - DataType::List(Box::new(Field::new("data", (&arrow_dtype).into()))), - ), + Field::new("data", DataType::List(Box::new((&arrow_dtype).into()))), Series::try_from(( "data", Box::new(arrow2::array::PrimitiveArray::from_vec(data)) diff --git a/src/daft-core/src/datatypes/dtype.rs b/src/daft-core/src/datatypes/dtype.rs index 24c650d0ca..293f0b29ce 100644 --- a/src/daft-core/src/datatypes/dtype.rs +++ b/src/daft-core/src/datatypes/dtype.rs @@ -71,16 +71,16 @@ pub enum DataType { /// A variable-length UTF-8 encoded string whose offsets are represented as [`i64`]. Utf8, /// A list of some logical data type with a fixed number of elements. - FixedSizeList(Box, usize), + FixedSizeList(Box, usize), /// A list of some logical data type whose offsets are represented as [`i64`]. - List(Box), + List(Box), /// A nested [`DataType`] with a given number of [`Field`]s. Struct(Vec), /// Extension type. Extension(String, Box, Option), // Stop ArrowTypes /// A logical type for embeddings. - Embedding(Box, usize), + Embedding(Box, usize), /// A logical type for images with variable shapes. Image(Option), /// A logical type for images with the same size (height x width). @@ -142,10 +142,17 @@ impl DataType { DataType::Duration(unit) => Ok(ArrowType::Duration(unit.to_arrow())), DataType::Binary => Ok(ArrowType::LargeBinary), DataType::Utf8 => Ok(ArrowType::LargeUtf8), - DataType::FixedSizeList(field, size) => { - Ok(ArrowType::FixedSizeList(Box::new(field.to_arrow()?), *size)) - } - DataType::List(field) => Ok(ArrowType::LargeList(Box::new(field.to_arrow()?))), + DataType::FixedSizeList(child_dtype, size) => Ok(ArrowType::FixedSizeList( + Box::new(arrow2::datatypes::Field::new( + "item", + child_dtype.to_arrow()?, + true, + )), + *size, + )), + DataType::List(field) => Ok(ArrowType::LargeList(Box::new( + arrow2::datatypes::Field::new("item", field.to_arrow()?, true), + ))), DataType::Struct(fields) => Ok({ let fields = fields .iter() @@ -187,28 +194,15 @@ impl DataType { Decimal128(..) => Int128, Date => Int32, Duration(_) | Timestamp(..) | Time(_) => Int64, - List(field) => List(Box::new( - Field::new(field.name.clone(), field.dtype.to_physical()) - .with_metadata(field.metadata.clone()), - )), - FixedSizeList(field, size) => FixedSizeList( - Box::new( - Field::new(field.name.clone(), field.dtype.to_physical()) - .with_metadata(field.metadata.clone()), - ), - *size, - ), - Embedding(field, size) => FixedSizeList( - Box::new(Field::new(field.name.clone(), field.dtype.to_physical())), - *size, - ), + List(child_dtype) => List(Box::new(child_dtype.to_physical())), + FixedSizeList(child_dtype, size) => { + FixedSizeList(Box::new(child_dtype.to_physical()), *size) + } + Embedding(dtype, size) => FixedSizeList(Box::new(dtype.to_physical()), *size), Image(mode) => Struct(vec![ Field::new( "data", - List(Box::new(Field::new( - "data", - mode.map_or(DataType::UInt8, |m| m.get_dtype()), - ))), + List(Box::new(mode.map_or(DataType::UInt8, |m| m.get_dtype()))), ), Field::new("channel", UInt16), Field::new("height", UInt32), @@ -216,18 +210,15 @@ impl DataType { Field::new("mode", UInt8), ]), FixedShapeImage(mode, height, width) => FixedSizeList( - Box::new(Field::new("data", mode.get_dtype())), + Box::new(mode.get_dtype()), usize::try_from(mode.num_channels() as u32 * height * width).unwrap(), ), Tensor(dtype) => Struct(vec![ - Field::new("data", List(Box::new(Field::new("data", *dtype.clone())))), - Field::new( - "shape", - List(Box::new(Field::new("shape", DataType::UInt64))), - ), + Field::new("data", List(Box::new(*dtype.clone()))), + Field::new("shape", List(Box::new(DataType::UInt64))), ]), FixedShapeTensor(dtype, shape) => FixedSizeList( - Box::new(Field::new("data", *dtype.clone())), + Box::new(*dtype.clone()), usize::try_from(shape.iter().product::()).unwrap(), ), _ => { @@ -373,8 +364,8 @@ impl DataType { #[inline] pub fn get_exploded_dtype(&self) -> DaftResult<&DataType> { match self { - DataType::List(child_field) | DataType::FixedSizeList(child_field, _) => { - Ok(&child_field.dtype) + DataType::List(child_dtype) | DataType::FixedSizeList(child_dtype, _) => { + Ok(child_dtype.as_ref()) } _ => Err(DaftError::ValueError(format!( "Datatype cannot be exploded: {self}" @@ -422,10 +413,10 @@ impl From<&ArrowType> for DataType { ArrowType::Utf8 | ArrowType::LargeUtf8 => DataType::Utf8, ArrowType::Decimal(precision, scale) => DataType::Decimal128(*precision, *scale), ArrowType::List(field) | ArrowType::LargeList(field) => { - DataType::List(Box::new(field.as_ref().into())) + DataType::List(Box::new(field.as_ref().data_type().into())) } ArrowType::FixedSizeList(field, size) => { - DataType::FixedSizeList(Box::new(field.as_ref().into()), *size) + DataType::FixedSizeList(Box::new(field.as_ref().data_type().into()), *size) } ArrowType::Struct(fields) => { let fields: Vec = fields.iter().map(|fld| fld.into()).collect(); @@ -467,9 +458,9 @@ impl Display for DataType { // `f` is a buffer, and this method must write the formatted string into it fn fmt(&self, f: &mut Formatter) -> Result { match self { - DataType::List(nested) => write!(f, "List[{}:{}]", nested.name, nested.dtype), + DataType::List(nested) => write!(f, "List[{}]", nested), DataType::FixedSizeList(inner, size) => { - write!(f, "FixedSizeList[{}; {}]", inner.dtype, size) + write!(f, "FixedSizeList[{}; {}]", inner, size) } DataType::Struct(fields) => { let fields: String = fields @@ -480,7 +471,7 @@ impl Display for DataType { write!(f, "Struct[{fields}]") } DataType::Embedding(inner, size) => { - write!(f, "Embedding[{}; {}]", inner.dtype, size) + write!(f, "Embedding[{}; {}]", inner, size) } DataType::Image(mode) => { write!( diff --git a/src/daft-core/src/datatypes/field.rs b/src/daft-core/src/datatypes/field.rs index 229bf7b683..dd37dd94b3 100644 --- a/src/daft-core/src/datatypes/field.rs +++ b/src/daft-core/src/datatypes/field.rs @@ -87,7 +87,7 @@ impl Field { if self.dtype.is_python() { return Ok(self.clone()); } - let list_dtype = DataType::List(Box::new(self.clone())); + let list_dtype = DataType::List(Box::new(self.dtype.clone())); Ok(Self { name: self.name.clone(), dtype: list_dtype, diff --git a/src/daft-core/src/python/datatype.rs b/src/daft-core/src/python/datatype.rs index 259e1c4060..1efe646147 100644 --- a/src/daft-core/src/python/datatype.rs +++ b/src/daft-core/src/python/datatype.rs @@ -181,23 +181,19 @@ impl PyDataType { } #[staticmethod] - pub fn list(name: &str, data_type: Self) -> PyResult { - Ok(DataType::List(Box::new(Field::new(name, data_type.dtype))).into()) + pub fn list(data_type: Self) -> PyResult { + Ok(DataType::List(Box::new(data_type.dtype)).into()) } #[staticmethod] - pub fn fixed_size_list(name: &str, data_type: Self, size: i64) -> PyResult { + pub fn fixed_size_list(data_type: Self, size: i64) -> PyResult { if size <= 0 { return Err(PyValueError::new_err(format!( "The size for fixed-size list types must be a positive integer, but got: {}", size ))); } - Ok(DataType::FixedSizeList( - Box::new(Field::new(name, data_type.dtype)), - usize::try_from(size)?, - ) - .into()) + Ok(DataType::FixedSizeList(Box::new(data_type.dtype), usize::try_from(size)?).into()) } #[staticmethod] @@ -231,7 +227,7 @@ impl PyDataType { } #[staticmethod] - pub fn embedding(name: &str, data_type: Self, size: i64) -> PyResult { + pub fn embedding(data_type: Self, size: i64) -> PyResult { if size <= 0 { return Err(PyValueError::new_err(format!( "The size for embedding types must be a positive integer, but got: {}", @@ -245,11 +241,7 @@ impl PyDataType { ))); } - Ok(DataType::Embedding( - Box::new(Field::new(name, data_type.dtype)), - usize::try_from(size)?, - ) - .into()) + Ok(DataType::Embedding(Box::new(data_type.dtype), usize::try_from(size)?).into()) } #[staticmethod] diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index ddc7d6fc45..3e08961db3 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -325,7 +325,7 @@ impl From for series::Series { fn infer_daft_dtype_for_sequence( vec_pyobj: &[PyObject], py: pyo3::Python, - name: &str, + _name: &str, ) -> PyResult> { let py_pil_image_type = py .import(pyo3::intern!(py, "PIL.Image")) @@ -373,7 +373,7 @@ fn infer_daft_dtype_for_sequence( let inferred_inner_dtype = from_numpy_dtype.call1((np_dtype,)).map(|dt| dt.getattr(pyo3::intern!(py, "_dtype")).unwrap().extract::().unwrap().dtype); let shape: Vec = obj.getattr(pyo3::intern!(py, "shape"))?.extract()?; let inferred_dtype = match inferred_inner_dtype { - Ok(inferred_inner_dtype) if shape.len() == 1 => Some(DataType::List(Box::new(Field::new(name, inferred_inner_dtype)))), + Ok(inferred_inner_dtype) if shape.len() == 1 => Some(DataType::List(Box::new(inferred_inner_dtype))), Ok(inferred_inner_dtype) if shape.len() > 1 => Some(DataType::Tensor(Box::new(inferred_inner_dtype))), _ => None, }; diff --git a/src/daft-core/src/series/ops/image.rs b/src/daft-core/src/series/ops/image.rs index 098e873941..3f9947ee65 100644 --- a/src/daft-core/src/series/ops/image.rs +++ b/src/daft-core/src/series/ops/image.rs @@ -1,5 +1,5 @@ use crate::datatypes::logical::{FixedShapeImageArray, ImageArray}; -use crate::datatypes::{DataType, Field, ImageFormat}; +use crate::datatypes::{DataType, ImageFormat}; use crate::series::{IntoSeries, Series}; use common_error::{DaftError, DaftResult}; @@ -58,7 +58,7 @@ impl Series { } pub fn image_crop(&self, bbox: &Series) -> DaftResult { - let bbox_type = DataType::FixedSizeList(Box::new(Field::new("bbox", DataType::UInt32)), 4); + let bbox_type = DataType::FixedSizeList(Box::new(DataType::UInt32), 4); let bbox = bbox.cast(&bbox_type)?; let bbox = bbox.fixed_size_list()?; diff --git a/src/daft-core/src/utils/supertype.rs b/src/daft-core/src/utils/supertype.rs index fb1774defa..6ee1db850a 100644 --- a/src/daft-core/src/utils/supertype.rs +++ b/src/daft-core/src/utils/supertype.rs @@ -1,5 +1,4 @@ use crate::datatypes::DataType; -use crate::datatypes::Field; use crate::datatypes::TimeUnit; use common_error::DaftError; use common_error::DaftResult; @@ -192,9 +191,9 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option { //TODO(sammy): add time, struct related dtypes (Boolean, Float32) => Some(Float32), (Boolean, Float64) => Some(Float64), - (List(inner_left_field), List(inner_right_field)) => { - let inner_st = get_supertype(&inner_left_field.dtype, &inner_right_field.dtype)?; - Some(DataType::List(Box::new(Field::new(inner_left_field.name.clone(), inner_st)))) + (List(inner_left_dtype), List(inner_right_dtype)) => { + let inner_st = get_supertype(inner_left_dtype.as_ref(), inner_right_dtype.as_ref())?; + Some(DataType::List(Box::new(inner_st))) } // TODO(Clark): Add support for getting supertype for two fixed size lists once Arrow2 supports such a cast. // (FixedSizeList(inner_left_field, inner_left_size), FixedSizeList(inner_right_field, inner_right_size)) if inner_left_size == inner_right_size => { diff --git a/src/daft-dsl/src/functions/image/crop.rs b/src/daft-dsl/src/functions/image/crop.rs index 21eb75454a..6d4111d7ef 100644 --- a/src/daft-dsl/src/functions/image/crop.rs +++ b/src/daft-dsl/src/functions/image/crop.rs @@ -27,8 +27,8 @@ impl FunctionEvaluator for CropEvaluator { "bbox FixedSizeList field must have size 4 for cropping".to_string(), )); } - DataType::FixedSizeList(field, _) | DataType::List(field) - if !field.dtype.is_numeric() => + DataType::FixedSizeList(child_dtype, _) | DataType::List(child_dtype) + if !child_dtype.is_numeric() => { return Err(DaftError::TypeError( "bbox list field must have numeric child type".to_string(), diff --git a/tests/dataframe/test_creation.py b/tests/dataframe/test_creation.py index 616745f6bb..06b752197d 100644 --- a/tests/dataframe/test_creation.py +++ b/tests/dataframe/test_creation.py @@ -322,16 +322,16 @@ def test_create_dataframe_pandas_tensor(valid_data: list[dict[str, float]]) -> N ), pytest.param( [np.array([1]), np.array([2]), np.array([3])], - DataType.list("item", DataType.int64()), + DataType.list(DataType.int64()), id="numpy_1d_arrays", ), - pytest.param(pa.array([[1, 2, 3], [1, 2], [1]]), DataType.list("item", DataType.int64()), id="pa_nested"), + pytest.param(pa.array([[1, 2, 3], [1, 2], [1]]), DataType.list(DataType.int64()), id="pa_nested"), pytest.param( pa.chunked_array([pa.array([[1, 2, 3], [1, 2], [1]])]), - DataType.list("item", DataType.int64()), + DataType.list(DataType.int64()), id="pa_nested_chunked", ), - pytest.param(np.ones((3, 3)), DataType.list("item", DataType.float64()), id="np_nested_1d"), + pytest.param(np.ones((3, 3)), DataType.list(DataType.float64()), id="np_nested_1d"), pytest.param(np.ones((3, 3, 3)), DataType.tensor(DataType.float64()), id="np_nested_nd"), ], ) diff --git a/tests/dataframe/test_logical_type.py b/tests/dataframe/test_logical_type.py index 342824a185..946dd5906d 100644 --- a/tests/dataframe/test_logical_type.py +++ b/tests/dataframe/test_logical_type.py @@ -18,7 +18,7 @@ def test_embedding_type_df() -> None: data = [[1, 2, 3], np.arange(3), ["1", "2", "3"], [1, "2", 3.0], pd.Series([1.1, 2, 3]), (1, 2, 3), None] df = daft.from_pydict({"index": np.arange(len(data)), "embeddings": Series.from_pylist(data, pyobj="force")}) - target = DataType.embedding("arr", DataType.float32(), 3) + target = DataType.embedding(DataType.float32(), 3) df = df.select(col("index"), col("embeddings").cast(target)) df = df.repartition(4, "index") df = df.sort("index") diff --git a/tests/series/test_cast.py b/tests/series/test_cast.py index a64201577f..62b4e6ebad 100644 --- a/tests/series/test_cast.py +++ b/tests/series/test_cast.py @@ -132,7 +132,7 @@ def test_series_cast_python_to_list(dtype) -> None: data = [[1, 2, 3], np.arange(3), ["1", "2", "3"], [1, "2", 3.0], pd.Series([1.1, 2]), (1, 2), None] s = Series.from_pylist(data, pyobj="force") - target_dtype = DataType.list("arr", DataType.from_arrow_type(dtype)) + target_dtype = DataType.list(DataType.from_arrow_type(dtype)) t = s.cast(target_dtype) @@ -153,7 +153,7 @@ def test_series_cast_python_to_fixed_size_list(dtype) -> None: data = [[1, 2, 3], np.arange(3), ["1", "2", "3"], [1, "2", 3.0], pd.Series([1.1, 2, 3]), (1, 2, 3), None] s = Series.from_pylist(data, pyobj="force") - target_dtype = DataType.fixed_size_list("arr", DataType.from_arrow_type(dtype), 3) + target_dtype = DataType.fixed_size_list(DataType.from_arrow_type(dtype), 3) t = s.cast(target_dtype) @@ -174,7 +174,7 @@ def test_series_cast_python_to_embedding(dtype) -> None: data = [[1, 2, 3], np.arange(3), ["1", "2", "3"], [1, "2", 3.0], pd.Series([1.1, 2, 3]), (1, 2, 3), None] s = Series.from_pylist(data, pyobj="force") - target_dtype = DataType.embedding("arr", DataType.from_arrow_type(dtype), 3) + target_dtype = DataType.embedding(DataType.from_arrow_type(dtype), 3) t = s.cast(target_dtype) @@ -471,7 +471,7 @@ def test_series_cast_embedding_to_fixed_shape_tensor() -> None: ] s = Series.from_pylist(data, pyobj="force") - target_dtype = DataType.embedding("arr", DataType.uint8(), 4) + target_dtype = DataType.embedding(DataType.uint8(), 4) t = s.cast(target_dtype) @@ -498,7 +498,7 @@ def test_series_cast_embedding_to_tensor() -> None: ] s = Series.from_pylist(data, pyobj="force") - target_dtype = DataType.embedding("arr", DataType.uint8(), 4) + target_dtype = DataType.embedding(DataType.uint8(), 4) t = s.cast(target_dtype) diff --git a/tests/series/test_concat.py b/tests/series/test_concat.py index 69d370baa4..48f7bd06e0 100644 --- a/tests/series/test_concat.py +++ b/tests/series/test_concat.py @@ -53,9 +53,9 @@ def test_series_concat_list_array(chunks, fixed) -> None: concated = Series.concat(series) if fixed: - assert concated.datatype() == DataType.fixed_size_list("item", DataType.int64(), 2) + assert concated.datatype() == DataType.fixed_size_list(DataType.int64(), 2) else: - assert concated.datatype() == DataType.list("item", DataType.int64()) + assert concated.datatype() == DataType.list(DataType.int64()) concated_list = concated.to_pylist() counter = 0 diff --git a/tests/series/test_embedding.py b/tests/series/test_embedding.py index f309110270..3181fe7e78 100644 --- a/tests/series/test_embedding.py +++ b/tests/series/test_embedding.py @@ -13,7 +13,7 @@ def test_embedding_arrow_round_trip(): data = [[1, 2, 3], np.arange(3), ["1", "2", "3"], [1, "2", 3.0], pd.Series([1.1, 2, 3]), (1, 2, 3), None] s = Series.from_pylist(data, pyobj="force") - target_dtype = DataType.embedding("arr", DataType.int32(), 3) + target_dtype = DataType.embedding(DataType.int32(), 3) t = s.cast(target_dtype) diff --git a/tests/series/test_if_else.py b/tests/series/test_if_else.py index 9a946ecc9e..591cd449bf 100644 --- a/tests/series/test_if_else.py +++ b/tests/series/test_if_else.py @@ -186,7 +186,7 @@ def test_series_if_else_list(if_true, if_false, expected) -> None: if_false_series = Series.from_arrow(if_false) predicate_series = Series.from_arrow(pa.array([True, False, None, True])) result = predicate_series.if_else(if_true_series, if_false_series) - assert result.datatype() == DataType.list("item", DataType.int64()) + assert result.datatype() == DataType.list(DataType.int64()) assert result.to_pylist() == expected @@ -227,7 +227,7 @@ def test_series_if_else_fixed_size_list(if_true, if_false, expected) -> None: if_false_series = Series.from_arrow(if_false) predicate_series = Series.from_arrow(pa.array([True, False, None, True])) result = predicate_series.if_else(if_true_series, if_false_series) - assert result.datatype() == DataType.fixed_size_list("item", DataType.int64(), 2) + assert result.datatype() == DataType.fixed_size_list(DataType.int64(), 2) assert result.to_pylist() == expected diff --git a/tests/table/table_io/test_json.py b/tests/table/table_io/test_json.py index 569856cc56..d8ebdfea0a 100644 --- a/tests/table/table_io/test_json.py +++ b/tests/table/table_io/test_json.py @@ -59,7 +59,7 @@ def _json_write_helper(data: dict[str, list[Any]]): (True, DataType.bool()), (None, DataType.null()), ({"foo": 1}, DataType.struct({"foo": DataType.int64()})), - ([1, None, 2], DataType.list("item", DataType.int64())), + ([1, None, 2], DataType.list(DataType.int64())), ], ) def test_json_infer_schema(data, expected_dtype): diff --git a/tests/table/table_io/test_parquet.py b/tests/table/table_io/test_parquet.py index 570473de5f..ec85766983 100644 --- a/tests/table/table_io/test_parquet.py +++ b/tests/table/table_io/test_parquet.py @@ -55,16 +55,16 @@ def _parquet_write_helper(data: pa.Table, row_group_size: int = None, papq_write (True, DataType.bool()), (None, DataType.null()), ({"foo": 1}, DataType.struct({"foo": DataType.int64()})), - ([1, None, 2], DataType.list("item", DataType.int64())), + ([1, None, 2], DataType.list(DataType.int64())), ], ) @pytest.mark.parametrize("use_native_downloader", [True, False]) def test_parquet_infer_schema(data, expected_dtype, use_native_downloader): - # HACK: Pyarrow 13 changed their schema parsing behavior so we receive DataType.list("element", ..) instead of DataType.list("item", ..) - # However, our native downloader still parses DataType.list("item", ..) regardless of PyArrow version - if PYARROW_GE_13_0_0 and not use_native_downloader and expected_dtype == DataType.list("item", DataType.int64()): - expected_dtype = DataType.list("element", DataType.int64()) + # HACK: Pyarrow 13 changed their schema parsing behavior so we receive DataType.list(..) instead of DataType.list(..) + # However, our native downloader still parses DataType.list(..) regardless of PyArrow version + if PYARROW_GE_13_0_0 and not use_native_downloader and expected_dtype == DataType.list(DataType.int64()): + expected_dtype = DataType.list(DataType.int64()) with _parquet_write_helper( pa.Table.from_pydict( @@ -264,9 +264,9 @@ def test_parquet_read_int96_timestamps_schema_inference(coerce_to, store_schema) } schema = [ ("timestamp", DataType.timestamp(coerce_to)), - ("nested_timestamp", DataType.list("item", DataType.timestamp(coerce_to))), + ("nested_timestamp", DataType.list(DataType.timestamp(coerce_to))), ("struct_timestamp", DataType.struct({"foo": DataType.timestamp(coerce_to)})), - ("struct_nested_timestamp", DataType.struct({"foo": DataType.list("item", DataType.timestamp(coerce_to))})), + ("struct_nested_timestamp", DataType.struct({"foo": DataType.list(DataType.timestamp(coerce_to))})), ] expected = Schema._from_field_name_and_types(schema) diff --git a/tests/table/test_broadcasts.py b/tests/table/test_broadcasts.py index 196cc64dd4..561e9c0c25 100644 --- a/tests/table/test_broadcasts.py +++ b/tests/table/test_broadcasts.py @@ -18,6 +18,6 @@ def test_broadcast_fixed_size_list(): data = [1, 2, 3] table = Table.from_pydict({"x": [1, 2, 3]}) new_table = table.eval_expression_list( - [col("x"), lit(data).cast(daft.DataType.fixed_size_list("foo", daft.DataType.int64(), 3))] + [col("x"), lit(data).cast(daft.DataType.fixed_size_list(daft.DataType.int64(), 3))] ) assert new_table.to_pydict() == {"x": [1, 2, 3], "literal": [data for _ in range(3)]} diff --git a/tests/table/test_from_py.py b/tests/table/test_from_py.py index 9186c99722..265de512a6 100644 --- a/tests/table/test_from_py.py +++ b/tests/table/test_from_py.py @@ -40,7 +40,7 @@ "str": DataType.string(), "binary": DataType.binary(), "date": DataType.date(), - "list": DataType.list("item", DataType.int64()), + "list": DataType.list(DataType.int64()), "struct": DataType.struct({"a": DataType.int64(), "b": DataType.float64()}), "empty_struct": DataType.struct({"": DataType.null()}), "null": DataType.null(), @@ -100,8 +100,8 @@ ], pa.struct( { - "data": pa.large_list(pa.field("data", pa.int64())), - "shape": pa.large_list(pa.field("shape", pa.uint64())), + "data": pa.large_list(pa.field("item", pa.int64())), + "shape": pa.large_list(pa.field("item", pa.uint64())), } ), ), @@ -508,7 +508,7 @@ def test_nested_list_dates(levels: int) -> None: expected_dtype = DataType.date() expected_arrow_type = pa.date32() for _ in range(levels): - expected_dtype = DataType.list("item", expected_dtype) + expected_dtype = DataType.list(expected_dtype) expected_arrow_type = pa.large_list(pa.field("item", expected_arrow_type)) assert dtype == expected_dtype @@ -528,7 +528,7 @@ def test_nested_fixed_size_list_dates(levels: int) -> None: expected_dtype = DataType.date() expected_arrow_type = pa.date32() for _ in range(levels): - expected_dtype = DataType.fixed_size_list("item", expected_dtype, 2) + expected_dtype = DataType.fixed_size_list(expected_dtype, 2) expected_arrow_type = pa.list_(expected_arrow_type, 2) pa_data = pa.array(data, type=expected_arrow_type) diff --git a/tests/table/test_table_aggs.py b/tests/table/test_table_aggs.py index a621ca2d99..7e59ce3749 100644 --- a/tests/table/test_table_aggs.py +++ b/tests/table/test_table_aggs.py @@ -427,7 +427,7 @@ def test_global_list_aggs(dtype) -> None: daft_table = Table.from_pydict({"input": input}) daft_table = daft_table.eval_expression_list([col("input").cast(dtype)]) result = daft_table.eval_expression_list([col("input").alias("list")._agg_list()]) - assert result.get_column("list").datatype() == DataType.list("list", dtype) + assert result.get_column("list").datatype() == DataType.list(dtype) assert result.to_pydict() == {"list": [daft_table.to_pydict()["input"]]} @@ -452,7 +452,7 @@ def test_grouped_list_aggs(dtype) -> None: daft_table = Table.from_pydict({"groups": groups, "input": input}) daft_table = daft_table.eval_expression_list([col("groups"), col("input").cast(dtype)]) result = daft_table.agg([col("input").alias("list")._agg_list()], group_by=[col("groups")]).sort([col("groups")]) - assert result.get_column("list").datatype() == DataType.list("list", dtype) + assert result.get_column("list").datatype() == DataType.list(dtype) input_as_dtype = daft_table.get_column("input").to_pylist() expected_groups = [[input_as_dtype[i] for i in group] for group in expected_idx] @@ -478,7 +478,7 @@ def test_list_aggs_empty() -> None: [col("col_A").cast(DataType.int32()).alias("list")._agg_list()], group_by=[col("col_B")], ) - assert daft_table.get_column("list").datatype() == DataType.list("list", DataType.int32()) + assert daft_table.get_column("list").datatype() == DataType.list(DataType.int32()) res = daft_table.to_pydict() assert res == {"col_B": [], "list": []} @@ -498,11 +498,9 @@ def test_global_concat_aggs(dtype, with_null) -> None: if with_null: input += [None] - daft_table = Table.from_pydict({"input": input}).eval_expression_list( - [col("input").cast(DataType.list("item", dtype))] - ) + daft_table = Table.from_pydict({"input": input}).eval_expression_list([col("input").cast(DataType.list(dtype))]) concated = daft_table.agg([col("input").alias("concat")._agg_concat()]) - assert concated.get_column("concat").datatype() == DataType.list("item", dtype) + assert concated.get_column("concat").datatype() == DataType.list(dtype) input_as_dtype = daft_table.get_column("input").to_pylist() # We should ignore Null Array elements when performing the concat agg @@ -537,12 +535,12 @@ def test_grouped_concat_aggs(dtype) -> None: input = [[x] for x in input] + [None] groups = [1, 2, 3, 4, 5, 6, 7] daft_table = Table.from_pydict({"groups": groups, "input": input}).eval_expression_list( - [col("groups"), col("input").cast(DataType.list("item", dtype))] + [col("groups"), col("input").cast(DataType.list(dtype))] ) concat_grouped = daft_table.agg([col("input").alias("concat")._agg_concat()], group_by=[col("groups") % 2]).sort( [col("groups")] ) - assert concat_grouped.get_column("concat").datatype() == DataType.list("item", dtype) + assert concat_grouped.get_column("concat").datatype() == DataType.list(dtype) input_as_dtype = daft_table.get_column("input").to_pylist() # We should ignore Null Array elements when performing the concat agg @@ -580,11 +578,11 @@ def test_concat_aggs_empty() -> None: daft_table = Table.from_pydict({"col_A": [], "col_B": []}) daft_table = daft_table.agg( - [col("col_A").cast(DataType.list("list", DataType.int32())).alias("concat")._agg_concat()], + [col("col_A").cast(DataType.list(DataType.int32())).alias("concat")._agg_concat()], group_by=[col("col_B")], ) - assert daft_table.get_column("concat").datatype() == DataType.list("list", DataType.int32()) + assert daft_table.get_column("concat").datatype() == DataType.list(DataType.int32()) res = daft_table.to_pydict() assert res == {"col_B": [], "concat": []} diff --git a/tests/test_schema.py b/tests/test_schema.py index 50d9c6ab8f..0ce2912973 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -146,7 +146,7 @@ def test_schema_from_pyarrow(): [ ("int", DataType.int64()), ("str", DataType.string()), - ("list", DataType.list("item", DataType.int64())), + ("list", DataType.list(DataType.int64())), ] )