Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] [Images] [9/N] Infer Image type for PIL images on ingress. #1067

Merged
merged 7 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/array/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ fn extract_python_to_vec<
|| supports_array_interface_protocol
|| supports_array_protocol
{
// Path if object is supports buffer/array protocols.
// Path if object supports buffer/array protocols.
let np_as_array_fn = py.import("numpy")?.getattr(pyo3::intern!(py, "asarray"))?;
let pyarray = np_as_array_fn.call1((object,))?;
let num_values = append_values_from_numpy(
Expand Down
59 changes: 56 additions & 3 deletions src/python/series.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use std::ops::{Add, Div, Mul, Rem, Sub};
use std::{
ops::{Add, Div, Mul, Rem, Sub},
str::FromStr,
};

use pyo3::{exceptions::PyValueError, prelude::*, pyclass::CompareOp, types::PyList};

use crate::{
array::{ops::DaftLogical, pseudo_arrow::PseudoArrowArray, DataArray},
datatypes::{DataType, Field, ImageFormat, PythonType, UInt64Type},
datatypes::{DataType, Field, ImageFormat, ImageMode, PythonType, UInt64Type},
ffi,
series::{self, IntoSeries, Series},
utils::arrow::{cast_array_for_daft_if_needed, cast_array_from_daft_if_needed},
Expand Down Expand Up @@ -33,12 +36,18 @@
#[staticmethod]
pub fn from_pylist(name: &str, pylist: &PyAny) -> PyResult<Self> {
let vec_pyobj: Vec<PyObject> = pylist.extract()?;
let py = pylist.py();
let dtype = infer_daft_dtype_for_sequence(&vec_pyobj, py)?;
let arrow_array: Box<dyn arrow2::array::Array> =
Box::new(PseudoArrowArray::<PyObject>::from_pyobj_vec(vec_pyobj));
let field = Field::new(name, DataType::Python);

let data_array = DataArray::<PythonType>::new(field.into(), arrow_array)?;
Ok(data_array.into_series().into())
let series = match dtype {
Some(dtype) => data_array.cast(&dtype)?,
None => data_array.into_series(),
};
Ok(series.into())
}

// This is for PythonArrays only,
Expand Down Expand Up @@ -312,3 +321,47 @@
item.series
}
}

fn infer_daft_dtype_for_sequence(
vec_pyobj: &[PyObject],
py: pyo3::Python,
) -> PyResult<Option<DataType>> {
let py_pil_image_type = py
.import("PIL.Image")
clarkzinzow marked this conversation as resolved.
Show resolved Hide resolved
.and_then(|m| m.getattr(pyo3::intern!(py, "Image")));
let mut dtype: Option<DataType> = None;
for obj in vec_pyobj.iter() {
let obj = obj.as_ref(py);
if let Ok(pil_image_type) = py_pil_image_type {
if obj.is_instance(pil_image_type)? {
let mode_str = obj.getattr("mode")?.extract::<String>()?;
clarkzinzow marked this conversation as resolved.
Show resolved Hide resolved
let mode = ImageMode::from_str(&mode_str)?;
clarkzinzow marked this conversation as resolved.
Show resolved Hide resolved
match dtype {
Some(DataType::Image(Some(existing_mode))) => {
if existing_mode != mode {
// Mixed-mode case, set mode to None.
dtype = Some(DataType::Image(None));
}
}
None => {
// Set to (currently) uniform mode image dtype.
dtype = Some(DataType::Image(Some(mode)));
}
// No-op, since dtype is already for mixed-mode images.
Some(DataType::Image(None)) => {}
_ => {
// Images mixed with non-images; short-circuit since union dtypes are not (yet) supported.
dtype = None;
break;

Check warning on line 355 in src/python/series.rs

View check run for this annotation

Codecov / codecov/patch

src/python/series.rs#L354-L355

Added lines #L354 - L355 were not covered by tests
}
}
}
} else if !obj.is_none() {

Check warning on line 359 in src/python/series.rs

View check run for this annotation

Codecov / codecov/patch

src/python/series.rs#L359

Added line #L359 was not covered by tests
// Non-image types; short-circuit since only image types are supported and union dtypes are not (yet)
// supported.
dtype = None;
break;
}

Check warning on line 364 in src/python/series.rs

View check run for this annotation

Codecov / codecov/patch

src/python/series.rs#L362-L364

Added lines #L362 - L364 were not covered by tests
}
Ok(dtype)
}
16 changes: 12 additions & 4 deletions tests/dataframe/test_logical_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np
import pandas as pd
import pytest
from PIL import Image

import daft
from daft import DataType, Series, col
Expand All @@ -21,16 +23,22 @@ def test_embedding_type_df() -> None:
assert isinstance(arrow_table["embeddings"].type, DaftExtension)


def test_image_type_df() -> None:
@pytest.mark.parametrize("from_pil_imgs", [True, False])
def test_image_type_df(from_pil_imgs) -> None:
data = [
np.arange(12, dtype=np.uint8).reshape((3, 2, 2)),
np.arange(12, dtype=np.uint8).reshape((2, 2, 3)),
np.arange(12, 39, dtype=np.uint8).reshape((3, 3, 3)),
None,
]
if from_pil_imgs:
data = [Image.fromarray(arr, mode="RGB") if arr is not None else None for arr in data]
df = daft.from_pydict({"index": np.arange(len(data)), "image": Series.from_pylist(data, pyobj="force")})

target = DataType.image("RGB")
df = df.select(col("index"), col("image").cast(target))
image_expr = col("image")
if not from_pil_imgs:
target = DataType.image("RGB")
image_expr = image_expr.cast(target)
df = df.select(col("index"), image_expr)
df = df.repartition(4, "index")
df = df.sort("index")
df = df.collect()
Expand Down
23 changes: 0 additions & 23 deletions tests/series/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pandas as pd
import pyarrow as pa
import pytest
from PIL import Image

from daft.datatype import DataType, ImageMode, TimeUnit
from daft.series import Series
Expand Down Expand Up @@ -189,28 +188,6 @@ def test_series_cast_python_to_embedding(dtype) -> None:
np.testing.assert_equal([np.asarray(arr, dtype=dtype.to_pandas_dtype()) for arr in data[:-1]], pydata[:-1])


def test_series_cast_pil_to_image() -> None:
data = [
Image.fromarray(np.arange(12).reshape((2, 2, 3)).astype(np.uint8)),
Image.fromarray(np.arange(12, 39).reshape((3, 3, 3)).astype(np.uint8)),
None,
]
s = Series.from_pylist(data, pyobj="force")

target_dtype = DataType.image("RGB")

t = s.cast(target_dtype)

assert t.datatype() == target_dtype
assert len(t) == len(data)

assert t.arr.lengths().to_pylist() == [12, 27, None]

pydata = t.to_pylist()
assert pydata[-1] is None
np.testing.assert_equal([np.asarray(data[0]), np.asarray(data[1])], pydata[:-1])


def test_series_cast_numpy_to_image() -> None:
data = [
np.arange(12, dtype=np.uint8).reshape((3, 2, 2)),
Expand Down
78 changes: 78 additions & 0 deletions tests/series/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
"RGBA32F": 4,
}

NUM_CHANNELS_TO_MODE = {
1: "L",
2: "LA",
3: "RGB",
4: "RGBA",
}

MODE_TO_OPENCV_COLOR_CONVERSION = {
"RGB": cv2.COLOR_RGB2BGR,
"RGBA": cv2.COLOR_RGBA2BGRA,
Expand Down Expand Up @@ -136,6 +143,77 @@ def test_fixed_shape_image_round_trip():
np.testing.assert_equal(t_copy.to_pylist(), t.to_pylist())


@pytest.mark.parametrize(
"mode",
[
"L",
"LA",
"RGB",
"RGBA",
],
)
@pytest.mark.parametrize("fixed_shape", [True, False])
def test_image_pil_inference(fixed_shape, mode):
np_dtype = MODE_TO_NP_DTYPE[mode]
num_channels = MODE_TO_NUM_CHANNELS[mode]
if fixed_shape:
height = 4
width = 4
shape = (height, width)
if num_channels > 1:
shape += (num_channels,)
arr = np.arange(np.prod(shape)).reshape(shape).astype(np_dtype)
arrs = [arr, arr, None]
else:
shape1 = (2, 2)
shape2 = (3, 3)
if num_channels > 1:
shape1 += (num_channels,)
shape2 += (num_channels,)
arr1 = np.arange(np.prod(shape1)).reshape(shape1).astype(np_dtype)
arr2 = np.arange(np.prod(shape1), np.prod(shape1) + np.prod(shape2)).reshape(shape2).astype(np_dtype)
arrs = [arr1, arr2, None]
if mode in ("LA", "RGBA"):
for arr in arrs:
if arr is not None:
arr[..., -1] = 255
imgs = [Image.fromarray(arr, mode=mode) if arr is not None else None for arr in arrs]
s = Series.from_pylist(imgs, pyobj="force")
assert s.datatype() == DataType.image(mode)
out = s.to_pylist()
if num_channels == 1:
arrs = [np.expand_dims(arr, -1) for arr in arrs]
np.testing.assert_equal(out, arrs)


def test_image_pil_inference_mixed():
rgba = np.ones((2, 2, 4), dtype=np.uint8)
rgba[..., 1] = 2
rgba[..., 2] = 3
rgba[..., 3] = 4

arrs = [
rgba[..., :3], # RGB
rgba, # RGBA
np.arange(12, dtype=np.uint8).reshape((1, 4, 3)), # RGB
np.arange(12, dtype=np.uint8).reshape((3, 4)) * 10, # L
np.ones(24, dtype=np.uint8).reshape((3, 4, 2)) * 10, # LA
None,
]
print([arr.shape if arr is not None else None for arr in arrs])
imgs = [
Image.fromarray(arr, mode=NUM_CHANNELS_TO_MODE[arr.shape[-1] if arr.ndim == 3 else 1])
if arr is not None
else None
for arr in arrs
]
s = Series.from_pylist(imgs, pyobj="force")
assert s.datatype() == DataType.image()
out = s.to_pylist()
arrs[3] = np.expand_dims(arrs[3], axis=-1)
np.testing.assert_equal(out, arrs)


@pytest.mark.parametrize(
["mode", "file_format"],
[
Expand Down