Skip to content

Commit

Permalink
Use calloc call instead of malloc or add it after a realloc call for …
Browse files Browse the repository at this point in the history
…derivative pointers
  • Loading branch information
kchristin22 committed Oct 23, 2024
1 parent d82f7fd commit f545a08
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 29 deletions.
11 changes: 11 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,17 @@ namespace clad {
/// \returns The atomicAdd call expression.
clang::Expr* BuildCallToCudaAtomicAdd(clang::Expr* LHS, clang::Expr* RHS);

/// Check whether this is an assignment to a malloc or realloc call for a
/// derivative variable and build a call to calloc instead if it's a malloc
/// call or add a calloc call after a realloc call, to properly intialize
/// the memory to zero. Currently these configurations of size are supported
/// in malloc or realloc:
/// 1. x * sizeof(T)
/// 2. sizeof(T) * x
/// \param[in] RHS The right-hand side expression of the assignment.
/// @returns The call to calloc if the condition is met, otherwise nullptr.
clang::Expr* CheckAndBuildCallToCalloc(clang::Expr* RHS);

static DeclDiff<clang::StaticAssertDecl>
DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD);

Expand Down
96 changes: 72 additions & 24 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,37 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return atomicAddCall;
}

Expr* ReverseModeVisitor::CheckAndBuildCallToCalloc(Expr* RHS) {
Expr* size = nullptr;
if (auto* callExpr = dyn_cast<CallExpr>(RHS))
if (auto* implCast = dyn_cast<ImplicitCastExpr>(callExpr->getCallee()))
if (auto* declRef = dyn_cast<DeclRefExpr>(implCast->getSubExpr()))
if (auto* FD = dyn_cast<FunctionDecl>(declRef->getDecl())) {
if (FD->getNameAsString() == "malloc")
size = callExpr->getArg(0);
else if (FD->getNameAsString() == "realloc")
size = callExpr->getArg(1);
}

if (size) {
llvm::SmallVector<Expr*, 2> args;
if (auto* BinOp = dyn_cast<BinaryOperator>(size)) {
if (BinOp->getOpcode() == BO_Mul) {
Expr* lhs = BinOp->getLHS();
Expr* rhs = BinOp->getRHS();
if (auto* sizeofCall = dyn_cast<UnaryExprOrTypeTraitExpr>(rhs))
args = {lhs, sizeofCall};
else if (auto* sizeofCall = dyn_cast<UnaryExprOrTypeTraitExpr>(lhs))
args = {rhs, sizeofCall};

Check warning on line 177 in lib/Differentiator/ReverseModeVisitor.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/ReverseModeVisitor.cpp#L176-L177

Added lines #L176 - L177 were not covered by tests
if (!args.empty())
return GetFunctionCall("calloc", "", args);
}
}
}

return {};
}

ReverseModeVisitor::ReverseModeVisitor(DerivativeBuilder& builder,
const DiffRequest& request)
: VisitorBase(builder, request), m_Result(nullptr) {}
Expand Down Expand Up @@ -1697,25 +1728,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
!isa<CXXOperatorCallExpr>(CE))
return StmtDiff(Clone(CE));

// If all arguments are constant literals, then this does not contribute to
// the gradient.
// FIXME: revert this when this is integrated in the activity analysis pass.
if (!isa<CXXMemberCallExpr>(CE) && !isa<CXXOperatorCallExpr>(CE)) {
bool allArgsAreConstantLiterals = true;
for (const Expr* arg : CE->arguments()) {
// if it's of type MaterializeTemporaryExpr, then check its
// subexpression.
if (const auto* MTE = dyn_cast<MaterializeTemporaryExpr>(arg))
arg = clad_compat::GetSubExpr(MTE)->IgnoreImpCasts();
if (!arg->isEvaluatable(m_Context)) {
allArgsAreConstantLiterals = false;
break;
}
}
if (allArgsAreConstantLiterals)
return StmtDiff(Clone(CE), Clone(CE));
}

SourceLocation Loc = CE->getExprLoc();

// Stores the call arguments for the function to be derived
Expand Down Expand Up @@ -1745,11 +1757,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(CallArgs), Loc)
.get();
Expr* call_dx =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs), Loc)
.get();
Expr* call_dx = nullptr;
if (FD->getNameAsString() == "malloc")
call_dx = CheckAndBuildCallToCalloc(Clone(CE));
if (!call_dx)
call_dx = m_Sema
.ActOnCallExpr(
getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs), Loc)
.get();
return StmtDiff(call, call_dx);
}
// For calls to C-style memory deallocation functions, we do not need to
Expand Down Expand Up @@ -1788,6 +1804,25 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff();
}

