Skip to content

Commit

Permalink
Avoid StackOverflowError in generic recursive dot (#53030)
Browse files Browse the repository at this point in the history
A quick check to throw an error early if `first(x) == x && first(y) ==
y`, in which case the recursive `dot` will lead to a stack overflow.

Close #35654
  • Loading branch information
jishnub authored Jan 25, 2024
1 parent 14cf64f commit 936701e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,8 @@ function dot(x, y) # arbitrary iterables
end
(vx, xs) = ix
(vy, ys) = iy
typeof(vx) == typeof(x) && typeof(vy) == typeof(y) && throw(ArgumentError(
"cannot evaluate dot recursively if the type of an element is identical to that of the container"))
s = dot(vx, vy)
while true
ix = iterate(x, xs)
Expand Down
6 changes: 6 additions & 0 deletions stdlib/LinearAlgebra/test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,12 @@ end
end
end

@testset "avoid stackoverflow in dot" begin
@test_throws "cannot evaluate dot recursively" dot('a', 'c')
@test_throws "cannot evaluate dot recursively" dot('a', 'b':'c')
@test_throws "x and y are of different lengths" dot(1, 1:2)
end

@testset "generalized dot #32739" begin
for elty in (Int, Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFloat})
n = 10
Expand Down

0 comments on commit 936701e

Please sign in to comment.