Skip to content

Commit

Permalink
Merge "_t" variable declarations with their initialization points on …
Browse files Browse the repository at this point in the history
…function global scope.

The purpose of this commit is to merge decl stmts (e.g. ``double _t0;``) with their initialization points (e.g.``_t0 = x;``) in the reverse mode whenever this is possible. All of the simplifications happen when the original statement is used on the function global scope because otherwise, the declaration has to be "promoted to the function global scope". The simplification is done for `_t` variables created by `GlobalStoreAndRef`, `DelayedGlobalStoreAndRef`, and `StoreAndRestore`, e.g.
```
double _t0;
...
_t0 = x;
```
can be refactored to
```
...
double _t0 = x;
```
Fixes vgvassilev#525.
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Jul 30, 2024
1 parent da53c3d commit 7a32271
Show file tree
Hide file tree
Showing 26 changed files with 405 additions and 682 deletions.
10 changes: 7 additions & 3 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,18 @@ namespace clad {
struct DelayedStoreResult {
ReverseModeVisitor& V;
StmtDiff Result;
clang::VarDecl* Declaration;
bool isConstant;
bool isInsideLoop;
bool isFnScope;
bool needsUpdate;
DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult,
bool pIsConstant, bool pIsInsideLoop,
clang::VarDecl* pDeclaration, bool pIsConstant,
bool pIsInsideLoop, bool pIsFnScope,
bool pNeedsUpdate = false)
: V(pV), Result(pResult), isConstant(pIsConstant),
isInsideLoop(pIsInsideLoop), needsUpdate(pNeedsUpdate) {}
: V(pV), Result(pResult), Declaration(pDeclaration),
isConstant(pIsConstant), isInsideLoop(pIsInsideLoop),
isFnScope(pIsFnScope), needsUpdate(pNeedsUpdate) {}
void Finalize(clang::Expr* New);
};

Expand Down
108 changes: 75 additions & 33 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

auto VisitBranch = [&](const Expr* Branch,
Expr* dfdx) -> std::pair<StmtDiff, StmtDiff> {
beginScope(Scope::DeclScope);
auto Result = DifferentiateSingleExpr(Branch, dfdx);
endScope();
StmtDiff BranchDiff = Result.first;
StmtDiff ExprDiff = Result.second;
Stmt* Forward = utils::unwrapIfSingleStmt(BranchDiff.getStmt());
Expand Down Expand Up @@ -975,10 +977,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitForStmt(const ForStmt* FS) {
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Scope::ContinueScope);
beginBlock(direction::reverse);
LoopCounter loopCounter(*this);
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Scope::ContinueScope);
llvm::SaveAndRestore<Expr*> SaveCurrentBreakFlagExpr(
m_CurrentBreakFlagExpr);
m_CurrentBreakFlagExpr = nullptr;
Expand Down Expand Up @@ -2366,16 +2368,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// and restore it in the reverse pass
if (m_DiffReq.shouldBeRecorded(L)) {
StmtDiff pushPop = StoreAndRestore(LCloned);
addToCurrentBlock(pushPop.getExpr(), direction::forward);
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
addToCurrentBlock(pushPop.getStmt(), direction::forward);
addToCurrentBlock(pushPop.getStmt_dx(), direction::reverse);
}

// We need to store values of derivative pointer variables in forward pass
// and restore them in reverse pass.
if (isPointerOp) {
StmtDiff pushPop = StoreAndRestore(Ldiff.getExpr_dx());
addToCurrentBlock(pushPop.getExpr(), direction::forward);
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
addToCurrentBlock(pushPop.getStmt(), direction::forward);
addToCurrentBlock(pushPop.getStmt_dx(), direction::reverse);
}

if (m_ExternalSource)
Expand Down Expand Up @@ -2750,8 +2752,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(assignDerivativeE);
if (isInsideLoop) {
StmtDiff pushPop = StoreAndRestore(derivedVDE);
addToCurrentBlock(pushPop.getExpr(), direction::forward);
m_LoopBlock.back().push_back(pushPop.getExpr_dx());
addToCurrentBlock(pushPop.getStmt(), direction::forward);
m_LoopBlock.back().push_back(pushPop.getStmt_dx());
}
derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE);
}
Expand Down Expand Up @@ -3167,9 +3169,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return CladTape.Last();
}

Expr* Ref = BuildDeclRef(GlobalStoreImpl(Type, prefix));
Expr* Set = BuildOp(BO_Assign, Ref, E);
addToCurrentBlock(Set, direction::forward);
VarDecl* VD = BuildGlobalVarDecl(Type, prefix);
DeclStmt* decl = BuildDeclStmt(VD);
Expr* Ref = BuildDeclRef(VD);
bool isFnScope = getCurrentScope()->isFunctionScope() ||
m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass;
if (isFnScope) {
addToCurrentBlock(decl, direction::forward);
m_Sema.AddInitializerToDecl(VD, E, /*DirectInit=*/true);
VD->setInitStyle(VarDecl::InitializationStyle::CInit);
} else {
addToBlock(decl, m_Globals);
Expr* Set = BuildOp(BO_Assign, Ref, E);
addToCurrentBlock(Set, direction::forward);
}

