Skip to content

Commit

Permalink
Make recursive_acc/accumulate more recursive
Browse files Browse the repository at this point in the history
  • Loading branch information
danielwe committed Sep 18, 2024
1 parent 6a19be2 commit 6d3fabb
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6562,7 +6562,13 @@ end
Base.@_inline_meta
prev = getfield(x, i)
next = getfield(y, i)
recursive_add(prev, next, f, forcelhs)
ST = Core.Typeof(prev)
if !mutable_register(ST)
recursive_add(prev, next, f, forcelhs)
elseif !(ST <: Integer)
recursive_accumulate(prev, next, f)
prev
end
end)
end

Expand Down Expand Up @@ -6591,18 +6597,19 @@ end

# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y)
@inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F=identity) where {T, F}
if !mutable_register(T)
for I in eachindex(x)
prev = x[I]
for I in eachindex(x, y)
if !mutable_register(T)
@inbounds x[I] = recursive_add(x[I], (@inbounds y[I]), f, mutable_register)
elseif !(T <: Integer)
recursive_accumulate((@inbounds x[I]), (@inbounds y[I]), f)
end
end
end


# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y)
@inline function recursive_accumulate(x::Core.Box, y::Core.Box, f::F=identity) where {F}
recursive_accumulate(x.contents, y.contents, seen, f)
recursive_accumulate(x.contents, y.contents, f)
end

@inline function recursive_accumulate(x::T, y::T, f::F=identity) where {T, F}
Expand All @@ -6613,12 +6620,14 @@ end
for i in 1:nf
if isdefined(x, i)
xi = getfield(x, i)
yi = getfield(y, i)
ST = Core.Typeof(xi)
if !mutable_register(ST)
@assert ismutable(x)
yi = getfield(y, i)
nexti = recursive_add(xi, yi, f, mutable_register)
setfield!(x, i, nexti)
elseif !(ST <: Integer)
recursive_accumulate(xi, yi, f)
end
end
end
Expand Down

0 comments on commit 6d3fabb

Please sign in to comment.