From 6d1feb777fec0a80062164f79f3e629342b7406b Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Wed, 21 Jun 2023 10:09:46 -0700 Subject: [PATCH] [FEAT] [Images] [9/N] Infer `Image` type for PIL images on ingress. (#1067) This PR infers the `Image` type for PIL images provided to `Series.from_pylist()`, and therefore for `Table.from_pydict()` and `DataFrame.from_pydict()`. --- src/array/ops/cast.rs | 2 +- src/datatypes/image_mode.rs | 20 ++++++++ src/python/series.rs | 56 +++++++++++++++++++- tests/dataframe/test_logical_type.py | 16 ++++-- tests/series/test_cast.py | 23 --------- tests/series/test_image.py | 77 ++++++++++++++++++++++++++++ 6 files changed, 164 insertions(+), 30 deletions(-) diff --git a/src/array/ops/cast.rs b/src/array/ops/cast.rs index 0f2a0797ed..df474cb9fd 100644 --- a/src/array/ops/cast.rs +++ b/src/array/ops/cast.rs @@ -437,7 +437,7 @@ fn extract_python_to_vec< || supports_array_interface_protocol || supports_array_protocol { - // Path if object is supports buffer/array protocols. + // Path if object supports buffer/array protocols. let np_as_array_fn = py.import("numpy")?.getattr(pyo3::intern!(py, "asarray"))?; let pyarray = np_as_array_fn.call1((object,))?; let num_values = append_values_from_numpy( diff --git a/src/datatypes/image_mode.rs b/src/datatypes/image_mode.rs index 1a1da18e39..5ae8fff0ab 100644 --- a/src/datatypes/image_mode.rs +++ b/src/datatypes/image_mode.rs @@ -59,6 +59,26 @@ impl ImageMode { } impl ImageMode { + pub fn from_pil_mode_str(mode: &str) -> DaftResult { + use ImageMode::*; + + match mode { + "L" => Ok(L), + "LA" => Ok(LA), + "RGB" => Ok(RGB), + "RGBA" => Ok(RGBA), + "1" | "P" | "CMYK" | "YCbCr" | "LAB" | "HSV" | "I" | "F" | "PA" | "RGBX" | "RGBa" | "La" | "I;16" | "I;16L" | "I;16B" | "I;16N" | "BGR;15" | "BGR;16" | "BGR;24" => Err(DaftError::TypeError(format!( + "PIL image mode {} is not supported; only the following modes are supported: {:?}", + mode, + ImageMode::iterator().as_slice() + ))), + _ => Err(DaftError::TypeError(format!( + "Image mode {} is not a valid PIL image mode; see https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes for valid PIL image modes. Of these, only the following modes are supported by Daft: {:?}", + mode, + ImageMode::iterator().as_slice() + ))), + } + } pub fn try_from_num_channels(num_channels: u16, dtype: &DataType) -> DaftResult { use ImageMode::*; diff --git a/src/python/series.rs b/src/python/series.rs index a6c5d3382a..bb889e4953 100644 --- a/src/python/series.rs +++ b/src/python/series.rs @@ -4,7 +4,7 @@ use pyo3::{exceptions::PyValueError, prelude::*, pyclass::CompareOp, types::PyLi use crate::{ array::{ops::DaftLogical, pseudo_arrow::PseudoArrowArray, DataArray}, - datatypes::{DataType, Field, ImageFormat, PythonType, UInt64Type}, + datatypes::{DataType, Field, ImageFormat, ImageMode, PythonType, UInt64Type}, ffi, series::{self, IntoSeries, Series}, utils::arrow::{cast_array_for_daft_if_needed, cast_array_from_daft_if_needed}, @@ -33,12 +33,18 @@ impl PySeries { #[staticmethod] pub fn from_pylist(name: &str, pylist: &PyAny) -> PyResult { let vec_pyobj: Vec = pylist.extract()?; + let py = pylist.py(); + let dtype = infer_daft_dtype_for_sequence(&vec_pyobj, py)?; let arrow_array: Box = Box::new(PseudoArrowArray::::from_pyobj_vec(vec_pyobj)); let field = Field::new(name, DataType::Python); let data_array = DataArray::::new(field.into(), arrow_array)?; - Ok(data_array.into_series().into()) + let series = match dtype { + Some(dtype) => data_array.cast(&dtype)?, + None => data_array.into_series(), + }; + Ok(series.into()) } // This is for PythonArrays only, @@ -312,3 +318,49 @@ impl From for series::Series { item.series } } + +fn infer_daft_dtype_for_sequence( + vec_pyobj: &[PyObject], + py: pyo3::Python, +) -> PyResult> { + let py_pil_image_type = py + .import(pyo3::intern!(py, "PIL.Image")) + .and_then(|m| m.getattr(pyo3::intern!(py, "Image"))); + let mut dtype: Option = None; + for obj in vec_pyobj.iter() { + let obj = obj.as_ref(py); + if let Ok(pil_image_type) = py_pil_image_type { + if obj.is_instance(pil_image_type)? { + let mode_str = obj + .getattr(pyo3::intern!(py, "mode"))? + .extract::()?; + let mode = ImageMode::from_pil_mode_str(&mode_str)?; + match dtype { + Some(DataType::Image(Some(existing_mode))) => { + if existing_mode != mode { + // Mixed-mode case, set mode to None. + dtype = Some(DataType::Image(None)); + } + } + None => { + // Set to (currently) uniform mode image dtype. + dtype = Some(DataType::Image(Some(mode))); + } + // No-op, since dtype is already for mixed-mode images. + Some(DataType::Image(None)) => {} + _ => { + // Images mixed with non-images; short-circuit since union dtypes are not (yet) supported. + dtype = None; + break; + } + } + } + } else if !obj.is_none() { + // Non-image types; short-circuit since only image types are supported and union dtypes are not (yet) + // supported. + dtype = None; + break; + } + } + Ok(dtype) +} diff --git a/tests/dataframe/test_logical_type.py b/tests/dataframe/test_logical_type.py index e3f03b71a0..d3d5224b72 100644 --- a/tests/dataframe/test_logical_type.py +++ b/tests/dataframe/test_logical_type.py @@ -2,6 +2,8 @@ import numpy as np import pandas as pd +import pytest +from PIL import Image import daft from daft import DataType, Series, col @@ -21,16 +23,22 @@ def test_embedding_type_df() -> None: assert isinstance(arrow_table["embeddings"].type, DaftExtension) -def test_image_type_df() -> None: +@pytest.mark.parametrize("from_pil_imgs", [True, False]) +def test_image_type_df(from_pil_imgs) -> None: data = [ - np.arange(12, dtype=np.uint8).reshape((3, 2, 2)), + np.arange(12, dtype=np.uint8).reshape((2, 2, 3)), np.arange(12, 39, dtype=np.uint8).reshape((3, 3, 3)), None, ] + if from_pil_imgs: + data = [Image.fromarray(arr, mode="RGB") if arr is not None else None for arr in data] df = daft.from_pydict({"index": np.arange(len(data)), "image": Series.from_pylist(data, pyobj="force")}) - target = DataType.image("RGB") - df = df.select(col("index"), col("image").cast(target)) + image_expr = col("image") + if not from_pil_imgs: + target = DataType.image("RGB") + image_expr = image_expr.cast(target) + df = df.select(col("index"), image_expr) df = df.repartition(4, "index") df = df.sort("index") df = df.collect() diff --git a/tests/series/test_cast.py b/tests/series/test_cast.py index e0f2b33ca7..46a4ddef8f 100644 --- a/tests/series/test_cast.py +++ b/tests/series/test_cast.py @@ -7,7 +7,6 @@ import pandas as pd import pyarrow as pa import pytest -from PIL import Image from daft.datatype import DataType, ImageMode, TimeUnit from daft.series import Series @@ -189,28 +188,6 @@ def test_series_cast_python_to_embedding(dtype) -> None: np.testing.assert_equal([np.asarray(arr, dtype=dtype.to_pandas_dtype()) for arr in data[:-1]], pydata[:-1]) -def test_series_cast_pil_to_image() -> None: - data = [ - Image.fromarray(np.arange(12).reshape((2, 2, 3)).astype(np.uint8)), - Image.fromarray(np.arange(12, 39).reshape((3, 3, 3)).astype(np.uint8)), - None, - ] - s = Series.from_pylist(data, pyobj="force") - - target_dtype = DataType.image("RGB") - - t = s.cast(target_dtype) - - assert t.datatype() == target_dtype - assert len(t) == len(data) - - assert t.arr.lengths().to_pylist() == [12, 27, None] - - pydata = t.to_pylist() - assert pydata[-1] is None - np.testing.assert_equal([np.asarray(data[0]), np.asarray(data[1])], pydata[:-1]) - - def test_series_cast_numpy_to_image() -> None: data = [ np.arange(12, dtype=np.uint8).reshape((3, 2, 2)), diff --git a/tests/series/test_image.py b/tests/series/test_image.py index 5d2a044e54..99c658aaf9 100644 --- a/tests/series/test_image.py +++ b/tests/series/test_image.py @@ -38,6 +38,13 @@ "RGBA32F": 4, } +NUM_CHANNELS_TO_MODE = { + 1: "L", + 2: "LA", + 3: "RGB", + 4: "RGBA", +} + MODE_TO_OPENCV_COLOR_CONVERSION = { "RGB": cv2.COLOR_RGB2BGR, "RGBA": cv2.COLOR_RGBA2BGRA, @@ -136,6 +143,76 @@ def test_fixed_shape_image_round_trip(): np.testing.assert_equal(t_copy.to_pylist(), t.to_pylist()) +@pytest.mark.parametrize( + "mode", + [ + "L", + "LA", + "RGB", + "RGBA", + ], +) +@pytest.mark.parametrize("fixed_shape", [True, False]) +def test_image_pil_inference(fixed_shape, mode): + np_dtype = MODE_TO_NP_DTYPE[mode] + num_channels = MODE_TO_NUM_CHANNELS[mode] + if fixed_shape: + height = 4 + width = 4 + shape = (height, width) + if num_channels > 1: + shape += (num_channels,) + arr = np.arange(np.prod(shape)).reshape(shape).astype(np_dtype) + arrs = [arr, arr, None] + else: + shape1 = (2, 2) + shape2 = (3, 3) + if num_channels > 1: + shape1 += (num_channels,) + shape2 += (num_channels,) + arr1 = np.arange(np.prod(shape1)).reshape(shape1).astype(np_dtype) + arr2 = np.arange(np.prod(shape1), np.prod(shape1) + np.prod(shape2)).reshape(shape2).astype(np_dtype) + arrs = [arr1, arr2, None] + if mode in ("LA", "RGBA"): + for arr in arrs: + if arr is not None: + arr[..., -1] = 255 + imgs = [Image.fromarray(arr, mode=mode) if arr is not None else None for arr in arrs] + s = Series.from_pylist(imgs, pyobj="force") + assert s.datatype() == DataType.image(mode) + out = s.to_pylist() + if num_channels == 1: + arrs = [np.expand_dims(arr, -1) for arr in arrs] + np.testing.assert_equal(out, arrs) + + +def test_image_pil_inference_mixed(): + rgba = np.ones((2, 2, 4), dtype=np.uint8) + rgba[..., 1] = 2 + rgba[..., 2] = 3 + rgba[..., 3] = 4 + + arrs = [ + rgba[..., :3], # RGB + rgba, # RGBA + np.arange(12, dtype=np.uint8).reshape((1, 4, 3)), # RGB + np.arange(12, dtype=np.uint8).reshape((3, 4)) * 10, # L + np.ones(24, dtype=np.uint8).reshape((3, 4, 2)) * 10, # LA + None, + ] + imgs = [ + Image.fromarray(arr, mode=NUM_CHANNELS_TO_MODE[arr.shape[-1] if arr.ndim == 3 else 1]) + if arr is not None + else None + for arr in arrs + ] + s = Series.from_pylist(imgs, pyobj="force") + assert s.datatype() == DataType.image() + out = s.to_pylist() + arrs[3] = np.expand_dims(arrs[3], axis=-1) + np.testing.assert_equal(out, arrs) + + @pytest.mark.parametrize( ["mode", "file_format"], [