Skip to content

Commit

Permalink
[BUG] Hit viz hooks for Python object display
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Jun 22, 2023
1 parent 1c49647 commit c167ef8
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 15 deletions.
2 changes: 1 addition & 1 deletion daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def from_pylist(data: list, name: str = "list_series", pyobj: str = "allow") ->
raise ValueError(f"pyobj: expected either 'allow', 'disallow', or 'force', but got {pyobj})")

if pyobj == "force":
pys = PySeries.from_pylist(name, data)
pys = PySeries.from_pylist(name, data, pyobj=pyobj)
return Series._from_pyseries(pys)

try:
Expand Down
8 changes: 4 additions & 4 deletions daft/viz/html_viz_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
_VIZ_HOOKS_REGISTRY = {}


def register_viz_hook(klass: type[HookClass], hook: Callable[[HookClass], str]):
def register_viz_hook(klass: type[HookClass], hook: Callable[[object], str]):
"""Registers a visualization hook that returns the appropriate HTML for
visualizing a specific class in HTML"""
_VIZ_HOOKS_REGISTRY[klass] = hook


def get_viz_hook(val: HookClass) -> Callable[[HookClass], str] | None:
def get_viz_hook(val: object) -> Callable[[object], str] | None:
for klass in _VIZ_HOOKS_REGISTRY:
if isinstance(val, klass):
return _VIZ_HOOKS_REGISTRY[klass]
Expand All @@ -45,7 +45,7 @@ def get_viz_hook(val: HookClass) -> Callable[[HookClass], str] | None:

if HAS_PILLOW:

def _viz_pil_image(val: PIL.Image.Image):
def _viz_pil_image(val: PIL.Image.Image) -> str:
img = val.copy()
img.thumbnail((128, 128))
bio = io.BytesIO()
Expand All @@ -57,7 +57,7 @@ def _viz_pil_image(val: PIL.Image.Image):

if HAS_NUMPY:

def _viz_numpy(val: np.ndarray):
def _viz_numpy(val: np.ndarray) -> str:
return f"&ltnp.ndarray<br>shape={val.shape}<br>dtype={val.dtype}&gt"

register_viz_hook(np.ndarray, _viz_numpy)
36 changes: 32 additions & 4 deletions src/array/ops/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,43 @@ impl_array_html_value!(ListArray);
impl_array_html_value!(FixedSizeListArray);
impl_array_html_value!(StructArray);
impl_array_html_value!(ExtensionArray);

#[cfg(feature = "python")]
impl_array_html_value!(crate::datatypes::PythonArray);

impl_array_html_value!(DateArray);
impl_array_html_value!(DurationArray);
impl_array_html_value!(TimestampArray);
impl_array_html_value!(EmbeddingArray);

#[cfg(feature = "python")]
impl crate::datatypes::PythonArray {
pub fn html_value(&self, idx: usize) -> String {
use pyo3::prelude::*;

let val = self.get(idx);

let custom_viz_hook_result: Option<String> = Python::with_gil(|py| {
// Find visualization hooks for this object's class
let pyany = val.into_ref(py);
let get_viz_hook = py
.import("daft.viz.html_viz_hooks")?
.getattr("get_viz_hook")?;
let hook = get_viz_hook.call1((pyany,))?;

if hook.is_none() {
Ok(None)
} else {
hook.call1((pyany,))?.extract()
}
})
.unwrap();

match custom_viz_hook_result {
None => html_escape::encode_text(&self.str_value(idx).unwrap())
.into_owned()
.replace('\n', "<br />"),
Some(result) => result,
}
}
}

impl<T> DataArray<T>
where
T: DaftNumericType,
Expand Down
15 changes: 9 additions & 6 deletions src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,22 @@ impl PySeries {

// This ingests a Python list[object] directly into a Rust PythonArray.
#[staticmethod]
pub fn from_pylist(name: &str, pylist: &PyAny) -> PyResult<Self> {
pub fn from_pylist(name: &str, pylist: &PyAny, pyobj: &str) -> PyResult<Self> {
let vec_pyobj: Vec<PyObject> = pylist.extract()?;
let py = pylist.py();
let dtype = infer_daft_dtype_for_sequence(&vec_pyobj, py)?;

let dtype = match pyobj {
"force" => DataType::Python,
"allow" => infer_daft_dtype_for_sequence(&vec_pyobj, py)?.unwrap_or(DataType::Python),
"disallow" => panic!("Cannot create a Series from a pylist and being strict about only using Arrow types by setting pyobj=disallow"),
_ => panic!("Unsupported pyobj behavior when creating Series from pylist: {}", pyobj)
};
let arrow_array: Box<dyn arrow2::array::Array> =
Box::new(PseudoArrowArray::<PyObject>::from_pyobj_vec(vec_pyobj));
let field = Field::new(name, DataType::Python);

let data_array = DataArray::<PythonType>::new(field.into(), arrow_array)?;
let series = match dtype {
Some(dtype) => data_array.cast(&dtype)?,
None => data_array.into_series(),
};
let series = data_array.cast(&dtype)?;
Ok(series.into())
}

Expand Down

0 comments on commit c167ef8

Please sign in to comment.