Skip to content

Commit

Permalink
[BUG] Reenable HTML viz hooks for np.ndarray and PIL Images (#1078)
Browse files Browse the repository at this point in the history
Closes: #1077 and #1075 

* Adds a fix for `Series.from_pylist(..., pyobj="force")` when the list
is a list of PIL Images (cc @clarkzinzow)
* Re-enables our HTML viz hooks in Python, for Python objects when
calling into the HTML repr for Python arrays

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia committed Jun 22, 2023
1 parent 58b19e1 commit 1c41ad8
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 18 deletions.
4 changes: 2 additions & 2 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def from_pylist(data: list, name: str = "list_series", pyobj: str = "allow") ->
raise ValueError(f"pyobj: expected either 'allow', 'disallow', or 'force', but got {pyobj})")

if pyobj == "force":
pys = PySeries.from_pylist(name, data)
pys = PySeries.from_pylist(name, data, pyobj=pyobj)
return Series._from_pyseries(pys)

try:
Expand All @@ -106,7 +106,7 @@ def from_pylist(data: list, name: str = "list_series", pyobj: str = "allow") ->
except pa.lib.ArrowInvalid:
if pyobj == "disallow":
raise
pys = PySeries.from_pylist(name, data)
pys = PySeries.from_pylist(name, data, pyobj=pyobj)
return Series._from_pyseries(pys)

@classmethod
Expand Down
8 changes: 4 additions & 4 deletions daft/viz/html_viz_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
_VIZ_HOOKS_REGISTRY = {}


def register_viz_hook(klass: type[HookClass], hook: Callable[[HookClass], str]):
def register_viz_hook(klass: type[HookClass], hook: Callable[[object], str]):
"""Registers a visualization hook that returns the appropriate HTML for
visualizing a specific class in HTML"""
_VIZ_HOOKS_REGISTRY[klass] = hook


def get_viz_hook(val: HookClass) -> Callable[[HookClass], str] | None:
def get_viz_hook(val: object) -> Callable[[object], str] | None:
for klass in _VIZ_HOOKS_REGISTRY:
if isinstance(val, klass):
return _VIZ_HOOKS_REGISTRY[klass]
Expand All @@ -45,7 +45,7 @@ def get_viz_hook(val: HookClass) -> Callable[[HookClass], str] | None:

if HAS_PILLOW:

def _viz_pil_image(val: PIL.Image.Image):
def _viz_pil_image(val: PIL.Image.Image) -> str:
img = val.copy()
img.thumbnail((128, 128))
bio = io.BytesIO()
Expand All @@ -57,7 +57,7 @@ def _viz_pil_image(val: PIL.Image.Image):

if HAS_NUMPY:

def _viz_numpy(val: np.ndarray):
def _viz_numpy(val: np.ndarray) -> str:
return f"&ltnp.ndarray<br>shape={val.shape}<br>dtype={val.dtype}&gt"

register_viz_hook(np.ndarray, _viz_numpy)
36 changes: 32 additions & 4 deletions src/array/ops/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,43 @@ impl_array_html_value!(ListArray);
impl_array_html_value!(FixedSizeListArray);
impl_array_html_value!(StructArray);
impl_array_html_value!(ExtensionArray);

#[cfg(feature = "python")]
impl_array_html_value!(crate::datatypes::PythonArray);

impl_array_html_value!(DateArray);
impl_array_html_value!(DurationArray);
impl_array_html_value!(TimestampArray);
impl_array_html_value!(EmbeddingArray);

#[cfg(feature = "python")]
impl crate::datatypes::PythonArray {
pub fn html_value(&self, idx: usize) -> String {
use pyo3::prelude::*;

let val = self.get(idx);

let custom_viz_hook_result: Option<String> = Python::with_gil(|py| {
// Find visualization hooks for this object's class
let pyany = val.into_ref(py);
let get_viz_hook = py
.import("daft.viz.html_viz_hooks")?
.getattr("get_viz_hook")?;
let hook = get_viz_hook.call1((pyany,))?;

if hook.is_none() {
Ok(None)
} else {
hook.call1((pyany,))?.extract()
}
})
.unwrap();

match custom_viz_hook_result {
None => html_escape::encode_text(&self.str_value(idx).unwrap())
.into_owned()
.replace('\n', "<br />"),
Some(result) => result,
}
}
}

impl<T> DataArray<T>
where
T: DaftNumericType,
Expand Down
15 changes: 9 additions & 6 deletions src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,22 @@ impl PySeries {

// This ingests a Python list[object] directly into a Rust PythonArray.
#[staticmethod]
pub fn from_pylist(name: &str, pylist: &PyAny) -> PyResult<Self> {
pub fn from_pylist(name: &str, pylist: &PyAny, pyobj: &str) -> PyResult<Self> {
let vec_pyobj: Vec<PyObject> = pylist.extract()?;
let py = pylist.py();
let dtype = infer_daft_dtype_for_sequence(&vec_pyobj, py)?;

let dtype = match pyobj {
"force" => DataType::Python,
"allow" => infer_daft_dtype_for_sequence(&vec_pyobj, py)?.unwrap_or(DataType::Python),
"disallow" => panic!("Cannot create a Series from a pylist and being strict about only using Arrow types by setting pyobj=disallow"),
_ => panic!("Unsupported pyobj behavior when creating Series from pylist: {}", pyobj)
};
let arrow_array: Box<dyn arrow2::array::Array> =
Box::new(PseudoArrowArray::<PyObject>::from_pyobj_vec(vec_pyobj));
let field = Field::new(name, DataType::Python);

let data_array = DataArray::<PythonType>::new(field.into(), arrow_array)?;
let series = match dtype {
Some(dtype) => data_array.cast(&dtype)?,
None => data_array.into_series(),
};
let series = data_array.cast(&dtype)?;
Ok(series.into())
}

Expand Down
2 changes: 1 addition & 1 deletion tests/dataframe/test_logical_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_image_type_df(from_pil_imgs) -> None:
]
if from_pil_imgs:
data = [Image.fromarray(arr, mode="RGB") if arr is not None else None for arr in data]
df = daft.from_pydict({"index": np.arange(len(data)), "image": Series.from_pylist(data, pyobj="force")})
df = daft.from_pydict({"index": np.arange(len(data)), "image": Series.from_pylist(data, pyobj="allow")})

image_expr = col("image")
if not from_pil_imgs:
Expand Down
55 changes: 55 additions & 0 deletions tests/dataframe/test_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import re

import numpy as np
import pandas as pd
from PIL import Image

import daft

Expand Down Expand Up @@ -176,3 +178,56 @@ def test_repr_with_html_string():
for i in range(3):
assert f"<div>body{i}</div>" in non_html_table
assert f"<tr><td>&lt;div&gt;body{i}&lt;/div&gt;</td></tr>" in html_table


class MyObj:
def __repr__(self) -> str:
return "myobj-custom-repr"


def test_repr_html_custom_hooks():
img = Image.fromarray(np.ones((3, 3)).astype(np.uint8))
arr = np.ones((3, 3))

df = daft.from_pydict(
{
"objects": daft.Series.from_pylist([MyObj() for _ in range(3)], pyobj="force"),
"np": daft.Series.from_pylist([arr for _ in range(3)], pyobj="force"),
"pil": daft.Series.from_pylist([img for _ in range(3)], pyobj="force"),
}
)
df.collect()

assert (
df.__repr__()
== """+-------------------+-------------+----------------------------------+
| objects | np | pil |
| Python | Python | Python |
+-------------------+-------------+----------------------------------+
| myobj-custom-repr | [[1. 1. 1.] | <PIL.Image.Image image mode=L... |
| | [1. 1. 1.] | |
| | [1. ... | |
+-------------------+-------------+----------------------------------+
| myobj-custom-repr | [[1. 1. 1.] | <PIL.Image.Image image mode=L... |
| | [1. 1. 1.] | |
| | [1. ... | |
+-------------------+-------------+----------------------------------+
| myobj-custom-repr | [[1. 1. 1.] | <PIL.Image.Image image mode=L... |
| | [1. 1. 1.] | |
| | [1. ... | |
+-------------------+-------------+----------------------------------+
(Showing first 3 of 3 rows)"""
)

html_repr = df._repr_html_()

# Assert that MyObj is correctly displayed in html repr (falls back to __repr__)
assert "myobj-custom-repr" in html_repr

# Assert that PIL viz hook correctly triggers in html repr
assert 'alt="<PIL.Image.Image image mode=L size=3x3' in html_repr
assert '<img style="max-height:128px;width:auto" src="data:image/png;base64,' in html_repr

# Assert that numpy array viz hook correctly triggers in html repr
assert "<td>&ltnp.ndarray<br>shape=(3, 3)<br>dtype=float64&gt</td><td>" in html_repr
7 changes: 6 additions & 1 deletion tests/series/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_image_pil_inference(fixed_shape, mode):
if arr is not None:
arr[..., -1] = 255
imgs = [Image.fromarray(arr, mode=mode) if arr is not None else None for arr in arrs]
s = Series.from_pylist(imgs, pyobj="force")
s = Series.from_pylist(imgs, pyobj="allow")
assert s.datatype() == DataType.image(mode)
out = s.to_pylist()
if num_channels == 1:
Expand Down Expand Up @@ -206,7 +206,12 @@ def test_image_pil_inference_mixed():
else None
for arr in arrs
]

# Forcing should still create Python Series
s = Series.from_pylist(imgs, pyobj="force")
assert s.datatype() == DataType.python()

s = Series.from_pylist(imgs, pyobj="allow")
assert s.datatype() == DataType.image()
out = s.to_pylist()
arrs[3] = np.expand_dims(arrs[3], axis=-1)
Expand Down

0 comments on commit 1c41ad8

Please sign in to comment.