Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Reenable HTML viz hooks for np.ndarray and PIL Images #1078

Merged
merged 3 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def from_pylist(data: list, name: str = "list_series", pyobj: str = "allow") ->
raise ValueError(f"pyobj: expected either 'allow', 'disallow', or 'force', but got {pyobj})")

if pyobj == "force":
pys = PySeries.from_pylist(name, data)
pys = PySeries.from_pylist(name, data, pyobj=pyobj)
return Series._from_pyseries(pys)

try:
Expand All @@ -106,7 +106,7 @@ def from_pylist(data: list, name: str = "list_series", pyobj: str = "allow") ->
except pa.lib.ArrowInvalid:
if pyobj == "disallow":
raise
pys = PySeries.from_pylist(name, data)
pys = PySeries.from_pylist(name, data, pyobj=pyobj)
return Series._from_pyseries(pys)

@classmethod
Expand Down
8 changes: 4 additions & 4 deletions daft/viz/html_viz_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
_VIZ_HOOKS_REGISTRY = {}


def register_viz_hook(klass: type[HookClass], hook: Callable[[HookClass], str]):
def register_viz_hook(klass: type[HookClass], hook: Callable[[object], str]):
"""Registers a visualization hook that returns the appropriate HTML for
visualizing a specific class in HTML"""
_VIZ_HOOKS_REGISTRY[klass] = hook


def get_viz_hook(val: HookClass) -> Callable[[HookClass], str] | None:
def get_viz_hook(val: object) -> Callable[[object], str] | None:
for klass in _VIZ_HOOKS_REGISTRY:
if isinstance(val, klass):
return _VIZ_HOOKS_REGISTRY[klass]
Expand All @@ -45,7 +45,7 @@ def get_viz_hook(val: HookClass) -> Callable[[HookClass], str] | None:

if HAS_PILLOW:

def _viz_pil_image(val: PIL.Image.Image):
def _viz_pil_image(val: PIL.Image.Image) -> str:
img = val.copy()
img.thumbnail((128, 128))
bio = io.BytesIO()
Expand All @@ -57,7 +57,7 @@ def _viz_pil_image(val: PIL.Image.Image):

if HAS_NUMPY:

def _viz_numpy(val: np.ndarray):
def _viz_numpy(val: np.ndarray) -> str:
return f"&ltnp.ndarray<br>shape={val.shape}<br>dtype={val.dtype}&gt"

register_viz_hook(np.ndarray, _viz_numpy)
36 changes: 32 additions & 4 deletions src/array/ops/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,43 @@ impl_array_html_value!(ListArray);
impl_array_html_value!(FixedSizeListArray);
impl_array_html_value!(StructArray);
impl_array_html_value!(ExtensionArray);

#[cfg(feature = "python")]
impl_array_html_value!(crate::datatypes::PythonArray);

impl_array_html_value!(DateArray);
impl_array_html_value!(DurationArray);
impl_array_html_value!(TimestampArray);
impl_array_html_value!(EmbeddingArray);

#[cfg(feature = "python")]
impl crate::datatypes::PythonArray {
pub fn html_value(&self, idx: usize) -> String {
use pyo3::prelude::*;

let val = self.get(idx);

let custom_viz_hook_result: Option<String> = Python::with_gil(|py| {
// Find visualization hooks for this object's class
let pyany = val.into_ref(py);
let get_viz_hook = py
.import("daft.viz.html_viz_hooks")?
.getattr("get_viz_hook")?;
let hook = get_viz_hook.call1((pyany,))?;

if hook.is_none() {
Ok(None)
} else {
hook.call1((pyany,))?.extract()
}
})
.unwrap();

match custom_viz_hook_result {
None => html_escape::encode_text(&self.str_value(idx).unwrap())
.into_owned()
.replace('\n', "<br />"),
Some(result) => result,
}
}
}

