Skip to content

Commit

Permalink
Add support for differentiating switch stmt in the reverse mode AD. (v…
Browse files Browse the repository at this point in the history
…gvassilev#339)

This commit adds support for differentiating switch statements in
the reverse mode AD.

The basic idea used to differentiate switch statement is that in the
forward pass, processing of the statements of the switch statement body
always starts from a case/default label and ends at a break statement or
at the end of the switch body.

Similarly, in the reverse pass, processing of the differentiated statements
of the switch statement body will start from the statement just above the
break statement that was hit or from the last differentiated statement in
the case when no break statement was hit.

Thus, we can keep track of which break statement was hit in the forward pass
or if no break statement was hit at all in a variable. This information is
further used by an auxiliary switch statement in the reverse pass to jump the
execution to the correct point (that is, differentiated statement of the
statement just before the break statement that was hit in the forward pass).

In this strategy, each switch case statement of the original function gets
transformed to an if condition in the reverse pass. The if condition decides
at runtime whether the processing of the differentiated statements of the switch
statement body should stop or continue. This is again based on the fact
that the processing of statements of the switch statement body always starts
at a case statement.

For an example, consider this code snippet:
```cpp
switch (count) {
    case 0: a += i; break;
    case 2: a += 4 * i; break;
    default: a += 10 * i;
}

case 0 of this code snippet gets transformed to the following in the
differentiated function:

forward pass:

{
  case 0: a += i;
}
{
  clad::push(_t0, 1UL); // this is used to keep track if this break was hit; 1UL is used to represent the case number
  break;
}

reverse pass:

case 1UL:;  // this case is selected if the corresponding break was hit in the forward pass
  {
    {
      double _r_d0 = _d_a;
      _d_a += _r_d0;
      *_d_i += _r_d0;
      _d_a -= _r_d0;
    }
    if (0 == _cond0)  // If case 0: was selected in the forward pass, we should break out of processing differentiated switch stmt body here.
      break;
  }
```
  • Loading branch information
parth-07 authored Feb 12, 2024
1 parent 03fcd38 commit 16a652b
Show file tree
Hide file tree
Showing 7 changed files with 1,100 additions and 28 deletions.
6 changes: 4 additions & 2 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ namespace clad {
llvm::SmallVectorImpl<clang::Expr*>& Exprs);

bool ContainsFunctionCalls(const clang::Stmt* E);
} // namespace utils
}

void SetSwitchCaseSubStmt(clang::SwitchCase* SC, clang::Stmt* subStmt);
} // namespace utils
} // namespace clad

#endif
33 changes: 29 additions & 4 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ namespace clad {
StmtDiff
VisitMaterializeTemporaryExpr(const clang::MaterializeTemporaryExpr* MTE);
StmtDiff VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* SCE);
StmtDiff VisitSwitchStmt(const clang::SwitchStmt* SS);
StmtDiff VisitCaseStmt(const clang::CaseStmt* CS);
StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS);
VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD);
StmtDiff VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP);
Expand Down Expand Up @@ -485,7 +488,7 @@ namespace clad {
clang::Stmt* forLoopIncDiff = nullptr,
bool isForLoop = false);

