diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 8e57a4cb4..121b82ab8 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -401,6 +401,8 @@ namespace clad { StmtDiff VisitDoStmt(const clang::DoStmt* DS); StmtDiff VisitContinueStmt(const clang::ContinueStmt* CS); StmtDiff VisitBreakStmt(const clang::BreakStmt* BS); + StmtDiff + VisitCXXStdInitializerListExpr(const clang::CXXStdInitializerListExpr* ILE); StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE); StmtDiff VisitCXXNewExpr(const clang::CXXNewExpr* CNE); StmtDiff VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 9efe02a98..9e34cec58 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -760,6 +760,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(enzymeCall); } } + + StmtDiff ReverseModeVisitor::VisitCXXStdInitializerListExpr( + const clang::CXXStdInitializerListExpr* ILE) { + return Visit(ILE->getSubExpr(), dfdx()); + } + StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) { diag( DiagnosticsEngine::Warning, S->getBeginLoc(), @@ -1479,7 +1485,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // global. Ref-type declarations cannot be moved to the function global // scope because they can't be separated from their inits. if (DRE->getDecl()->getType()->isReferenceType() && - !VD->getType()->isReferenceType()) + VD->getType()->isPointerType()) clonedDRE = BuildOp(UO_Deref, clonedDRE); if (m_DiffReq.Mode == DiffMode::jacobian) { if (m_VectorOutput.size() <= outputArrayCursor) @@ -2711,11 +2717,43 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, VDCloneType = CloneType(VD->getType()); VDDerivedType = getNonConstType(VDCloneType, m_Context, m_Sema); } - bool isDerivativeOfRefType = VD->getType()->isReferenceType(); + + bool isRefType = VD->getType()->isLValueReferenceType(); VarDecl* VDDerived = nullptr; bool isPointerType = VD->getType()->isPointerType(); bool isInitializedByNewExpr = false; bool initializeDerivedVar = true; + std::string typeName; + if (const auto* RT = + utils::GetValueType(VD->getType())->getAs()) + typeName = RT->getDecl()->getNameAsString(); + if (typeName == "initializer_list") { + if (VD->getInit()) { + if (const auto* CXXILE = dyn_cast( + VD->getInit()->IgnoreImplicit())) { + if (const auto* ILE = dyn_cast( + CXXILE->getSubExpr()->IgnoreImplicit())) { + VDDerivedType = GetCladArrayOfType((*ILE->getInits())->getType()); + unsigned numInits = ILE->getNumInits(); + VDDerivedInit = ConstantFolder::synthesizeLiteral( + m_Context.getSizeType(), m_Context, numInits); + VDCloneType = VDDerivedType; + } + } else if (isRefType) { + initDiff = Visit(VD->getInit()); + if (promoteToFnScope) { + VDDerivedInit = BuildOp(UO_AddrOf, initDiff.getExpr_dx()); + VDDerivedType = VDDerivedInit->getType(); + } else { + VDDerivedInit = initDiff.getExpr_dx(); + VDDerivedType = + m_Context.getLValueReferenceType(VDDerivedInit->getType()); + } + VDCloneType = VDDerivedType; + } + } + } + // Check if the variable is pointer type and initialized by new expression if (isPointerType && VD->getInit() && isa(VD->getInit())) isInitializedByNewExpr = true; @@ -2746,7 +2784,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // `VDDerivedType` is the corresponding non-reference type and the initial // value is set to 0. // Otherwise, for non-reference types, the initial value is set to 0. - VDDerivedInit = getZeroInit(VD->getType()); + if (!VDDerivedInit) + VDDerivedInit = getZeroInit(VD->getType()); // `specialThisDiffCase` is only required for correctly differentiating // the following code: @@ -2763,14 +2802,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } - if (isDerivativeOfRefType) { + if (isRefType) { initDiff = Visit(VD->getInit()); if (!initDiff.getForwSweepExpr_dx()) { VDDerivedType = ComputeAdjointType(VD->getType().getNonReferenceType()); - isDerivativeOfRefType = false; + isRefType = false; } - if (promoteToFnScope || !isDerivativeOfRefType) + if (promoteToFnScope || !isRefType) VDDerivedInit = getZeroInit(VDDerivedType); else VDDerivedInit = initDiff.getForwSweepExpr_dx(); @@ -2827,7 +2866,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // differentiated and should not be differentiated again. // If `VD` is a reference to a non-local variable then also there's no // need to call `Visit` since non-local variables are not differentiated. - if (!isDerivativeOfRefType && (!isPointerType || isInitializedByNewExpr)) { + if (!isRefType && (!isPointerType || isInitializedByNewExpr)) { Expr* derivedE = nullptr; if (!clad::utils::hasNonDifferentiableAttribute(VD)) { @@ -2876,7 +2915,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // FIXME: Add extra parantheses if derived variable pointer is pointing to a // class type object. - if (isDerivativeOfRefType && promoteToFnScope) { + if (isRefType && promoteToFnScope) { Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE, BuildOp(UnaryOperatorKind::UO_AddrOf, @@ -2898,7 +2937,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // -> // double* ref; // ref = &x; - if (isDerivativeOfRefType && promoteToFnScope) + if (isRefType && promoteToFnScope) VDClone = BuildGlobalVarDecl( VDCloneType, VD->getNameAsString(), BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getExpr()), diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 2315008bf..e04422e0f 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -3020,6 +3020,119 @@ double fn37(double x, double y) { //CHECK-NEXT: } //CHECK-NEXT: } +double fn38(double x, double y) { + double sum = 0; + if (x > 0) { + auto&& range = {1., x, 2., y, 3.}; + for (auto elem : range) + sum += elem; + } + return sum; +} + +//CHECK: void fn38_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: bool _cond0; +//CHECK-NEXT: clad::array _d_range = {{5U|5UL}}; +//CHECK-NEXT: clad::array range = {}; +//CHECK-NEXT: unsigned {{int|long}} _t0; +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: clad::tape _t3 = {}; +//CHECK-NEXT: double _d_sum = 0; +//CHECK-NEXT: double sum = 0; +//CHECK-NEXT: { +//CHECK-NEXT: _cond0 = x > 0; +//CHECK-NEXT: if (_cond0) { +//CHECK-NEXT: range = {1., x, 2., y, 3.}; +//CHECK-NEXT: _t0 = {{0U|0UL}}; +//CHECK-NEXT: clad::array &__range20 = range; +//CHECK-NEXT: clad::array &_d___range2 = _d_range; +//CHECK-NEXT: {{const double *\*|const_iterator }}__begin20 = std::begin(__range20); +//CHECK-NEXT: double *_d___begin2 = std::begin(_d___range2); +//CHECK-NEXT: {{const double *\*|const_iterator }}__end20 = std::end(__range20); +//CHECK-NEXT: double _d_elem = 0; +//CHECK-NEXT: double elem = 0; +//CHECK-NEXT: for (; __begin20 != __end20; ++__begin20 , ++_d___begin2) { +//CHECK-NEXT: { +//CHECK-NEXT: _d_elem = *_d___begin2; +//CHECK-NEXT: elem = *__begin20; +//CHECK-NEXT: clad::push(_t2, elem); +//CHECK-NEXT: clad::push(_t3, _d_elem); +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, sum); +//CHECK-NEXT: sum += elem; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: _d_sum += 1; +//CHECK-NEXT: if (_cond0) { +//CHECK-NEXT: for (; _t0; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: _d___begin2--; +//CHECK-NEXT: elem = clad::pop(_t2); +//CHECK-NEXT: _d_elem = clad::pop(_t3); +//CHECK-NEXT: } +//CHECK-NEXT: sum = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_sum; +//CHECK-NEXT: _d_elem += _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: *_d___begin2 += _d_elem; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: *_d_x += _d_range[1]; +//CHECK-NEXT: *_d_y += _d_range[3]; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } + +double fn39(double x) { + double res = 0; + auto &&range = {1, 2, 3}; + for (auto i = range.begin(); i != range.end(); i++) { + res += x * (*i); + } + return res; +} + +//CHECK: void fn39_grad(double x, double *_d_x) { +//CHECK-NEXT: int *_d_i = 0; +//CHECK-NEXT: {{const int *\*|const_iterator }}i = 0; +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: double _d_res = 0; +//CHECK-NEXT: double res = 0; +//CHECK-NEXT: clad::array _d_range = {{3U|3UL}}; +//CHECK-NEXT: clad::array range = {1, 2, 3}; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: _d_i = std::begin(_d_range); +//CHECK-NEXT: for (i = std::begin(range); ; _d_i++ , i++) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!(i != std::end(range))) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, res); +//CHECK-NEXT: res += x * (*i); +//CHECK-NEXT: } +//CHECK-NEXT: _d_res += 1; +//CHECK-NEXT: for (;; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!_t0) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: i--; +//CHECK-NEXT: _d_i--; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: res = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_res; +//CHECK-NEXT: *_d_x += _r_d0 * (*i); +//CHECK-NEXT: *_d_i += x * _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } #define TEST(F, x) { \ result[0] = 0; \ @@ -3109,6 +3222,8 @@ int main() { TEST_2(fn35, 2, 2); // CHECK-EXEC: {12.00, 4.00} TEST_2(fn36, 1, 1); // CHECK-EXEC: {1.75, 0.00} TEST_2(fn37, 1, 1); // CHECK-EXEC: {1.00, 1.00} + TEST_2(fn38, 6, 3); // CHECK-EXEC: {1.00, 1.00} + TEST(fn39, 9); // CHECK-EXEC: {6.00} } //CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {