diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index af8437d6..bd98bd3a 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.13" +version = "0.6.14" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index f2dc28ac..7fb7d021 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -61,40 +61,6 @@ end ### Out-of-place -function DI.value_and_pullback( - f::F, - ::NoPullbackPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x::Number, - ty::NTuple{1}, - contexts::Vararg{Context,C}, -) where {F,C} - f_and_df = force_annotation(get_f_and_df(f, backend)) - mode = reverse_split_withprimal(backend) - RA = eltype(ty) <: Number ? Active : Duplicated - dinputs, result = seeded_autodiff_thunk( - mode, only(ty), f_and_df, RA, Active(x), map(translate, contexts)... - ) - return result, (first(dinputs),) -end - -function DI.value_and_pullback( - f::F, - ::NoPullbackPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, - x::Number, - ty::NTuple{B}, - contexts::Vararg{Context,C}, -) where {F,B,C} - f_and_df = force_annotation(get_f_and_df(f, backend, Val(B))) - mode = reverse_split_withprimal(backend) - RA = eltype(ty) <: Number ? Active : BatchDuplicated - dinputs, result = batch_seeded_autodiff_thunk( - mode, ty, f_and_df, RA, Active(x), map(translate, contexts)... - ) - return result, values(first(dinputs)) -end - function DI.value_and_pullback( f::F, ::NoPullbackPrep, @@ -105,12 +71,18 @@ function DI.value_and_pullback( ) where {F,C} f_and_df = force_annotation(get_f_and_df(f, backend)) mode = reverse_split_withprimal(backend) - RA = eltype(ty) <: Number ? Active : Duplicated + IA = guess_activity(typeof(x), mode) + RA = guess_activity(eltype(ty), mode) dx = make_zero(x) - _, result = seeded_autodiff_thunk( - mode, only(ty), f_and_df, RA, Duplicated(x, dx), map(translate, contexts)... + dinputs, result = seeded_autodiff_thunk( + mode, only(ty), f_and_df, RA, annotate(IA, x, dx), map(translate, contexts)... ) - return result, (dx,) + new_dx = first(dinputs) + if isnothing(new_dx) + return result, (dx,) + else + return result, (new_dx,) + end end function DI.value_and_pullback( @@ -123,12 +95,18 @@ function DI.value_and_pullback( ) where {F,B,C} f_and_df = force_annotation(get_f_and_df(f, backend, Val(B))) mode = reverse_split_withprimal(backend) - RA = eltype(ty) <: Number ? Active : BatchDuplicated + IA = batchify_activity(guess_activity(typeof(x), mode), Val(B)) + RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B)) tx = ntuple(_ -> make_zero(x), Val(B)) - _, result = batch_seeded_autodiff_thunk( - mode, ty, f_and_df, RA, BatchDuplicated(x, tx), map(translate, contexts)... + dinputs, result = batch_seeded_autodiff_thunk( + mode, ty, f_and_df, RA, annotate(IA, x, tx), map(translate, contexts)... ) - return result, tx + new_tx = values(first(dinputs)) + if isnothing(new_tx) + return result, tx + else + return result, new_tx + end end function DI.pullback( @@ -155,7 +133,7 @@ function DI.value_and_pullback!( ) where {F,C} f_and_df = force_annotation(get_f_and_df(f, backend)) mode = reverse_split_withprimal(backend) - RA = eltype(ty) <: Number ? Active : Duplicated + RA = guess_activity(eltype(ty), mode) dx_righttype = convert(typeof(x), only(tx)) make_zero!(dx_righttype) _, result = seeded_autodiff_thunk( @@ -181,7 +159,7 @@ function DI.value_and_pullback!( ) where {F,B,C} f_and_df = force_annotation(get_f_and_df(f, backend, Val(B))) mode = reverse_split_withprimal(backend) - RA = eltype(ty) <: Number ? Active : BatchDuplicated + RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B)) tx_righttype = map(Fix1(convert, typeof(x)), tx) make_zero!(tx_righttype) _, result = batch_seeded_autodiff_thunk( @@ -213,29 +191,39 @@ end ### Without preparation function DI.gradient( - f::F, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, - x, - contexts::Vararg{Context,C}, + f::F, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{Context,C} ) where {F,C} f_and_df = get_f_and_df(f, backend) - ders = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...) - grad = first(ders) - return grad + mode = reverse_noprimal(backend) + IA = guess_activity(typeof(x), mode) + grad = make_zero(x) + dinputs = only( + autodiff(mode, f_and_df, Active, annotate(IA, x, grad), map(translate, contexts)...) + ) + new_grad = first(dinputs) + if isnothing(new_grad) + return grad + else + return new_grad + end end function DI.value_and_gradient( - f::F, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, - x, - contexts::Vararg{Context,C}, + f::F, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{Context,C} ) where {F,C} f_and_df = get_f_and_df(f, backend) - ders, y = gradient( - reverse_withprimal(backend), f_and_df, x, map(translate, contexts)... + mode = reverse_withprimal(backend) + IA = guess_activity(typeof(x), mode) + grad = make_zero(x) + dinputs, result = autodiff( + mode, f_and_df, Active, annotate(IA, x, grad), map(translate, contexts)... ) - grad = first(ders) - return y, grad + new_grad = first(dinputs) + if isnothing(new_grad) + return result, grad + else + return result, new_grad + end end ### With preparation @@ -245,10 +233,7 @@ struct EnzymeGradientPrep{G} <: GradientPrep end function DI.prepare_gradient( - f::F, - ::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, - x, - contexts::Vararg{Context,C}, + f::F, ::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{Context,C} ) where {F,C} grad_righttype = make_zero(x) return EnzymeGradientPrep(grad_righttype) @@ -257,21 +242,18 @@ end function DI.gradient( f::F, ::EnzymeGradientPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, + backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{Context,C}, ) where {F,C} - f_and_df = get_f_and_df(f, backend) - ders = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...) - grad = first(ders) - return grad + return DI.gradient(f, backend, x, contexts...) end function DI.gradient!( f::F, grad, prep::EnzymeGradientPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, + backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{Context,C}, ) where {F,C} @@ -292,23 +274,18 @@ end function DI.value_and_gradient( f::F, ::EnzymeGradientPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, + backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{Context,C}, ) where {F,C} - f_and_df = get_f_and_df(f, backend) - ders, y = gradient( - reverse_withprimal(backend), f_and_df, x, map(translate, contexts)... - ) - grad = first(ders) - return y, grad + return DI.value_and_gradient(f, backend, x, contexts...) end function DI.value_and_gradient!( f::F, grad, prep::EnzymeGradientPrep, - backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, + backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{Context,C}, ) where {F,C} @@ -328,6 +305,9 @@ end ## Jacobian +# TODO: does not support static arrays + +#= struct EnzymeReverseOneArgJacobianPrep{Sy,B} <: JacobianPrep end function EnzymeReverseOneArgJacobianPrep(::Val{Sy}, ::Val{B}) where {Sy,B} @@ -385,3 +365,4 @@ function DI.value_and_jacobian!( y, new_jac = DI.value_and_jacobian(f, prep, backend, x) return y, copyto!(jac, new_jac) end +=# diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 55c8e3ab..085e77ea 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -76,3 +76,13 @@ end function maybe_reshape(A::AbstractArray, m, n) return reshape(A, m, n) end + +annotate(::Type{Active{T}}, x, dx) where {T} = Active(x) +annotate(::Type{Duplicated{T}}, x, dx) where {T} = Duplicated(x, dx) + +function annotate(::Type{BatchDuplicated{T,B}}, x, tx::NTuple{B}) where {T,B} + return BatchDuplicated(x, tx) +end + +batchify_activity(::Type{Active{T}}, ::Val{B}) where {T,B} = Active{T} +batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T,B} = BatchDuplicated{T,B} diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index 2afa2af9..947eb875 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -95,3 +95,16 @@ test_differentiation( sparsity=true, logging=LOGGING, ); + +## + +filtered_static_scenarios = filter(static_scenarios()) do s + DIT.operator_place(s) == :out && DIT.function_place(s) == :out +end + +test_differentiation( + [AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse)], + filtered_static_scenarios; + excluded=SECOND_ORDER, + logging=LOGGING, +)