Skip to content

Commit

Permalink
Save VD type in a variable in DifferentiateVarDecl
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Aug 21, 2024
1 parent 0687cee commit 4cb2d28
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2719,29 +2719,30 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_DiffReq.Mode != DiffMode::reverse_mode_forward_pass && !keepLocal;
QualType VDCloneType;
QualType VDDerivedType;
QualType VDType = VD->getType();
// If the cloned declaration is moved to the function global scope,
// change its type for the corresponding adjoint type.
if (promoteToFnScope) {
VDDerivedType = ComputeAdjointType(CloneType(VD->getType()));
VDDerivedType = ComputeAdjointType(CloneType(VDType));
VDCloneType = VDDerivedType;
if (isa<ArrayType>(VDCloneType) && !isa<IncompleteArrayType>(VDCloneType))
VDCloneType =
GetCladArrayOfType(m_Context.getBaseElementType(VDCloneType));
} else {
VDCloneType = CloneType(VD->getType());
VDCloneType = CloneType(VDType);
VDDerivedType = getNonConstType(VDCloneType, m_Context, m_Sema);
}

bool isRefType = VD->getType()->isLValueReferenceType();
bool isRefType = VDType->isLValueReferenceType();
VarDecl* VDDerived = nullptr;
bool isPointerType = VD->getType()->isPointerType();
bool isPointerType = VDType->isPointerType();
bool isInitializedByNewExpr = false;
bool initializeDerivedVar = true;

// We need to replace std::initializer_list with clad::array because the
// former is temporary by design and it's not possible to create modifiable
// adjoints.
if (m_Sema.isStdInitializerList(utils::GetValueType(VD->getType()),
if (m_Sema.isStdInitializerList(utils::GetValueType(VDType),
/*Element=*/nullptr)) {
if (const Expr* init = VD->getInit()) {
if (const auto* CXXILE =
Expand Down Expand Up @@ -2776,7 +2777,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// VDDerivedInit now serves two purposes -- as the initial derivative value
// or the size of the derivative array -- depending on the primal type.
if (const auto* AT = dyn_cast<ArrayType>(VD->getType())) {
if (const auto* AT = dyn_cast<ArrayType>(VDType)) {
if (!isa<VariableArrayType>(AT)) {
Expr* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Expand All @@ -2801,7 +2802,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// value is set to 0.
// Otherwise, for non-reference types, the initial value is set to 0.
if (!VDDerivedInit)
VDDerivedInit = getZeroInit(VD->getType());
VDDerivedInit = getZeroInit(VDType);

// `specialThisDiffCase` is only required for correctly differentiating
// the following code:
Expand All @@ -2822,7 +2823,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
initDiff = Visit(VD->getInit());
if (!initDiff.getForwSweepExpr_dx()) {
VDDerivedType =
ComputeAdjointType(VD->getType().getNonReferenceType());
ComputeAdjointType(VDType.getNonReferenceType());
isRefType = false;
}
if (promoteToFnScope || !isRefType)
Expand Down Expand Up @@ -2855,15 +2856,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If the pointer is const and derived expression is not available, then
// we should not create a derived variable for it. This will be useful
// for reducing number of differentiation variables in pullbacks.
bool constPointer = VD->getType()->getPointeeType().isConstQualified();
bool constPointer = VDType->getPointeeType().isConstQualified();
if (constPointer && !isInitializedByNewExpr && !initDiff.getExpr_dx())
initializeDerivedVar = false;
else {
VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
// If it's a pointer to a constant type, then remove the constness.
if (constPointer) {
// first extract the pointee type
auto pointeeType = VD->getType()->getPointeeType();
auto pointeeType = VDType->getPointeeType();
// then remove the constness
pointeeType.removeLocalConst();
// then create a new pointer type with the new pointee type
Expand Down Expand Up @@ -3003,7 +3004,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// _d_y
// }
if ((VD->getDeclName() != VDClone->getDeclName() ||
VD->getType() != VDClone->getType()))
VDType != VDClone->getType()))
m_DeclReplacements[VD] = VDClone;

return DeclDiff<VarDecl>(VDClone, VDDerived);
Expand Down

0 comments on commit 4cb2d28

Please sign in to comment.