diff --git a/docs/src/faq.md b/docs/src/faq.md index 88c0cce3b9..6b3bbce6b4 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -193,7 +193,7 @@ That is why Enzyme provides a helper function `Enzyme.make_zero` that does this ```jldoctest sparse Enzyme.make_zero(a) -Enzyme.gradient(Reverse, sum, a) # This calls make_zero(a) +Enzyme.gradient(Reverse, sum, a)[1] # This calls make_zero(a) # output diff --git a/docs/src/index.md b/docs/src/index.md index 1f7f092a99..2643b87a1e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -76,24 +76,32 @@ Both the inplace and "normal" variant return the gradient. The difference is tha ## Forward mode -The return value of forward mode with a `Duplicated` return is a tuple containing as the first value -the primal return value and as the second value the derivative. +The return value when using `ForwardWithPrimal` is a tuple containing as the first value +the derivative return value and as the second value the original value. + +The return value when using `Forward` is a single-element tuple containing the derivative. In forward mode `Duplicated(x, 0.0)` is equivalent to `Const(x)`, except that we can perform more optimizations for `Const`. ```jldoctest rosenbrock -julia> autodiff(Forward, rosenbrock, Duplicated, Const(1.0), Duplicated(3.0, 1.0)) +julia> autodiff(ForwardWithPrimal, rosenbrock, Const(1.0), Duplicated(3.0, 1.0)) (400.0, 400.0) -julia> autodiff(Forward, rosenbrock, Duplicated, Duplicated(1.0, 1.0), Const(3.0)) -(400.0, -800.0) +julia> autodiff(Forward, rosenbrock, Const(1.0), Duplicated(3.0, 1.0)) +(400.0,) + +julia> autodiff(ForwardWithPrimal, rosenbrock, Duplicated(1.0, 1.0), Const(3.0)) +(-800.0, 400.0) + +julia> autodiff(Forward, rosenbrock, Duplicated(1.0, 1.0), Const(3.0)) +(-800.0,) ``` Of note, when we seed both arguments at once the tangent return is the sum of both. ```jldoctest rosenbrock -julia> autodiff(Forward, rosenbrock, Duplicated, Duplicated(1.0, 1.0), Duplicated(3.0, 1.0)) +julia> autodiff(ForwardWithPrimal, rosenbrock, Duplicated(1.0, 1.0), Duplicated(3.0, 1.0)) (400.0, -400.0) ``` @@ -121,7 +129,7 @@ Note the seeding through `dx`. We can also use vector mode to calculate both derivatives at once. ```jldoctest rosenbrock -julia> autodiff(Forward, rosenbrock, BatchDuplicated, BatchDuplicated(1.0, (1.0, 0.0)), BatchDuplicated(3.0, (0.0, 1.0))) +julia> autodiff(ForwardWithPrimal, rosenbrock, BatchDuplicated(1.0, (1.0, 0.0)), BatchDuplicated(3.0, (0.0, 1.0))) (400.0, (var"1" = -800.0, var"2" = 400.0)) julia> x = [1.0, 3.0] @@ -131,7 +139,7 @@ julia> x = [1.0, 3.0] julia> dx_1 = [1.0, 0.0]; dx_2 = [0.0, 1.0]; -julia> autodiff(Forward, rosenbrock_inp, BatchDuplicated, BatchDuplicated(x, (dx_1, dx_2))) +julia> autodiff(ForwardWithPrimal, rosenbrock_inp, BatchDuplicated(x, (dx_1, dx_2))) (400.0, (var"1" = -800.0, var"2" = 400.0)) ``` @@ -145,18 +153,20 @@ Like [`autodiff`](@ref), the mode (forward or reverse) is determined by the firs The functions [`gradient`](@ref) and [`gradient!`](@ref) compute the gradient of function with vector input and scalar return. +Gradient functions take a mode as the first argument. If the mode is `Reverse` or `Forward`, the return type is a tuple of gradients of each argument. +If the mode is `ReverseWithPrimal` or `ForwardWithPrimal`, the return type is a named tuple containing both the derivatives and the original return result. + ```jldoctest rosenbrock julia> gradient(Reverse, rosenbrock_inp, [1.0, 2.0]) -2-element Vector{Float64}: - -400.0 - 200.0 +([-400.0, 200.0],) + +julia> gradient(ReverseWithPrimal, rosenbrock_inp, [1.0, 2.0]) +(derivs=[-400.0, 200.0], val=100.0) julia> # inplace variant dx = [0.0, 0.0]; gradient!(Reverse, dx, rosenbrock_inp, [1.0, 2.0]) -2-element Vector{Float64}: - -400.0 - 200.0 +([-400.0, 200.0],) julia> dx 2-element Vector{Float64}: @@ -164,14 +174,16 @@ julia> dx 200.0 julia> gradient(Forward, rosenbrock_inp, [1.0, 2.0]) -(-400.0, 200.0) +([-400.0, 200.0],) + +julia> gradient(ForwardWithPrimal, rosenbrock_inp, [1.0, 2.0]) +(derivs = [-400.0, 200.0], val = 100.0) julia> # in forward mode, we can also optionally pass a chunk size # to specify the number of derivatives computed simulateneously # using vector forward mode - chunk_size = Val(2) - gradient(Forward, rosenbrock_inp, [1.0, 2.0], chunk_size) -(-400.0, 200.0) + gradient(Forward, rosenbrock_inp, [1.0, 2.0]; chunk=Val(1)) +([-400.0, 200.0],) ``` ## Jacobian Convenience functions @@ -179,31 +191,31 @@ julia> # in forward mode, we can also optionally pass a chunk size The function [`jacobian`](@ref) computes the Jacobian of a function vector input and vector return. Like [`autodiff`](@ref) and [`gradient`](@ref), the mode (forward or reverse) is determined by the first argument. +Again like [`gradient`](@ref), if the mode is `Reverse` or `Forward`, the return type is a tuple of jacobians of each argument. +If the mode is `ReverseWithPrimal` or `ForwardWithPrimal`, the return type is a named tuple containing both the derivatives and the original return result. + +Both forward and reverse modes take an optional chunk size to compute several derivatives simultaneously using vector mode, and reverse mode optionally takes `n_outs` which describes the shape of the output value. + ```jldoctest rosenbrock julia> foo(x) = [rosenbrock_inp(x), prod(x)]; -julia> output_size = Val(2) # here we have to provide the output size of `foo` since it cannot be statically inferred - jacobian(Reverse, foo, [1.0, 2.0], output_size) -2×2 transpose(::Matrix{Float64}) with eltype Float64: - -400.0 200.0 - 2.0 1.0 +julia> jacobian(Reverse, foo, [1.0, 2.0]) +([-400.0 200.0; 2.0 1.0],) -julia> chunk_size = Val(2) # By specifying the optional chunk size argument, we can use vector inverse mode to propogate derivatives of multiple outputs at once. - jacobian(Reverse, foo, [1.0, 2.0], output_size, chunk_size) -2×2 transpose(::Matrix{Float64}) with eltype Float64: - -400.0 200.0 - 2.0 1.0 +julia> jacobian(ReverseWithPrimal, foo, [1.0, 2.0]) +(derivs = ([-400.0 200.0; 2.0 1.0],), val = [100.0, 2.0]) + +julia> jacobian(Reverse, foo, [1.0, 2.0]; chunk=Val(2)) +([-400.0 200.0; 2.0 1.0],) + +julia> jacobian(Reverse, foo, [1.0, 2.0]; chunk=Val(2), n_outs=Val((2,))) +([-400.0 200.0; 2.0 1.0],) julia> jacobian(Forward, foo, [1.0, 2.0]) -2×2 Matrix{Float64}: - -400.0 200.0 - 2.0 1.0 - -julia> # Again, the optinal chunk size argument allows us to use vector forward mode - jacobian(Forward, foo, [1.0, 2.0], chunk_size) -2×2 Matrix{Float64}: - -400.0 200.0 - 2.0 1.0 +([-400.0 200.0; 2.0 1.0],) + +julia> jacobian(Forward, foo, [1.0, 2.0], chunk=Val(2)) +([-400.0 200.0; 2.0 1.0],) ``` ## Hessian Vector Product Convenience functions @@ -257,4 +269,4 @@ julia> grad 2-element Vector{Float64}: 2.880510859951098 1.920340573300732 -``` \ No newline at end of file +``` diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 66551f2958..c4994fb363 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1082,21 +1082,21 @@ a tuple where the first element contains the derivatives, and the second element grad = gradient(ReverseWithPrimal, f, [2.0, 3.0]) # output -(([3.0, 2.0],), 6.0) +(derivs = ([3.0, 2.0],), val = 6.0) ``` ```jldoctest gradient grad = gradient(ReverseWithPrimal, mul, [2.0], [3.0]) # output -(([3.0], [2.0]), 6.0) +(derivs = ([3.0], [2.0]), val = 6.0) ``` ```jldoctest gradient grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) # output -(([3.0], nothing), 6.0) +(derivs = ([3.0], nothing), val = 6.0) ``` """ @@ -1161,7 +1161,7 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) return quote Base.@_inline_meta $(toemit...) - (($(resargs...),), res[2]) + (; derivs=($(resargs...),), val=res[2]) end else return quote @@ -1196,14 +1196,14 @@ dx = [0.0, 0.0] gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0]) # output -(([3.0, 2.0],), 6.0) +(derivs = ([3.0, 2.0],), val = 6.0) ``` """ @inline function gradient!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} make_zero!(dx) res = autodiff(rm, f, Active, Duplicated(x, dx)) return if ReturnPrimal - ((dx,), res[2]) + (; derivs=(dx,), val=res[2]) else (dx,) end @@ -1300,7 +1300,7 @@ gradient(Forward, f, [2.0, 3.0]) gradient(ForwardWithPrimal, f, [2.0, 3.0]) # output -(([3.0, 2.0],), 6.0) +(derivs = ([3.0, 2.0],), val = 6.0) ``` ```jldoctest gradfwd @@ -1315,7 +1315,7 @@ gradient(Forward, f, [2.0, 3.0]; chunk=Val(1)) gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1)) # output -(([3.0, 2.0],), 6.0) +(derivs = ([3.0, 2.0],), val = 6.0) ``` For functions which return an AbstractArray or scalar, this function will return an AbstracttArray @@ -1336,10 +1336,10 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) """ @inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f, x; chunk::CS=nothing, shadows=create_shadows(chunk, x)) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity, CS} if length(shadows[1]) == 0 - if ReturnPrimal - ((x,), f(x.val)) + return if ReturnPrimal + (; derivs=(x,), val=f(x.val)) else - return (x,) + (x,) end end if chunk == Val(0) @@ -1430,7 +1430,7 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) cols end if ReturnPrimal - ((res,), gradtup[2]) + (; derivs=(res,), val=gradtup[2]) else (res,) end @@ -1498,7 +1498,7 @@ this function will retun an AbstractArray of shape `size(output)` of values of t end return if ReturnPrimal - (jac, res) + (; derivs=jac, val=res) else jac end @@ -1606,7 +1606,7 @@ this function will retun an AbstractArray of shape `size(output)` of values of t end if ReturnPrimal # TODO optimize away redundant fwd pass - (res, if f isa Enzyme.Const + (; derivs=res, val=if f isa Enzyme.Const f.val(x) else f(x) diff --git a/test/ext/bfloat16s.jl b/test/ext/bfloat16s.jl index 0a47f48f03..daaf6ef74c 100644 --- a/test/ext/bfloat16s.jl +++ b/test/ext/bfloat16s.jl @@ -2,6 +2,6 @@ using Enzyme using Test using BFloat16s -@test_broken Enzyme.gradient(Reverse, sum, ones(BFloat16, 10)) ≈ ones(BFloat16, 10) +@test_broken Enzyme.gradient(Reverse, sum, ones(BFloat16, 10))[1] ≈ ones(BFloat16, 10) -@test_broken Enzyme.gradient(Forward, sum, ones(BFloat16, 10)) ≈ ones(BFloat16, 10) +@test_broken Enzyme.gradient(Forward, sum, ones(BFloat16, 10))[1] ≈ ones(BFloat16, 10)