Skip to content

Commit

Permalink
feat(python): add "match_any" param/option to the by_name selector
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed May 15, 2024
1 parent a34ca2c commit 610804d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
40 changes: 34 additions & 6 deletions py-polars/polars/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from polars import functions as F
from polars._utils.deprecation import deprecate_nonkeyword_arguments
from polars._utils.parse_expr_input import _parse_inputs_as_iterable
from polars._utils.various import is_column
from polars._utils.various import is_column, re_escape
from polars.datatypes import (
FLOAT_DTYPES,
INTEGER_DTYPES,
Expand Down Expand Up @@ -284,7 +284,7 @@ def __repr__(self) -> str:
op = set_ops[selector_name]
return "({})".format(f" {op} ".join(repr(p) for p in params.values()))
else:
str_params = ",".join(
str_params = ", ".join(
(repr(v)[1:-1] if k.startswith("*") else f"{k}={v!r}")
for k, v in (params or {}).items()
).rstrip(",")
Expand Down Expand Up @@ -357,15 +357,15 @@ def as_expr(self) -> Expr:
def _re_string(string: str | Collection[str], *, escape: bool = True) -> str:
"""Return escaped regex, potentially representing multiple string fragments."""
if isinstance(string, str):
rx = f"{re.escape(string)}" if escape else string
rx = f"{re_escape(string)}" if escape else string
else:
strings: list[str] = []
for st in string:
if isinstance(st, Collection) and not isinstance(st, str): # type: ignore[redundant-expr]
strings.extend(st)
else:
strings.append(st)
rx = "|".join((re.escape(x) if escape else x) for x in strings)
rx = "|".join((re_escape(x) if escape else x) for x in strings)
return f"({rx})"


Expand Down Expand Up @@ -689,14 +689,21 @@ def by_index(*indices: int | range | Sequence[int | range]) -> SelectorType:
)


def by_name(*names: str | Collection[str]) -> SelectorType:
def by_name(*names: str | Collection[str], match_any: bool = False) -> SelectorType:
"""
Select all columns matching the given names.
Parameters
----------
*names
One or more names of columns to select.
match_any
Whether to match *all* names (the default) or *any* of the names.
Notes
-----
Matching columns are returned in the order in which they are declared in
the selector, not the original schema order.
See Also
--------
Expand Down Expand Up @@ -728,6 +735,19 @@ def by_name(*names: str | Collection[str]) -> SelectorType:
│ y ┆ 456 │
└─────┴─────┘
Match *any* of the given columns by name:
>>> df.select(cs.by_name("baz", "moose", "foo", "bear", match_any=True))
shape: (2, 2)
┌─────┬─────┐
│ foo ┆ baz │
│ --- ┆ --- │
│ str ┆ f64 │
╞═════╪═════╡
│ x ┆ 2.0 │
│ y ┆ 5.5 │
└─────┴─────┘
Match all columns *except* for those given:
>>> df.select(~cs.by_name("foo", "bar"))
Expand Down Expand Up @@ -755,8 +775,16 @@ def by_name(*names: str | Collection[str]) -> SelectorType:
msg = f"invalid name: {nm!r}"
raise TypeError(msg)

selector_params: dict[str, Any] = {"*names": all_names}
match_cols: list[str] | str = all_names
if match_any:
match_cols = f"^({'|'.join(re_escape(nm) for nm in all_names)})$"
selector_params["match_any"] = match_any

return _selector_proxy_(
F.col(all_names), name="by_name", parameters={"*names": all_names}
F.col(match_cols),
name="by_name",
parameters=selector_params,
)


Expand Down
17 changes: 16 additions & 1 deletion py-polars/tests/unit/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,13 @@ def test_selector_by_name(df: pl.DataFrame) -> None:
assert df.select(cs.by_name()).columns == []
assert df.select(cs.by_name([])).columns == []

selected_cols = df.select(cs.by_name("???", "fgg", "!!!", match_any=True)).columns
assert selected_cols == ["fgg"]

# expected errors
with pytest.raises(ColumnNotFoundError, match="xxx"):
df.select(cs.by_name("xxx", "fgg", "!!!"))

with pytest.raises(ColumnNotFoundError):
df.select(cs.by_name("stroopwafel"))

Expand Down Expand Up @@ -487,7 +493,16 @@ def test_selector_repr() -> None:
assert_repr_equals(~cs.starts_with("a", "b"), "~cs.starts_with('a', 'b')")
assert_repr_equals(cs.float() | cs.by_name("x"), "(cs.float() | cs.by_name('x'))")
assert_repr_equals(
cs.integer() & cs.matches("z"), "(cs.integer() & cs.matches(pattern='z'))"
cs.integer() & cs.matches("z"),
"(cs.integer() & cs.matches(pattern='z'))",
)
assert_repr_equals(
cs.by_name("baz", "moose", "foo", "bear"),
"cs.by_name('baz', 'moose', 'foo', 'bear')",
)
assert_repr_equals(
cs.by_name("baz", "moose", "foo", "bear", match_any=True),
"cs.by_name('baz', 'moose', 'foo', 'bear', match_any=True)",
)
assert_repr_equals(
cs.temporal() | cs.by_dtype(pl.String) & cs.string(include_categorical=False),
Expand Down

0 comments on commit 610804d

Please sign in to comment.