Skip to content

Commit

Permalink
if all partials AbstractZero don't call frule
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Jan 23, 2024
1 parent e9c1348 commit 0524b08
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/stage1/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,22 @@ function shuffle_base(r)
end

function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
r = frule(DiffractorRuleConfig(), map(first_partial, args), map(primal, args)...)
r = _frule(map(first_partial, args), map(primal, args)...)
if r === nothing
return ∂☆recurse{1}()(args...)
else
return shuffle_base(r)
end
end

_frule(partials, primals...) = frule(DiffractorRuleConfig(), partials, primals...)
function _frule(::NTuple{<:Any, AbstractZero}, f, primal_args...)
# frules are linear in partials, so zero maps to zero, no need to evaluate the frule
# If all partials are immutable AbstractZero subtyoes we know we don't have to worry about a mutating frule either
r = f(primal_args...)
return r, zero_tangent(r)
end

function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
bundles = map((p,a) -> ExplicitTangentBundle{1}(a, (p,)), partials, args)
result = ∂☆internal{1}()(bundles...)
Expand Down

0 comments on commit 0524b08

Please sign in to comment.