Skip to content

Commit

Permalink
Redesign of the rangebased for loops body (#1034)
Browse files Browse the repository at this point in the history
Fixes:#1019
Fixes:#1033
  • Loading branch information
ovdiiuv authored Aug 10, 2024
1 parent 1b81084 commit 6cc83ee
Show file tree
Hide file tree
Showing 3 changed files with 358 additions and 142 deletions.
3 changes: 2 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ namespace clad {
StmtDiff VisitSwitchStmt(const clang::SwitchStmt* SS);
StmtDiff VisitCaseStmt(const clang::CaseStmt* CS);
StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS);
DeclDiff<clang::VarDecl> DifferentiateVarDecl(const clang::VarDecl* VD);
DeclDiff<clang::VarDecl> DifferentiateVarDecl(const clang::VarDecl* VD,
bool AddToBlock = true);
StmtDiff VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff
Expand Down
205 changes: 103 additions & 102 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -976,120 +976,128 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

StmtDiff
ReverseModeVisitor::VisitCXXForRangeStmt(const CXXForRangeStmt* FRS) {
beginBlock(direction::reverse);
LoopCounter loopCounter(*this);
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Scope::ContinueScope);
beginBlock(direction::reverse);

LoopCounter loopCounter(*this);
llvm::SaveAndRestore<Expr*> SaveCurrentBreakFlagExpr(
m_CurrentBreakFlagExpr);
m_CurrentBreakFlagExpr = nullptr;
auto* activeBreakContHandler = PushBreakContStmtHandler();
activeBreakContHandler->BeginCFSwitchStmtScope();
const VarDecl* LoopVD = FRS->getLoopVariable();

const Stmt* RangeDecl = FRS->getRangeStmt();
const Stmt* BeginDecl = FRS->getBeginStmt();
StmtDiff VisitRange = Visit(RangeDecl);
StmtDiff VisitBegin = Visit(BeginDecl);
Expr* BeginExpr = cast<BinaryOperator>(VisitBegin.getStmt())->getLHS();
llvm::SaveAndRestore<bool> SaveIsInside(isInsideLoop,
/*NewValue=*/false);

const auto* RangeDecl = cast<VarDecl>(FRS->getRangeStmt()->getSingleDecl());
const auto* BeginDecl = cast<VarDecl>(FRS->getBeginStmt()->getSingleDecl());

DeclDiff<VarDecl> VisitRange = DifferentiateVarDecl(RangeDecl, false);
DeclDiff<VarDecl> VisitBegin = DifferentiateVarDecl(BeginDecl, false);

beginBlock(direction::reverse);
// Create all declarations needed.
auto* BeginDeclRef = cast<DeclRefExpr>(BeginExpr);
Expr* d_BeginDeclRef = m_Variables[BeginDeclRef->getDecl()];

auto* RangeExpr =
cast<DeclRefExpr>(cast<BinaryOperator>(VisitRange.getStmt())->getLHS());

Expr* RangeInit = Clone(FRS->getRangeInit());
Expr* AssignRange =
BuildOp(BO_Assign, RangeExpr, BuildOp(UO_AddrOf, RangeInit));
Expr* AssignBegin =
BuildOp(BO_Assign, BeginDeclRef, BuildOp(UO_Deref, RangeExpr));
addToCurrentBlock(AssignRange);
addToCurrentBlock(AssignBegin);
const auto* EndDecl = cast<VarDecl>(FRS->getEndStmt()->getSingleDecl());
DeclRefExpr* beginDeclRef = BuildDeclRef(VisitBegin.getDecl());
Expr* d_beginDeclRef = m_Variables[beginDeclRef->getDecl()];
DeclRefExpr* rangeDeclRef = BuildDeclRef(VisitRange.getDecl());
Expr* d_rangeDeclRef = m_Variables[rangeDeclRef->getDecl()];

Expr* rangeInit = Clone(FRS->getRangeInit());
Expr* d_rangeInitDeclRef =
m_Variables[cast<DeclRefExpr>(rangeInit)->getDecl()];
VisitRange.getDecl_dx()->setInit(BuildOp(UO_AddrOf, d_rangeInitDeclRef));
Expr* assignAdjBegin = BuildOp(BO_Assign, d_beginDeclRef, d_rangeDeclRef);
Expr* assignRange =
BuildOp(BO_Assign, rangeDeclRef, BuildOp(UO_AddrOf, rangeInit));

addToCurrentBlock(BuildDeclStmt(VisitRange.getDecl()));
addToCurrentBlock(BuildDeclStmt(VisitRange.getDecl_dx()));
addToCurrentBlock(BuildDeclStmt(VisitBegin.getDecl()));
addToCurrentBlock(BuildDeclStmt(VisitBegin.getDecl_dx()));
addToCurrentBlock(assignAdjBegin);
addToCurrentBlock(assignRange);

Expr* EndInit = cast<BinaryOperator>(EndDecl->getInit())->getRHS();
QualType EndType = CloneType(EndDecl->getType());
std::string EndName = EndDecl->getNameAsString();
Expr* EndAssign = BuildOp(BO_Add, BuildOp(UO_Deref, RangeExpr), EndInit);
VarDecl* EndVarDecl =
BuildGlobalVarDecl(EndType, EndName, EndAssign, /*DirectInit=*/false);
DeclStmt* AssignEnd = BuildDeclStmt(EndVarDecl);

addToCurrentBlock(AssignEnd);
auto* AssignEndVarDecl =
cast<VarDecl>(cast<DeclStmt>(AssignEnd)->getSingleDecl());
DeclRefExpr* EndExpr = BuildDeclRef(AssignEndVarDecl);
Expr* IncBegin = BuildOp(UO_PreInc, BeginDeclRef);
const auto* EndDecl = cast<VarDecl>(FRS->getEndStmt()->getSingleDecl());
QualType endType = CloneType(EndDecl->getType());
std::string endName = EndDecl->getNameAsString();
Expr* endInit = Visit(EndDecl->getInit()).getExpr();
VarDecl* endVarDecl =
BuildGlobalVarDecl(endType, endName, endInit, /*DirectInit=*/false);
addToCurrentBlock(BuildDeclStmt(endVarDecl));
DeclRefExpr* endExpr = BuildDeclRef(endVarDecl);
Expr* incBegin = BuildOp(UO_PreInc, beginDeclRef);

beginBlock(direction::forward);
DeclDiff<VarDecl> LoopVDDiff = DifferentiateVarDecl(LoopVD);
Stmt* AdjLoopVDAddAssign =
Stmt* adjLoopVDAddAssign =
utils::unwrapIfSingleStmt(endBlock(direction::forward));

if ((LoopVDDiff.getDecl()->getDeclName() != LoopVD->getDeclName() ||
LoopVD->getType() != LoopVDDiff.getDecl()->getType()))
m_DeclReplacements[LoopVD] = LoopVDDiff.getDecl();
llvm::SaveAndRestore<bool> SaveIsInsideLoop(isInsideLoop,
/*NewValue=*/true);

Expr* d_IncBegin = BuildOp(UO_PreInc, d_BeginDeclRef);
Expr* d_DecBegin = BuildOp(UO_PostDec, d_BeginDeclRef);
Expr* ForwardCond = BuildOp(BO_NE, BeginDeclRef, EndExpr);
// Add item assignment statement to the body.
Expr* d_incBegin = BuildOp(UO_PreInc, d_beginDeclRef);
Expr* d_decBegin = BuildOp(UO_PostDec, d_beginDeclRef);
Expr* forwardCond = BuildOp(BO_NE, beginDeclRef, endExpr);
const Stmt* body = FRS->getBody();
StmtDiff bodyDiff = Visit(body);
StmtDiff bodyDiff =
DifferentiateLoopBody(body, loopCounter, nullptr, nullptr,
/*isForLoop=*/true);

activeBreakContHandler->EndCFSwitchStmtScope();
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);
PopBreakContStmtHandler();

