Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
PhrygianGates committed Aug 17, 2023
1 parent 616708f commit b6b52be
Showing 1 changed file with 0 additions and 145 deletions.
145 changes: 0 additions & 145 deletions lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,11 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
addToCurrentBlock(S);

Stmt* fnBody = endBlock();
// llvm::errs() << "Derive: dumping fnBody:\n";
// fnBody->dumpColor();
m_Derivative->setBody(fnBody);
endScope();
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope();
// llvm::errs() << "Derive: Dumping m_Derivative:\n";
// m_Derivative->dumpColor();
return DerivativeAndOverload{m_Derivative, nullptr};
}

Expand Down Expand Up @@ -275,145 +271,4 @@ ReverseModeForwPassVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) {
Stmt* newRS = m_Sema.BuildReturnStmt(noLoc, returnInitList).get();
return {newRS};
}

// StmtDiff ReverseModeForwPassVisitor::VisitDeclStmt(const DeclStmt* DS) {
// llvm::SmallVector<Decl*, 4> decls, derivedDecls;
// for (auto D : DS->decls()) {
// if (auto VD = dyn_cast<VarDecl>(D)) {
// VarDeclDiff VDDiff = DifferentiateVarDecl(VD);

// if (VDDiff.getDecl()->getDeclName() != VD->getDeclName())
// m_DeclReplacements[VD] = VDDiff.getDecl();
// decls.push_back(VDDiff.getDecl());
// derivedDecls.push_back(VDDiff.getDecl_dx());
// } else {
// diag(DiagnosticsEngine::Warning, D->getEndLoc(),
// "Unsupported declaration");
// }
// }
// }

// VarDeclDiff ReverseModeForwPassVisitor::DifferentiateVarDecl(const VarDecl*
// VD) {
// StmtDiff initDiff;
// Expr* VDDerivedInit = nullptr;
// auto VDDerivedType = VD->getType();
// bool isVDRefType = VD->getType()->isReferenceType();
// VarDecl* VDDerived = nullptr;

// if (auto VDCAT = dyn_cast<ConstantArrayType>(VD->getType())) {
// assert("Should not reach here!!!");
// // VDDerivedType =
// // GetCladArrayOfType(QualType(VDCAT->getPointeeOrArrayElementType(),
// // VDCAT->getIndexTypeCVRQualifiers()));
// // VDDerivedInit = ConstantFolder::synthesizeLiteral(
// // m_Context.getSizeType(), m_Context,
// VDCAT->getSize().getZExtValue());
// // VDDerived = BuildVarDecl(VDDerivedType, "_d_" +
// VD->getNameAsString(),
// // VDDerivedInit, false, nullptr,
// // clang::VarDecl::InitializationStyle::CallInit);
// } else {
// // If VD is a reference to a local variable, then the initial value is
// set
// // to the derived variable of the corresponding local variable.
// // If VD is a reference to a non-local variable (global variable,
// struct
// // member etc), then no derived variable is available, thus `VDDerived`
// // does not need to reference any variable, consequentially the
// // `VDDerivedType` is the corresponding non-reference type and the
// initial
// // value is set to 0.
// // Otherwise, for non-reference types, the initial value is set to 0.
// VDDerivedInit = getZeroInit(VD->getType());

// // `specialThisDiffCase` is only required for correctly differentiating
// // the following code:
// // ```
// // Class _d_this_obj;
// // Class* _d_this = &_d_this_obj;
// // ```
// // Computation of hessian requires this code to be correctly
// // differentiated.
// bool specialThisDiffCase = false;
// if (auto MD = dyn_cast<CXXMethodDecl>(m_Function)) {
// if (VDDerivedType->isPointerType() && MD->isInstance()) {
// specialThisDiffCase = true;
// }
// }

// // FIXME: Remove the special cases introduced by `specialThisDiffCase`
// // once reverse mode supports pointers. `specialThisDiffCase` is only
// // required for correctly differentiating the following code:
// // ```
// // Class _d_this_obj;
// // Class* _d_this = &_d_this_obj;
// // ```
// // Computation of hessian requires this code to be correctly
// // differentiated.
// if (isVDRefType || specialThisDiffCase) {
// VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
// initDiff = Visit(VD->getInit());
// if (initDiff.getExpr_dx())
// VDDerivedInit = initDiff.getExpr_dx();
// else
// VDDerivedType = VDDerivedType.getNonReferenceType();
// }
// // Here separate behaviour for record and non-record types is only
// // necessary to preserve the old tests.
// if (VDDerivedType->isRecordType())
// VDDerived =
// BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
// VDDerivedInit, VD->isDirectInit(),
// m_Context.getTrivialTypeSourceInfo(VDDerivedType),
// VD->getInitStyle());
// else
// VDDerived = BuildVarDecl(VDDerivedType, "_d_" +
// VD->getNameAsString(),
// VDDerivedInit);
// }

// // If `VD` is a reference to a local variable, then it is already
// // differentiated and should not be differentiated again.
// // If `VD` is a reference to a non-local variable then also there's no
// // need to call `Visit` since non-local variables are not differentiated.
// if (!isVDRefType) {
// initDiff = VD->getInit() ? Visit(VD->getInit(),
// BuildDeclRef(VDDerived))
// : StmtDiff{};

// // If we are differentiating `VarDecl` corresponding to a local
// variable
// // inside a loop, then we need to reset it to 0 at each iteration.
// //
// // for example, if defined inside a loop,
// // ```
// // double localVar = i;
// // ```
// // this statement should get differentiated to,
// // ```
// // {
// // *_d_i += _d_localVar;
// // _d_localVar = 0;
// // }
// if (isInsideLoop) {
// Stmt* assignToZero = BuildOp(BinaryOperatorKind::BO_Assign,
// BuildDeclRef(VDDerived),
// getZeroInit(VDDerivedType));
// addToCurrentBlock(assignToZero, direction::reverse);
// }
// }
// VarDecl* VDClone = nullptr;
// // Here separate behaviour for record and non-record types is only
// // necessary to preserve the old tests.
// if (VD->getType()->isRecordType())
// VDClone = BuildVarDecl(VD->getType(), VD->getNameAsString(),
// initDiff.getExpr(), VD->isDirectInit(),
// VD->getTypeSourceInfo(), VD->getInitStyle());
// else
// VDClone = BuildVarDecl(VD->getType(), VD->getNameAsString(),
// initDiff.getExpr(), VD->isDirectInit());
// m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
// return VarDeclDiff(VDClone, VDDerived);
// }
} // namespace clad

0 comments on commit b6b52be

Please sign in to comment.