Skip to content

Commit

Permalink
Call calloc instead of malloc or after realloc for derivative vars
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Oct 22, 2024
1 parent cf7ed92 commit 66103cf
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 65 deletions.
10 changes: 10 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,16 @@ namespace clad {
/// \returns The atomicAdd call expression.
clang::Expr* BuildCallToCudaAtomicAdd(clang::Expr* LHS, clang::Expr* RHS);

/// Check whether this is an assignment to a malloc or realloc call for a
/// derivative variable and build a call to calloc instead, to properly
/// intialize the memory to zero. Currently these configurations of size are
/// supported in malloc or realloc:
/// 1. x * sizeof(T)
/// 2. sizeof(T) * x
/// \param[in] RHS The right-hand side expression of the assignment.
/// @returns The call to calloc if the condition is met, otherwise nullptr.
clang::Expr* CheckAndBuildCallToCalloc(clang::Expr* RHS);

static DeclDiff<clang::StaticAssertDecl>
DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD);

Expand Down
9 changes: 0 additions & 9 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -546,15 +546,6 @@ namespace clad {
bool useRefQualifiedThisObj = false,
const clang::CXXScopeSpec* SS = nullptr);

/// Build a call to a free function. Search for it using its name and args.
///
/// \param[in] funcName function name
/// \param[in] argExprs function arguments expressions
/// \returns Built call expression
clang::Expr*
BuildCallExprToFunction(std::string funcName,
llvm::MutableArrayRef<clang::Expr*> args);

/// Build a call to templated free function inside the clad namespace.
///
/// \param[in] name name of the function
Expand Down
77 changes: 56 additions & 21 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1736,11 +1736,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(CallArgs), Loc)
.get();
Expr* call_dx =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs), Loc)
.get();
Expr* call_dx = nullptr;
if (FD->getNameAsString() == "malloc")
call_dx = CheckAndBuildCallToCalloc(Clone(CE));
if (!call_dx)
call_dx = m_Sema
.ActOnCallExpr(
getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs), Loc)
.get();
return StmtDiff(call, call_dx);
}
// For calls to C-style memory deallocation functions, we do not need to
Expand Down Expand Up @@ -2856,6 +2860,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ComputeEffectiveDOperands(Ldiff, Rdiff, derivedL, derivedR);
addToCurrentBlock(BuildOp(opCode, derivedL, derivedR),
direction::forward);
if (opCode == BO_Assign && derivedR)
if (Expr* callocCall =
CheckAndBuildCallToCalloc(derivedR->IgnoreParenCasts())) {
Expr* cast =
m_Sema
.BuildCStyleCastExpr(
SourceLocation(),
m_Context.getTrivialTypeSourceInfo(derivedL->getType()),
SourceLocation(), callocCall)
.get();
addToCurrentBlock(BuildOp(BO_Assign, derivedL, cast),
direction::forward);
}
}
}
return StmtDiff(op, ResultRef, nullptr, valueForRevPass);
Expand Down Expand Up @@ -3209,8 +3226,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bool dxInForward = false;
if (auto* callExpr = dyn_cast_or_null<CallExpr>(stmtDx))
if (auto* FD = dyn_cast<FunctionDecl>(callExpr->getCalleeDecl()))
if (utils::IsMemoryFunction(FD))
if (utils::IsMemoryFunction(FD)) {
printf("%s\n", FD->getNameAsString().c_str());
dxInForward = true;
}
if (stmtDx) {
if (dxInForward)
addToCurrentBlock(stmtDx, direction::forward);
Expand All @@ -3224,6 +3243,37 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(SDiff.getStmt(), ReverseResult);
}

Expr* ReverseModeVisitor::CheckAndBuildCallToCalloc(Expr* RHS) {
Expr* size = nullptr;
if (auto* callExpr = dyn_cast<CallExpr>(RHS))
if (auto* implCast = dyn_cast<ImplicitCastExpr>(callExpr->getCallee()))
if (auto* declRef = dyn_cast<DeclRefExpr>(implCast->getSubExpr()))
if (auto* FD = dyn_cast<FunctionDecl>(declRef->getDecl())) {
if (FD->getNameAsString() == "malloc")
size = callExpr->getArg(0);
else if (FD->getNameAsString() == "realloc")
size = callExpr->getArg(1);
}

if (size) {
llvm::SmallVector<Expr*, 2> args;
if (auto BinOp = dyn_cast<BinaryOperator>(size)) {
if (BinOp->getOpcode() == BO_Mul) {
Expr* lhs = BinOp->getLHS();
Expr* rhs = BinOp->getRHS();
if (auto* sizeofCall = dyn_cast<UnaryExprOrTypeTraitExpr>(rhs))
args = {lhs, sizeofCall};
else if (auto* sizeofCall = dyn_cast<UnaryExprOrTypeTraitExpr>(lhs))
args = {rhs, sizeofCall};
if (!args.empty())

Check warning on line 3268 in lib/Differentiator/ReverseModeVisitor.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/ReverseModeVisitor.cpp#L3267-L3268

Added lines #L3267 - L3268 were not covered by tests
return GetFunctionCall("calloc", "", args);
}
}
}

return {};
}

