Skip to content

Commit

Permalink
Revert "Attempt to fix "pow/2" after Polars changes"
Browse files Browse the repository at this point in the history
This reverts commit 2644d27.
  • Loading branch information
philss committed Jul 19, 2024
1 parent 2644d27 commit 54991c0
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 43 deletions.
11 changes: 1 addition & 10 deletions lib/explorer/backend/lazy_series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ defmodule Explorer.Backend.LazySeries do

@comparison_operations [:equal, :not_equal, :greater, :greater_equal, :less, :less_equal]

@basic_arithmetic_operations [:add, :subtract, :multiply, :divide]
@basic_arithmetic_operations [:add, :subtract, :multiply, :divide, :pow]
@other_arithmetic_operations [:quotient, :remainder]

@aggregation_operations [
Expand Down Expand Up @@ -453,15 +453,6 @@ defmodule Explorer.Backend.LazySeries do
end
end

@impl true
def pow(dtype, %Series{} = left, %Series{} = right) do
# Cast from the main module is needed because we may be seeing a series from another backend.
args = [data!(Explorer.Series.cast(left, dtype)), data!(right)]
data = new(:pow, args, dtype, aggregations?(args))

Backend.Series.new(data, dtype)
end

for op <- @other_arithmetic_operations do
@impl true
def unquote(op)(left, right) do
Expand Down
4 changes: 2 additions & 2 deletions lib/explorer/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2774,11 +2774,11 @@ defmodule Explorer.DataFrame do
You can overwrite existing columns as well:
iex> df = Explorer.DataFrame.new(a: ["a", "b", "c"], b: [1, 2, 3])
iex> Explorer.DataFrame.mutate_with(df, &[b: Explorer.Series.add(&1["b"], 2)])
iex> Explorer.DataFrame.mutate_with(df, &[b: Explorer.Series.pow(&1["b"], 2)])
#Explorer.DataFrame<
Polars[3 x 2]
a string ["a", "b", "c"]
b s64 [3, 4, 5]
b s64 [1, 4, 9]
>
It's possible to "reuse" a variable for different computations:
Expand Down
12 changes: 1 addition & 11 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3598,9 +3598,6 @@ defmodule Explorer.Series do
sizes are series, the series must have the same size or
at last one of them must have size of 1.
In case the expoent is a signed integer number or series,
the resultant series will be of `{:f, 64}` dtype.
## Supported dtypes
* floats: #{Shared.inspect_dtypes(@float_dtypes, backsticks: true)}
Expand All @@ -3617,13 +3614,6 @@ defmodule Explorer.Series do
iex> s = [2, 4, 6] |> Explorer.Series.from_list()
iex> Explorer.Series.pow(s, 3)
#Explorer.Series<
Polars[3]
f64 [8.0, 64.0, 216.0]
>
iex> s = [2, 4, 6] |> Explorer.Series.from_list()
iex> Explorer.Series.pow(s, Explorer.Series.from_list([3], dtype: :u32))
#Explorer.Series<
Polars[3]
s64 [8, 64, 216]
Expand Down Expand Up @@ -3667,7 +3657,7 @@ defmodule Explorer.Series do
defp cast_to_pow({:f, l}, {:f, r}), do: {:f, max(l, r)}
defp cast_to_pow({:f, l}, {n, _}) when K.in(n, [:u, :s]), do: {:f, l}
defp cast_to_pow({n, _}, {:f, r}) when K.in(n, [:u, :s]), do: {:f, r}
defp cast_to_pow({n, _}, {:s, _}) when K.in(n, [:u, :s]), do: {:f, 64}
defp cast_to_pow({n, _}, {:s, _}) when K.in(n, [:u, :s]), do: {:s, 64}
defp cast_to_pow(_, _), do: nil

@doc """
Expand Down
26 changes: 14 additions & 12 deletions test/explorer/data_frame_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ defmodule Explorer.DataFrameTest do
df = DF.new(a: [1, 2, 3, 4, 5, 6, 5], b: [9, 8, 7, 6, 5, 4, 3])

message =
"expecting the function to return a boolean LazySeries, but instead it returned a LazySeries of type {:f, 64}"
"expecting the function to return a boolean LazySeries, but instead it returned a LazySeries of type {:s, 64}"

assert_raise ArgumentError, message, fn ->
DF.filter_with(df, fn ldf ->
Expand Down Expand Up @@ -948,7 +948,7 @@ defmodule Explorer.DataFrameTest do
calc2: [-1, 0, 2],
calc3: [2, 4, 8],
calc4: [0.5, 1.0, 2.0],
calc5: [1.0, 4.0, 16.0],
calc5: [1, 4, 16],
calc6: [0, 1, 2],
calc7: [1, 0, 0],
calc8: [:nan, :nan, :nan],
Expand All @@ -964,7 +964,7 @@ defmodule Explorer.DataFrameTest do
"calc2" => {:s, 64},
"calc3" => {:s, 64},
"calc4" => {:f, 64},
"calc5" => {:f, 64},
"calc5" => {:s, 64},
"calc6" => {:s, 64},
"calc7" => {:s, 64},
"calc8" => {:f, 64},
Expand All @@ -985,6 +985,7 @@ defmodule Explorer.DataFrameTest do
calc3: multiply(2, a),
calc4: divide(2, a),
calc5: pow(2, a),
calc5_1: pow(2.0, a),
calc6: quotient(2, a),
calc7: remainder(2, a)
)
Expand All @@ -995,7 +996,8 @@ defmodule Explorer.DataFrameTest do
calc2: [1, 0, -2],
calc3: [2, 4, 8],
calc4: [2.0, 1.0, 0.5],
calc5: [2.0, 4.0, 16.0],
calc5: [2, 4, 16],
calc5_1: [2.0, 4.0, 16.0],
calc6: [2, 1, 0],
calc7: [0, 0, 2]
}
Expand All @@ -1006,15 +1008,15 @@ defmodule Explorer.DataFrameTest do
"calc2" => {:s, 64},
"calc3" => {:s, 64},
"calc4" => {:f, 64},
"calc5" => {:f, 64},
"calc5" => {:s, 64},
"calc5_1" => {:f, 64},
"calc6" => {:s, 64},
"calc7" => {:s, 64}
}
end

test "adds some columns with arithmetic operations on (lazy series, series)" do
df = DF.new(a: [1, 2, 4])
# TODO: check remainder and quotient in case they have a u32 on the right side.
series = Explorer.Series.from_list([2, 1, 2])

df1 =
Expand All @@ -1034,7 +1036,7 @@ defmodule Explorer.DataFrameTest do
calc2: [-1, 1, 2],
calc3: [2, 2, 8],
calc4: [0.5, 2.0, 2.0],
calc5: [1.0, 2.0, 16.0],
calc5: [1, 2, 16],
calc6: [0, 2, 2],
calc7: [1, 0, 0]
}
Expand All @@ -1045,7 +1047,7 @@ defmodule Explorer.DataFrameTest do
"calc2" => {:s, 64},
"calc3" => {:s, 64},
"calc4" => {:f, 64},
"calc5" => {:f, 64},
"calc5" => {:s, 64},
"calc6" => {:s, 64},
"calc7" => {:s, 64}
}
Expand All @@ -1072,7 +1074,7 @@ defmodule Explorer.DataFrameTest do
calc2: [-1, 1, 2],
calc3: [2, 2, 8],
calc4: [0.5, 2.0, 2.0],
calc5: [1.0, 2.0, 16.0],
calc5: [1, 2, 16],
calc6: [0, 2, 2],
calc7: [1, 0, 0]
}
Expand All @@ -1083,7 +1085,7 @@ defmodule Explorer.DataFrameTest do
"calc2" => {:s, 64},
"calc3" => {:s, 64},
"calc4" => {:f, 64},
"calc5" => {:f, 64},
"calc5" => {:s, 64},
"calc6" => {:s, 64},
"calc7" => {:s, 64}
}
Expand Down Expand Up @@ -1112,7 +1114,7 @@ defmodule Explorer.DataFrameTest do
calc2: [19, 38, 57],
calc3: [3, 4, 3],
calc4: [2.0, :infinity, 7.5],
calc5: [1.0, 4.0, 3.0],
calc5: [1, 4, 3],
calc6: [2, nil, 7],
calc7: [0, nil, 4]
}
Expand All @@ -1126,7 +1128,7 @@ defmodule Explorer.DataFrameTest do
"calc2" => {:s, 64},
"calc3" => {:s, 64},
"calc4" => {:f, 64},
"calc5" => {:f, 64},
"calc5" => {:s, 64},
"calc6" => {:s, 64},
"calc7" => {:s, 64}
}
Expand Down
16 changes: 8 additions & 8 deletions test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2290,8 +2290,8 @@ defmodule Explorer.SeriesTest do

result = Series.pow(base, power)

assert result.dtype == {:f, 64}
assert Series.to_list(result) == [1.0, 4.0, 3.0]
assert result.dtype == {:s, 64}
assert Series.to_list(result) == [1, 4, 3]
end
end

Expand All @@ -2315,8 +2315,8 @@ defmodule Explorer.SeriesTest do

result = Series.pow(base, power)

assert result.dtype == {:f, 64}
assert Series.to_list(result) === [1.0, 4.0, 3.0]
assert result.dtype == {:s, 64}
assert Series.to_list(result) === [1, 4, 3]
end
end

Expand Down Expand Up @@ -2392,13 +2392,13 @@ defmodule Explorer.SeriesTest do

result = Series.pow(s1, s2)

assert result.dtype == {:f, 64}
assert result.dtype == {:s, 64}
assert Series.to_list(result) == [1, nil, 3]
end

test "pow of an integer series that contains nil with an integer series" do
s1 = Series.from_list([1, nil, 3])
s2 = Series.from_list([3, 2, 1], dtype: :u32)
s2 = Series.from_list([3, 2, 1])

result = Series.pow(s1, s2)

Expand All @@ -2408,7 +2408,7 @@ defmodule Explorer.SeriesTest do

test "pow of an integer series that contains nil with an integer series also with nil" do
s1 = Series.from_list([1, nil, 3])
s2 = Series.from_list([3, nil, 1], dtype: :u32)
s2 = Series.from_list([3, nil, 1])

result = Series.pow(s1, s2)

Expand All @@ -2421,7 +2421,7 @@ defmodule Explorer.SeriesTest do

result = Series.pow(s1, 2)

assert result.dtype == {:f, 64}
assert result.dtype == {:s, 64}
assert Series.to_list(result) == [1, 4, 9]
end

Expand Down

0 comments on commit 54991c0

Please sign in to comment.