From acb6f83678630630bfa281dd6fab046eddce4b05 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Tue, 11 Jun 2024 13:14:24 +0000 Subject: [PATCH] Address complaints from clang-tidy. --- include/clad/Differentiator/HessianModeVisitor.h | 2 +- lib/Differentiator/BaseForwardModeVisitor.cpp | 9 ++++++++- lib/Differentiator/HessianModeVisitor.cpp | 13 +++++++------ lib/Differentiator/ReverseModeVisitor.cpp | 12 ++++++++++++ lib/Differentiator/VectorForwardModeVisitor.cpp | 12 ++++++++---- lib/Differentiator/VisitorBase.cpp | 2 +- 6 files changed, 37 insertions(+), 13 deletions(-) diff --git a/include/clad/Differentiator/HessianModeVisitor.h b/include/clad/Differentiator/HessianModeVisitor.h index 2740add92..d1ca65c1c 100644 --- a/include/clad/Differentiator/HessianModeVisitor.h +++ b/include/clad/Differentiator/HessianModeVisitor.h @@ -34,7 +34,7 @@ namespace clad { public: HessianModeVisitor(DerivativeBuilder& builder, const DiffRequest& request); - ~HessianModeVisitor(); + ~HessianModeVisitor() = default; ///\brief Produces the hessian second derivative columns of a given /// function. diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index fd3b7f70c..5d641832e 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -166,7 +166,10 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, DeclarationNameInfo name(II, validLoc); llvm::SaveAndRestore SaveContext(m_Sema.CurContext); llvm::SaveAndRestore SaveScope(getCurrentScope()); - DeclContext* DC = const_cast(m_DiffReq->getDeclContext()); + + // FIXME: We should not use const_cast to get the decl context here. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + auto* DC = const_cast(m_DiffReq->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext result = m_Builder.cloneFunction(FD, *this, DC, validLoc, name, FD->getType()); @@ -392,6 +395,8 @@ void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() { DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, const DiffRequest& request) { + // FIXME: We must not reset the diff request here. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) const_cast(m_DiffReq) = request; m_Functor = request.Functor; m_DerivativeOrder = request.CurrentDerivativeOrder; @@ -437,6 +442,8 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, llvm::SaveAndRestore saveContext(m_Sema.CurContext); llvm::SaveAndRestore saveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); + // FIXME: We should not use const_cast to get the decl context here. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) auto* DC = const_cast(m_DiffReq->getDeclContext()); m_Sema.CurContext = DC; diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index 2521b312f..ce11303db 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -31,11 +31,9 @@ HessianModeVisitor::HessianModeVisitor(DerivativeBuilder& builder, const DiffRequest& request) : VisitorBase(builder, request) {} -HessianModeVisitor::~HessianModeVisitor() {} - /// Converts the string str into a StringLiteral static const StringLiteral* CreateStringLiteral(ASTContext& C, - std::string str) { + const std::string& str) { QualType CharTyConst = C.CharTy.withConst(); QualType StrTy = clad_compat::getConstantArrayType( C, CharTyConst, llvm::APInt(/*numBits=*/32, str.size() + 1), @@ -241,7 +239,8 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C, paramTypes.back() = m_Context.getPointerType(m_DiffReq->getReturnType()); - auto originalFnProtoType = cast(m_DiffReq->getType()); + const auto* originalFnProtoType = + cast(m_DiffReq->getType()); QualType hessianFunctionType = m_Context.getFunctionType( m_Context.VoidTy, llvm::ArrayRef(paramTypes.data(), paramTypes.size()), @@ -249,7 +248,9 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C, originalFnProtoType->getExtProtoInfo()); // Create the gradient function declaration. - DeclContext* DC = const_cast(m_DiffReq->getDeclContext()); + // FIXME: We should not use const_cast to get the decl context here. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + auto* DC = const_cast(m_DiffReq->getDeclContext()); llvm::SaveAndRestore SaveContext(m_Sema.CurContext); llvm::SaveAndRestore SaveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); @@ -341,7 +342,7 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C, // FIXME: Add support for class type in the hessian matrix. For this, we // need to add a way to represent hessian matrix when class type objects // are involved. - if (auto MD = dyn_cast(m_DiffReq.Function)) { + if (const auto* MD = dyn_cast(m_DiffReq.Function)) { const CXXRecordDecl* RD = MD->getParent(); if (MD->isInstance() && !RD->isLambda()) { QualType thisObjectType = diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 7a1f2e3dc..dffd119f1 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -155,6 +155,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Cast to function pointer. gradFuncOverloadEPI); + // FIXME: We should not use const_cast to get the decl context here. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) auto* DC = const_cast(m_DiffReq->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext gradientOverloadFDWC = @@ -363,6 +365,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SaveAndRestore SaveContext(m_Sema.CurContext); llvm::SaveAndRestore SaveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); + // FIXME: We should not use const_cast to get the decl context here. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) auto* DC = const_cast(m_DiffReq->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext result = m_Builder.cloneFunction( @@ -477,6 +481,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActOnStartOfDerive(); silenceDiags = !request.VerboseDiags; + // FIXME: We should not use const_cast to get the decl request here. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) const_cast(m_DiffReq) = request; m_Mode = DiffMode::experimental_pullback; assert(m_DiffReq.Function && "Must not be null."); @@ -514,6 +520,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SaveAndRestore saveContext(m_Sema.CurContext); llvm::SaveAndRestore saveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); + // FIXME: We should not use const_cast to get the decl context here. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) m_Sema.CurContext = const_cast(m_DiffReq->getDeclContext()); SourceLocation validLoc{m_DiffReq->getLocation()}; @@ -663,8 +671,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SmallVector enzymeRealParamsDerived; // First add the function itself as a parameter/argument + // FIXME: We should not use const_cast to get the decl context here. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) enzymeArgs.push_back( BuildDeclRef(const_cast(m_DiffReq.Function))); + // FIXME: We should not use const_cast to get the decl context here. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) auto* fdDeclContext = const_cast(m_DiffReq->getDeclContext()); enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( fdDeclContext, noLoc, m_DiffReq->getType())); diff --git a/lib/Differentiator/VectorForwardModeVisitor.cpp b/lib/Differentiator/VectorForwardModeVisitor.cpp index 5e08a6b84..f5efbc3dd 100644 --- a/lib/Differentiator/VectorForwardModeVisitor.cpp +++ b/lib/Differentiator/VectorForwardModeVisitor.cpp @@ -81,9 +81,9 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, // Generate the function type for the derivative. llvm::SmallVector paramTypes; paramTypes.reserve(m_DiffReq->getNumParams() + args.size()); - for (auto PVD : m_DiffReq->parameters()) + for (auto* PVD : m_DiffReq->parameters()) paramTypes.push_back(PVD->getType()); - for (auto PVD : m_DiffReq->parameters()) { + for (auto* PVD : m_DiffReq->parameters()) { auto it = std::find(std::begin(args), std::end(args), PVD); if (it == std::end(args)) continue; // This parameter is not in the diff list. @@ -108,7 +108,9 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, dyn_cast(m_DiffReq->getType())->getExtProtoInfo()); // Create the function declaration for the derivative. - DeclContext* DC = const_cast(m_DiffReq->getDeclContext()); + // FIXME: We should not use const_cast to get the decl context here. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + auto* DC = const_cast(m_DiffReq->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext result = m_Builder.cloneFunction( m_DiffReq.Function, *this, DC, loc, name, vectorDiffFunctionType); @@ -283,6 +285,8 @@ clang::FunctionDecl* VectorForwardModeVisitor::CreateVectorModeOverload() { vectorModeFuncOverloadEPI); // Create the function declaration for the derivative. + // FIXME: We should not use const_cast to get the decl context here. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) auto* DC = const_cast(m_DiffReq->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext result = @@ -400,7 +404,7 @@ VectorForwardModeVisitor::BuildVectorModeParams(DiffParams& diffParams) { // differentiation. size_t nonArrayIndVarCount = 0; - for (auto PVD : m_DiffReq->parameters()) { + for (auto* PVD : m_DiffReq->parameters()) { auto newPVD = utils::BuildParmVarDecl( m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(), PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo()); diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index ec0a8883c..780184274 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -115,7 +115,7 @@ namespace clad { VarDecl::InitializationStyle IS) { // add namespace specifier in variable declaration if needed. Type = utils::AddNamespaceSpecifier(m_Sema, m_Context, Type); - auto VD = VarDecl::Create( + auto* VD = VarDecl::Create( m_Context, m_Sema.CurContext, m_DiffReq->getLocation(), m_DiffReq->getLocation(), Identifier, Type, TSI, SC_None);