Skip to content

Commit

Permalink
Adapt pow/2 to follow the same rules from Polars
Browse files Browse the repository at this point in the history
If exponent is float, it follows dtype of exponent. Otherwise, it follows dtype of base.
See: pola-rs/polars#15506
  • Loading branch information
philss committed Jul 19, 2024
1 parent 54991c0 commit cb3886c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
6 changes: 5 additions & 1 deletion lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3595,9 +3595,13 @@ defmodule Explorer.Series do
Raises a numeric series to the power of the exponent.
At least one of the arguments must be a series. If both
sizes are series, the series must have the same size or
sides are series, the series must have the same size or
at last one of them must have size of 1.
Note that this operation can fail if the exponent is a
signed integer series or scalar containing negative values,
and the base is also of an integer type.
## Supported dtypes
* floats: #{Shared.inspect_dtypes(@float_dtypes, backsticks: true)}
Expand Down
37 changes: 23 additions & 14 deletions test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2283,7 +2283,7 @@ defmodule Explorer.SeriesTest do
end
end

test "pow(uint, sint) == float64" do
test "pow(uint, sint) == sint" do
for u_base <- [8, 16, 32, 64], s_power <- [8, 16, 32, 64] do
base = Series.from_list([1, 2, 3], dtype: {:u, u_base})
power = Series.from_list([3, 2, 1], dtype: {:s, s_power})
Expand All @@ -2308,7 +2308,7 @@ defmodule Explorer.SeriesTest do
end
end

test "pow(sint, sint) == float64" do
test "pow(sint, sint) == sint" do
for s_base <- [8, 16, 32, 64], s_power <- [8, 16, 32, 64] do
base = Series.from_list([1, 2, 3], dtype: {:s, s_base})
power = Series.from_list([3, 2, 1], dtype: {:s, s_power})
Expand Down Expand Up @@ -2360,10 +2360,13 @@ defmodule Explorer.SeriesTest do
s1 = Series.from_list([1, 2, 3])
s2 = Series.from_list([1, -2, 3])

result = Series.pow(s1, s2)
message =
"Polars Error: invalid operation: invalid operation: conversion from `i64` to `u32` failed in column 'exponent' for 1 out of 3 values: [-2]\n\n" <>
"Hint: if you were trying to raise an integer to a negative integer power, please cast your base or exponent to float first."

assert result.dtype == {:f, 64}
assert Series.to_list(result) === [1.0, 0.25, 27.0]
assert_raise RuntimeError, message, fn ->
Series.pow(s1, s2)
end
end

test "pow of an integer series with a float series" do
Expand Down Expand Up @@ -2428,10 +2431,13 @@ defmodule Explorer.SeriesTest do
test "pow of an integer series with a negative integer scalar value on the right-hand side" do
s1 = Series.from_list([1, 2, 3])

result = Series.pow(s1, -2)
message =
"Polars Error: invalid operation: invalid operation: conversion from `i64` to `u32` failed in column 'literal' for 1 out of 1 values: [-2]\n\n" <>
"Hint: if you were trying to raise an integer to a negative integer power, please cast your base or exponent to float first."

assert result.dtype == {:f, 64}
assert Series.to_list(result) === [1.0, 1 / 4, 1 / 9]
assert_raise RuntimeError, message, fn ->
Series.pow(s1, -2)
end
end

test "pow of an integer series with a float scalar value on the right-hand side" do
Expand Down Expand Up @@ -2484,25 +2490,28 @@ defmodule Explorer.SeriesTest do

result = Series.pow(2, s1)

assert result.dtype == {:f, 64}
assert Series.to_list(result) === [2.0, 4.0, 8.0]
assert result.dtype == {:s, 64}
assert Series.to_list(result) === [2, 4, 8]
end

test "pow of an integer series that contains negative integer with an integer scalar value on the left-hand side" do
s1 = Series.from_list([1, -2, 3])

result = Series.pow(2, s1)
message =
"Polars Error: invalid operation: invalid operation: conversion from `i64` to `u32` failed in column 'exponent' for 1 out of 3 values: [-2]\n\n" <>
"Hint: if you were trying to raise an integer to a negative integer power, please cast your base or exponent to float first."

assert result.dtype == {:f, 64}
assert Series.to_list(result) === [2.0, 0.25, 8.0]
assert_raise RuntimeError, message, fn ->
Series.pow(2, s1)
end
end

test "pow of an integer series with a negative integer scalar value on the left-hand side" do
s1 = Series.from_list([1, 2, 3])

result = Series.pow(-2, s1)

assert result.dtype == {:f, 64}
assert result.dtype == {:s, 64}
assert Series.to_list(result) == [-2, 4, -8]
end

Expand Down

0 comments on commit cb3886c

Please sign in to comment.