Skip to content

Commit

Permalink
fix: pl.repeat raise instead of panic if dtype can't be inferred
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Mar 29, 2024
1 parent 9c46183 commit 28e8e55
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
6 changes: 5 additions & 1 deletion py-polars/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,11 @@ pub fn repeat(value: PyExpr, n: PyExpr, dtype: Option<Wrap<DataType>>) -> PyResu
}

if let Expr::Literal(lv) = &value {
let av = lv.to_any_value().unwrap();
let av = lv.to_any_value().ok_or_else(||{
PyPolarsErr::from(polars_err!(ComputeError:
"The data type of the resulting column can't be inferred from the given value, please explicitly specify the dtype."
))
})?;
// Integer inputs that fit in Int32 are parsed as such
if let DataType::Int64 = av.dtype() {
let int_value = av.try_extract::<i64>().unwrap();
Expand Down
6 changes: 6 additions & 0 deletions py-polars/tests/unit/functions/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import polars as pl
from polars import ComputeError
from polars.testing import assert_frame_equal, assert_series_equal


Expand Down Expand Up @@ -201,3 +202,8 @@ def test_repeat_by_none_13053(data: list[Any], expected_data: list[list[Any]]) -
res = df.select(repeat=pl.col("x").repeat_by("by"))
expected = pl.Series("repeat", expected_data)
assert_series_equal(res.to_series(), expected)


def test_repeat_raise() -> None:
with pytest.raises(ComputeError, match="please explicitly specify the dtype"):
pl.repeat([1, 2], n=2)

0 comments on commit 28e8e55

Please sign in to comment.