Skip to content

Commit

Permalink
Move the checks of call arguments for being temporary exprs to a routine
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Mar 18, 2024
1 parent d017d19 commit 925700c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
9 changes: 5 additions & 4 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,12 @@ namespace clad {
/// otherwise returns false.
bool HasAnyReferenceOrPointerArgument(const clang::FunctionDecl* FD);

/// Returns true if `T` is a reference, pointer or array type.
/// Returns true if `arg` is an argument passed by reference or is of
/// pointer/array type.
///
/// \note Please note that this function returns true for array types as
/// well.
bool IsReferenceOrPointerType(clang::QualType T);
/// \note Please note that this function returns false for temporary
/// expressions.
bool IsReferenceOrPointerArg(const clang::Expr* arg);

/// Returns true if `T1` and `T2` have same cononical type; otherwise
/// returns false.
Expand Down
8 changes: 6 additions & 2 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,12 @@ namespace clad {
return false;
}

bool IsReferenceOrPointerType(QualType T) {
return T->isReferenceType() || isArrayOrPointerType(T);
bool IsReferenceOrPointerArg(const Expr* arg) {
// The argument is passed by reference if it's passed as an L-value.
// However, if arg is a MaterializeTemporaryExpr, then arg is a
// temporary variable passed as a const reference.
bool isRefType = arg->isLValue() && !isa<MaterializeTemporaryExpr>(arg);
return isRefType || isArrayOrPointerType(arg->getType());
}

bool SameCanonicalType(clang::QualType T1, clang::QualType T2) {
Expand Down
7 changes: 2 additions & 5 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1537,8 +1537,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// We do not need to create result arg for arguments passed by reference
// because the derivatives of arguments passed by reference are directly
// modified by the derived callee function.
if (utils::IsReferenceOrPointerType(PVD->getType()) &&
!isa<MaterializeTemporaryExpr>(arg)) {
if (utils::IsReferenceOrPointerArg(arg)) {
argDiff = Visit(arg);
CallArgDx.push_back(argDiff.getExpr_dx());
} else {
Expand Down Expand Up @@ -1723,9 +1722,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
for (auto* argDerivative : CallArgDx) {
Expr* gradArgExpr = nullptr;
const Expr* arg = CE->getArg(idx);
const auto* PVD = FD->getParamDecl(idx);
if (utils::IsReferenceOrPointerType(PVD->getType()) &&
!isa<MaterializeTemporaryExpr>(arg)) {
if (utils::IsReferenceOrPointerArg(arg)) {
if (argDerivative) {
if (utils::isArrayOrPointerType(argDerivative->getType()) ||
isCladArrayType(argDerivative->getType()) ||
Expand Down

0 comments on commit 925700c

Please sign in to comment.