From 55950162feff74c89c6f9590c8592b1bb7376a1f Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Tue, 11 Jun 2024 09:21:49 +0000 Subject: [PATCH] Pass the DiffRequest down to the visitors. NFC. The intent of this patch is to centralize the information about the diff request in one place. This will help upcoming patches to reduce the state of the visitors and refactor some of the implementation in a common place. That will unblock work on #721. --- .../ErrorEstimation/CustomModel/CustomModel.h | 5 +- demos/ErrorEstimation/PrintModel/PrintModel.h | 4 +- .../Differentiator/BaseForwardModeVisitor.h | 6 +- include/clad/Differentiator/DiffPlanner.h | 2 + include/clad/Differentiator/EstimationModel.h | 20 ++- .../clad/Differentiator/HessianModeVisitor.h | 4 +- .../Differentiator/PushForwardModeVisitor.h | 4 +- .../ReverseModeForwPassVisitor.h | 5 +- .../clad/Differentiator/ReverseModeVisitor.h | 10 +- .../Differentiator/VectorForwardModeVisitor.h | 7 +- .../VectorPushForwardModeVisitor.h | 3 +- include/clad/Differentiator/VisitorBase.h | 8 +- lib/Differentiator/BaseForwardModeVisitor.cpp | 41 +++--- lib/Differentiator/DerivativeBuilder.cpp | 29 ++-- lib/Differentiator/ErrorEstimator.cpp | 2 +- lib/Differentiator/HessianModeVisitor.cpp | 71 +++++----- lib/Differentiator/PushForwardModeVisitor.cpp | 5 +- .../ReverseModeForwPassVisitor.cpp | 45 +++--- lib/Differentiator/ReverseModeVisitor.cpp | 132 +++++++++--------- .../VectorForwardModeVisitor.cpp | 65 +++++---- .../VectorPushForwardModeVisitor.cpp | 4 +- lib/Differentiator/VisitorBase.cpp | 17 ++- tools/ClangPlugin.cpp | 3 +- 23 files changed, 254 insertions(+), 238 deletions(-) diff --git a/demos/ErrorEstimation/CustomModel/CustomModel.h b/demos/ErrorEstimation/CustomModel/CustomModel.h index b889391f7..728864bf8 100644 --- a/demos/ErrorEstimation/CustomModel/CustomModel.h +++ b/demos/ErrorEstimation/CustomModel/CustomModel.h @@ -15,8 +15,9 @@ // FPErrorEstimationModel class. class CustomModel : public clad::FPErrorEstimationModel { public: - CustomModel(clad::DerivativeBuilder& builder) - : FPErrorEstimationModel(builder) {} + CustomModel(clad::DerivativeBuilder& builder, + const clad::DiffRequest& request) + : FPErrorEstimationModel(builder, request) {} /// Return an expression of the following kind: /// dfdx * delta_x clang::Expr* AssignError(clad::StmtDiff refExpr, diff --git a/demos/ErrorEstimation/PrintModel/PrintModel.h b/demos/ErrorEstimation/PrintModel/PrintModel.h index 45e95d83e..d97dcfe63 100644 --- a/demos/ErrorEstimation/PrintModel/PrintModel.h +++ b/demos/ErrorEstimation/PrintModel/PrintModel.h @@ -15,8 +15,8 @@ // FPErrorEstimationModel class. class PrintModel : public clad::FPErrorEstimationModel { public: - PrintModel(clad::DerivativeBuilder& builder) - : FPErrorEstimationModel(builder) {} + PrintModel(clad::DerivativeBuilder& builder, const clad::DiffRequest& request) + : FPErrorEstimationModel(builder, request) {} // Return an expression of the following kind: // dfdx * delta_x clang::Expr* AssignError(clad::StmtDiff refExpr, const std::string& name) override; diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 3fda3a018..17e6ba6a4 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -24,7 +24,8 @@ class BaseForwardModeVisitor unsigned m_ArgIndex = ~0; public: - BaseForwardModeVisitor(DerivativeBuilder& builder); + BaseForwardModeVisitor(DerivativeBuilder& builder, + const DiffRequest& request); virtual ~BaseForwardModeVisitor(); ///\brief Produces the first derivative of a given function. @@ -41,8 +42,7 @@ class BaseForwardModeVisitor const DiffRequest& request); /// Returns the return type for the pushforward function of the function - /// `m_Function`. - /// \note `m_Function` field should be set before using this function. + /// `m_DiffReq->Function`. clang::QualType ComputePushforwardFnReturnType(); virtual void ExecuteInsidePushforwardFunctionBlock(); diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 50aae2898..58359a14b 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -105,6 +105,8 @@ struct DiffRequest { DeclarationOnly == other.DeclarationOnly; } + const clang::FunctionDecl* operator->() const { return Function; } + // String operator for printing the node. operator std::string() const { std::string res = BaseFunctionName + "__order_" + diff --git a/include/clad/Differentiator/EstimationModel.h b/include/clad/Differentiator/EstimationModel.h index 2cb87fbde..7de38c661 100644 --- a/include/clad/Differentiator/EstimationModel.h +++ b/include/clad/Differentiator/EstimationModel.h @@ -28,7 +28,10 @@ namespace clad { std::unordered_map m_EstimateVar; public: - FPErrorEstimationModel(DerivativeBuilder& builder) : VisitorBase(builder) {} + // FIXME: Add a proper parameter for the DiffRequest here. + FPErrorEstimationModel(DerivativeBuilder& builder, + const DiffRequest& request) + : VisitorBase(builder, request) {} virtual ~FPErrorEstimationModel(); /// Clear the variable estimate map so that we can start afresh. @@ -83,10 +86,13 @@ namespace clad { /// custom model. /// \param[in] builder A build instance to pass to the custom model /// constructor. + /// \param[in] request The differentiation configuration passed to the + /// custom model /// \returns A reference to the custom class wrapped in the /// FPErrorEstimationModel class. virtual std::unique_ptr - InstantiateCustomModel(DerivativeBuilder& builder) = 0; + InstantiateCustomModel(DerivativeBuilder& builder, + const DiffRequest& request) = 0; }; /// A class used to register custom plugins. @@ -99,16 +105,18 @@ namespace clad { /// /// \param[in] builder The current instance of derivative builder. std::unique_ptr - InstantiateCustomModel(DerivativeBuilder& builder) override { - return std::unique_ptr(new CustomClass(builder)); + InstantiateCustomModel(DerivativeBuilder& builder, + const DiffRequest& request) override { + return std::unique_ptr( + new CustomClass(builder, request)); } }; /// Example class for taylor series approximation based error estimation. class TaylorApprox : public FPErrorEstimationModel { public: - TaylorApprox(DerivativeBuilder& builder) - : FPErrorEstimationModel(builder) {} + TaylorApprox(DerivativeBuilder& builder, const DiffRequest& request) + : FPErrorEstimationModel(builder, request) {} // Return an expression of the following kind: // std::abs(dfdx * delta_x * Em) clang::Expr* AssignError(StmtDiff refExpr, diff --git a/include/clad/Differentiator/HessianModeVisitor.h b/include/clad/Differentiator/HessianModeVisitor.h index 7a8f2f88c..2740add92 100644 --- a/include/clad/Differentiator/HessianModeVisitor.h +++ b/include/clad/Differentiator/HessianModeVisitor.h @@ -33,7 +33,7 @@ namespace clad { size_t TotalIndependentArgsSize, std::string hessianFuncName); public: - HessianModeVisitor(DerivativeBuilder& builder); + HessianModeVisitor(DerivativeBuilder& builder, const DiffRequest& request); ~HessianModeVisitor(); ///\brief Produces the hessian second derivative columns of a given @@ -53,4 +53,4 @@ namespace clad { }; } // end namespace clad -#endif // CLAD_HESSIAN_MODE_VISITOR_H \ No newline at end of file +#endif // CLAD_HESSIAN_MODE_VISITOR_H diff --git a/include/clad/Differentiator/PushForwardModeVisitor.h b/include/clad/Differentiator/PushForwardModeVisitor.h index 82347ab2f..c8e1226d3 100644 --- a/include/clad/Differentiator/PushForwardModeVisitor.h +++ b/include/clad/Differentiator/PushForwardModeVisitor.h @@ -10,12 +10,14 @@ #include "BaseForwardModeVisitor.h" namespace clad { + /// A visitor for processing the function code in forward mode. /// Used to compute derivatives by clad::differentiate. class PushForwardModeVisitor : public BaseForwardModeVisitor { public: - PushForwardModeVisitor(DerivativeBuilder& builder); + PushForwardModeVisitor(DerivativeBuilder& builder, + const DiffRequest& request); ~PushForwardModeVisitor() override; StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; diff --git a/include/clad/Differentiator/ReverseModeForwPassVisitor.h b/include/clad/Differentiator/ReverseModeForwPassVisitor.h index fc2236306..574836934 100644 --- a/include/clad/Differentiator/ReverseModeForwPassVisitor.h +++ b/include/clad/Differentiator/ReverseModeForwPassVisitor.h @@ -22,7 +22,8 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor { clang::QualType xType); public: - ReverseModeForwPassVisitor(DerivativeBuilder& builder); + ReverseModeForwPassVisitor(DerivativeBuilder& builder, + const DiffRequest& request); DerivativeAndOverload Derive(const clang::FunctionDecl* FD, const DiffRequest& request); @@ -34,4 +35,4 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor { }; } // namespace clad -#endif \ No newline at end of file +#endif diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 7e9c23082..02604b41d 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -336,7 +336,7 @@ namespace clad { llvm::SmallVectorImpl& outputArgs); public: - ReverseModeVisitor(DerivativeBuilder& builder); + ReverseModeVisitor(DerivativeBuilder& builder, const DiffRequest& request); virtual ~ReverseModeVisitor(); ///\brief Produces the gradient of a given function. @@ -629,18 +629,14 @@ namespace clad { /// Computes and returns the sequence of derived function parameter types. /// /// Information about the original function and the differentiation mode - /// are taken from the data member variables. In particular, `m_Function`, - /// `m_Mode` data members should be correctly set before using this - /// function. + /// are taken from the data member variables. llvm::SmallVector ComputeParamTypes(const DiffParams& diffParams); /// Builds and returns the sequence of derived function parameters. /// /// Information about the original function, derived function, derived /// function parameter types and the differentiation mode are implicitly - /// taken from the data member variables. In particular, `m_Function`, - /// `m_Mode` and `m_Derivative` should be correctly set before using this - /// function. + /// taken from the data member variables. llvm::SmallVector BuildParams(DiffParams& diffParams); diff --git a/include/clad/Differentiator/VectorForwardModeVisitor.h b/include/clad/Differentiator/VectorForwardModeVisitor.h index 37d824f33..c5bccdeda 100644 --- a/include/clad/Differentiator/VectorForwardModeVisitor.h +++ b/include/clad/Differentiator/VectorForwardModeVisitor.h @@ -23,7 +23,8 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor { clang::Expr* m_IndVarCountExpr; public: - VectorForwardModeVisitor(DerivativeBuilder& builder); + VectorForwardModeVisitor(DerivativeBuilder& builder, + const DiffRequest& request); ~VectorForwardModeVisitor(); ///\brief Produces the first derivative of a given function with @@ -53,9 +54,7 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor { /// /// Information about the original function, derived function, derived /// function parameter types and the differentiation mode are implicitly - /// taken from the data member variables. In particular, `m_Function`, - /// `m_Mode` and `m_Derivative` should be correctly set before using this - /// function. + /// taken from the data member variables. llvm::SmallVector BuildVectorModeParams(DiffParams& diffParams); diff --git a/include/clad/Differentiator/VectorPushForwardModeVisitor.h b/include/clad/Differentiator/VectorPushForwardModeVisitor.h index eabeefae0..20b0c272f 100644 --- a/include/clad/Differentiator/VectorPushForwardModeVisitor.h +++ b/include/clad/Differentiator/VectorPushForwardModeVisitor.h @@ -8,7 +8,8 @@ namespace clad { class VectorPushForwardModeVisitor : public VectorForwardModeVisitor { public: - VectorPushForwardModeVisitor(DerivativeBuilder& builder); + VectorPushForwardModeVisitor(DerivativeBuilder& builder, + const DiffRequest& request); ~VectorPushForwardModeVisitor() override; void ExecuteInsidePushforwardFunctionBlock() override; diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 122231896..1142bb0e0 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -99,11 +99,11 @@ namespace clad { /// A base class for all common functionality for visitors class VisitorBase { protected: - VisitorBase(DerivativeBuilder& builder) + VisitorBase(DerivativeBuilder& builder, const DiffRequest& request) : m_Builder(builder), m_Sema(builder.m_Sema), m_CladPlugin(builder.m_CladPlugin), m_Context(builder.m_Context), m_DerivativeFnScope(nullptr), m_DerivativeInFlight(false), - m_Derivative(nullptr), m_Function(nullptr) {} + m_Derivative(nullptr), m_DiffReq(request) {} using Stmts = llvm::SmallVector; @@ -117,8 +117,8 @@ namespace clad { bool m_DerivativeInFlight; /// The Derivative function that is being generated. clang::FunctionDecl* m_Derivative; - /// The function that is currently differentiated. - const clang::FunctionDecl* m_Function; + /// The differentiation request that is being currently processed. + const DiffRequest& m_DiffReq; DiffMode m_Mode; /// Map used to keep track of variable declarations and match them /// with their derivatives. diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 5a3ed4394..fd3b7f70c 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -33,8 +33,9 @@ using namespace clang; namespace clad { -BaseForwardModeVisitor::BaseForwardModeVisitor(DerivativeBuilder& builder) - : VisitorBase(builder) {} +BaseForwardModeVisitor::BaseForwardModeVisitor(DerivativeBuilder& builder, + const DiffRequest& request) + : VisitorBase(builder, request) {} BaseForwardModeVisitor::~BaseForwardModeVisitor() {} @@ -60,8 +61,8 @@ bool IsRealNonReferenceType(QualType T) { DerivativeAndOverload BaseForwardModeVisitor::Derive(const FunctionDecl* FD, const DiffRequest& request) { + assert(m_DiffReq == request && "Can't pass two different requests!"); silenceDiags = !request.VerboseDiags; - m_Function = FD; m_Functor = request.Functor; m_Mode = DiffMode::forward; assert(!m_DerivativeInFlight && @@ -144,7 +145,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, // then the specified independent argument is a member variable of the // class defining the call operator. // Thus, we need to find index of the member variable instead. - if (m_Function->param_empty() && m_Functor) { + if (m_DiffReq->param_empty() && m_Functor) { m_ArgIndex = std::distance(m_Functor->field_begin(), std::find(m_Functor->field_begin(), @@ -161,11 +162,11 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, IdentifierInfo* II = &m_Context.Idents.get( request.BaseFunctionName + "_d" + s + "arg" + argInfo + derivativeSuffix); - SourceLocation validLoc{m_Function->getLocation()}; + SourceLocation validLoc{m_DiffReq->getLocation()}; DeclarationNameInfo name(II, validLoc); llvm::SaveAndRestore SaveContext(m_Sema.CurContext); llvm::SaveAndRestore SaveScope(getCurrentScope()); - DeclContext* DC = const_cast(m_Function->getDeclContext()); + DeclContext* DC = const_cast(m_DiffReq->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext result = m_Builder.cloneFunction(FD, *this, DC, validLoc, name, FD->getType()); @@ -368,7 +369,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, clang::QualType BaseForwardModeVisitor::ComputePushforwardFnReturnType() { assert(m_Mode == GetPushForwardMode()); - QualType originalFnRT = m_Function->getReturnType(); + QualType originalFnRT = m_DiffReq->getReturnType(); if (originalFnRT->isVoidType()) return m_Context.VoidTy; TemplateDecl* valueAndPushforward = @@ -382,7 +383,7 @@ clang::QualType BaseForwardModeVisitor::ComputePushforwardFnReturnType() { } void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() { - Stmt* bodyDiff = Visit(m_Function->getBody()).getStmt(); + Stmt* bodyDiff = Visit(m_DiffReq->getBody()).getStmt(); auto* CS = cast(bodyDiff); for (Stmt* S : CS->body()) addToCurrentBlock(S); @@ -391,7 +392,7 @@ void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() { DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, const DiffRequest& request) { - m_Function = FD; + const_cast(m_DiffReq) = request; m_Functor = request.Functor; m_DerivativeOrder = request.CurrentDerivativeOrder; m_Mode = GetPushForwardMode(); @@ -399,11 +400,12 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, "Doesn't support recursive diff. Use DiffPlan."); m_DerivativeInFlight = true; - auto originalFnEffectiveName = utils::ComputeEffectiveFnName(m_Function); + auto originalFnEffectiveName = + utils::ComputeEffectiveFnName(m_DiffReq.Function); IdentifierInfo* derivedFnII = &m_Context.Idents.get( originalFnEffectiveName + GetPushForwardFunctionSuffix()); - DeclarationNameInfo derivedFnName(derivedFnII, m_Function->getLocation()); + DeclarationNameInfo derivedFnName(derivedFnII, m_DiffReq->getLocation()); llvm::SmallVector paramTypes; llvm::SmallVector derivedParamTypes; @@ -417,7 +419,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, } } - for (auto* PVD : m_Function->parameters()) { + for (auto* PVD : m_DiffReq->parameters()) { paramTypes.push_back(PVD->getType()); if (BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) @@ -428,19 +430,19 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, derivedParamTypes.end()); const auto* originalFnType = - dyn_cast(m_Function->getType()); + dyn_cast(m_DiffReq->getType()); QualType returnType = ComputePushforwardFnReturnType(); QualType derivedFnType = m_Context.getFunctionType( returnType, paramTypes, originalFnType->getExtProtoInfo()); llvm::SaveAndRestore saveContext(m_Sema.CurContext); llvm::SaveAndRestore saveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); - auto* DC = const_cast(m_Function->getDeclContext()); + auto* DC = const_cast(m_DiffReq->getDeclContext()); m_Sema.CurContext = DC; - SourceLocation loc{m_Function->getLocation()}; + SourceLocation loc{m_DiffReq->getLocation()}; DeclWithContext cloneFunctionResult = m_Builder.cloneFunction( - m_Function, *this, DC, loc, derivedFnName, derivedFnType); + m_DiffReq.Function, *this, DC, loc, derivedFnName, derivedFnType); m_Derivative = cloneFunctionResult.first; llvm::SmallVector params; @@ -466,9 +468,9 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, } } - std::size_t numParamsOriginalFn = m_Function->getNumParams(); + std::size_t numParamsOriginalFn = m_DiffReq->getNumParams(); for (std::size_t i = 0; i < numParamsOriginalFn; ++i) { - const auto* PVD = m_Function->getParamDecl(i); + const auto* PVD = m_DiffReq->getParamDecl(i); // Some of the special member functions created implicitly by compilers // have missing parameter identifier. bool identifierMissing = false; @@ -1137,7 +1139,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { const_cast(FD->getDeclContext())); // Check if it is a recursive call. - if (!callDiff && (FD == m_Function) && m_Mode == GetPushForwardMode()) { + if (!callDiff && (FD == m_DiffReq.Function) && + m_Mode == GetPushForwardMode()) { // The differentiated function is called recursively. Expr* derivativeRef = m_Sema diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 27ad198d7..1865eb11e 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -268,7 +268,7 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { void InitErrorEstimation( llvm::SmallVectorImpl>& handler, llvm::SmallVectorImpl>& model, - DerivativeBuilder& builder) { + DerivativeBuilder& builder, const DiffRequest& request) { // Set the handler. std::unique_ptr pHandler( new ErrorEstimationHandler()); @@ -276,7 +276,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { // Set error estimation model. If no custom model provided by user, // use the built in Taylor approximation model. if (model.size() != handler.size()) { - std::unique_ptr pModel(new TaylorApprox(builder)); + std::unique_ptr pModel( + new TaylorApprox(builder, request)); model.push_back(std::move(pModel)); } handler.back()->SetErrorEstimationModel(model.back().get()); @@ -340,41 +341,41 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { DerivativeAndOverload result{}; if (request.Mode == DiffMode::forward) { - BaseForwardModeVisitor V(*this); + BaseForwardModeVisitor V(*this, request); result = V.Derive(FD, request); } else if (request.Mode == DiffMode::experimental_pushforward) { - PushForwardModeVisitor V(*this); + PushForwardModeVisitor V(*this, request); result = V.DerivePushforward(FD, request); } else if (request.Mode == DiffMode::vector_forward_mode) { - VectorForwardModeVisitor V(*this); + VectorForwardModeVisitor V(*this, request); result = V.DeriveVectorMode(FD, request); } else if (request.Mode == DiffMode::experimental_vector_pushforward) { - VectorPushForwardModeVisitor V(*this); + VectorPushForwardModeVisitor V(*this, request); result = V.DerivePushforward(FD, request); } else if (request.Mode == DiffMode::reverse) { - ReverseModeVisitor V(*this); + ReverseModeVisitor V(*this, request); result = V.Derive(FD, request); } else if (request.Mode == DiffMode::experimental_pullback) { - ReverseModeVisitor V(*this); + ReverseModeVisitor V(*this, request); if (!m_ErrorEstHandler.empty()) { - InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this); + InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request); V.AddExternalSource(*m_ErrorEstHandler.back()); } result = V.DerivePullback(FD, request); if (!m_ErrorEstHandler.empty()) CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel); } else if (request.Mode == DiffMode::reverse_mode_forward_pass) { - ReverseModeForwPassVisitor V(*this); + ReverseModeForwPassVisitor V(*this, request); result = V.Derive(FD, request); } else if (request.Mode == DiffMode::hessian) { - HessianModeVisitor H(*this); + HessianModeVisitor H(*this, request); result = H.Derive(FD, request); } else if (request.Mode == DiffMode::jacobian) { - ReverseModeVisitor R(*this); + ReverseModeVisitor R(*this, request); result = R.Derive(FD, request); } else if (request.Mode == DiffMode::error_estimation) { - ReverseModeVisitor R(*this); - InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this); + ReverseModeVisitor R(*this, request); + InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request); R.AddExternalSource(*m_ErrorEstHandler.back()); // Finally begin estimation. result = R.Derive(FD, request); diff --git a/lib/Differentiator/ErrorEstimator.cpp b/lib/Differentiator/ErrorEstimator.cpp index da3396dbc..368076483 100644 --- a/lib/Differentiator/ErrorEstimator.cpp +++ b/lib/Differentiator/ErrorEstimator.cpp @@ -313,7 +313,7 @@ void ErrorEstimationHandler::ActBeforeCreatingDerivedFnBodyScope() { void ErrorEstimationHandler::ActOnEndOfDerivedFnBody() { // Since 'return' is not an assignment, add its error to _final_error // given it is not a DeclRefExpr. - EmitFinalErrorStmts(*m_Params, m_RMV->m_Function->getNumParams()); + EmitFinalErrorStmts(*m_Params, m_RMV->m_DiffReq->getNumParams()); } void ErrorEstimationHandler::ActBeforeDifferentiatingStmtInVisitCompoundStmt() { diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index 6fe5ecbfd..2521b312f 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -27,25 +27,26 @@ using namespace clang; namespace clad { - HessianModeVisitor::HessianModeVisitor(DerivativeBuilder& builder) - : VisitorBase(builder) {} - - HessianModeVisitor::~HessianModeVisitor() {} - - /// Converts the string str into a StringLiteral - static const StringLiteral* CreateStringLiteral(ASTContext& C, - std::string str) { - QualType CharTyConst = C.CharTy.withConst(); - QualType StrTy = clad_compat::getConstantArrayType( - C, CharTyConst, llvm::APInt(/*numBits=*/32, str.size() + 1), - /*SizeExpr=*/nullptr, - /*ASM=*/clad_compat::ArraySizeModifier_Normal, - /*IndexTypeQuals*/ 0); - const StringLiteral* SL = StringLiteral::Create( - C, str, /*Kind=*/clad_compat::StringLiteralKind_Ordinary, - /*Pascal=*/false, StrTy, noLoc); - return SL; - } +HessianModeVisitor::HessianModeVisitor(DerivativeBuilder& builder, + const DiffRequest& request) + : VisitorBase(builder, request) {} + +HessianModeVisitor::~HessianModeVisitor() {} + +/// Converts the string str into a StringLiteral +static const StringLiteral* CreateStringLiteral(ASTContext& C, + std::string str) { + QualType CharTyConst = C.CharTy.withConst(); + QualType StrTy = clad_compat::getConstantArrayType( + C, CharTyConst, llvm::APInt(/*numBits=*/32, str.size() + 1), + /*SizeExpr=*/nullptr, + /*ASM=*/clad_compat::ArraySizeModifier_Normal, + /*IndexTypeQuals*/ 0); + const StringLiteral* SL = StringLiteral::Create( + C, str, /*Kind=*/clad_compat::StringLiteralKind_Ordinary, + /*Pascal=*/false, StrTy, noLoc); + return SL; +} /// Derives the function w.r.t both forward and reverse mode and returns the /// FunctionDecl obtained from reverse mode differentiation @@ -119,18 +120,18 @@ namespace clad { size_t TotalIndependentArgsSize = 0; // request.Function is original function passed in from clad::hessian - m_Function = request.Function; + assert(m_DiffReq == request); std::string hessianFuncName = request.BaseFunctionName + "_hessian"; // To be consistent with older tests, nothing is appended to 'f_hessian' if // we differentiate w.r.t. all the parameters at once. if (args.size() != FD->getNumParams() || - !std::equal(m_Function->param_begin(), m_Function->param_end(), + !std::equal(m_DiffReq->param_begin(), m_DiffReq->param_end(), args.begin())) { for (auto arg : args) { auto it = - std::find(m_Function->param_begin(), m_Function->param_end(), arg); - auto idx = std::distance(m_Function->param_begin(), it); + std::find(m_DiffReq->param_begin(), m_DiffReq->param_end(), arg); + auto idx = std::distance(m_DiffReq->param_begin(), it); hessianFuncName += ('_' + std::to_string(idx)); } } @@ -226,23 +227,21 @@ namespace clad { size_t TotalIndependentArgsSize, std::string hessianFuncName) { DiffParams args; - std::copy(m_Function->param_begin(), - m_Function->param_end(), + std::copy(m_DiffReq->param_begin(), m_DiffReq->param_end(), std::back_inserter(args)); IdentifierInfo* II = &m_Context.Idents.get(hessianFuncName); DeclarationNameInfo name(II, noLoc); - llvm::SmallVector paramTypes(m_Function->getNumParams() + 1); + llvm::SmallVector paramTypes(m_DiffReq->getNumParams() + 1); - std::transform(m_Function->param_begin(), - m_Function->param_end(), + std::transform(m_DiffReq->param_begin(), m_DiffReq->param_end(), std::begin(paramTypes), [](const ParmVarDecl* PVD) { return PVD->getType(); }); - paramTypes.back() = m_Context.getPointerType(m_Function->getReturnType()); + paramTypes.back() = m_Context.getPointerType(m_DiffReq->getReturnType()); - auto originalFnProtoType = cast(m_Function->getType()); + auto originalFnProtoType = cast(m_DiffReq->getType()); QualType hessianFunctionType = m_Context.getFunctionType( m_Context.VoidTy, llvm::ArrayRef(paramTypes.data(), paramTypes.size()), @@ -250,14 +249,14 @@ namespace clad { originalFnProtoType->getExtProtoInfo()); // Create the gradient function declaration. - DeclContext* DC = const_cast(m_Function->getDeclContext()); + DeclContext* DC = const_cast(m_DiffReq->getDeclContext()); llvm::SaveAndRestore SaveContext(m_Sema.CurContext); llvm::SaveAndRestore SaveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); m_Sema.CurContext = DC; DeclWithContext result = m_Builder.cloneFunction( - m_Function, *this, DC, noLoc, name, hessianFunctionType); + m_DiffReq.Function, *this, DC, noLoc, name, hessianFunctionType); FunctionDecl* hessianFD = result.first; beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | @@ -266,10 +265,8 @@ namespace clad { m_Sema.PushDeclContext(getCurrentScope(), hessianFD); llvm::SmallVector params(paramTypes.size()); - std::transform(m_Function->param_begin(), - m_Function->param_end(), - std::begin(params), - [&](const ParmVarDecl* PVD) { + std::transform(m_DiffReq->param_begin(), m_DiffReq->param_end(), + std::begin(params), [&](const ParmVarDecl* PVD) { auto VD = ParmVarDecl::Create(m_Context, hessianFD, @@ -344,7 +341,7 @@ namespace clad { // FIXME: Add support for class type in the hessian matrix. For this, we // need to add a way to represent hessian matrix when class type objects // are involved. - if (auto MD = dyn_cast(m_Function)) { + if (auto MD = dyn_cast(m_DiffReq.Function)) { const CXXRecordDecl* RD = MD->getParent(); if (MD->isInstance() && !RD->isLambda()) { QualType thisObjectType = diff --git a/lib/Differentiator/PushForwardModeVisitor.cpp b/lib/Differentiator/PushForwardModeVisitor.cpp index b44779149..962eda0b9 100644 --- a/lib/Differentiator/PushForwardModeVisitor.cpp +++ b/lib/Differentiator/PushForwardModeVisitor.cpp @@ -14,8 +14,9 @@ using namespace clang; namespace clad { -PushForwardModeVisitor::PushForwardModeVisitor(DerivativeBuilder& builder) - : BaseForwardModeVisitor(builder) {} +PushForwardModeVisitor::PushForwardModeVisitor(DerivativeBuilder& builder, + const DiffRequest& request) + : BaseForwardModeVisitor(builder, request) {} PushForwardModeVisitor::~PushForwardModeVisitor() = default; diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp index 7c444415f..8e51e95ba 100644 --- a/lib/Differentiator/ReverseModeForwPassVisitor.cpp +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -13,28 +13,29 @@ using namespace clang; namespace clad { ReverseModeForwPassVisitor::ReverseModeForwPassVisitor( - DerivativeBuilder& builder) - : ReverseModeVisitor(builder) {} + DerivativeBuilder& builder, const DiffRequest& request) + : ReverseModeVisitor(builder, request) {} DerivativeAndOverload ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, const DiffRequest& request) { + assert(m_DiffReq == request); silenceDiags = !request.VerboseDiags; - m_Function = FD; m_Mode = DiffMode::reverse_mode_forward_pass; - assert(m_Function && "Must not be null."); + assert(m_DiffReq.Function && "Must not be null."); DiffParams args{}; std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); - auto fnName = clad::utils::ComputeEffectiveFnName(m_Function) + "_forw"; + auto fnName = + clad::utils::ComputeEffectiveFnName(m_DiffReq.Function) + "_forw"; auto fnDNI = utils::BuildDeclarationNameInfo(m_Sema, fnName); auto paramTypes = ComputeParamTypes(args); auto returnType = ComputeReturnType(); - const auto* sourceFnType = dyn_cast(m_Function->getType()); + const auto* sourceFnType = dyn_cast(m_DiffReq->getType()); auto fnType = m_Context.getFunctionType(returnType, paramTypes, sourceFnType->getExtProtoInfo()); @@ -42,11 +43,11 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, llvm::SaveAndRestore saveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - m_Sema.CurContext = const_cast(m_Function->getDeclContext()); + m_Sema.CurContext = const_cast(m_DiffReq->getDeclContext()); - SourceLocation validLoc{m_Function->getLocation()}; + SourceLocation validLoc{m_DiffReq->getLocation()}; DeclWithContext fnBuildRes = m_Builder.cloneFunction( - m_Function, *this, m_Sema.CurContext, validLoc, fnDNI, fnType); + m_DiffReq.Function, *this, m_Sema.CurContext, validLoc, fnDNI, fnType); m_Derivative = fnBuildRes.first; beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | @@ -65,7 +66,7 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, beginBlock(); beginBlock(direction::reverse); - StmtDiff bodyDiff = Visit(m_Function->getBody()); + StmtDiff bodyDiff = Visit(m_DiffReq->getBody()); Stmt* forward = bodyDiff.getStmt(); for (Stmt* S : ReverseModeVisitor::m_Globals) @@ -105,14 +106,14 @@ ReverseModeForwPassVisitor::GetParameterDerivativeType(QualType yType, llvm::SmallVector ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) { llvm::SmallVector paramTypes; - paramTypes.reserve(m_Function->getNumParams() * 2); - for (auto* PVD : m_Function->parameters()) + paramTypes.reserve(m_DiffReq->getNumParams() * 2); + for (auto* PVD : m_DiffReq->parameters()) paramTypes.push_back(PVD->getType()); QualType effectiveReturnType = - m_Function->getReturnType().getNonReferenceType(); + m_DiffReq->getReturnType().getNonReferenceType(); - if (const auto* MD = dyn_cast(m_Function)) { + if (const auto* MD = dyn_cast(m_DiffReq.Function)) { const CXXRecordDecl* RD = MD->getParent(); if (MD->isInstance() && !RD->isLambda()) { QualType thisType = MD->getThisType(); @@ -121,7 +122,7 @@ ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) { } } - for (auto* PVD : m_Function->parameters()) { + for (auto* PVD : m_DiffReq->parameters()) { const auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD); if (it != std::end(diffParams)) { @@ -135,7 +136,7 @@ ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) { clang::QualType ReverseModeForwPassVisitor::ComputeReturnType() { auto* valAndAdjointTempDecl = LookupTemplateDeclInCladNamespace("ValueAndAdjoint"); - auto RT = m_Function->getReturnType(); + auto RT = m_DiffReq->getReturnType(); auto T = InstantiateTemplate(valAndAdjointTempDecl, {RT, RT}); return T; } @@ -144,13 +145,13 @@ llvm::SmallVector ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) { llvm::SmallVector params; llvm::SmallVector paramDerivatives; - params.reserve(m_Function->getNumParams() + diffParams.size()); + params.reserve(m_DiffReq->getNumParams() + diffParams.size()); const auto* derivativeFnType = cast(m_Derivative->getType()); - std::size_t dParamTypesIdx = m_Function->getNumParams(); + std::size_t dParamTypesIdx = m_DiffReq->getNumParams(); - if (const auto* MD = dyn_cast(m_Function)) { + if (const auto* MD = dyn_cast(m_DiffReq.Function)) { const CXXRecordDecl* RD = MD->getParent(); if (MD->isInstance() && !RD->isLambda()) { auto* thisDerivativePVD = utils::BuildParmVarDecl( @@ -168,7 +169,7 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) { ++dParamTypesIdx; } } - for (auto* PVD : m_Function->parameters()) { + for (auto* PVD : m_DiffReq->parameters()) { // FIXME: Call expression may contain default arguments that we are now // removing. This may cause issues. auto* newPVD = utils::BuildParmVarDecl( @@ -195,7 +196,7 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) { m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), /*AddToContext=*/false); m_Variables[*it] = - BuildOp(UO_Deref, BuildDeclRef(dPVD), m_Function->getLocation()); + BuildOp(UO_Deref, BuildDeclRef(dPVD), m_DiffReq->getLocation()); } } params.insert(params.end(), paramDerivatives.begin(), paramDerivatives.end()); @@ -256,7 +257,7 @@ ReverseModeForwPassVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { // and we should return it. Expr* ResultRef = nullptr; if (opCode == UnaryOperatorKind::UO_Deref) { - if (const auto* MD = dyn_cast(m_Function)) { + if (const auto* MD = dyn_cast(m_DiffReq.Function)) { if (MD->isInstance()) { diff = Visit(UnOp->getSubExpr()); Expr* cloneE = BuildOp(UnaryOperatorKind::UO_Deref, diff.getExpr()); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index c1bbb5a2d..7a1f2e3dc 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -78,7 +78,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BuildDeclRef(GlobalStoreImpl(TapeType, prefix, getZeroInit(TapeType))); auto* VD = cast(cast(TapeRef)->getDecl()); // Add fake location, since Clang AST does assert(Loc.isValid()) somewhere. - VD->setLocation(m_Function->getLocation()); + VD->setLocation(m_DiffReq->getLocation()); CXXScopeSpec CSS; CSS.Extend(m_Context, GetCladNamespace(), noLoc, noLoc); auto* PopDRE = m_Sema @@ -99,8 +99,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return CladTapeResult{*this, PushExpr, PopExpr, TapeRef}; } - ReverseModeVisitor::ReverseModeVisitor(DerivativeBuilder& builder) - : VisitorBase(builder), m_Result(nullptr) {} + ReverseModeVisitor::ReverseModeVisitor(DerivativeBuilder& builder, + const DiffRequest& request) + : VisitorBase(builder, request), m_Result(nullptr) {} ReverseModeVisitor::~ReverseModeVisitor() { if (m_ExternalSource) { @@ -121,11 +122,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // requested. // FIXME: Here we are assuming all function parameters are of differentiable // type. Ideally, we should not make any such assumption. - std::size_t totalDerivedParamsSize = m_Function->getNumParams() * 2; - std::size_t numOfDerivativeParams = m_Function->getNumParams(); + std::size_t totalDerivedParamsSize = m_DiffReq->getNumParams() * 2; + std::size_t numOfDerivativeParams = m_DiffReq->getNumParams(); // Account for the this pointer. - if (isa(m_Function) && !utils::IsStaticMethod(m_Function)) + if (isa(m_DiffReq.Function) && + !utils::IsStaticMethod(m_DiffReq.Function)) ++numOfDerivativeParams; // All output parameters will be of type `void*`. These // parameters will be casted to correct type before the call to the actual @@ -138,7 +140,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SmallVector paramTypes; // Add types for representing original function parameters. - for (auto* PVD : m_Function->parameters()) + for (auto* PVD : m_DiffReq->parameters()) paramTypes.push_back(PVD->getType()); // Add types for representing parameter derivatives. // FIXME: We are assuming all function parameters are differentiable. We @@ -147,17 +149,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, paramTypes.push_back(outputParamType); auto gradFuncOverloadEPI = - dyn_cast(m_Function->getType())->getExtProtoInfo(); + dyn_cast(m_DiffReq->getType())->getExtProtoInfo(); QualType gradientFunctionOverloadType = m_Context.getFunctionType(m_Context.VoidTy, paramTypes, // Cast to function pointer. gradFuncOverloadEPI); - auto* DC = const_cast(m_Function->getDeclContext()); + auto* DC = const_cast(m_DiffReq->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext gradientOverloadFDWC = - m_Builder.cloneFunction(m_Function, *this, DC, noLoc, gradientNameInfo, - gradientFunctionOverloadType); + m_Builder.cloneFunction(m_DiffReq.Function, *this, DC, noLoc, + gradientNameInfo, gradientFunctionOverloadType); FunctionDecl* gradientOverloadFD = gradientOverloadFDWC.first; beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | @@ -171,7 +173,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, overloadParams.reserve(totalDerivedParamsSize); callArgs.reserve(gradientParams.size()); - for (auto* PVD : m_Function->parameters()) { + for (auto* PVD : m_DiffReq->parameters()) { auto* VD = utils::BuildParmVarDecl( m_Sema, gradientOverloadFD, PVD->getIdentifier(), PVD->getType(), PVD->getStorageClass(), /*defArg=*/nullptr, PVD->getTypeSourceInfo()); @@ -182,7 +184,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (std::size_t i = 0; i < numOfDerivativeParams; ++i) { IdentifierInfo* II = nullptr; StorageClass SC = StorageClass::SC_None; - std::size_t effectiveGradientIndex = m_Function->getNumParams() + i; + std::size_t effectiveGradientIndex = m_DiffReq->getNumParams() + i; // `effectiveGradientIndex < gradientParams.size()` implies that this // parameter represents an actual derivative of one of the function // original parameters. @@ -213,7 +215,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Build derivatives to be used in the call to the actual derived function. // These are initialised by effectively casting the derivative parameters of // overloaded derived function to the correct type. - for (std::size_t i = m_Function->getNumParams(); i < gradientParams.size(); + for (std::size_t i = m_DiffReq->getNumParams(); i < gradientParams.size(); ++i) { auto* overloadParam = overloadParams[i]; auto* gradientParam = gradientParams[i]; @@ -252,7 +254,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActOnStartOfDerive(); silenceDiags = !request.VerboseDiags; - m_Function = FD; + assert(m_DiffReq == request); // reverse mode plugins may have request mode other than // `DiffMode::reverse`, but they still need the `DiffMode::reverse` mode @@ -262,7 +264,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Mode = DiffMode::jacobian; m_Pullback = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 1); - assert(m_Function && "Must not be null."); + assert(m_DiffReq.Function && "Must not be null."); DiffParams args{}; DiffInputVarsInfo DVI; @@ -282,8 +284,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // derived function if (request.Mode == DiffMode::jacobian) { isVectorValued = true; - unsigned lastArgN = m_Function->getNumParams() - 1; - outputArrayStr = m_Function->getParamDecl(lastArgN)->getNameAsString(); + unsigned lastArgN = m_DiffReq->getNumParams() - 1; + outputArrayStr = m_DiffReq->getParamDecl(lastArgN)->getNameAsString(); } // Check if DiffRequest asks for TBR analysis to be enabled @@ -345,7 +347,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, shouldCreateOverload = false; const auto* originalFnType = - dyn_cast(m_Function->getType()); + dyn_cast(m_DiffReq->getType()); // For a function f of type R(A1, A2, ..., An), // the type of the gradient function is void(A1, A2, ..., An, R*, R*, ..., // R*) . the type of the jacobian function is void(A1, A2, ..., An, R*, R*) @@ -361,10 +363,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SaveAndRestore SaveContext(m_Sema.CurContext); llvm::SaveAndRestore SaveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); - auto* DC = const_cast(m_Function->getDeclContext()); + auto* DC = const_cast(m_DiffReq->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext result = m_Builder.cloneFunction( - m_Function, *this, DC, noLoc, name, gradientFunctionType); + m_DiffReq.Function, *this, DC, noLoc, name, gradientFunctionType); FunctionDecl* gradientFD = result.first; m_Derivative = gradientFD; @@ -475,9 +477,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActOnStartOfDerive(); silenceDiags = !request.VerboseDiags; - m_Function = FD; + const_cast(m_DiffReq) = request; m_Mode = DiffMode::experimental_pullback; - assert(m_Function && "Must not be null."); + assert(m_DiffReq.Function && "Must not be null."); DiffParams args{}; if (!request.DVI.empty()) @@ -496,12 +498,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_ExternalSource->ActAfterParsingDiffArgs(request, args); auto derivativeName = - utils::ComputeEffectiveFnName(m_Function) + "_pullback"; + utils::ComputeEffectiveFnName(m_DiffReq.Function) + "_pullback"; auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName); auto paramTypes = ComputeParamTypes(args); const auto* originalFnType = - dyn_cast(m_Function->getType()); + dyn_cast(m_DiffReq->getType()); if (m_ExternalSource) m_ExternalSource->ActAfterCreatingDerivedFnParamTypes(paramTypes); @@ -512,11 +514,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SaveAndRestore saveContext(m_Sema.CurContext); llvm::SaveAndRestore saveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); - m_Sema.CurContext = const_cast(m_Function->getDeclContext()); + m_Sema.CurContext = const_cast(m_DiffReq->getDeclContext()); - SourceLocation validLoc{m_Function->getLocation()}; - DeclWithContext fnBuildRes = m_Builder.cloneFunction( - m_Function, *this, m_Sema.CurContext, validLoc, DNI, pullbackFnType); + SourceLocation validLoc{m_DiffReq->getLocation()}; + DeclWithContext fnBuildRes = + m_Builder.cloneFunction(m_DiffReq.Function, *this, m_Sema.CurContext, + validLoc, DNI, pullbackFnType); m_Derivative = fnBuildRes.first; if (m_ExternalSource) @@ -548,7 +551,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActOnStartOfDerivedFnBody(request); - StmtDiff bodyDiff = Visit(m_Function->getBody()); + StmtDiff bodyDiff = Visit(m_DiffReq->getBody()); Stmt* forward = bodyDiff.getStmt(); Stmt* reverse = bodyDiff.getStmt_dx(); @@ -586,7 +589,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, void ReverseModeVisitor::DifferentiateWithClad() { TBRAnalyzer analyzer(m_Context); if (enableTBR) { - analyzer.Analyze(m_Function); + analyzer.Analyze(m_DiffReq.Function); m_ToBeRecorded = analyzer.getResult(); } @@ -594,13 +597,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // create derived variables for parameters which are not part of // independent variables (args). - for (std::size_t i = 0; i < m_Function->getNumParams(); ++i) { + for (std::size_t i = 0; i < m_DiffReq->getNumParams(); ++i) { ParmVarDecl* param = paramsRef[i]; // derived variables are already created for independent variables. if (m_Variables.count(param)) continue; // in vector mode last non diff parameter is output parameter. - if (isVectorValued && i == m_Function->getNumParams() - 1) + if (isVectorValued && i == m_DiffReq->getNumParams() - 1) continue; auto VDDerivedType = param->getType(); // We cannot initialize derived variable for pointer types because @@ -615,7 +618,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } // Start the visitation process which outputs the statements in the // current block. - StmtDiff BodyDiff = Visit(m_Function->getBody()); + StmtDiff BodyDiff = Visit(m_DiffReq->getBody()); Stmt* Forward = BodyDiff.getStmt(); Stmt* Reverse = BodyDiff.getStmt_dx(); // Create the body of the function. @@ -647,11 +650,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } void ReverseModeVisitor::DifferentiateWithEnzyme() { - unsigned numParams = m_Function->getNumParams(); - auto origParams = m_Function->parameters(); + unsigned numParams = m_DiffReq->getNumParams(); + auto origParams = m_DiffReq->parameters(); llvm::ArrayRef paramsRef = m_Derivative->parameters(); const auto* originalFnType = - dyn_cast(m_Function->getType()); + dyn_cast(m_DiffReq->getType()); // Prepare Arguments and Parameters to enzyme_autodiff llvm::SmallVector enzymeArgs; @@ -660,11 +663,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SmallVector enzymeRealParamsDerived; // First add the function itself as a parameter/argument - enzymeArgs.push_back(BuildDeclRef(const_cast(m_Function))); - auto* fdDeclContext = - const_cast(m_Function->getDeclContext()); + enzymeArgs.push_back( + BuildDeclRef(const_cast(m_DiffReq.Function))); + auto* fdDeclContext = const_cast(m_DiffReq->getDeclContext()); enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( - fdDeclContext, noLoc, m_Function->getType())); + fdDeclContext, noLoc, m_DiffReq->getType())); // Add rest of the parameters/arguments for (unsigned i = 0; i < numParams; i++) { @@ -709,7 +712,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Prepare Function call std::string enzymeCallName = - "__enzyme_autodiff_" + m_Function->getNameAsString(); + "__enzyme_autodiff_" + m_DiffReq->getNameAsString(); IdentifierInfo* IIEnzyme = &m_Context.Idents.get(enzymeCallName); DeclarationName nameEnzyme(IIEnzyme); QualType enzymeFunctionType = @@ -717,7 +720,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, originalFnType->getExtProtoInfo()); FunctionDecl* enzymeCallFD = FunctionDecl::Create( m_Context, fdDeclContext, noLoc, noLoc, nameEnzyme, enzymeFunctionType, - m_Function->getTypeSourceInfo(), SC_Extern); + m_DiffReq->getTypeSourceInfo(), SC_Extern); enzymeCallFD->setParams(enzymeParams); Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs); @@ -1589,7 +1592,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, gradArgExpr = argDerivative; else gradArgExpr = - BuildOp(UO_AddrOf, argDerivative, m_Function->getLocation()); + BuildOp(UO_AddrOf, argDerivative, m_DiffReq->getLocation()); DerivedCallOutputArgs.push_back(gradArgExpr); idx++; } @@ -1631,7 +1634,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Derivative was not found, check if it is a recursive call if (!OverloadedDerivedFn) { - if (FD == m_Function && m_Mode == DiffMode::experimental_pullback) { + if (FD == m_DiffReq.Function && + m_Mode == DiffMode::experimental_pullback) { // Recursive call. Expr* selfRef = m_Sema @@ -1890,7 +1894,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Context.IntTy, m_Context, printErrorInf)); // Build the tape push expressions. - VD->setLocation(m_Function->getLocation()); + VD->setLocation(m_DiffReq->getLocation()); for (unsigned i = 0, e = numArgs; i < e; i++) { Expr* gradRef = BuildDeclRef(VD); Expr* idx = @@ -1987,7 +1991,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Expr* diff_dx = diff.getExpr_dx(); bool specialDThisCase = false; Expr* derivedE = nullptr; - if (const auto* MD = dyn_cast(m_Function)) { + if (const auto* MD = dyn_cast(m_DiffReq.Function)) { if (MD->isInstance() && !diff_dx->getType()->isPointerType()) specialDThisCase = true; // _d_this is already dereferenced. } @@ -2474,7 +2478,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Computation of hessian requires this code to be correctly // differentiated. bool specialThisDiffCase = false; - if (const auto* MD = dyn_cast(m_Function)) { + if (const auto* MD = dyn_cast(m_DiffReq.Function)) { if (VDDerivedType->isPointerType() && MD->isInstance()) { specialThisDiffCase = true; } @@ -3716,14 +3720,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SmallVector ReverseModeVisitor::ComputeParamTypes(const DiffParams& diffParams) { llvm::SmallVector paramTypes; - paramTypes.reserve(m_Function->getNumParams() * 2); - for (auto* PVD : m_Function->parameters()) + paramTypes.reserve(m_DiffReq->getNumParams() * 2); + for (auto* PVD : m_DiffReq->parameters()) paramTypes.push_back(PVD->getType()); // TODO: Add DiffMode::experimental_pullback support here as well. if (m_Mode == DiffMode::reverse || m_Mode == DiffMode::experimental_pullback) { QualType effectiveReturnType = - m_Function->getReturnType().getNonReferenceType(); + m_DiffReq->getReturnType().getNonReferenceType(); if (m_Mode == DiffMode::experimental_pullback) { // FIXME: Generally, we use the function's return type as the argument's // derivative type. We cannot follow this strategy for `void` function @@ -3740,7 +3744,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, paramTypes.push_back(effectiveReturnType); } - if (const auto* MD = dyn_cast(m_Function)) { + if (const auto* MD = dyn_cast(m_DiffReq.Function)) { const CXXRecordDecl* RD = MD->getParent(); if (MD->isInstance() && !RD->isLambda()) { QualType thisType = MD->getThisType(); @@ -3749,16 +3753,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } - for (auto* PVD : m_Function->parameters()) { + for (auto* PVD : m_DiffReq->parameters()) { const auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD); if (it != std::end(diffParams)) paramTypes.push_back(ComputeParamType(PVD->getType())); } } else if (m_Mode == DiffMode::jacobian) { - std::size_t lastArgIdx = m_Function->getNumParams() - 1; + std::size_t lastArgIdx = m_DiffReq->getNumParams() - 1; QualType derivativeParamType = - m_Function->getParamDecl(lastArgIdx)->getType(); + m_DiffReq->getParamDecl(lastArgIdx)->getType(); paramTypes.push_back(derivativeParamType); } return paramTypes; @@ -3768,17 +3772,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, ReverseModeVisitor::BuildParams(DiffParams& diffParams) { llvm::SmallVector params; llvm::SmallVector paramDerivatives; - params.reserve(m_Function->getNumParams() + diffParams.size()); + params.reserve(m_DiffReq->getNumParams() + diffParams.size()); const auto* derivativeFnType = cast(m_Derivative->getType()); - std::size_t dParamTypesIdx = m_Function->getNumParams(); + std::size_t dParamTypesIdx = m_DiffReq->getNumParams(); if (m_Mode == DiffMode::experimental_pullback && - !m_Function->getReturnType()->isVoidType()) { + !m_DiffReq->getReturnType()->isVoidType()) { ++dParamTypesIdx; } - if (const auto* MD = dyn_cast(m_Function)) { + if (const auto* MD = dyn_cast(m_DiffReq.Function)) { const CXXRecordDecl* RD = MD->getParent(); if (!isVectorValued && MD->isInstance() && !RD->isLambda()) { auto* thisDerivativePVD = utils::BuildParmVarDecl( @@ -3799,7 +3803,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } - for (auto* PVD : m_Function->parameters()) { + for (auto* PVD : m_DiffReq->parameters()) { auto* newPVD = utils::BuildParmVarDecl( m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(), PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo()); @@ -3830,8 +3834,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Variables[*it] = (Expr*)BuildDeclRef(dPVD); } else { QualType valueType = dPVD->getType()->getPointeeType(); - m_Variables[*it] = BuildOp(UO_Deref, BuildDeclRef(dPVD), - m_Function->getLocation()); + m_Variables[*it] = + BuildOp(UO_Deref, BuildDeclRef(dPVD), m_DiffReq->getLocation()); // Add additional paranthesis if derivative is of record type // because `*derivative.someField` will be incorrectly evaluated if // the derived function is compiled standalone. @@ -3844,10 +3848,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } if (m_Mode == DiffMode::experimental_pullback && - !m_Function->getReturnType()->isVoidType()) { + !m_DiffReq->getReturnType()->isVoidType()) { IdentifierInfo* pullbackParamII = CreateUniqueIdentifier("_d_y"); QualType pullbackType = - derivativeFnType->getParamType(m_Function->getNumParams()); + derivativeFnType->getParamType(m_DiffReq->getNumParams()); ParmVarDecl* pullbackPVD = utils::BuildParmVarDecl( m_Sema, m_Derivative, pullbackParamII, pullbackType); paramDerivatives.insert(paramDerivatives.begin(), pullbackPVD); diff --git a/lib/Differentiator/VectorForwardModeVisitor.cpp b/lib/Differentiator/VectorForwardModeVisitor.cpp index 2c8cabe21..5e08a6b84 100644 --- a/lib/Differentiator/VectorForwardModeVisitor.cpp +++ b/lib/Differentiator/VectorForwardModeVisitor.cpp @@ -10,8 +10,9 @@ using namespace clang; namespace clad { -VectorForwardModeVisitor::VectorForwardModeVisitor(DerivativeBuilder& builder) - : BaseForwardModeVisitor(builder), m_IndVarCountExpr(nullptr) {} +VectorForwardModeVisitor::VectorForwardModeVisitor(DerivativeBuilder& builder, + const DiffRequest& request) + : BaseForwardModeVisitor(builder, request), m_IndVarCountExpr(nullptr) {} VectorForwardModeVisitor::~VectorForwardModeVisitor() {} @@ -55,7 +56,7 @@ void VectorForwardModeVisitor::SetIndependentVarsExpr(Expr* IndVarCountExpr) { DerivativeAndOverload VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, const DiffRequest& request) { - m_Function = FD; + assert(m_DiffReq == request); m_Mode = DiffMode::vector_forward_mode; DiffParams args{}; @@ -74,16 +75,15 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, } } IdentifierInfo* II = &m_Context.Idents.get(derivedFnName); - SourceLocation loc{m_Function->getLocation()}; + SourceLocation loc{m_DiffReq->getLocation()}; DeclarationNameInfo name(II, loc); // Generate the function type for the derivative. llvm::SmallVector paramTypes; - paramTypes.reserve(m_Function->getNumParams() + args.size()); - for (auto PVD : m_Function->parameters()) { + paramTypes.reserve(m_DiffReq->getNumParams() + args.size()); + for (auto PVD : m_DiffReq->parameters()) paramTypes.push_back(PVD->getType()); - } - for (auto PVD : m_Function->parameters()) { + for (auto PVD : m_DiffReq->parameters()) { auto it = std::find(std::begin(args), std::end(args), PVD); if (it == std::end(args)) continue; // This parameter is not in the diff list. @@ -105,13 +105,13 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, m_Context.VoidTy, llvm::ArrayRef(paramTypes.data(), paramTypes.size()), // Cast to function pointer. - dyn_cast(m_Function->getType())->getExtProtoInfo()); + dyn_cast(m_DiffReq->getType())->getExtProtoInfo()); // Create the function declaration for the derivative. - DeclContext* DC = const_cast(m_Function->getDeclContext()); + DeclContext* DC = const_cast(m_DiffReq->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext result = m_Builder.cloneFunction( - m_Function, *this, DC, loc, name, vectorDiffFunctionType); + m_DiffReq.Function, *this, DC, loc, name, vectorDiffFunctionType); FunctionDecl* vectorDiffFD = result.first; m_Derivative = vectorDiffFD; @@ -153,15 +153,15 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, // Current Index of independent variable in the param list of the function. size_t independentVarIndex = 0; - for (size_t i = 0; i < m_Function->getNumParams(); ++i) { + for (size_t i = 0; i < m_DiffReq->getNumParams(); ++i) { bool is_array = - utils::isArrayOrPointerType(m_Function->getParamDecl(i)->getType()); + utils::isArrayOrPointerType(m_DiffReq->getParamDecl(i)->getType()); auto param = params[i]; QualType dParamType = clad::utils::GetValueType(param->getType()); Expr* dVectorParam = nullptr; if (m_IndependentVars.size() > independentVarIndex && - m_IndependentVars[independentVarIndex] == m_Function->getParamDecl(i)) { + m_IndependentVars[independentVarIndex] == m_DiffReq->getParamDecl(i)) { // Current offset for independent variable. Expr* offsetExpr = arrayIndVarCountExpr; @@ -175,8 +175,8 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, if (is_array) { // Get size of the array. - Expr* getSize = BuildArrayRefSizeExpr( - m_ParamVariables[m_Function->getParamDecl(i)]); + Expr* getSize = + BuildArrayRefSizeExpr(m_ParamVariables[m_DiffReq->getParamDecl(i)]); // Create an identity matrix for the parameter, // with number of rows equal to the size of the array, @@ -259,36 +259,35 @@ clang::FunctionDecl* VectorForwardModeVisitor::CreateVectorModeOverload() { // Calculate the total number of parameters that would be required for // automatic differentiation in the derived function if all args are // requested. - std::size_t totalDerivedParamsSize = m_Function->getNumParams() * 2; - std::size_t numDerivativeParams = m_Function->getNumParams(); + std::size_t totalDerivedParamsSize = m_DiffReq->getNumParams() * 2; + std::size_t numDerivativeParams = m_DiffReq->getNumParams(); // Generate the function type for the derivative. llvm::SmallVector paramTypes; paramTypes.reserve(totalDerivedParamsSize); - for (auto* PVD : m_Function->parameters()) { + for (auto* PVD : m_DiffReq->parameters()) paramTypes.push_back(PVD->getType()); - } // instantiate output parameter type as void* QualType outputParamType = GetCladArrayRefOfType(m_Context.VoidTy); // Push param types for derived params. - for (std::size_t i = 0; i < m_Function->getNumParams(); ++i) + for (std::size_t i = 0; i < m_DiffReq->getNumParams(); ++i) paramTypes.push_back(outputParamType); auto vectorModeFuncOverloadEPI = - dyn_cast(m_Function->getType())->getExtProtoInfo(); + dyn_cast(m_DiffReq->getType())->getExtProtoInfo(); QualType vectorModeFuncOverloadType = m_Context.getFunctionType( m_Context.VoidTy, llvm::ArrayRef(paramTypes.data(), paramTypes.size()), vectorModeFuncOverloadEPI); // Create the function declaration for the derivative. - auto* DC = const_cast(m_Function->getDeclContext()); + auto* DC = const_cast(m_DiffReq->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext result = - m_Builder.cloneFunction(m_Function, *this, DC, noLoc, vectorModeNameInfo, - vectorModeFuncOverloadType); + m_Builder.cloneFunction(m_DiffReq.Function, *this, DC, noLoc, + vectorModeNameInfo, vectorModeFuncOverloadType); FunctionDecl* vectorModeOverloadFD = result.first; // Function declaration scope @@ -304,7 +303,7 @@ clang::FunctionDecl* VectorForwardModeVisitor::CreateVectorModeOverload() { // vectormode function. callArgs.reserve(vectorModeParams.size()); - for (auto* PVD : m_Function->parameters()) { + for (auto* PVD : m_DiffReq->parameters()) { auto* VD = utils::BuildParmVarDecl( m_Sema, vectorModeOverloadFD, PVD->getIdentifier(), PVD->getType(), PVD->getStorageClass(), /*defArg=*/nullptr, PVD->getTypeSourceInfo()); @@ -314,7 +313,7 @@ clang::FunctionDecl* VectorForwardModeVisitor::CreateVectorModeOverload() { for (std::size_t i = 0; i < numDerivativeParams; ++i) { ParmVarDecl* PVD = nullptr; - std::size_t effectiveIndex = m_Function->getNumParams() + i; + std::size_t effectiveIndex = m_DiffReq->getNumParams() + i; if (effectiveIndex < vectorModeParams.size()) { // This parameter represents an actual derivative parameter. @@ -349,7 +348,7 @@ clang::FunctionDecl* VectorForwardModeVisitor::CreateVectorModeOverload() { // Build derivatives to be used in the call to the actual derived function. // These are initialised by effectively casting the derivative parameters of // overloaded derived function to the correct type. - for (std::size_t i = m_Function->getNumParams(); i < vectorModeParams.size(); + for (std::size_t i = m_DiffReq->getNumParams(); i < vectorModeParams.size(); ++i) { auto* overloadParam = overloadParams[i]; auto* vectorModeParam = vectorModeParams[i]; @@ -393,15 +392,15 @@ clang::FunctionDecl* VectorForwardModeVisitor::CreateVectorModeOverload() { llvm::SmallVector VectorForwardModeVisitor::BuildVectorModeParams(DiffParams& diffParams) { llvm::SmallVector params, paramDerivatives; - params.reserve(m_Function->getNumParams() + diffParams.size()); + params.reserve(m_DiffReq->getNumParams() + diffParams.size()); auto derivativeFnType = cast(m_Derivative->getType()); - std::size_t dParamTypesIdx = m_Function->getNumParams(); + std::size_t dParamTypesIdx = m_DiffReq->getNumParams(); // Count the number of non-array independent variables requested for // differentiation. size_t nonArrayIndVarCount = 0; - for (auto PVD : m_Function->parameters()) { + for (auto PVD : m_DiffReq->parameters()) { auto newPVD = utils::BuildParmVarDecl( m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(), PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo()); @@ -499,7 +498,7 @@ StmtDiff VectorForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { Expr* derivedRetValE = retValDiff.getExpr_dx(); // If we are in vector mode, we need to wrap the return value in a // vector. - SourceLocation loc{m_Function->getLocation()}; + SourceLocation loc{m_DiffReq->getLocation()}; llvm::SmallVector args = {m_IndVarCountExpr, derivedRetValE}; QualType cladArrayType = GetCladArrayOfType(utils::GetValueType(retType)); TypeSourceInfo* TSI = m_Context.getTrivialTypeSourceInfo(cladArrayType, loc); @@ -597,7 +596,7 @@ VectorForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { // clad::array _d_vector_y(2, 1); // this means that we have to initialize the derivative vector of // size 2 with all elements equal to 1. - SourceLocation loc{m_Function->getLocation()}; + SourceLocation loc{m_DiffReq->getLocation()}; llvm::SmallVector args = {m_IndVarCountExpr, initDiff.getExpr_dx()}; QualType cladArrayType = GetCladArrayOfType(utils::GetValueType(VD->getType())); diff --git a/lib/Differentiator/VectorPushForwardModeVisitor.cpp b/lib/Differentiator/VectorPushForwardModeVisitor.cpp index bfd8da43e..49d31374b 100644 --- a/lib/Differentiator/VectorPushForwardModeVisitor.cpp +++ b/lib/Differentiator/VectorPushForwardModeVisitor.cpp @@ -9,8 +9,8 @@ using namespace clang; namespace clad { VectorPushForwardModeVisitor::VectorPushForwardModeVisitor( - DerivativeBuilder& builder) - : VectorForwardModeVisitor(builder) {} + DerivativeBuilder& builder, const DiffRequest& request) + : VectorForwardModeVisitor(builder, request) {} VectorPushForwardModeVisitor::~VectorPushForwardModeVisitor() = default; diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 6fd70910d..ec0a8883c 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -115,10 +115,9 @@ namespace clad { VarDecl::InitializationStyle IS) { // add namespace specifier in variable declaration if needed. Type = utils::AddNamespaceSpecifier(m_Sema, m_Context, Type); - auto VD = - VarDecl::Create(m_Context, m_Sema.CurContext, m_Function->getLocation(), - m_Function->getLocation(), Identifier, Type, TSI, - SC_None); + auto VD = VarDecl::Create( + m_Context, m_Sema.CurContext, m_DiffReq->getLocation(), + m_DiffReq->getLocation(), Identifier, Type, TSI, SC_None); if (Init) { m_Sema.AddInitializerToDecl(VD, Init, DirectInit); @@ -133,7 +132,7 @@ namespace clad { } void VisitorBase::updateReferencesOf(Stmt* InSubtree) { - utils::ReferencesUpdater up(m_Sema, getCurrentScope(), m_Function, + utils::ReferencesUpdater up(m_Sema, getCurrentScope(), m_DiffReq.Function, m_DeclReplacements); up.TraverseStmt(InSubtree); } @@ -356,7 +355,7 @@ namespace clad { QualType VisitorBase::CloneType(const QualType QT) { auto clonedType = m_Builder.m_NodeCloner->CloneType(QT); - utils::ReferencesUpdater up(m_Sema, getCurrentScope(), m_Function, + utils::ReferencesUpdater up(m_Sema, getCurrentScope(), m_DiffReq.Function, m_DeclReplacements); up.updateType(clonedType); return clonedType; @@ -531,7 +530,7 @@ namespace clad { MutableArrayRef ArgExprs, SourceLocation Loc /*=noLoc*/) { if (Loc.isInvalid()) - Loc = m_Function->getLocation(); + Loc = m_DiffReq->getLocation(); UnqualifiedId Member; Member.setIdentifier(&m_Context.Idents.get(MemberFunctionName), Loc); CXXScopeSpec SS; @@ -566,7 +565,7 @@ namespace clad { Expr* thisExpr = clad_compat::Sema_BuildCXXThisExpr(m_Sema, FD); bool isArrow = true; if (Loc.isInvalid()) - Loc = m_Function->getLocation(); + Loc = m_DiffReq->getLocation(); // C++ does not support perfect forwarding of `*this` object inside // a member function. @@ -621,7 +620,7 @@ namespace clad { /*Fn=*/exprFunc, /*LParenLoc=*/noLoc, /*ArgExprs=*/llvm::MutableArrayRef(argExprs), - /*RParenLoc=*/m_Function->getLocation()) + /*RParenLoc=*/m_DiffReq->getLocation()) .get(); } return call; diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 415ffbeb0..5cf698a36 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -179,7 +179,8 @@ namespace clad { it != ie; ++it) { auto estimationPlugin = it->instantiate(); m_DerivativeBuilder->AddErrorEstimationModel( - estimationPlugin->InstantiateCustomModel(*m_DerivativeBuilder)); + estimationPlugin->InstantiateCustomModel(*m_DerivativeBuilder, + request)); } }