Skip to content

Commit

Permalink
Support pointer-valued functions in the reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Aug 16, 2024
1 parent 2e3a53a commit f5f56a1
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
const Expr* value = RS->getRetValue();
QualType type = value->getType();
auto* dfdf = m_Pullback;
if (isa<FloatingLiteral>(dfdf) || isa<IntegerLiteral>(dfdf)) {
if (dfdf && (isa<FloatingLiteral>(dfdf) || isa<IntegerLiteral>(dfdf))) {
ExprResult tmp = dfdf;
dfdf = m_Sema
.ImpCastExprToType(tmp.get(), type,
Expand Down Expand Up @@ -2007,8 +2007,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* call = nullptr;

QualType returnType = FD->getReturnType();
if (returnType->isReferenceType() &&
!returnType.getNonReferenceType().isConstQualified()) {
if ((returnType->isReferenceType() &&
!returnType.getNonReferenceType().isConstQualified())
|| returnType->isPointerType()) {
DiffRequest calleeFnForwPassReq;
calleeFnForwPassReq.Function = FD;
calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass;
Expand Down Expand Up @@ -2066,7 +2067,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value");
auto* resAdjoint =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint");
return StmtDiff(resValue, nullptr, resAdjoint);
return StmtDiff(resValue, resAdjoint, resAdjoint);
} // Recreate the original call expression.
call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
Expand Down Expand Up @@ -4097,7 +4098,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// respect to variable of type X, then the derivative should be of type
// X. Check this related issue for more details:
// https://github.com/vgvassilev/clad/issues/385
if (effectiveReturnType->isVoidType())
if (effectiveReturnType->isVoidType() || effectiveReturnType->isPointerType())
effectiveReturnType = m_Context.DoubleTy;
else
paramTypes.push_back(effectiveReturnType);
Expand Down Expand Up @@ -4137,7 +4138,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
std::size_t dParamTypesIdx = m_DiffReq->getNumParams();

if (m_DiffReq.Mode == DiffMode::experimental_pullback &&
!m_DiffReq->getReturnType()->isVoidType()) {
!m_DiffReq->getReturnType()->isVoidType() && !m_DiffReq->getReturnType()->isPointerType()) {
++dParamTypesIdx;
}

Expand Down Expand Up @@ -4208,7 +4209,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

if (m_DiffReq.Mode == DiffMode::experimental_pullback &&
!m_DiffReq->getReturnType()->isVoidType()) {
!m_DiffReq->getReturnType()->isVoidType() && !m_DiffReq->getReturnType()->isPointerType()) {
IdentifierInfo* pullbackParamII = CreateUniqueIdentifier("_d_y");
QualType pullbackType =
derivativeFnType->getParamType(m_DiffReq->getNumParams());
Expand Down

0 comments on commit f5f56a1

Please sign in to comment.