Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Redundant Goto Statements #615

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "clad/Differentiator/VisitorBase.h"
#include "clad/Differentiator/ReverseModeVisitorDirectionKinds.h"
#include "clad/Differentiator/ParseDiffArgsTypes.h"

#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"
Expand Down Expand Up @@ -92,6 +93,10 @@ namespace clad {
// Function to Differentiate with Enzyme as Backend
void DifferentiateWithEnzyme();

// Whether Stmt is Return and not inside any block;
bool OnlyReturn = false;
int CCount = 0;

public:
using direction = rmv::direction;
clang::Expr* dfdx() {
Expand Down
26 changes: 20 additions & 6 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
beginScope(Scope::DeclScope);
beginBlock(direction::forward);
beginBlock(direction::reverse);
CCount++;
for (Stmt* S : CS->body()) {
if(CCount==1&&isa<ReturnStmt>(S))
OnlyReturn=true;
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingStmtInVisitCompoundStmt();
StmtDiff SDiff = DifferentiateSingleStmt(S);
Expand All @@ -756,6 +759,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CompoundStmt* Forward = endBlock(direction::forward);
CompoundStmt* Reverse = endBlock(direction::reverse);
endScope();
CCount--;
return StmtDiff(Forward, Reverse);
}

Expand Down Expand Up @@ -1153,14 +1157,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If the original function returns at this point, some part of the reverse
// pass (corresponding to other branches that do not return here) must be
// skipped. We create a label in the reverse pass and jump to it via goto.
LabelDecl* LD = LabelDecl::Create(
m_Context, m_Sema.CurContext, noLoc, CreateUniqueIdentifier("_label"));
m_Sema.PushOnScopeChains(LD, m_DerivativeFnScope, true);
LabelDecl* LD = nullptr;
if (!OnlyReturn) {
LD = LabelDecl::Create(
m_Context, m_Sema.CurContext, noLoc, CreateUniqueIdentifier("_label"));
m_Sema.PushOnScopeChains(LD, m_DerivativeFnScope, true);
}
// Attach label to the last Stmt in the corresponding Reverse Stmt.
if (!Reverse)
Reverse = m_Sema.ActOnNullStmt(noLoc).get();
Stmt* LS = m_Sema.ActOnLabelStmt(noLoc, LD, noLoc, Reverse).get();
addToCurrentBlock(LS, direction::reverse);
if (!OnlyReturn) {
Stmt* LS = m_Sema.ActOnLabelStmt(noLoc, LD, noLoc, Reverse).get();
addToCurrentBlock(LS, direction::reverse);
}else {
addToCurrentBlock(Reverse, direction::reverse);
}
for (Stmt* S : cast<CompoundStmt>(ReturnDiff.getStmt())->body())
addToCurrentBlock(S, direction::forward);

Expand All @@ -1175,7 +1186,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

// Create goto to the label.
return m_Sema.ActOnGotoStmt(noLoc, noLoc, LD).get();
if (!OnlyReturn)
return m_Sema.ActOnGotoStmt(noLoc, noLoc, LD).get();

return nullptr;
}

StmtDiff ReverseModeVisitor::VisitParenExpr(const ParenExpr* PE) {
Expand Down
18 changes: 0 additions & 18 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ double addArr(double *arr, int n) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: ret += arr[clad::push(_t1, i)];
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_ret += _d_y;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: {
Expand All @@ -45,8 +43,6 @@ double f(double *arr) {
//CHECK: void f_grad(double *arr, clad::array_ref<double> _d_arr) {
//CHECK-NEXT: double *_t0;
//CHECK-NEXT: _t0 = arr;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: int _grad1 = 0;
//CHECK-NEXT: addArr_pullback(_t0, 3, 1, _d_arr, &_grad1);
Expand Down Expand Up @@ -82,8 +78,6 @@ float func(float* a, float* b) {
//CHECK-NEXT: _ref0 *= clad::push(_t3, b[clad::push(_t5, i)]);
//CHECK-NEXT: sum += a[clad::push(_t7, i)];
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: {
Expand Down Expand Up @@ -113,8 +107,6 @@ float helper(float x) {
// CHECK: void helper_pullback(float x, float _d_y, clad::array_ref<float> _d_x) {
// CHECK-NEXT: float _t0;
// CHECK-NEXT: _t0 = x;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: float _r0 = _d_y * _t0;
// CHECK-NEXT: float _r1 = 2 * _d_y;
Expand All @@ -141,8 +133,6 @@ float func2(float* a) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: sum += helper(clad::push(_t3, a[clad::push(_t1, i)]));
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: float _r_d0 = _d_sum;
Expand Down Expand Up @@ -175,8 +165,6 @@ float func3(float* a, float* b) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: sum += (a[clad::push(_t1, i)] += b[clad::push(_t3, i)]);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: float _r_d0 = _d_sum;
Expand Down Expand Up @@ -221,8 +209,6 @@ double func4(double x) {
//CHECK-NEXT: clad::push(_t4, arr , 3UL);
//CHECK-NEXT: sum += addArr(arr, 3);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t3; _t3--) {
//CHECK-NEXT: {
Expand Down Expand Up @@ -286,8 +272,6 @@ double func5(int k) {
//CHECK-NEXT: clad::push(_t4, arr , n);
//CHECK-NEXT: sum += addArr(arr, clad::push(_t5, n));
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t3; _t3--) {
//CHECK-NEXT: {
Expand Down Expand Up @@ -339,8 +323,6 @@ double func6(double seed) {
//CHECK-NEXT: clad::push(_t3, arr , 3UL);
//CHECK-NEXT: sum += addArr(arr, 3);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: {
Expand Down
4 changes: 0 additions & 4 deletions test/Arrays/Arrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ double const_dot_product(double x, double y, double z) {
//CHECK-NEXT: _t2 = consts[1];
//CHECK-NEXT: _t5 = vars[2];
//CHECK-NEXT: _t4 = consts[2];
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 1 * _t0;
//CHECK-NEXT: _d_vars[0] += _r0;
Expand Down Expand Up @@ -193,8 +191,6 @@ double const_matmul_sum(double a, double b, double c, double d) {
//: _t15 = A[1][1];
//: _t14 = B[1][1];
//: double C[2][2] = {{[{][{]}}_t1 * _t0 + _t3 * _t2, _t5 * _t4 + _t7 * _t6}, {_t9 * _t8 + _t11 * _t10, _t13 * _t12 + _t15 * _t14}};
//: goto _label0;
//: _label0:
//: {
//: _d_C[0][0] += 1;
//: _d_C[0][1] += 1;
Expand Down
2 changes: 0 additions & 2 deletions test/CUDA/GradientCuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ auto gauss_g = clad::gradient(gauss, "p");
//CHECK-NEXT: _t22 = _t20 * _t17;
//CHECK-NEXT: _t23 = t;
//CHECK-NEXT: _t16 = std::exp(_t23);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r8 = 1 * _t16;
//CHECK-NEXT: double _r9 = _r8 * _t17;
Expand Down
2 changes: 0 additions & 2 deletions test/Enzyme/DifferentCladEnzymeDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ double foo(double x, double y){
// CHECK-NEXT: double _t1;
// CHECK-NEXT: _t1 = x;
// CHECK-NEXT: _t0 = y;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 1 * _t0;
// CHECK-NEXT: * _d_x += _r0;
Expand Down
14 changes: 0 additions & 14 deletions test/ErrorEstimation/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ float func(float x, float y) {
//CHECK-NEXT: x = x + y;
//CHECK-NEXT: _EERepl_x1 = x;
//CHECK-NEXT: y = x;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_y += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r_d1 = * _d_y;
Expand Down Expand Up @@ -61,8 +59,6 @@ float func2(float x, int y) {
//CHECK-NEXT: _t2 = x;
//CHECK-NEXT: x = _t1 * _t0 + _t3 * _t2;
//CHECK-NEXT: _EERepl_x1 = x;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_x += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r_d0 = * _d_x;
Expand All @@ -89,8 +85,6 @@ float func3(int x, int y) {

//CHECK: void func3_grad(int x, int y, clad::array_ref<int> _d_x, clad::array_ref<int> _d_y, double &_final_error) {
//CHECK-NEXT: x = y;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_y += 1;
//CHECK-NEXT: {
//CHECK-NEXT: int _r_d0 = * _d_x;
Expand All @@ -117,8 +111,6 @@ float func4(float x, float y) {
//CHECK-NEXT: _EERepl_z0 = z;
//CHECK-NEXT: x = z + y;
//CHECK-NEXT: _EERepl_x1 = x;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_x += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r_d0 = * _d_x;
Expand Down Expand Up @@ -149,8 +141,6 @@ float func5(float x, float y) {
//CHECK-NEXT: int z = 56;
//CHECK-NEXT: x = z + y;
//CHECK-NEXT: _EERepl_x1 = x;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_x += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r_d0 = * _d_x;
Expand All @@ -169,8 +159,6 @@ float func5(float x, float y) {
float func6(float x) { return x; }

//CHECK: void func6_grad(float x, clad::array_ref<float> _d_x, double &_final_error) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_x += 1;
//CHECK-NEXT: double _delta_x = 0;
//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}});
Expand All @@ -186,8 +174,6 @@ float func7(float x, float y) { return (x * y); }
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: _t0 = y;
//CHECK-NEXT: _ret_value0 = (_t1 * _t0);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: float _r0 = 1 * _t0;
//CHECK-NEXT: * _d_x += _r0;
Expand Down
22 changes: 0 additions & 22 deletions test/ErrorEstimation/BasicOps.C
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ float func(float x, float y) {
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: float z = _t1 * _t0;
//CHECK-NEXT: _EERepl_z0 = z;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_z += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r0 = _d_z * _t0;
Expand Down Expand Up @@ -95,8 +93,6 @@ float func2(float x, float y) {
//CHECK-NEXT: _t2 = x;
//CHECK-NEXT: float z = _t3 / _t2;
//CHECK-NEXT: _EERepl_z0 = z;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_z += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r2 = _d_z / _t2;
Expand Down Expand Up @@ -164,8 +160,6 @@ float func3(float x, float y) {
//CHECK-NEXT: float t = _t5 * _t2;
//CHECK-NEXT: _EERepl_t0 = t;
//CHECK-NEXT: _EERepl_y1 = y;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_t += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r2 = _d_t * _t2;
Expand Down Expand Up @@ -210,8 +204,6 @@ float func4(float x, float y) { return std::pow(x, y); }
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: _t1 = y;
//CHECK-NEXT: _ret_value0 = std::pow(_t0, _t1);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: float _grad0 = 0.F;
//CHECK-NEXT: float _grad1 = 0.F;
Expand Down Expand Up @@ -248,8 +240,6 @@ float func5(float x, float y) {
//CHECK-NEXT: _t2 = y;
//CHECK-NEXT: _t1 = y;
//CHECK-NEXT: _ret_value0 = _t2 * _t1;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: float _r1 = 1 * _t1;
//CHECK-NEXT: * _d_y += _r1;
Expand Down Expand Up @@ -280,8 +270,6 @@ double helper(double x, double y) { return x * y; }
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: _t0 = y;
//CHECK-NEXT: _ret_value0 = _t1 * _t0;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y0 * _t0;
//CHECK-NEXT: * _d_x += _r0;
Expand Down Expand Up @@ -316,8 +304,6 @@ float func6(float x, float y) {
//CHECK-NEXT: _t4 = z;
//CHECK-NEXT: _t3 = z;
//CHECK-NEXT: _ret_value0 = _t4 * _t3;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: float _r2 = 1 * _t3;
//CHECK-NEXT: _d_z += _r2;
Expand Down Expand Up @@ -350,8 +336,6 @@ float func7(float x) {
//CHECK: void func7_grad(float x, clad::array_ref<float> _d_x, double &_final_error) {
//CHECK-NEXT: int _d_z = 0;
//CHECK-NEXT: int z = x;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: _d_z += 1;
//CHECK-NEXT: _d_z += 1;
Expand All @@ -372,8 +356,6 @@ double helper2(float& x) { return x * x; }
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: _ret_value0 = _t1 * _t0;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y * _t0;
//CHECK-NEXT: * _d_x += _r0;
Expand All @@ -400,8 +382,6 @@ float func8(float x, float y) {
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: z = y + helper2(x);
//CHECK-NEXT: _EERepl_z1 = z;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_z += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r_d0 = _d_z;
Expand Down Expand Up @@ -449,8 +429,6 @@ float func9(float x, float y) {
//CHECK-NEXT: _t5 = helper2(y);
//CHECK-NEXT: z += _t8 * _t5;
//CHECK-NEXT: _EERepl_z1 = z;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_z += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r_d0 = _d_z;
Expand Down
6 changes: 0 additions & 6 deletions test/ErrorEstimation/ConditonalStatements.C
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ float func(float x, float y) {
//CHECK-NEXT: x = y;
//CHECK-NEXT: }
//CHECK-NEXT: _ret_value0 = x + y;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: * _d_x += 1;
//CHECK-NEXT: * _d_y += 1;
Expand Down Expand Up @@ -160,8 +158,6 @@ float func3(float x, float y) { return x > 30 ? x * y : x + y; }
//CHECK-NEXT: _t0 = y;
//CHECK-NEXT: }
//CHECK-NEXT: _ret_value0 = _cond0 ? _t1 * _t0 : x + y;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: float _r0 = 1 * _t0;
//CHECK-NEXT: * _d_x += _r0;
Expand Down Expand Up @@ -207,8 +203,6 @@ float func4(float x, float y) {
//CHECK-NEXT: _t3 = y;
//CHECK-NEXT: _t2 = x;
//CHECK-NEXT: _ret_value0 = _t3 / _t2;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: float _r1 = 1 / _t2;
//CHECK-NEXT: * _d_y += _r1;
Expand Down
Loading
Loading