From 4cb2d286843487680997e17a23b0e565f43f0b0a Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Wed, 21 Aug 2024 18:13:58 +0300 Subject: [PATCH] Save VD type in a variable in DifferentiateVarDecl --- lib/Differentiator/ReverseModeVisitor.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index aeb7a49ee..478db7c19 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2719,29 +2719,30 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_DiffReq.Mode != DiffMode::reverse_mode_forward_pass && !keepLocal; QualType VDCloneType; QualType VDDerivedType; + QualType VDType = VD->getType(); // If the cloned declaration is moved to the function global scope, // change its type for the corresponding adjoint type. if (promoteToFnScope) { - VDDerivedType = ComputeAdjointType(CloneType(VD->getType())); + VDDerivedType = ComputeAdjointType(CloneType(VDType)); VDCloneType = VDDerivedType; if (isa(VDCloneType) && !isa(VDCloneType)) VDCloneType = GetCladArrayOfType(m_Context.getBaseElementType(VDCloneType)); } else { - VDCloneType = CloneType(VD->getType()); + VDCloneType = CloneType(VDType); VDDerivedType = getNonConstType(VDCloneType, m_Context, m_Sema); } - bool isRefType = VD->getType()->isLValueReferenceType(); + bool isRefType = VDType->isLValueReferenceType(); VarDecl* VDDerived = nullptr; - bool isPointerType = VD->getType()->isPointerType(); + bool isPointerType = VDType->isPointerType(); bool isInitializedByNewExpr = false; bool initializeDerivedVar = true; // We need to replace std::initializer_list with clad::array because the // former is temporary by design and it's not possible to create modifiable // adjoints. - if (m_Sema.isStdInitializerList(utils::GetValueType(VD->getType()), + if (m_Sema.isStdInitializerList(utils::GetValueType(VDType), /*Element=*/nullptr)) { if (const Expr* init = VD->getInit()) { if (const auto* CXXILE = @@ -2776,7 +2777,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // VDDerivedInit now serves two purposes -- as the initial derivative value // or the size of the derivative array -- depending on the primal type. - if (const auto* AT = dyn_cast(VD->getType())) { + if (const auto* AT = dyn_cast(VDType)) { if (!isa(AT)) { Expr* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); @@ -2801,7 +2802,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // value is set to 0. // Otherwise, for non-reference types, the initial value is set to 0. if (!VDDerivedInit) - VDDerivedInit = getZeroInit(VD->getType()); + VDDerivedInit = getZeroInit(VDType); // `specialThisDiffCase` is only required for correctly differentiating // the following code: @@ -2822,7 +2823,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, initDiff = Visit(VD->getInit()); if (!initDiff.getForwSweepExpr_dx()) { VDDerivedType = - ComputeAdjointType(VD->getType().getNonReferenceType()); + ComputeAdjointType(VDType.getNonReferenceType()); isRefType = false; } if (promoteToFnScope || !isRefType) @@ -2855,7 +2856,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If the pointer is const and derived expression is not available, then // we should not create a derived variable for it. This will be useful // for reducing number of differentiation variables in pullbacks. - bool constPointer = VD->getType()->getPointeeType().isConstQualified(); + bool constPointer = VDType->getPointeeType().isConstQualified(); if (constPointer && !isInitializedByNewExpr && !initDiff.getExpr_dx()) initializeDerivedVar = false; else { @@ -2863,7 +2864,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If it's a pointer to a constant type, then remove the constness. if (constPointer) { // first extract the pointee type - auto pointeeType = VD->getType()->getPointeeType(); + auto pointeeType = VDType->getPointeeType(); // then remove the constness pointeeType.removeLocalConst(); // then create a new pointer type with the new pointee type @@ -3003,7 +3004,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // _d_y // } if ((VD->getDeclName() != VDClone->getDeclName() || - VD->getType() != VDClone->getType())) + VDType != VDClone->getType())) m_DeclReplacements[VD] = VDClone; return DeclDiff(VDClone, VDDerived);