Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward mode mutable struct support #219

Merged
merged 23 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8a177b2
WIP: begin switching forward mode over to zero_bundles for mutation s…
oxinabox Sep 28, 2023
1bffce5
Handle first order constructors
oxinabox Sep 28, 2023
15230b8
make 2nd order mutation not error
oxinabox Sep 29, 2023
8652053
fix 2nd order test
oxinabox Oct 5, 2023
1d91273
add extra second order test
oxinabox Oct 5, 2023
67a15da
Fix closures
oxinabox Oct 6, 2023
7d481d4
remove debug statements
oxinabox Dec 29, 2023
6ae4e13
update for more aggressive iszero and take NoTangent more seriously
oxinabox Dec 29, 2023
b742b3e
type of type erasure in ∂☆new
oxinabox Oct 16, 2023
e41490f
tests for type of type
oxinabox Oct 17, 2023
d0a825b
correct test
oxinabox Jan 2, 2024
f47d5f0
Comment out reverse tests that are making it hang on nigthtly
oxinabox Jan 5, 2024
7c7c628
fix quotenodes not to error
oxinabox Jan 12, 2024
2b79ef3
Correct code for higher order ∂☆new
oxinabox Jan 16, 2024
55eb407
more extra rules for static arrays
oxinabox Jan 19, 2024
b9940e3
if all partials AbstractZero don't call frule
oxinabox Jan 23, 2024
d4db011
Remove unrelated confusing comment
oxinabox Jan 24, 2024
6f57657
Remove case in frule for purely AbstractZero tangent
oxinabox Jan 24, 2024
00224ca
Require version of ChainRulesCore with MutableTangent etc
oxinabox Jan 29, 2024
a61d05b
Remove outdated comment
oxinabox Feb 1, 2024
c9fb55a
Update comment
oxinabox Feb 1, 2024
0cfbc41
Add comment
oxinabox Feb 1, 2024
8badafd
Mark tests that broke in 1.11 only broken in 1.11
oxinabox Feb 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
[compat]
AbstractDifferentiation = "0.5"
ChainRules = "1.44.6"
ChainRulesCore = "1.15.3"
ChainRulesCore = "1.20"
Combinatorics = "1"
Cthulhu = "2"
OffsetArrays = "1"
Expand Down
6 changes: 3 additions & 3 deletions src/codegen/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function fwd_transform!(ci, mi, nargs, N)
args = map(stmt.args) do stmt
emit!(mapstmt!(stmt))
end
return Expr(:call, Core._apply_iterate, FwdIterate(ZeroBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...)
return Expr(:call, Core._apply_iterate, FwdIterate(DNEBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...)
elseif isa(stmt, SSAValue)
return SSAValue(ssa_mapping[stmt.id])
elseif isa(stmt, Core.SlotNumber)
Expand Down Expand Up @@ -64,14 +64,14 @@ function fwd_transform!(ci, mi, nargs, N)
# Always disable `@inbounds`, as we don't actually know if the AD'd
# code is truly `@inbounds` or not.
elseif isexpr(stmt, :boundscheck)
return ZeroBundle{N}(true)
return DNEBundle{N}(true)
else
# Fallback case, for literals.
# If it is an Expr, then it is not a literal
if isa(stmt, Expr)
error("Unexprected statement encountered. This is a bug in Diffractor. stmt=$stmt")
end
return Expr(:call, ZeroBundle{N}, stmt)
return Expr(:call, zero_bundle{N}(), stmt)
end
end

Expand Down
12 changes: 6 additions & 6 deletions src/codegen/forward_demand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,12 @@
return transform!(ir, arg, order, maparg)
elseif isa(arg, GlobalRef)
@assert isconst(arg)
return ZeroBundle{order}(getfield(arg.mod, arg.name))
return zero_bundle{order}()(getfield(arg.mod, arg.name))
elseif isa(arg, QuoteNode)
return ZeroBundle{order}(arg.value)
return zero_bundle{order}()(arg.value)

Check warning on line 269 in src/codegen/forward_demand.jl

View check run for this annotation

Codecov / codecov/patch

src/codegen/forward_demand.jl#L269

Added line #L269 was not covered by tests
end
@assert !isa(arg, Expr)
return ZeroBundle{order}(arg)
return zero_bundle{order}()(arg)
end

for (ssa, (order, custom)) in enumerate(ssa_orders)
Expand Down Expand Up @@ -309,7 +309,7 @@
stmt = insert_node!(ir, ssa, NewInstruction(inst))
end

replace_call!(ir, SSAValue(ssa), Expr(:call, ZeroBundle{order}, stmt))
replace_call!(ir, SSAValue(ssa), Expr(:call, zero_bundle{order}(), stmt))

Check warning on line 312 in src/codegen/forward_demand.jl

View check run for this annotation

Codecov / codecov/patch

src/codegen/forward_demand.jl#L312

Added line #L312 was not covered by tests
elseif isa(stmt, SSAValue) || isa(stmt, QuoteNode)
inst[:inst] = maparg(stmt, SSAValue(ssa), order)
inst[:type] = Any
Expand All @@ -329,7 +329,7 @@
inst[:type] = Any
inst[:flag] |= CC.IR_FLAG_REFINED
else
val = ZeroBundle{order}(inst[:inst])
val = zero_bundle{order}()(inst[:inst])

Check warning on line 332 in src/codegen/forward_demand.jl

View check run for this annotation

Codecov / codecov/patch

src/codegen/forward_demand.jl#L332

Added line #L332 was not covered by tests
inst[:inst] = val
inst[:type] = Const(val)
end
Expand Down Expand Up @@ -362,6 +362,6 @@
rt = CC._ir_abstract_constant_propagation(interp, irsv)

ir = compact!(ir)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

return ir
end
21 changes: 20 additions & 1 deletion src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,13 @@
end

function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,T}) where {S, T, N, L}
SArray{S, T, N, L}(x), SArray{S, T, N, L}(∂x.backing)
Δx = SArray{S, T, N, L}(ChainRulesCore.backing(∂x))
SArray{S, T, N, L}(x), Δx

