Skip to content

Commit

Permalink
Introduce To-Be-Recorded Analysis in Clad.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Dec 1, 2023
1 parent 8647f13 commit 567d0f9
Show file tree
Hide file tree
Showing 70 changed files with 4,252 additions and 3,787 deletions.
23 changes: 17 additions & 6 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace clad {
std::string ComputeEffectiveFnName(const clang::FunctionDecl* FD);

/// Creates and returns a compound statement having statements as follows:
/// {`S`, all the statement of `initial` in sequence}
/// {`S`, all the statement of `initial` in sequence}
clang::CompoundStmt* PrependAndCreateCompoundStmt(clang::ASTContext& C,
clang::Stmt* initial,
clang::Stmt* S);
Expand All @@ -38,7 +38,7 @@ namespace clad {
clang::CompoundStmt* AppendAndCreateCompoundStmt(clang::ASTContext& C,
clang::Stmt* initial,
clang::Stmt* S);

/// Shorthand to issues a warning or error.
template <std::size_t N>
void EmitDiag(clang::Sema& semaRef,
Expand Down Expand Up @@ -126,8 +126,8 @@ namespace clad {
///
/// \param S
/// \param namespc
/// \param shouldExist If true, then asserts that the specified namespace
/// is found.
/// \param shouldExist If true, then asserts that the specified namespace
/// is found.
/// \param DC
clang::NamespaceDecl* LookupNSD(clang::Sema& S, llvm::StringRef namespc,
bool shouldExist,
Expand Down Expand Up @@ -162,7 +162,10 @@ namespace clad {
llvm::StringRef str);

/// Returns true if `QT` is Array or Pointer Type, otherwise returns false.
bool isArrayOrPointerType(const clang::QualType QT);
bool isArrayOrPointerType(clang::QualType QT);

/// Returns true if `T` is auto or auto* type, otherwise returns false.
bool IsAutoOrAutoPtrType(clang::QualType T);

clang::DeclarationNameInfo BuildDeclarationNameInfo(clang::Sema& S,
llvm::StringRef name);
Expand Down Expand Up @@ -234,7 +237,7 @@ namespace clad {

bool IsCladValueAndPushforwardType(clang::QualType T);

/// Returns a valid `SourceRange` to be used in places where clang
/// Returns a valid `SourceRange` to be used in places where clang
/// requires a valid `SourceRange`.
clang::SourceRange GetValidSRange(clang::Sema& semaRef);

Expand Down Expand Up @@ -313,6 +316,14 @@ namespace clad {
bool hasNonDifferentiableAttribute(const clang::Decl* D);

bool hasNonDifferentiableAttribute(const clang::Expr* E);

/// Collects every DeclRefExpr, MemberExpr, ArraySubscriptExpr in an
/// assignment operator or a ternary if operator. This is useful to when we
/// need to decide what needs to be stored on tape in reverse mode.
void GetInnermostReturnExpr(const clang::Expr* E,
llvm::SmallVectorImpl<clang::Expr*>& Exprs);

bool ContainsFunctionCalls(const clang::Stmt* E);
} // namespace utils
}

Expand Down
14 changes: 14 additions & 0 deletions include/clad/Differentiator/Compatibility.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,20 @@ static inline bool Expr_EvaluateAsInt(const Expr *E,
#endif
}

// Clang 12: bool Expr::EvaluateAsConstantExpr(EvalResult &Result,
// ConstExprUsage Usage, ASTContext &)
// => bool Expr::EvaluateAsConstantExpr(EvalResult &Result, ASTContext &)

static inline bool Expr_EvaluateAsConstantExpr(const Expr* E,
Expr::EvalResult& res,
const ASTContext& Ctx) {
#if CLANG_VERSION_MAJOR < 12
return E->EvaluateAsConstantExpr(res, Expr::EvaluateForCodeGen, Ctx);
#else
return E->EvaluateAsConstantExpr(res, Ctx);
#endif
}

// Compatibility helper function for creation IfStmt.
// Clang 8 and above use Create.
// Clang 12 and above use two extra params.
Expand Down
2 changes: 2 additions & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ namespace clad {
bool CallUpdateRequired = false;
/// A flag to enable/disable diag warnings/errors during differentiation.
bool VerboseDiags = false;
/// A flag to enable TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
/// Puts the derived function and its code in the diff call
void updateCall(clang::FunctionDecl* FD, clang::FunctionDecl* OverloadedFD,
clang::Sema& SemaRef);
Expand Down
54 changes: 37 additions & 17 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,20 @@ namespace clad {
Stmts m_Globals;
//// A reference to the output parameter of the gradient function.
clang::Expr* m_Result;
/// Based on To-Be-Recorded analysis performed before differentiation,
/// tells UsefulToStoreGlobal whether a variable with a given
/// SourceLocation has to be stored before being changed or not.
std::set<clang::SourceLocation> m_ToBeRecorded;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
bool isInsideLoop = false;
/// Output variable of vector-valued function
std::string outputArrayStr;
/// Stores the pop index values for arrays in reverse mode.This is required
/// to maintain the correct statement order when the current block has
/// delayed emission i.e. assignment LHS.
Stmts m_PopIdxValues;
std::vector<Stmts> m_LoopBlock;
unsigned outputArrayCursor = 0;
unsigned numParams = 0;
bool isVectorValued = false;
bool use_enzyme = false;
bool enableTBR = false;
// FIXME: Should we make this an object instead of a pointer?
// Downside of making it an object: We will need to include
// 'MultiplexExternalRMVSource.h' file
Expand Down Expand Up @@ -142,24 +143,33 @@ namespace clad {
/// Create new block.
Stmts& beginBlock(direction d = direction::forward) {
if (d == direction::forward)
m_Blocks.push_back({});
m_Blocks.emplace_back();
else
m_Reverse.push_back({});
m_Reverse.emplace_back();
return getCurrentBlock(d);
}
/// Remove the block from the stack, wrap it in CompoundStmt and return it.
clang::CompoundStmt* endBlock(direction d = direction::forward) {
if (d == direction::forward) {
auto CS = MakeCompoundStmt(getCurrentBlock(direction::forward));
auto* CS = MakeCompoundStmt(getCurrentBlock(direction::forward));
m_Blocks.pop_back();
return CS;
} else {
auto CS = MakeCompoundStmt(getCurrentBlock(direction::reverse));
auto* CS = MakeCompoundStmt(getCurrentBlock(direction::reverse));
std::reverse(CS->body_begin(), CS->body_end());
m_Reverse.pop_back();
return CS;
}
}

Stmts EndBlockWithoutCreatingCS(direction d = direction::forward) {
auto blk = getCurrentBlock(d);
if (d == direction::forward)
m_Blocks.pop_back();
else
m_Reverse.pop_back();
return blk;
}
/// Output a statement to the current block. If Stmt is null or is an unused
/// expression, it is not output and false is returned.
bool addToCurrentBlock(clang::Stmt* S, direction d = direction::forward) {
Expand Down Expand Up @@ -237,6 +247,10 @@ namespace clad {
StmtDiff GlobalStoreAndRef(clang::Expr* E,
llvm::StringRef prefix = "_t",
bool force = false);
StmtDiff BuildPushPop(clang::Expr* E, clang::QualType Type,
llvm::StringRef prefix = "_t", bool force = false);
StmtDiff StoreAndRestore(clang::Expr* E, llvm::StringRef prefix = "_t",
bool force = false);

//// A type returned by DelayedGlobalStoreAndRef
/// .Result is a reference to the created (yet uninitialized) global
Expand All @@ -250,6 +264,12 @@ namespace clad {
StmtDiff Result;
bool isConstant;
bool isInsideLoop;
bool needsUpdate;
DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult,
bool pIsConstant, bool pIsInsideLoop,
bool pNeedsUpdate = false)
: V(pV), Result(pResult), isConstant(pIsConstant),
isInsideLoop(pIsInsideLoop), needsUpdate(pNeedsUpdate) {}
void Finalize(clang::Expr* New);
};

Expand Down Expand Up @@ -393,7 +413,7 @@ namespace clad {
clang::QualType xType);

/// Allows to easily create and manage a counter for counting the number of
/// executed iterations of a loop.
/// executed iterations of a loop.
///
/// It is required to save the number of executed iterations to use the
/// same number of iterations in the reverse pass.
Expand All @@ -412,11 +432,11 @@ namespace clad {
/// for counter; otherwise, returns nullptr.
clang::Expr* getPush() const { return m_Push; }

/// Returns `clad::pop(_t)` expression if clad tape is used for
/// Returns `clad::pop(_t)` expression if clad tape is used for
/// for counter; otherwise, returns nullptr.
clang::Expr* getPop() const { return m_Pop; }

/// Returns reference to the last object of the clad tape if clad tape
/// Returns reference to the last object of the clad tape if clad tape
/// is used as the counter; otherwise returns reference to the counter
/// variable.
clang::Expr* getRef() const { return m_Ref; }
Expand Down Expand Up @@ -458,11 +478,11 @@ namespace clad {

/// This class modifies forward and reverse blocks of the loop
/// body so that `break` and `continue` statements are correctly
/// handled. `break` and `continue` statements are handled by
/// handled. `break` and `continue` statements are handled by
/// enclosing entire reverse block loop body in a switch statement
/// and only executing the statements, with the help of case labels,
/// that were executed in the associated forward iteration. This is
/// determined by keeping track of which `break`/`continue` statement
/// that were executed in the associated forward iteration. This is
/// determined by keeping track of which `break`/`continue` statement
/// was hit in which iteration and that in turn helps to determine which
/// case label should be selected.
///
Expand Down Expand Up @@ -490,7 +510,7 @@ namespace clad {
/// \note `m_ControlFlowTape` is only initialized if the body contains
/// `continue` or `break` statement.
std::unique_ptr<CladTapeResult> m_ControlFlowTape;

/// Each `break` and `continue` statement is assigned a unique number,
/// starting from 1, that is used as the case label corresponding to that `break`/`continue`
/// statement. `m_CaseCounter` stores the value that was used for last
Expand Down Expand Up @@ -529,7 +549,7 @@ namespace clad {
/// control flow switch statement.
clang::CaseStmt* GetNextCFCaseStmt();

/// Builds and returns `clad::push(TapeRef, m_CurrentCounter)`
/// Builds and returns `clad::push(TapeRef, m_CurrentCounter)`
/// expression, where `TapeRef` and `m_CurrentCounter` are replaced
/// by their actual values respectively.
clang::Stmt* CreateCFTapePushExprToCurrentCase();
Expand All @@ -552,7 +572,7 @@ namespace clad {
void PopBreakContStmtHandler() {
m_BreakContStmtHandlers.pop_back();
}

/// Registers an external RMV source.
///
/// Multiple external RMV source can be registered by calling this function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
namespace clad {
namespace rmv {
/// An enum to operate between forward and reverse passes.
enum direction : int { forward, reverse };
enum direction : int { forward, reverse };
} // namespace rmv
} // namespace clad

Expand Down
9 changes: 7 additions & 2 deletions include/clad/Differentiator/StmtClone.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "clang/Sema/Scope.h"

#include "llvm/ADT/DenseMap.h"
#include <unordered_map>

namespace clang {
class Stmt;
Expand Down Expand Up @@ -153,9 +154,13 @@ namespace utils {
clang::Sema& m_Sema; // We don't own.
clang::Scope* m_CurScope; // We don't own.
const clang::FunctionDecl* m_Function; // We don't own.
const std::unordered_map<const clang::VarDecl*, clang::VarDecl*>&
m_DeclReplacements; // We don't own.
public:
ReferencesUpdater(clang::Sema& SemaRef, clang::Scope* S,
const clang::FunctionDecl* FD);
ReferencesUpdater(
clang::Sema& SemaRef, clang::Scope* S, const clang::FunctionDecl* FD,
const std::unordered_map<const clang::VarDecl*, clang::VarDecl*>&
DeclReplacements);
bool VisitDeclRefExpr(clang::DeclRefExpr* DRE);
bool VisitStmt(clang::Stmt* S);
/// Used to update the size expression of QT
Expand Down
20 changes: 18 additions & 2 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@ namespace clad {
private:
std::array<clang::Stmt*, 2> data;
clang::Stmt* m_DerivativeForForwSweep;
clang::Stmt* m_ValueForRevSweep;

public:
StmtDiff(clang::Stmt* orig = nullptr, clang::Stmt* diff = nullptr,
clang::Stmt* forwSweepDiff = nullptr)
: m_DerivativeForForwSweep(forwSweepDiff) {
clang::Stmt* forwSweepDiff = nullptr,
clang::Stmt* valueForRevSweep = nullptr)
: m_DerivativeForForwSweep(forwSweepDiff),
m_ValueForRevSweep(valueForRevSweep) {
data[1] = orig;
data[0] = diff;
}
Expand All @@ -57,6 +61,18 @@ namespace clad {

clang::Stmt* getForwSweepStmt_dx() { return m_DerivativeForForwSweep; }

clang::Expr* getRevSweepAsExpr() {
return llvm::cast_or_null<clang::Expr>(getRevSweepStmt());
}

clang::Stmt* getRevSweepStmt() {
/// If there is no specific value for
/// the reverse sweep, use Stmt_dx.
if (!m_ValueForRevSweep)
return data[1];
return m_ValueForRevSweep;
}

clang::Expr* getForwSweepExpr_dx() {
return llvm::cast_or_null<clang::Expr>(m_DerivativeForForwSweep);
}
Expand Down
1 change: 1 addition & 0 deletions lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ add_llvm_library(cladDifferentiator
MultiplexExternalRMVSource.cpp
ReverseModeForwPassVisitor.cpp
ReverseModeVisitor.cpp
TBRAnalyzer.cpp
StmtClone.cpp
VectorForwardModeVisitor.cpp
Version.cpp
Expand Down
Loading

0 comments on commit 567d0f9

Please sign in to comment.