Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ChainRulesTestUtils crashes with a custom regularization function on a Flux Dense layer #266

Open
TPU22 opened this issue Nov 15, 2022 · 4 comments

Comments

@TPU22
Copy link

TPU22 commented Nov 15, 2022

I have been trying to write a custom reverse rule to a simple regularization function on a Flux Dense layer, and evaluate it with ChainRulesTestUtils. The function gradient from Zygote seems to work fine with the rules, but ChainRulesTestUtils crashes. The following code is executed just fine until the test_rrule calls. The first test_rrule tries to check whether the one-layer regularization function works, but instead it raises an error

Got exception outside of a @test 
MethodError: no method matching zero(::typeof(tanh))

The second test_rrule crashes with

Got exception outside of a @test
return type Tuple{NoTangent, Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}}}} does not match inferred return type Tuple{NoTangent, Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}}}

Any idea what could be the issue here? A bug somewhere?

using ChainRulesCore
using Flux
using Random 
using ChainRulesTestUtils

Flux.trainable(nn::Dense) = (nn.weight, nn.bias,)

function weightregularization(nn::Dense)
    return sum((nn.weight).^2.0)

end

function ChainRulesCore.rrule(::typeof(weightregularization), nn::Dense)  
    y = weightregularization(nn)
    project_w = ProjectTo(nn.weight)
    function weightregularization_pullback(ȳ)
        pullb = Tangent{Dense}(weight=project_w(ȳ * 2.0*nn.weight),   bias=ZeroTangent(), σ= NoTangent())
        return NoTangent(), pullb
    end
    return y, weightregularization_pullback
end


function totalregularization(ch::Chain{T}) where T<:Tuple{Vararg{Dense}}
    a = 0.0
    for i in ch
        a = a + sum(i.weight.^2.0)
    end
    return a

end

function ChainRulesCore.rrule(::typeof(totalregularization), ch::Chain{T})  where T<:Tuple{Vararg{Dense}}
    y = totalregularization(ch)
    function totalregularization_pullback(ȳ)
        totalpullback = [] 
        N = length(ch)
        for i = 1:N
            project_w = ProjectTo(ch[i].weight)
            push!(totalpullback, (weight= project_w(ȳ * 2.0*ch[i].weight), bias = ZeroTangent(), σ= NoTangent()))
        end      
        pullb = Tangent{Chain{T}}(layers=Tuple(totalpullback))
        return NoTangent(), pullb
    end
    return y, totalregularization_pullback
end




nn = Dense(randn(1,2), randn(1), tanh)
gr1 = gradient(weightregularization,nn)


l1 = Dense(randn(2,2), randn(2), tanh)
l2 = Dense(randn(1,2), randn(1), tanh)
ch = Chain(l1,l2)
gr2 = gradient(totalregularization,ch)


test_rrule(weightregularization,nn) 
test_rrule(totalregularization,ch)
@oxinabox
Copy link
Member

Odds are as a work around you need to either implement FiniteDifferences.to_vec or maybe rand_tangent (

rand_tangent(x) = rand_tangent(Random.GLOBAL_RNG, x)
)
for your type.
I would need to see the stack trace to know which.

Either this to_vec method
or this rand_tangent method is not smart enough to handle a struct where some fields are differentiable and others are not.

Probably the extra smartness needed is to know that for objects that have no fields (like nonclosure/functor functions), they are NoTangent()

@TPU22
Copy link
Author

TPU22 commented Nov 24, 2022

Here's the stacktrace for test_rrule(weightregularization,nn):

 Stacktrace:
    [1] test_approx(::AbstractZero, x::Any, msg::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/check_result.jl:42
    [2] test_approx(actual::Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}, expected::Any, msg::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/check_result.jl:134
    [3] _test_cotangent(accum_cotangent::Any, ad_cotangent::Any, fd_cotangent::Any; check_inferred::Any, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:299
    [4] (::ChainRulesTestUtils.var"#60#64"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})(::Any, ::Vararg{Any})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:226
    [5] (::Base.var"#59#60"{ChainRulesTestUtils.var"#60#64"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}}})(#unused#::Nothing, xs::Tuple{Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}})
      @ Base ./tuple.jl:556
    [6] BottomRF
      @ ./reduce.jl:81 [inlined]
    [7] _foldl_impl(op::Base.BottomRF{Base.var"#59#60"{ChainRulesTestUtils.var"#60#64"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}}}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{Tuple{NoTangent, Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}}, Tuple{NoTangent, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}, Tuple{NoTangent, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}})
      @ Base ./reduce.jl:62
    [8] foldl_impl(op::Base.BottomRF{Base.var"#59#60"{ChainRulesTestUtils.var"#60#64"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}}}}, nt::Nothing, itr::Base.Iterators.Zip{Tuple{Tuple{NoTangent, Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}}, Tuple{NoTangent, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}, Tuple{NoTangent, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}})
      @ Base ./reduce.jl:48
    [9] mapfoldl_impl(f::typeof(identity), op::Base.var"#59#60"{ChainRulesTestUtils.var"#60#64"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}}}, nt::Nothing, itr::Base.Iterators.Zip{Tuple{Tuple{NoTangent, Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}}, Tuple{NoTangent, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}, Tuple{NoTangent, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}})
      @ Base ./reduce.jl:44
   [10] mapfoldl(f::Function, op::Function, itr::Base.Iterators.Zip{Tuple{Tuple{NoTangent, Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}}, Tuple{NoTangent, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}, Tuple{NoTangent, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}}; init::Nothing)
      @ Base ./reduce.jl:170
   [11] #foldl#260
      @ ./reduce.jl:193 [inlined]
   [12] foreach(::Function, ::Tuple{NoTangent, Tangent{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, NoTangent}}}}, ::Tuple{NoTangent, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, ZeroTangent, NoTangent}}}}, ::Tuple{NoTangent, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}})
      @ Base ./tuple.jl:556
   [13] macro expansion
      @ ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:225 [inlined]
   [14] macro expansion
      @ /opt/julia/share/julia/stdlib/v1.8/Test/src/Test.jl:1360 [inlined]
   [15] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:196
   [16] test_rrule(config::RuleConfig, f::Any, args::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:173
   [17] test_rrule(::Any, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:170
   [18] test_rrule(::Any, ::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:169
   [19] top-level scope
      @ ~/Julia/customreg.jl:61

and here's the stacktrace for test_rrule(totalregularization,ch):

Stacktrace:
    [1] error(s::String)
      @ Base ./error.jl:35
    [2] _test_inferred(f::Any, args::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:255
    [3] _test_inferred
      @ ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:253 [inlined]
    [4] macro expansion
      @ ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:211 [inlined]
    [5] macro expansion
      @ /opt/julia/share/julia/stdlib/v1.8/Test/src/Test.jl:1360 [inlined]
    [6] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:196
    [7] test_rrule(config::RuleConfig, f::Any, args::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:173
    [8] test_rrule(::Any, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:170
    [9] test_rrule(::Any, ::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:169
   [10] top-level scope
      @ ~/Julia/customreg.jl:62

@oxinabox
Copy link
Member

oxinabox commented Dec 8, 2022

So the problem with the second one is just that it isn't inferrable. So ChainRulesTestUtils is working correctly there.
Tuple(totalpullback)) is not an inferable operation, because it can't tell how long the Tuple will be.
Rather than writing it with a for loop + push!ing to a vector, consider using the ntuple function or map (on a Tuple) -- that tends to be inferrable.
Or you can use test_rrule(...; check_inferred=false)

@oxinabox
Copy link
Member

oxinabox commented Dec 8, 2022

For the first case what is happening seems to to be that finite differencing is returning a result that contans a tanh element.
Which is weird.

Something is going wrong here:

function _make_j′vp_call(fdm, f, ȳ, xs, ignores)
f2 = _wrap_function(f, xs, ignores)
ignores = collect(ignores)
args = Any[NoTangent() for _ in 1:length(xs)]
all(ignores) && return (args...,)
sigargs = xs[.!ignores]
arginds = (1:length(xs))[.!ignores]
fd = j′vp(fdm, f2, ȳ, sigargs...)
@assert length(fd) == length(arginds)
for (dx, ind) in zip(fd, arginds)
args[ind] = ProjectTo(xs[ind])(dx)
end
return (args...,)
end

But I have no idea what.

FiniteDifferences.to_vec looks right

julia> nn = Dense(randn(1,2), randn(1), tanh)
Dense(2 => 1, tanh)  # 3 parameters

julia> FiniteDifferences.to_vec(nn)
([0.6512775053740293, 0.09461593565834911, 2.580817436122576], FiniteDifferences.var"#structtype_from_vec#29"{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, FiniteDifferences.var"#Tuple_from_vec#52"{Tuple{Int64, Int64, Int64}, Tuple{Int64, Int64, Int64}, Tuple{typeof(identity), typeof(identity), typeof(identity)}}, Tuple{FiniteDifferences.var"#Array_from_vec#34"{Matrix{Float64}, typeof(identity)}, typeof(identity), FiniteDifferences.var"#24#27"{typeof(tanh)}}}(FiniteDifferences.var"#Tuple_from_vec#52"{Tuple{Int64, Int64, Int64}, Tuple{Int64, Int64, Int64}, Tuple{typeof(identity), typeof(identity), typeof(identity)}}((2, 3, 3), (2, 1, 0), (identity, identity, identity)), (FiniteDifferences.var"#Array_from_vec#34"{Matrix{Float64}, typeof(identity)}([0.6512775053740293 0.09461593565834911], identity), identity, FiniteDifferences.var"#24#27"{typeof(tanh)}(tanh))))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants