diff --git a/demos/ComputerGraphics/smallpt/SmallPT.cpp b/demos/ComputerGraphics/smallpt/SmallPT.cpp index 54c21c62a..468bca694 100644 --- a/demos/ComputerGraphics/smallpt/SmallPT.cpp +++ b/demos/ComputerGraphics/smallpt/SmallPT.cpp @@ -16,7 +16,8 @@ // ./SmallPT 500 && xv image.ppm // A typical invocation would be: -// ../../../../../obj/Debug+Asserts/bin/clang++ -O3 -Xclang -add-plugin -Xclang clad \ +// ../../../../../obj/Debug+Asserts/bin/clang++ -O3 -Xclang -add-plugin -Xclang +// clad \ // -Xclang -load -Xclang ../../../../../obj/Debug+Asserts/lib/libclad.dylib \ // -I../../include/ -std=c++11 SmallPT.cpp -fopenmp=libiomp5 -o SmallPT // ./SmallPT 500 && xv image.ppm diff --git a/demos/OpenCL/RosenbrockFunction.cpp b/demos/OpenCL/RosenbrockFunction.cpp index efa97c133..9b7ef154d 100644 --- a/demos/OpenCL/RosenbrockFunction.cpp +++ b/demos/OpenCL/RosenbrockFunction.cpp @@ -8,11 +8,13 @@ // To run the demo please type: // path/to/clang++ -Xclang -add-plugin -Xclang clad -Xclang -load -Xclang \ -// path/to/libclad.so -I../include/ -framework opencl -std=c++11 RosenbrockFunction.cpp +// path/to/libclad.so -I../include/ -framework opencl -std=c++11 +// RosenbrockFunction.cpp // // A typical invocation would be: -// ../../../../../obj/Debug+Asserts/bin/clang++ -Xclang -add-plugin -Xclang clad \ -// -Xclang -load -Xclang ../../../../../obj/Debug+Asserts/lib/libclad.dylib \ +// ../../../../../obj/Debug+Asserts/bin/clang++ -Xclang -add-plugin -Xclang +// clad \ +// -Xclang -load -Xclang ../../../../../obj/Debug+Asserts/lib/libclad.dylib \ // -I../../include/ -framework opencl -std=c++11 RosenbrockFunction.cpp // Necessary for clad to work include diff --git a/include/clad/Differentiator/Compatibility.h b/include/clad/Differentiator/Compatibility.h index fdc76a9b4..fc43f6554 100644 --- a/include/clad/Differentiator/Compatibility.h +++ b/include/clad/Differentiator/Compatibility.h @@ -75,15 +75,15 @@ static inline NamespaceDecl* NamespaceDecl_Create(ASTContext& C, DeclContext* DC, bool Inline, SourceLocation StartLoc, SourceLocation IdLoc, IdentifierInfo* Id, NamespaceDecl* PrevDecl) { - return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl); + return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl); } #else static inline NamespaceDecl* NamespaceDecl_Create(ASTContext& C, DeclContext* DC, bool Inline, SourceLocation StartLoc, SourceLocation IdLoc, IdentifierInfo* Id, NamespaceDecl* PrevDecl) { - return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl, - /*Nested=*/false); + return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl, + /*Nested=*/false); } #endif @@ -249,17 +249,19 @@ static inline void ExprSetDeps(Expr* result, Expr* Node) { #define CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsPar ,clang::CallExpr::ADLCallKind UsesADL #define CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsUse ,UsesADL #define CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsOverride FPOptionsOverride - #if CLANG_VERSION_MAJOR >= 16 - #define CLAD_COMPAT_CLANG11_LangOptions_EtraParams /**/ - #else - #define CLAD_COMPAT_CLANG11_LangOptions_EtraParams Ctx.getLangOpts() - #endif - #define CLAD_COMPAT_CLANG11_Ctx_ExtraParams Ctx, - #define CLAD_COMPAT_CREATE11(CLASS, CTORARGS) (CLASS::Create CTORARGS) - #define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Removed /**/ - #define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Moved ,Node->getComputationLHSType(),Node->getComputationResultType() - #define CLAD_COMPAT_CLANG11_ChooseExpr_EtraParams_Removed /**/ - #define CLAD_COMPAT_CLANG11_WhileStmt_ExtraParams ,Node->getLParenLoc(),Node->getRParenLoc() +#if CLANG_VERSION_MAJOR >= 16 +#define CLAD_COMPAT_CLANG11_LangOptions_EtraParams /**/ +#else +#define CLAD_COMPAT_CLANG11_LangOptions_EtraParams Ctx.getLangOpts() +#endif +#define CLAD_COMPAT_CLANG11_Ctx_ExtraParams Ctx, +#define CLAD_COMPAT_CREATE11(CLASS, CTORARGS) (CLASS::Create CTORARGS) +#define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Removed /**/ +#define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Moved \ + , Node->getComputationLHSType(), Node->getComputationResultType() +#define CLAD_COMPAT_CLANG11_ChooseExpr_EtraParams_Removed /**/ +#define CLAD_COMPAT_CLANG11_WhileStmt_ExtraParams \ + , Node->getLParenLoc(), Node->getRParenLoc() #endif // Compatibility helper function for creation CXXOperatorCallExpr. Clang 8 and above use Create. @@ -725,7 +727,7 @@ ArraySize_GetValue(const llvm::Optional& opt) { #else static inline const Expr* ArraySize_GetValue(const std::optional& opt) { - return opt.value(); + return opt.value(); } #endif diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 39c31691c..7ef3aad1f 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -77,7 +77,7 @@ namespace clad { friend class ReverseModeVisitor; friend class HessianModeVisitor; friend class JacobianModeVisitor; - + friend class ReverseModeForwPassVisitor; clang::Sema& m_Sema; plugin::CladPlugin& m_CladPlugin; clang::ASTContext& m_Context; diff --git a/include/clad/Differentiator/DiffMode.h b/include/clad/Differentiator/DiffMode.h index a9c27a935..a03e77e49 100644 --- a/include/clad/Differentiator/DiffMode.h +++ b/include/clad/Differentiator/DiffMode.h @@ -11,6 +11,7 @@ enum class DiffMode { reverse, hessian, jacobian, + reverse_mode_forward_pass, error_estimation }; } diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index b59f77189..f75692803 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -21,6 +21,11 @@ #include namespace clad { +template struct ValueAndAdjoint { + T value; + U adjoint; +}; + /// \returns the size of a c-style string CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { unsigned int count; @@ -299,9 +304,9 @@ namespace clad { differentiate(F fn, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(fn && "Must pass in a non-0 argument"); - return CladFunction>(derivedFn, - code); + assert(fn && "Must pass in a non-0 argument"); + return CladFunction>(derivedFn, + code); } /// Specialization for differentiating functors. @@ -319,8 +324,8 @@ namespace clad { differentiate(F&& f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - return CladFunction>(derivedFn, - code, f); + return CladFunction>(derivedFn, + code, f); } /// Generates function which computes derivative of `fn` argument w.r.t @@ -342,9 +347,9 @@ namespace clad { differentiate(F fn, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(fn && "Must pass in a non-0 argument"); - return CladFunction, true>( - derivedFn, code); + assert(fn && "Must pass in a non-0 argument"); + return CladFunction, true>( + derivedFn, code); } /// Generates function which computes gradient of the given function wrt the @@ -364,9 +369,9 @@ namespace clad { gradient(F f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(f && "Must pass in a non-0 argument"); - return CladFunction, true>( - derivedFn /* will be replaced by gradient*/, code); + assert(f && "Must pass in a non-0 argument"); + return CladFunction, true>( + derivedFn /* will be replaced by gradient*/, code); } /// Specialization for differentiating functors. @@ -382,8 +387,8 @@ namespace clad { gradient(F&& f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - return CladFunction, true>( - derivedFn /* will be replaced by gradient*/, code, f); + return CladFunction, true>( + derivedFn /* will be replaced by gradient*/, code, f); } /// Generates function which computes hessian matrix of the given function wrt diff --git a/include/clad/Differentiator/ReverseModeForwPassVisitor.h b/include/clad/Differentiator/ReverseModeForwPassVisitor.h new file mode 100644 index 000000000..7070a73f4 --- /dev/null +++ b/include/clad/Differentiator/ReverseModeForwPassVisitor.h @@ -0,0 +1,38 @@ +#ifndef CLAD_DIFFERENTIATOR_REVERSEMODEFORWPASSVISITOR_H +#define CLAD_DIFFERENTIATOR_REVERSEMODEFORWPASSVISITOR_H + +#include "clad/Differentiator/ParseDiffArgsTypes.h" +#include "clad/Differentiator/ReverseModeVisitor.h" + +#include "clang/AST/StmtVisitor.h" +#include "clang/Sema/Sema.h" + +#include "llvm/ADT/SmallVector.h" + +namespace clad { +class ReverseModeForwPassVisitor : public ReverseModeVisitor { +private: + Stmts m_Globals; + + llvm::SmallVector + ComputeParamTypes(const DiffParams& diffParams); + clang::QualType ComputeReturnType(); + llvm::SmallVector BuildParams(DiffParams& diffParams); + clang::QualType GetParameterDerivativeType(clang::QualType yType, + clang::QualType xType); + +public: + ReverseModeForwPassVisitor(DerivativeBuilder& builder); + DerivativeAndOverload Derive(const clang::FunctionDecl* FD, + const DiffRequest& request); + + StmtDiff ProcessSingleStmt(const clang::Stmt* S); + + StmtDiff VisitStmt(const clang::Stmt* S) override; + StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override; + StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE) override; + StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; +}; +} // namespace clad + +#endif \ No newline at end of file diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 3b668ff87..c49c497fb 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -30,8 +30,8 @@ namespace clad { class ReverseModeVisitor : public clang::ConstStmtVisitor, public VisitorBase { - - private: + + protected: // FIXME: We should remove friend-dependency of the plugin classes here. // For this we will need to separate out AST related functions in // a separate namespace, as well as add getters/setters function of @@ -292,7 +292,7 @@ namespace clad { public: ReverseModeVisitor(DerivativeBuilder& builder); - ~ReverseModeVisitor(); + virtual ~ReverseModeVisitor(); ///\brief Produces the gradient of a given function. /// @@ -321,11 +321,11 @@ namespace clad { StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp); StmtDiff VisitCallExpr(const clang::CallExpr* CE); - StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); + virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO); StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL); StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE); - StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE); + virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE); StmtDiff VisitDeclStmt(const clang::DeclStmt* DS); StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL); StmtDiff VisitForStmt(const clang::ForStmt* FS); @@ -335,8 +335,8 @@ namespace clad { StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL); StmtDiff VisitMemberExpr(const clang::MemberExpr* ME); StmtDiff VisitParenExpr(const clang::ParenExpr* PE); - StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS); - StmtDiff VisitStmt(const clang::Stmt* S); + virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS); + virtual StmtDiff VisitStmt(const clang::Stmt* S); StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp); StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC); /// Decl is not Stmt, so it cannot be visited directly. diff --git a/lib/Differentiator/CMakeLists.txt b/lib/Differentiator/CMakeLists.txt index 61695990e..f5eddb2c6 100644 --- a/lib/Differentiator/CMakeLists.txt +++ b/lib/Differentiator/CMakeLists.txt @@ -31,6 +31,7 @@ add_llvm_library(cladDifferentiator HessianModeVisitor.cpp JacobianModeVisitor.cpp MultiplexExternalRMVSource.cpp + ReverseModeForwPassVisitor.cpp ReverseModeVisitor.cpp StmtClone.cpp VectorForwardModeVisitor.cpp diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index d5a15580b..1eb32f525 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -21,6 +21,7 @@ #include "clad/Differentiator/ForwardModeVisitor.h" #include "clad/Differentiator/HessianModeVisitor.h" #include "clad/Differentiator/JacobianModeVisitor.h" +#include "clad/Differentiator/ReverseModeForwPassVisitor.h" #include "clad/Differentiator/ReverseModeVisitor.h" #include "clad/Differentiator/StmtClone.h" #include "clad/Differentiator/VectorForwardModeVisitor.h" @@ -230,6 +231,9 @@ namespace clad { result = V.DerivePullback(FD, request); if (!m_ErrorEstHandler.empty()) CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel); + } else if (request.Mode == DiffMode::reverse_mode_forward_pass) { + ReverseModeForwPassVisitor V(*this); + result = V.Derive(FD, request); } else if (request.Mode == DiffMode::hessian) { HessianModeVisitor H(*this); result = H.Derive(FD, request); diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 804671dd5..2097fb676 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -390,20 +390,20 @@ namespace clad { // The string is not a range just a single index size_t index; if (firstStr.getAsInteger(Radix, index)) { - utils::EmitDiag(semaRef, DiagnosticsEngine::Error, - diffArgs->getEndLoc(), - "Could not parse index '%0'", {diffSpec}); - return; + utils::EmitDiag(semaRef, DiagnosticsEngine::Error, + diffArgs->getEndLoc(), + "Could not parse index '%0'", {diffSpec}); + return; } dVarInfo.paramIndexInterval = IndexInterval(index); } else { size_t first, last; if (firstStr.getAsInteger(Radix, first) || lastStr.getAsInteger(Radix, last)) { - utils::EmitDiag(semaRef, DiagnosticsEngine::Error, - diffArgs->getEndLoc(), - "Could not parse range '%0'", {diffSpec}); - return; + utils::EmitDiag(semaRef, DiagnosticsEngine::Error, + diffArgs->getEndLoc(), + "Could not parse range '%0'", {diffSpec}); + return; } if (first >= last) { utils::EmitDiag(semaRef, DiagnosticsEngine::Error, diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index a4e5820f0..330dc4191 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -380,4 +380,4 @@ namespace clad { return DerivativeAndOverload{result.first, /*OverloadFunctionDecl=*/nullptr}; } -} // end namespace clad + } // end namespace clad diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp new file mode 100644 index 000000000..8089ffc88 --- /dev/null +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -0,0 +1,277 @@ +#include "clad/Differentiator/ReverseModeForwPassVisitor.h" + +#include "clad/Differentiator/CladUtils.h" +#include "clad/Differentiator/DiffPlanner.h" +#include "clad/Differentiator/ErrorEstimator.h" + +#include "llvm/Support/SaveAndRestore.h" + +#include + +using namespace clang; + +namespace clad { + +ReverseModeForwPassVisitor::ReverseModeForwPassVisitor( + DerivativeBuilder& builder) + : ReverseModeVisitor(builder) {} + +DerivativeAndOverload +ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, + const DiffRequest& request) { + silenceDiags = !request.VerboseDiags; + m_Function = FD; + + m_Mode = DiffMode::reverse_mode_forward_pass; + + assert(m_Function && "Must not be null."); + + DiffParams args{}; + std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); + + auto fnName = m_Function->getNameAsString() + "_forw"; + auto fnDNI = utils::BuildDeclarationNameInfo(m_Sema, fnName); + + auto paramTypes = ComputeParamTypes(args); + auto returnType = ComputeReturnType(); + const auto* sourceFnType = dyn_cast(m_Function->getType()); + auto fnType = m_Context.getFunctionType(returnType, paramTypes, + sourceFnType->getExtProtoInfo()); + + llvm::SaveAndRestore saveContext(m_Sema.CurContext); + llvm::SaveAndRestore saveScope(m_CurScope); + m_Sema.CurContext = const_cast(m_Function->getDeclContext()); + + DeclWithContext fnBuildRes = + m_Builder.cloneFunction(m_Function, *this, m_Sema.CurContext, m_Sema, + m_Context, noLoc, fnDNI, fnType); + m_Derivative = fnBuildRes.first; + + beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | + Scope::DeclScope); + m_Sema.PushFunctionScope(); + m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); + + auto params = BuildParams(args); + m_Derivative->setParams(params); + m_Derivative->setBody(nullptr); + + beginScope(Scope::FnScope | Scope::DeclScope); + m_DerivativeFnScope = getCurrentScope(); + + beginBlock(); + + StmtDiff bodyDiff = Visit(m_Function->getBody()); + Stmt* forward = bodyDiff.getStmt(); + + for (Stmt* S : ReverseModeVisitor::m_Globals) + addToCurrentBlock(S); + + if (auto* CS = dyn_cast(forward)) + for (Stmt* S : CS->body()) + addToCurrentBlock(S); + + Stmt* fnBody = endBlock(); + m_Derivative->setBody(fnBody); + endScope(); + m_Sema.PopFunctionScopeInfo(); + m_Sema.PopDeclContext(); + endScope(); + return DerivativeAndOverload{m_Derivative, nullptr}; +} + +// FIXME: This function is copied from ReverseModeVisitor. Find a suitable place +// for it. +QualType +ReverseModeForwPassVisitor::GetParameterDerivativeType(QualType yType, + QualType xType) { + assert(yType.getNonReferenceType()->isRealType() && + "yType should be a builtin-numerical scalar type!!"); + QualType xValueType = utils::GetValueType(xType); + // derivative variables should always be of non-const type. + xValueType.removeLocalConst(); + QualType nonRefXValueType = xValueType.getNonReferenceType(); + if (nonRefXValueType->isRealType()) + return GetCladArrayRefOfType(yType); + return GetCladArrayRefOfType(nonRefXValueType); +} + +llvm::SmallVector +ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) { + llvm::SmallVector paramTypes; + paramTypes.reserve(m_Function->getNumParams() * 2); + for (auto* PVD : m_Function->parameters()) + paramTypes.push_back(PVD->getType()); + + QualType effectiveReturnType = + m_Function->getReturnType().getNonReferenceType(); + + if (const auto* MD = dyn_cast(m_Function)) { + const CXXRecordDecl* RD = MD->getParent(); + if (MD->isInstance() && !RD->isLambda()) { + QualType thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); + paramTypes.push_back( + GetParameterDerivativeType(effectiveReturnType, thisType)); + } + } + + for (auto* PVD : m_Function->parameters()) { + const auto* it = + std::find(std::begin(diffParams), std::end(diffParams), PVD); + if (it != std::end(diffParams)) { + paramTypes.push_back( + GetParameterDerivativeType(effectiveReturnType, PVD->getType())); + } + } + return paramTypes; +} + +clang::QualType ReverseModeForwPassVisitor::ComputeReturnType() { + auto* valAndAdjointTempDecl = + LookupTemplateDeclInCladNamespace("ValueAndAdjoint"); + auto RT = m_Function->getReturnType(); + auto T = InstantiateTemplate(valAndAdjointTempDecl, {RT, RT}); + return T; +} + +llvm::SmallVector +ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) { + llvm::SmallVector params; + llvm::SmallVector paramDerivatives; + params.reserve(m_Function->getNumParams() + diffParams.size()); + const auto* derivativeFnType = + cast(m_Derivative->getType()); + + std::size_t dParamTypesIdx = m_Function->getNumParams(); + + if (const auto* MD = dyn_cast(m_Function)) { + const CXXRecordDecl* RD = MD->getParent(); + if (MD->isInstance() && !RD->isLambda()) { + auto* thisDerivativePVD = utils::BuildParmVarDecl( + m_Sema, m_Derivative, CreateUniqueIdentifier("_d_this"), + derivativeFnType->getParamType(dParamTypesIdx)); + paramDerivatives.push_back(thisDerivativePVD); + + if (thisDerivativePVD->getIdentifier()) + m_Sema.PushOnScopeChains(thisDerivativePVD, getCurrentScope(), + /*AddToContext=*/false); + + Expr* deref = + BuildOp(UnaryOperatorKind::UO_Deref, BuildDeclRef(thisDerivativePVD)); + m_ThisExprDerivative = utils::BuildParenExpr(m_Sema, deref); + ++dParamTypesIdx; + } + } + for (auto* PVD : m_Function->parameters()) { + // FIXME: Call expression may contain default arguments that we are now + // removing. This may cause issues. + auto* newPVD = utils::BuildParmVarDecl( + m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(), + PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo()); + params.push_back(newPVD); + + if (newPVD->getIdentifier()) + m_Sema.PushOnScopeChains(newPVD, getCurrentScope(), + /*AddToContext=*/false); + + auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD); + if (it != std::end(diffParams)) { + *it = newPVD; + QualType dType = derivativeFnType->getParamType(dParamTypesIdx); + IdentifierInfo* dII = + CreateUniqueIdentifier("_d_" + PVD->getNameAsString()); + auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType, + PVD->getStorageClass()); + paramDerivatives.push_back(dPVD); + ++dParamTypesIdx; + + if (dPVD->getIdentifier()) + m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), + /*AddToContext=*/false); + + if (utils::isArrayOrPointerType(PVD->getType())) { + m_Variables[*it] = (Expr*)BuildDeclRef(dPVD); + } else { + QualType valueType = DetermineCladArrayValueType(dPVD->getType()); + m_Variables[*it] = + BuildOp(UO_Deref, BuildDeclRef(dPVD), m_Function->getLocation()); + // Add additional paranthesis if derivative is of record type + // because `*derivative.someField` will be incorrectly evaluated if + // the derived function is compiled standalone. + if (valueType->isRecordType()) + m_Variables[*it] = utils::BuildParenExpr(m_Sema, m_Variables[*it]); + } + } + } + params.insert(params.end(), paramDerivatives.begin(), paramDerivatives.end()); + return params; +} + +StmtDiff ReverseModeForwPassVisitor::ProcessSingleStmt(const clang::Stmt* S) { + StmtDiff SDiff = Visit(S); + return {SDiff.getStmt()}; +} + +StmtDiff ReverseModeForwPassVisitor::VisitStmt(const clang::Stmt* S) { + return {Clone(S)}; +} + +StmtDiff +ReverseModeForwPassVisitor::VisitCompoundStmt(const clang::CompoundStmt* CS) { + beginScope(Scope::DeclScope); + beginBlock(); + for (Stmt* S : CS->body()) { + StmtDiff SDiff = ProcessSingleStmt(S); + addToCurrentBlock(SDiff.getStmt()); + } + CompoundStmt* forward = endBlock(); + endScope(); + return {forward}; +} + +StmtDiff ReverseModeForwPassVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { + DeclRefExpr* clonedDRE = nullptr; + // Check if referenced Decl was "replaced" with another identifier inside + // the derivative + if (const auto* VD = dyn_cast(DRE->getDecl())) { + auto it = m_DeclReplacements.find(VD); + if (it != std::end(m_DeclReplacements)) + clonedDRE = BuildDeclRef(it->second); + else + clonedDRE = cast(Clone(DRE)); + // If current context is different than the context of the original + // declaration (e.g. we are inside lambda), rebuild the DeclRefExpr + // with Sema::BuildDeclRefExpr. This is required in some cases, e.g. + // Sema::BuildDeclRefExpr is responsible for adding captured fields + // to the underlying struct of a lambda. + if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) { + auto* referencedDecl = cast(clonedDRE->getDecl()); + clonedDRE = cast(BuildDeclRef(referencedDecl)); + } + } else + clonedDRE = cast(Clone(DRE)); + + if (auto* decl = dyn_cast(clonedDRE->getDecl())) { + // Check DeclRefExpr is a reference to an independent variable. + auto it = m_Variables.find(decl); + if (it == std::end(m_Variables)) { + // Is not an independent variable, ignored. + return StmtDiff(clonedDRE); + } + return StmtDiff(clonedDRE, it->second); + } + + return StmtDiff(clonedDRE); +} + +StmtDiff +ReverseModeForwPassVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) { + const Expr* value = RS->getRetValue(); + auto returnDiff = Visit(value); + llvm::SmallVector returnArgs = {returnDiff.getExpr(), + returnDiff.getExpr_dx()}; + Expr* returnInitList = m_Sema.ActOnInitList(noLoc, returnArgs, noLoc).get(); + Stmt* newRS = m_Sema.BuildReturnStmt(noLoc, returnInitList).get(); + return {newRS}; +} +} // namespace clad \ No newline at end of file diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 4e9805b64..a2a95bdd6 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1673,15 +1673,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, ArgDeclStmts.push_back(BuildDeclStmt(gradVarDecl)); idx++; } + Expr* pullback = dfdx(); + if ((pullback == nullptr) && FD->getReturnType()->isLValueReferenceType()) + pullback = getZeroInit(FD->getReturnType().getNonReferenceType()); + // FIXME: Remove this restriction. if (!FD->getReturnType()->isVoidType()) { - assert((dfdx() && !FD->getReturnType()->isVoidType()) && + assert((pullback && !FD->getReturnType()->isVoidType()) && "Call to function returning non-void type with no dfdx() is not " "supported!"); } if (FD->getReturnType()->isVoidType()) { - assert(dfdx() == nullptr && FD->getReturnType()->isVoidType() && + assert(pullback == nullptr && FD->getReturnType()->isVoidType() && "Call to function returning void type should not have any " "corresponding dfdx()."); } @@ -1691,9 +1695,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, DerivedCallOutputArgs.end()); pullbackCallArgs = DerivedCallArgs; - if (dfdx()) + if (pullback) pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs(), - dfdx()); + pullback); // Try to find it in builtin derivatives std::string customPullback = FD->getNameAsString() + "_pullback"; @@ -1857,15 +1861,83 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::end(CallArgs), std::begin(CallArgs), [this](Expr* E) { return Clone(E); }); - // Recreate the original call expression. - Expr* call = m_Sema - .ActOnCallExpr(getCurrentScope(), - Clone(CE->getCallee()), - noLoc, - CallArgs, - noLoc) - .get(); + + Expr* call = nullptr; + + if (FD->getReturnType()->isReferenceType()) { + DiffRequest calleeFnForwPassReq; + calleeFnForwPassReq.Function = FD; + calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass; + calleeFnForwPassReq.BaseFunctionName = FD->getNameAsString(); + calleeFnForwPassReq.VerboseDiags = true; + FunctionDecl* calleeFnForwPassFD = + plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq); + + assert(calleeFnForwPassFD && + "Clad failed to generate callee function forward pass function"); + + // FIXME: We are using the derivatives in forward pass here + // If `expr_dx()` is only meant to be used in reverse pass, + // (for example, `clad::pop(...)` expression and a corresponding + // `clad::push(...)` in the forward pass), then this can result in + // incorrect derivative or crash at runtime. Ideally, we should have + // a separate routine to use derivative in the forward pass. + + // We cannot reuse the derivatives previously computed because + // they might contain 'clad::pop(..)` expression. + if (isa(CE)) { + Expr* derivedBase = baseDiff.getExpr_dx(); + // FIXME: We may need this if-block once we support pointers, and + // passing pointers-by-reference if + // (isCladArrayType(derivedBase->getType())) + // CallArgs.push_back(derivedBase); + // else + CallArgs.push_back( + BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, noLoc)); + } + + for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { + const Expr* arg = CE->getArg(i); + const ParmVarDecl* PVD = FD->getParamDecl(i); + StmtDiff argDiff = Visit(arg); + if ((argDiff.getExpr_dx() != nullptr) && + PVD->getType()->isReferenceType()) { + Expr* derivedArg = argDiff.getExpr_dx(); + // FIXME: We may need this if-block once we support pointers, and + // passing pointers-by-reference if + // (isCladArrayType(derivedArg->getType())) + // CallArgs.push_back(derivedArg); + // else + CallArgs.push_back( + BuildOp(UnaryOperatorKind::UO_AddrOf, derivedArg, noLoc)); + } else + CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(noLoc).get()); + } + if (isa(CE)) { + Expr* baseE = baseDiff.getExpr(); + call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(), + CallArgs, calleeFnForwPassFD); + } else { + call = m_Sema + .ActOnCallExpr(getCurrentScope(), + BuildDeclRef(calleeFnForwPassFD), noLoc, + CallArgs, noLoc) + .get(); + } + auto* callRes = StoreAndRef(call); + auto* resValue = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value"); + auto* resAdjoint = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); + return StmtDiff(resValue, nullptr, resAdjoint); + } // Recreate the original call expression. + call = m_Sema + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc, + CallArgs, noLoc) + .get(); return StmtDiff(call); + + return {}; } StmtDiff ReverseModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { @@ -2312,7 +2384,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (isDerivativeOfRefType) { initDiff = Visit(VD->getInit()); - if (!initDiff.getExpr_dx()) { + if (!initDiff.getForwSweepExpr_dx()) { VDDerivedType = ComputeAdjointType(VD->getType().getNonReferenceType()); isDerivativeOfRefType = false; @@ -3136,7 +3208,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // TODO: Add DiffMode::experimental_pullback support here as well. if (m_Mode == DiffMode::reverse || m_Mode == DiffMode::experimental_pullback) { - QualType effectiveReturnType = m_Function->getReturnType(); + QualType effectiveReturnType = + m_Function->getReturnType().getNonReferenceType(); if (m_Mode == DiffMode::experimental_pullback) { // FIXME: Generally, we use the function's return type as the argument's // derivative type. We cannot follow this strategy for `void` function @@ -3150,7 +3223,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (effectiveReturnType->isVoidType()) effectiveReturnType = m_Context.DoubleTy; else - paramTypes.push_back(m_Function->getReturnType()); + paramTypes.push_back(effectiveReturnType); } if (auto MD = dyn_cast(m_Function)) { diff --git a/lib/Differentiator/StmtClone.cpp b/lib/Differentiator/StmtClone.cpp index 14c49c894..f897d49f3 100644 --- a/lib/Differentiator/StmtClone.cpp +++ b/lib/Differentiator/StmtClone.cpp @@ -76,7 +76,12 @@ DEFINE_CLONE_EXPR(CharacterLiteral, (Node->getValue(), Node->getKind(), Node->ge DEFINE_CLONE_EXPR(ImaginaryLiteral, (Clone(Node->getSubExpr()), Node->getType())) DEFINE_CLONE_EXPR(ParenExpr, (Node->getLParen(), Node->getRParen(), Clone(Node->getSubExpr()))) DEFINE_CLONE_EXPR(ArraySubscriptExpr, (Clone(Node->getLHS()), Clone(Node->getRHS()), Node->getType(), Node->getValueKind(), Node->getObjectKind(), Node->getRBracketLoc())) -DEFINE_CREATE_EXPR(CXXDefaultArgExpr, (Ctx, SourceLocation(), Node->getParam() CLAD_COMPAT_CLANG16_CXXDefaultArgExpr_getRewrittenExpr_Param(Node) CLAD_COMPAT_CLANG9_CXXDefaultArgExpr_getUsedContext_Param(Node))) +DEFINE_CREATE_EXPR( + CXXDefaultArgExpr, + (Ctx, SourceLocation(), + Node->getParam() + CLAD_COMPAT_CLANG16_CXXDefaultArgExpr_getRewrittenExpr_Param(Node) + CLAD_COMPAT_CLANG9_CXXDefaultArgExpr_getUsedContext_Param(Node))) Stmt* StmtClone::VisitMemberExpr(MemberExpr* Node) { TemplateArgumentListInfo TemplateArgs; @@ -108,7 +113,14 @@ DEFINE_CREATE_EXPR(CXXStaticCastExpr, (Ctx, Node->getType(), Node->getValueKind( DEFINE_CREATE_EXPR(CXXDynamicCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) DEFINE_CREATE_EXPR(CXXReinterpretCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) DEFINE_CREATE_EXPR(CXXConstCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Clone(Node->getSubExpr()), Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) -DEFINE_CREATE_EXPR(CXXConstructExpr, (Ctx, Node->getType(), Node->getLocation(), Node->getConstructor(), Node->isElidable(), clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), Node->hadMultipleCandidates(), Node->isListInitialization(), Node->isStdInitListInitialization(), Node->requiresZeroInitialization(), Node->getConstructionKind(), Node->getParenOrBraceRange())) +DEFINE_CREATE_EXPR( + CXXConstructExpr, + (Ctx, Node->getType(), Node->getLocation(), Node->getConstructor(), + Node->isElidable(), + clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), + Node->hadMultipleCandidates(), Node->isListInitialization(), + Node->isStdInitListInitialization(), Node->requiresZeroInitialization(), + Node->getConstructionKind(), Node->getParenOrBraceRange())) DEFINE_CREATE_EXPR(CXXFunctionalCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getTypeInfoAsWritten(), Node->getCastKind(), Clone(Node->getSubExpr()), 0 /*EP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), Node->getLParenLoc(), Node->getRParenLoc())) DEFINE_CREATE_EXPR(ExprWithCleanups, (Ctx, Node->getSubExpr(), Node->cleanupsHaveSideEffects(), {})) @@ -117,7 +129,13 @@ DEFINE_CREATE_EXPR(ExprWithCleanups, (Ctx, Node->getSubExpr(), DEFINE_CREATE_EXPR(ConstantExpr, (Ctx, Clone(Node->getSubExpr()) CLAD_COMPAT_ConstantExpr_Create_ExtraParams)) #endif -DEFINE_CLONE_EXPR_CO(CXXTemporaryObjectExpr, (Ctx, Node->getConstructor(), Node->getType(), Node->getTypeSourceInfo(), clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), Node->getSourceRange(), Node->hadMultipleCandidates(), Node->isListInitialization(), Node->isStdInitListInitialization(), Node->requiresZeroInitialization())) +DEFINE_CLONE_EXPR_CO( + CXXTemporaryObjectExpr, + (Ctx, Node->getConstructor(), Node->getType(), Node->getTypeSourceInfo(), + clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), + Node->getSourceRange(), Node->hadMultipleCandidates(), + Node->isListInitialization(), Node->isStdInitListInitialization(), + Node->requiresZeroInitialization())) DEFINE_CLONE_EXPR(MaterializeTemporaryExpr, (Node->getType(), CLAD_COMPAT_CLANG10_GetTemporaryExpr(Node), Node->isBoundToLvalueReference())) DEFINE_CLONE_EXPR_CO11(CompoundAssignOperator, (CLAD_COMPAT_CLANG11_Ctx_ExtraParams Clone(Node->getLHS()), Clone(Node->getRHS()), Node->getOpcode(), Node->getType(), @@ -137,7 +155,11 @@ DEFINE_CLONE_EXPR(CXXThrowExpr, (Clone(Node->getSubExpr()), Node->getType(), Nod #if CLANG_VERSION_MAJOR < 16 DEFINE_CLONE_EXPR(SubstNonTypeTemplateParmExpr, (Node->getType(), Node->getValueKind(), Node->getBeginLoc(), Node->getParameter(), CLAD_COMPAT_SubstNonTypeTemplateParmExpr_isReferenceParameter_ExtraParam(Node) Node->getReplacement())) #else -DEFINE_CLONE_EXPR(SubstNonTypeTemplateParmExpr, (Node->getType(), Node->getValueKind(), Node->getBeginLoc(), Node->getReplacement(), Node->getAssociatedDecl(), Node->getIndex(), Node->getPackIndex(), Node->isReferenceParameter())); +DEFINE_CLONE_EXPR(SubstNonTypeTemplateParmExpr, + (Node->getType(), Node->getValueKind(), Node->getBeginLoc(), + Node->getReplacement(), Node->getAssociatedDecl(), + Node->getIndex(), Node->getPackIndex(), + Node->isReferenceParameter())); #endif DEFINE_CREATE_EXPR(PseudoObjectExpr, (Ctx, Node->getSyntacticForm(), llvm::SmallVector(Node->semantics_begin(), Node->semantics_end()), Node->getResultExprIndex())) //BlockExpr diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 725731bc4..c6944c985 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -351,6 +351,92 @@ double fn6(double i=0, double j=0) { return i*j; } +double& identity(double& i) { + return i; +} + +double fn7(double i, double j) { + double& k = identity(i); + double& l = identity(j); + k += 7*j; + l += 9*i; + return i + j; +} + +// CHECK: void fn6_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double _t1; +// CHECK-NEXT: _t1 = i; +// CHECK-NEXT: _t0 = j; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 1 * _t0; +// CHECK-NEXT: * _d_i += _r0; +// CHECK-NEXT: double _r1 = _t1 * 1; +// CHECK-NEXT: * _d_j += _r1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK: void identity_pullback(double &i, double _d_y, clad::array_ref _d_i) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: * _d_i += _d_y; +// CHECK-NEXT: } +// CHECK: clad::ValueAndAdjoint identity_forw(double &i, clad::array_ref _d_i) { +// CHECK-NEXT: return {i, * _d_i}; +// CHECK-NEXT: } +// CHECK: void fn7_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double *_d_k = 0; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double *_d_l = 0; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double _t5; +// CHECK-NEXT: _t0 = i; +// CHECK-NEXT: clad::ValueAndAdjoint _t1 = identity_forw(i, &* _d_i); +// CHECK-NEXT: _d_k = &_t1.adjoint; +// CHECK-NEXT: double &k = _t1.value; +// CHECK-NEXT: _t2 = j; +// CHECK-NEXT: clad::ValueAndAdjoint _t3 = identity_forw(j, &* _d_j); +// CHECK-NEXT: _d_l = &_t3.adjoint; +// CHECK-NEXT: double &l = _t3.value; +// CHECK-NEXT: _t4 = j; +// CHECK-NEXT: k += 7 * _t4; +// CHECK-NEXT: _t5 = i; +// CHECK-NEXT: l += 9 * _t5; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: * _d_i += 1; +// CHECK-NEXT: * _d_j += 1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d1 = *_d_l; +// CHECK-NEXT: *_d_l += _r_d1; +// CHECK-NEXT: double _r4 = _r_d1 * _t5; +// CHECK-NEXT: double _r5 = 9 * _r_d1; +// CHECK-NEXT: * _d_i += _r5; +// CHECK-NEXT: *_d_l -= _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d0 = *_d_k; +// CHECK-NEXT: *_d_k += _r_d0; +// CHECK-NEXT: double _r2 = _r_d0 * _t4; +// CHECK-NEXT: double _r3 = 7 * _r_d0; +// CHECK-NEXT: * _d_j += _r3; +// CHECK-NEXT: *_d_k -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: identity_pullback(_t2, 0, &* _d_j); +// CHECK-NEXT: double _r1 = * _d_j; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: identity_pullback(_t0, 0, &* _d_i); +// CHECK-NEXT: double _r0 = * _d_i; +// CHECK-NEXT: } +// CHECK-NEXT: } + + template void reset(T* arr, int n) { for (int i=0; i _d_i, clad::array_ref _d_j); void const_mem_fn_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j); void volatile_mem_fn_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j); @@ -756,6 +758,35 @@ double fn(double i,double j) { // CHECK-NEXT: } // CHECK-NEXT: } +double fn2(SimpleFunctions& sf, double i) { + return sf.ref_mem_fn(i); +} + +// CHECK: void ref_mem_fn_pullback(double i, double _d_y, clad::array_ref _d_this, clad::array_ref _d_i) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: (* _d_this).x += _d_y; +// CHECK-NEXT: } +// CHECK: clad::ValueAndAdjoint ref_mem_fn_forw(double i, clad::array_ref _d_this, clad::array_ref _d_i) { +// CHECK-NEXT: return {this->x, (* _d_this).x}; +// CHECK-NEXT: } +// CHECK: void fn2_grad(SimpleFunctions &sf, double i, clad::array_ref _d_sf, clad::array_ref _d_i) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: SimpleFunctions _t1; +// CHECK-NEXT: _t0 = i; +// CHECK-NEXT: _t1 = sf; +// CHECK-NEXT: clad::ValueAndAdjoint _t2 = _t1.ref_mem_fn_forw(_t0, &(* _d_sf), nullptr); +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: double _grad0 = 0.; +// CHECK-NEXT: _t1.ref_mem_fn_pullback(_t0, 1, &(* _d_sf), &_grad0); +// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: * _d_i += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + + int main() { auto d_mem_fn = clad::gradient(&SimpleFunctions::mem_fn); auto d_const_mem_fn = clad::gradient(&SimpleFunctions::const_mem_fn); @@ -790,6 +821,12 @@ int main() { printf("%.2f ",result[i]); //CHECK-EXEC: 40.00 16.00 } + SimpleFunctions sf(2, 3); + SimpleFunctions d_sf; + auto d_fn2 = clad::gradient(fn2); + d_fn2.execute(sf, 2, &d_sf, &result[0]); + printf("%.2f", result[0]); //CHECK-EXEC: 40.00 + auto d_const_volatile_lval_ref_mem_fn_i = clad::gradient(&SimpleFunctions::const_volatile_lval_ref_mem_fn, "i"); // CHECK: void const_volatile_lval_ref_mem_fn_grad_0(double i, double j, clad::array_ref _d_this, clad::array_ref _d_i) const volatile & { diff --git a/tools/DerivedFnInfo.h b/tools/DerivedFnInfo.h index dc07c3cf2..160cbc1af 100644 --- a/tools/DerivedFnInfo.h +++ b/tools/DerivedFnInfo.h @@ -6,39 +6,41 @@ #include "clad/Differentiator/ParseDiffArgsTypes.h" namespace clad { - struct DiffRequest; - - /// `DerivedFnInfo` is designed to effectively store information about a - /// derived function. - struct DerivedFnInfo { - const clang::FunctionDecl* m_OriginalFn = nullptr; - clang::FunctionDecl* m_DerivedFn = nullptr; - clang::FunctionDecl* m_OverloadedDerivedFn = nullptr; - DiffMode m_Mode = DiffMode::unknown; - unsigned m_DerivativeOrder = 0; - DiffInputVarsInfo m_DiffVarsInfo; - bool m_UsesEnzyme = false; - - DerivedFnInfo() {} - DerivedFnInfo(const DiffRequest& request, clang::FunctionDecl* derivedFn, - clang::FunctionDecl* overloadedDerivedFn); - - /// Returns true if the derived function represented by the object, - /// satisfies the requirements of the given differentiation request. - bool SatisfiesRequest(const DiffRequest& request) const; - - /// Returns true if the object represents any derived function; otherwise - /// returns false. - bool IsValid() const; - - const clang::FunctionDecl* OriginalFn() const { return m_OriginalFn; } - clang::FunctionDecl* DerivedFn() const { return m_DerivedFn; } - clang::FunctionDecl* OverloadedDerivedFn() const { return m_OverloadedDerivedFn; } - - /// Returns true if `lhs` and `rhs` represents same derivative. - /// Here derivative is any function derived by clad. - static bool RepresentsSameDerivative(const DerivedFnInfo& lhs, - const DerivedFnInfo& rhs); +struct DiffRequest; + +/// `DerivedFnInfo` is designed to effectively store information about a +/// derived function. +struct DerivedFnInfo { + const clang::FunctionDecl* m_OriginalFn = nullptr; + clang::FunctionDecl* m_DerivedFn = nullptr; + clang::FunctionDecl* m_OverloadedDerivedFn = nullptr; + DiffMode m_Mode = DiffMode::unknown; + unsigned m_DerivativeOrder = 0; + DiffInputVarsInfo m_DiffVarsInfo; + bool m_UsesEnzyme = false; + + DerivedFnInfo() {} + DerivedFnInfo(const DiffRequest& request, clang::FunctionDecl* derivedFn, + clang::FunctionDecl* overloadedDerivedFn); + + /// Returns true if the derived function represented by the object, + /// satisfies the requirements of the given differentiation request. + bool SatisfiesRequest(const DiffRequest& request) const; + + /// Returns true if the object represents any derived function; otherwise + /// returns false. + bool IsValid() const; + + const clang::FunctionDecl* OriginalFn() const { return m_OriginalFn; } + clang::FunctionDecl* DerivedFn() const { return m_DerivedFn; } + clang::FunctionDecl* OverloadedDerivedFn() const { + return m_OverloadedDerivedFn; + } + + /// Returns true if `lhs` and `rhs` represents same derivative. + /// Here derivative is any function derived by clad. + static bool RepresentsSameDerivative(const DerivedFnInfo& lhs, + const DerivedFnInfo& rhs); }; } // namespace clad