diff --git a/daft/series.py b/daft/series.py
index 553303628e..80e52aa272 100644
--- a/daft/series.py
+++ b/daft/series.py
@@ -97,7 +97,7 @@ def from_pylist(data: list, name: str = "list_series", pyobj: str = "allow") ->
raise ValueError(f"pyobj: expected either 'allow', 'disallow', or 'force', but got {pyobj})")
if pyobj == "force":
- pys = PySeries.from_pylist(name, data)
+ pys = PySeries.from_pylist(name, data, pyobj=pyobj)
return Series._from_pyseries(pys)
try:
diff --git a/daft/viz/html_viz_hooks.py b/daft/viz/html_viz_hooks.py
index e341ed38f6..0602cc6484 100644
--- a/daft/viz/html_viz_hooks.py
+++ b/daft/viz/html_viz_hooks.py
@@ -14,13 +14,13 @@
_VIZ_HOOKS_REGISTRY = {}
-def register_viz_hook(klass: type[HookClass], hook: Callable[[HookClass], str]):
+def register_viz_hook(klass: type[HookClass], hook: Callable[[object], str]):
"""Registers a visualization hook that returns the appropriate HTML for
visualizing a specific class in HTML"""
_VIZ_HOOKS_REGISTRY[klass] = hook
-def get_viz_hook(val: HookClass) -> Callable[[HookClass], str] | None:
+def get_viz_hook(val: object) -> Callable[[object], str] | None:
for klass in _VIZ_HOOKS_REGISTRY:
if isinstance(val, klass):
return _VIZ_HOOKS_REGISTRY[klass]
@@ -45,7 +45,7 @@ def get_viz_hook(val: HookClass) -> Callable[[HookClass], str] | None:
if HAS_PILLOW:
- def _viz_pil_image(val: PIL.Image.Image):
+ def _viz_pil_image(val: PIL.Image.Image) -> str:
img = val.copy()
img.thumbnail((128, 128))
bio = io.BytesIO()
@@ -57,7 +57,7 @@ def _viz_pil_image(val: PIL.Image.Image):
if HAS_NUMPY:
- def _viz_numpy(val: np.ndarray):
+ def _viz_numpy(val: np.ndarray) -> str:
return f"<np.ndarray
shape={val.shape}
dtype={val.dtype}>"
register_viz_hook(np.ndarray, _viz_numpy)
diff --git a/src/array/ops/repr.rs b/src/array/ops/repr.rs
index aabb178d47..becdcac90a 100644
--- a/src/array/ops/repr.rs
+++ b/src/array/ops/repr.rs
@@ -216,15 +216,43 @@ impl_array_html_value!(ListArray);
impl_array_html_value!(FixedSizeListArray);
impl_array_html_value!(StructArray);
impl_array_html_value!(ExtensionArray);
-
-#[cfg(feature = "python")]
-impl_array_html_value!(crate::datatypes::PythonArray);
-
impl_array_html_value!(DateArray);
impl_array_html_value!(DurationArray);
impl_array_html_value!(TimestampArray);
impl_array_html_value!(EmbeddingArray);
+#[cfg(feature = "python")]
+impl crate::datatypes::PythonArray {
+ pub fn html_value(&self, idx: usize) -> String {
+ use pyo3::prelude::*;
+
+ let val = self.get(idx);
+
+ let custom_viz_hook_result: Option = Python::with_gil(|py| {
+ // Find visualization hooks for this object's class
+ let pyany = val.into_ref(py);
+ let get_viz_hook = py
+ .import("daft.viz.html_viz_hooks")?
+ .getattr("get_viz_hook")?;
+ let hook = get_viz_hook.call1((pyany,))?;
+
+ if hook.is_none() {
+ Ok(None)
+ } else {
+ hook.call1((pyany,))?.extract()
+ }
+ })
+ .unwrap();
+
+ match custom_viz_hook_result {
+ None => html_escape::encode_text(&self.str_value(idx).unwrap())
+ .into_owned()
+ .replace('\n', "
"),
+ Some(result) => result,
+ }
+ }
+}
+
impl DataArray
where
T: DaftNumericType,
diff --git a/src/python/series.rs b/src/python/series.rs
index bb889e4953..daa616d035 100644
--- a/src/python/series.rs
+++ b/src/python/series.rs
@@ -31,19 +31,22 @@ impl PySeries {
// This ingests a Python list[object] directly into a Rust PythonArray.
#[staticmethod]
- pub fn from_pylist(name: &str, pylist: &PyAny) -> PyResult {
+ pub fn from_pylist(name: &str, pylist: &PyAny, pyobj: &str) -> PyResult {
let vec_pyobj: Vec = pylist.extract()?;
let py = pylist.py();
- let dtype = infer_daft_dtype_for_sequence(&vec_pyobj, py)?;
+
+ let dtype = match pyobj {
+ "force" => DataType::Python,
+ "allow" => infer_daft_dtype_for_sequence(&vec_pyobj, py)?.unwrap_or(DataType::Python),
+ "disallow" => panic!("Cannot create a Series from a pylist and being strict about only using Arrow types by setting pyobj=disallow"),
+ _ => panic!("Unsupported pyobj behavior when creating Series from pylist: {}", pyobj)
+ };
let arrow_array: Box =
Box::new(PseudoArrowArray::::from_pyobj_vec(vec_pyobj));
let field = Field::new(name, DataType::Python);
let data_array = DataArray::::new(field.into(), arrow_array)?;
- let series = match dtype {
- Some(dtype) => data_array.cast(&dtype)?,
- None => data_array.into_series(),
- };
+ let series = data_array.cast(&dtype)?;
Ok(series.into())
}