Skip to content

Commit

Permalink
Remove DerivePullback from ReverseModeVisitor and generate pullbacks …
Browse files Browse the repository at this point in the history
…with Derive.
  • Loading branch information
PetroZarytskyi committed Jul 22, 2024
1 parent a0eba51 commit 5fb5536
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 176 deletions.
2 changes: 0 additions & 2 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,6 @@ namespace clad {
/// y" will give 'f_grad_0_1' and "x, z" will give 'f_grad_0_2'.
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload DerivePullback(const clang::FunctionDecl* FD,
const DiffRequest& request);
StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
Expand Down
16 changes: 6 additions & 10 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,17 +419,13 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
result = V.DerivePushforward(FD, request);
} else if (request.Mode == DiffMode::reverse) {
ReverseModeVisitor V(*this, request);
if (request.CallUpdateRequired) {
result = V.Derive(FD, request);
} else {
if (!m_ErrorEstHandler.empty()) {
InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request);
V.AddExternalSource(*m_ErrorEstHandler.back());
}
result = V.DerivePullback(FD, request);
if (!m_ErrorEstHandler.empty())
CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel);
if (!request.CallUpdateRequired && !m_ErrorEstHandler.empty()) {
InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request);
V.AddExternalSource(*m_ErrorEstHandler.back());
}
result = V.Derive(FD, request);
if (!request.CallUpdateRequired && !m_ErrorEstHandler.empty())
CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel);
} else if (request.Mode == DiffMode::reverse_mode_forward_pass) {
ReverseModeForwPassVisitor V(*this, request);
result = V.Derive(FD, request);
Expand Down
187 changes: 26 additions & 161 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto gradientParams = m_Derivative->parameters();
std::string name =
m_DiffReq.BaseFunctionName + "_grad" + diffParamsPostfix(m_DiffReq);
if (m_DiffReq.use_enzyme)
name += "_enzyme";
IdentifierInfo* II = &m_Context.Idents.get(name);
DeclarationNameInfo DNI(II, noLoc);
// Calculate the total number of parameters that would be required for
Expand All @@ -131,6 +133,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_DiffReq->getNumParams() * 2 + numExtraParams;
std::size_t numOfDerivativeParams =
m_DiffReq->getNumParams() + numExtraParams;
// "Pullback parameter" here means the middle _d_y parameters used in
// pullbacks to represent the adjoint of the corresponding function call.
// Only Enzyme gradients and pullbacks of void functions don't have it.
bool hasPullbackParam =
!m_DiffReq.use_enzyme && !m_DiffReq->getReturnType()->isVoidType();
// Account for the this pointer.
if (isa<CXXMethodDecl>(m_DiffReq.Function) &&
!utils::IsStaticMethod(m_DiffReq.Function))
Expand Down Expand Up @@ -189,10 +196,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
callArgs.push_back(BuildDeclRef(VD));
}

std::size_t firstAdjParamIdx = m_DiffReq->getNumParams();
if (hasPullbackParam)
++firstAdjParamIdx;
for (std::size_t i = 0; i < numOfDerivativeParams; ++i) {
IdentifierInfo* II = nullptr;
StorageClass SC = StorageClass::SC_None;
std::size_t effectiveGradientIndex = m_DiffReq->getNumParams() + i + 1;
std::size_t effectiveGradientIndex = firstAdjParamIdx + i;
// `effectiveGradientIndex < gradientParams.size()` implies that this
// parameter represents an actual derivative of one of the function
// original parameters.
Expand Down Expand Up @@ -229,11 +239,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Build derivatives to be used in the call to the actual derived function.
// These are initialised by effectively casting the derivative parameters of
// overloaded derived function to the correct type.
for (std::size_t i = m_DiffReq->getNumParams() + 1;
i < gradientParams.size(); ++i) {
// Overloads don't have the _d_y parameter like pullbacks.
// Therefore, we have to shift the parameter index by 1.
auto* overloadParam = overloadParams[i - 1];
for (std::size_t i = firstAdjParamIdx; i < gradientParams.size(); ++i) {
// Overloads don't have the _d_y parameter like most pullbacks.
// Therefore, we have to shift the parameter index by 1 if the pullback
// has it.
auto* overloadParam = overloadParams[i - hasPullbackParam];
auto* gradientParam = gradientParams[i];
TypeSourceInfo* typeInfo =
m_Context.getTrivialTypeSourceInfo(gradientParam->getType());
Expand Down Expand Up @@ -283,15 +293,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
assert(m_DiffReq.Function && "Must not be null.");

DiffParams args{};
DiffInputVarsInfo DVI;
if (request.Args) {
DVI = request.DVI;
for (const auto& dParam : DVI)
if (!request.DVI.empty())
for (const auto& dParam : request.DVI)
args.push_back(dParam.param);
}
else
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));
if (args.empty())
if (args.empty() && (!isa<CXXMethodDecl>(FD) || utils::IsStaticMethod(FD)))
return {};

if (m_ExternalSource)
Expand All @@ -303,10 +310,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
outputArrayStr = m_DiffReq->getParamDecl(lastArgN)->getNameAsString();
}

