Skip to content

Commit

Permalink
Use original types for adjoint parameters in RMFPV.
Browse files Browse the repository at this point in the history
This commit simplifies how we pass adjoints to ``_forw`` functions. The types used to be converted from ref-types to the corresponding pointer types. There's no need to change the type because the function represents the forward pass. Moreover, non-ref types, which require ``_forw``, were not handed properly.
  • Loading branch information
PetroZarytskyi committed Aug 19, 2024
1 parent 3f5bfd0 commit 17cdf75
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 53 deletions.
2 changes: 0 additions & 2 deletions include/clad/Differentiator/ReverseModeForwPassVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor {
ComputeParamTypes(const DiffParams& diffParams);
clang::QualType ComputeReturnType();
llvm::SmallVector<clang::ParmVarDecl*, 8> BuildParams(DiffParams& diffParams);
clang::QualType GetParameterDerivativeType(clang::QualType yType,
clang::QualType xType);

public:
ReverseModeForwPassVisitor(DerivativeBuilder& builder,
Expand Down
25 changes: 3 additions & 22 deletions lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,45 +97,26 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
return DerivativeAndOverload{m_Derivative, nullptr};
}

// FIXME: This function is copied from ReverseModeVisitor. Find a suitable place
// for it.
QualType
ReverseModeForwPassVisitor::GetParameterDerivativeType(QualType yType,
QualType xType) {
QualType xValueType = utils::GetValueType(xType);
// derivative variables should always be of non-const type.
xValueType.removeLocalConst();
QualType nonRefXValueType = xValueType.getNonReferenceType();
if (nonRefXValueType->isRealType())
return m_Context.getPointerType(yType);
return m_Context.getPointerType(nonRefXValueType);
}

llvm::SmallVector<clang::QualType, 8>
ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) {
llvm::SmallVector<clang::QualType, 8> paramTypes;
paramTypes.reserve(m_DiffReq->getNumParams() * 2);
for (auto* PVD : m_DiffReq->parameters())
paramTypes.push_back(PVD->getType());

QualType effectiveReturnType =
m_DiffReq->getReturnType().getNonReferenceType();

if (const auto* MD = dyn_cast<CXXMethodDecl>(m_DiffReq.Function)) {
const CXXRecordDecl* RD = MD->getParent();
if (MD->isInstance() && !RD->isLambda()) {
QualType thisType = MD->getThisType();
paramTypes.push_back(
GetParameterDerivativeType(effectiveReturnType, thisType));
paramTypes.push_back(thisType);
}
}

for (auto* PVD : m_DiffReq->parameters()) {
const auto* it =
std::find(std::begin(diffParams), std::end(diffParams), PVD);
if (it != std::end(diffParams)) {
paramTypes.push_back(
GetParameterDerivativeType(effectiveReturnType, PVD->getType()));
paramTypes.push_back(PVD->getType());
}
}
return paramTypes;
Expand Down Expand Up @@ -204,7 +185,7 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) {
m_Sema.PushOnScopeChains(dPVD, getCurrentScope(),
/*AddToContext=*/false);
m_Variables[*it] =
BuildOp(UO_Deref, BuildDeclRef(dPVD), m_DiffReq->getLocation());
BuildDeclRef(dPVD), m_DiffReq->getLocation();
}
}
params.insert(params.end(), paramDerivatives.begin(), paramDerivatives.end());
Expand Down
15 changes: 1 addition & 14 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2047,21 +2047,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
e = CE->getNumArgs();
i != e; ++i) {
const Expr* arg = CE->getArg(i);
const ParmVarDecl* PVD =
FD->getParamDecl(i - static_cast<unsigned long>(isCXXOperatorCall));
StmtDiff argDiff = Visit(arg);
if ((argDiff.getExpr_dx() != nullptr) &&
PVD->getType()->isReferenceType()) {
Expr* derivedArg = argDiff.getExpr_dx();
// FIXME: We may need this if-block once we support pointers, and
// passing pointers-by-reference if
// (isCladArrayType(derivedArg->getType()))
// CallArgs.push_back(derivedArg);
// else
CallArgs.push_back(
BuildOp(UnaryOperatorKind::UO_AddrOf, derivedArg, Loc));
} else
CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(Loc).get());
CallArgs.push_back(argDiff.getExpr_dx());
}
if (baseDiff.getExpr()) {
Expr* baseE = baseDiff.getExpr();
Expand Down
20 changes: 11 additions & 9 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -256,25 +256,23 @@ double fn7(double i, double j) {

// CHECK: void identity_pullback(double &i, double _d_y, double *_d_i);

// CHECK: clad::ValueAndAdjoint<double &, double &> identity_forw(double &i, double *_d_i);
// CHECK: clad::ValueAndAdjoint<double &, double &> identity_forw(double &i, double &_d_i);

// CHECK: void custom_identity_pullback(double &i, double _d_y, double *_d_i);

// CHECK: clad::ValueAndAdjoint<double &, double &> custom_identity_forw(double &i, double *d_i) {
// CHECK-NEXT: return {i, *d_i};
// CHECK-NEXT: }
// CHECK: clad::ValueAndAdjoint<double &, double &> custom_identity_forw(double &i, double &_d_i);

// CHECK: void fn7_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: double _t0 = i;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = identity_forw(i, &*_d_i);
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = identity_forw(i, *_d_i);
// CHECK-NEXT: double &_d_k = _t1.adjoint;
// CHECK-NEXT: double &k = _t1.value;
// CHECK-NEXT: double _t2 = j;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t3 = identity_forw(j, &*_d_j);
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t3 = identity_forw(j, *_d_j);
// CHECK-NEXT: double &_d_l = _t3.adjoint;
// CHECK-NEXT: double &l = _t3.value;
// CHECK-NEXT: double _t4 = i;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t5 = custom_identity_forw(i, &*_d_i);
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t5 = custom_identity_forw(i, *_d_i);
// CHECK-NEXT: double &_d_temp = _t5.adjoint;
// CHECK-NEXT: double &temp = _t5.value;
// CHECK-NEXT: double _t6 = k;
Expand Down Expand Up @@ -927,19 +925,23 @@ double sq_defined_later(double x) {
// CHECK-NEXT: *_d_i += _d__d_i;
// CHECK-NEXT: }

// CHECK: clad::ValueAndAdjoint<double &, double &> identity_forw(double &i, double *_d_i) {
// CHECK: clad::ValueAndAdjoint<double &, double &> identity_forw(double &i, double &_d_i) {
// CHECK-NEXT: MyStruct::myFunction();
// CHECK-NEXT: double _d__d_i = 0;
// CHECK-NEXT: double _d_i0 = i;
// CHECK-NEXT: double _t0 = _d_i0;
// CHECK-NEXT: _d_i0 += 1;
// CHECK-NEXT: return {i, *_d_i};
// CHECK-NEXT: return {i, _d_i};
// CHECK-NEXT: }

// CHECK: void custom_identity_pullback(double &i, double _d_y, double *_d_i) {
// CHECK-NEXT: *_d_i += _d_y;
// CHECK-NEXT: }

// CHECK: clad::ValueAndAdjoint<double &, double &> custom_identity_forw(double &i, double &_d_i) {
// CHECK-NEXT: return {i, _d_i};
// CHECK-NEXT: }

// CHECK: void check_and_return_pullback(double x, char c, const char *s, double _d_y, double *_d_x, char *_d_c, char *_d_s) {
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: double _d_cond0;
Expand Down
12 changes: 6 additions & 6 deletions test/Gradient/MemberFunctions.C
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,11 @@ double fn2(SimpleFunctions& sf, double i) {

// CHECK: void ref_mem_fn_pullback(double i, double _d_y, SimpleFunctions *_d_this, double *_d_i);

// CHECK: clad::ValueAndAdjoint<double &, double &> ref_mem_fn_forw(double i, SimpleFunctions *_d_this, double *_d_i);
// CHECK: clad::ValueAndAdjoint<double &, double &> ref_mem_fn_forw(double i, SimpleFunctions *_d_this, double _d_i);

// CHECK: void fn2_grad(SimpleFunctions &sf, double i, SimpleFunctions *_d_sf, double *_d_i) {
// CHECK-NEXT: SimpleFunctions _t0 = sf;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = _t0.ref_mem_fn_forw(i, &(*_d_sf), nullptr);
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = _t0.ref_mem_fn_forw(i, &(*_d_sf), *_d_i);
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: _t0.ref_mem_fn_pullback(i, 1, &(*_d_sf), &_r0);
Expand All @@ -456,11 +456,11 @@ double fn5(SimpleFunctions& v, double value) {

// CHECK: void operator_plus_equal_pullback(double value, SimpleFunctions _d_y, SimpleFunctions *_d_this, double *_d_value);

// CHECK: clad::ValueAndAdjoint<SimpleFunctions &, SimpleFunctions &> operator_plus_equal_forw(double value, SimpleFunctions *_d_this, SimpleFunctions *_d_value);
// CHECK: clad::ValueAndAdjoint<SimpleFunctions &, SimpleFunctions &> operator_plus_equal_forw(double value, SimpleFunctions *_d_this, double _d_value);

// CHECK: void fn5_grad(SimpleFunctions &v, double value, SimpleFunctions *_d_v, double *_d_value) {
// CHECK-NEXT: SimpleFunctions _t0 = v;
// CHECK-NEXT: clad::ValueAndAdjoint<SimpleFunctions &, SimpleFunctions &> _t1 = _t0.operator_plus_equal_forw(value, &(*_d_v), nullptr);
// CHECK-NEXT: clad::ValueAndAdjoint<SimpleFunctions &, SimpleFunctions &> _t1 = _t0.operator_plus_equal_forw(value, &(*_d_v), *_d_value);
// CHECK-NEXT: (*_d_v).x += 1;
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
Expand Down Expand Up @@ -602,7 +602,7 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: clad::ValueAndAdjoint<double &, double &> ref_mem_fn_forw(double i, SimpleFunctions *_d_this, double *_d_i) {
// CHECK: clad::ValueAndAdjoint<double &, double &> ref_mem_fn_forw(double i, SimpleFunctions *_d_this, double _d_i) {
// CHECK-NEXT: double _t0 = this->x;
// CHECK-NEXT: this->x = +i;
// CHECK-NEXT: double _t1 = this->x;
Expand All @@ -620,7 +620,7 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: clad::ValueAndAdjoint<SimpleFunctions &, SimpleFunctions &> operator_plus_equal_forw(double value, SimpleFunctions *_d_this, SimpleFunctions *_d_value) {
// CHECK: clad::ValueAndAdjoint<SimpleFunctions &, SimpleFunctions &> operator_plus_equal_forw(double value, SimpleFunctions *_d_this, double _d_value) {
// CHECK-NEXT: double _t0 = this->x;
// CHECK-NEXT: this->x += value;
// CHECK-NEXT: return {*this, (*_d_this)};
Expand Down

0 comments on commit 17cdf75

Please sign in to comment.