Skip to content

Commit

Permalink
Implement forward-mode jacobians
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Oct 22, 2024
1 parent 3ec9af2 commit c936d60
Show file tree
Hide file tree
Showing 10 changed files with 359 additions and 50 deletions.
10 changes: 5 additions & 5 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,13 +541,13 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
typename F, typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
!std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
annotate("J")))
jacobian(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by Jacobian*/, code);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by Jacobian*/, code);
}

/// Specialization for differentiating functors.
Expand All @@ -557,12 +557,12 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
typename F, typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
annotate("J")))
jacobian(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by Jacobian*/, code, f);
}

Expand Down
57 changes: 24 additions & 33 deletions include/clad/Differentiator/FunctionTraits.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define FUNCTION_TRAITS

#include "clad/Differentiator/ArrayRef.h"
#include "clad/Differentiator/Matrix.h"

#include <type_traits>

Expand Down Expand Up @@ -548,15 +549,25 @@ namespace clad {
using type = NoFunction*;
};

template <class... Args> struct SelectLast;
// OutputVecParamType is used to deduce the type of derivative arguments
// for vector forward mode.
template <class T, class R> struct OutputVecParamType {
using type = array_ref<typename std::remove_pointer<R>::type>;
};

template <class... Args>
using SelectLast_t = typename SelectLast<Args...>::type;
template <class T, class R>
using OutputVecParamType_t = typename OutputVecParamType<T, R>::type;

/// Specialization for vector forward mode type.
template <class F, class = void> struct ExtractDerivedFnTraitsVecForwMode {};

template <class T> struct SelectLast<T> { using type = T; };
template <class F>
using ExtractDerivedFnTraitsVecForwMode_t =
typename ExtractDerivedFnTraitsVecForwMode<F>::type;