auto derivativeBaseName = request.BaseFunctionName;
std::string gradientName = derivativeBaseName + funcPostfix(m_DiffReq);
std::string derivativeBaseName = request.BaseFunctionName;
std::string derivativeName = derivativeBaseName + funcPostfix(m_DiffReq);

IdentifierInfo* II = &m_Context.Idents.get(gradientName);
IdentifierInfo* II = &m_Context.Idents.get(derivativeName);
DeclarationNameInfo name(II, noLoc);

// If we are in error estimation mode, we have an extra `double&`
Expand All @@ -323,7 +330,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If reverse mode differentiates only part of the arguments it needs to
// generate an overload that can take in all the diff variables
bool shouldCreateOverload = false;
if (request.Mode != DiffMode::jacobian)
if (request.Mode != DiffMode::jacobian && m_DiffReq.CallUpdateRequired)
shouldCreateOverload = true;
if (!request.DeclarationOnly && !request.DerivedFDPrototypes.empty())
// If the overload is already created, we don't need to create it again.
Expand All @@ -347,7 +354,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl(
gradientName, DC, gradientFunctionType)) {
derivativeName, DC, gradientFunctionType)) {
// Set m_Derivative for creating the overload.
m_Derivative = customDerivative;
FunctionDecl* gradientOverloadFD = nullptr;
Expand Down Expand Up @@ -459,137 +466,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return DerivativeAndOverload{result.first, gradientOverloadFD};
}

DerivativeAndOverload
ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD,
const DiffRequest& request) {
// FIXME: Duplication of external source here is a workaround
// for the two 'Derive's being different functions.
if (m_ExternalSource)
m_ExternalSource->ActOnStartOfDerive();
silenceDiags = !request.VerboseDiags;
// FIXME: We should not use const_cast to get the decl request here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<DiffRequest&>(m_DiffReq) = request;
assert(m_DiffReq.Function && "Must not be null.");

DiffParams args{};
if (!request.DVI.empty())
for (const auto& dParam : request.DVI)
args.push_back(dParam.param);
else
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));
#ifndef NDEBUG
bool isStaticMethod = utils::IsStaticMethod(FD);
assert((!args.empty() || !isStaticMethod) &&
"Cannot generate pullback function of a function "
"with no differentiable arguments");
#endif

if (m_ExternalSource)
m_ExternalSource->ActAfterParsingDiffArgs(request, args);

auto derivativeName = utils::ComputeEffectiveFnName(m_DiffReq.Function) +
funcPostfix(m_DiffReq);
auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName);

auto paramTypes = ComputeParamTypes(args);
const auto* originalFnType =
dyn_cast<FunctionProtoType>(m_DiffReq->getType());

if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnParamTypes(paramTypes);

QualType pullbackFnType = m_Context.getFunctionType(
m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo());

// Check if the function is already declared as a custom derivative.
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl(
derivativeName, DC, pullbackFnType))
return DerivativeAndOverload{customDerivative, nullptr};

llvm::SaveAndRestore<DeclContext*> saveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> saveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
m_Sema.CurContext = const_cast<DeclContext*>(m_DiffReq->getDeclContext());