Check warning on line 176 in src/extra_rules.jl

View check run for this annotation

Codecov / codecov/patch

src/extra_rules.jl#L175-L176

Added lines #L175 - L176 were not covered by tests
end

Base.view(t::Tangent{T}, inds) where T<:SVector = view(T(ChainRulesCore.backing(t.data)), inds)
Base.getindex(t::Tangent{<:SVector, <:NamedTuple}, ind::Int) = ChainRulesCore.backing(t.data)[ind]

Check warning on line 180 in src/extra_rules.jl

View check run for this annotation

Codecov / codecov/patch

src/extra_rules.jl#L179-L180

Added lines #L179 - L180 were not covered by tests

function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L}
SArray{S, T, N, L}(x), SArray{S}(∂x)
end
Expand Down Expand Up @@ -262,3 +266,18 @@

# Avoid https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495
ChainRulesCore._backing_error(P::Type{<:Base.Pairs}, G::Type{<:NamedTuple}, E::Type{<:AbstractDict}) = nothing

# Needed for higher order so we don't see the `backing` field of StructuralTangents, just the contents
# SHould these be in ChainRules/ChainRulesCore?
# is this always the right behavour, or just because of how we do higher order
function ChainRulesCore.frule((_, Δ, _, _), ::typeof(getproperty), strct::StructuralTangent, sym::Union{Int,Symbol}, inbounds)
return (getproperty(strct, sym, inbounds), getproperty(Δ, sym))

Check warning on line 274 in src/extra_rules.jl

View check run for this annotation

Codecov / codecov/patch

src/extra_rules.jl#L273-L274

Added lines #L273 - L274 were not covered by tests
end


function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setproperty!), obj::MutableTangent, field, x)
ȯbj::MutableTangent
y = setproperty!(obj, field, x)
ẏ = setproperty!(ȯbj, field, ẋ)
return y, ẏ
end
10 changes: 5 additions & 5 deletions src/higher_fwd_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,27 @@

