diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index c2b88ed83..d65871ec4 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -450,16 +450,13 @@ namespace clad { /// \returns The atomicAdd call expression. clang::Expr* BuildCallToCudaAtomicAdd(clang::Expr* LHS, clang::Expr* RHS); - /// Check whether this is an assignment to a malloc or realloc call for a - /// derivative variable and build a call to calloc instead if it's a malloc - /// call or add a calloc call after a realloc call, to properly intialize - /// the memory to zero. Currently these configurations of size are supported - /// in malloc or realloc: - /// 1. x * sizeof(T) - /// 2. sizeof(T) * x - /// \param[in] RHS The right-hand side expression of the assignment. - /// @returns The call to calloc if the condition is met, otherwise nullptr. - clang::Expr* CheckAndBuildCallToCalloc(clang::Expr* RHS); + /// Check whether this is an assignment to a malloc or a realloc call for a + /// derivative variable and build a call to memset to follow the memory + /// allocation in order to properly intialize the memory to zero. \param[in] + /// LHS The left-hand side expression of the assignment. \param[in] RHS The + /// right-hand side expression of the assignment. + /// @returns The call to memset if the condition is met, otherwise nullptr. + clang::Expr* CheckAndBuildCallToMemset(clang::Expr* LHS, clang::Expr* RHS); static DeclDiff DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 45bb0fe54..94841450b 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -153,35 +153,25 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return atomicAddCall; } - Expr* ReverseModeVisitor::CheckAndBuildCallToCalloc(Expr* RHS) { + Expr* ReverseModeVisitor::CheckAndBuildCallToMemset(Expr* LHS, Expr* RHS) { Expr* size = nullptr; if (auto* callExpr = dyn_cast(RHS)) - if (auto* implCast = dyn_cast(callExpr->getCallee())) - if (auto* declRef = dyn_cast(implCast->getSubExpr())) - if (auto* FD = dyn_cast(declRef->getDecl())) { - if (FD->getNameAsString() == "malloc") - size = callExpr->getArg(0); - else if (FD->getNameAsString() == "realloc") - size = callExpr->getArg(1); - } + if (auto* declRef = + dyn_cast(callExpr->getCallee()->IgnoreImpCasts())) + if (auto* FD = dyn_cast(declRef->getDecl())) { + if (FD->getNameAsString() == "malloc") + size = callExpr->getArg(0); + else if (FD->getNameAsString() == "realloc") + size = callExpr->getArg(1); + } if (size) { - llvm::SmallVector args; - if (auto* BinOp = dyn_cast(size)) { - if (BinOp->getOpcode() == BO_Mul) { - Expr* lhs = BinOp->getLHS(); - Expr* rhs = BinOp->getRHS(); - if (auto* sizeofCall = dyn_cast(rhs)) - args = {lhs, sizeofCall}; - else if (auto* sizeofCall = dyn_cast(lhs)) - args = {rhs, sizeofCall}; - if (!args.empty()) - return GetFunctionCall("calloc", "", args); - } - } + llvm::SmallVector args = {LHS, getZeroInit(m_Context.IntTy), + size}; + return GetFunctionCall("memset", "", args); } - return {}; + return nullptr; } ReverseModeVisitor::ReverseModeVisitor(DerivativeBuilder& builder, @@ -1757,15 +1747,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, llvm::MutableArrayRef(CallArgs), Loc) .get(); - Expr* call_dx = nullptr; - if (FD->getNameAsString() == "malloc") - call_dx = CheckAndBuildCallToCalloc(Clone(CE)); - if (!call_dx) - call_dx = m_Sema - .ActOnCallExpr( - getCurrentScope(), Clone(CE->getCallee()), Loc, - llvm::MutableArrayRef(DerivedCallArgs), Loc) - .get(); + Expr* call_dx = + m_Sema + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, + llvm::MutableArrayRef(DerivedCallArgs), Loc) + .get(); return StmtDiff(call, call_dx); } // For calls to C-style memory deallocation functions, we do not need to @@ -2877,19 +2863,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, ComputeEffectiveDOperands(Ldiff, Rdiff, derivedL, derivedR); addToCurrentBlock(BuildOp(opCode, derivedL, derivedR), direction::forward); - if (opCode == BO_Assign && derivedR) - if (Expr* callocCall = - CheckAndBuildCallToCalloc(derivedR->IgnoreParenCasts())) { - Expr* cast = - m_Sema - .BuildCStyleCastExpr( - SourceLocation(), - m_Context.getTrivialTypeSourceInfo(derivedL->getType()), - SourceLocation(), callocCall) - .get(); - addToCurrentBlock(BuildOp(BO_Assign, derivedL, cast), - direction::forward); - } + if (opCode == BO_Assign && derivedL && derivedR) + if (Expr* memsetCall = CheckAndBuildCallToMemset( + derivedL, derivedR->IgnoreParenCasts())) + addToCurrentBlock(memsetCall, direction::forward); } } return StmtDiff(op, ResultRef, nullptr, valueForRevPass); @@ -3276,6 +3253,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SmallVector inits; llvm::SmallVector decls; llvm::SmallVector declsDiff; + llvm::SmallVector memsetCalls; // Need to put array decls inlined. llvm::SmallVector localDeclsDiff; // reverse_mode_forward_pass does not have a reverse pass so declarations @@ -3365,8 +3343,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (VDDiff.getDecl_dx()) { if (isa(VD->getType())) localDeclsDiff.push_back(VDDiff.getDecl_dx()); - else - declsDiff.push_back(VDDiff.getDecl_dx()); + else { + VarDecl* VDDerived = VDDiff.getDecl_dx(); + declsDiff.push_back(VDDerived); + if (Stmt* memsetCall = CheckAndBuildCallToMemset( + BuildDeclRef(VDDerived), + VDDerived->getInit()->IgnoreCasts())) + memsetCalls.push_back(memsetCall); + } } } else if (auto* SAD = dyn_cast(D)) { DeclDiff SADDiff = DifferentiateStaticAssertDecl(SAD); @@ -3405,6 +3389,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Stmts& block = promoteToFnScope ? m_Globals : getCurrentBlock(direction::forward); addToBlock(DSDiff, block); + if (memsetCalls.empty()) + printf("memsetCalls is empty\n"); + for (Stmt* memset : memsetCalls) + addToBlock(memset, block); } if (m_ExternalSource) { diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index b6eb05bae..70fd6cf7d 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -407,7 +407,8 @@ double cStyleMemoryAlloc(double x, size_t n) { // CHECK: void cStyleMemoryAlloc_grad_0(double x, size_t n, double *_d_x) { // CHECK-NEXT: size_t _d_n = 0UL; -// CHECK-NEXT: T *_d_t = (T *)calloc(n, sizeof(T)); +// CHECK-NEXT: T *_d_t = (T *)malloc(sizeof(T) * n); +// CHECK-NEXT: memset(_d_t, 0, sizeof(T) * n); // CHECK-NEXT: T *t = (T *)malloc(sizeof(T) * n); // CHECK-NEXT: memset(_d_t, 0, n * sizeof(T)); // CHECK-NEXT: memset(t, 0, n * sizeof(T)); @@ -422,7 +423,7 @@ double cStyleMemoryAlloc(double x, size_t n) { // CHECK-NEXT: double *_t2 = p; // CHECK-NEXT: double *_t3 = _d_p; // CHECK-NEXT: _d_p = (double *)realloc(_d_p, 2 * sizeof(double)); -// CHECK-NEXT: _d_p = (double *)calloc(2, sizeof(double)); +// CHECK-NEXT: memset(_d_p, 0, 2 * sizeof(double)); // CHECK-NEXT: p = (double *)realloc(p, 2 * sizeof(double)); // CHECK-NEXT: double _t4 = p[1]; // CHECK-NEXT: p[1] = 2 * x;