Skip to content

Commit

Permalink
Add support for diff of ref return types in rev mode
Browse files Browse the repository at this point in the history
  • Loading branch information
parth-07 authored and phrygiangates committed Aug 19, 2023
1 parent 24a15f9 commit d89426c
Show file tree
Hide file tree
Showing 11 changed files with 544 additions and 22 deletions.
2 changes: 1 addition & 1 deletion include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ namespace clad {
friend class ReverseModeVisitor;
friend class HessianModeVisitor;
friend class JacobianModeVisitor;

friend class ReverseModeForwPassVisitor;
clang::Sema& m_Sema;
plugin::CladPlugin& m_CladPlugin;
clang::ASTContext& m_Context;
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ enum class DiffMode {
reverse,
hessian,
jacobian,
reverse_mode_forward_pass,
error_estimation
};
}
Expand Down
6 changes: 6 additions & 0 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
#include <cstring>

namespace clad {
template<typename T, typename U>
struct ValueAndAdjoint {
T value;
U adjoint;
};

/// \returns the size of a c-style string
CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
unsigned int count;
Expand Down
38 changes: 38 additions & 0 deletions include/clad/Differentiator/ReverseModeForwPassVisitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#ifndef CLAD_DIFFERENTIATOR_REVERSEMODEFORWPASSVISITOR_H
#define CLAD_DIFFERENTIATOR_REVERSEMODEFORWPASSVISITOR_H

#include "clad/Differentiator/ParseDiffArgsTypes.h"
#include "clad/Differentiator/ReverseModeVisitor.h"

#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"

#include "llvm/ADT/SmallVector.h"

namespace clad {
class ReverseModeForwPassVisitor : public ReverseModeVisitor {
private:
Stmts m_Globals;

llvm::SmallVector<clang::QualType, 8>
ComputeParamTypes(const DiffParams& diffParams);
clang::QualType ComputeReturnType();
llvm::SmallVector<clang::ParmVarDecl*, 8> BuildParams(DiffParams& diffParams);
clang::QualType GetParameterDerivativeType(clang::QualType yType,
clang::QualType xType);

public:
ReverseModeForwPassVisitor(DerivativeBuilder& builder);
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);

StmtDiff ProcessSingleStmt(const clang::Stmt* S);

StmtDiff VisitStmt(const clang::Stmt* S) override;
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override;
StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE) override;
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
};
} // namespace clad

#endif
10 changes: 5 additions & 5 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace clad {
: public clang::ConstStmtVisitor<ReverseModeVisitor, StmtDiff>,
public VisitorBase {

private:
protected:
// FIXME: We should remove friend-dependency of the plugin classes here.
// For this we will need to separate out AST related functions in
// a separate namespace, as well as add getters/setters function of
Expand Down Expand Up @@ -321,11 +321,11 @@ namespace clad {
StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO);
StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL);
StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE);
StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
StmtDiff VisitDeclStmt(const clang::DeclStmt* DS);
StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL);
StmtDiff VisitForStmt(const clang::ForStmt* FS);
Expand All @@ -335,8 +335,8 @@ namespace clad {
StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL);
StmtDiff VisitMemberExpr(const clang::MemberExpr* ME);
StmtDiff VisitParenExpr(const clang::ParenExpr* PE);
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS);
StmtDiff VisitStmt(const clang::Stmt* S);
virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS);
virtual StmtDiff VisitStmt(const clang::Stmt* S);
StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp);
StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC);
/// Decl is not Stmt, so it cannot be visited directly.
Expand Down
1 change: 1 addition & 0 deletions lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ add_llvm_library(cladDifferentiator
HessianModeVisitor.cpp
JacobianModeVisitor.cpp
MultiplexExternalRMVSource.cpp
ReverseModeForwPassVisitor.cpp
ReverseModeVisitor.cpp
StmtClone.cpp
VectorForwardModeVisitor.cpp
Expand Down
5 changes: 5 additions & 0 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "clad/Differentiator/HessianModeVisitor.h"
#include "clad/Differentiator/JacobianModeVisitor.h"
#include "clad/Differentiator/ReverseModeVisitor.h"
#include "clad/Differentiator/ReverseModeForwPassVisitor.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/VectorForwardModeVisitor.h"

Expand Down Expand Up @@ -230,6 +232,9 @@ namespace clad {
result = V.DerivePullback(FD, request);
if (!m_ErrorEstHandler.empty())
CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel);
} else if (request.Mode == DiffMode::reverse_mode_forward_pass) {
ReverseModeForwPassVisitor V(*this);
result = V.Derive(FD, request);
} else if (request.Mode == DiffMode::hessian) {
HessianModeVisitor H(*this);
result = H.Derive(FD, request);
Expand Down
Loading

0 comments on commit d89426c

Please sign in to comment.