SourceLocation validLoc{m_DiffReq->getLocation()};
DeclWithContext fnBuildRes =
m_Builder.cloneFunction(m_DiffReq.Function, *this, m_Sema.CurContext,
validLoc, DNI, pullbackFnType);
m_Derivative = fnBuildRes.first;

if (m_ExternalSource)
m_ExternalSource->ActBeforeCreatingDerivedFnScope();

beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Scope::DeclScope);
m_Sema.PushFunctionScope();
m_Sema.PushDeclContext(getCurrentScope(), m_Derivative);

if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnScope();

auto params = BuildParams(args);
if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnParams(params);

m_Derivative->setParams(params);
m_Derivative->setBody(nullptr);

if (!request.DeclarationOnly) {
if (m_ExternalSource)
m_ExternalSource->ActBeforeCreatingDerivedFnBodyScope();

beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();

beginBlock();
if (m_ExternalSource)
m_ExternalSource->ActOnStartOfDerivedFnBody(request);

StmtDiff bodyDiff = Visit(m_DiffReq->getBody());
Stmt* forward = bodyDiff.getStmt();
Stmt* reverse = bodyDiff.getStmt_dx();

// Create the body of the function.
// Firstly, all "global" Stmts are put into fn's body.
for (Stmt* S : m_Globals)
addToCurrentBlock(S, direction::forward);
// Forward pass.
if (auto* CS = dyn_cast<CompoundStmt>(forward))
for (Stmt* S : CS->body())
addToCurrentBlock(S, direction::forward);

// Reverse pass.
if (auto* RCS = dyn_cast<CompoundStmt>(reverse))
for (Stmt* S : RCS->body())
addToCurrentBlock(S, direction::forward);

if (m_ExternalSource)
m_ExternalSource->ActOnEndOfDerivedFnBody();

Stmt* fnBody = endBlock();
m_Derivative->setBody(fnBody);
endScope(); // Function body scope

// Size >= current derivative order means that there exists a declaration
// or prototype for the currently derived function.
if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]);
}
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope

return DerivativeAndOverload{fnBuildRes.first, nullptr};
}

void ReverseModeVisitor::DifferentiateWithClad() {
llvm::ArrayRef<ParmVarDecl*> paramsRef = m_Derivative->parameters();

Expand Down Expand Up @@ -3897,18 +3773,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_DiffReq.Mode == DiffMode::reverse) {
QualType effectiveReturnType =
m_DiffReq->getReturnType().getNonReferenceType();
// FIXME: Generally, we use the function's return type as the argument's
// derivative type. We cannot follow this strategy for `void` function
// return type. Thus, temporarily use `double` type as the placeholder
// type for argument derivatives. We should think of a more uniform and
// consistent solution to this problem. One effective strategy that may
// hold well: If we are differentiating a variable of type Y with
// respect to variable of type X, then the derivative should be of type
// X. Check this related issue for more details:
// https://github.com/vgvassilev/clad/issues/385
if (effectiveReturnType->isVoidType())
effectiveReturnType = m_Context.DoubleTy;
else
if (!effectiveReturnType->isVoidType() && !m_DiffReq.use_enzyme)
paramTypes.push_back(effectiveReturnType);

if (const auto* MD = dyn_cast<CXXMethodDecl>(m_DiffReq.Function)) {
Expand Down Expand Up @@ -3945,7 +3810,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
std::size_t dParamTypesIdx = m_DiffReq->getNumParams();

if (m_DiffReq.Mode == DiffMode::reverse &&
!m_DiffReq->getReturnType()->isVoidType()) {
!m_DiffReq->getReturnType()->isVoidType() && !m_DiffReq.use_enzyme) {
++dParamTypesIdx;
}

Expand Down
6 changes: 3 additions & 3 deletions test/Enzyme/DifferentCladEnzymeDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ double foo(double x, double y){
return x*y;
}

// CHECK: void foo_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK: void foo_pullback(double x, double y, double _d_y0, double *_d_x, double *_d_y) {
// CHECK-NEXT: {
// CHECK-NEXT: *_d_x += 1 * y;
// CHECK-NEXT: *_d_y += x * 1;
// CHECK-NEXT: *_d_x += _d_y0 * y;
// CHECK-NEXT: *_d_y += x * _d_y0;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down

0 comments on commit 5fb5536

Please sign in to comment.