diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 39c31691c..7ef3aad1f 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -77,7 +77,7 @@ namespace clad { friend class ReverseModeVisitor; friend class HessianModeVisitor; friend class JacobianModeVisitor; - + friend class ReverseModeForwPassVisitor; clang::Sema& m_Sema; plugin::CladPlugin& m_CladPlugin; clang::ASTContext& m_Context; diff --git a/include/clad/Differentiator/DiffMode.h b/include/clad/Differentiator/DiffMode.h index a9c27a935..a03e77e49 100644 --- a/include/clad/Differentiator/DiffMode.h +++ b/include/clad/Differentiator/DiffMode.h @@ -11,6 +11,7 @@ enum class DiffMode { reverse, hessian, jacobian, + reverse_mode_forward_pass, error_estimation }; } diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index b59f77189..9a5d7e3d6 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -21,6 +21,12 @@ #include namespace clad { + template + struct ValueAndAdjoint { + T value; + U adjoint; + }; + /// \returns the size of a c-style string CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { unsigned int count; diff --git a/include/clad/Differentiator/ReverseModeForwPassVisitor.h b/include/clad/Differentiator/ReverseModeForwPassVisitor.h new file mode 100644 index 000000000..32edf3e6d --- /dev/null +++ b/include/clad/Differentiator/ReverseModeForwPassVisitor.h @@ -0,0 +1,38 @@ +#ifndef CLAD_DIFFERENTIATOR_REVERSEMODEFORWPASSVISITOR_H +#define CLAD_DIFFERENTIATOR_REVERSEMODEFORWPASSVISITOR_H + +#include "clad/Differentiator/ParseDiffArgsTypes.h" +#include "clad/Differentiator/ReverseModeVisitor.h" + +#include "clang/AST/StmtVisitor.h" +#include "clang/Sema/Sema.h" + +#include "llvm/ADT/SmallVector.h" + +namespace clad { +class ReverseModeForwPassVisitor : public ReverseModeVisitor { +private: + Stmts m_Globals; + + llvm::SmallVector + ComputeParamTypes(const DiffParams& diffParams); + clang::QualType ComputeReturnType(); + llvm::SmallVector BuildParams(DiffParams& diffParams); + clang::QualType GetParameterDerivativeType(clang::QualType yType, + clang::QualType xType); + +public: + ReverseModeForwPassVisitor(DerivativeBuilder& builder); + DerivativeAndOverload Derive(const clang::FunctionDecl* FD, + const DiffRequest& request); + + StmtDiff ProcessSingleStmt(const clang::Stmt* S); + + StmtDiff VisitStmt(const clang::Stmt* S) override; + StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override; + StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE) override; + StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; +}; +} // namespace clad + +#endif \ No newline at end of file diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 3b668ff87..2b05c5bd5 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -31,7 +31,7 @@ namespace clad { : public clang::ConstStmtVisitor, public VisitorBase { - private: + protected: // FIXME: We should remove friend-dependency of the plugin classes here. // For this we will need to separate out AST related functions in // a separate namespace, as well as add getters/setters function of @@ -321,11 +321,11 @@ namespace clad { StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp); StmtDiff VisitCallExpr(const clang::CallExpr* CE); - StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); + virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO); StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL); StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE); - StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE); + virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE); StmtDiff VisitDeclStmt(const clang::DeclStmt* DS); StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL); StmtDiff VisitForStmt(const clang::ForStmt* FS); @@ -335,8 +335,8 @@ namespace clad { StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL); StmtDiff VisitMemberExpr(const clang::MemberExpr* ME); StmtDiff VisitParenExpr(const clang::ParenExpr* PE); - StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS); - StmtDiff VisitStmt(const clang::Stmt* S); + virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS); + virtual StmtDiff VisitStmt(const clang::Stmt* S); StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp); StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC); /// Decl is not Stmt, so it cannot be visited directly. diff --git a/lib/Differentiator/CMakeLists.txt b/lib/Differentiator/CMakeLists.txt index 61695990e..f5eddb2c6 100644 --- a/lib/Differentiator/CMakeLists.txt +++ b/lib/Differentiator/CMakeLists.txt @@ -31,6 +31,7 @@ add_llvm_library(cladDifferentiator HessianModeVisitor.cpp JacobianModeVisitor.cpp MultiplexExternalRMVSource.cpp + ReverseModeForwPassVisitor.cpp ReverseModeVisitor.cpp StmtClone.cpp VectorForwardModeVisitor.cpp diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index d5a15580b..a67f6ed44 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -22,6 +22,8 @@ #include "clad/Differentiator/HessianModeVisitor.h" #include "clad/Differentiator/JacobianModeVisitor.h" #include "clad/Differentiator/ReverseModeVisitor.h" +#include "clad/Differentiator/ReverseModeForwPassVisitor.h" +#include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/StmtClone.h" #include "clad/Differentiator/VectorForwardModeVisitor.h" @@ -230,6 +232,9 @@ namespace clad { result = V.DerivePullback(FD, request); if (!m_ErrorEstHandler.empty()) CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel); + } else if (request.Mode == DiffMode::reverse_mode_forward_pass) { + ReverseModeForwPassVisitor V(*this); + result = V.Derive(FD, request); } else if (request.Mode == DiffMode::hessian) { HessianModeVisitor H(*this); result = H.Derive(FD, request); diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp new file mode 100644 index 000000000..87eb965e3 --- /dev/null +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -0,0 +1,274 @@ +#include "clad/Differentiator/ReverseModeForwPassVisitor.h" + +#include "clad/Differentiator/CladUtils.h" +#include "clad/Differentiator/DiffPlanner.h" +#include "clad/Differentiator/ErrorEstimator.h" + +#include "llvm/Support/SaveAndRestore.h" + +#include + +using namespace clang; + +namespace clad { + +ReverseModeForwPassVisitor::ReverseModeForwPassVisitor( + DerivativeBuilder& builder) + : ReverseModeVisitor(builder) {} + +DerivativeAndOverload +ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, + const DiffRequest& request) { + silenceDiags = !request.VerboseDiags; + m_Function = FD; + + m_Mode = DiffMode::reverse_mode_forward_pass; + + assert(m_Function && "Must not be null."); + + DiffParams args{}; + std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); + + auto fnName = m_Function->getNameAsString() + "_forw"; + auto fnDNI = utils::BuildDeclarationNameInfo(m_Sema, fnName); + + auto paramTypes = ComputeParamTypes(args); + auto returnType = ComputeReturnType(); + const auto *sourceFnType = dyn_cast(m_Function->getType()); + auto fnType = m_Context.getFunctionType(returnType, paramTypes, + sourceFnType->getExtProtoInfo()); + + llvm::SaveAndRestore saveContext(m_Sema.CurContext); + llvm::SaveAndRestore saveScope(m_CurScope); + m_Sema.CurContext = const_cast(m_Function->getDeclContext()); + + DeclWithContext fnBuildRes = + m_Builder.cloneFunction(m_Function, *this, m_Sema.CurContext, m_Sema, + m_Context, noLoc, fnDNI, fnType); + m_Derivative = fnBuildRes.first; + + beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | + Scope::DeclScope); + m_Sema.PushFunctionScope(); + m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); + + auto params = BuildParams(args); + m_Derivative->setParams(params); + m_Derivative->setBody(nullptr); + + beginScope(Scope::FnScope | Scope::DeclScope); + m_DerivativeFnScope = getCurrentScope(); + + beginBlock(); + + StmtDiff bodyDiff = Visit(m_Function->getBody()); + Stmt* forward = bodyDiff.getStmt(); + + for (Stmt* S : ReverseModeVisitor::m_Globals) + addToCurrentBlock(S); + + if (auto *CS = dyn_cast(forward)) + for (Stmt* S : CS->body()) + addToCurrentBlock(S); + + Stmt* fnBody = endBlock(); + m_Derivative->setBody(fnBody); + endScope(); + m_Sema.PopFunctionScopeInfo(); + m_Sema.PopDeclContext(); + endScope(); + return DerivativeAndOverload{m_Derivative, nullptr}; +} + +// FIXME: This function is copied from ReverseModeVisitor. Find a suitable place +// for it. +QualType +ReverseModeForwPassVisitor::GetParameterDerivativeType(QualType yType, + QualType xType) { + assert(yType.getNonReferenceType()->isRealType() && + "yType should be a builtin-numerical scalar type!!"); + QualType xValueType = utils::GetValueType(xType); + // derivative variables should always be of non-const type. + xValueType.removeLocalConst(); + QualType nonRefXValueType = xValueType.getNonReferenceType(); + if (nonRefXValueType->isRealType()) + return GetCladArrayRefOfType(yType); + return GetCladArrayRefOfType(nonRefXValueType); +} + +llvm::SmallVector +ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) { + llvm::SmallVector paramTypes; + paramTypes.reserve(m_Function->getNumParams() * 2); + for (auto *PVD : m_Function->parameters()) + paramTypes.push_back(PVD->getType()); + + QualType effectiveReturnType = + m_Function->getReturnType().getNonReferenceType(); + + if (const auto *MD = dyn_cast(m_Function)) { + const CXXRecordDecl* RD = MD->getParent(); + if (MD->isInstance() && !RD->isLambda()) { + QualType thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); + paramTypes.push_back( + GetParameterDerivativeType(effectiveReturnType, thisType)); + } + } + + for (auto *PVD : m_Function->parameters()) { + const auto *it = std::find(std::begin(diffParams), std::end(diffParams), PVD); + if (it != std::end(diffParams)) { + paramTypes.push_back( + GetParameterDerivativeType(effectiveReturnType, PVD->getType())); + } + } + return paramTypes; +} + +clang::QualType ReverseModeForwPassVisitor::ComputeReturnType() { + auto *valAndAdjointTempDecl = LookupTemplateDeclInCladNamespace("ValueAndAdjoint"); + auto RT = m_Function->getReturnType(); + auto T = InstantiateTemplate(valAndAdjointTempDecl, {RT, RT}); + return T; +} + +llvm::SmallVector +ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) { + llvm::SmallVector params; + llvm::SmallVector paramDerivatives; + params.reserve(m_Function->getNumParams() + diffParams.size()); + const auto *derivativeFnType = cast(m_Derivative->getType()); + + std::size_t dParamTypesIdx = m_Function->getNumParams(); + + if (const auto *MD = dyn_cast(m_Function)) { + const CXXRecordDecl* RD = MD->getParent(); + if (MD->isInstance() && !RD->isLambda()) { + auto *thisDerivativePVD = utils::BuildParmVarDecl( + m_Sema, m_Derivative, CreateUniqueIdentifier("_d_this"), + derivativeFnType->getParamType(dParamTypesIdx)); + paramDerivatives.push_back(thisDerivativePVD); + + if (thisDerivativePVD->getIdentifier()) + m_Sema.PushOnScopeChains(thisDerivativePVD, getCurrentScope(), + /*AddToContext=*/false); + + Expr* deref = + BuildOp(UnaryOperatorKind::UO_Deref, BuildDeclRef(thisDerivativePVD)); + m_ThisExprDerivative = utils::BuildParenExpr(m_Sema, deref); + ++dParamTypesIdx; + } + } + for (auto *PVD : m_Function->parameters()) { + // FIXME: Call expression may contain default arguments that we are now + // removing. This may cause issues. + auto *newPVD = utils::BuildParmVarDecl( + m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(), + PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo()); + params.push_back(newPVD); + + if (newPVD->getIdentifier()) + m_Sema.PushOnScopeChains(newPVD, getCurrentScope(), + /*AddToContext=*/false); + + auto *it = std::find(std::begin(diffParams), std::end(diffParams), PVD); + if (it != std::end(diffParams)) { + *it = newPVD; + QualType dType = derivativeFnType->getParamType(dParamTypesIdx); + IdentifierInfo* dII = + CreateUniqueIdentifier("_d_" + PVD->getNameAsString()); + auto *dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType, + PVD->getStorageClass()); + paramDerivatives.push_back(dPVD); + ++dParamTypesIdx; + + if (dPVD->getIdentifier()) + m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), + /*AddToContext=*/false); + + if (utils::isArrayOrPointerType(PVD->getType())) { + m_Variables[*it] = (Expr*)BuildDeclRef(dPVD); + } else { + QualType valueType = DetermineCladArrayValueType(dPVD->getType()); + m_Variables[*it] = + BuildOp(UO_Deref, BuildDeclRef(dPVD), m_Function->getLocation()); + // Add additional paranthesis if derivative is of record type + // because `*derivative.someField` will be incorrectly evaluated if + // the derived function is compiled standalone. + if (valueType->isRecordType()) + m_Variables[*it] = utils::BuildParenExpr(m_Sema, m_Variables[*it]); + } + } + } + params.insert(params.end(), paramDerivatives.begin(), paramDerivatives.end()); + return params; +} + +StmtDiff ReverseModeForwPassVisitor::ProcessSingleStmt(const clang::Stmt* S) { + StmtDiff SDiff = Visit(S); + return {SDiff.getStmt()}; +} + +StmtDiff ReverseModeForwPassVisitor::VisitStmt(const clang::Stmt* S) { + return {Clone(S)}; +} + +StmtDiff +ReverseModeForwPassVisitor::VisitCompoundStmt(const clang::CompoundStmt* CS) { + beginScope(Scope::DeclScope); + beginBlock(); + for (Stmt* S : CS->body()) { + StmtDiff SDiff = ProcessSingleStmt(S); + addToCurrentBlock(SDiff.getStmt()); + } + CompoundStmt* forward = endBlock(); + endScope(); + return {forward}; +} + +StmtDiff ReverseModeForwPassVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { + DeclRefExpr* clonedDRE = nullptr; + // Check if referenced Decl was "replaced" with another identifier inside + // the derivative + if (const auto *VD = dyn_cast(DRE->getDecl())) { + auto it = m_DeclReplacements.find(VD); + if (it != std::end(m_DeclReplacements)) + clonedDRE = BuildDeclRef(it->second); + else + clonedDRE = cast(Clone(DRE)); + // If current context is different than the context of the original + // declaration (e.g. we are inside lambda), rebuild the DeclRefExpr + // with Sema::BuildDeclRefExpr. This is required in some cases, e.g. + // Sema::BuildDeclRefExpr is responsible for adding captured fields + // to the underlying struct of a lambda. + if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) { + auto *referencedDecl = cast(clonedDRE->getDecl()); + clonedDRE = cast(BuildDeclRef(referencedDecl)); + } + } else + clonedDRE = cast(Clone(DRE)); + + if (auto *decl = dyn_cast(clonedDRE->getDecl())) { + // Check DeclRefExpr is a reference to an independent variable. + auto it = m_Variables.find(decl); + if (it == std::end(m_Variables)) { + // Is not an independent variable, ignored. + return StmtDiff(clonedDRE); + } + return StmtDiff(clonedDRE, it->second); + } + + return StmtDiff(clonedDRE); +} + +StmtDiff +ReverseModeForwPassVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) { + const Expr* value = RS->getRetValue(); + auto returnDiff = Visit(value); + llvm::SmallVector returnArgs = {returnDiff.getExpr(), + returnDiff.getExpr_dx()}; + Expr* returnInitList = m_Sema.ActOnInitList(noLoc, returnArgs, noLoc).get(); + Stmt* newRS = m_Sema.BuildReturnStmt(noLoc, returnInitList).get(); + return {newRS}; +} +} // namespace clad \ No newline at end of file diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 4e9805b64..a5463fb93 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1673,15 +1673,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, ArgDeclStmts.push_back(BuildDeclStmt(gradVarDecl)); idx++; } + Expr* pullback = dfdx(); + if ((pullback == nullptr) && FD->getReturnType()->isLValueReferenceType()) + pullback = getZeroInit(FD->getReturnType().getNonReferenceType()); + // FIXME: Remove this restriction. if (!FD->getReturnType()->isVoidType()) { - assert((dfdx() && !FD->getReturnType()->isVoidType()) && + assert((pullback && !FD->getReturnType()->isVoidType()) && "Call to function returning non-void type with no dfdx() is not " "supported!"); } if (FD->getReturnType()->isVoidType()) { - assert(dfdx() == nullptr && FD->getReturnType()->isVoidType() && + assert(pullback == nullptr && FD->getReturnType()->isVoidType() && "Call to function returning void type should not have any " "corresponding dfdx()."); } @@ -1691,9 +1695,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, DerivedCallOutputArgs.end()); pullbackCallArgs = DerivedCallArgs; - if (dfdx()) + if (pullback) pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs(), - dfdx()); + pullback); // Try to find it in builtin derivatives std::string customPullback = FD->getNameAsString() + "_pullback"; @@ -1857,15 +1861,83 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::end(CallArgs), std::begin(CallArgs), [this](Expr* E) { return Clone(E); }); - // Recreate the original call expression. - Expr* call = m_Sema - .ActOnCallExpr(getCurrentScope(), - Clone(CE->getCallee()), - noLoc, - CallArgs, - noLoc) - .get(); - return StmtDiff(call); + + Expr* call = nullptr; + + if (FD->getReturnType()->isReferenceType()) { + DiffRequest calleeFnForwPassReq; + calleeFnForwPassReq.Function = FD; + calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass; + calleeFnForwPassReq.BaseFunctionName = FD->getNameAsString(); + calleeFnForwPassReq.VerboseDiags = true; + FunctionDecl* calleeFnForwPassFD = + plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq); + + assert(calleeFnForwPassFD && + "Clad failed to generate callee function forward pass function"); + + // FIXME: We are using the derivatives in forward pass here + // If `expr_dx()` is only meant to be used in reverse pass, + // (for example, `clad::pop(...)` expression and a corresponding + // `clad::push(...)` in the forward pass), then this can result in + // incorrect derivative or crash at runtime. Ideally, we should have + // a separate routine to use derivative in the forward pass. + + // We cannot reuse the derivatives previously computed because + // they might contain 'clad::pop(..)` expression. + if (isa(CE)) { + Expr* derivedBase = baseDiff.getExpr_dx(); + // FIXME: We may need this if-block once we support pointers, and passing pointers-by-reference + // if (isCladArrayType(derivedBase->getType())) + // CallArgs.push_back(derivedBase); + // else + CallArgs.push_back( + BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, noLoc)); + } + + for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { + const Expr* arg = CE->getArg(i); + const ParmVarDecl* PVD = FD->getParamDecl(i); + StmtDiff argDiff = Visit(arg); + if ((argDiff.getExpr_dx() != nullptr) && PVD->getType()->isReferenceType()) { + Expr* derivedArg = argDiff.getExpr_dx(); + // FIXME: We may need this if-block once we support pointers, and passing pointers-by-reference + // if (isCladArrayType(derivedArg->getType())) + // CallArgs.push_back(derivedArg); + // else + CallArgs.push_back( + BuildOp(UnaryOperatorKind::UO_AddrOf, derivedArg, noLoc)); + } else + CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(noLoc).get()); + } + if (isa(CE)) { + Expr* baseE = baseDiff.getExpr(); + call = BuildCallExprToMemFn( + baseE, calleeFnForwPassFD->getName(), CallArgs, calleeFnForwPassFD); + } else { + call = m_Sema + .ActOnCallExpr(getCurrentScope(), + BuildDeclRef(calleeFnForwPassFD), noLoc, + CallArgs, noLoc) + .get(); + } + auto *callRes = StoreAndRef(call); + auto *resValue = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value"); + auto *resAdjoint = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); + return StmtDiff(resValue, nullptr, resAdjoint); + } // Recreate the original call expression. + call = m_Sema + .ActOnCallExpr(getCurrentScope(), + Clone(CE->getCallee()), + noLoc, + CallArgs, + noLoc) + .get(); + return StmtDiff(call); + + return {}; } StmtDiff ReverseModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { @@ -2312,7 +2384,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (isDerivativeOfRefType) { initDiff = Visit(VD->getInit()); - if (!initDiff.getExpr_dx()) { + if (!initDiff.getForwSweepExpr_dx()) { VDDerivedType = ComputeAdjointType(VD->getType().getNonReferenceType()); isDerivativeOfRefType = false; @@ -3136,7 +3208,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // TODO: Add DiffMode::experimental_pullback support here as well. if (m_Mode == DiffMode::reverse || m_Mode == DiffMode::experimental_pullback) { - QualType effectiveReturnType = m_Function->getReturnType(); + QualType effectiveReturnType = m_Function->getReturnType().getNonReferenceType(); if (m_Mode == DiffMode::experimental_pullback) { // FIXME: Generally, we use the function's return type as the argument's // derivative type. We cannot follow this strategy for `void` function @@ -3150,7 +3222,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (effectiveReturnType->isVoidType()) effectiveReturnType = m_Context.DoubleTy; else - paramTypes.push_back(m_Function->getReturnType()); + paramTypes.push_back(effectiveReturnType); } if (auto MD = dyn_cast(m_Function)) { diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 725731bc4..c6944c985 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -351,6 +351,92 @@ double fn6(double i=0, double j=0) { return i*j; } +double& identity(double& i) { + return i; +} + +double fn7(double i, double j) { + double& k = identity(i); + double& l = identity(j); + k += 7*j; + l += 9*i; + return i + j; +} + +// CHECK: void fn6_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double _t1; +// CHECK-NEXT: _t1 = i; +// CHECK-NEXT: _t0 = j; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 1 * _t0; +// CHECK-NEXT: * _d_i += _r0; +// CHECK-NEXT: double _r1 = _t1 * 1; +// CHECK-NEXT: * _d_j += _r1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK: void identity_pullback(double &i, double _d_y, clad::array_ref _d_i) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: * _d_i += _d_y; +// CHECK-NEXT: } +// CHECK: clad::ValueAndAdjoint identity_forw(double &i, clad::array_ref _d_i) { +// CHECK-NEXT: return {i, * _d_i}; +// CHECK-NEXT: } +// CHECK: void fn7_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double *_d_k = 0; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double *_d_l = 0; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double _t5; +// CHECK-NEXT: _t0 = i; +// CHECK-NEXT: clad::ValueAndAdjoint _t1 = identity_forw(i, &* _d_i); +// CHECK-NEXT: _d_k = &_t1.adjoint; +// CHECK-NEXT: double &k = _t1.value; +// CHECK-NEXT: _t2 = j; +// CHECK-NEXT: clad::ValueAndAdjoint _t3 = identity_forw(j, &* _d_j); +// CHECK-NEXT: _d_l = &_t3.adjoint; +// CHECK-NEXT: double &l = _t3.value; +// CHECK-NEXT: _t4 = j; +// CHECK-NEXT: k += 7 * _t4; +// CHECK-NEXT: _t5 = i; +// CHECK-NEXT: l += 9 * _t5; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: * _d_i += 1; +// CHECK-NEXT: * _d_j += 1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d1 = *_d_l; +// CHECK-NEXT: *_d_l += _r_d1; +// CHECK-NEXT: double _r4 = _r_d1 * _t5; +// CHECK-NEXT: double _r5 = 9 * _r_d1; +// CHECK-NEXT: * _d_i += _r5; +// CHECK-NEXT: *_d_l -= _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d0 = *_d_k; +// CHECK-NEXT: *_d_k += _r_d0; +// CHECK-NEXT: double _r2 = _r_d0 * _t4; +// CHECK-NEXT: double _r3 = 7 * _r_d0; +// CHECK-NEXT: * _d_j += _r3; +// CHECK-NEXT: *_d_k -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: identity_pullback(_t2, 0, &* _d_j); +// CHECK-NEXT: double _r1 = * _d_j; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: identity_pullback(_t0, 0, &* _d_i); +// CHECK-NEXT: double _r0 = * _d_i; +// CHECK-NEXT: } +// CHECK-NEXT: } + + template void reset(T* arr, int n) { for (int i=0; i _d_i, clad::array_ref _d_j); void const_mem_fn_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j); void volatile_mem_fn_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j); @@ -756,6 +758,35 @@ double fn(double i,double j) { // CHECK-NEXT: } // CHECK-NEXT: } +double fn2(SimpleFunctions& sf, double i) { + return sf.ref_mem_fn(i); +} + +// CHECK: void ref_mem_fn_pullback(double i, double _d_y, clad::array_ref _d_this, clad::array_ref _d_i) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: (* _d_this).x += _d_y; +// CHECK-NEXT: } +// CHECK: clad::ValueAndAdjoint ref_mem_fn_forw(double i, clad::array_ref _d_this, clad::array_ref _d_i) { +// CHECK-NEXT: return {this->x, (* _d_this).x}; +// CHECK-NEXT: } +// CHECK: void fn2_grad(SimpleFunctions &sf, double i, clad::array_ref _d_sf, clad::array_ref _d_i) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: SimpleFunctions _t1; +// CHECK-NEXT: _t0 = i; +// CHECK-NEXT: _t1 = sf; +// CHECK-NEXT: clad::ValueAndAdjoint _t2 = _t1.ref_mem_fn_forw(_t0, &(* _d_sf), nullptr); +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: double _grad0 = 0.; +// CHECK-NEXT: _t1.ref_mem_fn_pullback(_t0, 1, &(* _d_sf), &_grad0); +// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: * _d_i += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + + int main() { auto d_mem_fn = clad::gradient(&SimpleFunctions::mem_fn); auto d_const_mem_fn = clad::gradient(&SimpleFunctions::const_mem_fn); @@ -790,6 +821,12 @@ int main() { printf("%.2f ",result[i]); //CHECK-EXEC: 40.00 16.00 } + SimpleFunctions sf(2, 3); + SimpleFunctions d_sf; + auto d_fn2 = clad::gradient(fn2); + d_fn2.execute(sf, 2, &d_sf, &result[0]); + printf("%.2f", result[0]); //CHECK-EXEC: 40.00 + auto d_const_volatile_lval_ref_mem_fn_i = clad::gradient(&SimpleFunctions::const_volatile_lval_ref_mem_fn, "i"); // CHECK: void const_volatile_lval_ref_mem_fn_grad_0(double i, double j, clad::array_ref _d_this, clad::array_ref _d_i) const volatile & {