StmtDiff storeLoop = StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl()));
StmtDiff storeAdjLoop =
StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl_dx()));

addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl_dx()));
Expr* CounterIncrement = loopCounter.getCounterIncrement();

Expr* LoopInit = LoopVDDiff.getDecl()->getInit();
Expr* loopInit = LoopVDDiff.getDecl()->getInit();
LoopVDDiff.getDecl()->setInit(getZeroInit(LoopVDDiff.getDecl()->getType()));
addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl()));
Expr* AssignLoop =
BuildOp(BO_Assign, BuildDeclRef(LoopVDDiff.getDecl()), LoopInit);
Expr* assignLoop =
BuildOp(BO_Assign, BuildDeclRef(LoopVDDiff.getDecl()), loopInit);

if (!LoopVD->getType()->isReferenceType()) {
Expr* d_LoopVD = BuildDeclRef(LoopVDDiff.getDecl_dx());
AdjLoopVDAddAssign =
BuildOp(BO_Assign, d_LoopVD, BuildOp(UO_Deref, d_BeginDeclRef));
adjLoopVDAddAssign =
BuildOp(BO_Assign, d_LoopVD, BuildOp(UO_Deref, d_beginDeclRef));
}

beginBlock(direction::forward);
addToCurrentBlock(CounterIncrement);
addToCurrentBlock(AdjLoopVDAddAssign);
addToCurrentBlock(AssignLoop);
addToCurrentBlock(adjLoopVDAddAssign);
addToCurrentBlock(assignLoop);
addToCurrentBlock(storeLoop.getStmt());
addToCurrentBlock(storeAdjLoop.getStmt());
CompoundStmt* LoopVDForwardDiff = endBlock(direction::forward);
CompoundStmt* bodyForward = utils::PrependAndCreateCompoundStmt(
m_Sema.getASTContext(), bodyDiff.getStmt(), LoopVDForwardDiff);

