Skip to content

Commit

Permalink
update tests for 'diagonal_relaxed'
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Oct 8, 2023
1 parent 683829f commit 2638078
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 16 deletions.
2 changes: 1 addition & 1 deletion crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ impl SQLContext {
},
// UNION ALL BY NAME
// TODO: add recognition for SetQuantifier::DistinctByName
// when "https://github.com/sqlparser-rs/sqlparser-rs/pull/997" is merged
// when "https://github.com/sqlparser-rs/sqlparser-rs/pull/997" is available
SetQuantifier::AllByName => concat_lf_diagonal(vec![left, right], opts),
// UNION [DISTINCT] BY NAME
SetQuantifier::ByName => {
Expand Down
16 changes: 12 additions & 4 deletions py-polars/polars/functions/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,11 @@ def concat(
elif how == "horizontal":
out = wrap_df(plr.concat_df_horizontal(elems))
else:
allowed = get_args(ConcatMethod)
raise ValueError(f"DataFrame `how` must be one of {allowed!r}, got {how!r}")
allowed = ", ".join(repr(m) for m in get_args(ConcatMethod))
raise ValueError(
f"DataFrame `how` must be one of {{{allowed}}}, got {how!r}"
)

elif isinstance(first, pl.LazyFrame):
if how in ("vertical", "vertical_relaxed"):
return wrap_ldf(
Expand All @@ -215,8 +218,13 @@ def concat(
)
)
else:
allowed = tuple(m for m in get_args(ConcatMethod) if m != "horizontal")
raise ValueError(f"LazyFrame `how` must be one of {allowed!r}, got {how!r}")
allowed = ", ".join(
repr(m) for m in get_args(ConcatMethod) if m != "horizontal"
)
raise ValueError(
f"LazyFrame `how` must be one of {{{allowed}}}, got {how!r}"
)

elif isinstance(first, pl.Series):
if how == "vertical":
out = wrap_s(plr.concat_series(elems))
Expand Down
20 changes: 9 additions & 11 deletions py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,33 +276,31 @@ def test_window_expression_different_group_length() -> None:


def test_lazy_concat_err() -> None:
df1 = pl.DataFrame(
df = pl.DataFrame(
{
"foo": [1, 2],
"bar": [6, 7],
"ham": ["a", "b"],
}
)
df2 = pl.DataFrame(
{
"foo": [3, 4],
"ham": ["c", "d"],
"bar": [8, 9],
}
)
with pytest.raises(
ValueError,
match="LazyFrame `how` must be one of ('vertical', 'vertical_relaxed', 'diagonal', 'diagonal_relaxed', 'align'), got 'horizontal'",
match="DataFrame `how` must be one of {'vertical', 'vertical_relaxed', 'diagonal', 'diagonal_relaxed', 'horizontal', 'align'}, got 'sausage'",
):
pl.concat([df, df], how="sausage") # type: ignore[arg-type]
with pytest.raises(
ValueError,
match="LazyFrame `how` must be one of {'vertical', 'vertical_relaxed', 'diagonal', 'diagonal_relaxed', 'align'}, got 'horizontal'",
):
pl.concat([df1.lazy(), df2.lazy()], how="horizontal").collect()
pl.concat([df.lazy(), df.lazy()], how="horizontal").collect()


@pytest.mark.parametrize("how", ["horizontal", "diagonal"])
def test_series_concat_err(how: ConcatMethod) -> None:
s = pl.Series([1, 2, 3])
with pytest.raises(
ValueError,
match="Series only allows 'vertical' concat strategy",
match="Series only supports 'vertical' concat strategy",
):
pl.concat([s, s], how=how)

Expand Down

0 comments on commit 2638078

Please sign in to comment.