Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Incorrect derivative when loops contains continue #833

Merged
merged 11 commits into from
Oct 30, 2024
2 changes: 1 addition & 1 deletion include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
/// N.
template <typename T> CUDA_HOST_DEVICE void zero_init(T* x, std::size_t N) {
for (std::size_t i = 0; i < N; ++i)
zero_init(x[i]);
zero_init(x[i]);
kchristin22 marked this conversation as resolved.
Show resolved Hide resolved
}

/// Initialize a const sized array.
Expand Down
9 changes: 9 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ namespace clad {
clang::Expr *m_Pop = nullptr;
clang::Expr *m_Push = nullptr;
ReverseModeVisitor& m_RMV;
clang::VarDecl* m_numRevIterations = nullptr;

public:
LoopCounter(ReverseModeVisitor& RMV);
Expand Down Expand Up @@ -549,6 +550,14 @@ namespace clad {
m_Ref,
clang::Sema::ConditionKind::Boolean);
}

/// Sets the number of reverse iterations to be executed.
void setNumRevIterations(clang::VarDecl* numRevIterations) {
m_numRevIterations = numRevIterations;
}

/// Returns the number of reverse iterations to be executed.
clang::VarDecl* getNumRevIterations() const { return m_numRevIterations; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: function 'getNumRevIterations' should be marked [[nodiscard]] [modernize-use-nodiscard]

Suggested change
clang::VarDecl* getNumRevIterations() const { return m_numRevIterations; }
[[nodiscard]] clang::VarDecl* getNumRevIterations() const { return m_numRevIterations; }

};

/// Helper function to differentiate a loop body.
Expand Down
46 changes: 35 additions & 11 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1372,8 +1372,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
BodyDiff.updateStmtDx(utils::unwrapIfSingleStmt(revPassCondStmts));
}

Stmt* revInit = loopCounter.getNumRevIterations()
? BuildDeclStmt(loopCounter.getNumRevIterations())
: nullptr;
Stmt* Reverse = new (m_Context)
ForStmt(m_Context, nullptr, nullptr, nullptr, CounterDecrement,
ForStmt(m_Context, revInit, nullptr, nullptr, CounterDecrement,
BodyDiff.getStmt_dx(), noLoc, noLoc, noLoc);

addToCurrentBlock(initResult.getStmt_dx(), direction::reverse);
Expand Down Expand Up @@ -3791,6 +3794,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

llvm::SaveAndRestore<bool> SaveIsInsideLoop(isInsideLoop);
isInsideLoop = true;
llvm::SaveAndRestore<Expr*> SaveCurrentBreakFlagExpr(
m_CurrentBreakFlagExpr);
m_CurrentBreakFlagExpr = nullptr;

Expr* condClone = (WS->getCond() ? Clone(WS->getCond()) : nullptr);
const VarDecl* condVarDecl = WS->getConditionVariable();
Expand Down Expand Up @@ -3849,6 +3855,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

llvm::SaveAndRestore<bool> SaveIsInsideLoop(isInsideLoop);
isInsideLoop = true;
llvm::SaveAndRestore<Expr*> SaveCurrentBreakFlagExpr(
m_CurrentBreakFlagExpr);
m_CurrentBreakFlagExpr = nullptr;

Expr* clonedCond = (DS->getCond() ? Clone(DS->getCond()) : nullptr);

Expand Down Expand Up @@ -4105,22 +4114,38 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlock));
m_LoopBlock.pop_back();

// Increment statement in the for-loop is only executed if the iteration
// did not end with a break/continue statement. Therefore, forLoopIncDiff
// should be inside the last switch case in the reverse pass.
activeBreakContHandler->EndCFSwitchStmtScope();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: use auto when initializing with a template cast to avoid duplicating the type name [modernize-use-auto]

Suggested change
activeBreakContHandler->EndCFSwitchStmtScope();
auto* forwardSS =

activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);
PopBreakContStmtHandler();

Expr* revCounter = loopCounter.getCounterConditionResult().get().second;
if (m_CurrentBreakFlagExpr) {
VarDecl* numRevIterations = BuildVarDecl(m_Context.getSizeType(),
"_numRevIterations", revCounter);
loopCounter.setNumRevIterations(numRevIterations);
}

// Increment statement in the for-loop is executed for every case
if (forLoopIncDiff) {
Stmt* forLoopIncDiffExpr = forLoopIncDiff;
if (m_CurrentBreakFlagExpr) {
m_CurrentBreakFlagExpr =
BuildOp(BinaryOperatorKind::BO_LOr,
BuildOp(BinaryOperatorKind::BO_NE, revCounter,
BuildDeclRef(loopCounter.getNumRevIterations())),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: initializing non-owner 'DefaultStmt *' with a newly created 'gsl::owner<>' [cppcoreguidelines-owning-memory]

    auto* newDefaultStmt =
    ^

BuildParens(m_CurrentBreakFlagExpr));
forLoopIncDiffExpr = clad_compat::IfStmt_Create(
m_Context, noLoc, false, nullptr, nullptr, m_CurrentBreakFlagExpr,
noLoc, noLoc, forLoopIncDiff, noLoc, nullptr);
}
if (bodyDiff.getStmt_dx()) {
bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt(
m_Context, bodyDiff.getStmt_dx(), forLoopIncDiff));
m_Context, bodyDiff.getStmt_dx(), forLoopIncDiffExpr));
} else {
bodyDiff.updateStmtDx(forLoopIncDiff);
bodyDiff.updateStmtDx(forLoopIncDiffExpr);
}
}

activeBreakContHandler->EndCFSwitchStmtScope();
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);
PopBreakContStmtHandler();

Expr* counterDecrement = loopCounter.getCounterDecrement();

// Create reverse pass loop body statements by arranging various
Expand Down Expand Up @@ -4169,7 +4194,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_CurrentBreakFlagExpr =
BuildOp(BinaryOperatorKind::BO_LAnd, m_CurrentBreakFlagExpr,
tapeBackExprForCurrentCase);

} else {
m_CurrentBreakFlagExpr = tapeBackExprForCurrentCase;
}
Expand Down
4 changes: 2 additions & 2 deletions test/Analyses/TBR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ double f2(double val) {
//CHECK-NEXT: if (!_t0)
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: --i;
//CHECK-NEXT: switch (clad::pop(_t1)) {
//CHECK-NEXT: case {{2U|2UL}}:
//CHECK-NEXT: ;
//CHECK-NEXT: --i;
//CHECK-NEXT: {
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: _d_i += _r_d0 * val;
Expand Down Expand Up @@ -167,6 +167,6 @@ double f3 (double x){
int main() {
double result[3] = {};
TEST(f1, 3); // CHECK-EXEC: {27.00}
TEST(f2, 3); // CHECK-EXEC: {9.00}
TEST(f2, 3); // CHECK-EXEC: {7.00}
TEST(f3, 3); // CHECK-EXEC: {2.00}
}
Loading
Loading