Skip to content

Commit

Permalink
Revert back to memset call after malloc and realloc
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Oct 25, 2024
1 parent fa16a53 commit f11aea6
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 59 deletions.
17 changes: 7 additions & 10 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,16 +450,13 @@ namespace clad {
/// \returns The atomicAdd call expression.
clang::Expr* BuildCallToCudaAtomicAdd(clang::Expr* LHS, clang::Expr* RHS);

/// Check whether this is an assignment to a malloc or realloc call for a
/// derivative variable and build a call to calloc instead if it's a malloc
/// call or add a calloc call after a realloc call, to properly intialize
/// the memory to zero. Currently these configurations of size are supported
/// in malloc or realloc:
/// 1. x * sizeof(T)
/// 2. sizeof(T) * x
/// \param[in] RHS The right-hand side expression of the assignment.
/// @returns The call to calloc if the condition is met, otherwise nullptr.
clang::Expr* CheckAndBuildCallToCalloc(clang::Expr* RHS);
/// Check whether this is an assignment to a malloc or a realloc call for a
/// derivative variable and build a call to memset to follow the memory
/// allocation in order to properly intialize the memory to zero. \param[in]
/// LHS The left-hand side expression of the assignment. \param[in] RHS The
/// right-hand side expression of the assignment.
/// @returns The call to memset if the condition is met, otherwise nullptr.
clang::Expr* CheckAndBuildCallToMemset(clang::Expr* LHS, clang::Expr* RHS);

static DeclDiff<clang::StaticAssertDecl>
DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD);
Expand Down
82 changes: 35 additions & 47 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,35 +153,25 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return atomicAddCall;
}

