Skip to content

Commit

Permalink
Do not assume index of derivedFn and code parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
MihailMihov committed Oct 3, 2024
1 parent a91260f commit 293b44d
Showing 1 changed file with 43 additions and 27 deletions.
70 changes: 43 additions & 27 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,21 @@ namespace clad {

FunctionDecl* replacementFD = OverloadedFD ? OverloadedFD : FD;

auto codeArgIdx = -1;
auto derivedFnArgIdx = -1;
auto idx = 0;
for(auto* arg : call->arguments()) {
if(auto *default_arg_expr = dyn_cast<CXXDefaultArgExpr>(arg)) {
std::string argName = default_arg_expr->getParam()->getNameAsString();
if(argName == "derivedFn") {
derivedFnArgIdx = idx;
} else if(argName == "code") {
codeArgIdx = idx;
}
}
++idx;
}

// Index of "CUDAkernel" parameter:
int numArgs = static_cast<int>(call->getNumArgs());
if (numArgs > 4) {
Expand All @@ -203,8 +218,6 @@ namespace clad {
call->setArg(kernelArgIdx, cudaKernelFlag);
numArgs--;
}
auto codeArgIdx = numArgs - 1;
auto derivedFnArgIdx = numArgs - 2;

// Create ref to generated FD.
DeclRefExpr* DRE =
Expand All @@ -220,31 +233,34 @@ namespace clad {
if (isa<CXXMethodDecl>(DRE->getDecl()))
DRE->setValueKind(CLAD_COMPAT_ExprValueKind_R_or_PR_Value);

// Add the "&" operator
auto newUnOp =
SemaRef.BuildUnaryOp(nullptr, noLoc, UnaryOperatorKind::UO_AddrOf, DRE)
.get();
call->setArg(derivedFnArgIdx, newUnOp);

// Update the code parameter.
if (CXXDefaultArgExpr* Arg
= dyn_cast<CXXDefaultArgExpr>(call->getArg(codeArgIdx))) {
clang::LangOptions LangOpts;
LangOpts.CPlusPlus = true;
clang::PrintingPolicy Policy(LangOpts);
Policy.Bool = true;

std::string s;
llvm::raw_string_ostream Out(s);
FD->print(Out, Policy);
Out.flush();

StringLiteral* SL = utils::CreateStringLiteral(C, Out.str());
Expr* newArg =
SemaRef.ImpCastExprToType(SL,
Arg->getType(),
CK_ArrayToPointerDecay).get();
call->setArg(codeArgIdx, newArg);
if(derivedFnArgIdx != -1) {
// Add the "&" operator
auto newUnOp =
SemaRef.BuildUnaryOp(nullptr, noLoc, UnaryOperatorKind::UO_AddrOf, DRE)
.get();
call->setArg(derivedFnArgIdx, newUnOp);
}

// Update the code parameter if it was found.
if (codeArgIdx != -1) {
if (auto* Arg = dyn_cast<CXXDefaultArgExpr>(call->getArg(codeArgIdx))) {
clang::LangOptions LangOpts;
LangOpts.CPlusPlus = true;
clang::PrintingPolicy Policy(LangOpts);
Policy.Bool = true;

std::string s;
llvm::raw_string_ostream Out(s);
FD->print(Out, Policy);
Out.flush();

StringLiteral* SL = utils::CreateStringLiteral(C, Out.str());
Expr* newArg =
SemaRef.ImpCastExprToType(SL,
Arg->getType(),
CK_ArrayToPointerDecay).get();
call->setArg(codeArgIdx, newArg);
}
}
}

Expand Down

0 comments on commit 293b44d

Please sign in to comment.