From 8a177b2f513e9c22489325770f3d9eb3ac486216 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 28 Sep 2023 16:58:47 +0800 Subject: [PATCH 01/23] WIP: begin switching forward mode over to zero_bundles for mutation support --- src/codegen/forward.jl | 2 +- src/codegen/forward_demand.jl | 10 +++++----- src/stage1/recurse_fwd.jl | 2 +- src/tangent.jl | 15 +++++++++++++++ test/forward_mutation.jl | 16 ++++++++++++++++ test/runtests.jl | 1 + test/tangent.jl | 15 +++++++++++++++ 7 files changed, 54 insertions(+), 7 deletions(-) create mode 100644 test/forward_mutation.jl diff --git a/src/codegen/forward.jl b/src/codegen/forward.jl index 2df865da..1210cdfc 100644 --- a/src/codegen/forward.jl +++ b/src/codegen/forward.jl @@ -71,7 +71,7 @@ function fwd_transform!(ci, mi, nargs, N) if isa(stmt, Expr) error("Unexprected statement encountered. This is a bug in Diffractor. stmt=$stmt") end - return Expr(:call, ZeroBundle{N}, stmt) + return Expr(:call, zero_bundle{order}(), stmt) end end diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl index 169fba90..368d9b55 100644 --- a/src/codegen/forward_demand.jl +++ b/src/codegen/forward_demand.jl @@ -264,12 +264,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}}; return transform!(ir, arg, order, maparg) elseif isa(arg, GlobalRef) @assert isconst(arg) - return ZeroBundle{order}(getfield(arg.mod, arg.name)) + return zero_bundle{order}()(getfield(arg.mod, arg.name)) elseif isa(arg, QuoteNode) - return ZeroBundle{order}(arg.value) + return zero_bundle{order}(){order}(arg.value) end @assert !isa(arg, Expr) - return ZeroBundle{order}(arg) + return zero_bundle{order}()(arg) end for (ssa, (order, custom)) in enumerate(ssa_orders) @@ -309,7 +309,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}}; stmt = insert_node!(ir, ssa, NewInstruction(inst)) end - replace_call!(ir, SSAValue(ssa), Expr(:call, ZeroBundle{order}, stmt)) + replace_call!(ir, SSAValue(ssa), Expr(:call, zero_bundle{order}(), stmt)) elseif isa(stmt, SSAValue) || isa(stmt, QuoteNode) inst[:inst] = maparg(stmt, SSAValue(ssa), order) inst[:type] = Any @@ -329,7 +329,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}}; inst[:type] = Any inst[:flag] |= CC.IR_FLAG_REFINED else - val = ZeroBundle{order}(inst[:inst]) + val = zero_bundle{order}()(inst[:inst]) inst[:inst] = val inst[:type] = Const(val) end diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index 2c561e73..8630fd06 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -50,7 +50,7 @@ _construct(::Type{B}, args) where B<:Tuple = B(args) # Hack for making things that do not have public constructors constructable: @generated _construct(B::Type, args) = Expr(:splatnew, :B, :args) -@generated (::∂☆new{N})(B::Type) where {N} = return :(ZeroBundle{$N}($(Expr(:new, :B)))) +@generated (::∂☆new{N})(B::Type) where {N} = return :(zero_bundle{order}()($(Expr(:new, :B)))) # Sometimes we don't know whether or not we need to the ZeroBundle when doing # the transform, so this can happen - allow it for now. diff --git a/src/tangent.jl b/src/tangent.jl index 5bad9e29..fbc470d0 100644 --- a/src/tangent.jl +++ b/src/tangent.jl @@ -418,3 +418,18 @@ end function ChainRulesCore.rrule(::typeof(rebundle), atb) rebundle(atb), Δ->throw(Δ) end + + +""" + (::zero_bundle{N})(primal) + +Creates a bundle with a zero tangent. +""" +struct zero_bundle{N} end +function (::zero_bundle{N})(primal) where N + if zero_tangent(primal) isa ZeroTangent + return ZeroBundle{N}(primal) + else + return TaylorBundle{N}(primal, ntuple(_->zero_tangent(primal), N)) + end +end \ No newline at end of file diff --git a/test/forward_mutation.jl b/test/forward_mutation.jl new file mode 100644 index 00000000..4d250737 --- /dev/null +++ b/test/forward_mutation.jl @@ -0,0 +1,16 @@ +using Diffractor +using Diffractor: ∂☆ +using Diffractor: bundle + +mutable struct MDemo1 + x::Float64 +end +function double!(val::MDemo1) + val.x *= 2.0 + return val +end +function wrap_and_double(x) + val = MDemo1(x) + double!(val) +end +∂☆{1}()(ZeroBundle{1}(wrap_and_double), TaylorBundle{1}(1.5, (1.0,))) diff --git a/test/runtests.jl b/test/runtests.jl index 8b75d5fc..7c073a8f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,6 +18,7 @@ const bwd = Diffractor.PrimeDerivativeBack "tangent.jl", "forward_diff_no_inf.jl", "forward.jl", + "forward_mutation.jl", "reverse.jl", "regression.jl", "AbstractDifferentiationTests.jl" diff --git a/test/tangent.jl b/test/tangent.jl index dc271710..5d09c9a3 100644 --- a/test/tangent.jl +++ b/test/tangent.jl @@ -122,4 +122,19 @@ end @test truncate(et, Val(1)) == TaylorTangent((1.0,)) end +@testset "zero_bundle" begin + zero_bundle = Diffractor.zero_bundle + + tup_zb = zero_bundle{2}()((1, 0)) + @test tup_zb isa ZeroBundle{2} + @test iszero(tup_zb[TaylorTangentIndex(1)]) + @test iszero(tup_zb[TaylorTangentIndex(2)]) + + + ref_zb = zero_bundle{2}()(Ref(1.5)) + @test ref_zb isa TaylorBundle{2} + @test iszero(ref_zb[TaylorTangentIndex(1)]) + @test iszero(ref_zb[TaylorTangentIndex(2)]) +end + end # module From 1bffce5f0eeee75e7a087ee609830e674ca971fe Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 28 Sep 2023 18:17:08 +0800 Subject: [PATCH 02/23] Handle first order constructors --- src/codegen/forward.jl | 2 +- src/codegen/forward_demand.jl | 1 + src/stage1/generated.jl | 2 +- src/stage1/recurse_fwd.jl | 7 ++++--- test/forward_mutation.jl | 5 ++++- 5 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/codegen/forward.jl b/src/codegen/forward.jl index 1210cdfc..50fe54b9 100644 --- a/src/codegen/forward.jl +++ b/src/codegen/forward.jl @@ -71,7 +71,7 @@ function fwd_transform!(ci, mi, nargs, N) if isa(stmt, Expr) error("Unexprected statement encountered. This is a bug in Diffractor. stmt=$stmt") end - return Expr(:call, zero_bundle{order}(), stmt) + return Expr(:call, zero_bundle{N}(), stmt) end end diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl index 368d9b55..451c3412 100644 --- a/src/codegen/forward_demand.jl +++ b/src/codegen/forward_demand.jl @@ -105,6 +105,7 @@ function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::I Δbacking = insert_node!(ir, ssa, NewInstruction(Expr(:splatnew, widenconst(tup_typ), Δbacking), tup_typ_typ.val)) end tangentT = Core.Compiler.apply_type_tfunc(Const(ChainRulesCore.Tangent), newT, tup_typ_typ).val + # TODO do we need to make sure this inserts right Δtangent = insert_node!(ir, ssa, NewInstruction(Expr(:new, tangentT, Δbacking), tangentT)) return Δtangent else # general frule handling diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 278be4a2..73ff47e7 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -390,7 +390,7 @@ end lifted_getfield(x::ZeroTangent, s) = ZeroTangent() lifted_getfield(x::NoTangent, s) = NoTangent() -lifted_getfield(x::Tangent, s) = getproperty(x, s) +lifted_getfield(x::StructuralTangent, s) = getproperty(x, s) function lifted_getfield(x::Tangent{<:Tangent{T}}, s) where T bb = getfield(x.backing, 1) diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index 8630fd06..40315ab7 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -15,14 +15,14 @@ struct ∂☆new{N}; end function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...) primal_args = map(primal, xs) the_primal = _construct(B, primal_args) - + @info "∂☆new{1}" tangent_tup = map(first_partial, xs) the_partial = if B<:Tuple Tangent{B, typeof(tangent_tup)}(tangent_tup) else names = fieldnames(B) tangent_nt = NamedTuple{names}(tangent_tup) - Tangent{B, typeof(tangent_nt)}(tangent_nt) + StructuralTangent{B}(tangent_nt) end return TaylorBundle{1, B}(the_primal, (the_partial,)) end @@ -30,13 +30,14 @@ end function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N} primal_args = map(primal, xs) the_primal = _construct(B, primal_args) - + @info "∂☆new{N}" the_partials = ntuple(Val{N}()) do ii iith_order_type = ii==1 ? B : Any # the type of the higher order tangents isn't worth tracking tangent_tup = map(x->partial(x, ii), xs) tangent = if B<:Tuple Tangent{iith_order_type, typeof(tangent_tup)}(tangent_tup) else + # TODO support mutation names = fieldnames(B) tangent_nt = NamedTuple{names}(tangent_tup) Tangent{iith_order_type, typeof(tangent_nt)}(tangent_nt) diff --git a/test/forward_mutation.jl b/test/forward_mutation.jl index 4d250737..d37fe676 100644 --- a/test/forward_mutation.jl +++ b/test/forward_mutation.jl @@ -1,10 +1,13 @@ using Diffractor -using Diffractor: ∂☆ +using Diffractor: ∂☆, ZeroBundle, TaylorBundle using Diffractor: bundle mutable struct MDemo1 x::Float64 end + +∂☆{1}()(ZeroBundle{1}(MDemo1), TaylorBundle{1}(1.5, (1.0,))) + function double!(val::MDemo1) val.x *= 2.0 return val From 15230b82bcf8ad2a76e57d72f3c5d7fb9fd84ab5 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 29 Sep 2023 16:58:42 +0800 Subject: [PATCH 03/23] make 2nd order mutation not error --- src/extra_rules.jl | 15 +++++++++++++++ src/stage1/recurse_fwd.jl | 8 ++++---- test/forward_mutation.jl | 38 +++++++++++++++++++++++++++++++++++--- 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/src/extra_rules.jl b/src/extra_rules.jl index 0b0e8b51..b2622695 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -262,3 +262,18 @@ Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDi # Avoid https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495 ChainRulesCore._backing_error(P::Type{<:Base.Pairs}, G::Type{<:NamedTuple}, E::Type{<:AbstractDict}) = nothing + +# Needed for higher order so we don't see the `backing` field of StructuralTangents, just the contents +# SHould these be in ChainRules/ChainRulesCore? +# is this always the right behavour, or just because of how we do higher order +function ChainRulesCore.frule((_, Δ, _, _), ::typeof(getproperty), strct::StructuralTangent, sym::Union{Int,Symbol}, inbounds) + return (getproperty(strct, sym, inbounds), getproperty(Δ, sym)) +end + + +function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setproperty!), obj::MutableTangent, field, x) + ȯbj::MutableTangent + y = setproperty!(obj, field, x) + ẏ = setproperty!(ȯbj, field, ẋ) + return y, ẏ +end diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index 40315ab7..13c51961 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -32,15 +32,15 @@ function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N} the_primal = _construct(B, primal_args) @info "∂☆new{N}" the_partials = ntuple(Val{N}()) do ii - iith_order_type = ii==1 ? B : Any # the type of the higher order tangents isn't worth tracking tangent_tup = map(x->partial(x, ii), xs) tangent = if B<:Tuple - Tangent{iith_order_type, typeof(tangent_tup)}(tangent_tup) + Tangent{B, typeof(tangent_tup)}(tangent_tup) else - # TODO support mutation + # It is a little dubious using StructuralTangent{B} for >1st order, but it is isomorphic. + # Just watch out for order mixing bugs. names = fieldnames(B) tangent_nt = NamedTuple{names}(tangent_tup) - Tangent{iith_order_type, typeof(tangent_nt)}(tangent_nt) + StructuralTangent{B}(tangent_nt) end return tangent end diff --git a/test/forward_mutation.jl b/test/forward_mutation.jl index d37fe676..5a0fffb9 100644 --- a/test/forward_mutation.jl +++ b/test/forward_mutation.jl @@ -1,12 +1,26 @@ +# module forward_mutation using Diffractor using Diffractor: ∂☆, ZeroBundle, TaylorBundle -using Diffractor: bundle +using Diffractor: bundle, first_partial, TaylorTangentIndex +using ChainRulesCore +using Test + mutable struct MDemo1 x::Float64 end -∂☆{1}()(ZeroBundle{1}(MDemo1), TaylorBundle{1}(1.5, (1.0,))) +@testset "construction" begin + 🍞 = ∂☆{1}()(ZeroBundle{1}(MDemo1), TaylorBundle{1}(1.5, (1.0,))) + @test 🍞[TaylorTangentIndex(1)] isa MutableTangent{MDemo1} + @test 🍞[TaylorTangentIndex(1)].x == 1.0 + + 🥯 = ∂☆{2}()(ZeroBundle{2}(MDemo1), TaylorBundle{2}(1.5, (1.0, 10.0))) + @test 🥯[TaylorTangentIndex(1)] isa MutableTangent{MDemo1} + @test 🥯[TaylorTangentIndex(1)].x == 1.0 + @test 🥯[TaylorTangentIndex(2)] isa MutableTangent + @test 🥯[TaylorTangentIndex(2)].x == 10.0 +end function double!(val::MDemo1) val.x *= 2.0 @@ -16,4 +30,22 @@ function wrap_and_double(x) val = MDemo1(x) double!(val) end -∂☆{1}()(ZeroBundle{1}(wrap_and_double), TaylorBundle{1}(1.5, (1.0,))) +🐰 = ∂☆{1}()(ZeroBundle{1}(wrap_and_double), TaylorBundle{1}(1.5, (1.0,))) +@test first_partial(🐰) isa MutableTangent{MDemo1} +@test first_partial(🐰).x == 2.0 + +# second derivative +🐇 = ∂☆{2}()(ZeroBundle{2}(wrap_and_double), TaylorBundle{2}(1.5, (1.0, 10.0))) +@test 🐇[TaylorTangentIndex(1)] isa MutableTangent{MDemo1} +@test 🐇[TaylorTangentIndex(1)].x == 2.0 +@test 🐇[TaylorTangentIndex(2)] isa MutableTangent +@test 🐇[TaylorTangentIndex(2)] == 0.0 # returns 20 + + + +foo(val) = val^2 +🥖 = ∂☆{2}()(ZeroBundle{2}(foo), TaylorBundle{2}(1.0, (0.0, 10.0))) +🥖[TaylorTangentIndex(1)] # returns 0 +🥖[TaylorTangentIndex(2)] # returns 20 + +# end # module \ No newline at end of file From 865205320b26a1667873291dd5daafc5d5619bcd Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 5 Oct 2023 17:59:39 +0800 Subject: [PATCH 04/23] fix 2nd order test --- test/forward_mutation.jl | 42 +++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/test/forward_mutation.jl b/test/forward_mutation.jl index 5a0fffb9..e809a258 100644 --- a/test/forward_mutation.jl +++ b/test/forward_mutation.jl @@ -22,30 +22,28 @@ end @test 🥯[TaylorTangentIndex(2)].x == 10.0 end -function double!(val::MDemo1) - val.x *= 2.0 - return val +@testset "basis struct work" begin + function double!(val::MDemo1) + val.x *= 2.0 + return val + end + function wrap_and_double(x) + val = MDemo1(x) + double!(val) + end + # first derivative + 🐰 = ∂☆{1}()(ZeroBundle{1}(wrap_and_double), TaylorBundle{1}(1.5, (1.0,))) + @test first_partial(🐰) isa MutableTangent{MDemo1} + @test first_partial(🐰).x == 2.0 + + # second derivative + 🐇 = ∂☆{2}()(ZeroBundle{2}(wrap_and_double), TaylorBundle{2}(1.5, (1.0, 0.0))) + @test 🐇[TaylorTangentIndex(1)] isa MutableTangent{MDemo1} + @test 🐇[TaylorTangentIndex(1)].x == 2.0 + @test 🐇[TaylorTangentIndex(2)] isa MutableTangent + @test 🐇[TaylorTangentIndex(2)].x == 0.0 # returns 20 end -function wrap_and_double(x) - val = MDemo1(x) - double!(val) -end -🐰 = ∂☆{1}()(ZeroBundle{1}(wrap_and_double), TaylorBundle{1}(1.5, (1.0,))) -@test first_partial(🐰) isa MutableTangent{MDemo1} -@test first_partial(🐰).x == 2.0 - -# second derivative -🐇 = ∂☆{2}()(ZeroBundle{2}(wrap_and_double), TaylorBundle{2}(1.5, (1.0, 10.0))) -@test 🐇[TaylorTangentIndex(1)] isa MutableTangent{MDemo1} -@test 🐇[TaylorTangentIndex(1)].x == 2.0 -@test 🐇[TaylorTangentIndex(2)] isa MutableTangent -@test 🐇[TaylorTangentIndex(2)] == 0.0 # returns 20 - -foo(val) = val^2 -🥖 = ∂☆{2}()(ZeroBundle{2}(foo), TaylorBundle{2}(1.0, (0.0, 10.0))) -🥖[TaylorTangentIndex(1)] # returns 0 -🥖[TaylorTangentIndex(2)] # returns 20 # end # module \ No newline at end of file From 1d912732095e21b27d33c305f5281d1c7ee6d484 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 5 Oct 2023 23:22:49 +0800 Subject: [PATCH 05/23] add extra second order test --- test/forward_mutation.jl | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/test/forward_mutation.jl b/test/forward_mutation.jl index e809a258..98904f46 100644 --- a/test/forward_mutation.jl +++ b/test/forward_mutation.jl @@ -22,7 +22,7 @@ end @test 🥯[TaylorTangentIndex(2)].x == 10.0 end -@testset "basis struct work" begin +@testset "basis struct work: double" begin function double!(val::MDemo1) val.x *= 2.0 return val @@ -44,6 +44,27 @@ end @test 🐇[TaylorTangentIndex(2)].x == 0.0 # returns 20 end +@testset "basis struct work: square" begin + function square!(val::MDemo1) + val.x ^= 2.0 + return val + end + function wrap_and_square(x) + val = MDemo1(x) + square!(val) + end + # first derivative + 🐰 = ∂☆{1}()(ZeroBundle{1}(wrap_and_square), TaylorBundle{1}(10.0, (1.0,))) + @test first_partial(🐰) isa MutableTangent{MDemo1} + @test first_partial(🐰).x == 20.0 + + # second derivative + 🐇 = ∂☆{2}()(ZeroBundle{2}(wrap_and_square), TaylorBundle{2}(1, (1.0, 0.0))) + @test 🐇[TaylorTangentIndex(1)] isa MutableTangent{MDemo1} + @test 🐇[TaylorTangentIndex(1)].x == 20.0 + @test 🐇[TaylorTangentIndex(2)] isa MutableTangent + @test 🐇[TaylorTangentIndex(2)].x == 2.0 # returns 20 +end # end # module \ No newline at end of file From 67a15da093e6a99a69d60920aeb3a95aa7a99098 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 6 Oct 2023 23:31:33 +0800 Subject: [PATCH 06/23] Fix closures --- src/stage1/forward.jl | 4 ++-- test/forward_mutation.jl | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 0279ffba..05f96e22 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -131,12 +131,12 @@ end function (::∂☆internal{N})(f::AbstractZeroBundle{N}, args::AbstractZeroBundle{N}...) where {N} f_v = primal(f) args_v = map(primal, args) - return ZeroBundle{N}(f_v(args_v...)) + return zero_bundle{N}()(f_v(args_v...)) end function (::∂☆internal{1})(f::AbstractZeroBundle{1}, args::AbstractZeroBundle{1}...) f_v = primal(f) args_v = map(primal, args) - return ZeroBundle{1}(f_v(args_v...)) + return zero_bundle{1}()(f_v(args_v...)) end function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N} diff --git a/test/forward_mutation.jl b/test/forward_mutation.jl index 98904f46..1ce9cd29 100644 --- a/test/forward_mutation.jl +++ b/test/forward_mutation.jl @@ -66,5 +66,22 @@ end @test 🐇[TaylorTangentIndex(2)].x == 2.0 # returns 20 end +@testset "closure" begin + function bar(x) + z = x + 1.0 + function foo!(y) + z = z * y + return z + end + + foo!(2) + foo!(2) + return z + end + + 🥯 = ∂☆{1}()(ZeroBundle{1}(bar), TaylorBundle{1}(10.0, (1.0,))) + @test 🥯[TaylorTangentIndex(1)] == 4.0 +end + # end # module \ No newline at end of file From 7d481d42d079c8091f38b31094a8c2e9745496a9 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 29 Dec 2023 15:51:34 +0800 Subject: [PATCH 07/23] remove debug statements --- src/stage1/recurse_fwd.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index 13c51961..372412b6 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -15,7 +15,6 @@ struct ∂☆new{N}; end function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...) primal_args = map(primal, xs) the_primal = _construct(B, primal_args) - @info "∂☆new{1}" tangent_tup = map(first_partial, xs) the_partial = if B<:Tuple Tangent{B, typeof(tangent_tup)}(tangent_tup) @@ -30,7 +29,6 @@ end function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N} primal_args = map(primal, xs) the_primal = _construct(B, primal_args) - @info "∂☆new{N}" the_partials = ntuple(Val{N}()) do ii tangent_tup = map(x->partial(x, ii), xs) tangent = if B<:Tuple From 6ae4e136746587b98b4b871ab4855a860ce90220 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 29 Dec 2023 20:23:45 +0800 Subject: [PATCH 08/23] update for more aggressive iszero and take NoTangent more seriously --- src/codegen/forward.jl | 4 ++-- src/codegen/forward_demand.jl | 2 +- src/higher_fwd_rules.jl | 10 +++++----- src/stage1/broadcast.jl | 2 +- src/stage1/forward.jl | 33 ++++++++++++++++++--------------- src/stage1/mixed.jl | 10 +++++----- src/stage1/recurse_fwd.jl | 2 ++ src/tangent.jl | 6 ++++-- test/stage2_fwd.jl | 15 ++++++++++----- test/tangent.jl | 2 +- 10 files changed, 49 insertions(+), 37 deletions(-) diff --git a/src/codegen/forward.jl b/src/codegen/forward.jl index 50fe54b9..e8f5bcdd 100644 --- a/src/codegen/forward.jl +++ b/src/codegen/forward.jl @@ -35,7 +35,7 @@ function fwd_transform!(ci, mi, nargs, N) args = map(stmt.args) do stmt emit!(mapstmt!(stmt)) end - return Expr(:call, Core._apply_iterate, FwdIterate(ZeroBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...) + return Expr(:call, Core._apply_iterate, FwdIterate(DNEBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...) elseif isa(stmt, SSAValue) return SSAValue(ssa_mapping[stmt.id]) elseif isa(stmt, Core.SlotNumber) @@ -64,7 +64,7 @@ function fwd_transform!(ci, mi, nargs, N) # Always disable `@inbounds`, as we don't actually know if the AD'd # code is truly `@inbounds` or not. elseif isexpr(stmt, :boundscheck) - return ZeroBundle{N}(true) + return DNEBundle{N}(true) else # Fallback case, for literals. # If it is an Expr, then it is not a literal diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl index 451c3412..ac67d50b 100644 --- a/src/codegen/forward_demand.jl +++ b/src/codegen/forward_demand.jl @@ -363,6 +363,6 @@ function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::Met rt = CC._ir_abstract_constant_propagation(interp, irsv) ir = compact!(ir) - + return ir end diff --git a/src/higher_fwd_rules.jl b/src/higher_fwd_rules.jl index d67a44e1..8486b8bd 100644 --- a/src/higher_fwd_rules.jl +++ b/src/higher_fwd_rules.jl @@ -19,10 +19,10 @@ end jeval(j, x) = j(x) for f in (sin, cos, exp) - function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(f)}, x::TaylorBundle{N}) where {N} + function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(f)}, x::TaylorBundle{N}) where {N} njet(Val{N}(), primal(fb), primal(x))(x) end - function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, fb::ZeroBundle{M, typeof(f)}, x::TaylorBundle{M}) where {N, M} + function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, fb::AbstractZeroBundle{M, typeof(f)}, x::TaylorBundle{M}) where {N, M} ∂⃖ₙ(jeval, njet(Val{N+M}(), primal(fb), primal(x)), x) end end @@ -30,16 +30,16 @@ end # TODO: It's a bit embarassing that we need to write these out, but currently the # compiler is not strong enough to automatically lift the frule. Let's hope we # can delete these in the near future. -function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N} +function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N} TaylorBundle{N}(primal(a) + primal(b), map(+, a.tangent.coeffs, b.tangent.coeffs)) end -function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::ZeroBundle{N}) where {N} +function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::AbstractZeroBundle{N}) where {N} TaylorBundle{N}(primal(a) + primal(b), a.tangent.coeffs) end -function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(-)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N} +function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(-)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N} TaylorBundle{N}(primal(a) - primal(b), map(-, a.tangent.coeffs, b.tangent.coeffs)) end diff --git a/src/stage1/broadcast.jl b/src/stage1/broadcast.jl index e04bd60a..e32dcfef 100644 --- a/src/stage1/broadcast.jl +++ b/src/stage1/broadcast.jl @@ -9,7 +9,7 @@ end n_getfield(∂ₙ::∂☆{N}, b::ATB{N}, x::Union{Symbol, Int}) where {N} = ∂ₙ(ZeroBundle{N}(getfield), b, ZeroBundle{N}(x)) -function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)}, +function (∂ₙ::∂☆{N})(zc::AbstractZeroBundle{N, typeof(copy)}, bc::ATB{N, <:Broadcasted}) where {N} bc = ∂ₙ(ZeroBundle{N}(Broadcast.flatten), bc) args = n_getfield(∂ₙ, bc, :args) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 05f96e22..14ddfc55 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -98,9 +98,12 @@ struct ∂☆shuffle{N}; end function shuffle_base(r) (primal, dual) = r - if isa(dual, Union{NoTangent, ZeroTangent}) + if dual isa NoTangent UniformBundle{1}(primal, dual) else + if dual isa ZeroTangent # Normalize zero for type-stability reasons + dual = zero_tangent(primal) + end TaylorBundle{1}(primal, (dual,)) end end @@ -193,25 +196,25 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}} end (f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...) -function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N} +function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N} ∂vararg{N}()(map(FwdMap(f), destructure(tup))...) end -function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N} +function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N} # TODO: This could do an inplace map! to avoid the extra rebundling rebundle(map(FwdMap(f), map(unbundle, args)...)) end -function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N}...) where {N} +function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N}...) where {N} ∂☆recurse{N}()(ZeroBundle{N, typeof(map)}(map), f, args...) end -function (::∂☆{N})(f::ZeroBundle{N, typeof(ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N} +function (::∂☆{N})(f::AbstractZeroBundle{N, typeof(ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N} ifelse(arg.primal, args...) end -function (::∂☆{N})(f::ZeroBundle{N, typeof(Core.ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N} +function (::∂☆{N})(f::AbstractZeroBundle{N, typeof(Core.ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N} Core.ifelse(arg.primal, args...) end @@ -233,48 +236,48 @@ end primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2)))) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N} Core._apply_iterate(FwdIterate(iterate), this, (f,), args...) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N} r = iterate(destructure(t)) r === nothing && return ZeroBundle{N}(nothing) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N} r = iterate(destructure(t), primal(a), map(primal, args)...) r === nothing && return ZeroBundle{N}(nothing) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N} r = Base.indexed_iterate(destructure(t), primal(i)) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N} r = Base.indexed_iterate(destructure(t), primal(i), primal(st1), map(primal, st)...) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TangentBundle{N, <:Tuple}, i::ATB{N}, st::ATB{N}...) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Base.indexed_iterate)}, t::TangentBundle{N, <:Tuple}, i::ATB{N}, st::ATB{N}...) where {N} ∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1)) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::ZeroBundle) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::AbstractZeroBundle) where {N} field_ind = primal(i) the_partials = ntuple(order_ind->partial(t, order_ind)[field_ind], N) TaylorBundle{N}(primal(t)[field_ind], the_partials) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N} DNEBundle{N}(typeof(primal(x))) end -function (this::∂☆{N})(f::ZeroBundle{N, Core.IntrinsicFunction}, args::ATB{N}...) where {N} +function (this::∂☆{N})(f::AbstractZeroBundle{N, Core.IntrinsicFunction}, args::ATB{N}...) where {N} ff = primal(f) if ff === Base.not_int DNEBundle{N}(ff(map(primal, args)...)) diff --git a/src/stage1/mixed.jl b/src/stage1/mixed.jl index a1d30a86..874f73b4 100644 --- a/src/stage1/mixed.jl +++ b/src/stage1/mixed.jl @@ -70,12 +70,12 @@ function (f::FwdIterate)(arg::ATB{N}, st) where {N} primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2)))) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N} +function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N} Core._apply_iterate(FwdIterate(iterate), this, (f,), args...) end =# -function (this::∂⃖{N})(that::∂☆{M}, ::ZeroBundle{M, typeof(Core._apply_iterate)}, +function (this::∂⃖{N})(that::∂☆{M}, ::AbstractZeroBundle{M, typeof(Core._apply_iterate)}, iterate, f, args::ATB{M, <:Tuple}...) where {N, M} @assert primal(iterate) === Base.iterate x, ∂⃖f = Core._apply_iterate(FwdIterate(iterate), this, (that, f), args...) @@ -83,13 +83,13 @@ function (this::∂⃖{N})(that::∂☆{M}, ::ZeroBundle{M, typeof(Core._apply_i end -function ChainRules.rrule(∂::∂☆{N}, m::ZeroBundle{N, typeof(map)}, p::ZeroBundle{N, typeof(+)}, A::ATB{N}, B::ATB{N}) where {N} +function ChainRules.rrule(∂::∂☆{N}, m::AbstractZeroBundle{N, typeof(map)}, p::AbstractZeroBundle{N, typeof(+)}, A::ATB{N}, B::ATB{N}) where {N} ∂(m, p, A, B), Δ->(NoTangent(), NoTangent(), NoTangent(), Δ, Δ) end mapev_unbundled(_, js, a) = rebundle(mapev(js, unbundle(a))) -function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, ::ZeroBundle{M, typeof(map)}, - f::ZeroBundle{M}, a::ATB{M, <:Array}) where {N, M} +function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, ::AbstractZeroBundle{M, typeof(map)}, + f::AbstractZeroBundle{M}, a::ATB{M, <:Array}) where {N, M} @assert Base.issingletontype(typeof(primal(f))) js = map(primal(a)) do x ∂f = ∂☆{N+M}()(ZeroBundle{N+M}(primal(f)), diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index 372412b6..c1232bca 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -23,6 +23,8 @@ function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...) tangent_nt = NamedTuple{names}(tangent_tup) StructuralTangent{B}(tangent_nt) end + @show typeof(the_partial) + #TODO: I think we need https://github.com/JuliaDiff/Diffractor.jl/pull/236/files here return TaylorBundle{1, B}(the_primal, (the_partial,)) end diff --git a/src/tangent.jl b/src/tangent.jl index fbc470d0..d771227e 100644 --- a/src/tangent.jl +++ b/src/tangent.jl @@ -427,9 +427,11 @@ Creates a bundle with a zero tangent. """ struct zero_bundle{N} end function (::zero_bundle{N})(primal) where N - if zero_tangent(primal) isa ZeroTangent - return ZeroBundle{N}(primal) + if zero_tangent(primal) isa AbstractZero + return UniformBundle{N}(primal, zero_tangent(primal) ) else + # Note: it is important that zero_tangent(primal) is called in ntuple + # so it gets distrinct values for each order, so it doesn't alias if mutated. return TaylorBundle{N}(primal, ntuple(_->zero_tangent(primal), N)) end end \ No newline at end of file diff --git a/test/stage2_fwd.jl b/test/stage2_fwd.jl index 987fbb3f..9a0248f2 100644 --- a/test/stage2_fwd.jl +++ b/test/stage2_fwd.jl @@ -6,7 +6,8 @@ module stage2_fwd @test sin′(1.0) == cos(1.0) end let sin′′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(mysin), Float64}, 2) - @test isa(sin′′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + # This broke some time between 1.10 and 1.11-DEV.10001 + @test_broken isa(sin′′, Core.OpaqueClosure{Tuple{Float64}, Float64}) @test sin′′(1.0) == -sin(1.0) end @@ -14,24 +15,28 @@ module stage2_fwd self_minus(a) = myminus(a, a) ChainRulesCore.@scalar_rule myminus(x, y) (true, -1) let self_minus′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}) - @test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + # This broke some time between 1.10 and 1.11-DEV.10001 + @test_broken isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) @test self_minus′(1.0) == 0. end ChainRulesCore.@scalar_rule myminus(x, y) (true, true) # frule for `x - y` let self_minus′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}) - @test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + # This broke some time between 1.10 and 1.11-DEV.10001 + @test_broken isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) @test self_minus′(1.0) == 2. end myminus2(a, b) = a - b self_minus2(a) = myminus2(a, a) let self_minus2′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus2), Float64}) - @test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + # This broke some time between 1.10 and 1.11-DEV.10001 + @test_broken isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) @test self_minus2′(1.0) == 0. end ChainRulesCore.@scalar_rule myminus2(x, y) (true, true) # frule for `x - y` let self_minus2′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus2), Float64}) - @test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + # This broke some time between 1.10 and 1.11-DEV.10001 + @test_broken isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) @test self_minus2′(1.0) == 2. end diff --git a/test/tangent.jl b/test/tangent.jl index 5d09c9a3..8bf2cac5 100644 --- a/test/tangent.jl +++ b/test/tangent.jl @@ -126,7 +126,7 @@ end zero_bundle = Diffractor.zero_bundle tup_zb = zero_bundle{2}()((1, 0)) - @test tup_zb isa ZeroBundle{2} + @test tup_zb isa Diffractor.AbstractTangentBundle{2} @test iszero(tup_zb[TaylorTangentIndex(1)]) @test iszero(tup_zb[TaylorTangentIndex(2)]) From b742b3e10251c9a02efed0709239583da2e78737 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 16 Oct 2023 14:38:51 +0800 Subject: [PATCH 09/23] =?UTF-8?q?type=20of=20type=20erasure=20in=20?= =?UTF-8?q?=E2=88=82=E2=98=86new?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/stage1/recurse_fwd.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index c1232bca..7f38b7d0 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -23,9 +23,8 @@ function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...) tangent_nt = NamedTuple{names}(tangent_tup) StructuralTangent{B}(tangent_nt) end - @show typeof(the_partial) - #TODO: I think we need https://github.com/JuliaDiff/Diffractor.jl/pull/236/files here - return TaylorBundle{1, B}(the_primal, (the_partial,)) + B2 = typeof(the_primal) # HACK: if the_primal actually has types in it then we want to make sure we get DataType not Type(...) + return TaylorBundle{1, B2}(the_primal, (the_partial,)) end function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N} From e41490f5652ff7229112d29a50d20bb55c3ee4a4 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 17 Oct 2023 16:11:00 +0800 Subject: [PATCH 10/23] tests for type of type --- test/forward.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/forward.jl b/test/forward.jl index dec11d97..f8040639 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -161,6 +161,18 @@ end end +@testset "types in tuples" begin + function foo(a) + tup = (a, 2a, Int) + return tup[2] + end + + let var"'" = Diffractor.PrimeDerivativeFwd + @test foo'(100.0) == 2.0 + end +end + + @testset "taylor_compatible" begin taylor_compatible = Diffractor.taylor_compatible From d0a825b6c721f80e9d4e3de7fde16271b7a576ff Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 2 Jan 2024 12:20:03 +0800 Subject: [PATCH 11/23] correct test --- test/forward_mutation.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/forward_mutation.jl b/test/forward_mutation.jl index 1ce9cd29..04e6f0ff 100644 --- a/test/forward_mutation.jl +++ b/test/forward_mutation.jl @@ -41,7 +41,7 @@ end @test 🐇[TaylorTangentIndex(1)] isa MutableTangent{MDemo1} @test 🐇[TaylorTangentIndex(1)].x == 2.0 @test 🐇[TaylorTangentIndex(2)] isa MutableTangent - @test 🐇[TaylorTangentIndex(2)].x == 0.0 # returns 20 + @test 🐇[TaylorTangentIndex(2)].x == 0.0 end @testset "basis struct work: square" begin @@ -59,11 +59,11 @@ end @test first_partial(🐰).x == 20.0 # second derivative - 🐇 = ∂☆{2}()(ZeroBundle{2}(wrap_and_square), TaylorBundle{2}(1, (1.0, 0.0))) + 🐇 = ∂☆{2}()(ZeroBundle{2}(wrap_and_square), TaylorBundle{2}(100.0, (1.0, 0.0))) @test 🐇[TaylorTangentIndex(1)] isa MutableTangent{MDemo1} - @test 🐇[TaylorTangentIndex(1)].x == 20.0 + @test 🐇[TaylorTangentIndex(1)].x == 200.0 @test 🐇[TaylorTangentIndex(2)] isa MutableTangent - @test 🐇[TaylorTangentIndex(2)].x == 2.0 # returns 20 + @test 🐇[TaylorTangentIndex(2)].x == 2.0 end @testset "closure" begin From f47d5f046b67846383caaf683df32777a012c40a Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 5 Jan 2024 17:07:02 +0800 Subject: [PATCH 12/23] Comment out reverse tests that are making it hang on nigthtly --- test/reverse.jl | 10 ++++++---- test/runtests.jl | 1 + 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/test/reverse.jl b/test/reverse.jl index 4e81b921..f525b10e 100644 --- a/test/reverse.jl +++ b/test/reverse.jl @@ -98,14 +98,16 @@ let var"'" = Diffractor.PrimeDerivativeBack # TODO This currently causes a segfault, c.f. https://github.com/JuliaLang/julia/pull/48742 # @test @inferred(complicated_2sin''''(1.0)) == 2sin''''(1.0) broken=true - # Control flow cases + # Control flow cases: + # if @test @inferred((x->simple_control_flow(true, x))'(1.0)) == sin'(1.0) @test @inferred((x->simple_control_flow(false, x))'(1.0)) == cos'(1.0) @test (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;] - @test times_three_while'(1.0) == 3.0 - + + # while + # @test times_three_while'(1.0) == 3.0 # hangs in 1.11 pow5p(x) = (x->mypow(x, 5))'(x) - @test pow5p(1.0) == 5.0 + #@test pow5p(1.0) == 5.0 # hangs in 1.11 end end diff --git a/test/runtests.jl b/test/runtests.jl index 7c073a8f..847f3cce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,6 +24,7 @@ const bwd = Diffractor.PrimeDerivativeBack "AbstractDifferentiationTests.jl" #"pinn.jl", # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24) ) + @info "testing" file include(file) end From 7c7c6287ced4f75da97652d7ffc35791560e7e14 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 12 Jan 2024 17:27:44 +0800 Subject: [PATCH 13/23] fix quotenodes not to error --- src/codegen/forward_demand.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl index ac67d50b..38701f24 100644 --- a/src/codegen/forward_demand.jl +++ b/src/codegen/forward_demand.jl @@ -267,7 +267,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}}; @assert isconst(arg) return zero_bundle{order}()(getfield(arg.mod, arg.name)) elseif isa(arg, QuoteNode) - return zero_bundle{order}(){order}(arg.value) + return zero_bundle{order}()(arg.value) end @assert !isa(arg, Expr) return zero_bundle{order}()(arg) From 2b79ef384dfd48cc34608f1788a1a278acd58f74 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 16 Jan 2024 13:38:46 +0800 Subject: [PATCH 14/23] =?UTF-8?q?Correct=20code=20for=20higher=20order=20?= =?UTF-8?q?=E2=88=82=E2=98=86new?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/stage1/recurse_fwd.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index 7f38b7d0..eb892cc9 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -50,7 +50,7 @@ _construct(::Type{B}, args) where B<:Tuple = B(args) # Hack for making things that do not have public constructors constructable: @generated _construct(B::Type, args) = Expr(:splatnew, :B, :args) -@generated (::∂☆new{N})(B::Type) where {N} = return :(zero_bundle{order}()($(Expr(:new, :B)))) +@generated (::∂☆new{N})(B::Type) where {N} = return :(zero_bundle{$N}()($(Expr(:new, :B)))) # Sometimes we don't know whether or not we need to the ZeroBundle when doing # the transform, so this can happen - allow it for now. From 55eb407df8eec33c9caa0f15d49538fb56cb3b3a Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 19 Jan 2024 18:47:04 +0800 Subject: [PATCH 15/23] more extra rules for static arrays more overloads for StaticArrays --- src/extra_rules.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/extra_rules.jl b/src/extra_rules.jl index b2622695..7acfeb85 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -172,9 +172,14 @@ function ChainRules.rrule(::DiffractorRuleConfig, ::Type{SArray{S, T, N, L}}, x: end function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,T}) where {S, T, N, L} - SArray{S, T, N, L}(x), SArray{S, T, N, L}(∂x.backing) + #TODO: we really shouldn't actually see the isa(∂x, AbstractZero) case since the frule should be called then + Δx = isa(∂x, AbstractZero) ? ∂x : SArray{S, T, N, L}(ChainRulesCore.backing(∂x)) + SArray{S, T, N, L}(x), Δx end +Base.view(t::Tangent{T}, inds) where T<:SVector = view(T(ChainRulesCore.backing(t.data)), inds) +Base.getindex(t::Tangent{<:SVector, <:NamedTuple}, ind::Int) = ChainRulesCore.backing(t.data)[ind] + function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L} SArray{S, T, N, L}(x), SArray{S}(∂x) end From b9940e30c993dbe7eaa6ed613776b6874056ee2a Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 23 Jan 2024 19:48:41 +0800 Subject: [PATCH 16/23] if all partials AbstractZero don't call frule --- src/stage1/forward.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 14ddfc55..a699caee 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -109,7 +109,7 @@ function shuffle_base(r) end function (::∂☆internal{1})(args::AbstractTangentBundle{1}...) - r = frule(DiffractorRuleConfig(), map(first_partial, args), map(primal, args)...) + r = _frule(map(first_partial, args), map(primal, args)...) if r === nothing return ∂☆recurse{1}()(args...) else @@ -117,6 +117,14 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...) end end +_frule(partials, primals...) = frule(DiffractorRuleConfig(), partials, primals...) +function _frule(::NTuple{<:Any, AbstractZero}, f, primal_args...) + # frules are linear in partials, so zero maps to zero, no need to evaluate the frule + # If all partials are immutable AbstractZero subtyoes we know we don't have to worry about a mutating frule either + r = f(primal_args...) + return r, zero_tangent(r) +end + function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...) bundles = map((p,a) -> ExplicitTangentBundle{1}(a, (p,)), partials, args) result = ∂☆internal{1}()(bundles...) From d4db01141296c6acd9333101043f171b8151e8f7 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 24 Jan 2024 09:26:45 +0800 Subject: [PATCH 17/23] Remove unrelated confusing comment --- src/codegen/forward_demand.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl index 38701f24..74c78694 100644 --- a/src/codegen/forward_demand.jl +++ b/src/codegen/forward_demand.jl @@ -105,7 +105,6 @@ function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::I Δbacking = insert_node!(ir, ssa, NewInstruction(Expr(:splatnew, widenconst(tup_typ), Δbacking), tup_typ_typ.val)) end tangentT = Core.Compiler.apply_type_tfunc(Const(ChainRulesCore.Tangent), newT, tup_typ_typ).val - # TODO do we need to make sure this inserts right Δtangent = insert_node!(ir, ssa, NewInstruction(Expr(:new, tangentT, Δbacking), tangentT)) return Δtangent else # general frule handling From 6f5765794331cc50ba39148834cf8e2ef7af314b Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 24 Jan 2024 09:35:30 +0800 Subject: [PATCH 18/23] Remove case in frule for purely AbstractZero tangent --- src/extra_rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/extra_rules.jl b/src/extra_rules.jl index 7acfeb85..0ff2dcbb 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -173,7 +173,7 @@ end function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,T}) where {S, T, N, L} #TODO: we really shouldn't actually see the isa(∂x, AbstractZero) case since the frule should be called then - Δx = isa(∂x, AbstractZero) ? ∂x : SArray{S, T, N, L}(ChainRulesCore.backing(∂x)) + Δx = SArray{S, T, N, L}(ChainRulesCore.backing(∂x)) SArray{S, T, N, L}(x), Δx end From 00224ca2450b7674060b949cfdee6c0f3a79c182 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 29 Jan 2024 16:44:17 +0800 Subject: [PATCH 19/23] Require version of ChainRulesCore with MutableTangent etc --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1d9ae724..8aa9aadd 100644 --- a/Project.toml +++ b/Project.toml @@ -18,7 +18,7 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [compat] AbstractDifferentiation = "0.5" ChainRules = "1.44.6" -ChainRulesCore = "1.15.3" +ChainRulesCore = "1.20" Combinatorics = "1" Cthulhu = "2" OffsetArrays = "1" From a61d05b5afdf81df7e4986ce61befd1cd4fc6340 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 1 Feb 2024 18:35:28 +0800 Subject: [PATCH 20/23] Remove outdated comment --- src/extra_rules.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/extra_rules.jl b/src/extra_rules.jl index 0ff2dcbb..b9bcff7e 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -172,7 +172,6 @@ function ChainRules.rrule(::DiffractorRuleConfig, ::Type{SArray{S, T, N, L}}, x: end function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,T}) where {S, T, N, L} - #TODO: we really shouldn't actually see the isa(∂x, AbstractZero) case since the frule should be called then Δx = SArray{S, T, N, L}(ChainRulesCore.backing(∂x)) SArray{S, T, N, L}(x), Δx end From c9fb55ad4a6f3e03b0976e3e42209fe6bc806079 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 1 Feb 2024 19:09:43 +0800 Subject: [PATCH 21/23] Update comment --- src/stage1/recurse_fwd.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index eb892cc9..fa8a99fe 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -35,8 +35,8 @@ function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N} tangent = if B<:Tuple Tangent{B, typeof(tangent_tup)}(tangent_tup) else - # It is a little dubious using StructuralTangent{B} for >1st order, but it is isomorphic. - # Just watch out for order mixing bugs. + # No matter the order we use `StructuralTangent{B}` for the partial + # It follows all required properties of the tangent to the n-1th order tangent names = fieldnames(B) tangent_nt = NamedTuple{names}(tangent_tup) StructuralTangent{B}(tangent_nt) From 0cfbc41e54f526bd995b2302cab6938b9927d519 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 1 Feb 2024 19:11:33 +0800 Subject: [PATCH 22/23] Add comment --- src/tangent.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tangent.jl b/src/tangent.jl index d771227e..146cd8cf 100644 --- a/src/tangent.jl +++ b/src/tangent.jl @@ -427,7 +427,8 @@ Creates a bundle with a zero tangent. """ struct zero_bundle{N} end function (::zero_bundle{N})(primal) where N - if zero_tangent(primal) isa AbstractZero + # We still use a Uniform bundle e.g. if primal has NoTangent + if zero_tangent(primal) isa AbstractZero return UniformBundle{N}(primal, zero_tangent(primal) ) else # Note: it is important that zero_tangent(primal) is called in ntuple From 8badafd5ee4c60dd9ecdaed9a6482ab502888c72 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 1 Feb 2024 19:39:45 +0800 Subject: [PATCH 23/23] Mark tests that broke in 1.11 only broken in 1.11 --- test/stage2_fwd.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/stage2_fwd.jl b/test/stage2_fwd.jl index 9a0248f2..6bc48d08 100644 --- a/test/stage2_fwd.jl +++ b/test/stage2_fwd.jl @@ -7,7 +7,7 @@ module stage2_fwd end let sin′′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(mysin), Float64}, 2) # This broke some time between 1.10 and 1.11-DEV.10001 - @test_broken isa(sin′′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + @test isa(sin′′, Core.OpaqueClosure{Tuple{Float64}, Float64}) broken=VERSION>=v"1.11-" @test sin′′(1.0) == -sin(1.0) end @@ -16,27 +16,26 @@ module stage2_fwd ChainRulesCore.@scalar_rule myminus(x, y) (true, -1) let self_minus′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}) # This broke some time between 1.10 and 1.11-DEV.10001 - @test_broken isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + @test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) broken=VERSION>=v"1.11-" @test self_minus′(1.0) == 0. end ChainRulesCore.@scalar_rule myminus(x, y) (true, true) # frule for `x - y` let self_minus′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}) # This broke some time between 1.10 and 1.11-DEV.10001 - @test_broken isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + @test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) broken=VERSION>=v"1.11-" @test self_minus′(1.0) == 2. end myminus2(a, b) = a - b self_minus2(a) = myminus2(a, a) let self_minus2′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus2), Float64}) - # This broke some time between 1.10 and 1.11-DEV.10001 - @test_broken isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + @test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) broken=VERSION>=v"1.11-" @test self_minus2′(1.0) == 0. end ChainRulesCore.@scalar_rule myminus2(x, y) (true, true) # frule for `x - y` let self_minus2′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus2), Float64}) # This broke some time between 1.10 and 1.11-DEV.10001 - @test_broken isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + @test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) broken=VERSION>=v"1.11-" @test self_minus2′(1.0) == 2. end