diff --git a/docs/blog/polars-styling/index.qmd b/docs/blog/polars-styling/index.qmd index b834bb7cf..7196dfdb7 100644 --- a/docs/blog/polars-styling/index.qmd +++ b/docs/blog/polars-styling/index.qmd @@ -198,12 +198,12 @@ gt_with_spanners = ( # Table column spanners ---- .tab_spanner( - label = "Time", - columns = time_cols + label="Time", + columns=time_cols ) .tab_spanner( - label = "Measurement", - columns = cs.all().exclude(time_cols) + label="Measurement", + columns=cs.exclude(time_cols) ) ) diff --git a/docs/blog/superbowl-squares/_code.py b/docs/blog/superbowl-squares/_code.py index 6c46c08cf..5a06dd135 100644 --- a/docs/blog/superbowl-squares/_code.py +++ b/docs/blog/superbowl-squares/_code.py @@ -45,7 +45,7 @@ def team_final_digits(game: pl.DataFrame, team_code: str) -> pl.DataFrame: .with_columns(joint=pl.col("prop") * pl.col("prop_right")) .sort("final_digit", "final_digit_right") .pivot(values="joint", columns="final_digit_right", index="final_digit") - .with_columns((cs.all().exclude("final_digit") * 100).round(1)) + .with_columns((cs.exclude("final_digit") * 100).round(1)) ) # Display ----- diff --git a/docs/get-started/basic-styling.qmd b/docs/get-started/basic-styling.qmd index 09ffa87ba..1e1d9c713 100644 --- a/docs/get-started/basic-styling.qmd +++ b/docs/get-started/basic-styling.qmd @@ -222,7 +222,7 @@ import polars.selectors as cs gt_pl_air.tab_style( style=style.fill(color="yellow"), locations=loc.body( - columns=cs.all().exclude(["Month", "Day"]), + columns=cs.exclude(["Month", "Day"]), rows=pl.col("Temp") == pl.col("Temp").max() ) ) diff --git a/great_tables/_tbl_data.py b/great_tables/_tbl_data.py index 970a775c4..03e38a95e 100644 --- a/great_tables/_tbl_data.py +++ b/great_tables/_tbl_data.py @@ -307,26 +307,47 @@ def _( def _(data: PlDataFrame, expr: Union[list[str], _selector_proxy_], strict: bool = True) -> _NamePos: # TODO: how to annotate type of a polars selector? # Seems to be polars.selectors._selector_proxy_. + import polars.selectors as cs + + from functools import reduce + from operator import or_ from polars import Expr - from polars import selectors if isinstance(expr, (str, int)): expr = [expr] + if isinstance(expr, list): + all_selectors = [ + cs.by_name(x) if isinstance(x, str) else cs.by_index(x) if isinstance(x, int) else x + for x in expr + ] + + _validate_selector_list(all_selectors) + + expr = reduce(or_, all_selectors, cs.by_name()) + col_pos = {k: ii for ii, k in enumerate(data.columns)} # just in case _selector_proxy_ gets renamed or something # it inherits from Expr, so we can just use that in a pinch - cls_selector = getattr(selectors, "_selector_proxy_", Expr) + cls_selector = getattr(cs, "_selector_proxy_", Expr) - if not isinstance(expr, (list, cls_selector)): + if not isinstance(expr, cls_selector): raise TypeError(f"Unsupported selection expr type: {type(expr)}") # I don't think there's a way to get the columns w/o running the selection - final_columns = selectors.expand_selector(data, expr) + final_columns = cs.expand_selector(data, expr) return [(col, col_pos[col]) for col in final_columns] +def _validate_selector_list(selectors: list): + from polars.selectors import is_selector + + for ii, sel in enumerate(selectors): + if not is_selector(sel): + raise TypeError(f"Expected a list of selectors, but entry {ii} is type: {type(sel)}.") + + def _eval_select_from_list( columns: list[str], expr: list[Union[str, int]] ) -> list[tuple[str, int]]: diff --git a/tests/test_formats.py b/tests/test_formats.py index 96047922b..647f24661 100644 --- a/tests/test_formats.py +++ b/tests/test_formats.py @@ -79,7 +79,7 @@ def test_format_repr_snap(snapshot): assert_repr_html(snapshot, new_gt) -@pytest.mark.parametrize("expr", [[0, -1], pl.selectors.all().exclude("y")]) +@pytest.mark.parametrize("expr", [[0, -1], pl.selectors.exclude("y")]) def test_format_col_selection_multi(expr: Any): df = pd.DataFrame({"x": [1], "y": [2], "z": [3]}) diff --git a/tests/test_tbl_data.py b/tests/test_tbl_data.py index 9d633837b..d30894844 100644 --- a/tests/test_tbl_data.py +++ b/tests/test_tbl_data.py @@ -70,14 +70,61 @@ def test_reorder(df: DataFrameLike): assert_frame_equal(res, dst) -@pytest.mark.parametrize( - "expr", [["col2", "col1"], [1, 0], ["col2", 0], pl.selectors.all().exclude("col3")] -) +@pytest.mark.parametrize("expr", [["col2", "col1"], [1, 0], ["col2", 0], [1, "col1"]]) def test_eval_select_with_list(df: DataFrameLike, expr): - sel = eval_select(df, ["col2", "col1"]) + sel = eval_select(df, expr) assert sel == [("col2", 1), ("col1", 0)] +@pytest.mark.parametrize( + "expr", + [ + pl.selectors.exclude("col3"), + pl.selectors.starts_with("col1") | pl.selectors.starts_with("col2"), + ], +) +def test_eval_select_with_list_pl_selector(expr): + df = pl.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [4.0, 5.0, 6.0]}) + sel = eval_select(df, expr) + assert sel == [("col1", 0), ("col2", 1)] + + +@pytest.mark.parametrize("expr", [["col2", 1.2]]) +def test_eval_select_pandas_raises1(expr): + df = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [4.0, 5.0, 6.0]}) + with pytest.raises(TypeError) as exc_info: + eval_select(df, expr) + + assert "Only int and str are supported." in str(exc_info.value.args[0]) + + +@pytest.mark.parametrize("expr", [3.45, {"col2"}, ("col2",)]) +def test_eval_select_pandas_raises2(expr): + df = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [4.0, 5.0, 6.0]}) + with pytest.raises(NotImplementedError) as exc_info: + eval_select(df, expr) + + assert "Unsupported selection expr: " in str(exc_info.value.args[0]) + + +@pytest.mark.parametrize("expr", [3.45, {6}, (7.8,)]) +def test_eval_select_polars_raises(expr): + df = pl.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"], "col3": [4.0, 5.0, 6.0]}) + with pytest.raises(TypeError) as exc_info: + eval_select(df, expr) + + assert "Unsupported selection expr type:" in str(exc_info.value.args[0]) + + +def test_eval_selector_polars_list_raises(): + expr = ["col1", 1.2] + df = pl.DataFrame({"col1": [], "col2": [], "col3": []}) + with pytest.raises(TypeError) as exc_info: + eval_select(df, expr) + + assert "entry 1 is type: " in str(exc_info.value.args[0]) + + def test_create_empty_frame(df: DataFrameLike): res = create_empty_frame(df) col = [None] * 3