Skip to content

Commit

Permalink
Add support for diff of ref return types in rev mode
Browse files Browse the repository at this point in the history
  • Loading branch information
parth-07 authored and vgvassilev committed Aug 19, 2023
1 parent 24a15f9 commit 771d709
Show file tree
Hide file tree
Showing 18 changed files with 655 additions and 102 deletions.
3 changes: 2 additions & 1 deletion demos/ComputerGraphics/smallpt/SmallPT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
// ./SmallPT 500 && xv image.ppm

// A typical invocation would be:
// ../../../../../obj/Debug+Asserts/bin/clang++ -O3 -Xclang -add-plugin -Xclang clad \
// ../../../../../obj/Debug+Asserts/bin/clang++ -O3 -Xclang -add-plugin -Xclang
// clad \
// -Xclang -load -Xclang ../../../../../obj/Debug+Asserts/lib/libclad.dylib \
// -I../../include/ -std=c++11 SmallPT.cpp -fopenmp=libiomp5 -o SmallPT
// ./SmallPT 500 && xv image.ppm
Expand Down
8 changes: 5 additions & 3 deletions demos/OpenCL/RosenbrockFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

// To run the demo please type:
// path/to/clang++ -Xclang -add-plugin -Xclang clad -Xclang -load -Xclang \
// path/to/libclad.so -I../include/ -framework opencl -std=c++11 RosenbrockFunction.cpp
// path/to/libclad.so -I../include/ -framework opencl -std=c++11
// RosenbrockFunction.cpp
//
// A typical invocation would be:
// ../../../../../obj/Debug+Asserts/bin/clang++ -Xclang -add-plugin -Xclang clad \
// -Xclang -load -Xclang ../../../../../obj/Debug+Asserts/lib/libclad.dylib \
// ../../../../../obj/Debug+Asserts/bin/clang++ -Xclang -add-plugin -Xclang
// clad \
// -Xclang -load -Xclang ../../../../../obj/Debug+Asserts/lib/libclad.dylib \
// -I../../include/ -framework opencl -std=c++11 RosenbrockFunction.cpp

// Necessary for clad to work include
Expand Down
32 changes: 17 additions & 15 deletions include/clad/Differentiator/Compatibility.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ static inline NamespaceDecl*
NamespaceDecl_Create(ASTContext& C, DeclContext* DC, bool Inline,
SourceLocation StartLoc, SourceLocation IdLoc,
IdentifierInfo* Id, NamespaceDecl* PrevDecl) {
return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl);
return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl);
}
#else
static inline NamespaceDecl*
NamespaceDecl_Create(ASTContext& C, DeclContext* DC, bool Inline,
SourceLocation StartLoc, SourceLocation IdLoc,
IdentifierInfo* Id, NamespaceDecl* PrevDecl) {
return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl,
/*Nested=*/false);
return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl,
/*Nested=*/false);
}
#endif

