diff --git a/src/extra_rules.jl b/src/extra_rules.jl index b2622695..7acfeb85 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -172,9 +172,14 @@ function ChainRules.rrule(::DiffractorRuleConfig, ::Type{SArray{S, T, N, L}}, x: end function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,T}) where {S, T, N, L} - SArray{S, T, N, L}(x), SArray{S, T, N, L}(∂x.backing) + #TODO: we really shouldn't actually see the isa(∂x, AbstractZero) case since the frule should be called then + Δx = isa(∂x, AbstractZero) ? ∂x : SArray{S, T, N, L}(ChainRulesCore.backing(∂x)) + SArray{S, T, N, L}(x), Δx end +Base.view(t::Tangent{T}, inds) where T<:SVector = view(T(ChainRulesCore.backing(t.data)), inds) +Base.getindex(t::Tangent{<:SVector, <:NamedTuple}, ind::Int) = ChainRulesCore.backing(t.data)[ind] + function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L} SArray{S, T, N, L}(x), SArray{S}(∂x) end