diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 6c041adf..f690c605 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.5.2" +version = "0.5.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -88,6 +88,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -106,5 +107,6 @@ test = [ "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", + "StableRNGs", "Test", ] diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index 9d95bc80..3d19b5fe 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -22,11 +22,14 @@ using Enzyme: DuplicatedNoNeed, Forward, ForwardMode, + Mode, Reverse, ReverseWithPrimal, ReverseSplitWithPrimal, ReverseMode, autodiff, + autodiff_deferred, + autodiff_deferred_thunk, autodiff_thunk, chunkedonehot, gradient, @@ -34,16 +37,25 @@ using Enzyme: jacobian, make_zero -const AutoForwardEnzyme = AutoEnzyme{<:ForwardMode} -const AutoForwardOrNothingEnzyme = Union{AutoEnzyme{<:ForwardMode},AutoEnzyme{Nothing}} -const AutoReverseEnzyme = AutoEnzyme{<:ReverseMode} -const AutoReverseOrNothingEnzyme = Union{AutoEnzyme{<:ReverseMode},AutoEnzyme{Nothing}} +struct AutoDeferredEnzyme{M} <: ADTypes.AbstractADType + mode::M +end + +ADTypes.mode(backend::AutoDeferredEnzyme) = ADTypes.mode(AutoEnzyme(backend.mode)) + +DI.backend_package_name(::AutoDeferredEnzyme) = "DeferredEnzyme" + +DI.nested(backend::AutoEnzyme) = AutoDeferredEnzyme(backend.mode) -forward_mode(backend::AutoEnzyme{<:ForwardMode}) = backend.mode -forward_mode(::AutoEnzyme{Nothing}) = Forward +const AnyAutoEnzyme{M} = Union{AutoEnzyme{M},AutoDeferredEnzyme{M}} -reverse_mode(backend::AutoEnzyme{<:ReverseMode}) = backend.mode -reverse_mode(::AutoEnzyme{Nothing}) = Reverse +# forward mode if possible +forward_mode(backend::AnyAutoEnzyme{<:Mode}) = backend.mode +forward_mode(::AnyAutoEnzyme{Nothing}) = Forward + +# reverse mode if possible +reverse_mode(backend::AnyAutoEnzyme{<:Mode}) = backend.mode +reverse_mode(::AnyAutoEnzyme{Nothing}) = Reverse DI.check_available(::AutoEnzyme) = true @@ -54,12 +66,6 @@ function DI.basis(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T return b end -function zero_sametype!(x_target, x) - x_sametype = convert(typeof(x), x_target) - x_sametype .= zero(eltype(x_sametype)) - return x_sametype -end - include("forward_onearg.jl") include("forward_twoarg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index a04ba42c..95ccbb5f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -1,31 +1,42 @@ ## Pushforward -DI.prepare_pushforward(f, ::AutoForwardOrNothingEnzyme, x, dx) = NoPushforwardExtras() +function DI.prepare_pushforward(f, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx) + return NoPushforwardExtras() +end function DI.value_and_pushforward( - f, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras + f, backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, ::NoPushforwardExtras ) dx_sametype = convert(typeof(x), dx) - y, new_dy = autodiff( - forward_mode(backend), Const(f), Duplicated, Duplicated(x, dx_sametype) - ) + x_and_dx = Duplicated(x, dx_sametype) + y, new_dy = if backend isa AutoDeferredEnzyme + autodiff_deferred(forward_mode(backend), f, Duplicated, x_and_dx) + else + autodiff(forward_mode(backend), Const(f), Duplicated, x_and_dx) + end return y, new_dy end function DI.pushforward( - f, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras + f, backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, ::NoPushforwardExtras ) dx_sametype = convert(typeof(x), dx) - new_dy = only( - autodiff( - forward_mode(backend), Const(f), DuplicatedNoNeed, Duplicated(x, dx_sametype) - ), - ) + x_and_dx = Duplicated(x, dx_sametype) + new_dy = if backend isa AutoDeferredEnzyme + only(autodiff_deferred(forward_mode(backend), f, DuplicatedNoNeed, x_and_dx)) + else + only(autodiff(forward_mode(backend), Const(f), DuplicatedNoNeed, x_and_dx)) + end return new_dy end function DI.value_and_pushforward!( - f, dy, backend::AutoForwardOrNothingEnzyme, x, dx, extras::NoPushforwardExtras + f, + dy, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, + x, + dx, + extras::NoPushforwardExtras, ) # dy cannot be passed anyway y, new_dy = DI.value_and_pushforward(f, backend, x, dx, extras) @@ -33,7 +44,12 @@ function DI.value_and_pushforward!( end function DI.pushforward!( - f, dy, backend::AutoForwardOrNothingEnzyme, x, dx, extras::NoPushforwardExtras + f, + dy, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, + x, + dx, + extras::NoPushforwardExtras, ) # dy cannot be passed anyway return copyto!(dy, DI.pushforward(f, backend, x, dx, extras)) @@ -45,34 +61,34 @@ struct EnzymeForwardGradientExtras{C,O} <: GradientExtras shadow::O end -function DI.prepare_gradient(f, ::AutoForwardEnzyme, x) +function DI.prepare_gradient(f, ::AutoEnzyme{<:ForwardMode}, x) C = pick_chunksize(length(x)) shadow = chunkedonehot(x, Val(C)) return EnzymeForwardGradientExtras{C,typeof(shadow)}(shadow) end function DI.gradient( - f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C} + f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C} ) where {C} grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow) return reshape(collect(grad_tup), size(x)) end function DI.value_and_gradient( - f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras + f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras ) return f(x), DI.gradient(f, backend, x, extras) end function DI.gradient!( - f, grad, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C} + f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C} ) where {C} grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow) return copyto!(grad, grad_tup) end function DI.value_and_gradient!( - f, grad, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C} + f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C} ) where {C} grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow) return f(x), copyto!(grad, grad_tup) @@ -84,14 +100,17 @@ struct EnzymeForwardOneArgJacobianExtras{C,O} <: JacobianExtras shadow::O end -function DI.prepare_jacobian(f, ::AutoForwardOrNothingEnzyme, x) +function DI.prepare_jacobian(f, ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x) C = pick_chunksize(length(x)) shadow = chunkedonehot(x, Val(C)) return EnzymeForwardOneArgJacobianExtras{C,typeof(shadow)}(shadow) end function DI.jacobian( - f, backend::AutoForwardOrNothingEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras{C} + f, + backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, + x, + extras::EnzymeForwardOneArgJacobianExtras{C}, ) where {C} jac_wrongshape = jacobian(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow) nx = length(x) @@ -100,7 +119,10 @@ function DI.jacobian( end function DI.value_and_jacobian( - f, backend::AutoForwardOrNothingEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras + f, + backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, + x, + extras::EnzymeForwardOneArgJacobianExtras, ) return f(x), DI.jacobian(f, backend, x, extras) end @@ -108,7 +130,7 @@ end function DI.jacobian!( f, jac, - backend::AutoForwardOrNothingEnzyme, + backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, extras::EnzymeForwardOneArgJacobianExtras, ) @@ -118,7 +140,7 @@ end function DI.value_and_jacobian!( f, jac, - backend::AutoForwardOrNothingEnzyme, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, extras::EnzymeForwardOneArgJacobianExtras, ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index 5de9a9c6..ed242c50 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -1,18 +1,25 @@ ## Pushforward -DI.prepare_pushforward(f!, y, ::AutoForwardOrNothingEnzyme, x, dx) = NoPushforwardExtras() +function DI.prepare_pushforward(f!, y, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx) + return NoPushforwardExtras() +end function DI.value_and_pushforward( - f!, y, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras + f!, + y, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, + x, + dx, + ::NoPushforwardExtras, ) dx_sametype = convert(typeof(x), dx) - dy_sametype = zero(y) - autodiff( - forward_mode(backend), - Const(f!), - Const, - Duplicated(y, dy_sametype), - Duplicated(x, dx_sametype), - ) + dy_sametype = make_zero(y) + y_and_dy = Duplicated(y, dy_sametype) + x_and_dx = Duplicated(x, dx_sametype) + if backend isa AutoDeferredEnzyme + autodiff_deferred(forward_mode(backend), f!, Const, y_and_dy, x_and_dx) + else + autodiff(forward_mode(backend), Const(f!), Const, y_and_dy, x_and_dx) + end return y, dy_sametype end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index a859562f..5c520d43 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -1,23 +1,36 @@ ## Pullback -DI.prepare_pullback(f, ::AutoReverseOrNothingEnzyme, x, dy) = NoPullbackExtras() +function DI.prepare_pullback(f, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy) + return NoPullbackExtras() +end ### Out-of-place function DI.value_and_pullback( - f, ::AutoReverseOrNothingEnzyme, x::Number, dy::Number, ::NoPullbackExtras + f, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + x::Number, + dy::Number, + ::NoPullbackExtras, ) - der, y = autodiff(ReverseWithPrimal, Const(f), Active, Active(x)) + der, y = if backend isa AutoDeferredEnzyme + autodiff_deferred(ReverseWithPrimal, f, Active, Active(x)) + else + autodiff(ReverseWithPrimal, Const(f), Active, Active(x)) + end new_dx = dy * only(der) return y, new_dx end function DI.value_and_pullback( - f, ::AutoReverseOrNothingEnzyme, x::Number, dy::AbstractArray, ::NoPullbackExtras + f, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + x::Number, + dy::AbstractArray, + ::NoPullbackExtras, ) - forw, rev = autodiff_thunk( - ReverseSplitWithPrimal, Const{typeof(f)}, Duplicated, Active{typeof(x)} - ) + tf, tx = typeof(f), typeof(x) + forw, rev = autodiff_thunk(ReverseSplitWithPrimal, Const{tf}, Duplicated, Active{tx}) tape, y, new_dy = forw(Const(f), Active(x)) copyto!(new_dy, dy) new_dx = only(only(rev(Const(f), Active(x), tape))) @@ -25,14 +38,18 @@ function DI.value_and_pullback( end function DI.value_and_pullback( - f, backend::AutoReverseOrNothingEnzyme, x::AbstractArray, dy, extras::NoPullbackExtras + f, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + x::AbstractArray, + dy, + extras::NoPullbackExtras, ) dx = similar(x) return DI.value_and_pullback!(f, dx, backend, x, dy, extras) end function DI.pullback( - f, backend::AutoReverseOrNothingEnzyme, x, dy, extras::NoPullbackExtras + f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy, extras::NoPullbackExtras ) return DI.value_and_pullback(f, backend, x, dy, extras)[2] end @@ -40,10 +57,21 @@ end ### In-place function DI.value_and_pullback!( - f, dx, ::AutoReverseOrNothingEnzyme, x::AbstractArray, dy::Number, ::NoPullbackExtras + f, + dx, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + x::AbstractArray, + dy::Number, + ::NoPullbackExtras, ) - dx_sametype = zero_sametype!(dx, x) - _, y = autodiff(ReverseWithPrimal, Const(f), Active, Duplicated(x, dx_sametype)) + dx_sametype = convert(typeof(x), dx) + dx_sametype .= zero(eltype(x)) + x_and_dx = Duplicated(x, dx_sametype) + _, y = if backend isa AutoDeferredEnzyme + autodiff_deferred(ReverseWithPrimal, Const(f), Active, x_and_dx) + else + autodiff(ReverseWithPrimal, Const(f), Active, x_and_dx) + end dx_sametype .*= dy return y, copyto!(dx, dx_sametype) end @@ -51,47 +79,77 @@ end function DI.value_and_pullback!( f, dx, - ::AutoReverseOrNothingEnzyme, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x::AbstractArray, dy::AbstractArray, ::NoPullbackExtras, ) - dx_sametype = zero_sametype!(dx, x) + tf, tx = typeof(f), typeof(x) forw, rev = autodiff_thunk( - ReverseSplitWithPrimal, Const{typeof(f)}, Duplicated, Duplicated{typeof(x)} + ReverseSplitWithPrimal, Const{tf}, Duplicated, Duplicated{tx} ) + dx_sametype = convert(typeof(x), dx) + dx_sametype .= zero(eltype(x)) tape, y, new_dy = forw(Const(f), Duplicated(x, dx_sametype)) copyto!(new_dy, dy) rev(Const(f), Duplicated(x, dx_sametype), tape) return y, copyto!(dx, dx_sametype) end -function DI.pullback!(f, dx, backend::AutoReverseEnzyme, x, dy, extras::NoPullbackExtras) +function DI.pullback!( + f, + dx, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + x, + dy, + extras::NoPullbackExtras, +) return DI.value_and_pullback!(f, dx, backend, x, dy, extras)[2] end ## Gradient -DI.prepare_gradient(f, ::AutoReverseOrNothingEnzyme, x) = NoGradientExtras() +function DI.prepare_gradient(f, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x) + return NoGradientExtras() +end -function DI.gradient(f, backend::AutoReverseOrNothingEnzyme, x, ::NoGradientExtras) - return gradient(reverse_mode(backend), f, x) +function DI.gradient( + f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ::NoGradientExtras +) + if backend isa AutoDeferredEnzyme + grad = make_zero(x) + autodiff_deferred(reverse_mode(backend), f, Active, Duplicated(x, grad)) + return grad + else + return gradient(reverse_mode(backend), f, x) + end end -function DI.gradient!(f, grad, backend::AutoReverseOrNothingEnzyme, x, ::NoGradientExtras) +function DI.gradient!( + f, + grad, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + x, + extras::NoGradientExtras, +) grad_sametype = convert(typeof(x), grad) - gradient!(reverse_mode(backend), grad_sametype, f, x) + grad_sametype .= zero(eltype(x)) + if backend isa AutoDeferredEnzyme + autodiff_deferred(reverse_mode(backend), f, Active, Duplicated(x, grad_sametype)) + else + gradient!(reverse_mode(backend), grad_sametype, f, x) + end return copyto!(grad, grad_sametype) end function DI.value_and_gradient( - f, backend::AutoReverseOrNothingEnzyme, x, ::NoGradientExtras + f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ::NoGradientExtras ) return DI.value_and_pullback(f, backend, x, one(eltype(x)), NoPullbackExtras()) end function DI.value_and_gradient!( - f, grad, backend::AutoReverseOrNothingEnzyme, x, ::NoGradientExtras + f, grad, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ::NoGradientExtras ) return DI.value_and_pullback!(f, grad, backend, x, one(eltype(x)), NoPullbackExtras()) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index 818e1896..4648e7b6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -1,30 +1,43 @@ ## Pullback -DI.prepare_pullback(f!, y, ::AutoReverseOrNothingEnzyme, x, dy) = NoPullbackExtras() +function DI.prepare_pullback(f!, y, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy) + return NoPullbackExtras() +end function DI.value_and_pullback( - f!, y, backend::AutoReverseOrNothingEnzyme, x::Number, dy, ::NoPullbackExtras + f!, + y, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + x::Number, + dy, + ::NoPullbackExtras, ) dy_sametype = convert(typeof(y), copy(dy)) - _, new_dx = only( - autodiff( - reverse_mode(backend), Const(f!), Const, Duplicated(y, dy_sametype), Active(x) - ), - ) + y_and_dy = Duplicated(y, dy_sametype) + _, new_dx = if backend isa AutoDeferredEnzyme + only(autodiff_deferred(reverse_mode(backend), f!, Const, y_and_dy, Active(x))) + else + only(autodiff(reverse_mode(backend), Const(f!), Const, y_and_dy, Active(x))) + end return y, new_dx end function DI.value_and_pullback( - f!, y, backend::AutoReverseOrNothingEnzyme, x::AbstractArray, dy, ::NoPullbackExtras + f!, + y, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + x::AbstractArray, + dy, + ::NoPullbackExtras, ) - dx_sametype = zero(x) + dx_sametype = make_zero(x) dy_sametype = convert(typeof(y), copy(dy)) - autodiff( - reverse_mode(backend), - Const(f!), - Const, - Duplicated(y, dy_sametype), - Duplicated(x, dx_sametype), - ) + y_and_dy = Duplicated(y, dy_sametype) + x_and_dx = Duplicated(x, dx_sametype) + if backend isa AutoDeferredEnzyme + autodiff_deferred(reverse_mode(backend), f!, Const, y_and_dy, x_and_dx) + else + autodiff(reverse_mode(backend), Const(f!), Const, y_and_dy, x_and_dx) + end return y, dx_sametype end diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 98d7a805..9b3433a3 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -83,7 +83,8 @@ end function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ForwardOverForward) where {F} # pushforward of many pushforwards in theory, but pushforward of gradient in practice - inner_gradient_closure(z) = gradient(f, inner(backend), z) + inner_backend = nested(inner(backend)) + inner_gradient_closure(z) = gradient(f, inner_backend, z) outer_pushforward_extras = prepare_pushforward( inner_gradient_closure, outer(backend), x, v ) @@ -92,7 +93,8 @@ end function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ForwardOverReverse) where {F} # pushforward of gradient - inner_gradient_closure(z) = gradient(f, inner(backend), z) + inner_backend = nested(inner(backend)) + inner_gradient_closure(z) = gradient(f, inner_backend, z) outer_pushforward_extras = prepare_pushforward( inner_gradient_closure, outer(backend), x, v ) @@ -102,7 +104,11 @@ end function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ReverseOverForward) where {F} # gradient of pushforward # uses v in the closure - inner_pushforward_closure_generator(v) = z -> pushforward(f, inner(backend), z, v) + inner_backend = nested(inner(backend)) + function inner_pushforward_closure_generator(v) + inner_pushforward_closure(z) = pushforward(f, inner_backend, z, v) + return inner_pushforward_closure + end outer_gradient_extras = prepare_gradient( inner_pushforward_closure_generator(v), outer(backend), x ) @@ -113,7 +119,8 @@ end function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ReverseOverReverse) where {F} # pullback of the gradient - inner_gradient_closure(z) = gradient(f, inner(backend), z) + inner_backend = nested(inner(backend)) + inner_gradient_closure(z) = gradient(f, inner_backend, z) outer_pullback_extras = prepare_pullback(inner_gradient_closure, outer(backend), x, v) return ReverseOverReverseHVPExtras(inner_gradient_closure, outer_pullback_extras) end diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index 664742ff..22e0a2e3 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -43,7 +43,8 @@ end prepare_second_derivative(f::F, ::AbstractADType, x) where {F} = NoSecondDerivativeExtras() function prepare_second_derivative(f::F, backend::SecondOrder, x) where {F} - inner_derivative_closure(z) = derivative(f, inner(backend), z) + inner_backend = nested(inner(backend)) + inner_derivative_closure(z) = derivative(f, inner_backend, z) outer_derivative_extras = prepare_derivative( inner_derivative_closure, outer(backend), x ) diff --git a/DifferentiationInterface/src/second_order/second_order.jl b/DifferentiationInterface/src/second_order/second_order.jl index b0fd8f95..5c9cae11 100644 --- a/DifferentiationInterface/src/second_order/second_order.jl +++ b/DifferentiationInterface/src/second_order/second_order.jl @@ -45,3 +45,12 @@ outer(backend::SecondOrder) = backend.outer Return the _outer_ mode of the second-order backend. """ ADTypes.mode(backend::SecondOrder) = mode(outer(backend)) + +""" + nested(backend) + +Return a possibly modified `backend` that can work while nested inside another differentiation procedure. + +At the moment, this is only useful for Enzyme, which needs `autodiff_deferred` to be compatible with higher-order differentiation. +""" +nested(backend::AbstractADType) = backend diff --git a/DifferentiationInterface/src/utils/check.jl b/DifferentiationInterface/src/utils/check.jl index 1db3d242..77ff8f6d 100644 --- a/DifferentiationInterface/src/utils/check.jl +++ b/DifferentiationInterface/src/utils/check.jl @@ -16,7 +16,7 @@ Check whether `backend` supports differentiation of two-argument functions. """ check_twoarg(backend::AbstractADType) = Bool(twoarg_support(backend)) -hess_checker(x::AbstractArray) = abs2(x[1]) * abs2(x[2]) +hess_checker(x::AbstractArray) = x[1] * x[1] * x[2] * x[2] """ check_hessian(backend) diff --git a/DifferentiationInterface/src/utils/exceptions.jl b/DifferentiationInterface/src/utils/exceptions.jl index 3a8d93c4..a2d59c15 100644 --- a/DifferentiationInterface/src/utils/exceptions.jl +++ b/DifferentiationInterface/src/utils/exceptions.jl @@ -9,7 +9,7 @@ function Base.showerror(io::IO, e::MissingBackendError) io, """Backend package is not loaded. To fix, run - using $(backend_package_name(e.backend)) + import $(backend_package_name(e.backend)) """, ) else diff --git a/DifferentiationInterface/src/utils/printing.jl b/DifferentiationInterface/src/utils/printing.jl index 63b9f5e4..8529b6f6 100644 --- a/DifferentiationInterface/src/utils/printing.jl +++ b/DifferentiationInterface/src/utils/printing.jl @@ -24,7 +24,7 @@ function backend_str(backend::AbstractADType) elseif mode(backend) isa SymbolicMode return "$bs (symbolic)" elseif mode(backend) isa ForwardOrReverseMode - return "$bs (forward/reverse)" + return "$bs (forward|reverse)" else error("Unknown mode") end diff --git a/DifferentiationInterface/test/Double/Enzyme-ForwardDiff.jl b/DifferentiationInterface/test/Double/Enzyme-ForwardDiff.jl index 1ad635ab..574530a0 100644 --- a/DifferentiationInterface/test/Double/Enzyme-ForwardDiff.jl +++ b/DifferentiationInterface/test/Double/Enzyme-ForwardDiff.jl @@ -5,8 +5,8 @@ using ForwardDiff: ForwardDiff using Test backends = [ - SecondOrder(AutoForwardDiff(), AutoEnzyme(Enzyme.Forward)), - SecondOrder(AutoEnzyme(Enzyme.Reverse), AutoForwardDiff()), + SecondOrder(AutoForwardDiff(), AutoEnzyme(; mode=Enzyme.Forward)), + SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse), AutoForwardDiff()), ] for backend in backends diff --git a/DifferentiationInterface/test/Single/Enzyme.jl b/DifferentiationInterface/test/Single/Enzyme.jl index 0a497d50..2c5ec20e 100644 --- a/DifferentiationInterface/test/Single/Enzyme.jl +++ b/DifferentiationInterface/test/Single/Enzyme.jl @@ -1,7 +1,8 @@ +using ADTypes: ADTypes using DifferentiationInterface, DifferentiationInterfaceTest using Enzyme: Enzyme -using SparseConnectivityTracer -using SparseMatrixColorings +using SparseConnectivityTracer, SparseMatrixColorings +using StableRNGs using Test dense_backends = [ @@ -10,28 +11,53 @@ dense_backends = [ AutoEnzyme(; mode=Enzyme.Reverse), ] -sparse_backends = [ - AutoSparse( - AutoEnzyme(; mode=Enzyme.Forward); - sparsity_detector=TracerSparsityDetector(), - coloring_algorithm=GreedyColoringAlgorithm(), - ), - AutoSparse( - AutoEnzyme(; mode=Enzyme.Reverse); +nested_dense_backends = [ + DifferentiationInterface.nested(AutoEnzyme(; mode=Enzyme.Forward)), + DifferentiationInterface.nested(AutoEnzyme(; mode=Enzyme.Reverse)), +] + +sparse_backends = + AutoSparse.( + dense_backends, sparsity_detector=TracerSparsityDetector(), coloring_algorithm=GreedyColoringAlgorithm(), - ), -] + ) -for backend in vcat(dense_backends, sparse_backends) - @test check_available(backend) - @test check_twoarg(backend) - @test !check_hessian(backend; verbose=false) +@testset "Checks" begin + @testset "Check $(typeof(backend))" for backend in vcat(dense_backends, sparse_backends) + @test check_available(backend) + @test check_twoarg(backend) + @test check_hessian(backend; verbose=false) + end end ## Dense backends -test_differentiation(dense_backends; second_order=false, logging=LOGGING); +test_differentiation( + vcat(dense_backends, nested_dense_backends), + default_scenarios(); + second_order=false, + logging=LOGGING, +); + +test_differentiation( + [ + AutoEnzyme(; mode=nothing), + AutoEnzyme(; mode=Enzyme.Reverse), + SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse), AutoEnzyme(; mode=Enzyme.Reverse)), + SecondOrder(AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse)), + ]; + first_order=false, + excluded=[SecondDerivativeScenario], + logging=LOGGING, +); + +test_differentiation( + [AutoEnzyme(; mode=nothing), AutoEnzyme(; mode=Enzyme.Forward)]; + first_order=false, + excluded=[HessianScenario, HVPScenario], + logging=LOGGING, +); test_differentiation( AutoEnzyme(; mode=Enzyme.Forward); # TODO: add more diff --git a/DifferentiationInterface/test/Single/ForwardDiff.jl b/DifferentiationInterface/test/Single/ForwardDiff.jl index 9a100193..fdfc80c7 100644 --- a/DifferentiationInterface/test/Single/ForwardDiff.jl +++ b/DifferentiationInterface/test/Single/ForwardDiff.jl @@ -1,7 +1,6 @@ using DifferentiationInterface, DifferentiationInterfaceTest using ForwardDiff: ForwardDiff -using SparseConnectivityTracer -using SparseMatrixColorings +using SparseConnectivityTracer, SparseMatrixColorings using Test dense_backends = [AutoForwardDiff(), AutoForwardDiff(; chunksize=2, tag=:hello)] diff --git a/DifferentiationInterface/test/Single/Zygote.jl b/DifferentiationInterface/test/Single/Zygote.jl index 225e81d5..13634893 100644 --- a/DifferentiationInterface/test/Single/Zygote.jl +++ b/DifferentiationInterface/test/Single/Zygote.jl @@ -1,6 +1,5 @@ using DifferentiationInterface, DifferentiationInterfaceTest -using SparseConnectivityTracer -using SparseMatrixColorings +using SparseConnectivityTracer, SparseMatrixColorings using Test using Zygote: Zygote diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 7482e3e6..e4853713 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -14,6 +14,7 @@ JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -29,6 +30,7 @@ JET = "0.4 - 0.8, 0.9" JLArrays = "0.1" LinearAlgebra = "<0.0.1,1" ProgressMeter = "1" +Random = "<0.0.1,1" SparseArrays = "<0.0.1,1" SparseConnectivityTracer = "0.4.2" StaticArrays = "1.9" diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index ae3e8c9c..f722f31e 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -46,6 +46,7 @@ using JET: JET using JLArrays: jl using LinearAlgebra: Adjoint, Diagonal, Transpose, dot, parent using ProgressMeter: ProgressUnknown, next! +using Random: AbstractRNG, default_rng using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm using StaticArrays: MMatrix, MVector, SMatrix, SVector using Test: @testset, @test diff --git a/DifferentiationInterfaceTest/src/scenarios/component.jl b/DifferentiationInterfaceTest/src/scenarios/component.jl index ba010c2e..349a95d1 100644 --- a/DifferentiationInterfaceTest/src/scenarios/component.jl +++ b/DifferentiationInterfaceTest/src/scenarios/component.jl @@ -45,17 +45,17 @@ end const CVEC = ComponentVector(; a=collect(1:4), b=collect(5:6)) """ - component_scenarios() + component_scenarios(rng=Random.default_rng()) Create a vector of [`AbstractScenario`](@ref)s with component array types from [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl). """ -function component_scenarios() - x = ComponentVector(; a=randn(4), b=randn(2)) +function component_scenarios(rng::AbstractRNG=default_rng()) + x = ComponentVector(; a=randn(rng, 4), b=randn(rng, 2)) return vcat( # one argument - num_to_arr_scenarios_onearg(randn(), CVEC), + num_to_arr_scenarios_onearg(randn(rng), CVEC), comp_to_num_scenarios_onearg(x::ComponentVector), # two arguments - num_to_arr_scenarios_twoarg(randn(), CVEC), + num_to_arr_scenarios_twoarg(randn(rng), CVEC), ) end diff --git a/DifferentiationInterfaceTest/src/scenarios/default.jl b/DifferentiationInterfaceTest/src/scenarios/default.jl index a5e8a547..b2108728 100644 --- a/DifferentiationInterfaceTest/src/scenarios/default.jl +++ b/DifferentiationInterfaceTest/src/scenarios/default.jl @@ -205,7 +205,6 @@ function arr_to_num_scenarios_onearg(x::AbstractArray; linalg=true) [ PullbackScenario(arr_to_num; x=x, ref=arr_to_num_pullback, place=place), GradientScenario(arr_to_num; x=x, ref=arr_to_num_gradient, place=place), - GradientScenario(arr_to_num; x=x, ref=arr_to_num_gradient, place=place), HVPScenario(arr_to_num; x=x, ref=arr_to_num_hvp, place=place), HessianScenario(arr_to_num; x=x, ref=arr_to_num_hessian, place=place), ], @@ -492,28 +491,28 @@ const IVEC = Vector(1:6) const IMAT = Matrix((1:2) .* transpose(1:3)) """ - default_scenarios() + default_scenarios(rng=Random.default_rng()) Create a vector of [`AbstractScenario`](@ref)s with standard array types. """ -function default_scenarios(; linalg=true) +function default_scenarios(rng::AbstractRNG=default_rng(); linalg=true) return vcat( # one argument - num_to_num_scenarios_onearg(rand()), - num_to_arr_scenarios_onearg(rand(), IVEC), - num_to_arr_scenarios_onearg(rand(), IMAT), - arr_to_num_scenarios_onearg(rand(6); linalg), - arr_to_num_scenarios_onearg(rand(2, 3); linalg), - vec_to_vec_scenarios_onearg(rand(6)), - vec_to_mat_scenarios_onearg(rand(6)), - mat_to_vec_scenarios_onearg(rand(2, 3)), - mat_to_mat_scenarios_onearg(rand(2, 3)), + num_to_num_scenarios_onearg(rand(rng)), + num_to_arr_scenarios_onearg(rand(rng), IVEC), + num_to_arr_scenarios_onearg(rand(rng), IMAT), + arr_to_num_scenarios_onearg(rand(rng, 6); linalg), + arr_to_num_scenarios_onearg(rand(rng, 2, 3); linalg), + vec_to_vec_scenarios_onearg(rand(rng, 6)), + vec_to_mat_scenarios_onearg(rand(rng, 6)), + mat_to_vec_scenarios_onearg(rand(rng, 2, 3)), + mat_to_mat_scenarios_onearg(rand(rng, 2, 3)), # two arguments - num_to_arr_scenarios_twoarg(rand(), IVEC), - num_to_arr_scenarios_twoarg(rand(), IMAT), - vec_to_vec_scenarios_twoarg(rand(6)), - vec_to_mat_scenarios_twoarg(rand(6)), - mat_to_vec_scenarios_twoarg(rand(2, 3)), - mat_to_mat_scenarios_twoarg(rand(2, 3)), + num_to_arr_scenarios_twoarg(rand(rng), IVEC), + num_to_arr_scenarios_twoarg(rand(rng), IMAT), + vec_to_vec_scenarios_twoarg(rand(rng, 6)), + vec_to_mat_scenarios_twoarg(rand(rng, 6)), + mat_to_vec_scenarios_twoarg(rand(rng, 2, 3)), + mat_to_mat_scenarios_twoarg(rand(rng, 2, 3)), ) end diff --git a/DifferentiationInterfaceTest/src/scenarios/gpu.jl b/DifferentiationInterfaceTest/src/scenarios/gpu.jl index 2ba88923..b69e7df4 100644 --- a/DifferentiationInterfaceTest/src/scenarios/gpu.jl +++ b/DifferentiationInterfaceTest/src/scenarios/gpu.jl @@ -2,27 +2,27 @@ const JLVEC = jl(IVEC) const JLMAT = jl(IMAT) """ - gpu_scenarios() + gpu_scenarios(rng=Random.default_rng()) Create a vector of [`AbstractScenario`](@ref)s with GPU array types from [JLArrays.jl](https://github.com/JuliaGPU/GPUArrays.jl/tree/master/lib/JLArrays). """ -function gpu_scenarios(; linalg=true) +function gpu_scenarios(rng::AbstractRNG=default_rng(); linalg=true) return vcat( # one argument - num_to_arr_scenarios_onearg(rand(), JLVEC), - num_to_arr_scenarios_onearg(rand(), JLMAT), - arr_to_num_scenarios_onearg(jl(rand(6)); linalg), - arr_to_num_scenarios_onearg(jl(rand(2, 3)); linalg), - vec_to_vec_scenarios_onearg(jl(rand(6))), - vec_to_mat_scenarios_onearg(jl(rand(6))), - mat_to_vec_scenarios_onearg(jl(rand(2, 3))), - mat_to_mat_scenarios_onearg(jl(rand(2, 3))), + num_to_arr_scenarios_onearg(rand(rng), JLVEC), + num_to_arr_scenarios_onearg(rand(rng), JLMAT), + arr_to_num_scenarios_onearg(jl(rand(rng, 6)); linalg), + arr_to_num_scenarios_onearg(jl(rand(rng, 2, 3)); linalg), + vec_to_vec_scenarios_onearg(jl(rand(rng, 6))), + vec_to_mat_scenarios_onearg(jl(rand(rng, 6))), + mat_to_vec_scenarios_onearg(jl(rand(rng, 2, 3))), + mat_to_mat_scenarios_onearg(jl(rand(rng, 2, 3))), # two arguments - num_to_arr_scenarios_twoarg(rand(), JLVEC), - num_to_arr_scenarios_twoarg(rand(), JLMAT), - vec_to_vec_scenarios_twoarg(jl(rand(6))), - vec_to_mat_scenarios_twoarg(jl(rand(6))), - mat_to_vec_scenarios_twoarg(jl(rand(2, 3))), - mat_to_mat_scenarios_twoarg(jl(rand(2, 3))), + num_to_arr_scenarios_twoarg(rand(rng), JLVEC), + num_to_arr_scenarios_twoarg(rand(rng), JLMAT), + vec_to_vec_scenarios_twoarg(jl(rand(rng, 6))), + vec_to_mat_scenarios_twoarg(jl(rand(rng, 6))), + mat_to_vec_scenarios_twoarg(jl(rand(rng, 2, 3))), + mat_to_mat_scenarios_twoarg(jl(rand(rng, 2, 3))), ) end diff --git a/DifferentiationInterfaceTest/src/scenarios/sparse.jl b/DifferentiationInterfaceTest/src/scenarios/sparse.jl index 8831f781..fc6e4c5d 100644 --- a/DifferentiationInterfaceTest/src/scenarios/sparse.jl +++ b/DifferentiationInterfaceTest/src/scenarios/sparse.jl @@ -223,17 +223,17 @@ end ## Gather """ - sparse_scenarios() + sparse_scenarios(rng=Random.default_rng()) Create a vector of [`AbstractScenario`](@ref)s with sparse array types, focused on sparse Jacobians and Hessians. """ -function sparse_scenarios() +function sparse_scenarios(rng::AbstractRNG=default_rng()) return vcat( - sparse_vec_to_vec_scenarios(rand(6)), - sparse_vec_to_mat_scenarios(rand(6)), - sparse_mat_to_vec_scenarios(rand(2, 3)), - sparse_mat_to_mat_scenarios(rand(2, 3)), - sparse_vec_to_num_scenarios(rand(6)), - sparse_mat_to_num_scenarios(rand(2, 3)), + sparse_vec_to_vec_scenarios(rand(rng, 6)), + sparse_vec_to_mat_scenarios(rand(rng, 6)), + sparse_mat_to_vec_scenarios(rand(rng, 2, 3)), + sparse_mat_to_mat_scenarios(rand(rng, 2, 3)), + sparse_vec_to_num_scenarios(rand(rng, 6)), + sparse_mat_to_num_scenarios(rand(rng, 2, 3)), ) end diff --git a/DifferentiationInterfaceTest/src/scenarios/static.jl b/DifferentiationInterfaceTest/src/scenarios/static.jl index faa2de63..6db19ab1 100644 --- a/DifferentiationInterfaceTest/src/scenarios/static.jl +++ b/DifferentiationInterfaceTest/src/scenarios/static.jl @@ -2,28 +2,28 @@ const SVEC = SVector{length(IVEC)}(IVEC) const SMAT = SMatrix{size(IMAT, 1),size(IMAT, 2)}(IMAT) """ - static_scenarios() + static_scenarios(rng=Random.default_rng()) Create a vector of [`AbstractScenario`](@ref)s with static array types from [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl). """ -function static_scenarios(; linalg=true) +function static_scenarios(rng::AbstractRNG=default_rng(); linalg=true) scens = vcat( # one argument - num_to_arr_scenarios_onearg(rand(), SVEC), - num_to_arr_scenarios_onearg(rand(), SMAT), - arr_to_num_scenarios_onearg(SVector{6}(rand(6)); linalg), - arr_to_num_scenarios_onearg(SMatrix{2,3}(rand(2, 3)); linalg), - vec_to_vec_scenarios_onearg(SVector{6}(rand(6))), - vec_to_mat_scenarios_onearg(SVector{6}(rand(6))), - mat_to_vec_scenarios_onearg(SMatrix{2,3}(rand(2, 3))), - mat_to_mat_scenarios_onearg(SMatrix{2,3}(rand(2, 3))), + num_to_arr_scenarios_onearg(rand(rng), SVEC), + num_to_arr_scenarios_onearg(rand(rng), SMAT), + arr_to_num_scenarios_onearg(SVector{6}(rand(rng, 6)); linalg), + arr_to_num_scenarios_onearg(SMatrix{2,3}(rand(rng, 2, 3)); linalg), + vec_to_vec_scenarios_onearg(SVector{6}(rand(rng, 6))), + vec_to_mat_scenarios_onearg(SVector{6}(rand(rng, 6))), + mat_to_vec_scenarios_onearg(SMatrix{2,3}(rand(rng, 2, 3))), + mat_to_mat_scenarios_onearg(SMatrix{2,3}(rand(rng, 2, 3))), # two arguments - num_to_arr_scenarios_twoarg(rand(), SVEC), - num_to_arr_scenarios_twoarg(rand(), SMAT), - vec_to_vec_scenarios_twoarg(MVector{6}(rand(6))), - vec_to_mat_scenarios_twoarg(MVector{6}(rand(6))), - mat_to_vec_scenarios_twoarg(MMatrix{2,3}(rand(2, 3))), - mat_to_mat_scenarios_twoarg(MMatrix{2,3}(rand(2, 3))), + num_to_arr_scenarios_twoarg(rand(rng), SVEC), + num_to_arr_scenarios_twoarg(rand(rng), SMAT), + vec_to_vec_scenarios_twoarg(MVector{6}(rand(rng, 6))), + vec_to_mat_scenarios_twoarg(MVector{6}(rand(rng, 6))), + mat_to_vec_scenarios_twoarg(MMatrix{2,3}(rand(rng, 2, 3))), + mat_to_mat_scenarios_twoarg(MMatrix{2,3}(rand(rng, 2, 3))), ) scens = filter(scens) do s operator_place(s) == :outofplace || s.x isa Union{Number,MVector,MMatrix}