Skip to content

Commit

Permalink
Merge "_t" variable declarations with their initialization points on …
Browse files Browse the repository at this point in the history
…function global scope
  • Loading branch information
PetroZarytskyi committed Jul 30, 2024
1 parent e326768 commit 725f63a
Show file tree
Hide file tree
Showing 26 changed files with 405 additions and 682 deletions.
10 changes: 7 additions & 3 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,18 @@ namespace clad {
struct DelayedStoreResult {
ReverseModeVisitor& V;
StmtDiff Result;
clang::VarDecl* Declaration;
bool isConstant;
bool isInsideLoop;
bool isFnScope;
bool needsUpdate;
DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult,
bool pIsConstant, bool pIsInsideLoop,
clang::VarDecl* pDeclaration, bool pIsConstant,
bool pIsInsideLoop, bool pIsFnScope,
bool pNeedsUpdate = false)
: V(pV), Result(pResult), isConstant(pIsConstant),
isInsideLoop(pIsInsideLoop), needsUpdate(pNeedsUpdate) {}
: V(pV), Result(pResult), Declaration(pDeclaration),
isConstant(pIsConstant), isInsideLoop(pIsInsideLoop),
isFnScope(pIsFnScope), needsUpdate(pNeedsUpdate) {}
void Finalize(clang::Expr* New);
};

Expand Down
108 changes: 75 additions & 33 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

auto VisitBranch = [&](const Expr* Branch,
Expr* dfdx) -> std::pair<StmtDiff, StmtDiff> {
beginScope(Scope::DeclScope);
auto Result = DifferentiateSingleExpr(Branch, dfdx);
endScope();
StmtDiff BranchDiff = Result.first;
StmtDiff ExprDiff = Result.second;
Stmt* Forward = utils::unwrapIfSingleStmt(BranchDiff.getStmt());
Expand Down Expand Up @@ -975,10 +977,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitForStmt(const ForStmt* FS) {
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Scope::ContinueScope);
beginBlock(direction::reverse);
LoopCounter loopCounter(*this);
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Scope::ContinueScope);
llvm::SaveAndRestore<Expr*> SaveCurrentBreakFlagExpr(
m_CurrentBreakFlagExpr);
m_CurrentBreakFlagExpr = nullptr;
Expand Down Expand Up @@ -2366,16 +2368,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// and restore it in the reverse pass
if (m_DiffReq.shouldBeRecorded(L)) {
StmtDiff pushPop = StoreAndRestore(LCloned);
addToCurrentBlock(pushPop.getExpr(), direction::forward);
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
addToCurrentBlock(pushPop.getStmt(), direction::forward);
addToCurrentBlock(pushPop.getStmt_dx(), direction::reverse);
}

// We need to store values of derivative pointer variables in forward pass
// and restore them in reverse pass.
if (isPointerOp) {
StmtDiff pushPop = StoreAndRestore(Ldiff.getExpr_dx());
addToCurrentBlock(pushPop.getExpr(), direction::forward);
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
addToCurrentBlock(pushPop.getStmt(), direction::forward);
addToCurrentBlock(pushPop.getStmt_dx(), direction::reverse);
}

if (m_ExternalSource)
Expand Down Expand Up @@ -2750,8 +2752,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(assignDerivativeE);
if (isInsideLoop) {
StmtDiff pushPop = StoreAndRestore(derivedVDE);
addToCurrentBlock(pushPop.getExpr(), direction::forward);
m_LoopBlock.back().push_back(pushPop.getExpr_dx());
addToCurrentBlock(pushPop.getStmt(), direction::forward);
m_LoopBlock.back().push_back(pushPop.getStmt_dx());
}
derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE);
}
Expand Down Expand Up @@ -3167,9 +3169,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return CladTape.Last();
}

Expr* Ref = BuildDeclRef(GlobalStoreImpl(Type, prefix));
Expr* Set = BuildOp(BO_Assign, Ref, E);
addToCurrentBlock(Set, direction::forward);
VarDecl* VD = BuildGlobalVarDecl(Type, prefix);
DeclStmt* decl = BuildDeclStmt(VD);
Expr* Ref = BuildDeclRef(VD);
bool isFnScope = getCurrentScope()->isFunctionScope() ||
m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass;
if (isFnScope) {
addToCurrentBlock(decl, direction::forward);
m_Sema.AddInitializerToDecl(VD, E, /*DirectInit=*/true);
VD->setInitStyle(VarDecl::InitializationStyle::CInit);
} else {
addToBlock(decl, m_Globals);
Expr* Set = BuildOp(BO_Assign, Ref, E);
addToCurrentBlock(Set, direction::forward);
}

return Ref;
}
Expand All @@ -3183,6 +3196,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

