Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

∇eachslice causes UndefRefError when input array contains references #807

Closed
BioTurboNick opened this issue Sep 14, 2024 · 4 comments · Fixed by #808
Closed

∇eachslice causes UndefRefError when input array contains references #807

BioTurboNick opened this issue Sep 14, 2024 · 4 comments · Fixed by #808

Comments

@BioTurboNick
Copy link
Contributor

function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim}
dys = unthunk(dys_raw)
i1 = findfirst(dy -> dy isa AbstractArray, dys)
if i1 === nothing # all slices are Zero!
return _zero_fill!(similar(x, float(eltype(x)), axes(x)))
end
T = promote_type(eltype(dys[i1]), eltype(x))
# The whole point of this gradient is that we can allocate one `dx` array:
dx = similar(x, T, axes(x))
for i in axes(x, dim)
slice = selectdim(dx, dim, i)
if dys[i] isa AbstractZero
_zero_fill!(slice) # Avoids this: copyto!([1,2,3], ZeroTangent()) == [0,2,3]
else
copyto!(slice, dys[i])
end
end
return ProjectTo(x)(dx)
end

  1. similar makes dx an uninitialized array.
  2. When dys[i] is zero, presumably the intent of _zero_fill!(slice) is to populate the array with zeros.
  3. _zero_fill! on an uninitialized array calls map!(zero, dx, dx)
  4. map! attempts to map zero onto the elements of dx, which are #undef - error

So this branch never works when the input array has references rather than inline values.

@BioTurboNick
Copy link
Contributor Author

Here, dys[i1] elements are of type ChainRulesCore.Tangent{Any, @NamedTuple{μ::Float64, σ::Float64}}, and x elements are of type Distributions.Normal{Float64}, which get promoted to Any for dx

Based on that comment, should the undef positions be filled with ZeroTangent()?

Or is there some upstream issue that's only surfacing here?

@mcabbott
Copy link
Member

I wonder why this code does T = promote_type(eltype(dys[i1]), eltype(x)). It seems that if it did T = eltype(dys[i1]) then this would just work.

I'm also not entirely sure what the second method of _zero_fill! was intended for, perhaps arrays of arrays? As you say it will fail for dx made by similar with reference types:

_zero_fill!(dx::AbstractArray{<:Number}) = fill!(dx, zero(eltype(dx)))
_zero_fill!(dx::AbstractArray) = map!(zero, dx, dx)

The tests for this are not extensive:

# Make sure pulling back an array that mixes some AbstractZeros in works right
_, back = rrule(eachcol, rand(3, 4))
@test back([1:3, ZeroTangent(), 7:9, NoTangent()]) == (NoTangent(), [1 0 7 0; 2 0 8 0; 3 0 9 0])
@test back([1:3, ZeroTangent(), 7:9, NoTangent()])[2] isa Matrix{Float64}
@test back([ZeroTangent(), ZeroTangent(), NoTangent(), NoTangent()]) == (NoTangent(), [0 0 0 0; 0 0 0 0; 0 0 0 0])

@BioTurboNick
Copy link
Contributor Author

BioTurboNick commented Sep 15, 2024

Unfortunately that doesn't just work, because the output needs to store different AbstractTangent types.

So I think T = promote_type(eltype.(dys)...) does that. But if there are mixed AbstractTangent types in the input, that means the array is abstractly typed and would still start with undef.

EDIT: Ah, and then if I strip the special _zero_fill! method with map! and remove the Number type restriction, that seems to work. I'll work on a PR of this with tests.

@BioTurboNick
Copy link
Contributor Author

So there's another case where an unexpected situation occurs.

Zygote uses wrap_chainrules_input to convert an Array{Union{Float64, Nothing}}} to Array{Any} via replacing nothing with ZeroTangent. But seems like the intent inside ∇eachslice is that zero(T) is used instead of ZeroTangent() wherever possible.

I'm wondering if there's a function where the NamedTuple tangents are decomposed into individual values, and the ZeroTangents there should be turned into 0s but instead are turned into nothing, which causes the above.

However, given that this occurs, the following appears to at least intercept it and transform it prior to being passed back to ChainRules:

@inline Zygote.wrap_chainrules_input(dxs::AbstractArray{<:Union{Nothing,T}}) where T <: Number = map(x -> x === nothing ? zero(T) : x, dxs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants