Skip to content

Commit

Permalink
Fix image tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Jun 22, 2023
1 parent ee2b0a2 commit 129d436
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def from_pylist(data: list, name: str = "list_series", pyobj: str = "allow") ->
except pa.lib.ArrowInvalid:
if pyobj == "disallow":
raise
pys = PySeries.from_pylist(name, data)
pys = PySeries.from_pylist(name, data, pyobj=pyobj)
return Series._from_pyseries(pys)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/dataframe/test_logical_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_image_type_df(from_pil_imgs) -> None:
]
if from_pil_imgs:
data = [Image.fromarray(arr, mode="RGB") if arr is not None else None for arr in data]
df = daft.from_pydict({"index": np.arange(len(data)), "image": Series.from_pylist(data, pyobj="force")})
df = daft.from_pydict({"index": np.arange(len(data)), "image": Series.from_pylist(data, pyobj="allow")})

image_expr = col("image")
if not from_pil_imgs:
Expand Down
7 changes: 6 additions & 1 deletion tests/series/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_image_pil_inference(fixed_shape, mode):
if arr is not None:
arr[..., -1] = 255
imgs = [Image.fromarray(arr, mode=mode) if arr is not None else None for arr in arrs]
s = Series.from_pylist(imgs, pyobj="force")
s = Series.from_pylist(imgs, pyobj="allow")
assert s.datatype() == DataType.image(mode)
out = s.to_pylist()
if num_channels == 1:
Expand Down Expand Up @@ -206,7 +206,12 @@ def test_image_pil_inference_mixed():
else None
for arr in arrs
]

# Forcing should still create Python Series
s = Series.from_pylist(imgs, pyobj="force")
assert s.datatype() == DataType.python()

s = Series.from_pylist(imgs, pyobj="allow")
assert s.datatype() == DataType.image()
out = s.to_pylist()
arrs[3] = np.expand_dims(arrs[3], axis=-1)
Expand Down

0 comments on commit 129d436

Please sign in to comment.