From 0524b0886882245e091caffe1d3a38d3cfaa88a3 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 23 Jan 2024 19:48:41 +0800 Subject: [PATCH] if all partials AbstractZero don't call frule --- src/stage1/forward.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 14ddfc55..a699caee 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -109,7 +109,7 @@ function shuffle_base(r) end function (::∂☆internal{1})(args::AbstractTangentBundle{1}...) - r = frule(DiffractorRuleConfig(), map(first_partial, args), map(primal, args)...) + r = _frule(map(first_partial, args), map(primal, args)...) if r === nothing return ∂☆recurse{1}()(args...) else @@ -117,6 +117,14 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...) end end +_frule(partials, primals...) = frule(DiffractorRuleConfig(), partials, primals...) +function _frule(::NTuple{<:Any, AbstractZero}, f, primal_args...) + # frules are linear in partials, so zero maps to zero, no need to evaluate the frule + # If all partials are immutable AbstractZero subtyoes we know we don't have to worry about a mutating frule either + r = f(primal_args...) + return r, zero_tangent(r) +end + function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...) bundles = map((p,a) -> ExplicitTangentBundle{1}(a, (p,)), partials, args) result = ∂☆internal{1}()(bundles...)