beginBlock(direction::forward);
addToCurrentBlock(d_DecBegin);
addToCurrentBlock(d_decBegin);
addToCurrentBlock(storeLoop.getStmt_dx());
addToCurrentBlock(storeAdjLoop.getStmt_dx());
CompoundStmt* LoopVDReverseDiff = endBlock(direction::forward);
CompoundStmt* bodyReverse = utils::PrependAndCreateCompoundStmt(
m_Sema.getASTContext(), bodyDiff.getStmt_dx(), LoopVDReverseDiff);

Expr* Inc = BuildOp(BO_Comma, IncBegin, d_IncBegin);
Expr* inc = BuildOp(BO_Comma, incBegin, d_incBegin);
Stmt* Forward = new (m_Context) ForStmt(
m_Context, /*Init=*/nullptr, ForwardCond, /*CondVar=*/nullptr, Inc,
m_Context, /*Init=*/nullptr, forwardCond, /*CondVar=*/nullptr, inc,
bodyForward, FRS->getForLoc(), FRS->getBeginLoc(), FRS->getEndLoc());
Expr* CounterCondition =
Expr* counterCondition =
loopCounter.getCounterConditionResult().get().second;
Expr* CounterDecrement = loopCounter.getCounterDecrement();
Expr* counterDecrement = loopCounter.getCounterDecrement();

Stmt* Reverse = bodyReverse;
addToCurrentBlock(Reverse, direction::reverse);
Reverse = endBlock(direction::reverse);

Reverse = new (m_Context)
ForStmt(m_Context, /*Init=*/nullptr, CounterCondition,
/*CondVar=*/nullptr, CounterDecrement, Reverse,
ForStmt(m_Context, /*Init=*/nullptr, counterCondition,
/*CondVar=*/nullptr, counterDecrement, Reverse,
FRS->getForLoc(), FRS->getBeginLoc(), FRS->getEndLoc());
addToCurrentBlock(Reverse, direction::reverse);
Reverse = endBlock(direction::reverse);
Expand Down Expand Up @@ -2647,18 +2655,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (!BinOp->isComparisonOp() && !BinOp->isLogicalOp())
unsupportedOpWarn(BinOp->getEndLoc());

// If either LHS or RHS is a declaration reference, visit it to avoid
// naming collision
auto* LDRE = dyn_cast<DeclRefExpr>(L);
auto* RDRE = dyn_cast<DeclRefExpr>(R);

if (!LDRE && !RDRE)
return Clone(BinOp);

Expr* LExpr = LDRE ? Visit(L).getExpr() : L;
Expr* RExpr = RDRE ? Visit(R).getExpr() : R;

return BuildOp(opCode, LExpr, RExpr);
return BuildOp(opCode, Visit(L).getExpr(), Visit(R).getExpr());
}
Expr* op = BuildOp(opCode, Ldiff.getExpr(), Rdiff.getExpr());

