Skip to content

Commit

Permalink
Merge pull request #360 from jrycw/pl-16250
Browse files Browse the repository at this point in the history
fix: refactor column selection logic for `Polars`
  • Loading branch information
machow authored May 23, 2024
2 parents 9f366c1 + 7a7f825 commit 6b2624b
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 15 deletions.
8 changes: 4 additions & 4 deletions docs/blog/polars-styling/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,12 @@ gt_with_spanners = (
# Table column spanners ----
.tab_spanner(
label = "Time",
columns = time_cols
label="Time",
columns=time_cols
)
.tab_spanner(
label = "Measurement",
columns = cs.all().exclude(time_cols)
label="Measurement",
columns=cs.exclude(time_cols)
)
)
Expand Down
2 changes: 1 addition & 1 deletion docs/blog/superbowl-squares/_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def team_final_digits(game: pl.DataFrame, team_code: str) -> pl.DataFrame:
.with_columns(joint=pl.col("prop") * pl.col("prop_right"))
.sort("final_digit", "final_digit_right")
.pivot(values="joint", columns="final_digit_right", index="final_digit")
.with_columns((cs.all().exclude("final_digit") * 100).round(1))
.with_columns((cs.exclude("final_digit") * 100).round(1))
)

# Display -----
Expand Down
2 changes: 1 addition & 1 deletion docs/get-started/basic-styling.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ import polars.selectors as cs
gt_pl_air.tab_style(
style=style.fill(color="yellow"),
locations=loc.body(
columns=cs.all().exclude(["Month", "Day"]),
columns=cs.exclude(["Month", "Day"]),
rows=pl.col("Temp") == pl.col("Temp").max()
)
)
Expand Down
29 changes: 25 additions & 4 deletions great_tables/_tbl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,26 +307,47 @@ def _(
def _(data: PlDataFrame, expr: Union[list[str], _selector_proxy_], strict: bool = True) -> _NamePos:
# TODO: how to annotate type of a polars selector?
# Seems to be polars.selectors._selector_proxy_.
import polars.selectors as cs

from functools import reduce
from operator import or_
from polars import Expr
from polars import selectors

if isinstance(expr, (str, int)):
expr = [expr]

if isinstance(expr, list):
all_selectors = [
cs.by_name(x) if isinstance(x, str) else cs.by_index(x) if isinstance(x, int) else x
for x in expr
]

_validate_selector_list(all_selectors)

expr = reduce(or_, all_selectors, cs.by_name())

col_pos = {k: ii for ii, k in enumerate(data.columns)}

# just in case _selector_proxy_ gets renamed or something
# it inherits from Expr, so we can just use that in a pinch
cls_selector = getattr(selectors, "_selector_proxy_", Expr)
cls_selector = getattr(cs, "_selector_proxy_", Expr)

if not isinstance(expr, (list, cls_selector)):
if not isinstance(expr, cls_selector):
raise TypeError(f"Unsupported selection expr type: {type(expr)}")

# I don't think there's a way to get the columns w/o running the selection
final_columns = selectors.expand_selector(data, expr)
final_columns = cs.expand_selector(data, expr)
return [(col, col_pos[col]) for col in final_columns]


def _validate_selector_list(selectors: list):
from polars.selectors import is_selector

for ii, sel in enumerate(selectors):
if not is_selector(sel):
raise TypeError(f"Expected a list of selectors, but entry {ii} is type: {type(sel)}.")


def _eval_select_from_list(
columns: list[str], expr: list[Union[str, int]]
) -> list[tuple[str, int]]:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_format_repr_snap(snapshot):
assert_repr_html(snapshot, new_gt)


@pytest.mark.parametrize("expr", [[0, -1], pl.selectors.all().exclude("y")])
@pytest.mark.parametrize("expr", [[0, -1], pl.selectors.exclude("y")])
def test_format_col_selection_multi(expr: Any):
df = pd.DataFrame({"x": [1], "y": [2], "z": [3]})

Expand Down
55 changes: 51 additions & 4 deletions tests/test_tbl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,61 @@ def test_reorder(df: DataFrameLike):
assert_frame_equal(res, dst)


@pytest.mark.parametrize(
"expr", [["col2", "col1"], [1, 0], ["col2", 0], pl.selectors.all().exclude("col3")]
)
@pytest.mark.parametrize("expr", [["col2", "col1"], [1, 0], ["col2", 0], [1, "col1"]])
def test_eval_select_with_list(df: DataFrameLike, expr):
sel = eval_select(df, ["col2", "col1"])
sel = eval_select(df, expr)
assert sel == [("col2", 1), ("col1", 0)]


@pytest.mark.parametrize(
"expr",
[
pl.selectors.exclude("col3"),
pl.selectors.starts_with("col1") | pl.selectors.starts_with("col2"),
],
)
def test_eval_select_with_list_pl_selector(expr):
df = pl.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [4.0, 5.0, 6.0]})
sel = eval_select(df, expr)
assert sel == [("col1", 0), ("col2", 1)]


@pytest.mark.parametrize("expr", [["col2", 1.2]])
def test_eval_select_pandas_raises1(expr):
df = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [4.0, 5.0, 6.0]})
with pytest.raises(TypeError) as exc_info:
eval_select(df, expr)

assert "Only int and str are supported." in str(exc_info.value.args[0])


@pytest.mark.parametrize("expr", [3.45, {"col2"}, ("col2",)])
def test_eval_select_pandas_raises2(expr):
df = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [4.0, 5.0, 6.0]})
with pytest.raises(NotImplementedError) as exc_info:
eval_select(df, expr)

assert "Unsupported selection expr: " in str(exc_info.value.args[0])


@pytest.mark.parametrize("expr", [3.45, {6}, (7.8,)])
def test_eval_select_polars_raises(expr):
df = pl.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [4.0, 5.0, 6.0]})
with pytest.raises(TypeError) as exc_info:
eval_select(df, expr)

assert "Unsupported selection expr type:" in str(exc_info.value.args[0])


def test_eval_selector_polars_list_raises():
expr = ["col1", 1.2]
df = pl.DataFrame({"col1": [], "col2": [], "col3": []})
with pytest.raises(TypeError) as exc_info:
eval_select(df, expr)

assert "entry 1 is type: <class 'float'>" in str(exc_info.value.args[0])


def test_create_empty_frame(df: DataFrameLike):
res = create_empty_frame(df)
col = [None] * 3
Expand Down

0 comments on commit 6b2624b

Please sign in to comment.