Skip to content

Commit

Permalink
Implement automatic preparation with eval macro (#409)
Browse files Browse the repository at this point in the history
* Implement automatic preparation with eval macro

* Typos
  • Loading branch information
gdalle authored Aug 12, 2024
1 parent ad1328f commit e56be86
Show file tree
Hide file tree
Showing 13 changed files with 103 additions and 396 deletions.
2 changes: 2 additions & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ include("second_order/hvp.jl")
include("second_order/hvp_batched.jl")
include("second_order/hessian.jl")

include("fallbacks/no_extras.jl")

include("sparse/fallbacks.jl")
include("sparse/matrices.jl")
include("sparse/jacobian.jl")
Expand Down
101 changes: 101 additions & 0 deletions DifferentiationInterface/src/fallbacks/no_extras.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
for op in (:derivative, :gradient, :jacobian)
op! = Symbol(op, "!")
val_prefix = "value_and_"
val_and_op = Symbol(val_prefix, op)
val_and_op! = Symbol(val_prefix, op!)
prep_op = Symbol("prepare_", op)
# 1-arg
@eval function $op(f::F, backend::AbstractADType, x) where {F}
return $op(f, backend, x, $prep_op(f, backend, x))
end
@eval function $op!(f::F, result, backend::AbstractADType, x) where {F}
return $op!(f, result, backend, x, $prep_op(f, backend, x))
end
@eval function $val_and_op(f::F, backend::AbstractADType, x) where {F}
return $val_and_op(f, backend, x, $prep_op(f, backend, x))
end
@eval function $val_and_op!(f::F, result, backend::AbstractADType, x) where {F}
return $val_and_op!(f, result, backend, x, $prep_op(f, backend, x))
end
op == :gradient && continue
# 2-arg
@eval function $op(f!::F, y, backend::AbstractADType, x) where {F}
return $op(f!, y, backend, x, $prep_op(f!, y, backend, x))
end
@eval function $op!(f!::F, y, result, backend::AbstractADType, x) where {F}
return $op!(f!, y, result, backend, x, $prep_op(f!, y, backend, x))
end
@eval function $val_and_op(f!::F, y, backend::AbstractADType, x) where {F}
return $val_and_op(f!, y, backend, x, $prep_op(f!, y, backend, x))
end
@eval function $val_and_op!(f!::F, y, result, backend::AbstractADType, x) where {F}
return $val_and_op!(f!, y, result, backend, x, $prep_op(f!, y, backend, x))
end
end

for op in (:second_derivative, :hessian)
op! = Symbol(op, "!")
val_prefix = if op == :second_derivative
"value_derivative_and_"
elseif op == :hessian
"value_gradient_and_"
end
val_and_op = Symbol(val_prefix, op)
val_and_op! = Symbol(val_prefix, op!)
prep_op = Symbol("prepare_", op)
# 1-arg
@eval function $op(f::F, backend::AbstractADType, x) where {F}
return $op(f, backend, x, $prep_op(f, backend, x))
end
@eval function $op!(f::F, result2, backend::AbstractADType, x) where {F}
return $op!(f, result2, backend, x, $prep_op(f, backend, x))
end
@eval function $val_and_op(f::F, backend::AbstractADType, x) where {F}
return $val_and_op(f, backend, x, $prep_op(f, backend, x))
end
@eval function $val_and_op!(
f::F, result1, result2, backend::AbstractADType, x
) where {F}
return $val_and_op!(f, result1, result2, backend, x, $prep_op(f, backend, x))
end
end

for op in
(:pushforward, :pushforward_batched, :pullback, :pullback_batched, :hvp, :hvp_batched)
op! = Symbol(op, "!")
val_prefix = "value_and_"
val_and_op = Symbol(val_prefix, op)
val_and_op! = Symbol(val_prefix, op!)
prep_op = Symbol("prepare_", op)
# 1-arg
@eval function $op(f::F, backend::AbstractADType, x, seed) where {F}
return $op(f, backend, x, seed, $prep_op(f, backend, x, seed))
end
@eval function $op!(f::F, result, backend::AbstractADType, x, seed) where {F}
return $op!(f, result, backend, x, seed, $prep_op(f, backend, x, seed))
end
op == :hvp && continue
# 2-arg
@eval function $val_and_op(f::F, backend::AbstractADType, x, seed) where {F}
return $val_and_op(f, backend, x, seed, $prep_op(f, backend, x, seed))
end
@eval function $val_and_op!(f::F, result, backend::AbstractADType, x, seed) where {F}
return $val_and_op!(f, result, backend, x, seed, $prep_op(f, backend, x, seed))
end
@eval function $op(f!::F, y, backend::AbstractADType, x, seed) where {F}
return $op(f!, y, backend, x, seed, $prep_op(f!, y, backend, x, seed))
end
@eval function $op!(f!::F, y, result, backend::AbstractADType, x, seed) where {F}
return $op!(f!, y, result, backend, x, seed, $prep_op(f!, y, backend, x, seed))
end
@eval function $val_and_op(f!::F, y, backend::AbstractADType, x, seed) where {F}
return $val_and_op(f!, y, backend, x, seed, $prep_op(f!, y, backend, x, seed))
end
@eval function $val_and_op!(
f!::F, y, result, backend::AbstractADType, x, seed
) where {F}
return $val_and_op!(
f!, y, result, backend, x, seed, $prep_op(f!, y, backend, x, seed)
)
end
end
42 changes: 0 additions & 42 deletions DifferentiationInterface/src/first_order/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,6 @@ end

## One argument

### Without extras

function value_and_derivative(f::F, backend::AbstractADType, x) where {F}
return value_and_derivative(f, backend, x, prepare_derivative(f, backend, x))
end

function value_and_derivative!(f::F, der, backend::AbstractADType, x) where {F}
return value_and_derivative!(f, der, backend, x, prepare_derivative(f, backend, x))
end

function derivative(f::F, backend::AbstractADType, x) where {F}
return derivative(f, backend, x, prepare_derivative(f, backend, x))
end

function derivative!(f::F, der, backend::AbstractADType, x) where {F}
return derivative!(f, der, backend, x, prepare_derivative(f, backend, x))
end

### With extras

function value_and_derivative(
f::F, backend::AbstractADType, x, extras::PushforwardDerivativeExtras
) where {F}
Expand All @@ -126,28 +106,6 @@ end

## Two arguments

### Without extras

function value_and_derivative(f!::F, y, backend::AbstractADType, x) where {F}
return value_and_derivative(f!, y, backend, x, prepare_derivative(f!, y, backend, x))
end

function value_and_derivative!(f!::F, y, der, backend::AbstractADType, x) where {F}
return value_and_derivative!(
f!, y, der, backend, x, prepare_derivative(f!, y, backend, x)
)
end

function derivative(f!::F, y, backend::AbstractADType, x) where {F}
return derivative(f!, y, backend, x, prepare_derivative(f!, y, backend, x))
end

function derivative!(f!::F, y, der, backend::AbstractADType, x) where {F}
return derivative!(f!, y, der, backend, x, prepare_derivative(f!, y, backend, x))
end

### With extras

function value_and_derivative(
f!::F, y, backend::AbstractADType, x, extras::PushforwardDerivativeExtras
) where {F}
Expand Down
20 changes: 0 additions & 20 deletions DifferentiationInterface/src/first_order/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,6 @@ end

## One argument

### Without extras

function value_and_gradient(f::F, backend::AbstractADType, x) where {F}
return value_and_gradient(f, backend, x, prepare_gradient(f, backend, x))
end

function value_and_gradient!(f::F, der, backend::AbstractADType, x) where {F}
return value_and_gradient!(f, der, backend, x, prepare_gradient(f, backend, x))
end

function gradient(f::F, backend::AbstractADType, x) where {F}
return gradient(f, backend, x, prepare_gradient(f, backend, x))
end

function gradient!(f::F, der, backend::AbstractADType, x) where {F}
return gradient!(f, der, backend, x, prepare_gradient(f, backend, x))
end

### With extras

function value_and_gradient(
f::F, backend::AbstractADType, x, extras::PullbackGradientExtras
) where {F}
Expand Down
40 changes: 0 additions & 40 deletions DifferentiationInterface/src/first_order/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,26 +134,6 @@ end

## One argument

### Without extras

function jacobian(f::F, backend::AbstractADType, x) where {F}
return jacobian(f, backend, x, prepare_jacobian(f, backend, x))
end

function jacobian!(f::F, jac, backend::AbstractADType, x) where {F}
return jacobian!(f, jac, backend, x, prepare_jacobian(f, backend, x))
end

function value_and_jacobian(f::F, backend::AbstractADType, x) where {F}
return value_and_jacobian(f, backend, x, prepare_jacobian(f, backend, x))
end

function value_and_jacobian!(f::F, jac, backend::AbstractADType, x) where {F}
return value_and_jacobian!(f, jac, backend, x, prepare_jacobian(f, backend, x))
end

### With extras

function jacobian(f::F, backend::AbstractADType, x, extras::JacobianExtras) where {F}
return jacobian_aux((f,), backend, x, extras)
end
Expand All @@ -176,26 +156,6 @@ end

## Two arguments

### Without extras

function jacobian(f!::F, y, backend::AbstractADType, x) where {F}
return jacobian(f!, y, backend, x, prepare_jacobian(f!, y, backend, x))
end

function jacobian!(f!::F, y, jac, backend::AbstractADType, x) where {F}
return jacobian!(f!, y, jac, backend, x, prepare_jacobian(f!, y, backend, x))
end

function value_and_jacobian(f!::F, y, backend::AbstractADType, x) where {F}
return value_and_jacobian(f!, y, backend, x, prepare_jacobian(f!, y, backend, x))
end

function value_and_jacobian!(f!::F, y, jac, backend::AbstractADType, x) where {F}
return value_and_jacobian!(f!, y, jac, backend, x, prepare_jacobian(f!, y, backend, x))
end

### With extras

function jacobian(f!::F, y, backend::AbstractADType, x, extras::JacobianExtras) where {F}
return jacobian_aux((f!, y), backend, x, extras)
end
Expand Down
44 changes: 0 additions & 44 deletions DifferentiationInterface/src/first_order/pullback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,26 +160,6 @@ end

## One argument

### Without extras

function value_and_pullback(f::F, backend::AbstractADType, x, dy) where {F}
return value_and_pullback(f, backend, x, dy, prepare_pullback(f, backend, x, dy))
end

function value_and_pullback!(f::F, dx, backend::AbstractADType, x, dy) where {F}
return value_and_pullback!(f, dx, backend, x, dy, prepare_pullback(f, backend, x, dy))
end

function pullback(f::F, backend::AbstractADType, x, dy) where {F}
return pullback(f, backend, x, dy, prepare_pullback(f, backend, x, dy))
end

function pullback!(f::F, dx, backend::AbstractADType, x, dy) where {F}
return pullback!(f, dx, backend, x, dy, prepare_pullback(f, backend, x, dy))
end

### With extras

function value_and_pullback(
f::F, backend::AbstractADType, x, dy, extras::PushforwardPullbackExtras
) where {F}
Expand Down Expand Up @@ -220,30 +200,6 @@ end

## Two arguments

### Without extras

function value_and_pullback(f!::F, y, backend::AbstractADType, x, dy) where {F}
return value_and_pullback(
f!, y, backend, x, dy, prepare_pullback(f!, y, backend, x, dy)
)
end

function value_and_pullback!(f!::F, y, dx, backend::AbstractADType, x, dy) where {F}
return value_and_pullback!(
f!, y, dx, backend, x, dy, prepare_pullback(f!, y, backend, x, dy)
)
end

function pullback(f!::F, y, backend::AbstractADType, x, dy) where {F}
return pullback(f!, y, backend, x, dy, prepare_pullback(f!, y, backend, x, dy))
end

function pullback!(f!::F, y, dx, backend::AbstractADType, x, dy) where {F}
return pullback!(f!, y, dx, backend, x, dy, prepare_pullback(f!, y, backend, x, dy))
end

### With extras

function value_and_pullback(
f!::F, y, backend::AbstractADType, x, dy, extras::PushforwardPullbackExtras
) where {F}
Expand Down
62 changes: 0 additions & 62 deletions DifferentiationInterface/src/first_order/pullback_batched.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,34 +50,6 @@ end

## One argument

### Without extras

function value_and_pullback_batched(f::F, backend::AbstractADType, x, dy::Batch) where {F}
return value_and_pullback_batched(
f, backend, x, dy, prepare_pullback_batched(f, backend, x, dy)
)
end

function value_and_pullback_batched!(
f::F, dx::Batch, backend::AbstractADType, x, dy::Batch
) where {F}
return value_and_pullback_batched!(
f, dx, backend, x, dy, prepare_pullback_batched(f, backend, x, dy)
)
end

function pullback_batched(f::F, backend::AbstractADType, x, dy::Batch) where {F}
return pullback_batched(f, backend, x, dy, prepare_pullback_batched(f, backend, x, dy))
end

function pullback_batched!(f::F, dx::Batch, backend::AbstractADType, x, dy::Batch) where {F}
return pullback_batched!(
f, dx, backend, x, dy, prepare_pullback_batched(f, backend, x, dy)
)
end

### With extras

function pullback_batched(
f::F, backend::AbstractADType, x, dy::Batch, extras::PullbackExtras
) where {F}
Expand Down Expand Up @@ -108,40 +80,6 @@ end

## Two arguments

### Without extras

function value_and_pullback_batched(
f!::F, y, backend::AbstractADType, x, dy::Batch
) where {F}
return value_and_pullback_batched(
f!, y, backend, x, dy, prepare_pullback_batched(f!, y, backend, x, dy)
)
end

function value_and_pullback_batched!(
f!::F, y, dx::Batch, backend::AbstractADType, x, dy::Batch
) where {F}
return value_and_pullback_batched!(
f!, y, dx, backend, x, dy, prepare_pullback_batched(f!, y, backend, x, dy)
)
end

function pullback_batched(f!::F, y, backend::AbstractADType, x, dy::Batch) where {F}
return pullback_batched(
f!, y, backend, x, dy, prepare_pullback_batched(f!, y, backend, x, dy)
)
end

function pullback_batched!(
f!::F, y, dx::Batch, backend::AbstractADType, x, dy::Batch
) where {F}
return pullback_batched!(
f!, y, dx, backend, x, dy, prepare_pullback_batched(f!, y, backend, x, dy)
)
end

### With extras

function pullback_batched(
f!::F, y, backend::AbstractADType, x, dy::Batch, extras::PullbackExtras
) where {F}
Expand Down
Loading

0 comments on commit e56be86

Please sign in to comment.