Skip to content

Commit

Permalink
[FEAT] [Images] [9/N] Infer Image type for PIL images on ingress. (#…
Browse files Browse the repository at this point in the history
…1067)

This PR infers the `Image` type for PIL images provided to
`Series.from_pylist()`, and therefore for `Table.from_pydict()` and
`DataFrame.from_pydict()`.
  • Loading branch information
clarkzinzow committed Jun 21, 2023
1 parent d3aa3b4 commit 6d1feb7
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/array/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ fn extract_python_to_vec<
|| supports_array_interface_protocol
|| supports_array_protocol
{
// Path if object is supports buffer/array protocols.
// Path if object supports buffer/array protocols.
let np_as_array_fn = py.import("numpy")?.getattr(pyo3::intern!(py, "asarray"))?;
let pyarray = np_as_array_fn.call1((object,))?;
let num_values = append_values_from_numpy(
Expand Down
20 changes: 20 additions & 0 deletions src/datatypes/image_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,26 @@ impl ImageMode {
}

impl ImageMode {
pub fn from_pil_mode_str(mode: &str) -> DaftResult<Self> {
use ImageMode::*;

match mode {
"L" => Ok(L),
"LA" => Ok(LA),
"RGB" => Ok(RGB),
"RGBA" => Ok(RGBA),
"1" | "P" | "CMYK" | "YCbCr" | "LAB" | "HSV" | "I" | "F" | "PA" | "RGBX" | "RGBa" | "La" | "I;16" | "I;16L" | "I;16B" | "I;16N" | "BGR;15" | "BGR;16" | "BGR;24" => Err(DaftError::TypeError(format!(
"PIL image mode {} is not supported; only the following modes are supported: {:?}",
mode,
ImageMode::iterator().as_slice()
))),
_ => Err(DaftError::TypeError(format!(
"Image mode {} is not a valid PIL image mode; see https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes for valid PIL image modes. Of these, only the following modes are supported by Daft: {:?}",
mode,
ImageMode::iterator().as_slice()
))),
}
}
pub fn try_from_num_channels(num_channels: u16, dtype: &DataType) -> DaftResult<Self> {
use ImageMode::*;

Expand Down
56 changes: 54 additions & 2 deletions src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use pyo3::{exceptions::PyValueError, prelude::*, pyclass::CompareOp, types::PyLi

use crate::{
array::{ops::DaftLogical, pseudo_arrow::PseudoArrowArray, DataArray},
datatypes::{DataType, Field, ImageFormat, PythonType, UInt64Type},
datatypes::{DataType, Field, ImageFormat, ImageMode, PythonType, UInt64Type},
ffi,
series::{self, IntoSeries, Series},
utils::arrow::{cast_array_for_daft_if_needed, cast_array_from_daft_if_needed},
Expand Down Expand Up @@ -33,12 +33,18 @@ impl PySeries {
#[staticmethod]
pub fn from_pylist(name: &str, pylist: &PyAny) -> PyResult<Self> {
let vec_pyobj: Vec<PyObject> = pylist.extract()?;
let py = pylist.py();
let dtype = infer_daft_dtype_for_sequence(&vec_pyobj, py)?;
let arrow_array: Box<dyn arrow2::array::Array> =
Box::new(PseudoArrowArray::<PyObject>::from_pyobj_vec(vec_pyobj));
let field = Field::new(name, DataType::Python);

let data_array = DataArray::<PythonType>::new(field.into(), arrow_array)?;
Ok(data_array.into_series().into())
let series = match dtype {
Some(dtype) => data_array.cast(&dtype)?,
None => data_array.into_series(),
};
Ok(series.into())
}

// This is for PythonArrays only,
Expand Down Expand Up @@ -312,3 +318,49 @@ impl From<PySeries> for series::Series {
item.series
}
}

fn infer_daft_dtype_for_sequence(
vec_pyobj: &[PyObject],
py: pyo3::Python,
) -> PyResult<Option<DataType>> {
let py_pil_image_type = py
.import(pyo3::intern!(py, "PIL.Image"))
.and_then(|m| m.getattr(pyo3::intern!(py, "Image")));
let mut dtype: Option<DataType> = None;
for obj in vec_pyobj.iter() {
let obj = obj.as_ref(py);
if let Ok(pil_image_type) = py_pil_image_type {
if obj.is_instance(pil_image_type)? {
let mode_str = obj
.getattr(pyo3::intern!(py, "mode"))?
.extract::<String>()?;
let mode = ImageMode::from_pil_mode_str(&mode_str)?;
match dtype {
Some(DataType::Image(Some(existing_mode))) => {
if existing_mode != mode {
// Mixed-mode case, set mode to None.
dtype = Some(DataType::Image(None));
}
}
None => {
// Set to (currently) uniform mode image dtype.
dtype = Some(DataType::Image(Some(mode)));
}
// No-op, since dtype is already for mixed-mode images.
Some(DataType::Image(None)) => {}
_ => {
// Images mixed with non-images; short-circuit since union dtypes are not (yet) supported.
dtype = None;
break;
}
}
}
} else if !obj.is_none() {
// Non-image types; short-circuit since only image types are supported and union dtypes are not (yet)
// supported.
dtype = None;
break;
}
}
Ok(dtype)
}
16 changes: 12 additions & 4 deletions tests/dataframe/test_logical_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np
import pandas as pd
import pytest
from PIL import Image

import daft
from daft import DataType, Series, col
Expand All @@ -21,16 +23,22 @@ def test_embedding_type_df() -> None:
assert isinstance(arrow_table["embeddings"].type, DaftExtension)


def test_image_type_df() -> None:
@pytest.mark.parametrize("from_pil_imgs", [True, False])
def test_image_type_df(from_pil_imgs) -> None:
data = [
np.arange(12, dtype=np.uint8).reshape((3, 2, 2)),
np.arange(12, dtype=np.uint8).reshape((2, 2, 3)),
np.arange(12, 39, dtype=np.uint8).reshape((3, 3, 3)),
None,
]
if from_pil_imgs:
data = [Image.fromarray(arr, mode="RGB") if arr is not None else None for arr in data]
df = daft.from_pydict({"index": np.arange(len(data)), "image": Series.from_pylist(data, pyobj="force")})

target = DataType.image("RGB")
df = df.select(col("index"), col("image").cast(target))
image_expr = col("image")
if not from_pil_imgs:
target = DataType.image("RGB")
image_expr = image_expr.cast(target)
df = df.select(col("index"), image_expr)
df = df.repartition(4, "index")
df = df.sort("index")
df = df.collect()
Expand Down
23 changes: 0 additions & 23 deletions tests/series/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pandas as pd
import pyarrow as pa
import pytest
from PIL import Image

from daft.datatype import DataType, ImageMode, TimeUnit
from daft.series import Series
Expand Down Expand Up @@ -189,28 +188,6 @@ def test_series_cast_python_to_embedding(dtype) -> None:
np.testing.assert_equal([np.asarray(arr, dtype=dtype.to_pandas_dtype()) for arr in data[:-1]], pydata[:-1])


def test_series_cast_pil_to_image() -> None:
data = [
Image.fromarray(np.arange(12).reshape((2, 2, 3)).astype(np.uint8)),
Image.fromarray(np.arange(12, 39).reshape((3, 3, 3)).astype(np.uint8)),
None,
]
s = Series.from_pylist(data, pyobj="force")

target_dtype = DataType.image("RGB")

t = s.cast(target_dtype)

assert t.datatype() == target_dtype
assert len(t) == len(data)

assert t.arr.lengths().to_pylist() == [12, 27, None]

pydata = t.to_pylist()
assert pydata[-1] is None
np.testing.assert_equal([np.asarray(data[0]), np.asarray(data[1])], pydata[:-1])


def test_series_cast_numpy_to_image() -> None:
data = [
np.arange(12, dtype=np.uint8).reshape((3, 2, 2)),
Expand Down
77 changes: 77 additions & 0 deletions tests/series/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
"RGBA32F": 4,
}

NUM_CHANNELS_TO_MODE = {
1: "L",
2: "LA",
3: "RGB",
4: "RGBA",
}

MODE_TO_OPENCV_COLOR_CONVERSION = {
"RGB": cv2.COLOR_RGB2BGR,
"RGBA": cv2.COLOR_RGBA2BGRA,
Expand Down Expand Up @@ -136,6 +143,76 @@ def test_fixed_shape_image_round_trip():
np.testing.assert_equal(t_copy.to_pylist(), t.to_pylist())


@pytest.mark.parametrize(
"mode",
[
"L",
"LA",
"RGB",
"RGBA",
],
)
@pytest.mark.parametrize("fixed_shape", [True, False])
def test_image_pil_inference(fixed_shape, mode):
np_dtype = MODE_TO_NP_DTYPE[mode]
num_channels = MODE_TO_NUM_CHANNELS[mode]
if fixed_shape:
height = 4
width = 4
shape = (height, width)
if num_channels > 1:
shape += (num_channels,)
arr = np.arange(np.prod(shape)).reshape(shape).astype(np_dtype)
arrs = [arr, arr, None]
else:
shape1 = (2, 2)
shape2 = (3, 3)
if num_channels > 1:
shape1 += (num_channels,)
shape2 += (num_channels,)
arr1 = np.arange(np.prod(shape1)).reshape(shape1).astype(np_dtype)
arr2 = np.arange(np.prod(shape1), np.prod(shape1) + np.prod(shape2)).reshape(shape2).astype(np_dtype)
arrs = [arr1, arr2, None]
if mode in ("LA", "RGBA"):
for arr in arrs:
if arr is not None:
arr[..., -1] = 255
imgs = [Image.fromarray(arr, mode=mode) if arr is not None else None for arr in arrs]
s = Series.from_pylist(imgs, pyobj="force")
assert s.datatype() == DataType.image(mode)
out = s.to_pylist()
if num_channels == 1:
arrs = [np.expand_dims(arr, -1) for arr in arrs]
np.testing.assert_equal(out, arrs)


def test_image_pil_inference_mixed():
rgba = np.ones((2, 2, 4), dtype=np.uint8)
rgba[..., 1] = 2
rgba[..., 2] = 3
rgba[..., 3] = 4

arrs = [
rgba[..., :3], # RGB
rgba, # RGBA
np.arange(12, dtype=np.uint8).reshape((1, 4, 3)), # RGB
np.arange(12, dtype=np.uint8).reshape((3, 4)) * 10, # L
np.ones(24, dtype=np.uint8).reshape((3, 4, 2)) * 10, # LA
None,
]
imgs = [
Image.fromarray(arr, mode=NUM_CHANNELS_TO_MODE[arr.shape[-1] if arr.ndim == 3 else 1])
if arr is not None
else None
for arr in arrs
]
s = Series.from_pylist(imgs, pyobj="force")
assert s.datatype() == DataType.image()
out = s.to_pylist()
arrs[3] = np.expand_dims(arrs[3], axis=-1)
np.testing.assert_equal(out, arrs)


@pytest.mark.parametrize(
["mode", "file_format"],
[
Expand Down

0 comments on commit 6d1feb7

Please sign in to comment.