std::pair<StmtDiff, StmtDiff>
ReverseModeVisitor::DifferentiateSingleExpr(const Expr* E, Expr* dfdE) {
beginBlock(direction::forward);
Expand Down Expand Up @@ -3335,21 +3385,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
else {
VarDecl* VDDerived = VDDiff.getDecl_dx();
declsDiff.push_back(VDDerived);
if (auto* cast = dyn_cast<CStyleCastExpr>(VDDerived->getInit()))
if (auto* callExpr =
dyn_cast<CallExpr>(cast->getSubExpr()->IgnoreCasts()))
if (auto* implCast =
dyn_cast<ImplicitCastExpr>(callExpr->getCallee()))
if (auto* declRef =
dyn_cast<DeclRefExpr>(implCast->getSubExpr()))
if (auto* FD = dyn_cast<FunctionDecl>(declRef->getDecl()))
if (FD->getNameAsString() == "malloc") {
llvm::SmallVector<Expr*, 3> memsetArgs{
BuildDeclRef(VDDerived),
getZeroInit(m_Context.IntTy), callExpr->getArg(0)};
callToMemsets.push_back(
BuildCallExprToFunction("memset", memsetArgs));
}
}
}
} else if (auto* SAD = dyn_cast<StaticAssertDecl>(D)) {
Expand Down
42 changes: 9 additions & 33 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,14 @@ namespace clad {
Expr* VisitorBase::GetFunctionCall(const std::string& funcName,
const std::string& nmspace,
llvm::SmallVectorImpl<Expr*>& callArgs) {
NamespaceDecl* NSD =
utils::LookupNSD(m_Sema, nmspace, /*shouldExist=*/true);
DeclContext* DC = NSD;
NamespaceDecl* NSD = nullptr;
CXXScopeSpec SS;
SS.Extend(m_Context, NSD, noLoc, noLoc);

if (!nmspace.empty()) {
NSD = utils::LookupNSD(m_Sema, nmspace, /*shouldExist=*/true);
SS.Extend(m_Context, NSD, noLoc, noLoc);
}
DeclContext* DC = NSD;

IdentifierInfo* II = &m_Context.Idents.get(funcName);
DeclarationName name(II);
Expand All @@ -544,6 +547,8 @@ namespace clad {

if (DC)
m_Sema.LookupQualifiedName(R, DC);
else
m_Sema.LookupQualifiedName(R, m_Context.getTranslationUnitDecl());
Expr* UnresolvedLookup = nullptr;
if (!R.empty())
UnresolvedLookup =
Expand Down Expand Up @@ -683,35 +688,6 @@ namespace clad {
return call;
}

clang::Expr*
VisitorBase::BuildCallExprToFunction(std::string funcName,
llvm::MutableArrayRef<Expr*> args) {
DeclarationName Id = &m_Context.Idents.get(funcName);
LookupResult lookupResult(m_Sema, Id, SourceLocation(),
Sema::LookupOrdinaryName);
m_Sema.LookupQualifiedName(lookupResult,
m_Context.getTranslationUnitDecl());

CXXScopeSpec SS;
Expr* UnresolvedLookup =
m_Sema.BuildDeclarationNameExpr(SS, lookupResult, /*ADL=*/true).get();
for (auto arg : args)
arg->dump();

assert(!m_Builder.noOverloadExists(UnresolvedLookup, args) &&
"memset function not found");

Expr* call = m_Sema
.ActOnCallExpr(getCurrentScope(),
/*Fn=*/UnresolvedLookup,
/*LParenLoc=*/noLoc,
/*ArgExprs=*/args,
/*RParenLoc=*/m_DiffReq->getLocation())
.get();

return call;
}

Expr* VisitorBase::BuildCallExprToCladFunction(
llvm::StringRef name, llvm::MutableArrayRef<clang::Expr*> argExprs,
llvm::ArrayRef<clang::TemplateArgument> templateArgs,
Expand Down
2 changes: 1 addition & 1 deletion test/CUDA/GradientKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ double fn_memory(double *out, double *in) {
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: kernel_call<<<1, 10>>>(out, in);
//CHECK-NEXT: cudaDeviceSynchronize();
//CHECK-NEXT: double *_d_out_host = (double *)malloc(10 * sizeof(double));
//CHECK-NEXT: double *_d_out_host = (double *)calloc(10, sizeof(double));
//CHECK-NEXT: double *out_host = (double *)malloc(10 * sizeof(double));
//CHECK-NEXT: cudaMemcpy(out_host, out, 10 * sizeof(double), cudaMemcpyDeviceToHost);
//CHECK-NEXT: double _d_res = 0.;
Expand Down
3 changes: 2 additions & 1 deletion test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ double cStyleMemoryAlloc(double x, size_t n) {

// CHECK: void cStyleMemoryAlloc_grad_0(double x, size_t n, double *_d_x) {
// CHECK-NEXT: size_t _d_n = 0UL;
// CHECK-NEXT: T *_d_t = (T *)malloc(n * sizeof(T));
// CHECK-NEXT: T *_d_t = (T *)calloc(n, sizeof(T));
// CHECK-NEXT: T *t = (T *)malloc(n * sizeof(T));
// CHECK-NEXT: memset(_d_t, 0, n * sizeof(T));
// CHECK-NEXT: memset(t, 0, n * sizeof(T));
Expand All @@ -422,6 +422,7 @@ double cStyleMemoryAlloc(double x, size_t n) {
// CHECK-NEXT: double *_t2 = p;
// CHECK-NEXT: double *_t3 = _d_p;
// CHECK-NEXT: _d_p = (double *)realloc(_d_p, 2 * sizeof(double));
// CHECK-NEXT: _d_p = (double *)calloc(2, sizeof(double));
// CHECK-NEXT: p = (double *)realloc(p, 2 * sizeof(double));
// CHECK-NEXT: double _t4 = p[1];
// CHECK-NEXT: p[1] = 2 * x;
Expand Down

0 comments on commit 66103cf

Please sign in to comment.