template <class T, class... Args> struct SelectLast<T, Args...> {
using type = typename SelectLast<Args...>::type;
template <class ReturnType, class... Args>
struct ExtractDerivedFnTraitsVecForwMode<ReturnType (*)(Args...)> {
using type = void (*)(Args..., OutputVecParamType_t<Args, void>...);
};

template <class T, class = void> struct JacobianDerivedFnTraits {};
Expand All @@ -569,7 +580,7 @@ namespace clad {
// JacobianDerivedFnTraits specializations for pure function pointer types
template <class ReturnType, class... Args>
struct JacobianDerivedFnTraits<ReturnType (*)(Args...)> {
using type = void (*)(Args..., SelectLast_t<Args...>);
using type = void (*)(Args..., OutputVecParamType_t<Args, void>...);
};

/// These macro expansions are used to cover all possible cases of
Expand All @@ -581,11 +592,12 @@ namespace clad {
/// qualifier and reference respectively. The AddNOEX adds cases for noexcept
/// qualifier only if it is supported and finally AddSPECS declares the
/// function with all the cases
#define JacobianDerivedFnTraits_AddSPECS(var, cv, vol, ref, noex) \
template <typename R, typename C, typename... Args> \
struct JacobianDerivedFnTraits<R (C::*)(Args...) cv vol ref noex> { \
using type = void (C::*)(Args..., SelectLast_t<Args...>) cv vol ref noex; \
};
#define JacobianDerivedFnTraits_AddSPECS(var, cv, vol, ref, noex) \
template <typename R, typename C, typename... Args> \
struct JacobianDerivedFnTraits<R (C::*)(Args...) cv vol ref noex> { \
using type = void (C::*)( \
Args..., OutputVecParamType_t<Args, void>...) cv vol ref noex; \
};

#if __cpp_noexcept_function_type > 0
#define JacobianDerivedFnTraits_AddNOEX(var, con, vol, ref) \
Expand Down Expand Up @@ -739,27 +751,6 @@ namespace clad {
using ExtractDerivedFnTraitsForwMode_t =
typename ExtractDerivedFnTraitsForwMode<F>::type;

// OutputVecParamType is used to deduce the type of derivative arguments
// for vector forward mode.
template <class T, class R> struct OutputVecParamType {
using type = array_ref<typename std::remove_pointer<R>::type>;
};

template <class T, class R>
using OutputVecParamType_t = typename OutputVecParamType<T, R>::type;

/// Specialization for vector forward mode type.
template <class F, class = void> struct ExtractDerivedFnTraitsVecForwMode {};

template <class F>
using ExtractDerivedFnTraitsVecForwMode_t =
typename ExtractDerivedFnTraitsVecForwMode<F>::type;

template <class ReturnType, class... Args>
struct ExtractDerivedFnTraitsVecForwMode<ReturnType (*)(Args...)> {
using type = void (*)(Args..., OutputVecParamType_t<Args, void>...);
};

/// Specialization for free function pointer type
template <class F>
struct ExtractDerivedFnTraitsForwMode<
Expand Down
20 changes: 20 additions & 0 deletions include/clad/Differentiator/JacobianModeVisitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef CLAD_DIFFERENTIATOR_JACOBIANMODEVISITOR_H
#define CLAD_DIFFERENTIATOR_JACOBIANMODEVISITOR_H

#include "VectorPushForwardModeVisitor.h"

namespace clad {
class JacobianModeVisitor : public VectorPushForwardModeVisitor {

public:
JacobianModeVisitor(DerivativeBuilder& builder, const DiffRequest& request);

DerivativeAndOverload DeriveJacobian();

clang::QualType getParamAdjointType(clang::QualType T);

StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
};
} // end namespace clad

#endif // CLAD_DIFFERENTIATOR_JACOBIANMODEVISITOR_H
2 changes: 1 addition & 1 deletion include/clad/Differentiator/VectorForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace clad {
/// A visitor for processing the function code in vector forward mode.
/// Used to compute derivatives by clad::vector_forward_differentiate.
class VectorForwardModeVisitor : public BaseForwardModeVisitor {
private:
protected:
llvm::SmallVector<const clang::ValueDecl*, 16> m_IndependentVars;
/// Map used to keep track of parameter variables w.r.t which the
/// the derivative is being computed. This is separate from the
Expand Down
6 changes: 5 additions & 1 deletion lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
clang::QualType BaseForwardModeVisitor::ComputePushforwardFnReturnType() {
assert(m_DiffReq.Mode == GetPushForwardMode());
QualType originalFnRT = m_DiffReq->getReturnType();
if (m_DiffReq.Mode == DiffMode::jacobian)
return GetPushForwardDerivativeType(originalFnRT);
if (originalFnRT->isVoidType())
return m_Context.VoidTy;
TemplateDecl* valueAndPushforward =
Expand Down Expand Up @@ -1446,7 +1448,9 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {
derivedR = BuildParens(derivedR);
opDiff = BuildOp(opCode, derivedL, derivedR);
} else if (BinOp->isAssignmentOp()) {
if (Ldiff.getExpr_dx()->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
if ((Ldiff.getExpr_dx()->isModifiableLvalue(m_Context) !=
Expr::MLV_Valid) &&
!isCladArrayType(Ldiff.getExpr_dx()->getType())) {
diag(DiagnosticsEngine::Warning, BinOp->getEndLoc(),
"derivative of an assignment attempts to assign to unassignable "
"expr, assignment ignored");
Expand Down
1 change: 1 addition & 0 deletions lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ llvm_add_library(cladDifferentiator
DiffPlanner.cpp
ErrorEstimator.cpp
EstimationModel.cpp
JacobianModeVisitor.cpp
HessianModeVisitor.cpp
MultiplexExternalRMVSource.cpp
PushForwardModeVisitor.cpp
Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/HessianModeVisitor.h"
#include "clad/Differentiator/JacobianModeVisitor.h"
#include "clad/Differentiator/PushForwardModeVisitor.h"
#include "clad/Differentiator/ReverseModeForwPassVisitor.h"
#include "clad/Differentiator/ReverseModeVisitor.h"
Expand Down Expand Up @@ -429,8 +430,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
HessianModeVisitor H(*this, request);
result = H.Derive();
} else if (request.Mode == DiffMode::jacobian) {
ReverseModeVisitor R(*this, request);
result = R.Derive();
JacobianModeVisitor J(*this, request);
result = J.DeriveJacobian();
} else if (request.Mode == DiffMode::error_estimation) {
ReverseModeVisitor R(*this, request);
InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request);
Expand Down
9 changes: 1 addition & 8 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,21 +563,14 @@ namespace clad {

// If the function has no parameters, then we cannot differentiate it."
// and if the DiffMode is Jacobian, we must have atleast 2 parameters.
if (params.empty() || (params.size()==1 && this->Mode == DiffMode::jacobian)) {
if (params.empty()) {
utils::EmitDiag(semaRef, DiagnosticsEngine::Error,
CallContext->getEndLoc(),
"Attempted to differentiate a function without "
"parameters");
return;
}

// If it is a Vector valued function, the last parameter is to store the
// output vector and hence is not a differentiable parameter, so we must
// pop it out
if (this->Mode == DiffMode::jacobian){
params.pop_back();
}

// insert an empty index for each parameter.
for (unsigned i=0; i<params.size(); ++i) {
DiffInputVarInfo dVarInfo(params[i], IndexInterval());
Expand Down
Loading

0 comments on commit c936d60

Please sign in to comment.