Skip to content

Commit

Permalink
Desugar the type before analyzing it
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Aug 19, 2024
1 parent 6cc83ee commit 26d6a64
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
14 changes: 8 additions & 6 deletions lib/Differentiator/TBRAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E,
return EData;
}

TBRAnalyzer::VarData::VarData(QualType QT, bool forceNonRefType) {
TBRAnalyzer::VarData::VarData(QualType QT, const ASTContext& C,
bool forceNonRefType) {
QT = QT.getDesugaredType(C);
if (forceNonRefType && QT->isReferenceType())
QT = QT->getPointeeType();

Expand All @@ -205,7 +207,7 @@ TBRAnalyzer::VarData::VarData(QualType QT, bool forceNonRefType) {
elemType = QT->getArrayElementTypeNoTypeQual();
ProfileID nonConstIdxID;
auto& idxData = (*m_Val.m_ArrData)[nonConstIdxID];
idxData = VarData(QualType::getFromOpaquePtr(elemType));
idxData = VarData(QualType::getFromOpaquePtr(elemType), C);
} else if (QT->isBuiltinType()) {
m_Type = VarData::FUND_TYPE;
m_Val.m_FundData = false;
Expand All @@ -216,7 +218,7 @@ TBRAnalyzer::VarData::VarData(QualType QT, bool forceNonRefType) {
newArrMap = std::unique_ptr<ArrMap>(new ArrMap());
for (const auto* field : recordDecl->fields()) {
const auto varType = field->getType();
(*newArrMap)[getProfileID(field)] = VarData(varType);
(*newArrMap)[getProfileID(field)] = VarData(varType, C);
}
}
}
Expand Down Expand Up @@ -287,11 +289,11 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD, bool forceNonRefType) {
if (const auto* const pointerType = dyn_cast<clang::PointerType>(varType)) {
const auto* elemType = pointerType->getPointeeType().getTypePtrOrNull();
if (elemType && elemType->isRecordType()) {
curBranch[VD] = VarData(QualType::getFromOpaquePtr(elemType));
curBranch[VD] = VarData(QualType::getFromOpaquePtr(elemType), m_Context);
return;
}
}
curBranch[VD] = VarData(varType, forceNonRefType);
curBranch[VD] = VarData(varType, m_Context, forceNonRefType);
}

void TBRAnalyzer::markLocation(const clang::Expr* E) {
Expand Down Expand Up @@ -331,7 +333,7 @@ void TBRAnalyzer::Analyze(const FunctionDecl* FD) {
if (MD && !MD->isStatic()) {
const Type* recordType = MD->getParent()->getTypeForDecl();
getCurBlockVarsData()[nullptr] =
VarData(QualType::getFromOpaquePtr(recordType));
VarData(QualType::getFromOpaquePtr(recordType), m_Context);
}
auto paramsRef = FD->parameters();
for (std::size_t i = 0; i < FD->getNumParams(); ++i)
Expand Down
2 changes: 1 addition & 1 deletion lib/Differentiator/TBRAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor<TBRAnalyzer> {
/// reference type (it will store TBR information itself without referring
/// to other VarData's). This is necessary for reference-type parameters,
/// when the referenced expressions are out of the function's scope.
VarData(QualType QT, bool forceNonRefType = false);
VarData(QualType QT, const ASTContext& C, bool forceNonRefType = false);

/// Erases all children VarData's of this VarData.
~VarData() {
Expand Down
25 changes: 25 additions & 0 deletions test/Gradient/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,29 @@ double fn9(Tangent t, dcomplex c) {
// CHECK-NEXT: }
// CHECK-NEXT: }

template <typename T>
struct A {
using PtrType = T*;
};

double fn10(double x, double y) {
A<double>::PtrType ptr = &x;
ptr[0] += 6;
return *ptr;
}

// CHECK: void fn10_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: A<double>::PtrType _d_ptr = &*_d_x;
// CHECK-NEXT: A<double>::PtrType ptr = &x;
// CHECK-NEXT: double _t0 = ptr[0];
// CHECK-NEXT: ptr[0] += 6;
// CHECK-NEXT: *_d_ptr += 1;
// CHECK-NEXT: {
// CHECK-NEXT: ptr[0] = _t0;
// CHECK-NEXT: double _r_d0 = _d_ptr[0];
// CHECK-NEXT: }
// CHECK-NEXT: }

void print(const Tangent& t) {
for (int i = 0; i < 5; ++i) {
printf("%.2f", t.data[i]);
Expand All @@ -351,6 +374,7 @@ int main() {
INIT_GRADIENT(fn7);
INIT_GRADIENT(fn8);
INIT_GRADIENT(fn9);
INIT_GRADIENT(fn10);

TEST_GRADIENT(fn1, /*numOfDerivativeArgs=*/2, p, i, &d_p, &d_i); // CHECK-EXEC: {1.00, 2.00, 3.00}
TEST_GRADIENT(fn2, /*numOfDerivativeArgs=*/2, t, i, &d_t, &d_i); // CHECK-EXEC: {4.00, 2.00, 2.00, 2.00, 2.00, 1.00}
Expand All @@ -364,6 +388,7 @@ int main() {
TEST_GRADIENT(fn7, /*numOfDerivativeArgs=*/2, c1, c2, &d_c1, &d_c2);// CHECK-EXEC: {0.00, 3.00, 5.00, 1.00}
TEST_GRADIENT(fn8, /*numOfDerivativeArgs=*/2, t, c1, &d_t, &d_c1); // CHECK-EXEC: {0.00, 0.00, 0.00, 0.00, 0.00, 5.00, 0.00}
TEST_GRADIENT(fn9, /*numOfDerivativeArgs=*/2, t, c1, &d_t, &d_c1); // CHECK-EXEC: {1.00, 1.00, 1.00, 1.00, 1.00, 5.00, 10.00}
TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 5, 10, &d_i, &d_j); // CHECK-EXEC: {1.00, 0.00}
}

// CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) {
Expand Down

0 comments on commit 26d6a64

Please sign in to comment.