diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 15effe2fa..adc88c955 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -417,7 +417,7 @@ namespace clad { StmtDiff VisitCaseStmt(const clang::CaseStmt* CS); StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS); DeclDiff DifferentiateVarDecl(const clang::VarDecl* VD, - bool AddToBlock = true); + bool keepLocal = false); StmtDiff VisitSubstNonTypeTemplateParmExpr( const clang::SubstNonTypeTemplateParmExpr* NTTP); StmtDiff diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 347b1fffb..cae22121b 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -976,6 +976,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff ReverseModeVisitor::VisitCXXForRangeStmt(const CXXForRangeStmt* FRS) { + const auto* RangeDecl = cast(FRS->getRangeStmt()->getSingleDecl()); + const auto* BeginDecl = cast(FRS->getBeginStmt()->getSingleDecl()); + DeclDiff VisitRange = + DifferentiateVarDecl(RangeDecl, /*keepLocal=*/true); + DeclDiff VisitBegin = + DifferentiateVarDecl(BeginDecl, /*keepLocal=*/true); + beginBlock(direction::reverse); LoopCounter loopCounter(*this); beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope | @@ -991,33 +998,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SaveAndRestore SaveIsInside(isInsideLoop, /*NewValue=*/false); - const auto* RangeDecl = cast(FRS->getRangeStmt()->getSingleDecl()); - const auto* BeginDecl = cast(FRS->getBeginStmt()->getSingleDecl()); - - DeclDiff VisitRange = DifferentiateVarDecl(RangeDecl, false); - DeclDiff VisitBegin = DifferentiateVarDecl(BeginDecl, false); - beginBlock(direction::reverse); // Create all declarations needed. DeclRefExpr* beginDeclRef = BuildDeclRef(VisitBegin.getDecl()); Expr* d_beginDeclRef = m_Variables[beginDeclRef->getDecl()]; - DeclRefExpr* rangeDeclRef = BuildDeclRef(VisitRange.getDecl()); - Expr* d_rangeDeclRef = m_Variables[rangeDeclRef->getDecl()]; - - Expr* rangeInit = Clone(FRS->getRangeInit()); - Expr* d_rangeInitDeclRef = - m_Variables[cast(rangeInit)->getDecl()]; - VisitRange.getDecl_dx()->setInit(BuildOp(UO_AddrOf, d_rangeInitDeclRef)); - Expr* assignAdjBegin = BuildOp(BO_Assign, d_beginDeclRef, d_rangeDeclRef); - Expr* assignRange = - BuildOp(BO_Assign, rangeDeclRef, BuildOp(UO_AddrOf, rangeInit)); - addToCurrentBlock(BuildDeclStmt(VisitRange.getDecl())); addToCurrentBlock(BuildDeclStmt(VisitRange.getDecl_dx())); addToCurrentBlock(BuildDeclStmt(VisitBegin.getDecl())); addToCurrentBlock(BuildDeclStmt(VisitBegin.getDecl_dx())); - addToCurrentBlock(assignAdjBegin); - addToCurrentBlock(assignRange); const auto* EndDecl = cast(FRS->getEndStmt()->getSingleDecl()); QualType endType = CloneType(EndDecl->getType()); @@ -2718,7 +2706,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } DeclDiff ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD, - bool addToBlock) { + bool keepLocal) { StmtDiff initDiff; Expr* VDDerivedInit = nullptr; @@ -2728,7 +2716,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // declarations don't have to be moved to the function global scope. bool promoteToFnScope = !getCurrentScope()->isFunctionScope() && - m_DiffReq.Mode != DiffMode::reverse_mode_forward_pass; + m_DiffReq.Mode != DiffMode::reverse_mode_forward_pass && !keepLocal; QualType VDCloneType; QualType VDDerivedType; // If the cloned declaration is moved to the function global scope, @@ -2896,7 +2884,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, getZeroInit(VDDerivedType)); else assignToZero = GetCladZeroInit(declRef); - if (addToBlock) + if (!keepLocal) addToCurrentBlock(assignToZero, direction::reverse); } } @@ -2913,11 +2901,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE, BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getForwSweepExpr_dx())); - if (addToBlock) - addToCurrentBlock(assignDerivativeE); + addToCurrentBlock(assignDerivativeE); if (isInsideLoop) { StmtDiff pushPop = StoreAndRestore(derivedVDE); - if (addToBlock) + if (!keepLocal) addToCurrentBlock(pushPop.getStmt(), direction::forward); m_LoopBlock.back().push_back(pushPop.getStmt_dx()); } @@ -2944,11 +2931,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (promoteToFnScope) { Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE, initDiff.getExpr_dx()); - if (addToBlock) - addToCurrentBlock(assignDerivativeE, direction::forward); + addToCurrentBlock(assignDerivativeE, direction::forward); if (isInsideLoop) { auto tape = MakeCladTapeFor(derivedVDE); - if (addToBlock) + if (!keepLocal) addToCurrentBlock(tape.Push); auto* reverseSweepDerivativePointerE = BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop); diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 6b0a120d4..d3a568aff 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -2702,14 +2702,12 @@ double fn34(double x, double y){ //CHECK-NEXT: double r = 0; //CHECK-NEXT: double _d_a[3] = {0}; //CHECK-NEXT: double a[3] = {y, x * y, x * x + y}; -//CHECK-NEXT: unsigned {{int|long|long long}} _t0 = {{0U|0UL|0ULL}}; -//CHECK-NEXT: double (*__range10)[3] = &a; -//CHECK-NEXT: double (*_d___range1)[3] = &_d_a; -//CHECK-NEXT: double *__begin10 = *__range10; -//CHECK-NEXT: double *_d___begin1 = 0; -//CHECK-NEXT: _d___begin1 = *_d___range1; -//CHECK-NEXT: __range10 = &a; -//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL|0ULL}}; +//CHECK-NEXT: double (&__range10)[3] = a; +//CHECK-NEXT: double (&_d___range1)[3] = _d_a; +//CHECK-NEXT: double *__begin10 = __range10; +//CHECK-NEXT: double *_d___begin1 = _d___range1; +//CHECK-NEXT: double *__end10 = __range10 + {{3|3L}}; //CHECK-NEXT: double *_d_i = 0; //CHECK-NEXT: double *i = 0; //CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { @@ -2778,14 +2776,12 @@ double fn35(double x, double y){ //CHECK-NEXT: double r = 0; //CHECK-NEXT: double _d_a[3] = {0}; //CHECK-NEXT: double a[3] = {x, x * y, 0}; -//CHECK-NEXT: unsigned {{int|long|long long}} _t0 = {{0U|0UL|0ULL}}; -//CHECK-NEXT: double (*__range10)[3] = &a; -//CHECK-NEXT: double (*_d___range1)[3] = &_d_a; -//CHECK-NEXT: double *__begin10 = *__range10; -//CHECK-NEXT: double *_d___begin1 = 0; -//CHECK-NEXT: _d___begin1 = *_d___range1; -//CHECK-NEXT: __range10 = &a; -//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL|0ULL}}; +//CHECK-NEXT: double (&__range10)[3] = a; +//CHECK-NEXT: double (&_d___range1)[3] = _d_a; +//CHECK-NEXT: double *__begin10 = __range10; +//CHECK-NEXT: double *_d___begin1 = _d___range1; +//CHECK-NEXT: double *__end10 = __range10 + {{3|3L}}; //CHECK-NEXT: double *_d_i = 0; //CHECK-NEXT: double *i = 0; //CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { @@ -2797,13 +2793,11 @@ double fn35(double x, double y){ //CHECK-NEXT: } //CHECK-NEXT: _t0++; //CHECK-NEXT: clad::push(_t1, {{0U|0UL|0ULL}}); -//CHECK-NEXT: double (*__range20)[3] = &a; -//CHECK-NEXT: double (*_d___range2)[3] = &_d_a; -//CHECK-NEXT: double *__begin20 = *__range20; -//CHECK-NEXT: double *_d___begin2 = 0; -//CHECK-NEXT: _d___begin2 = *_d___range2; -//CHECK-NEXT: __range20 = &a; -//CHECK-NEXT: double *__end20 = *__range20 + {{3|3L}}; +//CHECK-NEXT: double (&__range20)[3] = a; +//CHECK-NEXT: double (&_d___range2)[3] = _d_a; +//CHECK-NEXT: double *__begin20 = __range20; +//CHECK-NEXT: double *_d___begin2 = _d___range2; +//CHECK-NEXT: double *__end20 = __range20 + {{3|3L}}; //CHECK-NEXT: double *_d_j = 0; //CHECK-NEXT: double *j = 0; //CHECK-NEXT: for (; __begin20 != __end20; ++__begin20 , ++_d___begin2) { @@ -2906,14 +2900,12 @@ double fn36(double x, double y){ //CHECK-NEXT: double a[3] = {1, 2, 3}; //CHECK-NEXT: double _d_sum = 0; //CHECK-NEXT: double sum = 0; -//CHECK-NEXT: unsigned {{int|long|long long}} _t0 = {{0U|0UL|0ULL}}; -//CHECK-NEXT: double (*__range10)[3] = &a; -//CHECK-NEXT: double (*_d___range1)[3] = &_d_a; -//CHECK-NEXT: double *__begin10 = *__range10; -//CHECK-NEXT: double *_d___begin1 = 0; -//CHECK-NEXT: _d___begin1 = *_d___range1; -//CHECK-NEXT: __range10 = &a; -//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL|0ULL}}; +//CHECK-NEXT: double (&__range10)[3] = a; +//CHECK-NEXT: double (&_d___range1)[3] = _d_a; +//CHECK-NEXT: double *__begin10 = __range10; +//CHECK-NEXT: double *_d___begin1 = _d___range1; +//CHECK-NEXT: double *__end10 = __range10 + {{3|3L}}; //CHECK-NEXT: double _d_i = 0; //CHECK-NEXT: double i = 0; //CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { @@ -2989,14 +2981,12 @@ double fn37(double x, double y) { //CHECK-NEXT: double range[3] = {x, 4., y}; //CHECK-NEXT: double _d_sum = 0; //CHECK-NEXT: double sum = 0; -//CHECK-NEXT: unsigned {{int|long|long long}} _t0 = {{0U|0UL|0ULL}}; -//CHECK-NEXT: double (*__range10)[3] = ⦥ -//CHECK-NEXT: double (*_d___range1)[3] = &_d_range; -//CHECK-NEXT: double *__begin10 = *__range10; -//CHECK-NEXT: double *_d___begin1 = 0; -//CHECK-NEXT: _d___begin1 = *_d___range1; -//CHECK-NEXT: __range10 = ⦥ -//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL|0ULL}}; +//CHECK-NEXT: double (&__range10)[3] = range; +//CHECK-NEXT: double (&_d___range1)[3] = _d_range; +//CHECK-NEXT: double *__begin10 = __range10; +//CHECK-NEXT: double *_d___begin1 = _d___range1; +//CHECK-NEXT: double *__end10 = __range10 + {{3|3L}}; //CHECK-NEXT: double _d_elem = 0; //CHECK-NEXT: double elem = 0; //CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) {