diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 3a3656572..fd6f48c55 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1314,7 +1314,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, const Expr* value = RS->getRetValue(); QualType type = value->getType(); auto* dfdf = m_Pullback; - if (isa(dfdf) || isa(dfdf)) { + if (dfdf && (isa(dfdf) || isa(dfdf))) { ExprResult tmp = dfdf; dfdf = m_Sema .ImpCastExprToType(tmp.get(), type, @@ -2007,8 +2007,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Expr* call = nullptr; QualType returnType = FD->getReturnType(); - if (returnType->isReferenceType() && - !returnType.getNonReferenceType().isConstQualified()) { + if ((returnType->isReferenceType() && + !returnType.getNonReferenceType().isConstQualified()) + || returnType->isPointerType()) { DiffRequest calleeFnForwPassReq; calleeFnForwPassReq.Function = FD; calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass; @@ -2066,7 +2067,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value"); auto* resAdjoint = utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); - return StmtDiff(resValue, nullptr, resAdjoint); + return StmtDiff(resValue, resAdjoint, resAdjoint); } // Recreate the original call expression. call = m_Sema .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, @@ -4097,7 +4098,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // respect to variable of type X, then the derivative should be of type // X. Check this related issue for more details: // https://github.com/vgvassilev/clad/issues/385 - if (effectiveReturnType->isVoidType()) + if (effectiveReturnType->isVoidType() || effectiveReturnType->isPointerType()) effectiveReturnType = m_Context.DoubleTy; else paramTypes.push_back(effectiveReturnType); @@ -4137,7 +4138,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::size_t dParamTypesIdx = m_DiffReq->getNumParams(); if (m_DiffReq.Mode == DiffMode::experimental_pullback && - !m_DiffReq->getReturnType()->isVoidType()) { + !m_DiffReq->getReturnType()->isVoidType() && !m_DiffReq->getReturnType()->isPointerType()) { ++dParamTypesIdx; } @@ -4208,7 +4209,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } if (m_DiffReq.Mode == DiffMode::experimental_pullback && - !m_DiffReq->getReturnType()->isVoidType()) { + !m_DiffReq->getReturnType()->isVoidType() && !m_DiffReq->getReturnType()->isPointerType()) { IdentifierInfo* pullbackParamII = CreateUniqueIdentifier("_d_y"); QualType pullbackType = derivativeFnType->getParamType(m_DiffReq->getNumParams());