Skip to content

Commit

Permalink
fix(python): Fix Array constructor when inner type is another Array (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored May 31, 2024
1 parent bfcc1ee commit 65b8cdc
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
11 changes: 6 additions & 5 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override]
# allow comparing object instances to class
if type(other) is DataTypeClass and issubclass(other, List):
return True
if isinstance(other, List):
elif isinstance(other, List):
if self.inner is None or other.inner is None:
return True
else:
Expand Down Expand Up @@ -758,19 +758,20 @@ def __init__(
raise TypeError(msg)

inner_parsed = polars.datatypes.py_type_to_dtype(inner)
inner_shape = inner_parsed.shape if isinstance(inner_parsed, Array) else ()

if isinstance(shape, int):
self.inner = inner_parsed
self.size = shape
self.shape = (shape,)
self.shape = (shape,) + inner_shape

elif isinstance(shape, tuple):
if len(shape) > 1:
inner_parsed = Array(inner_parsed, shape[1:])

self.inner = inner_parsed
self.size = shape[0]
self.shape = shape
self.shape = shape + inner_shape

else:
msg = f"invalid input for shape: {shape!r}"
Expand All @@ -786,8 +787,8 @@ def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override]
# allow comparing object instances to class
if type(other) is DataTypeClass and issubclass(other, Array):
return True
if isinstance(other, Array):
if self.size != other.size:
elif isinstance(other, Array):
if self.shape != other.shape:
return False
elif self.inner is None or other.inner is None:
return True
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ def strptime(
if dtype == Date:
return self.to_date(format, strict=strict, exact=exact, cache=cache)
elif dtype == Datetime:
time_unit = dtype.time_unit # type: ignore[union-attr]
time_zone = dtype.time_zone # type: ignore[union-attr]
time_unit = getattr(dtype, "time_unit", None)
time_zone = getattr(dtype, "time_zone", None)
return self.to_datetime(
format,
time_unit=time_unit,
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/unit/datatypes/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ def test_array_data_type_equality() -> None:
assert pl.Array(pl.Int64, 2) != pl.Array(pl.String, 2)
assert pl.Array(pl.Int64, 2) != pl.List(pl.Int64)

assert pl.Array(pl.Int64, (4, 2)) == pl.Array
assert pl.Array(pl.Array(pl.Int64, 2), 4) == pl.Array(pl.Int64, (4, 2))
assert pl.Array(pl.Int64, (4, 2)) == pl.Array(pl.Int64, (4, 2))
assert pl.Array(pl.Int64, (4, 2)) != pl.Array(pl.String, (4, 2))
assert pl.Array(pl.Int64, (4, 2)) != pl.Array(pl.Int64, 4)
assert pl.Array(pl.Int64, (4,)) != pl.Array(pl.Int64, (4, 2))


@pytest.mark.parametrize(
("data", "inner_type"),
Expand Down

0 comments on commit 65b8cdc

Please sign in to comment.