StmtDiff ReverseModeVisitor::StoreAndRestore(clang::Expr* E,
llvm::StringRef prefix) {
assert(E && "must be provided");
auto Type = getNonConstType(E->getType(), m_Context, m_Sema);

if (isInsideLoop) {
Expand All @@ -3197,17 +3211,26 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (const auto* AT = dyn_cast<ArrayType>(Type))
init = getArraySizeExpr(AT, m_Context, *this);

Expr* Ref = BuildDeclRef(GlobalStoreImpl(Type, prefix, init));
if (E) {
Expr* Store = BuildOp(BO_Assign, Ref, Clone(E));
Expr* Restore = nullptr;
if (E->isModifiableLvalue(m_Context) == Expr::MLV_Valid) {
auto* r = Clone(E);
Restore = BuildOp(BO_Assign, r, Ref);
}
return {Store, Restore};
VarDecl* VD = BuildGlobalVarDecl(Type, prefix, init);
DeclStmt* decl = BuildDeclStmt(VD);
Expr* Ref = BuildDeclRef(VD);
Stmt* Store = nullptr;
bool isFnScope = getCurrentScope()->isFunctionScope() ||
m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass;
if (isFnScope) {
Store = decl;
m_Sema.AddInitializerToDecl(VD, E, /*DirectInit=*/true);
VD->setInitStyle(VarDecl::InitializationStyle::CInit);
} else {
addToBlock(decl, m_Globals);
Store = BuildOp(BO_Assign, Ref, Clone(E));
}
return {};

Stmt* Restore = nullptr;
if (E->isModifiableLvalue(m_Context) == Expr::MLV_Valid)
Restore = BuildOp(BO_Assign, Clone(E), Ref);

return {Store, Restore};
}

void ReverseModeVisitor::DelayedStoreResult::Finalize(Expr* New) {
Expand All @@ -3217,6 +3240,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto* Push = cast<CallExpr>(Result.getExpr());
unsigned lastArg = Push->getNumArgs() - 1;
Push->setArg(lastArg, V.m_Sema.DefaultLvalueConversion(New).get());
} else if (isFnScope) {
V.m_Sema.AddInitializerToDecl(Declaration, New, true);
Declaration->setInitStyle(VarDecl::InitializationStyle::CInit);
V.addToCurrentBlock(V.BuildDeclStmt(Declaration), direction::forward);
} else {
V.addToCurrentBlock(V.BuildOp(BO_Assign, Result.getExpr(), New),
direction::forward);
Expand All @@ -3232,26 +3259,42 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr::EvalResult evalRes;
bool isConst =
clad_compat::Expr_EvaluateAsConstantExpr(E, evalRes, m_Context);
return DelayedStoreResult{*this, Ediff,
/*isConstant*/ isConst,
/*isInsideLoop*/ false,
return DelayedStoreResult{*this,
Ediff,
/*Declaration=*/nullptr,
/*isConstant=*/isConst,
/*isInsideLoop=*/false,
/*isFnScope=*/false,
/*pNeedsUpdate=*/false};
}
if (isInsideLoop) {
Expr* dummy = E;
auto CladTape = MakeCladTapeFor(dummy);
Expr* Push = CladTape.Push;
Expr* Pop = CladTape.Pop;
return DelayedStoreResult{*this, StmtDiff{Push, nullptr, nullptr, Pop},
/*isConstant*/ false,
/*isInsideLoop*/ true, /*pNeedsUpdate=*/true};
return DelayedStoreResult{*this,
StmtDiff{Push, nullptr, nullptr, Pop},
/*Declaration=*/nullptr,
/*isConstant=*/false,
/*isInsideLoop=*/true,
/*isFnScope=*/false,
/*pNeedsUpdate=*/true};
}
Expr* Ref = BuildDeclRef(GlobalStoreImpl(
getNonConstType(E->getType(), m_Context, m_Sema), prefix));
bool isFnScope = getCurrentScope()->isFunctionScope() ||
m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass;
VarDecl* VD = BuildGlobalVarDecl(
getNonConstType(E->getType(), m_Context, m_Sema), prefix);
Expr* Ref = BuildDeclRef(VD);
if (!isFnScope)
addToBlock(BuildDeclStmt(VD), m_Globals);
// Return reference to the declaration instead of original expression.
return DelayedStoreResult{*this, StmtDiff{Ref, nullptr, nullptr, Ref},
/*isConstant*/ false,
/*isInsideLoop*/ false, /*pNeedsUpdate=*/true};
return DelayedStoreResult{*this,
StmtDiff{Ref, nullptr, nullptr, Ref},
/*Declaration=*/VD,
/*isConstant=*/false,
/*isInsideLoop=*/false,
/*isFnScope=*/isFnScope,
/*pNeedsUpdate=*/true};
}

ReverseModeVisitor::LoopCounter::LoopCounter(ReverseModeVisitor& RMV)
Expand Down Expand Up @@ -3323,7 +3366,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitDoStmt(const DoStmt* DS) {

beginBlock(direction::reverse);
LoopCounter loopCounter(*this);

Expand Down
Loading

0 comments on commit 725f63a

Please sign in to comment.