Skip to content

Commit

Permalink
Address complaints from clang-tidy.
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Jun 11, 2024
1 parent 5595016 commit acb6f83
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 13 deletions.
2 changes: 1 addition & 1 deletion include/clad/Differentiator/HessianModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace clad {

public:
HessianModeVisitor(DerivativeBuilder& builder, const DiffRequest& request);
~HessianModeVisitor();
~HessianModeVisitor() = default;

///\brief Produces the hessian second derivative columns of a given
/// function.
Expand Down
9 changes: 8 additions & 1 deletion lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,10 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
DeclarationNameInfo name(II, validLoc);
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope());
DeclContext* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());

// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
m_Sema.CurContext = DC;
DeclWithContext result =
m_Builder.cloneFunction(FD, *this, DC, validLoc, name, FD->getType());
Expand Down Expand Up @@ -392,6 +395,8 @@ void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() {
DerivativeAndOverload
BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
const DiffRequest& request) {
// FIXME: We must not reset the diff request here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<DiffRequest&>(m_DiffReq) = request;
m_Functor = request.Functor;
m_DerivativeOrder = request.CurrentDerivativeOrder;
Expand Down Expand Up @@ -437,6 +442,8 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
llvm::SaveAndRestore<DeclContext*> saveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> saveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
m_Sema.CurContext = DC;

Expand Down
13 changes: 7 additions & 6 deletions lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,9 @@ HessianModeVisitor::HessianModeVisitor(DerivativeBuilder& builder,
const DiffRequest& request)
: VisitorBase(builder, request) {}

HessianModeVisitor::~HessianModeVisitor() {}

/// Converts the string str into a StringLiteral
static const StringLiteral* CreateStringLiteral(ASTContext& C,
std::string str) {
const std::string& str) {
QualType CharTyConst = C.CharTy.withConst();
QualType StrTy = clad_compat::getConstantArrayType(
C, CharTyConst, llvm::APInt(/*numBits=*/32, str.size() + 1),
Expand Down Expand Up @@ -241,15 +239,18 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C,

paramTypes.back() = m_Context.getPointerType(m_DiffReq->getReturnType());

auto originalFnProtoType = cast<FunctionProtoType>(m_DiffReq->getType());
const auto* originalFnProtoType =
cast<FunctionProtoType>(m_DiffReq->getType());
QualType hessianFunctionType = m_Context.getFunctionType(
m_Context.VoidTy,
llvm::ArrayRef<QualType>(paramTypes.data(), paramTypes.size()),
// Cast to function pointer.
originalFnProtoType->getExtProtoInfo());

// Create the gradient function declaration.
DeclContext* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
Expand Down Expand Up @@ -341,7 +342,7 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C,
// FIXME: Add support for class type in the hessian matrix. For this, we
// need to add a way to represent hessian matrix when class type objects
// are involved.
if (auto MD = dyn_cast<CXXMethodDecl>(m_DiffReq.Function)) {
if (const auto* MD = dyn_cast<CXXMethodDecl>(m_DiffReq.Function)) {
const CXXRecordDecl* RD = MD->getParent();
if (MD->isInstance() && !RD->isLambda()) {
QualType thisObjectType =
Expand Down
12 changes: 12 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Cast to function pointer.
gradFuncOverloadEPI);

// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
m_Sema.CurContext = DC;
DeclWithContext gradientOverloadFDWC =
Expand Down Expand Up @@ -363,6 +365,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
m_Sema.CurContext = DC;
DeclWithContext result = m_Builder.cloneFunction(
Expand Down Expand Up @@ -477,6 +481,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActOnStartOfDerive();
silenceDiags = !request.VerboseDiags;
// FIXME: We should not use const_cast to get the decl request here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<DiffRequest&>(m_DiffReq) = request;
m_Mode = DiffMode::experimental_pullback;
assert(m_DiffReq.Function && "Must not be null.");
Expand Down Expand Up @@ -514,6 +520,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::SaveAndRestore<DeclContext*> saveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> saveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
m_Sema.CurContext = const_cast<DeclContext*>(m_DiffReq->getDeclContext());

SourceLocation validLoc{m_DiffReq->getLocation()};
Expand Down Expand Up @@ -663,8 +671,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::SmallVector<ParmVarDecl*, 16> enzymeRealParamsDerived;

// First add the function itself as a parameter/argument
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
enzymeArgs.push_back(
BuildDeclRef(const_cast<FunctionDecl*>(m_DiffReq.Function)));
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* fdDeclContext = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, m_DiffReq->getType()));
Expand Down
12 changes: 8 additions & 4 deletions lib/Differentiator/VectorForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD,
// Generate the function type for the derivative.
llvm::SmallVector<clang::QualType, 8> paramTypes;
paramTypes.reserve(m_DiffReq->getNumParams() + args.size());
for (auto PVD : m_DiffReq->parameters())
for (auto* PVD : m_DiffReq->parameters())
paramTypes.push_back(PVD->getType());
for (auto PVD : m_DiffReq->parameters()) {
for (auto* PVD : m_DiffReq->parameters()) {
auto it = std::find(std::begin(args), std::end(args), PVD);
if (it == std::end(args))
continue; // This parameter is not in the diff list.
Expand All @@ -108,7 +108,9 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD,
dyn_cast<FunctionProtoType>(m_DiffReq->getType())->getExtProtoInfo());

// Create the function declaration for the derivative.
DeclContext* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
m_Sema.CurContext = DC;
DeclWithContext result = m_Builder.cloneFunction(
m_DiffReq.Function, *this, DC, loc, name, vectorDiffFunctionType);
Expand Down Expand Up @@ -283,6 +285,8 @@ clang::FunctionDecl* VectorForwardModeVisitor::CreateVectorModeOverload() {
vectorModeFuncOverloadEPI);

// Create the function declaration for the derivative.
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
m_Sema.CurContext = DC;
DeclWithContext result =
Expand Down Expand Up @@ -400,7 +404,7 @@ VectorForwardModeVisitor::BuildVectorModeParams(DiffParams& diffParams) {
// differentiation.
size_t nonArrayIndVarCount = 0;

for (auto PVD : m_DiffReq->parameters()) {
for (auto* PVD : m_DiffReq->parameters()) {
auto newPVD = utils::BuildParmVarDecl(
m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(),
PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo());
Expand Down
2 changes: 1 addition & 1 deletion lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ namespace clad {
VarDecl::InitializationStyle IS) {
// add namespace specifier in variable declaration if needed.
Type = utils::AddNamespaceSpecifier(m_Sema, m_Context, Type);
auto VD = VarDecl::Create(
auto* VD = VarDecl::Create(
m_Context, m_Sema.CurContext, m_DiffReq->getLocation(),
m_DiffReq->getLocation(), Identifier, Type, TSI, SC_None);

Expand Down

0 comments on commit acb6f83

Please sign in to comment.