Expand All @@ -2685,10 +2682,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(op, ResultRef, nullptr, valueForRevPass);
}

DeclDiff<VarDecl>
ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
DeclDiff<VarDecl> ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD,
bool addToBlock) {
StmtDiff initDiff;
Expr* VDDerivedInit = nullptr;

// Local declarations are promoted to the function global scope. This
// procedure is done to make declarations visible in the reverse sweep.
// The reverse_mode_forward_pass mode does not have a reverse pass so
Expand Down Expand Up @@ -2863,7 +2861,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
getZeroInit(VDDerivedType));
else
assignToZero = GetCladZeroInit(declRef);
addToCurrentBlock(assignToZero, direction::reverse);
if (addToBlock)
addToCurrentBlock(assignToZero, direction::reverse);
}
}

Expand All @@ -2879,10 +2878,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE,
BuildOp(UnaryOperatorKind::UO_AddrOf,
initDiff.getForwSweepExpr_dx()));
addToCurrentBlock(assignDerivativeE);
if (addToBlock)
addToCurrentBlock(assignDerivativeE);
if (isInsideLoop) {
StmtDiff pushPop = StoreAndRestore(derivedVDE);
addToCurrentBlock(pushPop.getStmt(), direction::forward);
if (addToBlock)
addToCurrentBlock(pushPop.getStmt(), direction::forward);
m_LoopBlock.back().push_back(pushPop.getStmt_dx());
}
derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE);
Expand All @@ -2908,10 +2909,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (promoteToFnScope) {
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
derivedVDE, initDiff.getExpr_dx());
addToCurrentBlock(assignDerivativeE, direction::forward);
if (addToBlock)
addToCurrentBlock(assignDerivativeE, direction::forward);
if (isInsideLoop) {
auto tape = MakeCladTapeFor(derivedVDE);
addToCurrentBlock(tape.Push);
if (addToBlock)
addToCurrentBlock(tape.Push);
auto* reverseSweepDerivativePointerE =
BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop);
m_LoopBlock.back().push_back(
Expand All @@ -2925,6 +2928,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
if (derivedVDE)
m_Variables.emplace(VDClone, derivedVDE);
// Check if decl's name is the same as before. The name may be changed
// if decl name collides with something in the derivative body.
// This can happen in rare cases, e.g. when the original function
// has both y and _d_y (here _d_y collides with the name produced by
// the derivation process), e.g.
// double f(double x) {
// double y = x;
// double _d_y = x;
// }
// ->
// double f_darg0(double x) {
// double _d_x = 1;
// double _d_y = _d_x; // produced as a derivative for y
// double y = x;
// double _d__d_y = _d_x;
// double _d_y = x; // copied from original function, collides with
// _d_y
// }
if ((VD->getDeclName() != VDClone->getDeclName() ||
VD->getType() != VDClone->getType()))
m_DeclReplacements[VD] = VDClone;

return DeclDiff<VarDecl>(VDClone, VDDerived);
}
Expand Down Expand Up @@ -3027,29 +3051,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (!isLambda)
VDDiff = DifferentiateVarDecl(VD);

// Check if decl's name is the same as before. The name may be changed
// if decl name collides with something in the derivative body.
// This can happen in rare cases, e.g. when the original function
// has both y and _d_y (here _d_y collides with the name produced by
// the derivation process), e.g.
// double f(double x) {
// double y = x;
// double _d_y = x;
// }
// ->
// double f_darg0(double x) {
// double _d_x = 1;
// double _d_y = _d_x; // produced as a derivative for y
// double y = x;
// double _d__d_y = _d_x;
// double _d_y = x; // copied from original function, collides with
// _d_y
// }
if (!isLambda &&
(VDDiff.getDecl()->getDeclName() != VD->getDeclName() ||
VD->getType() != VDDiff.getDecl()->getType()))
m_DeclReplacements[VD] = VDDiff.getDecl();

// Here, we move the declaration to the function global scope.
// Initialization is replaced with an assignment operation at the same
// place as the original declaration. This procedure is done to make the
Expand Down
Loading

0 comments on commit 6cc83ee

Please sign in to comment.