Skip to content

Commit

Permalink
add interface v2
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Jul 19, 2023
1 parent 05810c2 commit 11d8d0e
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ function (::Type{∇})(f, x1, args...)
unthunk.((f)(x1, args...))
end

const gradient =

# Star Trek has their prime directive. We have the...
abstract type AbstractPrimeDerivative{N, T}; end

Expand Down Expand Up @@ -181,8 +179,8 @@ struct PrimeDerivative{N, T}
end

function (f::PrimeDerivative{N, T})(x) where {N, T}
# For now, this is backwards mode, since that's more fully implemented
return PrimeDerivativeBack{N, T}(f.f)(x)
# For now, this is forward mode, since that's more fully implemented
return PrimeDerivativeFwd{N, T}(f.f)(x)

Check warning on line 183 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L183

Added line #L183 was not covered by tests
end

"""
Expand Down Expand Up @@ -227,3 +225,7 @@ will compute the derivative `∂^3 f/∂x^2 ∂y` at `(x,y)`.
macro (expr)
error("Write me")
end
derivative(f, x) = Diffractor.PrimeDerivativeFwd(f)(x)
const gradient =
jacobian(f, x::AbstractArray) = reduce(hcat, vec.(gradient(f, x)))
hessian(f, x::AbstractArray) = jacobian(y -> gradient(f, y), float(x))

0 comments on commit 11d8d0e

Please sign in to comment.