impl<T> DataArray<T>
where
T: DaftNumericType,
Expand Down
15 changes: 9 additions & 6 deletions src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,22 @@ impl PySeries {

// This ingests a Python list[object] directly into a Rust PythonArray.
#[staticmethod]
pub fn from_pylist(name: &str, pylist: &PyAny) -> PyResult<Self> {
pub fn from_pylist(name: &str, pylist: &PyAny, pyobj: &str) -> PyResult<Self> {
let vec_pyobj: Vec<PyObject> = pylist.extract()?;
let py = pylist.py();
let dtype = infer_daft_dtype_for_sequence(&vec_pyobj, py)?;

let dtype = match pyobj {
"force" => DataType::Python,
"allow" => infer_daft_dtype_for_sequence(&vec_pyobj, py)?.unwrap_or(DataType::Python),
"disallow" => panic!("Cannot create a Series from a pylist and being strict about only using Arrow types by setting pyobj=disallow"),
_ => panic!("Unsupported pyobj behavior when creating Series from pylist: {}", pyobj)
};
let arrow_array: Box<dyn arrow2::array::Array> =
Box::new(PseudoArrowArray::<PyObject>::from_pyobj_vec(vec_pyobj));
let field = Field::new(name, DataType::Python);

let data_array = DataArray::<PythonType>::new(field.into(), arrow_array)?;
let series = match dtype {
Some(dtype) => data_array.cast(&dtype)?,
None => data_array.into_series(),
};
let series = data_array.cast(&dtype)?;
Ok(series.into())
}

Expand Down
2 changes: 1 addition & 1 deletion tests/dataframe/test_logical_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_image_type_df(from_pil_imgs) -> None:
]
if from_pil_imgs:
data = [Image.fromarray(arr, mode="RGB") if arr is not None else None for arr in data]
df = daft.from_pydict({"index": np.arange(len(data)), "image": Series.from_pylist(data, pyobj="force")})
df = daft.from_pydict({"index": np.arange(len(data)), "image": Series.from_pylist(data, pyobj="allow")})

image_expr = col("image")
if not from_pil_imgs:
Expand Down
55 changes: 55 additions & 0 deletions tests/dataframe/test_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import re

import numpy as np
import pandas as pd
from PIL import Image

import daft

Expand Down Expand Up @@ -176,3 +178,56 @@ def test_repr_with_html_string():
for i in range(3):
assert f"<div>body{i}</div>" in non_html_table
assert f"<tr><td>&lt;div&gt;body{i}&lt;/div&gt;</td></tr>" in html_table


class MyObj:
def __repr__(self) -> str:
return "myobj-custom-repr"


def test_repr_html_custom_hooks():
img = Image.fromarray(np.ones((3, 3)).astype(np.uint8))
arr = np.ones((3, 3))

df = daft.from_pydict(
{
"objects": daft.Series.from_pylist([MyObj() for _ in range(3)], pyobj="force"),
"np": daft.Series.from_pylist([arr for _ in range(3)], pyobj="force"),
"pil": daft.Series.from_pylist([img for _ in range(3)], pyobj="force"),
}
)
df.collect()

assert (
df.__repr__()
== """+-------------------+-------------+----------------------------------+
| objects | np | pil |
| Python | Python | Python |
+-------------------+-------------+----------------------------------+
| myobj-custom-repr | [[1. 1. 1.] | <PIL.Image.Image image mode=L... |
| | [1. 1. 1.] | |
| | [1. ... | |
+-------------------+-------------+----------------------------------+
| myobj-custom-repr | [[1. 1. 1.] | <PIL.Image.Image image mode=L... |
| | [1. 1. 1.] | |
| | [1. ... | |
+-------------------+-------------+----------------------------------+
| myobj-custom-repr | [[1. 1. 1.] | <PIL.Image.Image image mode=L... |
| | [1. 1. 1.] | |
| | [1. ... | |
+-------------------+-------------+----------------------------------+

(Showing first 3 of 3 rows)"""
)

html_repr = df._repr_html_()

# Assert that MyObj is correctly displayed in html repr (falls back to __repr__)
assert "myobj-custom-repr" in html_repr

# Assert that PIL viz hook correctly triggers in html repr
assert 'alt="<PIL.Image.Image image mode=L size=3x3' in html_repr
assert '<img style="max-height:128px;width:auto" src="data:image/png;base64,' in html_repr

# Assert that numpy array viz hook correctly triggers in html repr
assert "<td>&ltnp.ndarray<br>shape=(3, 3)<br>dtype=float64&gt</td><td>" in html_repr
7 changes: 6 additions & 1 deletion tests/series/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_image_pil_inference(fixed_shape, mode):
if arr is not None:
arr[..., -1] = 255
imgs = [Image.fromarray(arr, mode=mode) if arr is not None else None for arr in arrs]
s = Series.from_pylist(imgs, pyobj="force")
s = Series.from_pylist(imgs, pyobj="allow")
assert s.datatype() == DataType.image(mode)
out = s.to_pylist()
if num_channels == 1:
Expand Down Expand Up @@ -206,7 +206,12 @@ def test_image_pil_inference_mixed():
else None
for arr in arrs
]

# Forcing should still create Python Series
s = Series.from_pylist(imgs, pyobj="force")
assert s.datatype() == DataType.python()

s = Series.from_pylist(imgs, pyobj="allow")
assert s.datatype() == DataType.image()
out = s.to_pylist()
arrs[3] = np.expand_dims(arrs[3], axis=-1)
Expand Down