Skip to content

Commit

Permalink
more extra rules for static arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Jan 19, 2024
1 parent e1c7c7e commit b642a9f
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,13 @@ function ChainRules.rrule(::DiffractorRuleConfig, ::Type{SArray{S, T, N, L}}, x:
end

function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,T}) where {S, T, N, L}
SArray{S, T, N, L}(x), SArray{S, T, N, L}(∂x.backing)
#TODO: we really shouldn't actually see the isa(∂x, AbstractZero) case since the frule should be called then
Δx = isa(∂x, AbstractZero) ? ∂x : SArray{S, T, N, L}(ChainRulesCore.backing(∂x))
SArray{S, T, N, L}(x), Δx
end

Base.view(t::Tangent{T}, inds) where T<:SVector = view(T(ChainRulesCore.backing(t.data)), inds)

function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L}
SArray{S, T, N, L}(x), SArray{S}(∂x)
end
Expand Down

0 comments on commit b642a9f

Please sign in to comment.