jeval(j, x) = j(x)
for f in (sin, cos, exp)
function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(f)}, x::TaylorBundle{N}) where {N}
function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(f)}, x::TaylorBundle{N}) where {N}
njet(Val{N}(), primal(fb), primal(x))(x)
end
function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, fb::ZeroBundle{M, typeof(f)}, x::TaylorBundle{M}) where {N, M}
function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, fb::AbstractZeroBundle{M, typeof(f)}, x::TaylorBundle{M}) where {N, M}

Check warning on line 25 in src/higher_fwd_rules.jl

View check run for this annotation

Codecov / codecov/patch

src/higher_fwd_rules.jl#L25

Added line #L25 was not covered by tests
∂⃖ₙ(jeval, njet(Val{N+M}(), primal(fb), primal(x)), x)
end
end

# TODO: It's a bit embarassing that we need to write these out, but currently the
# compiler is not strong enough to automatically lift the frule. Let's hope we
# can delete these in the near future.
function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
TaylorBundle{N}(primal(a) + primal(b),
map(+, a.tangent.coeffs, b.tangent.coeffs))
end

function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::ZeroBundle{N}) where {N}
function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::AbstractZeroBundle{N}) where {N}
TaylorBundle{N}(primal(a) + primal(b), a.tangent.coeffs)
end

function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(-)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(-)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
TaylorBundle{N}(primal(a) - primal(b),
map(-, a.tangent.coeffs, b.tangent.coeffs))
end
Expand Down
2 changes: 1 addition & 1 deletion src/stage1/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ end

n_getfield(∂ₙ::∂☆{N}, b::ATB{N}, x::Union{Symbol, Int}) where {N} = ∂ₙ(ZeroBundle{N}(getfield), b, ZeroBundle{N}(x))

function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)},
function (∂ₙ::∂☆{N})(zc::AbstractZeroBundle{N, typeof(copy)},
bc::ATB{N, <:Broadcasted}) where {N}
bc = ∂ₙ(ZeroBundle{N}(Broadcast.flatten), bc)
args = n_getfield(∂ₙ, bc, :args)
Expand Down
47 changes: 29 additions & 18 deletions src/stage1/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,33 @@

function shuffle_base(r)
(primal, dual) = r
if isa(dual, Union{NoTangent, ZeroTangent})
if dual isa NoTangent
UniformBundle{1}(primal, dual)
else
if dual isa ZeroTangent # Normalize zero for type-stability reasons
dual = zero_tangent(primal)
end
TaylorBundle{1}(primal, (dual,))
end
end

function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
r = frule(DiffractorRuleConfig(), map(first_partial, args), map(primal, args)...)
r = _frule(map(first_partial, args), map(primal, args)...)
if r === nothing
return ∂☆recurse{1}()(args...)
else
return shuffle_base(r)
end
end

_frule(partials, primals...) = frule(DiffractorRuleConfig(), partials, primals...)
function _frule(::NTuple{<:Any, AbstractZero}, f, primal_args...)
# frules are linear in partials, so zero maps to zero, no need to evaluate the frule
# If all partials are immutable AbstractZero subtyoes we know we don't have to worry about a mutating frule either
r = f(primal_args...)
return r, zero_tangent(r)
end

function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
bundles = map((p,a) -> ExplicitTangentBundle{1}(a, (p,)), partials, args)
result = ∂☆internal{1}()(bundles...)
Expand All @@ -131,12 +142,12 @@
function (::∂☆internal{N})(f::AbstractZeroBundle{N}, args::AbstractZeroBundle{N}...) where {N}
f_v = primal(f)
args_v = map(primal, args)
return ZeroBundle{N}(f_v(args_v...))
return zero_bundle{N}()(f_v(args_v...))
end
function (::∂☆internal{1})(f::AbstractZeroBundle{1}, args::AbstractZeroBundle{1}...)
f_v = primal(f)
args_v = map(primal, args)
return ZeroBundle{1}(f_v(args_v...))
return zero_bundle{1}()(f_v(args_v...))
end

function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N}
Expand Down Expand Up @@ -193,25 +204,25 @@
end
(f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...)

function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N}
function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N}
∂vararg{N}()(map(FwdMap(f), destructure(tup))...)
end

function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N}
function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N}
# TODO: This could do an inplace map! to avoid the extra rebundling
rebundle(map(FwdMap(f), map(unbundle, args)...))
end

