Skip to content

Commit

Permalink
Introduce To-Be-Recorded Analysis in Clad. (#655)
Browse files Browse the repository at this point in the history
This patch optimizes storing and restoring in the reverse mode of Clad and
introduces TBR analysis to determine when variables should be stored.

Fixes #465, #441, #439, #429. Partially resolves #606.
  • Loading branch information
PetroZarytskyi authored Dec 1, 2023
1 parent 8647f13 commit 0ab0c6e
Show file tree
Hide file tree
Showing 70 changed files with 4,464 additions and 3,787 deletions.
23 changes: 17 additions & 6 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace clad {
std::string ComputeEffectiveFnName(const clang::FunctionDecl* FD);

/// Creates and returns a compound statement having statements as follows:
/// {`S`, all the statement of `initial` in sequence}
/// {`S`, all the statement of `initial` in sequence}
clang::CompoundStmt* PrependAndCreateCompoundStmt(clang::ASTContext& C,
clang::Stmt* initial,
clang::Stmt* S);
Expand All @@ -38,7 +38,7 @@ namespace clad {
clang::CompoundStmt* AppendAndCreateCompoundStmt(clang::ASTContext& C,
clang::Stmt* initial,
clang::Stmt* S);

/// Shorthand to issues a warning or error.
template <std::size_t N>
void EmitDiag(clang::Sema& semaRef,
Expand Down Expand Up @@ -126,8 +126,8 @@ namespace clad {
///
/// \param S
/// \param namespc
/// \param shouldExist If true, then asserts that the specified namespace
/// is found.
/// \param shouldExist If true, then asserts that the specified namespace
/// is found.
/// \param DC
clang::NamespaceDecl* LookupNSD(clang::Sema& S, llvm::StringRef namespc,
bool shouldExist,
Expand Down Expand Up @@ -162,7 +162,10 @@ namespace clad {
llvm::StringRef str);

/// Returns true if `QT` is Array or Pointer Type, otherwise returns false.
bool isArrayOrPointerType(const clang::QualType QT);
bool isArrayOrPointerType(clang::QualType QT);

/// Returns true if `T` is auto or auto* type, otherwise returns false.
bool IsAutoOrAutoPtrType(clang::QualType T);

clang::DeclarationNameInfo BuildDeclarationNameInfo(clang::Sema& S,
llvm::StringRef name);
Expand Down Expand Up @@ -234,7 +237,7 @@ namespace clad {

bool IsCladValueAndPushforwardType(clang::QualType T);

/// Returns a valid `SourceRange` to be used in places where clang
/// Returns a valid `SourceRange` to be used in places where clang
/// requires a valid `SourceRange`.
clang::SourceRange GetValidSRange(clang::Sema& semaRef);

Expand Down Expand Up @@ -313,6 +316,14 @@ namespace clad {
bool hasNonDifferentiableAttribute(const clang::Decl* D);

bool hasNonDifferentiableAttribute(const clang::Expr* E);

/// Collects every DeclRefExpr, MemberExpr, ArraySubscriptExpr in an
/// assignment operator or a ternary if operator. This is useful to when we
/// need to decide what needs to be stored on tape in reverse mode.
void GetInnermostReturnExpr(const clang::Expr* E,
llvm::SmallVectorImpl<clang::Expr*>& Exprs);

bool ContainsFunctionCalls(const clang::Stmt* E);
} // namespace utils
}

Expand Down
14 changes: 14 additions & 0 deletions include/clad/Differentiator/Compatibility.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,20 @@ static inline bool Expr_EvaluateAsInt(const Expr *E,
#endif
}

// Clang 12: bool Expr::EvaluateAsConstantExpr(EvalResult &Result,
// ConstExprUsage Usage, ASTContext &)
// => bool Expr::EvaluateAsConstantExpr(EvalResult &Result, ASTContext &)

static inline bool Expr_EvaluateAsConstantExpr(const Expr* E,
Expr::EvalResult& res,
const ASTContext& Ctx) {
#if CLANG_VERSION_MAJOR < 12
return E->EvaluateAsConstantExpr(res, Expr::EvaluateForCodeGen, Ctx);
#else
return E->EvaluateAsConstantExpr(res, Ctx);
#endif
}

// Compatibility helper function for creation IfStmt.
// Clang 8 and above use Create.
// Clang 12 and above use two extra params.
Expand Down
2 changes: 2 additions & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ namespace clad {
bool CallUpdateRequired = false;
/// A flag to enable/disable diag warnings/errors during differentiation.
bool VerboseDiags = false;
/// A flag to enable TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
/// Puts the derived function and its code in the diff call
void updateCall(clang::FunctionDecl* FD, clang::FunctionDecl* OverloadedFD,
clang::Sema& SemaRef);
Expand Down
54 changes: 37 additions & 17 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,20 @@ namespace clad {
Stmts m_Globals;
//// A reference to the output parameter of the gradient function.
clang::Expr* m_Result;
/// Based on To-Be-Recorded analysis performed before differentiation,
/// tells UsefulToStoreGlobal whether a variable with a given
/// SourceLocation has to be stored before being changed or not.
std::set<clang::SourceLocation> m_ToBeRecorded;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
bool isInsideLoop = false;
/// Output variable of vector-valued function
std::string outputArrayStr;
/// Stores the pop index values for arrays in reverse mode.This is required
/// to maintain the correct statement order when the current block has
/// delayed emission i.e. assignment LHS.
Stmts m_PopIdxValues;
std::vector<Stmts> m_LoopBlock;
unsigned outputArrayCursor = 0;
unsigned numParams = 0;
bool isVectorValued = false;
bool use_enzyme = false;
bool enableTBR = false;
// FIXME: Should we make this an object instead of a pointer?
// Downside of making it an object: We will need to include
// 'MultiplexExternalRMVSource.h' file
Expand Down Expand Up @@ -142,24 +143,33 @@ namespace clad {
/// Create new block.
Stmts& beginBlock(direction d = direction::forward) {
if (d == direction::forward)
m_Blocks.push_back({});
m_Blocks.emplace_back();
else
m_Reverse.push_back({});
m_Reverse.emplace_back();
return getCurrentBlock(d);
}
/// Remove the block from the stack, wrap it in CompoundStmt and return it.
clang::CompoundStmt* endBlock(direction d = direction::forward) {
if (d == direction::forward) {
auto CS = MakeCompoundStmt(getCurrentBlock(direction::forward));
auto* CS = MakeCompoundStmt(getCurrentBlock(direction::forward));
m_Blocks.pop_back();
return CS;
} else {
auto CS = MakeCompoundStmt(getCurrentBlock(direction::reverse));
auto* CS = MakeCompoundStmt(getCurrentBlock(direction::reverse));
std::reverse(CS->body_begin(), CS->body_end());
m_Reverse.pop_back();
return CS;
}
}

Stmts EndBlockWithoutCreatingCS(direction d = direction::forward) {
auto blk = getCurrentBlock(d);
if (d == direction::forward)
m_Blocks.pop_back();
else
m_Reverse.pop_back();
return blk;
}
/// Output a statement to the current block. If Stmt is null or is an unused
/// expression, it is not output and false is returned.
bool addToCurrentBlock(clang::Stmt* S, direction d = direction::forward) {
Expand Down Expand Up @@ -237,6 +247,10 @@ namespace clad {
StmtDiff GlobalStoreAndRef(clang::Expr* E,
llvm::StringRef prefix = "_t",
bool force = false);
StmtDiff BuildPushPop(clang::Expr* E, clang::QualType Type,
llvm::StringRef prefix = "_t", bool force = false);
StmtDiff StoreAndRestore(clang::Expr* E, llvm::StringRef prefix = "_t",
bool force = false);

//// A type returned by DelayedGlobalStoreAndRef
/// .Result is a reference to the created (yet uninitialized) global
Expand All @@ -250,6 +264,12 @@ namespace clad {
StmtDiff Result;
bool isConstant;
bool isInsideLoop;
bool needsUpdate;
DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult,
bool pIsConstant, bool pIsInsideLoop,
bool pNeedsUpdate = false)
: V(pV), Result(pResult), isConstant(pIsConstant),
isInsideLoop(pIsInsideLoop), needsUpdate(pNeedsUpdate) {}
void Finalize(clang::Expr* New);
};

Expand Down Expand Up @@ -393,7 +413,7 @@ namespace clad {
clang::QualType xType);

/// Allows to easily create and manage a counter for counting the number of
/// executed iterations of a loop.
/// executed iterations of a loop.
///
/// It is required to save the number of executed iterations to use the
/// same number of iterations in the reverse pass.
Expand All @@ -412,11 +432,11 @@ namespace clad {
/// for counter; otherwise, returns nullptr.
clang::Expr* getPush() const { return m_Push; }

/// Returns `clad::pop(_t)` expression if clad tape is used for
/// Returns `clad::pop(_t)` expression if clad tape is used for
/// for counter; otherwise, returns nullptr.
clang::Expr* getPop() const { return m_Pop; }

/// Returns reference to the last object of the clad tape if clad tape
/// Returns reference to the last object of the clad tape if clad tape
/// is used as the counter; otherwise returns reference to the counter
/// variable.
clang::Expr* getRef() const { return m_Ref; }
Expand Down Expand Up @@ -458,11 +478,11 @@ namespace clad {

/// This class modifies forward and reverse blocks of the loop
/// body so that `break` and `continue` statements are correctly
/// handled. `break` and `continue` statements are handled by
/// handled. `break` and `continue` statements are handled by
/// enclosing entire reverse block loop body in a switch statement
/// and only executing the statements, with the help of case labels,
/// that were executed in the associated forward iteration. This is
/// determined by keeping track of which `break`/`continue` statement
/// that were executed in the associated forward iteration. This is
/// determined by keeping track of which `break`/`continue` statement
/// was hit in which iteration and that in turn helps to determine which
/// case label should be selected.
///
Expand Down Expand Up @@ -490,7 +510,7 @@ namespace clad {
/// \note `m_ControlFlowTape` is only initialized if the body contains
/// `continue` or `break` statement.
std::unique_ptr<CladTapeResult> m_ControlFlowTape;

/// Each `break` and `continue` statement is assigned a unique number,
/// starting from 1, that is used as the case label corresponding to that `break`/`continue`
/// statement. `m_CaseCounter` stores the value that was used for last
Expand Down Expand Up @@ -529,7 +549,7 @@ namespace clad {
/// control flow switch statement.
clang::CaseStmt* GetNextCFCaseStmt();

/// Builds and returns `clad::push(TapeRef, m_CurrentCounter)`
/// Builds and returns `clad::push(TapeRef, m_CurrentCounter)`
/// expression, where `TapeRef` and `m_CurrentCounter` are replaced
/// by their actual values respectively.
clang::Stmt* CreateCFTapePushExprToCurrentCase();
Expand All @@ -552,7 +572,7 @@ namespace clad {
void PopBreakContStmtHandler() {
m_BreakContStmtHandlers.pop_back();
}

/// Registers an external RMV source.
///
/// Multiple external RMV source can be registered by calling this function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
namespace clad {
namespace rmv {
/// An enum to operate between forward and reverse passes.
enum direction : int { forward, reverse };
enum direction : int { forward, reverse };
} // namespace rmv
} // namespace clad

Expand Down
9 changes: 7 additions & 2 deletions include/clad/Differentiator/StmtClone.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "clang/Sema/Scope.h"

#include "llvm/ADT/DenseMap.h"
#include <unordered_map>

namespace clang {
class Stmt;
Expand Down Expand Up @@ -153,9 +154,13 @@ namespace utils {
clang::Sema& m_Sema; // We don't own.
clang::Scope* m_CurScope; // We don't own.
const clang::FunctionDecl* m_Function; // We don't own.
const std::unordered_map<const clang::VarDecl*, clang::VarDecl*>&
m_DeclReplacements; // We don't own.
public:
ReferencesUpdater(clang::Sema& SemaRef, clang::Scope* S,
const clang::FunctionDecl* FD);
ReferencesUpdater(
clang::Sema& SemaRef, clang::Scope* S, const clang::FunctionDecl* FD,
const std::unordered_map<const clang::VarDecl*, clang::VarDecl*>&
DeclReplacements);
bool VisitDeclRefExpr(clang::DeclRefExpr* DRE);
bool VisitStmt(clang::Stmt* S);
/// Used to update the size expression of QT
Expand Down
20 changes: 18 additions & 2 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@ namespace clad {
private:
std::array<clang::Stmt*, 2> data;
clang::Stmt* m_DerivativeForForwSweep;
clang::Stmt* m_ValueForRevSweep;

public:
StmtDiff(clang::Stmt* orig = nullptr, clang::Stmt* diff = nullptr,
clang::Stmt* forwSweepDiff = nullptr)
: m_DerivativeForForwSweep(forwSweepDiff) {
clang::Stmt* forwSweepDiff = nullptr,
clang::Stmt* valueForRevSweep = nullptr)
: m_DerivativeForForwSweep(forwSweepDiff),
m_ValueForRevSweep(valueForRevSweep) {
data[1] = orig;
data[0] = diff;
}
Expand All @@ -57,6 +61,18 @@ namespace clad {

clang::Stmt* getForwSweepStmt_dx() { return m_DerivativeForForwSweep; }

clang::Expr* getRevSweepAsExpr() {
return llvm::cast_or_null<clang::Expr>(getRevSweepStmt());
}

clang::Stmt* getRevSweepStmt() {
/// If there is no specific value for
/// the reverse sweep, use Stmt_dx.
if (!m_ValueForRevSweep)
return data[1];
return m_ValueForRevSweep;
}

clang::Expr* getForwSweepExpr_dx() {
return llvm::cast_or_null<clang::Expr>(m_DerivativeForForwSweep);
}
Expand Down
1 change: 1 addition & 0 deletions lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ add_llvm_library(cladDifferentiator
MultiplexExternalRMVSource.cpp
ReverseModeForwPassVisitor.cpp
ReverseModeVisitor.cpp
TBRAnalyzer.cpp
StmtClone.cpp
VectorForwardModeVisitor.cpp
Version.cpp
Expand Down
Loading

0 comments on commit 0ab0c6e

Please sign in to comment.