Skip to content

Commit

Permalink
stage2(forward): add method table backedge for non-existing frule m…
Browse files Browse the repository at this point in the history
…ethod (#182)
  • Loading branch information
aviatesk authored Jul 13, 2023
1 parent 2462184 commit 3d5bee0
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 19 deletions.
96 changes: 83 additions & 13 deletions Manifest.toml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 8 additions & 1 deletion src/analysis/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,19 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize
frule_preargtypes = Any[Const(ChainRulesCore.frule), Tuple{Nothing,Vararg{Any,nargs}}]
frule_argtypes = append!(frule_preargtypes, arginfo.argtypes)
frule_arginfo = ArgInfo(nothing, frule_argtypes)
frule_si = StmtInfo(true)
frule_atype = CC.argtypes_to_type(frule_argtypes)
# turn off frule analysis in the frule to avoid cycling
interp′ = disable_forward(interp)
frule_call = CC.abstract_call_known(interp′, ChainRulesCore.frule, frule_arginfo, StmtInfo(true), sv, #=max_methods=#-1)
frule_call = CC.abstract_call_gf_by_type(interp′,
ChainRulesCore.frule, frule_arginfo, frule_si, frule_atype, sv, #=max_methods=#-1)
if frule_call.rt !== Const(nothing)
return CallMeta(primal_call.rt, primal_call.effects, FRuleCallInfo(primal_call.info, frule_call))
else
CC.add_mt_backedge!(sv, frule_mt, frule_atype)
end

return nothing
end

const frule_mt = methods(ChainRulesCore.frule).mt
26 changes: 21 additions & 5 deletions test/stage2_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,28 @@ module stage2_fwd
end

myminus(a, b) = a - b
self_minus(a) = myminus(a, a)
ChainRulesCore.@scalar_rule myminus(x, y) (true, -1)
let self_minus′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64})
@test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64})
@test self_minus′(1.0) == 0.
end
ChainRulesCore.@scalar_rule myminus(x, y) (true, true) # frule for `x - y`
let self_minus′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64})
@test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64})
@test self_minus′(1.0) == 2.
end

self_minus(a) = myminus(a, a)
let self_minus′′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}, 2)
@test isa(self_minus′′, Core.OpaqueClosure{Tuple{Float64}, Float64})
@test self_minus′′(1.0) == 0.
myminus2(a, b) = a - b
self_minus2(a) = myminus2(a, a)
let self_minus2′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus2), Float64})
@test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64})
@test self_minus2′(1.0) == 0.
end
ChainRulesCore.@scalar_rule myminus2(x, y) (true, true) # frule for `x - y`
let self_minus2′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus2), Float64})
@test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64})
@test self_minus2′(1.0) == 2.
end

@testset "structs" begin
Expand All @@ -43,4 +59,4 @@ module stage2_fwd
g(x) = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(f), Diffractor.TaylorBundle{1}(x, (1.0,)))
Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(g), Diffractor.TaylorBundle{1}(10f0, (1.0,)))
end
end
end

0 comments on commit 3d5bee0

Please sign in to comment.