Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_rrule() fails in iterate() #263

Open
dfdx opened this issue Sep 25, 2022 · 0 comments
Open

test_rrule() fails in iterate() #263

dfdx opened this issue Sep 25, 2022 · 0 comments
Labels
bug Something isn't working

Comments

@dfdx
Copy link
Contributor

dfdx commented Sep 25, 2022

It looks like finite difference implementation has hard time going through iterate (see MRE and full stacktrace below):

juia> test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false)
test_rrule: iterate on Float64,Float64: Error During Test at /home/azbs/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
  Got exception outside of a @test
  DimensionMismatch: second dimension of A, 2, does not match length of x, 1
  Stacktrace:
    [1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
      @ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:493
...
    [7] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/finite_difference_calls.jl:51
...

Below I provide rrule() implementation for iterate on tuples for convenience, but perhaps the example can be narrowed down to direct invocation of _make_j′vp_call(). Also, I see the same error when testing with arrays.

MWE
using ChainRulesCore
import ChainRulesCore.rrule
using ChainRulesTestUtils

function ungetfield(dy, s::Tuple, f::Int)
    T = typeof(s)
    return Tangent{T}([i == f ? dy : ZeroTangent() for i=1:length(s)]...)
end

function rrule(::typeof(iterate), t::Tuple)
    y = iterate(t)
    function iterate_pullback(dy)
        dy = unthunk(dy)
        return NoTangent(), ungetfield(dy[1], t, 1)
    end
    return y, iterate_pullback
end

function rrule(::typeof(iterate), t::Tuple, i::Integer)
    y = iterate(t, i)
    function iterate_pullback(dy)
        dy = unthunk(dy)
        return NoTangent(), ungetfield(dy[1], t, i), ZeroTangent()
    end
    return y, iterate_pullback
end

test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false)
Complete stacktrace
julia> test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false)
test_rrule: iterate on Float64,Float64: Error During Test at /home/azbs/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
  Got exception outside of a @test
  DimensionMismatch: second dimension of A, 2, does not match length of x, 1
  Stacktrace:
    [1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
      @ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:493
    [2] mul!
      @ /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:93 [inlined]
    [3] mul!
      @ /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined]
    [4] *(tA::LinearAlgebra.Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
      @ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:86
    [5] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
      @ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/grad.jl:80
    [6] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::ChainRulesTestUtils.var"#fnew#53"{ChainRulesTestUtils.var"#call#63"{NamedTuple{(), Tuple{}}}, Tuple{typeof(iterate), Tuple{Float64, Float64}}, Tuple{Bool, Bool}}, ȳ::Tangent{Tuple{Float64, Int64}, Tuple{Float64, NoTangent}}, x::Tuple{Float64, Float64})
      @ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/grad.jl:73
    [7] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/finite_difference_calls.jl:51
    [8] macro expansion
      @ ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:224 [inlined]
    [9] macro expansion
      @ /opt/julia-1.8.0/share/julia/stdlib/v1.8/Test/src/Test.jl:1357 [inlined]
   [10] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:196
   [11] test_rrule(::Any, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:170
   [12] top-level scope
      @ REPL[1]:1
   [13] eval
      @ ./boot.jl:368 [inlined]
   [14] eval
      @ ./Base.jl:65 [inlined]
   [15] repleval(m::Module, code::Expr, #unused#::String)
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:222
   [16] (::VSCodeServer.var"#107#109"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:186
   [17] with_logstate(f::Function, logstate::Any)
      @ Base.CoreLogging ./logging.jl:511
   [18] with_logger
      @ ./logging.jl:623 [inlined]
   [19] (::VSCodeServer.var"#106#108"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:187
   [20] #invokelatest#2
      @ ./essentials.jl:729 [inlined]
   [21] invokelatest(::Any)
      @ Base ./essentials.jl:726
   [22] macro expansion
      @ ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/eval.jl:34 [inlined]
   [23] (::VSCodeServer.var"#61#62")()
      @ VSCodeServer ./task.jl:484
Test Summary:                          | Pass  Error  Total  Time
test_rrule: iterate on Float64,Float64 |    3      1      4  0.0s
ERROR: Some tests did not pass: 3 passed, 0 failed, 1 errored, 0 broken.
@oxinabox oxinabox added the bug Something isn't working label Oct 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants