diff --git a/daft/series.py b/daft/series.py index 553303628e..74724c740e 100644 --- a/daft/series.py +++ b/daft/series.py @@ -97,7 +97,7 @@ def from_pylist(data: list, name: str = "list_series", pyobj: str = "allow") -> raise ValueError(f"pyobj: expected either 'allow', 'disallow', or 'force', but got {pyobj})") if pyobj == "force": - pys = PySeries.from_pylist(name, data) + pys = PySeries.from_pylist(name, data, pyobj=pyobj) return Series._from_pyseries(pys) try: @@ -106,7 +106,7 @@ def from_pylist(data: list, name: str = "list_series", pyobj: str = "allow") -> except pa.lib.ArrowInvalid: if pyobj == "disallow": raise - pys = PySeries.from_pylist(name, data) + pys = PySeries.from_pylist(name, data, pyobj=pyobj) return Series._from_pyseries(pys) @classmethod diff --git a/daft/viz/html_viz_hooks.py b/daft/viz/html_viz_hooks.py index e341ed38f6..0602cc6484 100644 --- a/daft/viz/html_viz_hooks.py +++ b/daft/viz/html_viz_hooks.py @@ -14,13 +14,13 @@ _VIZ_HOOKS_REGISTRY = {} -def register_viz_hook(klass: type[HookClass], hook: Callable[[HookClass], str]): +def register_viz_hook(klass: type[HookClass], hook: Callable[[object], str]): """Registers a visualization hook that returns the appropriate HTML for visualizing a specific class in HTML""" _VIZ_HOOKS_REGISTRY[klass] = hook -def get_viz_hook(val: HookClass) -> Callable[[HookClass], str] | None: +def get_viz_hook(val: object) -> Callable[[object], str] | None: for klass in _VIZ_HOOKS_REGISTRY: if isinstance(val, klass): return _VIZ_HOOKS_REGISTRY[klass] @@ -45,7 +45,7 @@ def get_viz_hook(val: HookClass) -> Callable[[HookClass], str] | None: if HAS_PILLOW: - def _viz_pil_image(val: PIL.Image.Image): + def _viz_pil_image(val: PIL.Image.Image) -> str: img = val.copy() img.thumbnail((128, 128)) bio = io.BytesIO() @@ -57,7 +57,7 @@ def _viz_pil_image(val: PIL.Image.Image): if HAS_NUMPY: - def _viz_numpy(val: np.ndarray): + def _viz_numpy(val: np.ndarray) -> str: return f"<np.ndarray
shape={val.shape}
dtype={val.dtype}>" register_viz_hook(np.ndarray, _viz_numpy) diff --git a/src/array/ops/repr.rs b/src/array/ops/repr.rs index aabb178d47..becdcac90a 100644 --- a/src/array/ops/repr.rs +++ b/src/array/ops/repr.rs @@ -216,15 +216,43 @@ impl_array_html_value!(ListArray); impl_array_html_value!(FixedSizeListArray); impl_array_html_value!(StructArray); impl_array_html_value!(ExtensionArray); - -#[cfg(feature = "python")] -impl_array_html_value!(crate::datatypes::PythonArray); - impl_array_html_value!(DateArray); impl_array_html_value!(DurationArray); impl_array_html_value!(TimestampArray); impl_array_html_value!(EmbeddingArray); +#[cfg(feature = "python")] +impl crate::datatypes::PythonArray { + pub fn html_value(&self, idx: usize) -> String { + use pyo3::prelude::*; + + let val = self.get(idx); + + let custom_viz_hook_result: Option = Python::with_gil(|py| { + // Find visualization hooks for this object's class + let pyany = val.into_ref(py); + let get_viz_hook = py + .import("daft.viz.html_viz_hooks")? + .getattr("get_viz_hook")?; + let hook = get_viz_hook.call1((pyany,))?; + + if hook.is_none() { + Ok(None) + } else { + hook.call1((pyany,))?.extract() + } + }) + .unwrap(); + + match custom_viz_hook_result { + None => html_escape::encode_text(&self.str_value(idx).unwrap()) + .into_owned() + .replace('\n', "
"), + Some(result) => result, + } + } +} + impl DataArray where T: DaftNumericType, diff --git a/src/python/series.rs b/src/python/series.rs index bb889e4953..daa616d035 100644 --- a/src/python/series.rs +++ b/src/python/series.rs @@ -31,19 +31,22 @@ impl PySeries { // This ingests a Python list[object] directly into a Rust PythonArray. #[staticmethod] - pub fn from_pylist(name: &str, pylist: &PyAny) -> PyResult { + pub fn from_pylist(name: &str, pylist: &PyAny, pyobj: &str) -> PyResult { let vec_pyobj: Vec = pylist.extract()?; let py = pylist.py(); - let dtype = infer_daft_dtype_for_sequence(&vec_pyobj, py)?; + + let dtype = match pyobj { + "force" => DataType::Python, + "allow" => infer_daft_dtype_for_sequence(&vec_pyobj, py)?.unwrap_or(DataType::Python), + "disallow" => panic!("Cannot create a Series from a pylist and being strict about only using Arrow types by setting pyobj=disallow"), + _ => panic!("Unsupported pyobj behavior when creating Series from pylist: {}", pyobj) + }; let arrow_array: Box = Box::new(PseudoArrowArray::::from_pyobj_vec(vec_pyobj)); let field = Field::new(name, DataType::Python); let data_array = DataArray::::new(field.into(), arrow_array)?; - let series = match dtype { - Some(dtype) => data_array.cast(&dtype)?, - None => data_array.into_series(), - }; + let series = data_array.cast(&dtype)?; Ok(series.into()) } diff --git a/tests/dataframe/test_logical_type.py b/tests/dataframe/test_logical_type.py index d3d5224b72..60660fc96e 100644 --- a/tests/dataframe/test_logical_type.py +++ b/tests/dataframe/test_logical_type.py @@ -32,7 +32,7 @@ def test_image_type_df(from_pil_imgs) -> None: ] if from_pil_imgs: data = [Image.fromarray(arr, mode="RGB") if arr is not None else None for arr in data] - df = daft.from_pydict({"index": np.arange(len(data)), "image": Series.from_pylist(data, pyobj="force")}) + df = daft.from_pydict({"index": np.arange(len(data)), "image": Series.from_pylist(data, pyobj="allow")}) image_expr = col("image") if not from_pil_imgs: diff --git a/tests/dataframe/test_repr.py b/tests/dataframe/test_repr.py index 7aaf6049e9..026df3bf75 100644 --- a/tests/dataframe/test_repr.py +++ b/tests/dataframe/test_repr.py @@ -2,7 +2,9 @@ import re +import numpy as np import pandas as pd +from PIL import Image import daft @@ -176,3 +178,56 @@ def test_repr_with_html_string(): for i in range(3): assert f"
body{i}
" in non_html_table assert f"<div>body{i}</div>" in html_table + + +class MyObj: + def __repr__(self) -> str: + return "myobj-custom-repr" + + +def test_repr_html_custom_hooks(): + img = Image.fromarray(np.ones((3, 3)).astype(np.uint8)) + arr = np.ones((3, 3)) + + df = daft.from_pydict( + { + "objects": daft.Series.from_pylist([MyObj() for _ in range(3)], pyobj="force"), + "np": daft.Series.from_pylist([arr for _ in range(3)], pyobj="force"), + "pil": daft.Series.from_pylist([img for _ in range(3)], pyobj="force"), + } + ) + df.collect() + + assert ( + df.__repr__() + == """+-------------------+-------------+----------------------------------+ +| objects | np | pil | +| Python | Python | Python | ++-------------------+-------------+----------------------------------+ +| myobj-custom-repr | [[1. 1. 1.] | <np.ndarray
shape=(3, 3)
dtype=float64>" in html_repr diff --git a/tests/series/test_image.py b/tests/series/test_image.py index 99c658aaf9..68933206ac 100644 --- a/tests/series/test_image.py +++ b/tests/series/test_image.py @@ -178,7 +178,7 @@ def test_image_pil_inference(fixed_shape, mode): if arr is not None: arr[..., -1] = 255 imgs = [Image.fromarray(arr, mode=mode) if arr is not None else None for arr in arrs] - s = Series.from_pylist(imgs, pyobj="force") + s = Series.from_pylist(imgs, pyobj="allow") assert s.datatype() == DataType.image(mode) out = s.to_pylist() if num_channels == 1: @@ -206,7 +206,12 @@ def test_image_pil_inference_mixed(): else None for arr in arrs ] + + # Forcing should still create Python Series s = Series.from_pylist(imgs, pyobj="force") + assert s.datatype() == DataType.python() + + s = Series.from_pylist(imgs, pyobj="allow") assert s.datatype() == DataType.image() out = s.to_pylist() arrs[3] = np.expand_dims(arrs[3], axis=-1)