diff --git a/Project.toml b/Project.toml index 1d9ae724..8aa9aadd 100644 --- a/Project.toml +++ b/Project.toml @@ -18,7 +18,7 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [compat] AbstractDifferentiation = "0.5" ChainRules = "1.44.6" -ChainRulesCore = "1.15.3" +ChainRulesCore = "1.20" Combinatorics = "1" Cthulhu = "2" OffsetArrays = "1" diff --git a/src/codegen/forward.jl b/src/codegen/forward.jl index 2df865da..e8f5bcdd 100644 --- a/src/codegen/forward.jl +++ b/src/codegen/forward.jl @@ -35,7 +35,7 @@ function fwd_transform!(ci, mi, nargs, N) args = map(stmt.args) do stmt emit!(mapstmt!(stmt)) end - return Expr(:call, Core._apply_iterate, FwdIterate(ZeroBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...) + return Expr(:call, Core._apply_iterate, FwdIterate(DNEBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...) elseif isa(stmt, SSAValue) return SSAValue(ssa_mapping[stmt.id]) elseif isa(stmt, Core.SlotNumber) @@ -64,14 +64,14 @@ function fwd_transform!(ci, mi, nargs, N) # Always disable `@inbounds`, as we don't actually know if the AD'd # code is truly `@inbounds` or not. elseif isexpr(stmt, :boundscheck) - return ZeroBundle{N}(true) + return DNEBundle{N}(true) else # Fallback case, for literals. # If it is an Expr, then it is not a literal if isa(stmt, Expr) error("Unexprected statement encountered. This is a bug in Diffractor. stmt=$stmt") end - return Expr(:call, ZeroBundle{N}, stmt) + return Expr(:call, zero_bundle{N}(), stmt) end end diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl index 169fba90..74c78694 100644 --- a/src/codegen/forward_demand.jl +++ b/src/codegen/forward_demand.jl @@ -264,12 +264,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}}; return transform!(ir, arg, order, maparg) elseif isa(arg, GlobalRef) @assert isconst(arg) - return ZeroBundle{order}(getfield(arg.mod, arg.name)) + return zero_bundle{order}()(getfield(arg.mod, arg.name)) elseif isa(arg, QuoteNode) - return ZeroBundle{order}(arg.value) + return zero_bundle{order}()(arg.value) end @assert !isa(arg, Expr) - return ZeroBundle{order}(arg) + return zero_bundle{order}()(arg) end for (ssa, (order, custom)) in enumerate(ssa_orders) @@ -309,7 +309,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}}; stmt = insert_node!(ir, ssa, NewInstruction(inst)) end - replace_call!(ir, SSAValue(ssa), Expr(:call, ZeroBundle{order}, stmt)) + replace_call!(ir, SSAValue(ssa), Expr(:call, zero_bundle{order}(), stmt)) elseif isa(stmt, SSAValue) || isa(stmt, QuoteNode) inst[:inst] = maparg(stmt, SSAValue(ssa), order) inst[:type] = Any @@ -329,7 +329,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}}; inst[:type] = Any inst[:flag] |= CC.IR_FLAG_REFINED else - val = ZeroBundle{order}(inst[:inst]) + val = zero_bundle{order}()(inst[:inst]) inst[:inst] = val inst[:type] = Const(val) end @@ -362,6 +362,6 @@ function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::Met rt = CC._ir_abstract_constant_propagation(interp, irsv) ir = compact!(ir) - + return ir end diff --git a/src/extra_rules.jl b/src/extra_rules.jl index 0b0e8b51..b9bcff7e 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -172,9 +172,13 @@ function ChainRules.rrule(::DiffractorRuleConfig, ::Type{SArray{S, T, N, L}}, x: end function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,T}) where {S, T, N, L} - SArray{S, T, N, L}(x), SArray{S, T, N, L}(∂x.backing) + Δx = SArray{S, T, N, L}(ChainRulesCore.backing(∂x)) + SArray{S, T, N, L}(x), Δx end +Base.view(t::Tangent{T}, inds) where T<:SVector = view(T(ChainRulesCore.backing(t.data)), inds) +Base.getindex(t::Tangent{<:SVector, <:NamedTuple}, ind::Int) = ChainRulesCore.backing(t.data)[ind] + function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L} SArray{S, T, N, L}(x), SArray{S}(∂x) end @@ -262,3 +266,18 @@ Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDi # Avoid https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495 ChainRulesCore._backing_error(P::Type{<:Base.Pairs}, G::Type{<:NamedTuple}, E::Type{<:AbstractDict}) = nothing + +# Needed for higher order so we don't see the `backing` field of StructuralTangents, just the contents +# SHould these be in ChainRules/ChainRulesCore? +# is this always the right behavour, or just because of how we do higher order +function ChainRulesCore.frule((_, Δ, _, _), ::typeof(getproperty), strct::StructuralTangent, sym::Union{Int,Symbol}, inbounds) + return (getproperty(strct, sym, inbounds), getproperty(Δ, sym)) +end + + +function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setproperty!), obj::MutableTangent, field, x) + ȯbj::MutableTangent + y = setproperty!(obj, field, x) + ẏ = setproperty!(ȯbj, field, ẋ) + return y, ẏ +end diff --git a/src/higher_fwd_rules.jl b/src/higher_fwd_rules.jl index d67a44e1..8486b8bd 100644 --- a/src/higher_fwd_rules.jl +++ b/src/higher_fwd_rules.jl @@ -19,10 +19,10 @@ end jeval(j, x) = j(x) for f in (sin, cos, exp) - function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(f)}, x::TaylorBundle{N}) where {N} + function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(f)}, x::TaylorBundle{N}) where {N} njet(Val{N}(), primal(fb), primal(x))(x) end - function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, fb::ZeroBundle{M, typeof(f)}, x::TaylorBundle{M}) where {N, M} + function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, fb::AbstractZeroBundle{M, typeof(f)}, x::TaylorBundle{M}) where {N, M} ∂⃖ₙ(jeval, njet(Val{N+M}(), primal(fb), primal(x)), x) end end @@ -30,16 +30,16 @@ end # TODO: It's a bit embarassing that we need to write these out, but currently the # compiler is not strong enough to automatically lift the frule. Let's hope we # can delete these in the near future. -function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N} +function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N} TaylorBundle{N}(primal(a) + primal(b), map(+, a.tangent.coeffs, b.tangent.coeffs)) end -function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::ZeroBundle{N}) where {N} +function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::AbstractZeroBundle{N}) where {N} TaylorBundle{N}(primal(a) + primal(b), a.tangent.coeffs) end -function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(-)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N} +function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(-)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N} TaylorBundle{N}(primal(a) - primal(b), map(-, a.tangent.coeffs, b.tangent.coeffs)) end diff --git a/src/stage1/broadcast.jl b/src/stage1/broadcast.jl index e04bd60a..e32dcfef 100644 --- a/src/stage1/broadcast.jl +++ b/src/stage1/broadcast.jl @@ -9,7 +9,7 @@ end n_getfield(∂ₙ::∂☆{N}, b::ATB{N}, x::Union{Symbol, Int}) where {N} = ∂ₙ(ZeroBundle{N}(getfield), b, ZeroBundle{N}(x)) -function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)}, +function (∂ₙ::∂☆{N})(zc::AbstractZeroBundle{N, typeof(copy)}, bc::ATB{N, <:Broadcasted}) where {N} bc = ∂ₙ(ZeroBundle{N}(Broadcast.flatten), bc) args = n_getfield(∂ₙ, bc, :args) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 0279ffba..a699caee 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -98,15 +98,18 @@ struct ∂☆shuffle{N}; end function shuffle_base(r) (primal, dual) = r - if isa(dual, Union{NoTangent, ZeroTangent}) + if dual isa NoTangent UniformBundle{1}(primal, dual) else + if dual isa ZeroTangent # Normalize zero for type-stability reasons + dual = zero_tangent(primal) + end TaylorBundle{1}(primal, (dual,)) end end function (::∂☆internal{1})(args::AbstractTangentBundle{1}...) - r = frule(DiffractorRuleConfig(), map(first_partial, args), map(primal, args)...) + r = _frule(map(first_partial, args), map(primal, args)...) if r === nothing return ∂☆recurse{1}()(args...) else @@ -114,6 +117,14 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...) end end +_frule(partials, primals...) = frule(DiffractorRuleConfig(), partials, primals...) +function _frule(::NTuple{<:Any, AbstractZero}, f, primal_args...) + # frules are linear in partials, so zero maps to zero, no need to evaluate the frule + # If all partials are immutable AbstractZero subtyoes we know we don't have to worry about a mutating frule either + r = f(primal_args...) + return r, zero_tangent(r) +end + function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...) bundles = map((p,a) -> ExplicitTangentBundle{1}(a, (p,)), partials, args) result = ∂☆internal{1}()(bundles...) @@ -131,12 +142,12 @@ end function (::∂☆internal{N})(f::AbstractZeroBundle{N}, args::AbstractZeroBundle{N}...) where {N} f_v = primal(f) args_v = map(primal, args) - return ZeroBundle{N}(f_v(args_v...)) + return zero_bundle{N}()(f_v(args_v...)) end function (::∂☆internal{1})(f::AbstractZeroBundle{1}, args::AbstractZeroBundle{1}...) f_v = primal(f) args_v = map(primal, args) - return ZeroBundle{1}(f_v(args_v...)) + return zero_bundle{1}()(f_v(args_v...)) end function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N} @@ -193,25 +204,25 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}} end (f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...) -function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N} +function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N} ∂vararg{N}()(map(FwdMap(f), destructure(tup))...) end -function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N} +function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N} # TODO: This could do an inplace map! to avoid the extra rebundling rebundle(map(FwdMap(f), map(unbundle, args)...)) end -function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N}...) where {N} +function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N}...) where {N} ∂☆recurse{N}()(ZeroBundle{N, typeof(map)}(map), f, args...) end -function (::∂☆{N})(f::ZeroBundle{N, typeof(ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N} +function (::∂☆{N})(f::AbstractZeroBundle{N, typeof(ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N} ifelse(arg.primal, args...) end -function (::∂☆{N})(f::ZeroBundle{N, typeof(Core.ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N} +function (::∂☆{N})(f::AbstractZeroBundle{N, typeof(Core.ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N} Core.ifelse(arg.primal, args...) end @@ -233,48 +244,48 @@ end primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2)))) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N} Core._apply_iterate(FwdIterate(iterate), this, (f,), args...) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N} r = iterate(destructure(t)) r === nothing && return ZeroBundle{N}(nothing) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N} r = iterate(destructure(t), primal(a), map(primal, args)...) r === nothing && return ZeroBundle{N}(nothing) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N} r = Base.indexed_iterate(destructure(t), primal(i)) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N} r = Base.indexed_iterate(destructure(t), primal(i), primal(st1), map(primal, st)...) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TangentBundle{N, <:Tuple}, i::ATB{N}, st::ATB{N}...) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Base.indexed_iterate)}, t::TangentBundle{N, <:Tuple}, i::ATB{N}, st::ATB{N}...) where {N} ∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1)) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::ZeroBundle) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::AbstractZeroBundle) where {N} field_ind = primal(i) the_partials = ntuple(order_ind->partial(t, order_ind)[field_ind], N) TaylorBundle{N}(primal(t)[field_ind], the_partials) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N} DNEBundle{N}(typeof(primal(x))) end -function (this::∂☆{N})(f::ZeroBundle{N, Core.IntrinsicFunction}, args::ATB{N}...) where {N} +function (this::∂☆{N})(f::AbstractZeroBundle{N, Core.IntrinsicFunction}, args::ATB{N}...) where {N} ff = primal(f) if ff === Base.not_int DNEBundle{N}(ff(map(primal, args)...)) diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 278be4a2..73ff47e7 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -390,7 +390,7 @@ end lifted_getfield(x::ZeroTangent, s) = ZeroTangent() lifted_getfield(x::NoTangent, s) = NoTangent() -lifted_getfield(x::Tangent, s) = getproperty(x, s) +lifted_getfield(x::StructuralTangent, s) = getproperty(x, s) function lifted_getfield(x::Tangent{<:Tangent{T}}, s) where T bb = getfield(x.backing, 1) diff --git a/src/stage1/mixed.jl b/src/stage1/mixed.jl index a1d30a86..874f73b4 100644 --- a/src/stage1/mixed.jl +++ b/src/stage1/mixed.jl @@ -70,12 +70,12 @@ function (f::FwdIterate)(arg::ATB{N}, st) where {N} primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2)))) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N} Core._apply_iterate(FwdIterate(iterate), this, (f,), args...) end =# -function (this::∂⃖{N})(that::∂☆{M}, ::ZeroBundle{M, typeof(Core._apply_iterate)}, +function (this::∂⃖{N})(that::∂☆{M}, ::AbstractZeroBundle{M, typeof(Core._apply_iterate)}, iterate, f, args::ATB{M, <:Tuple}...) where {N, M} @assert primal(iterate) === Base.iterate x, ∂⃖f = Core._apply_iterate(FwdIterate(iterate), this, (that, f), args...) @@ -83,13 +83,13 @@ function (this::∂⃖{N})(that::∂☆{M}, ::ZeroBundle{M, typeof(Core._apply_i end -function ChainRules.rrule(∂::∂☆{N}, m::ZeroBundle{N, typeof(map)}, p::ZeroBundle{N, typeof(+)}, A::ATB{N}, B::ATB{N}) where {N} +function ChainRules.rrule(∂::∂☆{N}, m::AbstractZeroBundle{N, typeof(map)}, p::AbstractZeroBundle{N, typeof(+)}, A::ATB{N}, B::ATB{N}) where {N} ∂(m, p, A, B), Δ->(NoTangent(), NoTangent(), NoTangent(), Δ, Δ) end mapev_unbundled(_, js, a) = rebundle(mapev(js, unbundle(a))) -function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, ::ZeroBundle{M, typeof(map)}, - f::ZeroBundle{M}, a::ATB{M, <:Array}) where {N, M} +function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, ::AbstractZeroBundle{M, typeof(map)}, + f::AbstractZeroBundle{M}, a::ATB{M, <:Array}) where {N, M} @assert Base.issingletontype(typeof(primal(f))) js = map(primal(a)) do x ∂f = ∂☆{N+M}()(ZeroBundle{N+M}(primal(f)), diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index 2c561e73..fa8a99fe 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -15,31 +15,31 @@ struct ∂☆new{N}; end function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...) primal_args = map(primal, xs) the_primal = _construct(B, primal_args) - tangent_tup = map(first_partial, xs) the_partial = if B<:Tuple Tangent{B, typeof(tangent_tup)}(tangent_tup) else names = fieldnames(B) tangent_nt = NamedTuple{names}(tangent_tup) - Tangent{B, typeof(tangent_nt)}(tangent_nt) + StructuralTangent{B}(tangent_nt) end - return TaylorBundle{1, B}(the_primal, (the_partial,)) + B2 = typeof(the_primal) # HACK: if the_primal actually has types in it then we want to make sure we get DataType not Type(...) + return TaylorBundle{1, B2}(the_primal, (the_partial,)) end function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N} primal_args = map(primal, xs) the_primal = _construct(B, primal_args) - the_partials = ntuple(Val{N}()) do ii - iith_order_type = ii==1 ? B : Any # the type of the higher order tangents isn't worth tracking tangent_tup = map(x->partial(x, ii), xs) tangent = if B<:Tuple - Tangent{iith_order_type, typeof(tangent_tup)}(tangent_tup) + Tangent{B, typeof(tangent_tup)}(tangent_tup) else + # No matter the order we use `StructuralTangent{B}` for the partial + # It follows all required properties of the tangent to the n-1th order tangent names = fieldnames(B) tangent_nt = NamedTuple{names}(tangent_tup) - Tangent{iith_order_type, typeof(tangent_nt)}(tangent_nt) + StructuralTangent{B}(tangent_nt) end return tangent end @@ -50,7 +50,7 @@ _construct(::Type{B}, args) where B<:Tuple = B(args) # Hack for making things that do not have public constructors constructable: @generated _construct(B::Type, args) = Expr(:splatnew, :B, :args) -@generated (::∂☆new{N})(B::Type) where {N} = return :(ZeroBundle{$N}($(Expr(:new, :B)))) +@generated (::∂☆new{N})(B::Type) where {N} = return :(zero_bundle{$N}()($(Expr(:new, :B)))) # Sometimes we don't know whether or not we need to the ZeroBundle when doing # the transform, so this can happen - allow it for now. diff --git a/src/tangent.jl b/src/tangent.jl index 5bad9e29..146cd8cf 100644 --- a/src/tangent.jl +++ b/src/tangent.jl @@ -418,3 +418,21 @@ end function ChainRulesCore.rrule(::typeof(rebundle), atb) rebundle(atb), Δ->throw(Δ) end + + +""" + (::zero_bundle{N})(primal) + +Creates a bundle with a zero tangent. +""" +struct zero_bundle{N} end +function (::zero_bundle{N})(primal) where N + # We still use a Uniform bundle e.g. if primal has NoTangent + if zero_tangent(primal) isa AbstractZero + return UniformBundle{N}(primal, zero_tangent(primal) ) + else + # Note: it is important that zero_tangent(primal) is called in ntuple + # so it gets distrinct values for each order, so it doesn't alias if mutated. + return TaylorBundle{N}(primal, ntuple(_->zero_tangent(primal), N)) + end +end \ No newline at end of file diff --git a/test/forward.jl b/test/forward.jl index dec11d97..f8040639 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -161,6 +161,18 @@ end end +@testset "types in tuples" begin + function foo(a) + tup = (a, 2a, Int) + return tup[2] + end + + let var"'" = Diffractor.PrimeDerivativeFwd + @test foo'(100.0) == 2.0 + end +end + + @testset "taylor_compatible" begin taylor_compatible = Diffractor.taylor_compatible diff --git a/test/forward_mutation.jl b/test/forward_mutation.jl new file mode 100644 index 00000000..04e6f0ff --- /dev/null +++ b/test/forward_mutation.jl @@ -0,0 +1,87 @@ +# module forward_mutation +using Diffractor +using Diffractor: ∂☆, ZeroBundle, TaylorBundle +using Diffractor: bundle, first_partial, TaylorTangentIndex +using ChainRulesCore +using Test + + +mutable struct MDemo1 + x::Float64 +end + +@testset "construction" begin + 🍞 = ∂☆{1}()(ZeroBundle{1}(MDemo1), TaylorBundle{1}(1.5, (1.0,))) + @test 🍞[TaylorTangentIndex(1)] isa MutableTangent{MDemo1} + @test 🍞[TaylorTangentIndex(1)].x == 1.0 + + 🥯 = ∂☆{2}()(ZeroBundle{2}(MDemo1), TaylorBundle{2}(1.5, (1.0, 10.0))) + @test 🥯[TaylorTangentIndex(1)] isa MutableTangent{MDemo1} + @test 🥯[TaylorTangentIndex(1)].x == 1.0 + @test 🥯[TaylorTangentIndex(2)] isa MutableTangent + @test 🥯[TaylorTangentIndex(2)].x == 10.0 +end + +@testset "basis struct work: double" begin + function double!(val::MDemo1) + val.x *= 2.0 + return val + end + function wrap_and_double(x) + val = MDemo1(x) + double!(val) + end + # first derivative + 🐰 = ∂☆{1}()(ZeroBundle{1}(wrap_and_double), TaylorBundle{1}(1.5, (1.0,))) + @test first_partial(🐰) isa MutableTangent{MDemo1} + @test first_partial(🐰).x == 2.0 + + # second derivative + 🐇 = ∂☆{2}()(ZeroBundle{2}(wrap_and_double), TaylorBundle{2}(1.5, (1.0, 0.0))) + @test 🐇[TaylorTangentIndex(1)] isa MutableTangent{MDemo1} + @test 🐇[TaylorTangentIndex(1)].x == 2.0 + @test 🐇[TaylorTangentIndex(2)] isa MutableTangent + @test 🐇[TaylorTangentIndex(2)].x == 0.0 +end + +@testset "basis struct work: square" begin + function square!(val::MDemo1) + val.x ^= 2.0 + return val + end + function wrap_and_square(x) + val = MDemo1(x) + square!(val) + end + # first derivative + 🐰 = ∂☆{1}()(ZeroBundle{1}(wrap_and_square), TaylorBundle{1}(10.0, (1.0,))) + @test first_partial(🐰) isa MutableTangent{MDemo1} + @test first_partial(🐰).x == 20.0 + + # second derivative + 🐇 = ∂☆{2}()(ZeroBundle{2}(wrap_and_square), TaylorBundle{2}(100.0, (1.0, 0.0))) + @test 🐇[TaylorTangentIndex(1)] isa MutableTangent{MDemo1} + @test 🐇[TaylorTangentIndex(1)].x == 200.0 + @test 🐇[TaylorTangentIndex(2)] isa MutableTangent + @test 🐇[TaylorTangentIndex(2)].x == 2.0 +end + +@testset "closure" begin + function bar(x) + z = x + 1.0 + function foo!(y) + z = z * y + return z + end + + foo!(2) + foo!(2) + return z + end + + 🥯 = ∂☆{1}()(ZeroBundle{1}(bar), TaylorBundle{1}(10.0, (1.0,))) + @test 🥯[TaylorTangentIndex(1)] == 4.0 +end + + +# end # module \ No newline at end of file diff --git a/test/reverse.jl b/test/reverse.jl index 4e81b921..f525b10e 100644 --- a/test/reverse.jl +++ b/test/reverse.jl @@ -98,14 +98,16 @@ let var"'" = Diffractor.PrimeDerivativeBack # TODO This currently causes a segfault, c.f. https://github.com/JuliaLang/julia/pull/48742 # @test @inferred(complicated_2sin''''(1.0)) == 2sin''''(1.0) broken=true - # Control flow cases + # Control flow cases: + # if @test @inferred((x->simple_control_flow(true, x))'(1.0)) == sin'(1.0) @test @inferred((x->simple_control_flow(false, x))'(1.0)) == cos'(1.0) @test (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;] - @test times_three_while'(1.0) == 3.0 - + + # while + # @test times_three_while'(1.0) == 3.0 # hangs in 1.11 pow5p(x) = (x->mypow(x, 5))'(x) - @test pow5p(1.0) == 5.0 + #@test pow5p(1.0) == 5.0 # hangs in 1.11 end end diff --git a/test/runtests.jl b/test/runtests.jl index 8b75d5fc..847f3cce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,11 +18,13 @@ const bwd = Diffractor.PrimeDerivativeBack "tangent.jl", "forward_diff_no_inf.jl", "forward.jl", + "forward_mutation.jl", "reverse.jl", "regression.jl", "AbstractDifferentiationTests.jl" #"pinn.jl", # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24) ) + @info "testing" file include(file) end diff --git a/test/stage2_fwd.jl b/test/stage2_fwd.jl index 987fbb3f..6bc48d08 100644 --- a/test/stage2_fwd.jl +++ b/test/stage2_fwd.jl @@ -6,7 +6,8 @@ module stage2_fwd @test sin′(1.0) == cos(1.0) end let sin′′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(mysin), Float64}, 2) - @test isa(sin′′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + # This broke some time between 1.10 and 1.11-DEV.10001 + @test isa(sin′′, Core.OpaqueClosure{Tuple{Float64}, Float64}) broken=VERSION>=v"1.11-" @test sin′′(1.0) == -sin(1.0) end @@ -14,24 +15,27 @@ module stage2_fwd self_minus(a) = myminus(a, a) ChainRulesCore.@scalar_rule myminus(x, y) (true, -1) let self_minus′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}) - @test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + # This broke some time between 1.10 and 1.11-DEV.10001 + @test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) broken=VERSION>=v"1.11-" @test self_minus′(1.0) == 0. end ChainRulesCore.@scalar_rule myminus(x, y) (true, true) # frule for `x - y` let self_minus′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}) - @test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + # This broke some time between 1.10 and 1.11-DEV.10001 + @test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) broken=VERSION>=v"1.11-" @test self_minus′(1.0) == 2. end myminus2(a, b) = a - b self_minus2(a) = myminus2(a, a) let self_minus2′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus2), Float64}) - @test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + @test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) broken=VERSION>=v"1.11-" @test self_minus2′(1.0) == 0. end ChainRulesCore.@scalar_rule myminus2(x, y) (true, true) # frule for `x - y` let self_minus2′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus2), Float64}) - @test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + # This broke some time between 1.10 and 1.11-DEV.10001 + @test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) broken=VERSION>=v"1.11-" @test self_minus2′(1.0) == 2. end diff --git a/test/tangent.jl b/test/tangent.jl index dc271710..8bf2cac5 100644 --- a/test/tangent.jl +++ b/test/tangent.jl @@ -122,4 +122,19 @@ end @test truncate(et, Val(1)) == TaylorTangent((1.0,)) end +@testset "zero_bundle" begin + zero_bundle = Diffractor.zero_bundle + + tup_zb = zero_bundle{2}()((1, 0)) + @test tup_zb isa Diffractor.AbstractTangentBundle{2} + @test iszero(tup_zb[TaylorTangentIndex(1)]) + @test iszero(tup_zb[TaylorTangentIndex(2)]) + + + ref_zb = zero_bundle{2}()(Ref(1.5)) + @test ref_zb isa TaylorBundle{2} + @test iszero(ref_zb[TaylorTangentIndex(1)]) + @test iszero(ref_zb[TaylorTangentIndex(2)]) +end + end # module