diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index a161f1f58..8d899a4ac 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -411,7 +411,8 @@ namespace clad { StmtDiff VisitSwitchStmt(const clang::SwitchStmt* SS); StmtDiff VisitCaseStmt(const clang::CaseStmt* CS); StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS); - DeclDiff DifferentiateVarDecl(const clang::VarDecl* VD); + DeclDiff DifferentiateVarDecl(const clang::VarDecl* VD, + bool AddToBlock = true); StmtDiff VisitSubstNonTypeTemplateParmExpr( const clang::SubstNonTypeTemplateParmExpr* NTTP); StmtDiff diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index fb1102d66..8308da4b7 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -976,91 +976,99 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff ReverseModeVisitor::VisitCXXForRangeStmt(const CXXForRangeStmt* FRS) { + beginBlock(direction::reverse); + LoopCounter loopCounter(*this); beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope | Scope::ContinueScope); - beginBlock(direction::reverse); - LoopCounter loopCounter(*this); + llvm::SaveAndRestore SaveCurrentBreakFlagExpr( + m_CurrentBreakFlagExpr); + m_CurrentBreakFlagExpr = nullptr; + auto* activeBreakContHandler = PushBreakContStmtHandler(); + activeBreakContHandler->BeginCFSwitchStmtScope(); const VarDecl* LoopVD = FRS->getLoopVariable(); - const Stmt* RangeDecl = FRS->getRangeStmt(); - const Stmt* BeginDecl = FRS->getBeginStmt(); - StmtDiff VisitRange = Visit(RangeDecl); - StmtDiff VisitBegin = Visit(BeginDecl); - Expr* BeginExpr = cast(VisitBegin.getStmt())->getLHS(); + llvm::SaveAndRestore SaveIsInside(isInsideLoop, + /*NewValue=*/false); + + const auto* RangeDecl = cast(FRS->getRangeStmt()->getSingleDecl()); + const auto* BeginDecl = cast(FRS->getBeginStmt()->getSingleDecl()); + + DeclDiff VisitRange = DifferentiateVarDecl(RangeDecl, false); + DeclDiff VisitBegin = DifferentiateVarDecl(BeginDecl, false); beginBlock(direction::reverse); // Create all declarations needed. - auto* BeginDeclRef = cast(BeginExpr); - Expr* d_BeginDeclRef = m_Variables[BeginDeclRef->getDecl()]; - - auto* RangeExpr = - cast(cast(VisitRange.getStmt())->getLHS()); - - Expr* RangeInit = Clone(FRS->getRangeInit()); - Expr* AssignRange = - BuildOp(BO_Assign, RangeExpr, BuildOp(UO_AddrOf, RangeInit)); - Expr* AssignBegin = - BuildOp(BO_Assign, BeginDeclRef, BuildOp(UO_Deref, RangeExpr)); - addToCurrentBlock(AssignRange); - addToCurrentBlock(AssignBegin); - const auto* EndDecl = cast(FRS->getEndStmt()->getSingleDecl()); + DeclRefExpr* beginDeclRef = BuildDeclRef(VisitBegin.getDecl()); + Expr* d_beginDeclRef = m_Variables[beginDeclRef->getDecl()]; + DeclRefExpr* rangeDeclRef = BuildDeclRef(VisitRange.getDecl()); + Expr* d_rangeDeclRef = m_Variables[rangeDeclRef->getDecl()]; + + Expr* rangeInit = Clone(FRS->getRangeInit()); + Expr* d_rangeInitDeclRef = + m_Variables[cast(rangeInit)->getDecl()]; + VisitRange.getDecl_dx()->setInit(BuildOp(UO_AddrOf, d_rangeInitDeclRef)); + Expr* assignAdjBegin = BuildOp(BO_Assign, d_beginDeclRef, d_rangeDeclRef); + Expr* assignRange = + BuildOp(BO_Assign, rangeDeclRef, BuildOp(UO_AddrOf, rangeInit)); + + addToCurrentBlock(BuildDeclStmt(VisitRange.getDecl())); + addToCurrentBlock(BuildDeclStmt(VisitRange.getDecl_dx())); + addToCurrentBlock(BuildDeclStmt(VisitBegin.getDecl())); + addToCurrentBlock(BuildDeclStmt(VisitBegin.getDecl_dx())); + addToCurrentBlock(assignAdjBegin); + addToCurrentBlock(assignRange); - Expr* EndInit = cast(EndDecl->getInit())->getRHS(); - QualType EndType = CloneType(EndDecl->getType()); - std::string EndName = EndDecl->getNameAsString(); - Expr* EndAssign = BuildOp(BO_Add, BuildOp(UO_Deref, RangeExpr), EndInit); - VarDecl* EndVarDecl = - BuildGlobalVarDecl(EndType, EndName, EndAssign, /*DirectInit=*/false); - DeclStmt* AssignEnd = BuildDeclStmt(EndVarDecl); - - addToCurrentBlock(AssignEnd); - auto* AssignEndVarDecl = - cast(cast(AssignEnd)->getSingleDecl()); - DeclRefExpr* EndExpr = BuildDeclRef(AssignEndVarDecl); - Expr* IncBegin = BuildOp(UO_PreInc, BeginDeclRef); + const auto* EndDecl = cast(FRS->getEndStmt()->getSingleDecl()); + QualType endType = CloneType(EndDecl->getType()); + std::string endName = EndDecl->getNameAsString(); + Expr* endInit = Visit(EndDecl->getInit()).getExpr(); + VarDecl* endVarDecl = + BuildGlobalVarDecl(endType, endName, endInit, /*DirectInit=*/false); + addToCurrentBlock(BuildDeclStmt(endVarDecl)); + DeclRefExpr* endExpr = BuildDeclRef(endVarDecl); + Expr* incBegin = BuildOp(UO_PreInc, beginDeclRef); beginBlock(direction::forward); DeclDiff LoopVDDiff = DifferentiateVarDecl(LoopVD); - Stmt* AdjLoopVDAddAssign = + Stmt* adjLoopVDAddAssign = utils::unwrapIfSingleStmt(endBlock(direction::forward)); - if ((LoopVDDiff.getDecl()->getDeclName() != LoopVD->getDeclName() || - LoopVD->getType() != LoopVDDiff.getDecl()->getType())) - m_DeclReplacements[LoopVD] = LoopVDDiff.getDecl(); llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop, /*NewValue=*/true); - Expr* d_IncBegin = BuildOp(UO_PreInc, d_BeginDeclRef); - Expr* d_DecBegin = BuildOp(UO_PostDec, d_BeginDeclRef); - Expr* ForwardCond = BuildOp(BO_NE, BeginDeclRef, EndExpr); - // Add item assignment statement to the body. + Expr* d_incBegin = BuildOp(UO_PreInc, d_beginDeclRef); + Expr* d_decBegin = BuildOp(UO_PostDec, d_beginDeclRef); + Expr* forwardCond = BuildOp(BO_NE, beginDeclRef, endExpr); const Stmt* body = FRS->getBody(); - StmtDiff bodyDiff = Visit(body); + StmtDiff bodyDiff = + DifferentiateLoopBody(body, loopCounter, nullptr, nullptr, + /*isForLoop=*/true); + + activeBreakContHandler->EndCFSwitchStmtScope(); + activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); + PopBreakContStmtHandler(); StmtDiff storeLoop = StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl())); StmtDiff storeAdjLoop = StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl_dx())); - addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl_dx())); - Expr* CounterIncrement = loopCounter.getCounterIncrement(); - Expr* LoopInit = LoopVDDiff.getDecl()->getInit(); + Expr* loopInit = LoopVDDiff.getDecl()->getInit(); LoopVDDiff.getDecl()->setInit(getZeroInit(LoopVDDiff.getDecl()->getType())); addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl())); - Expr* AssignLoop = - BuildOp(BO_Assign, BuildDeclRef(LoopVDDiff.getDecl()), LoopInit); + Expr* assignLoop = + BuildOp(BO_Assign, BuildDeclRef(LoopVDDiff.getDecl()), loopInit); if (!LoopVD->getType()->isReferenceType()) { Expr* d_LoopVD = BuildDeclRef(LoopVDDiff.getDecl_dx()); - AdjLoopVDAddAssign = - BuildOp(BO_Assign, d_LoopVD, BuildOp(UO_Deref, d_BeginDeclRef)); + adjLoopVDAddAssign = + BuildOp(BO_Assign, d_LoopVD, BuildOp(UO_Deref, d_beginDeclRef)); } beginBlock(direction::forward); - addToCurrentBlock(CounterIncrement); - addToCurrentBlock(AdjLoopVDAddAssign); - addToCurrentBlock(AssignLoop); + addToCurrentBlock(adjLoopVDAddAssign); + addToCurrentBlock(assignLoop); addToCurrentBlock(storeLoop.getStmt()); addToCurrentBlock(storeAdjLoop.getStmt()); CompoundStmt* LoopVDForwardDiff = endBlock(direction::forward); @@ -1068,28 +1076,28 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Sema.getASTContext(), bodyDiff.getStmt(), LoopVDForwardDiff); beginBlock(direction::forward); - addToCurrentBlock(d_DecBegin); + addToCurrentBlock(d_decBegin); addToCurrentBlock(storeLoop.getStmt_dx()); addToCurrentBlock(storeAdjLoop.getStmt_dx()); CompoundStmt* LoopVDReverseDiff = endBlock(direction::forward); CompoundStmt* bodyReverse = utils::PrependAndCreateCompoundStmt( m_Sema.getASTContext(), bodyDiff.getStmt_dx(), LoopVDReverseDiff); - Expr* Inc = BuildOp(BO_Comma, IncBegin, d_IncBegin); + Expr* inc = BuildOp(BO_Comma, incBegin, d_incBegin); Stmt* Forward = new (m_Context) ForStmt( - m_Context, /*Init=*/nullptr, ForwardCond, /*CondVar=*/nullptr, Inc, + m_Context, /*Init=*/nullptr, forwardCond, /*CondVar=*/nullptr, inc, bodyForward, FRS->getForLoc(), FRS->getBeginLoc(), FRS->getEndLoc()); - Expr* CounterCondition = + Expr* counterCondition = loopCounter.getCounterConditionResult().get().second; - Expr* CounterDecrement = loopCounter.getCounterDecrement(); + Expr* counterDecrement = loopCounter.getCounterDecrement(); Stmt* Reverse = bodyReverse; addToCurrentBlock(Reverse, direction::reverse); Reverse = endBlock(direction::reverse); Reverse = new (m_Context) - ForStmt(m_Context, /*Init=*/nullptr, CounterCondition, - /*CondVar=*/nullptr, CounterDecrement, Reverse, + ForStmt(m_Context, /*Init=*/nullptr, counterCondition, + /*CondVar=*/nullptr, counterDecrement, Reverse, FRS->getForLoc(), FRS->getBeginLoc(), FRS->getEndLoc()); addToCurrentBlock(Reverse, direction::reverse); Reverse = endBlock(direction::reverse); @@ -2647,18 +2655,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!BinOp->isComparisonOp() && !BinOp->isLogicalOp()) unsupportedOpWarn(BinOp->getEndLoc()); - // If either LHS or RHS is a declaration reference, visit it to avoid - // naming collision - auto* LDRE = dyn_cast(L); - auto* RDRE = dyn_cast(R); - - if (!LDRE && !RDRE) - return Clone(BinOp); - - Expr* LExpr = LDRE ? Visit(L).getExpr() : L; - Expr* RExpr = RDRE ? Visit(R).getExpr() : R; - - return BuildOp(opCode, LExpr, RExpr); + return BuildOp(opCode, Visit(L).getExpr(), Visit(R).getExpr()); } Expr* op = BuildOp(opCode, Ldiff.getExpr(), Rdiff.getExpr()); @@ -2685,10 +2682,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(op, ResultRef, nullptr, valueForRevPass); } - DeclDiff - ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { + DeclDiff ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD, + bool addToBlock) { StmtDiff initDiff; Expr* VDDerivedInit = nullptr; + // Local declarations are promoted to the function global scope. This // procedure is done to make declarations visible in the reverse sweep. // The reverse_mode_forward_pass mode does not have a reverse pass so @@ -2863,7 +2861,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, getZeroInit(VDDerivedType)); else assignToZero = GetCladZeroInit(declRef); - addToCurrentBlock(assignToZero, direction::reverse); + if (addToBlock) + addToCurrentBlock(assignToZero, direction::reverse); } } @@ -2879,10 +2878,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE, BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getForwSweepExpr_dx())); - addToCurrentBlock(assignDerivativeE); + if (addToBlock) + addToCurrentBlock(assignDerivativeE); if (isInsideLoop) { StmtDiff pushPop = StoreAndRestore(derivedVDE); - addToCurrentBlock(pushPop.getStmt(), direction::forward); + if (addToBlock) + addToCurrentBlock(pushPop.getStmt(), direction::forward); m_LoopBlock.back().push_back(pushPop.getStmt_dx()); } derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE); @@ -2908,10 +2909,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (promoteToFnScope) { Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE, initDiff.getExpr_dx()); - addToCurrentBlock(assignDerivativeE, direction::forward); + if (addToBlock) + addToCurrentBlock(assignDerivativeE, direction::forward); if (isInsideLoop) { auto tape = MakeCladTapeFor(derivedVDE); - addToCurrentBlock(tape.Push); + if (addToBlock) + addToCurrentBlock(tape.Push); auto* reverseSweepDerivativePointerE = BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop); m_LoopBlock.back().push_back( @@ -2925,6 +2928,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } if (derivedVDE) m_Variables.emplace(VDClone, derivedVDE); + // Check if decl's name is the same as before. The name may be changed + // if decl name collides with something in the derivative body. + // This can happen in rare cases, e.g. when the original function + // has both y and _d_y (here _d_y collides with the name produced by + // the derivation process), e.g. + // double f(double x) { + // double y = x; + // double _d_y = x; + // } + // -> + // double f_darg0(double x) { + // double _d_x = 1; + // double _d_y = _d_x; // produced as a derivative for y + // double y = x; + // double _d__d_y = _d_x; + // double _d_y = x; // copied from original function, collides with + // _d_y + // } + if ((VD->getDeclName() != VDClone->getDeclName() || + VD->getType() != VDClone->getType())) + m_DeclReplacements[VD] = VDClone; return DeclDiff(VDClone, VDDerived); } @@ -3027,29 +3051,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!isLambda) VDDiff = DifferentiateVarDecl(VD); - // Check if decl's name is the same as before. The name may be changed - // if decl name collides with something in the derivative body. - // This can happen in rare cases, e.g. when the original function - // has both y and _d_y (here _d_y collides with the name produced by - // the derivation process), e.g. - // double f(double x) { - // double y = x; - // double _d_y = x; - // } - // -> - // double f_darg0(double x) { - // double _d_x = 1; - // double _d_y = _d_x; // produced as a derivative for y - // double y = x; - // double _d__d_y = _d_x; - // double _d_y = x; // copied from original function, collides with - // _d_y - // } - if (!isLambda && - (VDDiff.getDecl()->getDeclName() != VD->getDeclName() || - VD->getType() != VDDiff.getDecl()->getType())) - m_DeclReplacements[VD] = VDDiff.getDecl(); - // Here, we move the declaration to the function global scope. // Initialization is replaced with an assignment operation at the same // place as the original declaration. This procedure is done to make the diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 7dfd0ba9a..e8d4e6ca2 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -2695,11 +2695,6 @@ double fn34(double x, double y){ } //CHECK: void fn34_grad(double x, double y, double *_d_x, double *_d_y) { -//CHECK-NEXT: unsigned {{int|long}} _t0; -//CHECK-NEXT: double (*_d___range1)[3] = 0; -//CHECK-NEXT: double (*__range10)[3] = {}; -//CHECK-NEXT: double *_d___begin1 = 0; -//CHECK-NEXT: double *__begin10 = 0; //CHECK-NEXT: clad::tape _t1 = {}; //CHECK-NEXT: clad::tape _t2 = {}; //CHECK-NEXT: clad::tape _t3 = {}; @@ -2707,22 +2702,24 @@ double fn34(double x, double y){ //CHECK-NEXT: double r = 0; //CHECK-NEXT: double _d_a[3] = {0}; //CHECK-NEXT: double a[3] = {y, x * y, x * x + y}; -//CHECK-NEXT: _t0 = {{0U|0UL}}; -//CHECK-NEXT: _d___range1 = &_d_a; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: double (*__range10)[3] = &a; +//CHECK-NEXT: double (*_d___range1)[3] = &_d_a; +//CHECK-NEXT: double *__begin10 = *__range10; +//CHECK-NEXT: double *_d___begin1 = 0; //CHECK-NEXT: _d___begin1 = *_d___range1; //CHECK-NEXT: __range10 = &a; -//CHECK-NEXT: __begin10 = *__range10; //CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; //CHECK-NEXT: double *_d_i = 0; //CHECK-NEXT: double *i = 0; //CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { //CHECK-NEXT: { -//CHECK-NEXT: _t0++; //CHECK-NEXT: _d_i = &*_d___begin1; //CHECK-NEXT: i = &*__begin10; //CHECK-NEXT: clad::push(_t2, i); //CHECK-NEXT: clad::push(_t3, _d_i); //CHECK-NEXT: } +//CHECK-NEXT: _t0++; //CHECK-NEXT: clad::push(_t1, r); //CHECK-NEXT: r += *i * *i; //CHECK-NEXT: } @@ -2752,72 +2749,287 @@ double fn34(double x, double y){ //CHECK-NEXT: } //CHECK-NEXT: } - double fn35(double x, double y){ + double r = 0; + double a[] = {x, x*y, 0}; + for(auto& i: a){ + for(auto& j:a){ + if(r<=x*x){ + r+=i*j; + }else if(r>x*x){ + break; + } + } + } + return r; +} + +//CHECK: void fn35_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _cond0 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: clad::tape _cond1 = {}; +//CHECK-NEXT: clad::tape _t3 = {}; +//CHECK-NEXT: clad::tape _t4 = {}; +//CHECK-NEXT: clad::tape _t5 = {}; +//CHECK-NEXT: clad::tape _t6 = {}; +//CHECK-NEXT: clad::tape _t7 = {}; +//CHECK-NEXT: double _d_r = 0; +//CHECK-NEXT: double r = 0; +//CHECK-NEXT: double _d_a[3] = {0}; +//CHECK-NEXT: double a[3] = {x, x * y, 0}; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: double (*__range10)[3] = &a; +//CHECK-NEXT: double (*_d___range1)[3] = &_d_a; +//CHECK-NEXT: double *__begin10 = *__range10; +//CHECK-NEXT: double *_d___begin1 = 0; +//CHECK-NEXT: _d___begin1 = *_d___range1; +//CHECK-NEXT: __range10 = &a; +//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; +//CHECK-NEXT: double *_d_i = 0; +//CHECK-NEXT: double *i = 0; +//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { +//CHECK-NEXT: { +//CHECK-NEXT: _d_i = &*_d___begin1; +//CHECK-NEXT: i = &*__begin10; +//CHECK-NEXT: clad::push(_t6, i); +//CHECK-NEXT: clad::push(_t7, _d_i); +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, {{0U|0UL}}); +//CHECK-NEXT: double (*__range20)[3] = &a; +//CHECK-NEXT: double (*_d___range2)[3] = &_d_a; +//CHECK-NEXT: double *__begin20 = *__range20; +//CHECK-NEXT: double *_d___begin2 = 0; +//CHECK-NEXT: _d___begin2 = *_d___range2; +//CHECK-NEXT: __range20 = &a; +//CHECK-NEXT: double *__end20 = *__range20 + {{3|3L}}; +//CHECK-NEXT: double *_d_j = 0; +//CHECK-NEXT: double *j = 0; +//CHECK-NEXT: for (; __begin20 != __end20; ++__begin20 , ++_d___begin2) { +//CHECK-NEXT: { +//CHECK-NEXT: _d_j = &*_d___begin2; +//CHECK-NEXT: j = &*__begin20; +//CHECK-NEXT: clad::push(_t4, j); +//CHECK-NEXT: clad::push(_t5, _d_j); +//CHECK-NEXT: } +//CHECK-NEXT: clad::back(_t1)++; +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_cond0, r <= x * x); +//CHECK-NEXT: if (clad::back(_cond0)) { +//CHECK-NEXT: clad::push(_t2, r); +//CHECK-NEXT: r += *i * *j; +//CHECK-NEXT: } else { +//CHECK-NEXT: clad::push(_cond1, r > x * x); +//CHECK-NEXT: if (clad::back(_cond1)) { +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_t3, {{1U|1UL}}); +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: clad::push(_t3, {{2U|2UL}}); +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: _d_r += 1; +//CHECK-NEXT: for (; _t0; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: _d___begin1--; +//CHECK-NEXT: i = clad::pop(_t6); +//CHECK-NEXT: _d_i = clad::pop(_t7); +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: for (; clad::back(_t1); clad::back(_t1)--) { +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: _d___begin2--; +//CHECK-NEXT: j = clad::pop(_t4); +//CHECK-NEXT: _d_j = clad::pop(_t5); +//CHECK-NEXT: } +//CHECK-NEXT: switch (clad::pop(_t3)) { +//CHECK-NEXT: case {{2U|2UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: { +//CHECK-NEXT: if (clad::back(_cond0)) { +//CHECK-NEXT: { +//CHECK-NEXT: r = clad::pop(_t2); +//CHECK-NEXT: double _r_d0 = _d_r; +//CHECK-NEXT: *_d_i += _r_d0 * *j; +//CHECK-NEXT: *_d_j += *i * _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } else { +//CHECK-NEXT: if (clad::back(_cond1)) { +//CHECK-NEXT: case {{1U|1UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: } +//CHECK-NEXT: clad::pop(_cond1); +//CHECK-NEXT: } +//CHECK-NEXT: clad::pop(_cond0); +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: clad::pop(_t1); +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: *_d_x += _d_a[0]; +//CHECK-NEXT: *_d_x += _d_a[1] * y; +//CHECK-NEXT: *_d_y += x * _d_a[1]; +//CHECK-NEXT: } +//CHECK-NEXT: } + +double fn36(double x, double y){ double a[] = {1, 2, 3}; double sum = 0; - for(auto i:a){ - sum += sin(i)*x; + for(auto i: a){ + if(sum > x){ + continue; + }else if(1){ + sum += sin(i)*x; + } } return sum; } -//CHECK: void fn35_grad(double x, double y, double *_d_x, double *_d_y) { -//CHECK-NEXT: unsigned {{int|long}} _t0; -//CHECK-NEXT: double (*_d___range1)[3] = 0; -//CHECK-NEXT: double (*__range10)[3] = {}; -//CHECK-NEXT: double *_d___begin1 = 0; -//CHECK-NEXT: double *__begin10 = 0; -//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK: void fn36_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: clad::tape _cond0 = {}; +//CHECK-NEXT: clad::tape _t1 = {}; //CHECK-NEXT: clad::tape _t2 = {}; //CHECK-NEXT: clad::tape _t3 = {}; //CHECK-NEXT: clad::tape _t4 = {}; +//CHECK-NEXT: clad::tape _t5 = {}; //CHECK-NEXT: double _d_a[3] = {0}; //CHECK-NEXT: double a[3] = {1, 2, 3}; //CHECK-NEXT: double _d_sum = 0; //CHECK-NEXT: double sum = 0; -//CHECK-NEXT: _t0 = {{0U|0UL}}; -//CHECK-NEXT: _d___range1 = &_d_a; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: double (*__range10)[3] = &a; +//CHECK-NEXT: double (*_d___range1)[3] = &_d_a; +//CHECK-NEXT: double *__begin10 = *__range10; +//CHECK-NEXT: double *_d___begin1 = 0; //CHECK-NEXT: _d___begin1 = *_d___range1; //CHECK-NEXT: __range10 = &a; -//CHECK-NEXT: __begin10 = *__range10; //CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; //CHECK-NEXT: double _d_i = 0; //CHECK-NEXT: double i = 0; //CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { //CHECK-NEXT: { -//CHECK-NEXT: _t0++; //CHECK-NEXT: _d_i = *_d___begin1; //CHECK-NEXT: i = *__begin10; -//CHECK-NEXT: clad::push(_t3, i); -//CHECK-NEXT: clad::push(_t4, _d_i); +//CHECK-NEXT: clad::push(_t4, i); +//CHECK-NEXT: clad::push(_t5, _d_i); //CHECK-NEXT: } -//CHECK-NEXT: clad::push(_t1, sum); -//CHECK-NEXT: clad::push(_t2, sin(i)); -//CHECK-NEXT: sum += clad::back(_t2) * x; +//CHECK-NEXT: _t0++; +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_cond0, sum > x); +//CHECK-NEXT: if (clad::back(_cond0)) { +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_t1, {{1U|1UL}}); +//CHECK-NEXT: continue; +//CHECK-NEXT: } +//CHECK-NEXT: } else if (1) { +//CHECK-NEXT: clad::push(_t2, sum); +//CHECK-NEXT: clad::push(_t3, sin(i)); +//CHECK-NEXT: sum += clad::back(_t3) * x; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: clad::push(_t1, {{2U|2UL}}); //CHECK-NEXT: } //CHECK-NEXT: _d_sum += 1; //CHECK-NEXT: for (; _t0; _t0--) { //CHECK-NEXT: { //CHECK-NEXT: { //CHECK-NEXT: _d___begin1--; -//CHECK-NEXT: i = clad::pop(_t3); -//CHECK-NEXT: _d_i = clad::pop(_t4); +//CHECK-NEXT: i = clad::pop(_t4); +//CHECK-NEXT: _d_i = clad::pop(_t5); //CHECK-NEXT: } -//CHECK-NEXT: { -//CHECK-NEXT: sum = clad::pop(_t1); -//CHECK-NEXT: double _r_d0 = _d_sum; -//CHECK-NEXT: double _r0 = 0; -//CHECK-NEXT: _r0 += _r_d0 * x * clad::custom_derivatives::sin_pushforward(i, 1.).pushforward; -//CHECK-NEXT: _d_i += _r0; -//CHECK-NEXT: *_d_x += clad::back(_t2) * _r_d0; -//CHECK-NEXT: clad::pop(_t2); +//CHECK-NEXT: switch (clad::pop(_t1)) { +//CHECK-NEXT: case {{2U|2UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: { +//CHECK-NEXT: if (clad::back(_cond0)) { +//CHECK-NEXT: case {{1U|1UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: } else if (1) { +//CHECK-NEXT: { +//CHECK-NEXT: sum = clad::pop(_t2); +//CHECK-NEXT: double _r_d0 = _d_sum; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: _r0 += _r_d0 * x * clad::custom_derivatives::sin_pushforward(i, 1.).pushforward; +//CHECK-NEXT: _d_i += _r0; +//CHECK-NEXT: *_d_x += clad::back(_t3) * _r_d0; +//CHECK-NEXT: clad::pop(_t3); +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: clad::pop(_cond0); +//CHECK-NEXT: } //CHECK-NEXT: } //CHECK-NEXT: } //CHECK-NEXT: *_d___begin1 += _d_i; //CHECK-NEXT: } //CHECK-NEXT: } +double fn37(double x, double y) { + double range[] = {x, 4., y}; + double sum = 0; + for (auto elem: range) + sum += elem; + return sum; +} + +//CHECK: void fn37_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: clad::tape _t3 = {}; +//CHECK-NEXT: double _d_range[3] = {0}; +//CHECK-NEXT: double range[3] = {x, 4., y}; +//CHECK-NEXT: double _d_sum = 0; +//CHECK-NEXT: double sum = 0; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: double (*__range10)[3] = ⦥ +//CHECK-NEXT: double (*_d___range1)[3] = &_d_range; +//CHECK-NEXT: double *__begin10 = *__range10; +//CHECK-NEXT: double *_d___begin1 = 0; +//CHECK-NEXT: _d___begin1 = *_d___range1; +//CHECK-NEXT: __range10 = ⦥ +//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; +//CHECK-NEXT: double _d_elem = 0; +//CHECK-NEXT: double elem = 0; +//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { +//CHECK-NEXT: { +//CHECK-NEXT: _d_elem = *_d___begin1; +//CHECK-NEXT: elem = *__begin10; +//CHECK-NEXT: clad::push(_t2, elem); +//CHECK-NEXT: clad::push(_t3, _d_elem); +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, sum); +//CHECK-NEXT: sum += elem; +//CHECK-NEXT: } +//CHECK-NEXT: _d_sum += 1; +//CHECK-NEXT: for (; _t0; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: _d___begin1--; +//CHECK-NEXT: elem = clad::pop(_t2); +//CHECK-NEXT: _d_elem = clad::pop(_t3); +//CHECK-NEXT: } +//CHECK-NEXT: sum = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_sum; +//CHECK-NEXT: _d_elem += _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: *_d___begin1 += _d_elem; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: *_d_x += _d_range[0]; +//CHECK-NEXT: *_d_y += _d_range[2]; +//CHECK-NEXT: } +//CHECK-NEXT: } + #define TEST(F, x) { \ result[0] = 0; \ @@ -2904,7 +3116,9 @@ int main() { TEST_2(fn33, 3, 5); // CHECK-EXEC: {15.00, 9.00} TEST_2(fn34, 2, 2); // CHECK-EXEC: {64.00, 32.00} - TEST_2(fn35, 1, 1); // CHECK-EXEC: {1.89, 0.00} + TEST_2(fn35, 2, 2); // CHECK-EXEC: {12.00, 4.00} + TEST_2(fn36, 1, 1); // CHECK-EXEC: {1.75, 0.00} + TEST_2(fn37, 1, 1); // CHECK-EXEC: {1.00, 1.00} } //CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {