Skip to content

Commit

Permalink
fix: Raise on invalid shape of shape 1, empty combination
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 6, 2024
1 parent f7de80c commit b82410e
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 25 deletions.
23 changes: 13 additions & 10 deletions crates/polars-mem-engine/src/executors/projection_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ pub(super) fn check_expand_literals(
}
}
}

// If all series are the same length it is ok. If not we can broadcast Series of length one.
if !all_equal_len && should_broadcast {
selected_columns = selected_columns
Expand All @@ -300,32 +301,34 @@ pub(super) fn check_expand_literals(
Ok(match series.len() {
0 if df_height == 1 => series,
1 => {
if has_empty {
polars_ensure!(df_height == 1,
ComputeError: "Series length {} doesn't match the DataFrame height of {}",
series.len(), df_height
);
series.slice(0, 0)
} else if df_height == 1 {
if !has_empty && df_height == 1 {
series
} else {
if has_empty {
polars_ensure!(df_height == 1,
ShapeMismatch: "Series length {} doesn't match the DataFrame height of {}",
series.len(), df_height
);

}

if verify_scalar {
polars_ensure!(phys.is_scalar(),
InvalidOperation: "Series: {}, length {} doesn't match the DataFrame height of {}\n\n\
ShapeMismatch: "Series: {}, length {} doesn't match the DataFrame height of {}\n\n\
If you want this Series to be broadcasted, ensure it is a scalar (for instance by adding '.first()').",
series.name(), series.len(), df_height
);

}
series.new_from_index(0, df_height)
series.new_from_index(0, df_height * (!has_empty as usize) )
}
},
len if len == df_height => {
series
},
_ => {
polars_bail!(
ComputeError: "Series length {} doesn't match the DataFrame height of {}",
ShapeMismatch: "Series length {} doesn't match the DataFrame height of {}",
series.len(), df_height
)
}
Expand Down
12 changes: 0 additions & 12 deletions py-polars/tests/unit/constructors/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,18 +151,6 @@ def test_df_init_nested_mixed_types() -> None:
assert df.to_dicts() == [{"key": [{"value": 1.0}, {"value": 1.0}]}]


def test_unit_and_empty_construction_15896() -> None:
# This is still incorrect.
# We should raise, but currently for len 1 dfs,
# we cannot tell if they come from a literal or expression.
assert "shape: (0, 2)" in str(
pl.DataFrame({"A": [0]}).select(
C="A",
A=pl.int_range("A"), # creates empty series
)
)


class CustomSchema(Mapping[str, Any]):
"""Dummy schema object for testing compatibility with Mapping."""

Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/unit/dataframe/test_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest

import polars as pl


def test_raise_invalid_shape_19108() -> None:
df = pl.DataFrame({"foo": [1, 2], "bar": [3, 4]})
with pytest.raises(pl.exceptions.ShapeError):
df.select(pl.col.foo.head(0), pl.col.bar.head(1))
3 changes: 1 addition & 2 deletions py-polars/tests/unit/lazyframe/test_with_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest

import polars as pl
from polars.exceptions import ComputeError
from polars.testing import assert_frame_equal


Expand All @@ -19,7 +18,7 @@ def test_with_context() -> None:

with pytest.deprecated_call():
context = df_a.with_context(df_b.lazy())
with pytest.raises(ComputeError):
with pytest.raises(pl.exceptions.ShapeError):
context.select("a", "c").collect()


Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/unit/operations/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def test_slice_pushdown_literal_projection_14349() -> None:
q = lf.select("a", x=1).head(0)
# slice isn't in plan if it has been pushed down to the dataframe
assert "SLICE" not in q.explain()
print(q.explain())
assert q.collect().height == 0

# For with_columns, slice pushdown should happen if the input has at least 1 column
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_invalid_broadcast() -> None:
"group": [0, 1],
}
)
with pytest.raises(pl.exceptions.InvalidOperationError):
with pytest.raises(pl.exceptions.ShapeError):
df.select(pl.col("group").filter(pl.col("group") == 0), "a")


Expand Down

0 comments on commit b82410e

Please sign in to comment.