/// This class modifies forward and reverse blocks of the loop
/// This class modifies forward and reverse blocks of the loop/switch
/// body so that `break` and `continue` statements are correctly
/// handled. `break` and `continue` statements are handled by
/// enclosing entire reverse block loop body in a switch statement
Expand Down Expand Up @@ -528,6 +531,7 @@ namespace clad {

ReverseModeVisitor& m_RMV;

const bool m_IsInvokedBySwitchStmt = false;
/// Builds and returns a literal expression of type `std::size_t` with
/// `value` as value.
clang::Expr* CreateSizeTLiteralExpr(std::size_t value);
Expand All @@ -542,7 +546,8 @@ namespace clad {
clang::Expr* CreateCFTapePushExpr(std::size_t value);

public:
BreakContStmtHandler(ReverseModeVisitor& RMV) : m_RMV(RMV) {}
BreakContStmtHandler(ReverseModeVisitor& RMV, bool forSwitchStmt = false)
: m_RMV(RMV), m_IsInvokedBySwitchStmt(forSwitchStmt) {}

/// Begins control flow switch statement scope.
/// Control flow switch statement is used to refer to the
Expand Down Expand Up @@ -574,8 +579,8 @@ namespace clad {
BreakContStmtHandler* GetActiveBreakContStmtHandler() {
return &m_BreakContStmtHandlers.back();
}
BreakContStmtHandler* PushBreakContStmtHandler() {
m_BreakContStmtHandlers.emplace_back(*this);
BreakContStmtHandler* PushBreakContStmtHandler(bool forSwitchStmt = false) {
m_BreakContStmtHandlers.emplace_back(*this, forSwitchStmt);
return &m_BreakContStmtHandlers.back();
}
void PopBreakContStmtHandler() {
Expand Down Expand Up @@ -609,6 +614,26 @@ namespace clad {

clang::QualType ComputeAdjointType(clang::QualType T);
clang::QualType ComputeParamType(clang::QualType T);
/// Stores data required for differentiating a switch statement.
struct SwitchStmtInfo {
llvm::SmallVector<clang::SwitchCase*, 16> cases;
clang::Expr* switchStmtCond = nullptr;
clang::IfStmt* defaultIfBreakExpr = nullptr;
};

/// Maintains a stack of `SwitchStmtInfo`.
llvm::SmallVector<SwitchStmtInfo, 4> m_SwitchStmtsData;

SwitchStmtInfo* GetActiveSwitchStmtInfo() {
return &m_SwitchStmtsData.back();
}

SwitchStmtInfo* PushSwitchStmtInfo() {
m_SwitchStmtsData.emplace_back();
return &m_SwitchStmtsData.back();
}

void PopSwitchStmtInfo() { m_SwitchStmtsData.pop_back(); }
};
} // end namespace clad

Expand Down
14 changes: 2 additions & 12 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1592,16 +1592,6 @@ static SwitchCase* getContainedSwitchCaseStmt(const CompoundStmt* CS) {
return nullptr;
}

static void setSwitchCaseSubStmt(SwitchCase* SC, Stmt* subStmt) {
if (auto caseStmt = dyn_cast<CaseStmt>(SC)) {
caseStmt->setSubStmt(subStmt);
} else if (auto defaultStmt = dyn_cast<DefaultStmt>(SC)) {
defaultStmt->setSubStmt(subStmt);
} else {
assert(0 && "Unsupported switch case statement");
}
}

/// Returns top switch statement in the `SwitchStack` of the given
/// Function Scope.
static SwitchStmt* getTopSwitchStmtOfSwitchStack(sema::FunctionScopeInfo* FSI) {
Expand Down Expand Up @@ -1674,7 +1664,7 @@ StmtDiff BaseForwardModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) {
// been processed aka when all the statments in switch statement body
// have been processed.
if (activeSC) {
setSwitchCaseSubStmt(activeSC, endBlock());
utils::SetSwitchCaseSubStmt(activeSC, endBlock());
endScope();
activeSC = nullptr;
}
Expand Down Expand Up @@ -1702,7 +1692,7 @@ BaseForwardModeVisitor::DeriveSwitchStmtBodyHelper(const Stmt* stmt,
// corresponding to the active switch case label, and update its
// substatement.
if (activeSC) {
setSwitchCaseSubStmt(activeSC, endBlock());
utils::SetSwitchCaseSubStmt(activeSC, endBlock());
endScope();
}
// sub statement will be updated later, either when the corresponding
Expand Down
7 changes: 7 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,5 +632,12 @@ namespace clad {
finder.TraverseStmt(const_cast<Stmt*>(S));
return finder.hasCallExpr;
}

void SetSwitchCaseSubStmt(SwitchCase* SC, Stmt* subStmt) {
if (auto* caseStmt = dyn_cast<CaseStmt>(SC))
caseStmt->setSubStmt(subStmt);
else
cast<DefaultStmt>(SC)->setSubStmt(subStmt);
}
} // namespace utils
} // namespace clad
219 changes: 209 additions & 10 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ConstantFolder.h"

#include "TBRAnalyzer.h"
#include "clad/Differentiator/DerivativeBuilder.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/ExternalRMVSource.h"
Expand All @@ -17,7 +18,9 @@

#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
#include "clang/AST/Stmt.h"
#include "clang/AST/TemplateBase.h"
#include "clang/Basic/TokenKinds.h"
#include "clang/Sema/Lookup.h"
#include "clang/Sema/Overload.h"
#include "clang/Sema/Scope.h"
Expand Down Expand Up @@ -3258,6 +3261,211 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return {forwardDS, reverseBlock};
}

// Basic idea used for differentiating switch statement is that in the reverse
// pass, processing of the differentiated statments of the switch statement
// body should start either from a `break` statement or from the last
// statement of the switch statement body and always end at a switch
// case/default statement.
//
// Therefore, here we keep track of which `break` was hit in the forward pass,
// or if we no `break` statement was hit at all in a variable or clad tape.
// This information is further used by an auxilliary switch statement in the
// reverse pass to jump the execution to the correct point (that is,
// differentiated statement of the statement just before the `break` statement
// that was hit in the forward pass)
StmtDiff ReverseModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) {
// Scope and blocks for the compound statement that encloses the switch
// statement in both the forward and the reverse pass. Block is required
// for handling condition variable and switch-init statement.
beginScope(Scope::DeclScope);
beginBlock(direction::forward);
beginBlock(direction::reverse);

// Handles switch init statement
if (SS->getInit()) {
StmtDiff switchInitDiff = DifferentiateSingleStmt(SS->getInit());
addToCurrentBlock(switchInitDiff.getStmt(), direction::forward);
addToCurrentBlock(switchInitDiff.getStmt_dx(), direction::reverse);
}

// Handles condition variable
if (SS->getConditionVariable()) {
StmtDiff condVarDiff =
DifferentiateSingleStmt(SS->getConditionVariableDeclStmt());
addToCurrentBlock(condVarDiff.getStmt(), direction::forward);
addToCurrentBlock(condVarDiff.getStmt_dx(), direction::reverse);
}

StmtDiff condDiff = DifferentiateSingleStmt(SS->getCond());
addToCurrentBlock(condDiff.getStmt(), direction::forward);
addToCurrentBlock(condDiff.getStmt_dx(), direction::reverse);
Expr* condExpr = nullptr;
clad_compat::llvm_Optional<CladTapeResult> condTape;

if (isInsideLoop) {
// If we are inside a loop, condition will be stored and used as follows:
//
// forward block:
// switch (clad::push(..., cond)) { ... }
//
// reverse block:
// switch (...) { ... }
// clad::pop(...);
condTape.emplace(MakeCladTapeFor(condDiff.getExpr(), "_cond"));
condExpr = condTape->Push;
} else {
condExpr = GlobalStoreAndRef(condDiff.getExpr(), "_cond").getExpr();
}

auto* activeBreakContHandler = PushBreakContStmtHandler(
/*forSwitchStmt=*/true);
activeBreakContHandler->BeginCFSwitchStmtScope();
auto* SSData = PushSwitchStmtInfo();

if (isInsideLoop)
SSData->switchStmtCond = condTape->Last();
else
SSData->switchStmtCond = condExpr;

// scope for the switch statement body.
beginScope(Scope::DeclScope);

const Stmt* body = SS->getBody();
StmtDiff bodyDiff = nullptr;
if (isa<CompoundStmt>(body))
bodyDiff = Visit(body);
else
bodyDiff = DifferentiateSingleStmt(body);

// Each switch case statement of the original function gets transformed to
// an if condition in the reverse pass. The if condition decides at runtime
// whether the processing of the differentiated statements of the switch
// statement body should stop or continue. This is based on the fact that
// processing of statements of switch statement body always starts at a case
// statement. For example,
// ```
// case 3:
// ```
// gets transformed to,
//
// ```
// if (3 == _cond)
// break;
// ```
//
// This kind of if expression cannot by easily formed for the default
// statement, thus, we instead compare value of the switch condition with
// the values of all the case statements to determine if the default
// statement was selected in the forward pass.
// Therefore,
//
// ```
// default:
// ```
//
// will get transformed to something like,
//
// ```
// if (_cond != 1 && _cond != 2 && _cond != 3)
// break;
// ```
if (SSData->defaultIfBreakExpr) {
Expr* breakCond = nullptr;
for (auto* SC : SSData->cases) {
if (auto* CS = dyn_cast<CaseStmt>(SC)) {
if (breakCond) {
breakCond = BuildOp(BinaryOperatorKind::BO_LAnd, breakCond,
BuildOp(BinaryOperatorKind::BO_NE,
SSData->switchStmtCond, CS->getLHS()));
} else {
breakCond = BuildOp(BinaryOperatorKind::BO_NE,
SSData->switchStmtCond, CS->getLHS());
}
}
}
if (!breakCond)
breakCond = m_Sema.ActOnCXXBoolLiteral(noLoc, tok::kw_true).get();
SSData->defaultIfBreakExpr->setCond(breakCond);
}

activeBreakContHandler->EndCFSwitchStmtScope();

// If switch statement contains no cases, then, no statement of the switch
// statement body will be processed in both the forward and the reverse
// pass. Thus, we do not need to add them in the differentiated function.
if (!(SSData->cases.empty())) {
Sema::ConditionResult condRes = m_Sema.ActOnCondition(
getCurrentScope(), noLoc, condExpr, Sema::ConditionKind::Switch);
SwitchStmt* forwardSS =
clad_compat::Sema_ActOnStartOfSwitchStmt(m_Sema, nullptr, condRes)
.getAs<SwitchStmt>();
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);

// Registers all the cases to the switch statement.
for (auto* SC : SSData->cases)
forwardSS->addSwitchCase(SC);

forwardSS =
m_Sema.ActOnFinishSwitchStmt(noLoc, forwardSS, bodyDiff.getStmt())
.getAs<SwitchStmt>();

addToCurrentBlock(forwardSS, direction::forward);
if (isInsideLoop)
addToCurrentBlock(condTape->Pop, direction::reverse);
addToCurrentBlock(bodyDiff.getStmt_dx(), direction::reverse);
}

PopBreakContStmtHandler();
PopSwitchStmtInfo();
return {endBlock(direction::forward), endBlock(direction::reverse)};
}

StmtDiff ReverseModeVisitor::VisitCaseStmt(const CaseStmt* CS) {
beginBlock(direction::forward);
beginBlock(direction::reverse);
SwitchStmtInfo* SSData = GetActiveSwitchStmtInfo();

Expr* lhsClone = (CS->getLHS() ? Clone(CS->getLHS()) : nullptr);
Expr* rhsClone = (CS->getRHS() ? Clone(CS->getRHS()) : nullptr);

auto* newSC = clad_compat::CaseStmt_Create(m_Sema.getASTContext(), lhsClone,
rhsClone, noLoc, noLoc, noLoc);

Expr* ifCond = BuildOp(BinaryOperatorKind::BO_EQ, newSC->getLHS(),
SSData->switchStmtCond);
Stmt* ifThen = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get();
Stmt* ifBreakExpr = clad_compat::IfStmt_Create(
m_Context, noLoc, false, nullptr, nullptr, ifCond, noLoc, noLoc, ifThen,
noLoc, nullptr);
SSData->cases.push_back(newSC);
addToCurrentBlock(ifBreakExpr, direction::reverse);
addToCurrentBlock(newSC, direction::forward);
auto diff = DifferentiateSingleStmt(CS->getSubStmt());
utils::SetSwitchCaseSubStmt(newSC, diff.getStmt());
addToCurrentBlock(diff.getStmt_dx(), direction::reverse);
return {endBlock(direction::forward), endBlock(direction::reverse)};
}

StmtDiff ReverseModeVisitor::VisitDefaultStmt(const DefaultStmt* DS) {
beginBlock(direction::reverse);
beginBlock(direction::forward);
auto* SSData = GetActiveSwitchStmtInfo();
auto* newDefaultStmt =
new (m_Sema.getASTContext()) DefaultStmt(noLoc, noLoc, nullptr);
Stmt* ifThen = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get();
Stmt* ifBreakExpr = clad_compat::IfStmt_Create(
m_Context, noLoc, false, nullptr, nullptr, nullptr, noLoc, noLoc,
ifThen, noLoc, nullptr);
SSData->cases.push_back(newDefaultStmt);
SSData->defaultIfBreakExpr = cast<IfStmt>(ifBreakExpr);
addToCurrentBlock(ifBreakExpr, direction::reverse);
addToCurrentBlock(newDefaultStmt, direction::forward);
auto diff = DifferentiateSingleStmt(DS->getSubStmt());
utils::SetSwitchCaseSubStmt(newDefaultStmt, diff.getStmt());
addToCurrentBlock(diff.getStmt_dx(), direction::reverse);
return {endBlock(direction::forward), endBlock(direction::reverse)};
}

StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body,
LoopCounter& loopCounter,
Stmt* condVarDiff,
Expand Down Expand Up @@ -3401,10 +3609,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

CaseStmt* ReverseModeVisitor::BreakContStmtHandler::GetNextCFCaseStmt() {
// End scope for currenly active case statement, if any.
if (!m_SwitchCases.empty())
m_RMV.endScope();

++m_CaseCounter;
auto* counterLiteral = CreateSizeTLiteralExpr(m_CaseCounter);
CaseStmt* CS = clad_compat::CaseStmt_Create(m_RMV.m_Context, counterLiteral,
Expand All @@ -3418,8 +3622,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// corresponding next statements.
CS->setSubStmt(m_RMV.m_Sema.ActOnNullStmt(noLoc).get());

// begin scope for the new active switch case statement.
m_RMV.beginScope(Scope::DeclScope);
m_SwitchCases.push_back(CS);
return CS;
}
Expand All @@ -3433,12 +3635,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

void ReverseModeVisitor::BreakContStmtHandler::UpdateForwAndRevBlocks(
StmtDiff& bodyDiff) {
if (m_SwitchCases.empty())
if (m_SwitchCases.empty() && !m_IsInvokedBySwitchStmt)
return;

// end scope for last switch case.
m_RMV.endScope();

// Add case statement in the beginning of the reverse block
// and corresponding push expression for this case statement
// at the end of the forward block to cover the case when no
Expand Down
Loading

0 comments on commit 16a652b

Please sign in to comment.