diff --git a/daft/series.py b/daft/series.py index 553303628e..80e52aa272 100644 --- a/daft/series.py +++ b/daft/series.py @@ -97,7 +97,7 @@ def from_pylist(data: list, name: str = "list_series", pyobj: str = "allow") -> raise ValueError(f"pyobj: expected either 'allow', 'disallow', or 'force', but got {pyobj})") if pyobj == "force": - pys = PySeries.from_pylist(name, data) + pys = PySeries.from_pylist(name, data, pyobj=pyobj) return Series._from_pyseries(pys) try: diff --git a/daft/viz/html_viz_hooks.py b/daft/viz/html_viz_hooks.py index e341ed38f6..0602cc6484 100644 --- a/daft/viz/html_viz_hooks.py +++ b/daft/viz/html_viz_hooks.py @@ -14,13 +14,13 @@ _VIZ_HOOKS_REGISTRY = {} -def register_viz_hook(klass: type[HookClass], hook: Callable[[HookClass], str]): +def register_viz_hook(klass: type[HookClass], hook: Callable[[object], str]): """Registers a visualization hook that returns the appropriate HTML for visualizing a specific class in HTML""" _VIZ_HOOKS_REGISTRY[klass] = hook -def get_viz_hook(val: HookClass) -> Callable[[HookClass], str] | None: +def get_viz_hook(val: object) -> Callable[[object], str] | None: for klass in _VIZ_HOOKS_REGISTRY: if isinstance(val, klass): return _VIZ_HOOKS_REGISTRY[klass] @@ -45,7 +45,7 @@ def get_viz_hook(val: HookClass) -> Callable[[HookClass], str] | None: if HAS_PILLOW: - def _viz_pil_image(val: PIL.Image.Image): + def _viz_pil_image(val: PIL.Image.Image) -> str: img = val.copy() img.thumbnail((128, 128)) bio = io.BytesIO() @@ -57,7 +57,7 @@ def _viz_pil_image(val: PIL.Image.Image): if HAS_NUMPY: - def _viz_numpy(val: np.ndarray): + def _viz_numpy(val: np.ndarray) -> str: return f"<np.ndarray
shape={val.shape}
dtype={val.dtype}>" register_viz_hook(np.ndarray, _viz_numpy) diff --git a/src/array/ops/repr.rs b/src/array/ops/repr.rs index aabb178d47..becdcac90a 100644 --- a/src/array/ops/repr.rs +++ b/src/array/ops/repr.rs @@ -216,15 +216,43 @@ impl_array_html_value!(ListArray); impl_array_html_value!(FixedSizeListArray); impl_array_html_value!(StructArray); impl_array_html_value!(ExtensionArray); - -#[cfg(feature = "python")] -impl_array_html_value!(crate::datatypes::PythonArray); - impl_array_html_value!(DateArray); impl_array_html_value!(DurationArray); impl_array_html_value!(TimestampArray); impl_array_html_value!(EmbeddingArray); +#[cfg(feature = "python")] +impl crate::datatypes::PythonArray { + pub fn html_value(&self, idx: usize) -> String { + use pyo3::prelude::*; + + let val = self.get(idx); + + let custom_viz_hook_result: Option = Python::with_gil(|py| { + // Find visualization hooks for this object's class + let pyany = val.into_ref(py); + let get_viz_hook = py + .import("daft.viz.html_viz_hooks")? + .getattr("get_viz_hook")?; + let hook = get_viz_hook.call1((pyany,))?; + + if hook.is_none() { + Ok(None) + } else { + hook.call1((pyany,))?.extract() + } + }) + .unwrap(); + + match custom_viz_hook_result { + None => html_escape::encode_text(&self.str_value(idx).unwrap()) + .into_owned() + .replace('\n', "
"), + Some(result) => result, + } + } +} + impl DataArray where T: DaftNumericType, diff --git a/src/python/series.rs b/src/python/series.rs index bb889e4953..daa616d035 100644 --- a/src/python/series.rs +++ b/src/python/series.rs @@ -31,19 +31,22 @@ impl PySeries { // This ingests a Python list[object] directly into a Rust PythonArray. #[staticmethod] - pub fn from_pylist(name: &str, pylist: &PyAny) -> PyResult { + pub fn from_pylist(name: &str, pylist: &PyAny, pyobj: &str) -> PyResult { let vec_pyobj: Vec = pylist.extract()?; let py = pylist.py(); - let dtype = infer_daft_dtype_for_sequence(&vec_pyobj, py)?; + + let dtype = match pyobj { + "force" => DataType::Python, + "allow" => infer_daft_dtype_for_sequence(&vec_pyobj, py)?.unwrap_or(DataType::Python), + "disallow" => panic!("Cannot create a Series from a pylist and being strict about only using Arrow types by setting pyobj=disallow"), + _ => panic!("Unsupported pyobj behavior when creating Series from pylist: {}", pyobj) + }; let arrow_array: Box = Box::new(PseudoArrowArray::::from_pyobj_vec(vec_pyobj)); let field = Field::new(name, DataType::Python); let data_array = DataArray::::new(field.into(), arrow_array)?; - let series = match dtype { - Some(dtype) => data_array.cast(&dtype)?, - None => data_array.into_series(), - }; + let series = data_array.cast(&dtype)?; Ok(series.into()) }