function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N}...) where {N}
function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N}...) where {N}
∂☆recurse{N}()(ZeroBundle{N, typeof(map)}(map), f, args...)
end


function (::∂☆{N})(f::ZeroBundle{N, typeof(ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N}
function (::∂☆{N})(f::AbstractZeroBundle{N, typeof(ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N}

Check warning on line 221 in src/stage1/forward.jl

View check run for this annotation

Codecov / codecov/patch

src/stage1/forward.jl#L221

Added line #L221 was not covered by tests
ifelse(arg.primal, args...)
end

function (::∂☆{N})(f::ZeroBundle{N, typeof(Core.ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N}
function (::∂☆{N})(f::AbstractZeroBundle{N, typeof(Core.ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N}

Check warning on line 225 in src/stage1/forward.jl

View check run for this annotation

Codecov / codecov/patch

src/stage1/forward.jl#L225

Added line #L225 was not covered by tests
Core.ifelse(arg.primal, args...)
end

Expand All @@ -233,48 +244,48 @@
primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2))))
end

function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N}
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N}
Core._apply_iterate(FwdIterate(iterate), this, (f,), args...)
end


function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N}
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N}
r = iterate(destructure(t))
r === nothing && return ZeroBundle{N}(nothing)
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
end

function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N}
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N}
r = iterate(destructure(t), primal(a), map(primal, args)...)
r === nothing && return ZeroBundle{N}(nothing)
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
end

function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N}
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N}
r = Base.indexed_iterate(destructure(t), primal(i))
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
end

function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N}
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N}
r = Base.indexed_iterate(destructure(t), primal(i), primal(st1), map(primal, st)...)
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
end

function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TangentBundle{N, <:Tuple}, i::ATB{N}, st::ATB{N}...) where {N}
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Base.indexed_iterate)}, t::TangentBundle{N, <:Tuple}, i::ATB{N}, st::ATB{N}...) where {N}

Check warning on line 274 in src/stage1/forward.jl

View check run for this annotation

Codecov / codecov/patch

src/stage1/forward.jl#L274

Added line #L274 was not covered by tests
∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1))
end

function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::ZeroBundle) where {N}
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::AbstractZeroBundle) where {N}
field_ind = primal(i)
the_partials = ntuple(order_ind->partial(t, order_ind)[field_ind], N)
TaylorBundle{N}(primal(t)[field_ind], the_partials)
end

function (this::∂☆{N})(::ZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N}
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N}
DNEBundle{N}(typeof(primal(x)))
end

function (this::∂☆{N})(f::ZeroBundle{N, Core.IntrinsicFunction}, args::ATB{N}...) where {N}
function (this::∂☆{N})(f::AbstractZeroBundle{N, Core.IntrinsicFunction}, args::ATB{N}...) where {N}
ff = primal(f)
if ff === Base.not_int
DNEBundle{N}(ff(map(primal, args)...))
Expand Down
2 changes: 1 addition & 1 deletion src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ end
lifted_getfield(x::ZeroTangent, s) = ZeroTangent()
lifted_getfield(x::NoTangent, s) = NoTangent()

lifted_getfield(x::Tangent, s) = getproperty(x, s)
lifted_getfield(x::StructuralTangent, s) = getproperty(x, s)

function lifted_getfield(x::Tangent{<:Tangent{T}}, s) where T
bb = getfield(x.backing, 1)
Expand Down
10 changes: 5 additions & 5 deletions src/stage1/mixed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,26 @@
primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2))))
end

function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N}
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N}
Core._apply_iterate(FwdIterate(iterate), this, (f,), args...)
end
=#

