Skip to content

Commit

Permalink
fix: Invalid selectors not being recognized
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelzw committed Oct 8, 2024
1 parent 1f48036 commit 8ae6f94
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
24 changes: 13 additions & 11 deletions crates/polars-plan/src/plans/conversion/expr_expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -921,22 +921,24 @@ pub(super) fn expand_selector(
let mut members = PlIndexSet::new();
replace_selector_inner(s, &mut members, &mut vec![], schema, keys)?;

if members.len() <= 1 {
members
.into_iter()
.map(|e| {
let Expr::Column(name) = e else {
polars_bail!(InvalidOperation: "invalid selector expression: {}", e)
};
Ok(name)
})
.collect()
let column_names = members
.into_iter()
.map(|e| {
let Expr::Column(name) = e else {
polars_bail!(InvalidOperation: "invalid selector expression: {}", e)
};
Ok(name)
})
.collect::<PolarsResult<Arc<[PlSmallStr]>>>()?;

if column_names.len() <= 1 {
Ok(column_names)
} else {
// Ensure that multiple columns returned from combined/nested selectors remain in schema order
let selected = schema
.iter_fields()
.map(|field| field.name().clone())
.filter(|field_name| members.contains(&Expr::Column(field_name.clone())))
.filter(|field_name| column_names.contains(&field_name))
.collect();

Ok(selected)
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,3 +814,11 @@ def test_selector_list_of_lists_18499() -> None:

with pytest.raises(InvalidOperationError, match="invalid selector expression"):
lf.unique(subset=[["bar", "ham"]]) # type: ignore[list-item]


def test_invalid_selector() -> None:
df = pl.DataFrame(data={"x": [1, 2], "z": ["a", "b"]})
with pytest.raises(InvalidOperationError, match="invalid selector expression"):
df.drop(pl.col("x", "z") + 2)
with pytest.raises(InvalidOperationError, match="invalid selector expression"):
df.drop(pl.col("x") + 2)

0 comments on commit 8ae6f94

Please sign in to comment.