Skip to content

Commit

Permalink
Add location information and improve clarity of diagnostics.
Browse files Browse the repository at this point in the history
This patch is a first step towards diagnostics refactoring in the context of
non-differentiable propagators.
  • Loading branch information
vgvassilev committed Aug 6, 2024
1 parent 33d9441 commit fc64644
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 61 deletions.
15 changes: 5 additions & 10 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,14 +394,12 @@ namespace clad {
/// to avoid recomputation.
static bool UsefulToStore(clang::Expr* E);
/// A flag for silencing warnings/errors output by diag function.
bool silenceDiags = false;
/// Shorthand to issues a warning or error.
template <std::size_t N>
void diag(clang::DiagnosticsEngine::Level level, // Warning or Error
clang::SourceLocation loc, const char (&format)[N],
llvm::ArrayRef<llvm::StringRef> args = {}) {
if (!silenceDiags)
m_Builder.diag(level, loc, format, args);
m_Builder.diag(level, loc, format, args);
}

/// Creates unique identifier of the form "_nameBase<number>" that is
Expand Down Expand Up @@ -584,17 +582,14 @@ namespace clad {
clang::Expr* GetSingleArgCentralDiffCall(
clang::Expr* targetFuncCall, clang::Expr* targetArg, unsigned targetPos,
unsigned numArgs, llvm::SmallVectorImpl<clang::Expr*>& args);

/// Emits diagnostic messages on differentiation (or lack thereof) for
/// call expressions.
///
/// \param[in] \c funcName The name of the underlying function of the
/// call expression.
/// \param[in] \c FD - The function declaration.
/// \param[in] \c srcLoc Any associated source location information.
/// \param[in] \c isDerived A flag to determine if differentiation of the
/// call expression was successful.
void CallExprDiffDiagnostics(llvm::StringRef funcName,
clang::SourceLocation srcLoc,
bool isDerived);
void CallExprDiffDiagnostics(const clang::FunctionDecl* FD,
clang::SourceLocation srcLoc);

clang::QualType DetermineCladArrayValueType(clang::QualType T);

Expand Down
3 changes: 1 addition & 2 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ DerivativeAndOverload
BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
const DiffRequest& request) {
assert(m_DiffReq == request && "Can't pass two different requests!");
silenceDiags = !request.VerboseDiags;
m_Functor = request.Functor;
assert(m_DiffReq.Mode == DiffMode::forward);
assert(!m_DerivativeInFlight &&
Expand Down Expand Up @@ -1331,7 +1330,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
GetSingleArgCentralDiffCall(fnCallee, CallArgs[0],
/*targetPos=*/0, /*numArgs=*/1, CallArgs);
}
CallExprDiffDiagnostics(FD->getNameAsString(), CE->getBeginLoc(), callDiff);
CallExprDiffDiagnostics(FD, CE->getBeginLoc());
if (!callDiff) {
auto zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Expand Down
9 changes: 1 addition & 8 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
}
if (!NSD) {
NSD = utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist);
if (!forCustomDerv && !NSD) {
diag(DiagnosticsEngine::Warning, noLoc,
"Numerical differentiation is diabled using the "
"-DCLAD_NO_NUM_DIFF "
"flag, this means that every try to numerically differentiate a "
"function will fail! Remove the flag to revert to default "
"behaviour.");
if (!NSD)
return R;
}
}
DeclContext* DC = NSD;

Expand Down
1 change: 0 additions & 1 deletion lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ DerivativeAndOverload
ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
const DiffRequest& request) {
assert(m_DiffReq == request);
silenceDiags = !request.VerboseDiags;

assert(m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass);

Expand Down
5 changes: 1 addition & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
const DiffRequest& request) {
if (m_ExternalSource)
m_ExternalSource->ActOnStartOfDerive();
silenceDiags = !request.VerboseDiags;
assert(m_DiffReq == request);

// FIXME: reverse mode plugins may have request mode other than
Expand Down Expand Up @@ -479,7 +478,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// for the two 'Derive's being different functions.
if (m_ExternalSource)
m_ExternalSource->ActOnStartOfDerive();
silenceDiags = !request.VerboseDiags;
// FIXME: We should not use const_cast to get the decl request here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<DiffRequest&>(m_DiffReq) = request;
Expand Down Expand Up @@ -1943,8 +1941,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CE->getNumArgs(), dfdx(), PreCallStmts, PostCallStmts,
DerivedCallArgs, CallArgDx);
}
CallExprDiffDiagnostics(FD->getNameAsString(), CE->getBeginLoc(),
OverloadedDerivedFn);
CallExprDiffDiagnostics(FD, CE->getBeginLoc());
if (!OverloadedDerivedFn) {
Stmts& block = getCurrentBlock(direction::reverse);
block.insert(block.begin(), PreCallStmts.begin(),
Expand Down
36 changes: 21 additions & 15 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
#include "clang/AST/TemplateBase.h"
#include "clang/Lex/Preprocessor.h"
#include "clang/Sema/Lookup.h"
#include "clang/Sema/Overload.h"
#include "clang/Sema/Scope.h"
Expand Down Expand Up @@ -733,23 +734,28 @@ namespace clad {
/*namespaceShouldExist=*/false);
}

void VisitorBase::CallExprDiffDiagnostics(llvm::StringRef funcName,
SourceLocation srcLoc, bool isDerived){
if (!isDerived) {
// Function was not derived => issue a warning.
diag(DiagnosticsEngine::Warning,
srcLoc,
"function '%0' was not differentiated because clad failed to "
"differentiate it and no suitable overload was found in "
"namespace 'custom_derivatives', and function may not be "
"eligible for numerical differentiation.",
void VisitorBase::CallExprDiffDiagnostics(const clang::FunctionDecl* FD,
SourceLocation srcLoc) {
bool NumDiffEnabled =
!m_Sema.getPreprocessor().isMacroDefined("CLAD_NO_NUM_DIFF");
// FIXME: Switch to the real diagnostics engine and pass FD directly.
std::string funcName = FD->getNameAsString();
diag(DiagnosticsEngine::Warning, srcLoc,
"function '%0' was not differentiated because clad failed to "
"differentiate it and no suitable overload was found in "
"namespace 'custom_derivatives'",
{funcName});
if (NumDiffEnabled) {
diag(DiagnosticsEngine::Note, srcLoc,
"falling back to numerical differentiation for '%0' since no "
"suitable overload was found and clad could not derive it; "
"to disable this feature, compile your programs with "
"-DCLAD_NO_NUM_DIFF",
{funcName});
} else {
diag(DiagnosticsEngine::Warning, noLoc,
"Falling back to numerical differentiation for '%0' since no "
"suitable overload was found and clad could not derive it. "
"To disable this feature, compile your programs with "
"-DCLAD_NO_NUM_DIFF.",
diag(DiagnosticsEngine::Note, srcLoc,
"fallback to numerical differentiation is disabled by the "
"'CLAD_NO_NUM_DIFF' macro; considering '%0' as 0",
{funcName});
}
}
Expand Down
7 changes: 7 additions & 0 deletions test/Gradient/NonDifferentiableError.C
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ non_differentiable double fn_s2_mem_fn(double i, double j) {
return obj.mem_fn(i, j) + i * j;
}

double no_body(double x);

double fn1(double x) { return no_body(x); } //expected-warning {{function 'no_body' was not differentiated}}
//expected-note@34 {{fallback to numerical differentiation is disabled}}
double fn2(double x) { return fn1(x); }

#define INIT_EXPR(classname) \
classname expr_1(2, 3); \
classname expr_2(3, 5);
Expand All @@ -48,4 +54,5 @@ int main() {
INIT_EXPR(SimpleFunctions2);
TEST_CLASS(SimpleFunctions2, mem_fn, 3, 5);
TEST_FUNC(fn_s2_mem_fn, 3, 5); // expected-error {{attempted differentiation of function 'fn_s2_mem_fn', which is marked as non-differentiable}}
auto fn2_grad = clad::gradient(fn2);
}
8 changes: 4 additions & 4 deletions test/NumericalDiff/GradientMultiArg.C
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: %cladnumdiffclang %s -I%S/../../include -oGradientMultiArg.out 2>&1 | FileCheck -check-prefix=CHECK %s
// RUN: %cladnumdiffclang %s -I%S/../../include -oGradientMultiArg.out -Xclang -verify 2>&1 | FileCheck -check-prefix=CHECK %s
// RUN: ./GradientMultiArg.out | %filecheck_exec %s
// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oGradientMultiArg.out
// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oGradientMultiArg.out -Xclang -verify
// RUN: ./GradientMultiArg.out | %filecheck_exec %s

//CHECK-NOT: {{.*error|warning|note:.*}}
Expand All @@ -11,9 +11,9 @@
#include <algorithm>

double test_1(double x, double y){
return std::hypot(x, y);
return std::hypot(x, y); // expected-warning {{function 'hypot' was not differentiated}}
// expected-note@14 {{falling back to numerical differentiation}}
}
// CHECK: warning: Falling back to numerical differentiation for 'hypot' since no suitable overload was found and clad could not derive it. To disable this feature, compile your programs with -DCLAD_NO_NUM_DIFF.
// CHECK: void test_1_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
Expand Down
11 changes: 5 additions & 6 deletions test/NumericalDiff/NoNumDiff.C
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
// RUN: %cladclang %s -I%S/../../include -oNoNumDiff.out 2>&1 | FileCheck -check-prefix=CHECK %s
// RUN: %cladclang %s -I%S/../../include -oNoNumDiff.out -Xclang -verify 2>&1 | FileCheck -check-prefix=CHECK %s

//CHECK-NOT: {{.*error|warning|note:.*}}

#include "clad/Differentiator/Differentiator.h"

#include <cmath>

double func(double x) { return std::tanh(x); }
double func(double x) { return std::tanh(x); } // expected-warning 2{{function 'tanh' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives'}}
// expected-note@9 2{{fallback to numerical differentiation is disabled by the 'CLAD_NO_NUM_DIFF' macro}}

//CHECK: warning: Numerical differentiation is diabled using the -DCLAD_NO_NUM_DIFF flag, this means that every try to numerically differentiate a function will fail! Remove the flag to revert to default behaviour.
//CHECK: warning: Numerical differentiation is diabled using the -DCLAD_NO_NUM_DIFF flag, this means that every try to numerically differentiate a function will fail! Remove the flag to revert to default behaviour.
//CHECK: double func_darg0(double x) {
//CHECK-NEXT: double _d_x = 1;
//CHECK-NEXT: return 0;
Expand All @@ -24,6 +23,6 @@ double func(double x) { return std::tanh(x); }


int main(){
clad::differentiate(func, "x");
clad::gradient(func);
clad::differentiate(func, "x");
clad::gradient(func);
}
15 changes: 8 additions & 7 deletions test/NumericalDiff/NumDiff.C
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
// RUN: %cladnumdiffclang %s -I%S/../../include -oNumDiff.out 2>&1 | FileCheck -check-prefix=CHECK %s
// RUN: %cladnumdiffclang %s -I%S/../../include -oNumDiff.out -Xclang -verify 2>&1 | FileCheck -check-prefix=CHECK %s
// RUN: ./NumDiff.out | %filecheck_exec %s
// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oNumDiff.out
// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr -Xclang -verify %s -I%S/../../include -oNumDiff.out
// RUN: ./NumDiff.out | %filecheck_exec %s
//CHECK-NOT: {{.*error|warning|note:.*}}
#include "clad/Differentiator/Differentiator.h"

double test_1(double x){
return tanh(x);
return tanh(x); // expected-warning {{function 'tanh' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives'}}
// expected-note@9 {{falling back to numerical differentiation for 'tanh'}}
}
//CHECK: warning: Falling back to numerical differentiation for 'tanh' since no suitable overload was found and clad could not derive it. To disable this feature, compile your programs with -DCLAD_NO_NUM_DIFF.
//CHECK: warning: Falling back to numerical differentiation for 'log10' since no suitable overload was found and clad could not derive it. To disable this feature, compile your programs with -DCLAD_NO_NUM_DIFF.

//CHECK: void test_1_grad(double x, double *_d_x) {
//CHECK-NEXT: {
Expand All @@ -21,7 +20,8 @@ double test_1(double x){


double test_2(double x){
return std::log10(x);
return std::log10(x);// expected-warning {{function 'log10' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives'}}
// expected-note@23 {{falling back to numerical differentiation for 'log10'}}
}
//CHECK: double test_2_darg0(double x) {
//CHECK-NEXT: double _d_x = 1;
Expand All @@ -32,7 +32,8 @@ double test_2(double x){
double test_3(double x) {
if (x > 0) {
double constant = 11.;
return std::hypot(x, constant);
return std::hypot(x, constant); // expected-warning {{function 'hypot' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives'}}
// expected-note@35 {{falling back to numerical differentiation for 'hypot'}}
}
return 0;
}
Expand Down
8 changes: 4 additions & 4 deletions test/NumericalDiff/PrintErrorNumDiff.C
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -fprint-num-diff-errors %s -I%S/../../include -oPrintErrorNumDiff.out 2>&1 | FileCheck -check-prefix=CHECK %s
// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -fprint-num-diff-errors %s -I%S/../../include -oPrintErrorNumDiff.out -Xclang -verify 2>&1 | FileCheck -check-prefix=CHECK %s
// RUN: ./PrintErrorNumDiff.out | %filecheck_exec %s
// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -fprint-num-diff-errors -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oPrintErrorNumDiff.out
// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -fprint-num-diff-errors -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oPrintErrorNumDiff.out -Xclang -verify
// RUN: ./PrintErrorNumDiff.out | %filecheck_exec %s

//CHECK-NOT: {{.*error|warning|note:.*}}
Expand All @@ -12,10 +12,10 @@
extern "C" int printf(const char* fmt, ...);

double test_1(double x){
return tanh(x);
return tanh(x); // expected-warning {{function 'tanh' was not differentiated because}}
// expected-note@15 {{falling back to numerical differentiation for 'tanh}}
}

//CHECK: warning: Falling back to numerical differentiation for 'tanh' since no suitable overload was found and clad could not derive it. To disable this feature, compile your programs with -DCLAD_NO_NUM_DIFF.
//CHECK: void test_1_grad(double x, double *_d_x) {
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 0;
Expand Down

0 comments on commit fc64644

Please sign in to comment.