Skip to content

Commit

Permalink
Optimize memory usage in analysis.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Aug 28, 2023
1 parent 9dfaa85 commit 78c88be
Show file tree
Hide file tree
Showing 7 changed files with 368 additions and 252 deletions.
31 changes: 15 additions & 16 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace clad {
class ReverseModeVisitor
: public clang::ConstStmtVisitor<ReverseModeVisitor, StmtDiff>,
public VisitorBase {

private:
// FIXME: We should remove friend-dependency of the plugin classes here.
// For this we will need to separate out AST related functions in
Expand All @@ -52,6 +52,10 @@ namespace clad {
Stmts m_Globals;
//// A reference to the output parameter of the gradient function.
clang::Expr* m_Result;
/// Based on To-Be-Recorded analysis performed before differentiation,
/// tells UsefulToStoreGlobal whether a variable with a given
/// SourceLocation has to be stored before being changed or not.
std::map<clang::SourceLocation, bool> m_ToBeRecorded;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
bool isInsideLoop = false;
/// Output variable of vector-valued function
Expand Down Expand Up @@ -136,7 +140,7 @@ namespace clad {
return m_Blocks.back();
else if (d == direction::reverse)
return m_Reverse.back();
else
else
return m_EssentialReverse.back();
}
/// Create new block.
Expand All @@ -145,7 +149,7 @@ namespace clad {
m_Blocks.push_back({});
else if (d == direction::reverse)
m_Reverse.push_back({});
else
else
m_EssentialReverse.push_back({});
return getCurrentBlock(d);
}
Expand Down Expand Up @@ -228,11 +232,6 @@ namespace clad {
forceDeclCreation, IS);
}

/// Based on To-Be-Recorded analysis performed before differentiation,
/// tells UsefulToStoreGlobal whether a variable with a given
/// SourceLocation has to be stored before changed or not.
std::map<clang::SourceLocation, bool> m_ToBeRecorded;

/// For an expr E, decides if it is useful to store it in a global temporary
/// variable and replace E's further usage by a reference to that variable
/// to avoid recomputiation.
Expand Down Expand Up @@ -429,7 +428,7 @@ namespace clad {
clang::QualType xType);

/// Allows to easily create and manage a counter for counting the number of
/// executed iterations of a loop.
/// executed iterations of a loop.
///
/// It is required to save the number of executed iterations to use the
/// same number of iterations in the reverse pass.
Expand All @@ -448,11 +447,11 @@ namespace clad {
/// for counter; otherwise, returns nullptr.
clang::Expr* getPush() const { return m_Push; }

/// Returns `clad::pop(_t)` expression if clad tape is used for
/// Returns `clad::pop(_t)` expression if clad tape is used for
/// for counter; otherwise, returns nullptr.
clang::Expr* getPop() const { return m_Pop; }

/// Returns reference to the last object of the clad tape if clad tape
/// Returns reference to the last object of the clad tape if clad tape
/// is used as the counter; otherwise returns reference to the counter
/// variable.
clang::Expr* getRef() const { return m_Ref; }
Expand Down Expand Up @@ -494,11 +493,11 @@ namespace clad {

/// This class modifies forward and reverse blocks of the loop
/// body so that `break` and `continue` statements are correctly
/// handled. `break` and `continue` statements are handled by
/// handled. `break` and `continue` statements are handled by
/// enclosing entire reverse block loop body in a switch statement
/// and only executing the statements, with the help of case labels,
/// that were executed in the associated forward iteration. This is
/// determined by keeping track of which `break`/`continue` statement
/// that were executed in the associated forward iteration. This is
/// determined by keeping track of which `break`/`continue` statement
/// was hit in which iteration and that in turn helps to determine which
/// case label should be selected.
///
Expand Down Expand Up @@ -526,7 +525,7 @@ namespace clad {
/// \note `m_ControlFlowTape` is only initialized if the body contains
/// `continue` or `break` statement.
std::unique_ptr<CladTapeResult> m_ControlFlowTape;

/// Each `break` and `continue` statement is assigned a unique number,
/// starting from 1, that is used as the case label corresponding to that `break`/`continue`
/// statement. `m_CaseCounter` stores the value that was used for last
Expand Down Expand Up @@ -565,7 +564,7 @@ namespace clad {
/// control flow switch statement.
clang::CaseStmt* GetNextCFCaseStmt();

/// Builds and returns `clad::push(TapeRef, m_CurrentCounter)`
/// Builds and returns `clad::push(TapeRef, m_CurrentCounter)`
/// expression, where `TapeRef` and `m_CurrentCounter` are replaced
/// by their actual values respectively.
clang::Stmt* CreateCFTapePushExprToCurrentCase();
Expand Down
98 changes: 52 additions & 46 deletions include/clad/Differentiator/TBRAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,18 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// a row, which seems uncommon. It's worth considering analysing arrays as
/// whole structures instead (just one VarData for the whole array).

struct VarData;
using ObjMap =
std::unordered_map<const clang::FieldDecl*, std::unique_ptr<VarData>>;
using ArrMap = std::unordered_map<const llvm::APInt, std::unique_ptr<VarData>,
APIntHash>;

struct VarData {
enum VarDataType { UNDEFINED, FUND_TYPE, OBJ_TYPE, ARR_TYPE, REF_TYPE };
union VarDataValue {
bool fundData;
std::unordered_map<const clang::FieldDecl*, VarData*> objData;
std::unordered_map<const llvm::APInt, VarData*, APIntHash> arrData;
std::unique_ptr<ObjMap> objData;
std::unique_ptr<ArrMap> arrData;
VarData* refData;
VarDataValue() {}
~VarDataValue() {}
Expand All @@ -85,18 +91,21 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
VarDataValue val;
bool isReferenced = false;

/// For non-fundamental type variables, all the child nodes have to be
/// deleted.
~VarData() {
if (type == OBJ_TYPE) {
for (auto pair : val.objData) {
delete pair.second;
}
VarData() = default;
VarData(VarData&& other) { *this = std::move(other); }
VarData& operator=(VarData&& other) {
type = other.type;
isReferenced = other.isReferenced;
if (type == FUND_TYPE) {
val.fundData = other.val.fundData;
} else if (type == ARR_TYPE) {
for (auto pair : val.arrData) {
delete pair.second;
}
val.arrData = std::move(other.val.arrData);
} else if (type == OBJ_TYPE) {
val.objData = std::move(other.val.objData);
} else if (type == REF_TYPE) {
val.refData = other.val.refData;
}
return *this;
}

/// Recursively sets all the leaves' bools to isReq.
Expand All @@ -116,13 +125,19 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// (e.g. after an if-else statements). Look at the Control Flow section for
/// more information.
void merge(VarData* mergeData);
void merge(const std::unique_ptr<VarData>& mergeData) {
merge(mergeData.get());
}
/// Used to recursively copy VarData when separating into different branches
/// (e.g. when entering an if-else statements). Look at the Control Flow
/// section for more information. refVars stores copied nodes for
/// corresponding original nodes in case those are referenced (a referenced
/// node is a child to multiple nodes, therefore, we need to make sure we
/// don't make multiple copies of it).
VarData* copy(std::unordered_map<VarData*, VarData*>& refVars);
std::unique_ptr<VarData> copy();
std::unique_ptr<VarData>
copy(std::unordered_map<VarData*, VarData*>& refVars);
void restoreRefs(std::unordered_map<VarData*, VarData*>& refVars);
};

/// Given a MemberExpr*/ArraySubscriptExpr* return a pointer to its
Expand All @@ -138,7 +153,8 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// Given an Expr* returns its corresponding VarData.
VarData* getExprVarData(const clang::Expr* E, bool addNonConstIdx = false);
/// Adds the field FD to objData.
void addField(std::unordered_map<const clang::FieldDecl*, VarData*>& objData,
void addField(std::unordered_map<const clang::FieldDecl*,
std::unique_ptr<VarData>>* objData,
const FieldDecl* FD);
/// Whenever an array element with a non-constant index is set to required
/// this function is used to set to required all the array elements that
Expand All @@ -152,7 +168,8 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// particular moment.
/// Note: 'this' pointer does not have a declaration so nullptr is used as
/// its key instead.
using VarsData = std::unordered_map<const clang::VarDecl*, VarData*>;
using VarsData =
std::unordered_map<const clang::VarDecl*, std::unique_ptr<VarData>>;
/// Used to find DeclRefExpr's that will be used in the backwards pass.
/// In order to be marked as required, a variables has to appear in a place
/// where it would have a differential influence and will appear non-linearly
Expand All @@ -164,7 +181,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
std::map<clang::SourceLocation, bool> TBRLocs;
/// Stores VarsData for every branch in control flow (e.g. if-else statements,
/// loops).
std::vector<VarsData> reqStack;
std::vector<std::vector<VarsData>> reqStack;
/// Stores modes in a stack (used to retrieve the old mode after entering
/// a new one).
std::vector<int> modeStack;
Expand All @@ -179,7 +196,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {

/// The index of the innermost branch corresponding to a loop (used to handle
/// break/continue statements).
size_t innermostLoopBranch = 0;
size_t innermostLoopLayer = 0;
/// Tells if the current branch should be deleted instead of merged with
/// others. This happens when the branch has a break/continue statement or a
/// return expression in it.
Expand All @@ -204,22 +221,22 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
void setIsRequired(const clang::Expr* E, bool isReq = true);

//// Control Flow
/// Returns the current branch.
VarsData& getCurBranch() { return reqStack.back().back(); }
/// Adds a new layer.
void addLayer() { reqStack.push_back(std::vector<VarsData>()); }
/// Creates a new branch as a copy of the last used branch.
void addBranch();
/// Merges the last into the one right before it and deletes it.
/// If keepNewVars==false, it removes all the variables that are present
/// in the last branch but not the other. If keepNewVars==true, all the new
/// variables are kept.
/// Note: The branch we are merging into is not supposed to have its own
/// local variables (this doesn't matter to the branch being merged).
void mergeAndDelete(bool keepNewVars = false);
/// Swaps the last two branches in the stack.
void swapLastPairOfBranches();
/// Merges the current branch to a branch with a given index in the stack.
/// Current branch is NOT deleted.
/// Note: The branch we are merging into is not supposed to have its own
/// local variables (this doesn't matter to the branch being merged).
void mergeCurBranchTo(size_t targetBranchNum);
void addBranch() { reqStack.back().push_back(VarsData()); }
/// Merges the last layer into the one last branch on the previous layer
/// right and deletes the last layer.
void mergeLayer();
/// Merges the last layer but, unlike the previous method, basically replaces
/// the last branch on the previous layer with the result of merging. After
/// that, removes the last layer.
void mergeLayerOnTop();
/// Merges the branch with index targetBranch into a sourceBranchNum.
/// No branches are deleted.
void mergeBranchTo(size_t sourceBranchNum, VarsData& targetBranch);
/// Removes local variables from the current branch (uses localVarsStack).
/// This is necessary when merging if-else branches.
void removeLocalVars();
Expand All @@ -242,23 +259,12 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
// Constructor
TBRAnalyzer(ASTContext* m_Context) : m_Context(m_Context) {
modeStack.push_back(0);
reqStack.push_back(VarsData());
}

// Destructor
~TBRAnalyzer() {
/// By the end of analysis, reqStack is supposed have just one branch
/// but it's better to iterate through it just to make sure there's no
/// memory leak.
for (auto& branch : reqStack) {
for (auto pair : branch) {
delete pair.second;
}
}
addLayer();
addBranch();
}

/// Returns the result of the whole analysis
const std::map<clang::SourceLocation, bool> getResult() { return TBRLocs; }
std::map<clang::SourceLocation, bool> getResult() { return TBRLocs; }

/// Visitors

Expand Down
8 changes: 2 additions & 6 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,12 +537,8 @@ namespace clad {
std::vector<clang::Expr*> GetInnermostReturnExpr(clang::Expr* E) {
struct Finder : public ConstStmtVisitor<Finder> {
std::vector<clang::Expr*> m_return_exprs;
// Sema* m_Sema;
// ASTContext* m_Context;

public:
Finder(/*Sema* S*/) /* : m_Sema(S), m_Context(S.getASTContext()) */ {}

std::vector<clang::Expr*> Find(const clang::Expr* E) {
Visit(E);
return m_return_exprs;
Expand All @@ -551,7 +547,7 @@ namespace clad {
void VisitBinaryOperator(const clang::BinaryOperator* BO) {
if (BO->isAssignmentOp() || BO->isCompoundAssignmentOp()) {
Visit(BO->getLHS());
} else if (BO->isCommaOp()) {
} else if (BO->getOpcode() == clang::BO_Comma) {
/**/
} else {
assert("Unexpected binary operator!!");
Expand All @@ -576,7 +572,7 @@ namespace clad {
}

void VisitExpr(const clang::Expr* E) {
if (auto PE = dyn_cast<clang::ParenExpr>(E)) {
if (const auto PE = dyn_cast<clang::ParenExpr>(E)) {
Visit(PE->getSubExpr());
}
}
Expand Down
8 changes: 2 additions & 6 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
DerivativeAndOverload
ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD,
const DiffRequest& request) {
TBRAnalyzer* analyzer = new TBRAnalyzer(&m_Context);
auto* analyzer = new TBRAnalyzer(&m_Context);
analyzer->Analyze(FD);
m_ToBeRecorded = analyzer->getResult();
delete analyzer;
Expand Down Expand Up @@ -571,7 +571,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

void ReverseModeVisitor::DifferentiateWithClad() {
TBRAnalyzer* analyzer = new TBRAnalyzer(&m_Context);
auto* analyzer = new TBRAnalyzer(&m_Context);
analyzer->Analyze(m_Function);
m_ToBeRecorded = analyzer->getResult();
delete analyzer;
Expand Down Expand Up @@ -1897,10 +1897,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// call to gradient and call to original function. At this point, each arg
// is either a simple expression or a reference to a temporary variable.
// Therefore cloning it has constant complexity.
std::transform(std::begin(CallArgs),
std::end(CallArgs),
std::begin(CallArgs),
[this](Expr* E) { return Clone(E); });
// Recreate the original call expression.
Expr* call = m_Sema
.ActOnCallExpr(getCurrentScope(),
Expand Down
Loading

0 comments on commit 78c88be

Please sign in to comment.