Skip to content

Commit

Permalink
Optimize memory usage in analysis.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Aug 30, 2023
1 parent 9dfaa85 commit a02a5ba
Show file tree
Hide file tree
Showing 8 changed files with 431 additions and 309 deletions.
12 changes: 6 additions & 6 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace clad {
std::string ComputeEffectiveFnName(const clang::FunctionDecl* FD);

/// Creates and returns a compound statement having statements as follows:
/// {`S`, all the statement of `initial` in sequence}
/// {`S`, all the statement of `initial` in sequence}
clang::CompoundStmt* PrependAndCreateCompoundStmt(clang::ASTContext& C,
clang::Stmt* initial,
clang::Stmt* S);
Expand All @@ -38,7 +38,7 @@ namespace clad {
clang::CompoundStmt* AppendAndCreateCompoundStmt(clang::ASTContext& C,
clang::Stmt* initial,
clang::Stmt* S);

/// Shorthand to issues a warning or error.
template <std::size_t N>
void EmitDiag(clang::Sema& semaRef,
Expand Down Expand Up @@ -126,8 +126,8 @@ namespace clad {
///
/// \param S
/// \param namespc
/// \param shouldExist If true, then asserts that the specified namespace
/// is found.
/// \param shouldExist If true, then asserts that the specified namespace
/// is found.
/// \param DC
clang::NamespaceDecl* LookupNSD(clang::Sema& S, llvm::StringRef namespc,
bool shouldExist,
Expand Down Expand Up @@ -234,7 +234,7 @@ namespace clad {

bool IsCladValueAndPushforwardType(clang::QualType T);

/// Returns a valid `SourceRange` to be used in places where clang
/// Returns a valid `SourceRange` to be used in places where clang
/// requires a valid `SourceRange`.
clang::SourceRange GetValidSRange(clang::Sema& semaRef);

Expand Down Expand Up @@ -314,7 +314,7 @@ namespace clad {

bool hasNonDifferentiableAttribute(const clang::Expr* E);
/// FIXME: add documentation
std::vector<clang::Expr*> GetInnermostReturnExpr(clang::Expr* E);
std::vector<clang::Expr*> GetInnermostReturnExpr(const clang::Expr* E);
} // namespace utils
}

Expand Down
31 changes: 15 additions & 16 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace clad {
class ReverseModeVisitor
: public clang::ConstStmtVisitor<ReverseModeVisitor, StmtDiff>,
public VisitorBase {

private:
// FIXME: We should remove friend-dependency of the plugin classes here.
// For this we will need to separate out AST related functions in
Expand All @@ -52,6 +52,10 @@ namespace clad {
Stmts m_Globals;
//// A reference to the output parameter of the gradient function.
clang::Expr* m_Result;
/// Based on To-Be-Recorded analysis performed before differentiation,
/// tells UsefulToStoreGlobal whether a variable with a given
/// SourceLocation has to be stored before being changed or not.
std::map<clang::SourceLocation, bool> m_ToBeRecorded;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
bool isInsideLoop = false;
/// Output variable of vector-valued function
Expand Down Expand Up @@ -136,7 +140,7 @@ namespace clad {
return m_Blocks.back();
else if (d == direction::reverse)
return m_Reverse.back();
else
else
return m_EssentialReverse.back();
}
/// Create new block.
Expand All @@ -145,7 +149,7 @@ namespace clad {
m_Blocks.push_back({});
else if (d == direction::reverse)
m_Reverse.push_back({});
else
else
m_EssentialReverse.push_back({});
return getCurrentBlock(d);
}
Expand Down Expand Up @@ -228,11 +232,6 @@ namespace clad {
forceDeclCreation, IS);
}

/// Based on To-Be-Recorded analysis performed before differentiation,
/// tells UsefulToStoreGlobal whether a variable with a given
/// SourceLocation has to be stored before changed or not.
std::map<clang::SourceLocation, bool> m_ToBeRecorded;

/// For an expr E, decides if it is useful to store it in a global temporary
/// variable and replace E's further usage by a reference to that variable
/// to avoid recomputiation.
Expand Down Expand Up @@ -429,7 +428,7 @@ namespace clad {
clang::QualType xType);

/// Allows to easily create and manage a counter for counting the number of
/// executed iterations of a loop.
/// executed iterations of a loop.
///
/// It is required to save the number of executed iterations to use the
/// same number of iterations in the reverse pass.
Expand All @@ -448,11 +447,11 @@ namespace clad {
/// for counter; otherwise, returns nullptr.
clang::Expr* getPush() const { return m_Push; }

/// Returns `clad::pop(_t)` expression if clad tape is used for
/// Returns `clad::pop(_t)` expression if clad tape is used for
/// for counter; otherwise, returns nullptr.
clang::Expr* getPop() const { return m_Pop; }

/// Returns reference to the last object of the clad tape if clad tape
/// Returns reference to the last object of the clad tape if clad tape
/// is used as the counter; otherwise returns reference to the counter
/// variable.
clang::Expr* getRef() const { return m_Ref; }
Expand Down Expand Up @@ -494,11 +493,11 @@ namespace clad {

/// This class modifies forward and reverse blocks of the loop
/// body so that `break` and `continue` statements are correctly
/// handled. `break` and `continue` statements are handled by
/// handled. `break` and `continue` statements are handled by
/// enclosing entire reverse block loop body in a switch statement
/// and only executing the statements, with the help of case labels,
/// that were executed in the associated forward iteration. This is
/// determined by keeping track of which `break`/`continue` statement
/// that were executed in the associated forward iteration. This is
/// determined by keeping track of which `break`/`continue` statement
/// was hit in which iteration and that in turn helps to determine which
/// case label should be selected.
///
Expand Down Expand Up @@ -526,7 +525,7 @@ namespace clad {
/// \note `m_ControlFlowTape` is only initialized if the body contains
/// `continue` or `break` statement.
std::unique_ptr<CladTapeResult> m_ControlFlowTape;

/// Each `break` and `continue` statement is assigned a unique number,
/// starting from 1, that is used as the case label corresponding to that `break`/`continue`
/// statement. `m_CaseCounter` stores the value that was used for last
Expand Down Expand Up @@ -565,7 +564,7 @@ namespace clad {
/// control flow switch statement.
clang::CaseStmt* GetNextCFCaseStmt();

/// Builds and returns `clad::push(TapeRef, m_CurrentCounter)`
/// Builds and returns `clad::push(TapeRef, m_CurrentCounter)`
/// expression, where `TapeRef` and `m_CurrentCounter` are replaced
/// by their actual values respectively.
clang::Stmt* CreateCFTapePushExprToCurrentCase();
Expand Down
85 changes: 46 additions & 39 deletions include/clad/Differentiator/TBRAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,18 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// a row, which seems uncommon. It's worth considering analysing arrays as
/// whole structures instead (just one VarData for the whole array).

struct VarData;
using ObjMap = std::unordered_map<const clang::FieldDecl*, VarData*>;
using ArrMap = std::unordered_map<const llvm::APInt, VarData*, APIntHash>;

struct VarData {
enum VarDataType { UNDEFINED, FUND_TYPE, OBJ_TYPE, ARR_TYPE, REF_TYPE };
union VarDataValue {
bool fundData;
std::unordered_map<const clang::FieldDecl*, VarData*> objData;
std::unordered_map<const llvm::APInt, VarData*, APIntHash> arrData;
/// objData, arrData are stored as pointers for VarDataValue to take
/// less space.
ObjMap* objData;
ArrMap* arrData;
VarData* refData;
VarDataValue() {}
~VarDataValue() {}
Expand All @@ -85,20 +91,17 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
VarDataValue val;
bool isReferenced = false;

/// For non-fundamental type variables, all the child nodes have to be
/// deleted.
~VarData() {
if (type == OBJ_TYPE) {
for (auto pair : val.objData) {
for (auto& pair : *val.objData) {
delete pair.second;
}
} else if (type == ARR_TYPE) {
for (auto pair : val.arrData) {
for (auto& pair : *val.arrData) {
delete pair.second;
}
}
}

/// Recursively sets all the leaves' bools to isReq.
void setIsRequired(bool isReq = true);
/// Returns true if there is at least one required to store node among
Expand All @@ -122,7 +125,9 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// corresponding original nodes in case those are referenced (a referenced
/// node is a child to multiple nodes, therefore, we need to make sure we
/// don't make multiple copies of it).
VarData* copy();
VarData* copy(std::unordered_map<VarData*, VarData*>& refVars);
void restoreRefs(std::unordered_map<VarData*, VarData*>& refVars);
};

/// Given a MemberExpr*/ArraySubscriptExpr* return a pointer to its
Expand All @@ -138,7 +143,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// Given an Expr* returns its corresponding VarData.
VarData* getExprVarData(const clang::Expr* E, bool addNonConstIdx = false);
/// Adds the field FD to objData.
void addField(std::unordered_map<const clang::FieldDecl*, VarData*>& objData,
void addField(std::unordered_map<const clang::FieldDecl*, VarData*>* objData,
const FieldDecl* FD);
/// Whenever an array element with a non-constant index is set to required
/// this function is used to set to required all the array elements that
Expand All @@ -164,7 +169,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
std::map<clang::SourceLocation, bool> TBRLocs;
/// Stores VarsData for every branch in control flow (e.g. if-else statements,
/// loops).
std::vector<VarsData> reqStack;
std::vector<std::vector<VarsData>> reqStack;
/// Stores modes in a stack (used to retrieve the old mode after entering
/// a new one).
std::vector<int> modeStack;
Expand All @@ -179,7 +184,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {

/// The index of the innermost branch corresponding to a loop (used to handle
/// break/continue statements).
size_t innermostLoopBranch = 0;
size_t innermostLoopLayer = 0;
/// Tells if the current branch should be deleted instead of merged with
/// others. This happens when the branch has a break/continue statement or a
/// return expression in it.
Expand All @@ -204,22 +209,28 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
void setIsRequired(const clang::Expr* E, bool isReq = true);

//// Control Flow
/// Creates a new branch as a copy of the last used branch.
void addBranch();
/// Merges the last into the one right before it and deletes it.
/// If keepNewVars==false, it removes all the variables that are present
/// in the last branch but not the other. If keepNewVars==true, all the new
/// variables are kept.
/// Note: The branch we are merging into is not supposed to have its own
/// local variables (this doesn't matter to the branch being merged).
void mergeAndDelete(bool keepNewVars = false);
/// Swaps the last two branches in the stack.
void swapLastPairOfBranches();
/// Merges the current branch to a branch with a given index in the stack.
/// Current branch is NOT deleted.
/// Note: The branch we are merging into is not supposed to have its own
/// local variables (this doesn't matter to the branch being merged).
void mergeCurBranchTo(size_t targetBranchNum);
/// Returns the current branch.
VarsData& getCurBranch() { return reqStack.back().back(); }
/// Adds a new layer.
void addLayer() { reqStack.emplace_back(); }
/// Creates a new empty branch.
void addBranch() { reqStack.back().emplace_back(); }
/// Deletes the last branch.
void deleteBranch() {
for (auto& pair : getCurBranch())
delete pair.second;
reqStack.back().pop_back();
}
/// Merges the last layer into the one last branch on the previous layer
/// right and deletes the last layer.
void mergeLayer();
/// Merges the last layer but, unlike the previous method, basically replaces
/// the last branch on the previous layer with the result of merging. After
/// that, removes the last layer.
void mergeLayerOnTop();
/// Merges the branch with index targetBranch into a sourceBranchNum.
/// No branches are deleted.
void mergeBranchTo(size_t sourceBranchNum, VarsData& targetBranch);
/// Removes local variables from the current branch (uses localVarsStack).
/// This is necessary when merging if-else branches.
void removeLocalVars();
Expand All @@ -239,29 +250,25 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
void resetMode() { modeStack.pop_back(); }

public:
// Constructor
/// Constructor
TBRAnalyzer(ASTContext* m_Context) : m_Context(m_Context) {
modeStack.push_back(0);
reqStack.push_back(VarsData());
addLayer();
addBranch();
}

// Destructor
/// Destructor
~TBRAnalyzer() {
/// By the end of analysis, reqStack is supposed have just one branch
/// but it's better to iterate through it just to make sure there's no
/// memory leak.
for (auto& branch : reqStack) {
for (auto pair : branch) {
delete pair.second;
}
}
for (auto& layer : reqStack)
for (auto& branch : layer)
for (auto& pair : branch)
delete pair.second;
}

/// Returns the result of the whole analysis
const std::map<clang::SourceLocation, bool> getResult() { return TBRLocs; }
std::map<clang::SourceLocation, bool> getResult() { return TBRLocs; }

/// Visitors

void Analyze(const clang::FunctionDecl* FD);

void Visit(const clang::Stmt* stmt) {
Expand Down
10 changes: 3 additions & 7 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,15 +534,11 @@ namespace clad {
return false;
}

std::vector<clang::Expr*> GetInnermostReturnExpr(clang::Expr* E) {
std::vector<clang::Expr*> GetInnermostReturnExpr(const clang::Expr* E) {
struct Finder : public ConstStmtVisitor<Finder> {
std::vector<clang::Expr*> m_return_exprs;
// Sema* m_Sema;
// ASTContext* m_Context;

public:
Finder(/*Sema* S*/) /* : m_Sema(S), m_Context(S.getASTContext()) */ {}

std::vector<clang::Expr*> Find(const clang::Expr* E) {
Visit(E);
return m_return_exprs;
Expand All @@ -551,7 +547,7 @@ namespace clad {
void VisitBinaryOperator(const clang::BinaryOperator* BO) {
if (BO->isAssignmentOp() || BO->isCompoundAssignmentOp()) {
Visit(BO->getLHS());
} else if (BO->isCommaOp()) {
} else if (BO->getOpcode() == clang::BO_Comma) {
/**/
} else {
assert("Unexpected binary operator!!");
Expand All @@ -576,7 +572,7 @@ namespace clad {
}

void VisitExpr(const clang::Expr* E) {
if (auto PE = dyn_cast<clang::ParenExpr>(E)) {
if (const auto PE = dyn_cast<clang::ParenExpr>(E)) {
Visit(PE->getSubExpr());
}
}
Expand Down
8 changes: 2 additions & 6 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
DerivativeAndOverload
ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD,
const DiffRequest& request) {
TBRAnalyzer* analyzer = new TBRAnalyzer(&m_Context);
auto* analyzer = new TBRAnalyzer(&m_Context);
analyzer->Analyze(FD);
m_ToBeRecorded = analyzer->getResult();
delete analyzer;
Expand Down Expand Up @@ -571,7 +571,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

void ReverseModeVisitor::DifferentiateWithClad() {
TBRAnalyzer* analyzer = new TBRAnalyzer(&m_Context);
auto* analyzer = new TBRAnalyzer(&m_Context);
analyzer->Analyze(m_Function);
m_ToBeRecorded = analyzer->getResult();
delete analyzer;
Expand Down Expand Up @@ -1897,10 +1897,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// call to gradient and call to original function. At this point, each arg
// is either a simple expression or a reference to a temporary variable.
// Therefore cloning it has constant complexity.
std::transform(std::begin(CallArgs),
std::end(CallArgs),
std::begin(CallArgs),
[this](Expr* E) { return Clone(E); });
// Recreate the original call expression.
Expr* call = m_Sema
.ActOnCallExpr(getCurrentScope(),
Expand Down
Loading

0 comments on commit a02a5ba

Please sign in to comment.