return Ref;
}
Expand All @@ -3183,6 +3196,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

StmtDiff ReverseModeVisitor::StoreAndRestore(clang::Expr* E,
llvm::StringRef prefix) {
assert(E && "must be provided");
auto Type = getNonConstType(E->getType(), m_Context, m_Sema);

if (isInsideLoop) {
Expand All @@ -3197,17 +3211,26 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (const auto* AT = dyn_cast<ArrayType>(Type))
init = getArraySizeExpr(AT, m_Context, *this);

Expr* Ref = BuildDeclRef(GlobalStoreImpl(Type, prefix, init));
if (E) {
Expr* Store = BuildOp(BO_Assign, Ref, Clone(E));
Expr* Restore = nullptr;
if (E->isModifiableLvalue(m_Context) == Expr::MLV_Valid) {
auto* r = Clone(E);
Restore = BuildOp(BO_Assign, r, Ref);
}
return {Store, Restore};
VarDecl* VD = BuildGlobalVarDecl(Type, prefix, init);
DeclStmt* decl = BuildDeclStmt(VD);
Expr* Ref = BuildDeclRef(VD);
Stmt* Store = nullptr;
bool isFnScope = getCurrentScope()->isFunctionScope() ||
m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass;
if (isFnScope) {
Store = decl;
m_Sema.AddInitializerToDecl(VD, E, /*DirectInit=*/true);
VD->setInitStyle(VarDecl::InitializationStyle::CInit);
} else {
addToBlock(decl, m_Globals);
Store = BuildOp(BO_Assign, Ref, Clone(E));
}
return {};

Stmt* Restore = nullptr;
if (E->isModifiableLvalue(m_Context) == Expr::MLV_Valid)
Restore = BuildOp(BO_Assign, Clone(E), Ref);

return {Store, Restore};
}

void ReverseModeVisitor::DelayedStoreResult::Finalize(Expr* New) {
Expand All @@ -3217,6 +3240,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto* Push = cast<CallExpr>(Result.getExpr());
unsigned lastArg = Push->getNumArgs() - 1;
Push->setArg(lastArg, V.m_Sema.DefaultLvalueConversion(New).get());
} else if (isFnScope) {
V.m_Sema.AddInitializerToDecl(Declaration, New, true);
Declaration->setInitStyle(VarDecl::InitializationStyle::CInit);
V.addToCurrentBlock(V.BuildDeclStmt(Declaration), direction::forward);
} else {
V.addToCurrentBlock(V.BuildOp(BO_Assign, Result.getExpr(), New),
direction::forward);
Expand All @@ -3232,26 +3259,42 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr::EvalResult evalRes;
bool isConst =
clad_compat::Expr_EvaluateAsConstantExpr(E, evalRes, m_Context);
return DelayedStoreResult{*this, Ediff,
/*isConstant*/ isConst,
/*isInsideLoop*/ false,
return DelayedStoreResult{*this,
Ediff,
/*Declaration=*/nullptr,
/*isConstant=*/isConst,
/*isInsideLoop=*/false,
/*isFnScope=*/false,
/*pNeedsUpdate=*/false};
}
if (isInsideLoop) {
Expr* dummy = E;
auto CladTape = MakeCladTapeFor(dummy);
Expr* Push = CladTape.Push;
Expr* Pop = CladTape.Pop;
return DelayedStoreResult{*this, StmtDiff{Push, nullptr, nullptr, Pop},
/*isConstant*/ false,
/*isInsideLoop*/ true, /*pNeedsUpdate=*/true};
return DelayedStoreResult{*this,
StmtDiff{Push, nullptr, nullptr, Pop},
/*Declaration=*/nullptr,
/*isConstant=*/false,
/*isInsideLoop=*/true,
/*isFnScope=*/false,
/*pNeedsUpdate=*/true};
}
Expr* Ref = BuildDeclRef(GlobalStoreImpl(
getNonConstType(E->getType(), m_Context, m_Sema), prefix));
bool isFnScope = getCurrentScope()->isFunctionScope() ||
m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass;
VarDecl* VD = BuildGlobalVarDecl(
getNonConstType(E->getType(), m_Context, m_Sema), prefix);
Expr* Ref = BuildDeclRef(VD);
if (!isFnScope)
addToBlock(BuildDeclStmt(VD), m_Globals);
// Return reference to the declaration instead of original expression.
return DelayedStoreResult{*this, StmtDiff{Ref, nullptr, nullptr, Ref},
/*isConstant*/ false,
/*isInsideLoop*/ false, /*pNeedsUpdate=*/true};
return DelayedStoreResult{*this,
StmtDiff{Ref, nullptr, nullptr, Ref},
/*Declaration=*/VD,
/*isConstant=*/false,
/*isInsideLoop=*/false,
/*isFnScope=*/isFnScope,
/*pNeedsUpdate=*/true};
}

ReverseModeVisitor::LoopCounter::LoopCounter(ReverseModeVisitor& RMV)
Expand Down Expand Up @@ -3323,7 +3366,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitDoStmt(const DoStmt* DS) {

beginBlock(direction::reverse);
LoopCounter loopCounter(*this);

Expand Down
Loading

0 comments on commit 7a32271

Please sign in to comment.