function (this::∂⃖{N})(that::∂☆{M}, ::ZeroBundle{M, typeof(Core._apply_iterate)},
function (this::∂⃖{N})(that::∂☆{M}, ::AbstractZeroBundle{M, typeof(Core._apply_iterate)},

Check warning on line 78 in src/stage1/mixed.jl

View check run for this annotation

Codecov / codecov/patch

src/stage1/mixed.jl#L78

Added line #L78 was not covered by tests
iterate, f, args::ATB{M, <:Tuple}...) where {N, M}
@assert primal(iterate) === Base.iterate
x, ∂⃖f = Core._apply_iterate(FwdIterate(iterate), this, (that, f), args...)
return x, ApplyOdd{1, c_order(N)}(UnApply{map(x->length(primal(x)), args)}(), ∂⃖f)
end


function ChainRules.rrule(∂::∂☆{N}, m::ZeroBundle{N, typeof(map)}, p::ZeroBundle{N, typeof(+)}, A::ATB{N}, B::ATB{N}) where {N}
function ChainRules.rrule(∂::∂☆{N}, m::AbstractZeroBundle{N, typeof(map)}, p::AbstractZeroBundle{N, typeof(+)}, A::ATB{N}, B::ATB{N}) where {N}

Check warning on line 86 in src/stage1/mixed.jl

View check run for this annotation

Codecov / codecov/patch

src/stage1/mixed.jl#L86

Added line #L86 was not covered by tests
∂(m, p, A, B), Δ->(NoTangent(), NoTangent(), NoTangent(), Δ, Δ)
end

mapev_unbundled(_, js, a) = rebundle(mapev(js, unbundle(a)))
function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, ::ZeroBundle{M, typeof(map)},
f::ZeroBundle{M}, a::ATB{M, <:Array}) where {N, M}
function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, ::AbstractZeroBundle{M, typeof(map)},

Check warning on line 91 in src/stage1/mixed.jl

View check run for this annotation

Codecov / codecov/patch

src/stage1/mixed.jl#L91

Added line #L91 was not covered by tests
f::AbstractZeroBundle{M}, a::ATB{M, <:Array}) where {N, M}
@assert Base.issingletontype(typeof(primal(f)))
js = map(primal(a)) do x
∂f = ∂☆{N+M}()(ZeroBundle{N+M}(primal(f)),
Expand Down
16 changes: 8 additions & 8 deletions src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,31 @@ struct ∂☆new{N}; end
function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...)
primal_args = map(primal, xs)
the_primal = _construct(B, primal_args)

tangent_tup = map(first_partial, xs)
the_partial = if B<:Tuple
Tangent{B, typeof(tangent_tup)}(tangent_tup)
else
names = fieldnames(B)
tangent_nt = NamedTuple{names}(tangent_tup)
Tangent{B, typeof(tangent_nt)}(tangent_nt)
StructuralTangent{B}(tangent_nt)
end
return TaylorBundle{1, B}(the_primal, (the_partial,))
B2 = typeof(the_primal) # HACK: if the_primal actually has types in it then we want to make sure we get DataType not Type(...)
return TaylorBundle{1, B2}(the_primal, (the_partial,))
end

function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N}
primal_args = map(primal, xs)
the_primal = _construct(B, primal_args)

the_partials = ntuple(Val{N}()) do ii
iith_order_type = ii==1 ? B : Any # the type of the higher order tangents isn't worth tracking
tangent_tup = map(x->partial(x, ii), xs)
tangent = if B<:Tuple
Tangent{iith_order_type, typeof(tangent_tup)}(tangent_tup)
Tangent{B, typeof(tangent_tup)}(tangent_tup)
else
# No matter the order we use `StructuralTangent{B}` for the partial
# It follows all required properties of the tangent to the n-1th order tangent
names = fieldnames(B)
tangent_nt = NamedTuple{names}(tangent_tup)
Tangent{iith_order_type, typeof(tangent_nt)}(tangent_nt)
StructuralTangent{B}(tangent_nt)
end
return tangent
end
Expand All @@ -50,7 +50,7 @@ _construct(::Type{B}, args) where B<:Tuple = B(args)
# Hack for making things that do not have public constructors constructable:
@generated _construct(B::Type, args) = Expr(:splatnew, :B, :args)

@generated (::∂☆new{N})(B::Type) where {N} = return :(ZeroBundle{$N}($(Expr(:new, :B))))
@generated (::∂☆new{N})(B::Type) where {N} = return :(zero_bundle{$N}()($(Expr(:new, :B))))

# Sometimes we don't know whether or not we need to the ZeroBundle when doing
# the transform, so this can happen - allow it for now.
Expand Down
Loading
Loading