Skip to content

Commit

Permalink
Pass the DiffRequest down to the visitors. NFC.
Browse files Browse the repository at this point in the history
The intent of this patch is to centralize the information about the diff request
in one place. This will help upcoming patches to reduce the state of the
visitors and refactor some of the implementation in a common place.

That will unblock work on vgvassilev#721.
  • Loading branch information
vgvassilev committed Jun 11, 2024
1 parent 340ad91 commit 5595016
Show file tree
Hide file tree
Showing 23 changed files with 254 additions and 238 deletions.
5 changes: 3 additions & 2 deletions demos/ErrorEstimation/CustomModel/CustomModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
// FPErrorEstimationModel class.
class CustomModel : public clad::FPErrorEstimationModel {
public:
CustomModel(clad::DerivativeBuilder& builder)
: FPErrorEstimationModel(builder) {}
CustomModel(clad::DerivativeBuilder& builder,
const clad::DiffRequest& request)
: FPErrorEstimationModel(builder, request) {}
/// Return an expression of the following kind:
/// dfdx * delta_x
clang::Expr* AssignError(clad::StmtDiff refExpr,
Expand Down
4 changes: 2 additions & 2 deletions demos/ErrorEstimation/PrintModel/PrintModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
// FPErrorEstimationModel class.
class PrintModel : public clad::FPErrorEstimationModel {
public:
PrintModel(clad::DerivativeBuilder& builder)
: FPErrorEstimationModel(builder) {}
PrintModel(clad::DerivativeBuilder& builder, const clad::DiffRequest& request)
: FPErrorEstimationModel(builder, request) {}
// Return an expression of the following kind:
// dfdx * delta_x
clang::Expr* AssignError(clad::StmtDiff refExpr, const std::string& name) override;
Expand Down
6 changes: 3 additions & 3 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class BaseForwardModeVisitor
unsigned m_ArgIndex = ~0;

public:
BaseForwardModeVisitor(DerivativeBuilder& builder);
BaseForwardModeVisitor(DerivativeBuilder& builder,
const DiffRequest& request);
virtual ~BaseForwardModeVisitor();

///\brief Produces the first derivative of a given function.
Expand All @@ -41,8 +42,7 @@ class BaseForwardModeVisitor
const DiffRequest& request);

/// Returns the return type for the pushforward function of the function
/// `m_Function`.
/// \note `m_Function` field should be set before using this function.
/// `m_DiffReq->Function`.
clang::QualType ComputePushforwardFnReturnType();

virtual void ExecuteInsidePushforwardFunctionBlock();
Expand Down
2 changes: 2 additions & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ struct DiffRequest {
DeclarationOnly == other.DeclarationOnly;
}

const clang::FunctionDecl* operator->() const { return Function; }

// String operator for printing the node.
operator std::string() const {
std::string res = BaseFunctionName + "__order_" +
Expand Down
20 changes: 14 additions & 6 deletions include/clad/Differentiator/EstimationModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ namespace clad {
std::unordered_map<const clang::VarDecl*, clang::Expr*> m_EstimateVar;

public:
FPErrorEstimationModel(DerivativeBuilder& builder) : VisitorBase(builder) {}
// FIXME: Add a proper parameter for the DiffRequest here.
FPErrorEstimationModel(DerivativeBuilder& builder,
const DiffRequest& request)
: VisitorBase(builder, request) {}
virtual ~FPErrorEstimationModel();

/// Clear the variable estimate map so that we can start afresh.
Expand Down Expand Up @@ -83,10 +86,13 @@ namespace clad {
/// custom model.
/// \param[in] builder A build instance to pass to the custom model
/// constructor.
/// \param[in] request The differentiation configuration passed to the
/// custom model
/// \returns A reference to the custom class wrapped in the
/// FPErrorEstimationModel class.
virtual std::unique_ptr<FPErrorEstimationModel>
InstantiateCustomModel(DerivativeBuilder& builder) = 0;
InstantiateCustomModel(DerivativeBuilder& builder,
const DiffRequest& request) = 0;
};

/// A class used to register custom plugins.
Expand All @@ -99,16 +105,18 @@ namespace clad {
///
/// \param[in] builder The current instance of derivative builder.
std::unique_ptr<FPErrorEstimationModel>
InstantiateCustomModel(DerivativeBuilder& builder) override {
return std::unique_ptr<FPErrorEstimationModel>(new CustomClass(builder));
InstantiateCustomModel(DerivativeBuilder& builder,
const DiffRequest& request) override {
return std::unique_ptr<FPErrorEstimationModel>(
new CustomClass(builder, request));
}
};

/// Example class for taylor series approximation based error estimation.
class TaylorApprox : public FPErrorEstimationModel {
public:
TaylorApprox(DerivativeBuilder& builder)
: FPErrorEstimationModel(builder) {}
TaylorApprox(DerivativeBuilder& builder, const DiffRequest& request)
: FPErrorEstimationModel(builder, request) {}
// Return an expression of the following kind:
// std::abs(dfdx * delta_x * Em)
clang::Expr* AssignError(StmtDiff refExpr,
Expand Down
4 changes: 2 additions & 2 deletions include/clad/Differentiator/HessianModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace clad {
size_t TotalIndependentArgsSize, std::string hessianFuncName);

public:
HessianModeVisitor(DerivativeBuilder& builder);
HessianModeVisitor(DerivativeBuilder& builder, const DiffRequest& request);
~HessianModeVisitor();

///\brief Produces the hessian second derivative columns of a given
Expand All @@ -53,4 +53,4 @@ namespace clad {
};
} // end namespace clad

#endif // CLAD_HESSIAN_MODE_VISITOR_H
#endif // CLAD_HESSIAN_MODE_VISITOR_H
4 changes: 3 additions & 1 deletion include/clad/Differentiator/PushForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
#include "BaseForwardModeVisitor.h"

namespace clad {

/// A visitor for processing the function code in forward mode.
/// Used to compute derivatives by clad::differentiate.
class PushForwardModeVisitor : public BaseForwardModeVisitor {

public:
PushForwardModeVisitor(DerivativeBuilder& builder);
PushForwardModeVisitor(DerivativeBuilder& builder,
const DiffRequest& request);
~PushForwardModeVisitor() override;

StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
Expand Down
5 changes: 3 additions & 2 deletions include/clad/Differentiator/ReverseModeForwPassVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor {
clang::QualType xType);

public:
ReverseModeForwPassVisitor(DerivativeBuilder& builder);
ReverseModeForwPassVisitor(DerivativeBuilder& builder,
const DiffRequest& request);
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);

Expand All @@ -34,4 +35,4 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor {
};
} // namespace clad

#endif
#endif
10 changes: 3 additions & 7 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ namespace clad {
llvm::SmallVectorImpl<clang::Expr*>& outputArgs);

public:
ReverseModeVisitor(DerivativeBuilder& builder);
ReverseModeVisitor(DerivativeBuilder& builder, const DiffRequest& request);
virtual ~ReverseModeVisitor();

///\brief Produces the gradient of a given function.
Expand Down Expand Up @@ -629,18 +629,14 @@ namespace clad {
/// Computes and returns the sequence of derived function parameter types.
///
/// Information about the original function and the differentiation mode
/// are taken from the data member variables. In particular, `m_Function`,
/// `m_Mode` data members should be correctly set before using this
/// function.
/// are taken from the data member variables.
llvm::SmallVector<clang::QualType, 8> ComputeParamTypes(const DiffParams& diffParams);

/// Builds and returns the sequence of derived function parameters.
///
/// Information about the original function, derived function, derived
/// function parameter types and the differentiation mode are implicitly
/// taken from the data member variables. In particular, `m_Function`,
/// `m_Mode` and `m_Derivative` should be correctly set before using this
/// function.
/// taken from the data member variables.
llvm::SmallVector<clang::ParmVarDecl*, 8>
BuildParams(DiffParams& diffParams);

Expand Down
7 changes: 3 additions & 4 deletions include/clad/Differentiator/VectorForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
clang::Expr* m_IndVarCountExpr;

public:
VectorForwardModeVisitor(DerivativeBuilder& builder);
VectorForwardModeVisitor(DerivativeBuilder& builder,
const DiffRequest& request);
~VectorForwardModeVisitor();

///\brief Produces the first derivative of a given function with
Expand Down Expand Up @@ -53,9 +54,7 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
///
/// Information about the original function, derived function, derived
/// function parameter types and the differentiation mode are implicitly
/// taken from the data member variables. In particular, `m_Function`,
/// `m_Mode` and `m_Derivative` should be correctly set before using this
/// function.
/// taken from the data member variables.
llvm::SmallVector<clang::ParmVarDecl*, 8>
BuildVectorModeParams(DiffParams& diffParams);

Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/VectorPushForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ namespace clad {
class VectorPushForwardModeVisitor : public VectorForwardModeVisitor {

public:
VectorPushForwardModeVisitor(DerivativeBuilder& builder);
VectorPushForwardModeVisitor(DerivativeBuilder& builder,
const DiffRequest& request);
~VectorPushForwardModeVisitor() override;

void ExecuteInsidePushforwardFunctionBlock() override;
Expand Down
8 changes: 4 additions & 4 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ namespace clad {
/// A base class for all common functionality for visitors
class VisitorBase {
protected:
VisitorBase(DerivativeBuilder& builder)
VisitorBase(DerivativeBuilder& builder, const DiffRequest& request)
: m_Builder(builder), m_Sema(builder.m_Sema),
m_CladPlugin(builder.m_CladPlugin), m_Context(builder.m_Context),
m_DerivativeFnScope(nullptr), m_DerivativeInFlight(false),
m_Derivative(nullptr), m_Function(nullptr) {}
m_Derivative(nullptr), m_DiffReq(request) {}

using Stmts = llvm::SmallVector<clang::Stmt*, 16>;

Expand All @@ -117,8 +117,8 @@ namespace clad {
bool m_DerivativeInFlight;
/// The Derivative function that is being generated.
clang::FunctionDecl* m_Derivative;
/// The function that is currently differentiated.
const clang::FunctionDecl* m_Function;
/// The differentiation request that is being currently processed.
const DiffRequest& m_DiffReq;
DiffMode m_Mode;
/// Map used to keep track of variable declarations and match them
/// with their derivatives.
Expand Down
41 changes: 22 additions & 19 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@
using namespace clang;

namespace clad {
BaseForwardModeVisitor::BaseForwardModeVisitor(DerivativeBuilder& builder)
: VisitorBase(builder) {}
BaseForwardModeVisitor::BaseForwardModeVisitor(DerivativeBuilder& builder,
const DiffRequest& request)
: VisitorBase(builder, request) {}

BaseForwardModeVisitor::~BaseForwardModeVisitor() {}

Expand All @@ -60,8 +61,8 @@ bool IsRealNonReferenceType(QualType T) {
DerivativeAndOverload
BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
const DiffRequest& request) {
assert(m_DiffReq == request && "Can't pass two different requests!");
silenceDiags = !request.VerboseDiags;
m_Function = FD;
m_Functor = request.Functor;
m_Mode = DiffMode::forward;
assert(!m_DerivativeInFlight &&
Expand Down Expand Up @@ -144,7 +145,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
// then the specified independent argument is a member variable of the
// class defining the call operator.
// Thus, we need to find index of the member variable instead.
if (m_Function->param_empty() && m_Functor) {
if (m_DiffReq->param_empty() && m_Functor) {
m_ArgIndex =
std::distance(m_Functor->field_begin(),
std::find(m_Functor->field_begin(),
Expand All @@ -161,11 +162,11 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,

IdentifierInfo* II = &m_Context.Idents.get(
request.BaseFunctionName + "_d" + s + "arg" + argInfo + derivativeSuffix);
SourceLocation validLoc{m_Function->getLocation()};
SourceLocation validLoc{m_DiffReq->getLocation()};
DeclarationNameInfo name(II, validLoc);
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope());
DeclContext* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
DeclContext* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
m_Sema.CurContext = DC;
DeclWithContext result =
m_Builder.cloneFunction(FD, *this, DC, validLoc, name, FD->getType());
Expand Down Expand Up @@ -368,7 +369,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,

clang::QualType BaseForwardModeVisitor::ComputePushforwardFnReturnType() {
assert(m_Mode == GetPushForwardMode());
QualType originalFnRT = m_Function->getReturnType();
QualType originalFnRT = m_DiffReq->getReturnType();
if (originalFnRT->isVoidType())
return m_Context.VoidTy;
TemplateDecl* valueAndPushforward =
Expand All @@ -382,7 +383,7 @@ clang::QualType BaseForwardModeVisitor::ComputePushforwardFnReturnType() {
}

void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() {
Stmt* bodyDiff = Visit(m_Function->getBody()).getStmt();
Stmt* bodyDiff = Visit(m_DiffReq->getBody()).getStmt();
auto* CS = cast<CompoundStmt>(bodyDiff);
for (Stmt* S : CS->body())
addToCurrentBlock(S);
Expand All @@ -391,19 +392,20 @@ void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() {
DerivativeAndOverload
BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
const DiffRequest& request) {
m_Function = FD;
const_cast<DiffRequest&>(m_DiffReq) = request;
m_Functor = request.Functor;
m_DerivativeOrder = request.CurrentDerivativeOrder;
m_Mode = GetPushForwardMode();
assert(!m_DerivativeInFlight &&
"Doesn't support recursive diff. Use DiffPlan.");
m_DerivativeInFlight = true;

auto originalFnEffectiveName = utils::ComputeEffectiveFnName(m_Function);
auto originalFnEffectiveName =
utils::ComputeEffectiveFnName(m_DiffReq.Function);

IdentifierInfo* derivedFnII = &m_Context.Idents.get(
originalFnEffectiveName + GetPushForwardFunctionSuffix());
DeclarationNameInfo derivedFnName(derivedFnII, m_Function->getLocation());
DeclarationNameInfo derivedFnName(derivedFnII, m_DiffReq->getLocation());
llvm::SmallVector<QualType, 16> paramTypes;
llvm::SmallVector<QualType, 16> derivedParamTypes;

Expand All @@ -417,7 +419,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
}
}

for (auto* PVD : m_Function->parameters()) {
for (auto* PVD : m_DiffReq->parameters()) {
paramTypes.push_back(PVD->getType());

if (BaseForwardModeVisitor::IsDifferentiableType(PVD->getType()))
Expand All @@ -428,19 +430,19 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
derivedParamTypes.end());

const auto* originalFnType =
dyn_cast<FunctionProtoType>(m_Function->getType());
dyn_cast<FunctionProtoType>(m_DiffReq->getType());
QualType returnType = ComputePushforwardFnReturnType();
QualType derivedFnType = m_Context.getFunctionType(
returnType, paramTypes, originalFnType->getExtProtoInfo());
llvm::SaveAndRestore<DeclContext*> saveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> saveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
auto* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
m_Sema.CurContext = DC;

SourceLocation loc{m_Function->getLocation()};
SourceLocation loc{m_DiffReq->getLocation()};
DeclWithContext cloneFunctionResult = m_Builder.cloneFunction(
m_Function, *this, DC, loc, derivedFnName, derivedFnType);
m_DiffReq.Function, *this, DC, loc, derivedFnName, derivedFnType);
m_Derivative = cloneFunctionResult.first;

llvm::SmallVector<ParmVarDecl*, 16> params;
Expand All @@ -466,9 +468,9 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
}
}

std::size_t numParamsOriginalFn = m_Function->getNumParams();
std::size_t numParamsOriginalFn = m_DiffReq->getNumParams();
for (std::size_t i = 0; i < numParamsOriginalFn; ++i) {
const auto* PVD = m_Function->getParamDecl(i);
const auto* PVD = m_DiffReq->getParamDecl(i);
// Some of the special member functions created implicitly by compilers
// have missing parameter identifier.
bool identifierMissing = false;
Expand Down Expand Up @@ -1137,7 +1139,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
const_cast<DeclContext*>(FD->getDeclContext()));

// Check if it is a recursive call.
if (!callDiff && (FD == m_Function) && m_Mode == GetPushForwardMode()) {
if (!callDiff && (FD == m_DiffReq.Function) &&
m_Mode == GetPushForwardMode()) {
// The differentiated function is called recursively.
Expr* derivativeRef =
m_Sema
Expand Down
Loading

0 comments on commit 5595016

Please sign in to comment.