Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactor column selection logic for Polars #360

Merged
merged 5 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/blog/polars-styling/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,12 @@ gt_with_spanners = (

# Table column spanners ----
.tab_spanner(
label = "Time",
columns = time_cols
label="Time",
columns=time_cols
)
.tab_spanner(
label = "Measurement",
columns = cs.all().exclude(time_cols)
label="Measurement",
columns=cs.exclude(time_cols)
)
)

Expand Down
2 changes: 1 addition & 1 deletion docs/blog/superbowl-squares/_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def team_final_digits(game: pl.DataFrame, team_code: str) -> pl.DataFrame:
.with_columns(joint=pl.col("prop") * pl.col("prop_right"))
.sort("final_digit", "final_digit_right")
.pivot(values="joint", columns="final_digit_right", index="final_digit")
.with_columns((cs.all().exclude("final_digit") * 100).round(1))
.with_columns((cs.exclude("final_digit") * 100).round(1))
)

# Display -----
Expand Down
2 changes: 1 addition & 1 deletion docs/get-started/basic-styling.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ import polars.selectors as cs
gt_pl_air.tab_style(
style=style.fill(color="yellow"),
locations=loc.body(
columns=cs.all().exclude(["Month", "Day"]),
columns=cs.exclude(["Month", "Day"]),
rows=pl.col("Temp") == pl.col("Temp").max()
)
)
Expand Down
29 changes: 25 additions & 4 deletions great_tables/_tbl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,26 +307,47 @@ def _(
def _(data: PlDataFrame, expr: Union[list[str], _selector_proxy_], strict: bool = True) -> _NamePos:
# TODO: how to annotate type of a polars selector?
# Seems to be polars.selectors._selector_proxy_.
import polars.selectors as cs

from functools import reduce
from operator import or_
from polars import Expr
from polars import selectors

if isinstance(expr, (str, int)):
expr = [expr]

if isinstance(expr, list):
all_selectors = [
cs.by_name(x) if isinstance(x, str) else cs.by_index(x) if isinstance(x, int) else x
for x in expr
]

_validate_selector_list(all_selectors)

expr = reduce(or_, all_selectors, cs.by_name())

col_pos = {k: ii for ii, k in enumerate(data.columns)}

# just in case _selector_proxy_ gets renamed or something
# it inherits from Expr, so we can just use that in a pinch
cls_selector = getattr(selectors, "_selector_proxy_", Expr)
cls_selector = getattr(cs, "_selector_proxy_", Expr)

if not isinstance(expr, (list, cls_selector)):
if not isinstance(expr, cls_selector):
raise TypeError(f"Unsupported selection expr type: {type(expr)}")

# I don't think there's a way to get the columns w/o running the selection
final_columns = selectors.expand_selector(data, expr)
final_columns = cs.expand_selector(data, expr)
return [(col, col_pos[col]) for col in final_columns]


def _validate_selector_list(selectors: list):
from polars.selectors import is_selector

for ii, sel in enumerate(selectors):
if not is_selector(sel):
raise TypeError(f"Expected a list of selectors, but entry {ii} is type: {type(sel)}.")


def _eval_select_from_list(
columns: list[str], expr: list[Union[str, int]]
) -> list[tuple[str, int]]:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_format_repr_snap(snapshot):
assert_repr_html(snapshot, new_gt)


@pytest.mark.parametrize("expr", [[0, -1], pl.selectors.all().exclude("y")])
@pytest.mark.parametrize("expr", [[0, -1], pl.selectors.exclude("y")])
def test_format_col_selection_multi(expr: Any):
df = pd.DataFrame({"x": [1], "y": [2], "z": [3]})

Expand Down
55 changes: 51 additions & 4 deletions tests/test_tbl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,61 @@ def test_reorder(df: DataFrameLike):
assert_frame_equal(res, dst)


@pytest.mark.parametrize(
"expr", [["col2", "col1"], [1, 0], ["col2", 0], pl.selectors.all().exclude("col3")]
)
@pytest.mark.parametrize("expr", [["col2", "col1"], [1, 0], ["col2", 0], [1, "col1"]])
def test_eval_select_with_list(df: DataFrameLike, expr):
sel = eval_select(df, ["col2", "col1"])
sel = eval_select(df, expr)
assert sel == [("col2", 1), ("col1", 0)]


@pytest.mark.parametrize(
"expr",
[
pl.selectors.exclude("col3"),
pl.selectors.starts_with("col1") | pl.selectors.starts_with("col2"),
],
)
def test_eval_select_with_list_pl_selector(expr):
df = pl.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [4.0, 5.0, 6.0]})
sel = eval_select(df, expr)
assert sel == [("col1", 0), ("col2", 1)]


@pytest.mark.parametrize("expr", [["col2", 1.2]])
def test_eval_select_pandas_raises1(expr):
df = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [4.0, 5.0, 6.0]})
with pytest.raises(TypeError) as exc_info:
eval_select(df, expr)

assert "Only int and str are supported." in str(exc_info.value.args[0])


@pytest.mark.parametrize("expr", [3.45, {"col2"}, ("col2",)])
def test_eval_select_pandas_raises2(expr):
df = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [4.0, 5.0, 6.0]})
with pytest.raises(NotImplementedError) as exc_info:
eval_select(df, expr)

assert "Unsupported selection expr: " in str(exc_info.value.args[0])


@pytest.mark.parametrize("expr", [3.45, {6}, (7.8,)])
def test_eval_select_polars_raises(expr):
df = pl.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [4.0, 5.0, 6.0]})
with pytest.raises(TypeError) as exc_info:
eval_select(df, expr)

assert "Unsupported selection expr type:" in str(exc_info.value.args[0])


def test_eval_selector_polars_list_raises():
expr = ["col1", 1.2]
df = pl.DataFrame({"col1": [], "col2": [], "col3": []})
with pytest.raises(TypeError) as exc_info:
eval_select(df, expr)

assert "entry 1 is type: <class 'float'>" in str(exc_info.value.args[0])


def test_create_empty_frame(df: DataFrameLike):
res = create_empty_frame(df)
col = [None] * 3
Expand Down