Skip to content

Commit

Permalink
Add support for diff of ref return types in rev mode
Browse files Browse the repository at this point in the history
  • Loading branch information
parth-07 authored and PhrygianGates committed Aug 16, 2023
1 parent 049d9db commit e2b5e5f
Show file tree
Hide file tree
Showing 19 changed files with 717 additions and 41 deletions.
20 changes: 11 additions & 9 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ jobs:
chmod +x git-clang-format
- name: Run git-clang-format
run: |
PR_BASE=$(git rev-list ${{ github.event.pull_request.head.sha }} ^${{ github.event.pull_request.base.sha }} | tail --lines 1 | xargs -I {} git rev-parse {}~1)
echo "running git clang-format against $PR_BASE commit"
git \
-c color.ui=always \
-c diff.wsErrorHighlight=all \
-c color.diff.whitespace='red reverse' \
clang-format-15 --diff --binary clang-format-15 origin/master -- demos/ include/ lib/ tools/ || \
clang-format-15 --diff --binary clang-format-15 --commit $PR_BASE -- demos/ include/ lib/ tools/ || \
(echo "Please run the following git-clang-format locally to fix the formatting: \n
git clang-format origin/master -- demos/ include/ lib/ tools/" && exit 1)
build:
Expand Down Expand Up @@ -82,12 +84,12 @@ jobs:
os: macos-latest
compiler: clang
clang-runtime: '14'

- name: osx-clang-runtime15
os: macos-latest
compiler: clang
clang-runtime: '15'

- name: osx-clang-runtime16
os: macos-latest
compiler: clang
Expand Down Expand Up @@ -414,7 +416,7 @@ jobs:
os: ubuntu-22.04
compiler: clang-15
clang-runtime: '14'

- name: ubu22-clang15-runtime15
os: ubuntu-22.04
compiler: clang-15
Expand Down Expand Up @@ -614,15 +616,15 @@ jobs:
echo "PATH_TO_LLVM_BUILD=$env:PATH_TO_LLVM_BUILD" >> $env:GITHUB_ENV
- name: Setup CUDA 8 on Linux
if: ${{ matrix.cuda == true }}
run: |
wget --no-verbose https://developer.nvidia.com/compute/cuda/8.0/Prod2/local_installers/cuda_8.0.61_375.26_linux-run
run: |
wget --no-verbose https://developer.nvidia.com/compute/cuda/8.0/Prod2/local_installers/cuda_8.0.61_375.26_linux-run
wget --no-verbose https://developer.nvidia.com/compute/cuda/8.0/Prod2/patches/2/cuda_8.0.61.2_linux-run
sh ./cuda_8.0.61_375.26_linux-run --tar mxvf
sudo cp InstallUtils.pm /usr/lib/x86_64-linux-gnu/perl-base
export $PERL5LIB
sudo sh cuda_8.0.61_375.26_linux-run --override --no-opengl-lib --silent --toolkit --kernel-source-path=/lib/modules/4.15.0-1113-azure/build
sudo sh cuda_8.0.61.2_linux-run --silent --accept-eula
export PATH=/usr/local/cuda-8.0/bin:${PATH}
sudo sh cuda_8.0.61_375.26_linux-run --override --no-opengl-lib --silent --toolkit --kernel-source-path=/lib/modules/4.15.0-1113-azure/build
sudo sh cuda_8.0.61.2_linux-run --silent --accept-eula
export PATH=/usr/local/cuda-8.0/bin:${PATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/cuda-8.0/lib64
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" >> $GITHUB_ENV
echo "PATH=$PATH" >> $GITHUB_ENV
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.2~dev
1.3~dev
10 changes: 5 additions & 5 deletions docs/internalDocs/ReleaseNotes.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Introduction
============

This document contains the release notes for the automatic differentiation
plugin for clang Clad, release 1.2. Clad is built on top of
plugin for clang Clad, release 1.3. Clad is built on top of
[Clang](http://clang.llvm.org) and [LLVM](http://llvm.org>) compiler
infrastructure. Here we describe the status of Clad in some detail, including
major improvements from the previous release and new feature work.
Expand All @@ -11,7 +11,7 @@ Note that if you are reading this file from a git checkout,
this document applies to the *next* release, not the current one.


What's New in Clad 1.2?
What's New in Clad 1.3?
========================

Some of the major new features and improvements to Clad are listed here. Generic
Expand All @@ -21,7 +21,7 @@ described first.
External Dependencies
---------------------

* Clad now works with clang-5.0 to clang-15
* Clad now works with clang-5.0 to clang-16


Forward Mode & Reverse Mode
Expand Down Expand Up @@ -54,7 +54,7 @@ Fixed Bugs
[XXX](https://github.com/vgvassilev/clad/issues/XXX)

<!---Get release bugs
git log v1.0..master | grep 'Fixes|Closes'
git log v1.2..master | grep 'Fixes|Closes'
--->

Special Kudos
Expand All @@ -68,6 +68,6 @@ FirstName LastName (#commits)
A B (N)

<!---Find contributor list for this release
git log --pretty=format:"%an" v1.1...master | sort | uniq -c | sort -rn |\
git log --pretty=format:"%an" v1.2...master | sort | uniq -c | sort -rn |\
sed -E 's,^ *([0-9]+) (.*)$,\2 \(\1\),'
--->
4 changes: 4 additions & 0 deletions docs/userDocs/source/user/UsingVectorMode.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Using Vector Mode for Differentiation
**************************************

.. note::
This feature is still under development and may result in unexpected
behavior. Please report any issues you find.

For forward mode AD, the restriction is that the function can be only be
differentiated with respect to a single input variable. However, in many cases,
it is desirable to differentiate a function with respect to multiple input
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
channels:
- conda-forge
dependencies:
- clad=1.0
- clad=0.9
- xeus-cling
2 changes: 1 addition & 1 deletion include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ namespace clad {
friend class ReverseModeVisitor;
friend class HessianModeVisitor;
friend class JacobianModeVisitor;

friend class ReverseModeForwPassVisitor;
clang::Sema& m_Sema;
plugin::CladPlugin& m_CladPlugin;
clang::ASTContext& m_Context;
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ enum class DiffMode {
reverse,
hessian,
jacobian,
reverse_mode_forward_pass,
error_estimation
};
}
Expand Down
6 changes: 6 additions & 0 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
#include <cstring>

namespace clad {
template<typename T, typename U>
struct ValueAndAdjoint {
T value;
U adjoint;
};

/// \returns the size of a c-style string
CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
unsigned int count;
Expand Down
38 changes: 38 additions & 0 deletions include/clad/Differentiator/ReverseModeForwPassVisitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#ifndef CLAD_TRANSFORM_SOURCE_FN_VISITOR_H
#define CLAD_TRANSFORM_SOURCE_FN_VISITOR_H

#include "clad/Differentiator/ParseDiffArgsTypes.h"
#include "clad/Differentiator/ReverseModeVisitor.h"

#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"

#include "llvm/ADT/SmallVector.h"

namespace clad {
class ReverseModeForwPassVisitor : public ReverseModeVisitor {
private:
Stmts m_Globals;

llvm::SmallVector<clang::QualType, 8>
ComputeParamTypes(const DiffParams& diffParams);
clang::QualType ComputeReturnType();
llvm::SmallVector<clang::ParmVarDecl*, 8> BuildParams(DiffParams& diffParams);
clang::QualType GetParameterDerivativeType(clang::QualType yType,
clang::QualType xType);

public:
ReverseModeForwPassVisitor(DerivativeBuilder& builder);
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);

StmtDiff ProcessSingleStmt(const clang::Stmt* S);

StmtDiff VisitStmt(const clang::Stmt* S) override;
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override;
StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE) override;
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
};
} // namespace clad

#endif
10 changes: 5 additions & 5 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace clad {
: public clang::ConstStmtVisitor<ReverseModeVisitor, StmtDiff>,
public VisitorBase {

private:
protected:
// FIXME: We should remove friend-dependency of the plugin classes here.
// For this we will need to separate out AST related functions in
// a separate namespace, as well as add getters/setters function of
Expand Down Expand Up @@ -321,11 +321,11 @@ namespace clad {
StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO);
StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL);
StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE);
StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
StmtDiff VisitDeclStmt(const clang::DeclStmt* DS);
StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL);
StmtDiff VisitForStmt(const clang::ForStmt* FS);
Expand All @@ -335,8 +335,8 @@ namespace clad {
StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL);
StmtDiff VisitMemberExpr(const clang::MemberExpr* ME);
StmtDiff VisitParenExpr(const clang::ParenExpr* PE);
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS);
StmtDiff VisitStmt(const clang::Stmt* S);
virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS);
virtual StmtDiff VisitStmt(const clang::Stmt* S);
StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp);
StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC);
/// Decl is not Stmt, so it cannot be visited directly.
Expand Down
1 change: 1 addition & 0 deletions lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ add_llvm_library(cladDifferentiator
HessianModeVisitor.cpp
JacobianModeVisitor.cpp
MultiplexExternalRMVSource.cpp
ReverseModeForwPassVisitor.cpp
ReverseModeVisitor.cpp
StmtClone.cpp
VectorForwardModeVisitor.cpp
Expand Down
5 changes: 5 additions & 0 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "clad/Differentiator/HessianModeVisitor.h"
#include "clad/Differentiator/JacobianModeVisitor.h"
#include "clad/Differentiator/ReverseModeVisitor.h"
#include "clad/Differentiator/ReverseModeForwPassVisitor.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/VectorForwardModeVisitor.h"

Expand Down Expand Up @@ -230,6 +232,9 @@ namespace clad {
result = V.DerivePullback(FD, request);
if (!m_ErrorEstHandler.empty())
CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel);
} else if (request.Mode == DiffMode::reverse_mode_forward_pass) {
ReverseModeForwPassVisitor V(*this);
result = V.Derive(FD, request);
} else if (request.Mode == DiffMode::hessian) {
HessianModeVisitor H(*this);
result = H.Derive(FD, request);
Expand Down
Loading

0 comments on commit e2b5e5f

Please sign in to comment.