Skip to content

Commit

Permalink
Generalize custom derivative templates to be compatible with the vect…
Browse files Browse the repository at this point in the history
…orized frw mode and jacobians
  • Loading branch information
PetroZarytskyi committed Oct 22, 2024
1 parent c936d60 commit c479267
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 20 deletions.
59 changes: 40 additions & 19 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,31 +150,31 @@ CUDA_HOST_DEVICE inline void __builtin_powf_pullback(float x, float exponent,
// FIXME: Add the rest of the __builtin_ routines for log, sqrt, abs, etc.

namespace std {
template <typename T>
CUDA_HOST_DEVICE ValueAndPushforward<T, T> abs_pushforward(T x, T d_x) {
template <typename T, typename dT>
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> abs_pushforward(T x, dT d_x) {
if (x >= 0)
return {x, d_x};
else
return {-x, -d_x};
}

template <typename T>
CUDA_HOST_DEVICE ValueAndPushforward<T, T> exp_pushforward(T x, T d_x) {
template <typename T, typename dT>
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> exp_pushforward(T x, dT d_x) {
return {::std::exp(x), ::std::exp(x) * d_x};
}

template <typename T>
CUDA_HOST_DEVICE ValueAndPushforward<T, T> sin_pushforward(T x, T d_x) {
template <typename T, typename dT>
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> sin_pushforward(T x, dT d_x) {
return {::std::sin(x), ::std::cos(x) * d_x};
}

template <typename T>
CUDA_HOST_DEVICE ValueAndPushforward<T, T> cos_pushforward(T x, T d_x) {
template <typename T, typename dT>
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> cos_pushforward(T x, dT d_x) {
return {::std::cos(x), (-1) * ::std::sin(x) * d_x};
}

template <typename T>
CUDA_HOST_DEVICE ValueAndPushforward<T, T> sqrt_pushforward(T x, T d_x) {
template <typename T, typename dT>
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> sqrt_pushforward(T x, dT d_x) {
return {::std::sqrt(x), (((T)1) / (((T)2) * ::std::sqrt(x))) * d_x};
}

Expand All @@ -183,9 +183,9 @@ CUDA_HOST_DEVICE ValueAndPushforward<T, T> floor_pushforward(T x, T /*d_x*/) {
return {::std::floor(x), (T)0};
}

template <typename T>
CUDA_HOST_DEVICE ValueAndPushforward<T, T> atan2_pushforward(T y, T x, T d_y,
T d_x) {
template <typename T, typename dT>
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> atan2_pushforward(T y, T x, dT d_y,
dT d_x) {
return {::std::atan2(y, x),
-(y / ((x * x) + (y * y))) * d_x + x / ((x * x) + (y * y)) * d_y};
}
Expand All @@ -197,8 +197,8 @@ CUDA_HOST_DEVICE void atan2_pullback(T y, T x, U d_z, T* d_y, T* d_x) {
*d_x += -(y / ((x * x) + (y * y))) * d_z;
}

template <typename T>
CUDA_HOST_DEVICE ValueAndPushforward<T, T> acos_pushforward(T x, T d_x) {
template <typename T, typename dT>
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> acos_pushforward(T x, dT d_x) {
return {::std::acos(x), ((-1) / (::std::sqrt(1 - x * x))) * d_x};
}

Expand All @@ -221,10 +221,31 @@ pow_pushforward(T1 x, T2 exponent, T1 d_x, T2 d_exponent) {
auto val = ::std::pow(x, exponent);
auto derivative = (exponent * ::std::pow(x, exponent - 1)) * d_x;
// Only add directional derivative of base^exp w.r.t exp if the directional
// seed d_exponent is non-zero. This is required because if base is less than or
// equal to 0, then log(base) is undefined, and therefore if user only requested
// directional derivative of base^exp w.r.t base -- which is valid --, the result would
// be undefined because as per C++ valid number + NaN * 0 = NaN.
// seed d_exponent is non-zero. This is required because if base is less than
// or equal to 0, then log(base) is undefined, and therefore if user only
// requested directional derivative of base^exp w.r.t base -- which is valid
// --, the result would be undefined because as per C++ valid number + NaN * 0
// = NaN.
if (d_exponent)
derivative += (::std::pow(x, exponent) * ::std::log(x)) * d_exponent;
return {val, derivative};
}

template <typename T1, typename T2>
CUDA_HOST_DEVICE
ValueAndPushforward<decltype(::std::pow(T1(), T2())),
clad::array<decltype(::std::pow(T1(), T2()))>>
pow_pushforward(T1 x, T2 exponent, clad::array<T1> d_x,
clad::array<T2> d_exponent) {
decltype(::std::pow(T1(), T2())) val = ::std::pow(x, exponent);
clad::array<decltype(::std::pow(T1(), T2()))> derivative =
(exponent * ::std::pow(x, exponent - 1)) * d_x;
// Only add directional derivative of base^exp w.r.t exp if the directional
// seed d_exponent is non-zero. This is required because if base is less than
// or equal to 0, then log(base) is undefined, and therefore if user only
// requested directional derivative of base^exp w.r.t base -- which is valid
// --, the result would be undefined because as per C++ valid number + NaN * 0
// = NaN.
if (d_exponent)
derivative += (::std::pow(x, exponent) * ::std::log(x)) * d_exponent;
return {val, derivative};
Expand Down
11 changes: 10 additions & 1 deletion lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,16 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));

// Custom derivative templates can be written in a
// general way that works for both vectorized and non-vectorized
// modes. We have to also look for the pushforward with the regular name.
if (!callDiff && m_DiffReq.Mode != DiffMode::forward) {
customPushforward =
clad::utils::ComputeEffectiveFnName(FD) + "_pushforward";
callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
}
if (!isLambda) {
// Check if it is a recursive call.
if (!callDiff && (FD == m_DiffReq.Function) &&
Expand Down

0 comments on commit c479267

Please sign in to comment.