// If all arguments are constant literals, then this does not contribute to
// the gradient.
// FIXME: revert this when this is integrated in the activity analysis pass.
if (!isa<CXXMemberCallExpr>(CE) && !isa<CXXOperatorCallExpr>(CE)) {
bool allArgsAreConstantLiterals = true;
for (const Expr* arg : CE->arguments()) {
// if it's of type MaterializeTemporaryExpr, then check its
// subexpression.
if (const auto* MTE = dyn_cast<MaterializeTemporaryExpr>(arg))
arg = clad_compat::GetSubExpr(MTE)->IgnoreImpCasts();
if (!arg->isEvaluatable(m_Context)) {
allArgsAreConstantLiterals = false;
break;
}
}
if (allArgsAreConstantLiterals)
return StmtDiff(Clone(CE), Clone(CE));
}

// If the result does not depend on the result of the call, just clone
// the call and visit arguments (since they may contain side-effects like
// f(x = y))
Expand Down Expand Up @@ -2842,6 +2877,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ComputeEffectiveDOperands(Ldiff, Rdiff, derivedL, derivedR);
addToCurrentBlock(BuildOp(opCode, derivedL, derivedR),
direction::forward);
if (opCode == BO_Assign && derivedR)
if (Expr* callocCall =
CheckAndBuildCallToCalloc(derivedR->IgnoreParenCasts())) {
Expr* cast =
m_Sema
.BuildCStyleCastExpr(
SourceLocation(),
m_Context.getTrivialTypeSourceInfo(derivedL->getType()),
SourceLocation(), callocCall)
.get();
addToCurrentBlock(BuildOp(BO_Assign, derivedL, cast),
direction::forward);
}
}
}
return StmtDiff(op, ResultRef, nullptr, valueForRevPass);
Expand Down
13 changes: 9 additions & 4 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,14 @@ namespace clad {
Expr* VisitorBase::GetFunctionCall(const std::string& funcName,
const std::string& nmspace,
llvm::SmallVectorImpl<Expr*>& callArgs) {
NamespaceDecl* NSD =
utils::LookupNSD(m_Sema, nmspace, /*shouldExist=*/true);
DeclContext* DC = NSD;
NamespaceDecl* NSD = nullptr;
CXXScopeSpec SS;
SS.Extend(m_Context, NSD, noLoc, noLoc);

if (!nmspace.empty()) {
NSD = utils::LookupNSD(m_Sema, nmspace, /*shouldExist=*/true);
SS.Extend(m_Context, NSD, noLoc, noLoc);
}
DeclContext* DC = NSD;

IdentifierInfo* II = &m_Context.Idents.get(funcName);
DeclarationName name(II);
Expand All @@ -544,6 +547,8 @@ namespace clad {

if (DC)
m_Sema.LookupQualifiedName(R, DC);
else
m_Sema.LookupQualifiedName(R, m_Context.getTranslationUnitDecl());
Expr* UnresolvedLookup = nullptr;
if (!R.empty())
UnresolvedLookup =
Expand Down
3 changes: 2 additions & 1 deletion test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ double cStyleMemoryAlloc(double x, size_t n) {

// CHECK: void cStyleMemoryAlloc_grad_0(double x, size_t n, double *_d_x) {
// CHECK-NEXT: size_t _d_n = 0UL;
// CHECK-NEXT: T *_d_t = (T *)malloc(n * sizeof(T));
// CHECK-NEXT: T *_d_t = (T *)calloc(n, sizeof(T));
// CHECK-NEXT: T *t = (T *)malloc(n * sizeof(T));
// CHECK-NEXT: memset(_d_t, 0, n * sizeof(T));
// CHECK-NEXT: memset(t, 0, n * sizeof(T));
Expand All @@ -422,6 +422,7 @@ double cStyleMemoryAlloc(double x, size_t n) {
// CHECK-NEXT: double *_t2 = p;
// CHECK-NEXT: double *_t3 = _d_p;
// CHECK-NEXT: _d_p = (double *)realloc(_d_p, 2 * sizeof(double));
// CHECK-NEXT: _d_p = (double *)calloc(2, sizeof(double));
// CHECK-NEXT: p = (double *)realloc(p, 2 * sizeof(double));
// CHECK-NEXT: double _t4 = p[1];
// CHECK-NEXT: p[1] = 2 * x;
Expand Down

0 comments on commit f545a08

Please sign in to comment.