Expr* ReverseModeVisitor::CheckAndBuildCallToCalloc(Expr* RHS) {
Expr* ReverseModeVisitor::CheckAndBuildCallToMemset(Expr* LHS, Expr* RHS) {
Expr* size = nullptr;
if (auto* callExpr = dyn_cast<CallExpr>(RHS))
if (auto* implCast = dyn_cast<ImplicitCastExpr>(callExpr->getCallee()))
if (auto* declRef = dyn_cast<DeclRefExpr>(implCast->getSubExpr()))
if (auto* FD = dyn_cast<FunctionDecl>(declRef->getDecl())) {
if (FD->getNameAsString() == "malloc")
size = callExpr->getArg(0);
else if (FD->getNameAsString() == "realloc")
size = callExpr->getArg(1);
}
if (auto* declRef =
dyn_cast<DeclRefExpr>(callExpr->getCallee()->IgnoreImpCasts()))
if (auto* FD = dyn_cast<FunctionDecl>(declRef->getDecl())) {
if (FD->getNameAsString() == "malloc")
size = callExpr->getArg(0);
else if (FD->getNameAsString() == "realloc")
size = callExpr->getArg(1);
}

if (size) {
llvm::SmallVector<Expr*, 2> args;
if (auto* BinOp = dyn_cast<BinaryOperator>(size)) {
if (BinOp->getOpcode() == BO_Mul) {
Expr* lhs = BinOp->getLHS();
Expr* rhs = BinOp->getRHS();
if (auto* sizeofCall = dyn_cast<UnaryExprOrTypeTraitExpr>(rhs))
args = {lhs, sizeofCall};
else if (auto* sizeofCall = dyn_cast<UnaryExprOrTypeTraitExpr>(lhs))
args = {rhs, sizeofCall};
if (!args.empty())
return GetFunctionCall("calloc", "", args);
}
}
llvm::SmallVector<Expr*, 3> args = {LHS, getZeroInit(m_Context.IntTy),
size};
return GetFunctionCall("memset", "", args);
}

return {};
return nullptr;
}

ReverseModeVisitor::ReverseModeVisitor(DerivativeBuilder& builder,
Expand Down Expand Up @@ -1757,15 +1747,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(CallArgs), Loc)
.get();
Expr* call_dx = nullptr;
if (FD->getNameAsString() == "malloc")
call_dx = CheckAndBuildCallToCalloc(Clone(CE));
if (!call_dx)
call_dx = m_Sema
.ActOnCallExpr(
getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs), Loc)
.get();
Expr* call_dx =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs), Loc)
.get();
return StmtDiff(call, call_dx);
}
// For calls to C-style memory deallocation functions, we do not need to
Expand Down Expand Up @@ -2877,19 +2863,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ComputeEffectiveDOperands(Ldiff, Rdiff, derivedL, derivedR);
addToCurrentBlock(BuildOp(opCode, derivedL, derivedR),
direction::forward);
if (opCode == BO_Assign && derivedR)
if (Expr* callocCall =
CheckAndBuildCallToCalloc(derivedR->IgnoreParenCasts())) {
Expr* cast =
m_Sema
.BuildCStyleCastExpr(
SourceLocation(),
m_Context.getTrivialTypeSourceInfo(derivedL->getType()),
SourceLocation(), callocCall)
.get();
addToCurrentBlock(BuildOp(BO_Assign, derivedL, cast),
direction::forward);
}
if (opCode == BO_Assign && derivedL && derivedR)
if (Expr* memsetCall = CheckAndBuildCallToMemset(
derivedL, derivedR->IgnoreParenCasts()))
addToCurrentBlock(memsetCall, direction::forward);
}
}
return StmtDiff(op, ResultRef, nullptr, valueForRevPass);
Expand Down Expand Up @@ -3276,6 +3253,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::SmallVector<Stmt*, 16> inits;
llvm::SmallVector<Decl*, 4> decls;
llvm::SmallVector<Decl*, 4> declsDiff;
llvm::SmallVector<Stmt*, 4> memsetCalls;
// Need to put array decls inlined.
llvm::SmallVector<Decl*, 4> localDeclsDiff;
// reverse_mode_forward_pass does not have a reverse pass so declarations
Expand Down Expand Up @@ -3365,8 +3343,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (VDDiff.getDecl_dx()) {
if (isa<VariableArrayType>(VD->getType()))
localDeclsDiff.push_back(VDDiff.getDecl_dx());
else
declsDiff.push_back(VDDiff.getDecl_dx());
else {
VarDecl* VDDerived = VDDiff.getDecl_dx();
declsDiff.push_back(VDDerived);
if (Stmt* memsetCall = CheckAndBuildCallToMemset(
BuildDeclRef(VDDerived),
VDDerived->getInit()->IgnoreCasts()))
memsetCalls.push_back(memsetCall);
}
}
} else if (auto* SAD = dyn_cast<StaticAssertDecl>(D)) {
DeclDiff<StaticAssertDecl> SADDiff = DifferentiateStaticAssertDecl(SAD);
Expand Down Expand Up @@ -3405,6 +3389,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Stmts& block =
promoteToFnScope ? m_Globals : getCurrentBlock(direction::forward);
addToBlock(DSDiff, block);
if (memsetCalls.empty())
printf("memsetCalls is empty\n");
for (Stmt* memset : memsetCalls)
addToBlock(memset, block);
}

if (m_ExternalSource) {
Expand Down
5 changes: 3 additions & 2 deletions test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ double cStyleMemoryAlloc(double x, size_t n) {

// CHECK: void cStyleMemoryAlloc_grad_0(double x, size_t n, double *_d_x) {
// CHECK-NEXT: size_t _d_n = 0UL;
// CHECK-NEXT: T *_d_t = (T *)calloc(n, sizeof(T));
// CHECK-NEXT: T *_d_t = (T *)malloc(sizeof(T) * n);
// CHECK-NEXT: memset(_d_t, 0, sizeof(T) * n);
// CHECK-NEXT: T *t = (T *)malloc(sizeof(T) * n);
// CHECK-NEXT: memset(_d_t, 0, n * sizeof(T));
// CHECK-NEXT: memset(t, 0, n * sizeof(T));
Expand All @@ -422,7 +423,7 @@ double cStyleMemoryAlloc(double x, size_t n) {
// CHECK-NEXT: double *_t2 = p;
// CHECK-NEXT: double *_t3 = _d_p;
// CHECK-NEXT: _d_p = (double *)realloc(_d_p, 2 * sizeof(double));
// CHECK-NEXT: _d_p = (double *)calloc(2, sizeof(double));
// CHECK-NEXT: memset(_d_p, 0, 2 * sizeof(double));
// CHECK-NEXT: p = (double *)realloc(p, 2 * sizeof(double));
// CHECK-NEXT: double _t4 = p[1];
// CHECK-NEXT: p[1] = 2 * x;
Expand Down

0 comments on commit f11aea6

Please sign in to comment.