Expand Down Expand Up @@ -249,17 +249,19 @@ static inline void ExprSetDeps(Expr* result, Expr* Node) {
#define CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsPar ,clang::CallExpr::ADLCallKind UsesADL
#define CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsUse ,UsesADL
#define CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsOverride FPOptionsOverride
#if CLANG_VERSION_MAJOR >= 16
#define CLAD_COMPAT_CLANG11_LangOptions_EtraParams /**/
#else
#define CLAD_COMPAT_CLANG11_LangOptions_EtraParams Ctx.getLangOpts()
#endif
#define CLAD_COMPAT_CLANG11_Ctx_ExtraParams Ctx,
#define CLAD_COMPAT_CREATE11(CLASS, CTORARGS) (CLASS::Create CTORARGS)
#define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Removed /**/
#define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Moved ,Node->getComputationLHSType(),Node->getComputationResultType()
#define CLAD_COMPAT_CLANG11_ChooseExpr_EtraParams_Removed /**/
#define CLAD_COMPAT_CLANG11_WhileStmt_ExtraParams ,Node->getLParenLoc(),Node->getRParenLoc()
#if CLANG_VERSION_MAJOR >= 16
#define CLAD_COMPAT_CLANG11_LangOptions_EtraParams /**/
#else
#define CLAD_COMPAT_CLANG11_LangOptions_EtraParams Ctx.getLangOpts()
#endif
#define CLAD_COMPAT_CLANG11_Ctx_ExtraParams Ctx,
#define CLAD_COMPAT_CREATE11(CLASS, CTORARGS) (CLASS::Create CTORARGS)
#define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Removed /**/
#define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Moved \
, Node->getComputationLHSType(), Node->getComputationResultType()
#define CLAD_COMPAT_CLANG11_ChooseExpr_EtraParams_Removed /**/
#define CLAD_COMPAT_CLANG11_WhileStmt_ExtraParams \
, Node->getLParenLoc(), Node->getRParenLoc()
#endif

// Compatibility helper function for creation CXXOperatorCallExpr. Clang 8 and above use Create.
Expand Down Expand Up @@ -725,7 +727,7 @@ ArraySize_GetValue(const llvm::Optional<const Expr*>& opt) {
#else
static inline const Expr*
ArraySize_GetValue(const std::optional<const Expr*>& opt) {
return opt.value();
return opt.value();
}
#endif

Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ namespace clad {
friend class ReverseModeVisitor;
friend class HessianModeVisitor;
friend class JacobianModeVisitor;

friend class ReverseModeForwPassVisitor;
clang::Sema& m_Sema;
plugin::CladPlugin& m_CladPlugin;
clang::ASTContext& m_Context;
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ enum class DiffMode {
reverse,
hessian,
jacobian,
reverse_mode_forward_pass,
error_estimation
};
}
Expand Down
31 changes: 18 additions & 13 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
#include <cstring>

namespace clad {
template <typename T, typename U> struct ValueAndAdjoint {
T value;
U adjoint;
};

/// \returns the size of a c-style string
CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
unsigned int count;
Expand Down Expand Up @@ -299,9 +304,9 @@ namespace clad {
differentiate(F fn, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(fn && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(derivedFn,
code);
assert(fn && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(derivedFn,
code);
}

/// Specialization for differentiating functors.
Expand All @@ -319,8 +324,8 @@ namespace clad {
differentiate(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(derivedFn,
code, f);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(derivedFn,
code, f);
}

/// Generates function which computes derivative of `fn` argument w.r.t
Expand All @@ -342,9 +347,9 @@ namespace clad {
differentiate(F fn, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(fn && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn, code);
assert(fn && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn, code);
}

/// Generates function which computes gradient of the given function wrt the
Expand All @@ -364,9 +369,9 @@ namespace clad {
gradient(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code);
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code);
}

/// Specialization for differentiating functors.
Expand All @@ -382,8 +387,8 @@ namespace clad {
gradient(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code, f);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code, f);
}

/// Generates function which computes hessian matrix of the given function wrt
Expand Down
38 changes: 38 additions & 0 deletions include/clad/Differentiator/ReverseModeForwPassVisitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#ifndef CLAD_DIFFERENTIATOR_REVERSEMODEFORWPASSVISITOR_H
#define CLAD_DIFFERENTIATOR_REVERSEMODEFORWPASSVISITOR_H

#include "clad/Differentiator/ParseDiffArgsTypes.h"
#include "clad/Differentiator/ReverseModeVisitor.h"

#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"

#include "llvm/ADT/SmallVector.h"

namespace clad {
class ReverseModeForwPassVisitor : public ReverseModeVisitor {
private:
Stmts m_Globals;

llvm::SmallVector<clang::QualType, 8>
ComputeParamTypes(const DiffParams& diffParams);
clang::QualType ComputeReturnType();
llvm::SmallVector<clang::ParmVarDecl*, 8> BuildParams(DiffParams& diffParams);
clang::QualType GetParameterDerivativeType(clang::QualType yType,
clang::QualType xType);

public:
ReverseModeForwPassVisitor(DerivativeBuilder& builder);
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);

StmtDiff ProcessSingleStmt(const clang::Stmt* S);

StmtDiff VisitStmt(const clang::Stmt* S) override;
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override;
StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE) override;
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
};
} // namespace clad

#endif
14 changes: 7 additions & 7 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ namespace clad {
class ReverseModeVisitor
: public clang::ConstStmtVisitor<ReverseModeVisitor, StmtDiff>,
public VisitorBase {
private:

protected:
// FIXME: We should remove friend-dependency of the plugin classes here.
// For this we will need to separate out AST related functions in
// a separate namespace, as well as add getters/setters function of
Expand Down Expand Up @@ -292,7 +292,7 @@ namespace clad {

public:
ReverseModeVisitor(DerivativeBuilder& builder);
~ReverseModeVisitor();
virtual ~ReverseModeVisitor();

///\brief Produces the gradient of a given function.
///
Expand Down Expand Up @@ -321,11 +321,11 @@ namespace clad {
StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO);
StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL);
StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE);
StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
StmtDiff VisitDeclStmt(const clang::DeclStmt* DS);
StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL);
StmtDiff VisitForStmt(const clang::ForStmt* FS);
Expand All @@ -335,8 +335,8 @@ namespace clad {
StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL);
StmtDiff VisitMemberExpr(const clang::MemberExpr* ME);
StmtDiff VisitParenExpr(const clang::ParenExpr* PE);
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS);
StmtDiff VisitStmt(const clang::Stmt* S);
virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS);
virtual StmtDiff VisitStmt(const clang::Stmt* S);
StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp);
StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC);
/// Decl is not Stmt, so it cannot be visited directly.
Expand Down
1 change: 1 addition & 0 deletions lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ add_llvm_library(cladDifferentiator
HessianModeVisitor.cpp
JacobianModeVisitor.cpp
MultiplexExternalRMVSource.cpp
ReverseModeForwPassVisitor.cpp
ReverseModeVisitor.cpp
StmtClone.cpp
VectorForwardModeVisitor.cpp
Expand Down
4 changes: 4 additions & 0 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "clad/Differentiator/ForwardModeVisitor.h"
#include "clad/Differentiator/HessianModeVisitor.h"
#include "clad/Differentiator/JacobianModeVisitor.h"
#include "clad/Differentiator/ReverseModeForwPassVisitor.h"
#include "clad/Differentiator/ReverseModeVisitor.h"
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/VectorForwardModeVisitor.h"
Expand Down Expand Up @@ -230,6 +231,9 @@ namespace clad {
result = V.DerivePullback(FD, request);
if (!m_ErrorEstHandler.empty())
CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel);
} else if (request.Mode == DiffMode::reverse_mode_forward_pass) {
ReverseModeForwPassVisitor V(*this);
result = V.Derive(FD, request);
} else if (request.Mode == DiffMode::hessian) {
HessianModeVisitor H(*this);
result = H.Derive(FD, request);
Expand Down
16 changes: 8 additions & 8 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,20 +390,20 @@ namespace clad {
// The string is not a range just a single index
size_t index;
if (firstStr.getAsInteger(Radix, index)) {
utils::EmitDiag(semaRef, DiagnosticsEngine::Error,
diffArgs->getEndLoc(),
"Could not parse index '%0'", {diffSpec});
return;
utils::EmitDiag(semaRef, DiagnosticsEngine::Error,
diffArgs->getEndLoc(),
"Could not parse index '%0'", {diffSpec});
return;
}
dVarInfo.paramIndexInterval = IndexInterval(index);
} else {
size_t first, last;
if (firstStr.getAsInteger(Radix, first) ||
lastStr.getAsInteger(Radix, last)) {
utils::EmitDiag(semaRef, DiagnosticsEngine::Error,
diffArgs->getEndLoc(),
"Could not parse range '%0'", {diffSpec});
return;
utils::EmitDiag(semaRef, DiagnosticsEngine::Error,
diffArgs->getEndLoc(),
"Could not parse range '%0'", {diffSpec});
return;
}
if (first >= last) {
utils::EmitDiag(semaRef, DiagnosticsEngine::Error,
Expand Down
2 changes: 1 addition & 1 deletion lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,4 +380,4 @@ namespace clad {
return DerivativeAndOverload{result.first,
/*OverloadFunctionDecl=*/nullptr};
}
} // end namespace clad
} // end namespace clad
Loading

0 comments on commit 771d709

Please sign in to comment.