From 151e3b82c06dc78676895825dcedd7439e7918f1 Mon Sep 17 00:00:00 2001 From: "Documenter.jl" Date: Tue, 22 Oct 2024 06:51:54 +0000 Subject: [PATCH] build based on f0c170e --- dev/.documenter-siteinfo.json | 2 +- dev/algorithmic_differentiation/index.html | 4 ++-- dev/debug_mode/index.html | 6 +++--- dev/debugging_and_mwes/index.html | 6 +++--- dev/index.html | 2 +- dev/known_limitations/index.html | 4 ++-- dev/mathematical_interpretation/index.html | 8 ++++---- dev/objects.inv | Bin 1613 -> 1663 bytes dev/running_tests_locally/index.html | 2 ++ dev/search_index.js | 2 +- dev/tools_for_rules/index.html | 10 +++++----- dev/understanding_intro/index.html | 2 +- 12 files changed, 25 insertions(+), 23 deletions(-) create mode 100644 dev/running_tests_locally/index.html diff --git a/dev/.documenter-siteinfo.json b/dev/.documenter-siteinfo.json index 07d67615a..229697e50 100644 --- a/dev/.documenter-siteinfo.json +++ b/dev/.documenter-siteinfo.json @@ -1 +1 @@ -{"documenter":{"julia_version":"1.11.1","generation_timestamp":"2024-10-21T17:53:19","documenter_version":"1.7.0"}} \ No newline at end of file +{"documenter":{"julia_version":"1.11.1","generation_timestamp":"2024-10-22T06:51:48","documenter_version":"1.7.0"}} \ No newline at end of file diff --git a/dev/algorithmic_differentiation/index.html b/dev/algorithmic_differentiation/index.html index 96761b76c..dbeaf7fac 100644 --- a/dev/algorithmic_differentiation/index.html +++ b/dev/algorithmic_differentiation/index.html @@ -1,5 +1,5 @@ -Algorithmic Differentiation · Mooncake.jl

Algorithmic Differentiation

This section introduces the mathematics behind AD. Even if you have worked with AD before, we recommend reading in order to acclimatise yourself to the perspective that Mooncake.jl takes on the subject.

Derivatives

A foundation on which all of AD is built the the derivate – we require a fairly general definition of it, which we build up to here.

Scalar-to-Scalar Functions

Consider first $f : \RR \to \RR$, which we require to be differentiable at $x \in \RR$. Its derivative at $x$ is usually thought of as the scalar $\alpha \in \RR$ such that

\[\text{d}f = \alpha \, \text{d}x .\]

Loosely speaking, by this notation we mean that for arbitrary small changes $\text{d} x$ in the input to $f$, the change in the output $\text{d} f$ is $\alpha \, \text{d}x$. We refer readers to the first few minutes of the first lecture mentioned before for a more careful explanation.

Vector-to-Vector Functions

The generalisation of this to Euclidean space should be familiar: if $f : \RR^P \to \RR^Q$ is differentiable at a point $x \in \RR^P$, then the derivative of $f$ at $x$ is given by the Jacobian matrix at $x$, denoted $J[x] \in \RR^{Q \times P}$, such that

\[\text{d}f = J[x] \, \text{d}x .\]

It is possible to stop here, as all the functions we shall need to consider can in principle be written as functions on some subset $\RR^P$.

However, when we consider differentiating computer programmes, we will have to deal with complicated nested data structures, e.g. structs inside Tuples inside Vectors etc. While all of these data structures can be mapped onto a flat vector in order to make sense of the Jacobian of a computer programme, this becomes very inconvenient very quickly. To see the problem, consider the Julia function whose input is of type Tuple{Tuple{Float64, Vector{Float64}}, Vector{Float64}, Float64} and whose output is of type Tuple{Vector{Float64}, Float64}. What kind of object might be use to represent the derivative of a function mapping between these two spaces? We certainly can treat these as structured "view" into a "flat" Vector{Float64}s, and then define a Jacobian, but actually finding this mapping is a tedious exercise, even if it quite obviously exists.

In fact, a more general formulation of the derivative is used all the time in the context of AD – the matrix calculus discussed by [1] and [2] (to name a couple) make use of a generalised form of the derivative in order to work with functions which map to and from matrices (albeit there are slight differences in naming conventions from text to text), without needing to "flatten" them into vectors in order to make sense of them.

In general, it will be much easier to avoid "flattening" operations wherever possible. In order to do so, we now introduce a generalised notion of the derivative.

Functions Between More General Spaces

In order to avoid the difficulties described above, we consider functions $f : \mathcal{X} \to \mathcal{Y}$, where $\mathcal{X}$ and $\mathcal{Y}$ are finite dimensional real Hilbert spaces (read: finite-dimensional vector space with an inner product, and real-valued scalars). This definition includes functions to / from $\RR$, $\RR^D$, but also real-valued matrices, and any other "container" for collections of real numbers. Furthermore, we shall see later how we can model all sorts of structured representations of data directly as such spaces.

For such spaces, the derivative of $f$ at $x \in \mathcal{X}$ is the linear operator (read: linear function) $D f [x] : \mathcal{X} \to \mathcal{Y}$ satisfying

\[\text{d}f = D f [x] \, (\text{d} x)\]

The purpose of this linear operator is to provide a linear approximation to $f$ which is accurate for arguments which are very close to $x$.

Please note that $D f [x]$ is a single mathematical object, despite the fact that 3 separate symbols are used to denote it – $D f [x] (\dot{x})$ denotes the application of the function $D f [x]$ to argument $\dot{x}$. Furthermore, the dot-notation ($\dot{x}$) does not have anything to do with time-derivatives, it is simply common notation used in the AD literature to denote the arguments of derivatives.

So, instead of thinking of the derivative as a number or a matrix, we think about it as a function. We can express the previous notions of the derivative in this language.

In the scalar case, rather than thinking of the derivative as being $\alpha$, we think of it is a the linear operator $D f [x] (\dot{x}) := \alpha \dot{x}$. Put differently, rather than thinking of the derivative as the slope of the tangent to $f$ at $x$, think of it as the function decribing the tangent itself. Observe that up until now we had only considered inputs to $D f [x]$ which were small ($\text{d} x$) – here we extend it to the entire space $\mathcal{X}$ and denote inputs in this space $\dot{x}$. Inputs $\dot{x}$ should be thoughts of as "directions", in the directional derivative sense (why this is true will be discussed later).

Similarly, if $\mathcal{X} = \RR^P$ and $\mathcal{Y} = \RR^Q$ then this operator can be specified in terms of the Jacobian matrix: $D f [x] (\dot{x}) := J[x] \dot{x}$ – brackets are used to emphasise that $D f [x]$ is a function, and is being applied to $\dot{x}$.[note_for_geometers]

To reiterate, for the rest of this document, we define the derivative to be "multiply by $\alpha$" or "multiply by $J[x]$", rather than to be $\alpha$ or $J[x]$. So whenever you see the word "derivative", you should think "linear function".

The Chain Rule

The chain rule is the result which makes AD work. Fortunately, it applies to this version of the derivative:

\[f = g \circ h \implies D f [x] = (D g [h(x)]) \circ (D h [x])\]

By induction, this extends to a collection of $N$ functions $f_1, \dots, f_N$:

\[f := f_N \circ \dots \circ f_1 \implies D f [x] = (D f_N [x_N]) \circ \dots \circ (D f_1 [x_1]),\]

where $x_{n+1} := f(x_n)$, and $x_1 := x$.

An aside: the definition of the Frechet Derivative

This definition of the derivative has a name: the Frechet derivative. It is a generalisation of the Total Derivative. Formally, we say that a function $f : \mathcal{X} \to \mathcal{Y}$ is differentiable at a point $x \in \mathcal{X}$ if there exists a linear operator $D f [x] : \mathcal{X} \to \mathcal{Y}$ (the derivative) satisfying

\[\lim_{\text{d} h \to 0} \frac{\| f(x + \text{d} h) - f(x) + D f [x] (\text{d} h) \|_\mathcal{Y}}{\| \text{d}h \|_\mathcal{X}} = 0,\]

where $\| \cdot \|_\mathcal{X}$ and $\| \cdot \|_\mathcal{Y}$ are the norms associated to Hilbert spaces $\mathcal{X}$ and $\mathcal{Y}$ respectively. It is a good idea to consider what this looks like when $\mathcal{X} = \mathcal{Y} = \RR$ and when $\mathcal{X} = \mathcal{Y} = \RR^D$. It is sometimes helpful to refer to this definition to e.g. verify the correctness of the derivative of a function – as with single-variable calculus, however, this is rare.

Another aside: what does Forwards-Mode AD compute?

At this point we have enough machinery to discuss forwards-mode AD. Expressed in the language of linear operators and Hilbert spaces, the goal of forwards-mode AD is the following: given a function $f$ which is differentiable at a point $x$, compute $D f [x] (\dot{x})$ for a given vector $\dot{x}$. If $f : \RR^P \to \RR^Q$, this is equivalent to computing $J[x] \dot{x}$, where $J[x]$ is the Jacobian of $f$ at $x$. For the interested reader we provide a high-level explanation of how forwards-mode AD does this in How does Forwards-Mode AD work?.

Another aside: notation

You may have noticed that we typically denote the argument to a derivative with a "dot" over it, e.g. $\dot{x}$. This is something that we will do consistently, and we will use the same notation for the outputs of derivatives. Wherever you see a symbol with a "dot" over it, expect it to be an input or output of a derivative / forwards-mode AD.

Reverse-Mode AD: what does it do?

In order to explain what reverse-mode AD does, we first consider the "vector-Jacobian product" definition in Euclidean space which will be familiar to many readers. We then generalise.

Reverse-Mode AD: what does it do in Euclidean space?

In this setting, the goal of reverse-mode AD is the following: given a function $f : \RR^P \to \RR^Q$ which is differentiable at $x \in \RR^P$ with Jacobian $J[x]$ at $x$, compute $J[x]^\top \bar{y}$ for any $\bar{y} \in \RR^Q$. This is useful because we can obtain the gradient from this when $Q = 1$ by letting $\bar{y} = 1$.

Adjoint Operators

In order to generalise this algorithm to work with linear operators, we must first generalise the idea of multiplying a vector by the transpose of the Jacobian. The relevant concept here is that of the adjoint operator. Specifically, the adjoint $A^\ast$ of linear operator $A$ is the linear operator satisfying

\[\langle A^\ast \bar{y}, \dot{x} \rangle = \langle \bar{y}, A \dot{x} \rangle.\]

where $\langle \cdot, \cdot \rangle$ denotes the inner-product. The relationship between the adjoint and matrix transpose is: if $A (x) := J x$ for some matrix $J$, then $A^\ast (y) := J^\top y$.

Moreover, just as $(A B)^\top = B^\top A^\top$ when $A$ and $B$ are matrices, $(A B)^\ast = B^\ast A^\ast$ when $A$ and $B$ are linear operators. This result follows in short order from the definition of the adjoint operator – (and is a good exercise!)

Reverse-Mode AD: what does it do in general?

Equipped with adjoints, we can express reverse-mode AD only in terms of linear operators, dispensing with the need to express everything in terms of Jacobians. The goal of reverse-mode AD is as follows: given a differentiable function $f : \mathcal{X} \to \mathcal{Y}$, compute $D f [x]^\ast (\bar{y})$ for some $\bar{y}$.

Notation: $D f [x]^\ast$ denotes the single mathematical object which is the adjoint of $D f [x]$. It is a linear function from $\mathcal{Y}$ to $\mathcal{X}$. We may occassionally write it as $(D f [x])^\ast$ if there is some risk of confusion.

We will explain how reverse-mode AD goes about computing this after some worked examples.

Aside: Notation

You will have noticed that arguments to adjoints have thus far always had a "bar" over them, e.g. $\bar{y}$. This notation is common in the AD literature and will be used throughout. Additionally, this "bar" notation will be used for the outputs of adjoints of derivatives. So wherever you see a symbol with a "bar" over it, think "input or output of adjoint of derivative".

Some Worked Examples

We now present some worked examples in order to prime intuition, and to introduce the important classes of problems that will be encountered when doing AD in the Julia language. We will put all of these problems in a single general framework later on.

An Example with Matrix Calculus

We have introduced some mathematical abstraction in order to simplify the calculations involved in AD. To this end, we consider differentiating $f(X) := X^\top X$. Results for this and similar operations are given by [1]. A similar operation, but which maps from matrices to $\RR$ is discussed in Lecture 4 part 2 of the MIT course mentioned previouly. Both [1] and Lecture 4 part 2 provide approaches to obtaining the derivative of this function.

Following either resource will yield the derivative:

\[D f [X] (\dot{X}) = \dot{X}^\top X + X^\top \dot{X}\]

Observe that this is indeed a linear operator (i.e. it is linear in its argument, $\dot{X}$). (You can always plug it in to the definition of the Frechet derivative to confirm that it is indeed the derivative.)

In order to perform reverse-mode AD, we need to find the adjoint operator. Using the usual definition of the inner product between matrices,

\[\langle X, Y \rangle := \textrm{tr} (X^\top Y)\]

we can rearrange the inner product as follows:

\[\begin{align} +Algorithmic Differentiation · Mooncake.jl

Algorithmic Differentiation

This section introduces the mathematics behind AD. Even if you have worked with AD before, we recommend reading in order to acclimatise yourself to the perspective that Mooncake.jl takes on the subject.

Derivatives

A foundation on which all of AD is built the the derivate – we require a fairly general definition of it, which we build up to here.

Scalar-to-Scalar Functions

Consider first $f : \RR \to \RR$, which we require to be differentiable at $x \in \RR$. Its derivative at $x$ is usually thought of as the scalar $\alpha \in \RR$ such that

\[\text{d}f = \alpha \, \text{d}x .\]

Loosely speaking, by this notation we mean that for arbitrary small changes $\text{d} x$ in the input to $f$, the change in the output $\text{d} f$ is $\alpha \, \text{d}x$. We refer readers to the first few minutes of the first lecture mentioned before for a more careful explanation.

Vector-to-Vector Functions

The generalisation of this to Euclidean space should be familiar: if $f : \RR^P \to \RR^Q$ is differentiable at a point $x \in \RR^P$, then the derivative of $f$ at $x$ is given by the Jacobian matrix at $x$, denoted $J[x] \in \RR^{Q \times P}$, such that

\[\text{d}f = J[x] \, \text{d}x .\]

It is possible to stop here, as all the functions we shall need to consider can in principle be written as functions on some subset $\RR^P$.

However, when we consider differentiating computer programmes, we will have to deal with complicated nested data structures, e.g. structs inside Tuples inside Vectors etc. While all of these data structures can be mapped onto a flat vector in order to make sense of the Jacobian of a computer programme, this becomes very inconvenient very quickly. To see the problem, consider the Julia function whose input is of type Tuple{Tuple{Float64, Vector{Float64}}, Vector{Float64}, Float64} and whose output is of type Tuple{Vector{Float64}, Float64}. What kind of object might be use to represent the derivative of a function mapping between these two spaces? We certainly can treat these as structured "view" into a "flat" Vector{Float64}s, and then define a Jacobian, but actually finding this mapping is a tedious exercise, even if it quite obviously exists.

In fact, a more general formulation of the derivative is used all the time in the context of AD – the matrix calculus discussed by [1] and [2] (to name a couple) make use of a generalised form of the derivative in order to work with functions which map to and from matrices (albeit there are slight differences in naming conventions from text to text), without needing to "flatten" them into vectors in order to make sense of them.

In general, it will be much easier to avoid "flattening" operations wherever possible. In order to do so, we now introduce a generalised notion of the derivative.

Functions Between More General Spaces

In order to avoid the difficulties described above, we consider functions $f : \mathcal{X} \to \mathcal{Y}$, where $\mathcal{X}$ and $\mathcal{Y}$ are finite dimensional real Hilbert spaces (read: finite-dimensional vector space with an inner product, and real-valued scalars). This definition includes functions to / from $\RR$, $\RR^D$, but also real-valued matrices, and any other "container" for collections of real numbers. Furthermore, we shall see later how we can model all sorts of structured representations of data directly as such spaces.

For such spaces, the derivative of $f$ at $x \in \mathcal{X}$ is the linear operator (read: linear function) $D f [x] : \mathcal{X} \to \mathcal{Y}$ satisfying

\[\text{d}f = D f [x] \, (\text{d} x)\]

The purpose of this linear operator is to provide a linear approximation to $f$ which is accurate for arguments which are very close to $x$.

Please note that $D f [x]$ is a single mathematical object, despite the fact that 3 separate symbols are used to denote it – $D f [x] (\dot{x})$ denotes the application of the function $D f [x]$ to argument $\dot{x}$. Furthermore, the dot-notation ($\dot{x}$) does not have anything to do with time-derivatives, it is simply common notation used in the AD literature to denote the arguments of derivatives.

So, instead of thinking of the derivative as a number or a matrix, we think about it as a function. We can express the previous notions of the derivative in this language.

In the scalar case, rather than thinking of the derivative as being $\alpha$, we think of it is a the linear operator $D f [x] (\dot{x}) := \alpha \dot{x}$. Put differently, rather than thinking of the derivative as the slope of the tangent to $f$ at $x$, think of it as the function decribing the tangent itself. Observe that up until now we had only considered inputs to $D f [x]$ which were small ($\text{d} x$) – here we extend it to the entire space $\mathcal{X}$ and denote inputs in this space $\dot{x}$. Inputs $\dot{x}$ should be thoughts of as "directions", in the directional derivative sense (why this is true will be discussed later).

Similarly, if $\mathcal{X} = \RR^P$ and $\mathcal{Y} = \RR^Q$ then this operator can be specified in terms of the Jacobian matrix: $D f [x] (\dot{x}) := J[x] \dot{x}$ – brackets are used to emphasise that $D f [x]$ is a function, and is being applied to $\dot{x}$.[note_for_geometers]

To reiterate, for the rest of this document, we define the derivative to be "multiply by $\alpha$" or "multiply by $J[x]$", rather than to be $\alpha$ or $J[x]$. So whenever you see the word "derivative", you should think "linear function".

The Chain Rule

The chain rule is the result which makes AD work. Fortunately, it applies to this version of the derivative:

\[f = g \circ h \implies D f [x] = (D g [h(x)]) \circ (D h [x])\]

By induction, this extends to a collection of $N$ functions $f_1, \dots, f_N$:

\[f := f_N \circ \dots \circ f_1 \implies D f [x] = (D f_N [x_N]) \circ \dots \circ (D f_1 [x_1]),\]

where $x_{n+1} := f(x_n)$, and $x_1 := x$.

An aside: the definition of the Frechet Derivative

This definition of the derivative has a name: the Frechet derivative. It is a generalisation of the Total Derivative. Formally, we say that a function $f : \mathcal{X} \to \mathcal{Y}$ is differentiable at a point $x \in \mathcal{X}$ if there exists a linear operator $D f [x] : \mathcal{X} \to \mathcal{Y}$ (the derivative) satisfying

\[\lim_{\text{d} h \to 0} \frac{\| f(x + \text{d} h) - f(x) + D f [x] (\text{d} h) \|_\mathcal{Y}}{\| \text{d}h \|_\mathcal{X}} = 0,\]

where $\| \cdot \|_\mathcal{X}$ and $\| \cdot \|_\mathcal{Y}$ are the norms associated to Hilbert spaces $\mathcal{X}$ and $\mathcal{Y}$ respectively. It is a good idea to consider what this looks like when $\mathcal{X} = \mathcal{Y} = \RR$ and when $\mathcal{X} = \mathcal{Y} = \RR^D$. It is sometimes helpful to refer to this definition to e.g. verify the correctness of the derivative of a function – as with single-variable calculus, however, this is rare.

Another aside: what does Forwards-Mode AD compute?

At this point we have enough machinery to discuss forwards-mode AD. Expressed in the language of linear operators and Hilbert spaces, the goal of forwards-mode AD is the following: given a function $f$ which is differentiable at a point $x$, compute $D f [x] (\dot{x})$ for a given vector $\dot{x}$. If $f : \RR^P \to \RR^Q$, this is equivalent to computing $J[x] \dot{x}$, where $J[x]$ is the Jacobian of $f$ at $x$. For the interested reader we provide a high-level explanation of how forwards-mode AD does this in How does Forwards-Mode AD work?.

Another aside: notation

You may have noticed that we typically denote the argument to a derivative with a "dot" over it, e.g. $\dot{x}$. This is something that we will do consistently, and we will use the same notation for the outputs of derivatives. Wherever you see a symbol with a "dot" over it, expect it to be an input or output of a derivative / forwards-mode AD.

Reverse-Mode AD: what does it do?

In order to explain what reverse-mode AD does, we first consider the "vector-Jacobian product" definition in Euclidean space which will be familiar to many readers. We then generalise.

Reverse-Mode AD: what does it do in Euclidean space?

In this setting, the goal of reverse-mode AD is the following: given a function $f : \RR^P \to \RR^Q$ which is differentiable at $x \in \RR^P$ with Jacobian $J[x]$ at $x$, compute $J[x]^\top \bar{y}$ for any $\bar{y} \in \RR^Q$. This is useful because we can obtain the gradient from this when $Q = 1$ by letting $\bar{y} = 1$.

Adjoint Operators

In order to generalise this algorithm to work with linear operators, we must first generalise the idea of multiplying a vector by the transpose of the Jacobian. The relevant concept here is that of the adjoint operator. Specifically, the adjoint $A^\ast$ of linear operator $A$ is the linear operator satisfying

\[\langle A^\ast \bar{y}, \dot{x} \rangle = \langle \bar{y}, A \dot{x} \rangle.\]

where $\langle \cdot, \cdot \rangle$ denotes the inner-product. The relationship between the adjoint and matrix transpose is: if $A (x) := J x$ for some matrix $J$, then $A^\ast (y) := J^\top y$.

Moreover, just as $(A B)^\top = B^\top A^\top$ when $A$ and $B$ are matrices, $(A B)^\ast = B^\ast A^\ast$ when $A$ and $B$ are linear operators. This result follows in short order from the definition of the adjoint operator – (and is a good exercise!)

Reverse-Mode AD: what does it do in general?

Equipped with adjoints, we can express reverse-mode AD only in terms of linear operators, dispensing with the need to express everything in terms of Jacobians. The goal of reverse-mode AD is as follows: given a differentiable function $f : \mathcal{X} \to \mathcal{Y}$, compute $D f [x]^\ast (\bar{y})$ for some $\bar{y}$.

Notation: $D f [x]^\ast$ denotes the single mathematical object which is the adjoint of $D f [x]$. It is a linear function from $\mathcal{Y}$ to $\mathcal{X}$. We may occassionally write it as $(D f [x])^\ast$ if there is some risk of confusion.

We will explain how reverse-mode AD goes about computing this after some worked examples.

Aside: Notation

You will have noticed that arguments to adjoints have thus far always had a "bar" over them, e.g. $\bar{y}$. This notation is common in the AD literature and will be used throughout. Additionally, this "bar" notation will be used for the outputs of adjoints of derivatives. So wherever you see a symbol with a "bar" over it, think "input or output of adjoint of derivative".

Some Worked Examples

We now present some worked examples in order to prime intuition, and to introduce the important classes of problems that will be encountered when doing AD in the Julia language. We will put all of these problems in a single general framework later on.

An Example with Matrix Calculus

We have introduced some mathematical abstraction in order to simplify the calculations involved in AD. To this end, we consider differentiating $f(X) := X^\top X$. Results for this and similar operations are given by [1]. A similar operation, but which maps from matrices to $\RR$ is discussed in Lecture 4 part 2 of the MIT course mentioned previouly. Both [1] and Lecture 4 part 2 provide approaches to obtaining the derivative of this function.

Following either resource will yield the derivative:

\[D f [X] (\dot{X}) = \dot{X}^\top X + X^\top \dot{X}\]

Observe that this is indeed a linear operator (i.e. it is linear in its argument, $\dot{X}$). (You can always plug it in to the definition of the Frechet derivative to confirm that it is indeed the derivative.)

In order to perform reverse-mode AD, we need to find the adjoint operator. Using the usual definition of the inner product between matrices,

\[\langle X, Y \rangle := \textrm{tr} (X^\top Y)\]

we can rearrange the inner product as follows:

\[\begin{align} \langle \bar{Y}, D f [X] (\dot{X}) \rangle &= \langle \bar{Y}, \dot{X}^\top X + X^\top \dot{X} \rangle \nonumber \\ &= \textrm{tr} (\bar{Y}^\top \dot{X}^\top X) + \textrm{tr}(\bar{Y}^\top X^\top \dot{X}) \nonumber \\ &= \textrm{tr} ( [\bar{Y} X^\top]^\top \dot{X}) + \textrm{tr}( [X \bar{Y}]^\top \dot{X}) \nonumber \\ @@ -19,4 +19,4 @@ D f [x] (\dot{x}) &= [(D \mathcal{l} [g(x)]) \circ (D g [x])](\dot{x}) \nonumber \\ &= \langle \bar{y}, D g [x] (\dot{x}) \rangle \nonumber \\ &= \langle D g [x]^\ast (\bar{y}), \dot{x} \rangle, \nonumber -\end{align}\]

from which we conclude that $D g [x]^\ast (\bar{y})$ is the gradient of the composition $l \circ g$ at $x$.

The consequence is that we can always view the computation performed by reverse-mode AD as computing the gradient of the composition of the function in question and an inner product with the argument to the adjoint.

The above shows that if $\mathcal{Y} = \RR$ and $g$ is the function we wish to compute the gradient of, we can simply set $\bar{y} = 1$ and compute $D g [x]^\ast (\bar{y})$ to obtain the gradient of $g$ at $x$.

Summary

This document explains the core mathematical foundations of AD. It explains separately what is does, and how it goes about it. Some basic examples are given which show how these mathematical foundations can be applied to differentiate functions of matrices, and Julia functions.

Subsequent sections will build on these foundations, to provide a more general explanation of what AD looks like for a Julia programme.

Asides

How does Forwards-Mode AD work?

Forwards-mode AD achieves this by breaking down $f$ into the composition $f = f_N \circ \dots \circ f_1$, where each $f_n$ is a simple function whose derivative (function) $D f_n [x_n]$ we know for any given $x_n$. By the chain rule, we have that

\[D f [x] (\dot{x}) = D f_N [x_N] \circ \dots \circ D f_1 [x_1] (\dot{x})\]

which suggests the following algorithm:

  1. let $x_1 = x$, $\dot{x}_1 = \dot{x}$, and $n = 1$
  2. let $\dot{x}_{n+1} = D f_n [x_n] (\dot{x}_n)$
  3. let $x_{n+1} = f(x_n)$
  4. let $n = n + 1$
  5. if $n = N+1$ then return $\dot{x}_{N+1}$, otherwise go to 2.

When each function $f_n$ maps between Euclidean spaces, the applications of derivatives $D f_n [x_n] (\dot{x}_n)$ are given by $J_n \dot{x}_n$ where $J_n$ is the Jacobian of $f_n$ at $x_n$.

[1]
M. Giles. An extended collection of matrix derivative results for forward and reverse mode automatic differentiation. Unpublished (2008).
[2]
T. P. Minka. Old and new matrix algebra useful for statistics. See www. stat. cmu. edu/minka/papers/matrix. html 4 (2000).
  • note_for_geometersin AD we only really need to discuss differentiatiable functions between vector spaces that are isomorphic to Euclidean space. Consequently, a variety of considerations which are usually required in differential geometry are not required here. Notably, the tangent space is assumed to be the same everywhere, and to be the same as the domain of the function. Avoiding these additional considerations helps keep the mathematics as simple as possible.
+\end{align}\]

from which we conclude that $D g [x]^\ast (\bar{y})$ is the gradient of the composition $l \circ g$ at $x$.

The consequence is that we can always view the computation performed by reverse-mode AD as computing the gradient of the composition of the function in question and an inner product with the argument to the adjoint.

The above shows that if $\mathcal{Y} = \RR$ and $g$ is the function we wish to compute the gradient of, we can simply set $\bar{y} = 1$ and compute $D g [x]^\ast (\bar{y})$ to obtain the gradient of $g$ at $x$.

Summary

This document explains the core mathematical foundations of AD. It explains separately what is does, and how it goes about it. Some basic examples are given which show how these mathematical foundations can be applied to differentiate functions of matrices, and Julia functions.

Subsequent sections will build on these foundations, to provide a more general explanation of what AD looks like for a Julia programme.

Asides

How does Forwards-Mode AD work?

Forwards-mode AD achieves this by breaking down $f$ into the composition $f = f_N \circ \dots \circ f_1$, where each $f_n$ is a simple function whose derivative (function) $D f_n [x_n]$ we know for any given $x_n$. By the chain rule, we have that

\[D f [x] (\dot{x}) = D f_N [x_N] \circ \dots \circ D f_1 [x_1] (\dot{x})\]

which suggests the following algorithm:

  1. let $x_1 = x$, $\dot{x}_1 = \dot{x}$, and $n = 1$
  2. let $\dot{x}_{n+1} = D f_n [x_n] (\dot{x}_n)$
  3. let $x_{n+1} = f(x_n)$
  4. let $n = n + 1$
  5. if $n = N+1$ then return $\dot{x}_{N+1}$, otherwise go to 2.

When each function $f_n$ maps between Euclidean spaces, the applications of derivatives $D f_n [x_n] (\dot{x}_n)$ are given by $J_n \dot{x}_n$ where $J_n$ is the Jacobian of $f_n$ at $x_n$.

[1]
M. Giles. An extended collection of matrix derivative results for forward and reverse mode automatic differentiation. Unpublished (2008).
[2]
T. P. Minka. Old and new matrix algebra useful for statistics. See www. stat. cmu. edu/minka/papers/matrix. html 4 (2000).
  • note_for_geometersin AD we only really need to discuss differentiatiable functions between vector spaces that are isomorphic to Euclidean space. Consequently, a variety of considerations which are usually required in differential geometry are not required here. Notably, the tangent space is assumed to be the same everywhere, and to be the same as the domain of the function. Avoiding these additional considerations helps keep the mathematics as simple as possible.
diff --git a/dev/debug_mode/index.html b/dev/debug_mode/index.html index 350eceb26..582362f66 100644 --- a/dev/debug_mode/index.html +++ b/dev/debug_mode/index.html @@ -1,6 +1,6 @@ -Debug Mode · Mooncake.jl

Debug Mode

The Problem

A major source of potential problems in AD systems is rules returning the wrong type of tangent / fdata / rdata for a given primal value. For example, if someone writes a rule like

function rrule!!(::CoDual{typeof(+)}, x::CoDual{<:Real}, y::CoDual{<:Real})
+Debug Mode · Mooncake.jl

Debug Mode

The Problem

A major source of potential problems in AD systems is rules returning the wrong type of tangent / fdata / rdata for a given primal value. For example, if someone writes a rule like

function rrule!!(::CoDual{typeof(+)}, x::CoDual{<:Real}, y::CoDual{<:Real})
     plus_reverse_pass(dz::Real) = NoRData(), dz, dz
     return zero_fcodual(primal(x) + primal(y))
-end

and calls

rrule(zero_fcodual(+), zero_fcodual(5.0), zero_fcodual(4f0))

then the type of dz on the reverse pass will be Float64 (assuming everything happens correctly), and this rule will return a Float64 as the rdata for y. However, the primal value of y is a Float32, so the appropriate tangent type is also Float32. This error may causes the reverse pass to fail loudly, but it may fail silently. It may cause an error much later in the reverse pass, making it hard to determine that the source of the error was the above rule. Worst of all, in some cases it could plausibly cause a segfault, which is more-or-less the worst kind of outcome possible.

The Solution

Check that the types of the fdata / rdata associated to arguments are exactly what tangent_type / fdata_type / rdata_type require upon entry to / exit from rules and pullbacks.

This is implemented via DebugRRule:

Mooncake.DebugRRuleType
DebugRRule(rule)

Construct a callable which is equivalent to rule, but inserts additional type checking. In particular:

  • check that the fdata in each argument is of the correct type for the primal
  • check that the fdata in the CoDual returned from the rule is of the correct type for the primal.

This happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.

Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).

Let rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.

Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.

source

You can straightforwardly enable it when building a rule via the debug_mode kwarg in the following:

Mooncake.build_rruleFunction
build_rrule(args...; debug_mode=false)

Helper method. Only uses static information from args.

source
build_rrule(sig::Type{<:Tuple})

Equivalent to build_rrule(Mooncake.get_interpreter(), sig).

source
build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}

Returns a DerivedRule which is an rrule!! for sig_or_mi in context C. See the docstring for rrule!! for more info.

If debug_mode is true, then all calls to rules are replaced with calls to DebugRRules.

source

When using ADTypes.jl, you can choose whether or not to use it via the debug_mode kwarg:

julia> AutoMooncake(; config=Mooncake.Config(; debug_mode=true))
-AutoMooncake{Mooncake.Config}(Mooncake.Config(true, false))

When Should You Use Debug Mode?

Only use debug_mode when debugging a problem. This is because is has substantial performance implications.

+end

and calls

rrule(zero_fcodual(+), zero_fcodual(5.0), zero_fcodual(4f0))

then the type of dz on the reverse pass will be Float64 (assuming everything happens correctly), and this rule will return a Float64 as the rdata for y. However, the primal value of y is a Float32, so the appropriate tangent type is also Float32. This error may causes the reverse pass to fail loudly, but it may fail silently. It may cause an error much later in the reverse pass, making it hard to determine that the source of the error was the above rule. Worst of all, in some cases it could plausibly cause a segfault, which is more-or-less the worst kind of outcome possible.

The Solution

Check that the types of the fdata / rdata associated to arguments are exactly what tangent_type / fdata_type / rdata_type require upon entry to / exit from rules and pullbacks.

This is implemented via DebugRRule:

Mooncake.DebugRRuleType
DebugRRule(rule)

Construct a callable which is equivalent to rule, but inserts additional type checking. In particular:

  • check that the fdata in each argument is of the correct type for the primal
  • check that the fdata in the CoDual returned from the rule is of the correct type for the primal.

This happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.

Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).

Let rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.

Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.

source

You can straightforwardly enable it when building a rule via the debug_mode kwarg in the following:

Mooncake.build_rruleFunction
build_rrule(args...; debug_mode=false)

Helper method. Only uses static information from args.

source
build_rrule(sig::Type{<:Tuple})

Equivalent to build_rrule(Mooncake.get_interpreter(), sig).

source
build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}

Returns a DerivedRule which is an rrule!! for sig_or_mi in context C. See the docstring for rrule!! for more info.

If debug_mode is true, then all calls to rules are replaced with calls to DebugRRules.

source

When using ADTypes.jl, you can choose whether or not to use it via the debug_mode kwarg:

julia> AutoMooncake(; config=Mooncake.Config(; debug_mode=true))
+AutoMooncake{Mooncake.Config}(Mooncake.Config(true, false))

When Should You Use Debug Mode?

Only use debug_mode when debugging a problem. This is because is has substantial performance implications.

diff --git a/dev/debugging_and_mwes/index.html b/dev/debugging_and_mwes/index.html index 406f831d0..6a06e7655 100644 --- a/dev/debugging_and_mwes/index.html +++ b/dev/debugging_and_mwes/index.html @@ -1,10 +1,10 @@ -Debugging and MWEs · Mooncake.jl

Debugging and MWEs

There's a reasonable chance that you'll run into an issue with Mooncake.jl at some point. In order to debug what is going on when this happens, or to produce an MWE, it is helpful to have a convenient way to run Mooncake.jl on whatever function and arguments you have which are causing problems.

We recommend making use of Mooncake.jl's testing functionality to generate your test cases:

Mooncake.TestUtils.test_ruleFunction
test_rule(
+Debugging and MWEs · Mooncake.jl

Debugging and MWEs

There's a reasonable chance that you'll run into an issue with Mooncake.jl at some point. In order to debug what is going on when this happens, or to produce an MWE, it is helpful to have a convenient way to run Mooncake.jl on whatever function and arguments you have which are causing problems.

We recommend making use of Mooncake.jl's testing functionality to generate your test cases:

Mooncake.TestUtils.test_ruleFunction
test_rule(
     rng, x...;
     interface_only=false,
     is_primitive::Bool=true,
     perf_flag::Symbol=:none,
     interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(),
     debug_mode::Bool=false,
-)

Run standardised tests on the rule for x. The first element of x should be the primal function to test, and each other element a positional argument. In most cases, elements of x can just be the primal values, and randn_tangent can be relied upon to generate an appropriate tangent to test. Some notable exceptions exist though, in partcular Ptrs. In this case, the argument for which randn_tangent cannot be readily defined should be a CoDual containing the primal, and a manually constructed tangent field.

This function uses Mooncake.build_rrule to construct a rule. This will use an rrule!! if one exists, and derive a rule otherwise.

source

This approach is convenient because it can

  1. check whether AD runs at all,
  2. check whether AD produces the correct answers,
  3. check whether AD is performant, and
  4. can be used without having to manually generate tangents.

Example

For example

f(x) = Core.bitcast(Float64, x)
-Mooncake.TestUtils.test_rule(Random.Xoshiro(123), f, 3; is_primitive=false)

will error. (In this particular case, it is caused by Mooncake.jl preventing you from doing (potentially) unsafe casting. In this particular instance, Mooncake.jl just fails to compile, but in other instances other things can happen.)

In any case, the point here is that Mooncake.TestUtils.test_rule provides a convenient way to produce and report an error.

Segfaults

These are everyone's least favourite kind of problem, and they should be extremely rare in Mooncake.jl. However, if you are unfortunate enough to encounter one, please re-run your problem with the debug_mode kwarg set to true. See Debug Mode for more info. In general, this will catch problems before they become segfaults, at which point the above strategy for debugging and error reporting should work well.

+)

Run standardised tests on the rule for x. The first element of x should be the primal function to test, and each other element a positional argument. In most cases, elements of x can just be the primal values, and randn_tangent can be relied upon to generate an appropriate tangent to test. Some notable exceptions exist though, in partcular Ptrs. In this case, the argument for which randn_tangent cannot be readily defined should be a CoDual containing the primal, and a manually constructed tangent field.

This function uses Mooncake.build_rrule to construct a rule. This will use an rrule!! if one exists, and derive a rule otherwise.

source

This approach is convenient because it can

  1. check whether AD runs at all,
  2. check whether AD produces the correct answers,
  3. check whether AD is performant, and
  4. can be used without having to manually generate tangents.

Example

For example

f(x) = Core.bitcast(Float64, x)
+Mooncake.TestUtils.test_rule(Random.Xoshiro(123), f, 3; is_primitive=false)

will error. (In this particular case, it is caused by Mooncake.jl preventing you from doing (potentially) unsafe casting. In this particular instance, Mooncake.jl just fails to compile, but in other instances other things can happen.)

In any case, the point here is that Mooncake.TestUtils.test_rule provides a convenient way to produce and report an error.

Segfaults

These are everyone's least favourite kind of problem, and they should be extremely rare in Mooncake.jl. However, if you are unfortunate enough to encounter one, please re-run your problem with the debug_mode kwarg set to true. See Debug Mode for more info. In general, this will catch problems before they become segfaults, at which point the above strategy for debugging and error reporting should work well.

diff --git a/dev/index.html b/dev/index.html index b8ef71e2b..c3144ad08 100644 --- a/dev/index.html +++ b/dev/index.html @@ -1,2 +1,2 @@ -Mooncake.jl · Mooncake.jl

Mooncake.jl

Documentation for Mooncake.jl is on its way!

Note (03/10/2024): Various bits of utility functionality are now carefully documented. This includes how to change the code which Mooncake sees, declare that the derivative of a function is zero, make use of existing ChainRules.rrules to quicky create new rules in Mooncake, and more.

Note (02/07/2024): The first round of documentation has arrived. This is largely targetted at those who are interested in contributing to Mooncake.jl – you can find this work in the "Understanding Mooncake.jl" section of the docs. There is more to to do, but it should be sufficient to understand how AD works in principle, and the core abstractions underlying Mooncake.jl.

Note (29/05/2024): I (Will) am currently actively working on the documentation. It will be merged in chunks over the next month or so as good first drafts of sections are completed. Please don't be alarmed that not all of it is here!

+Mooncake.jl · Mooncake.jl

Mooncake.jl

Documentation for Mooncake.jl is on its way!

Note (03/10/2024): Various bits of utility functionality are now carefully documented. This includes how to change the code which Mooncake sees, declare that the derivative of a function is zero, make use of existing ChainRules.rrules to quicky create new rules in Mooncake, and more.

Note (02/07/2024): The first round of documentation has arrived. This is largely targetted at those who are interested in contributing to Mooncake.jl – you can find this work in the "Understanding Mooncake.jl" section of the docs. There is more to to do, but it should be sufficient to understand how AD works in principle, and the core abstractions underlying Mooncake.jl.

Note (29/05/2024): I (Will) am currently actively working on the documentation. It will be merged in chunks over the next month or so as good first drafts of sections are completed. Please don't be alarmed that not all of it is here!

diff --git a/dev/known_limitations/index.html b/dev/known_limitations/index.html index 87f405a27..bfd95c9b4 100644 --- a/dev/known_limitations/index.html +++ b/dev/known_limitations/index.html @@ -1,5 +1,5 @@ -Known Limitations · Mooncake.jl

Known Limitations

Mooncake.jl has a number of known qualitative limitations, which we document here.

Mutation of Global Variables

While great care is taken in this package to prevent silent errors, this is one edge case that we have yet to provide a satisfactory solution for. Consider a function of the form:

julia> const x = Ref(1.0);
+Known Limitations · Mooncake.jl

Known Limitations

Mooncake.jl has a number of known qualitative limitations, which we document here.

Mutation of Global Variables

While great care is taken in this package to prevent silent errors, this is one edge case that we have yet to provide a satisfactory solution for. Consider a function of the form:

julia> const x = Ref(1.0);
 
 julia> function foo(y::Float64)
            x[] = y
@@ -38,4 +38,4 @@
 Mooncake.value_and_gradient!!(rule, foo, [5.0, 4.0])
 
 # output
-(4.0, (NoTangent(), [0.0, 1.0]))

The Solution

This is only really a problem for tangent / fdata / rdata generation functionality, such as zero_tangent. As a work-around, AD testing functionality permits users to pass in CoDuals. So if you are testing something involving a pointer, you will need to construct its tangent yourself, and pass a CoDual to e.g. Mooncake.TestUtils.test_rule.

While pointers tend to be a low-level implementation detail in Julia code, you could in principle actually be interested in differentiating a function of a pointer. In this case, you will not be able to use Mooncake.value_and_gradient!! as this requires the use of zero_tangent. Instead, you will need to use lower-level (internal) functionality, such as Mooncake.__value_and_gradient!!, or use the rule interface directly.

Honestly, your best bet is just to avoid differentiating functions whose arguments are pointers if you can.

+(4.0, (NoTangent(), [0.0, 1.0]))

The Solution

This is only really a problem for tangent / fdata / rdata generation functionality, such as zero_tangent. As a work-around, AD testing functionality permits users to pass in CoDuals. So if you are testing something involving a pointer, you will need to construct its tangent yourself, and pass a CoDual to e.g. Mooncake.TestUtils.test_rule.

While pointers tend to be a low-level implementation detail in Julia code, you could in principle actually be interested in differentiating a function of a pointer. In this case, you will not be able to use Mooncake.value_and_gradient!! as this requires the use of zero_tangent. Instead, you will need to use lower-level (internal) functionality, such as Mooncake.__value_and_gradient!!, or use the rule interface directly.

Honestly, your best bet is just to avoid differentiating functions whose arguments are pointers if you can.

diff --git a/dev/mathematical_interpretation/index.html b/dev/mathematical_interpretation/index.html index d7ee802a0..8a08bcaaf 100644 --- a/dev/mathematical_interpretation/index.html +++ b/dev/mathematical_interpretation/index.html @@ -1,5 +1,5 @@ -Mooncake.jl's Rule System · Mooncake.jl

Mooncake.jl's Rule System

Mooncake.jl's approach to AD is recursive. It has a single specification for what it means to differentiate a Julia callable, and basically two approaches to achieving this. This section of the documentation explains the former.

We take an iterative approach to this explanation, starting at a high-level and adding more depth as we go.

10,000 Foot View

A rule r(f, x) for a function f(x) "does reverse mode AD", and executes in two phases, known as the forwards pass and the reverse pass. In the forwards pass a rule executes the original function, and does some additional book-keeping in preparation for the reverse pass. On the reverse pass it undoes the computation from the forwards pass, "backpropagates" the gradient w.r.t. the output of the original function by applying the adjoint of the derivative of the original function to it, and writes the results of this computation to the correct places.

A precise mathematical model for the original function is therefore entirely crucial to this discussion, as it is needed to understand what the adjoint of its derivative is.

A Model For A Julia Function

Since Julia permits the in-place modification / mutation of many data structures, we cannot make a naive translation between a Julia function and a mathematical object. Rather, we will have to model the state of the arguments to a function both before and after execution. Moreover, since a function can allocate new memory as part of execution and return it to the calling scope, we must track that too.

Consider Only Externally-Visible Effects Of Function Evaluation

We wish to treat a given function as a black box – we care about what a function does, not how it does it – so we consider only the externally-visible results of executing it. There are two ways in which changes can be made externally visible.

Return Value

(This point hardly requires explanation, but for the sake of completeness we do so anyway.)

The most obvious way in which a result can be made visible outside of a function is via its return value. For example, letting bar(x) = sin(x), consider the function

function foo(x)
+Mooncake.jl's Rule System · Mooncake.jl

Mooncake.jl's Rule System

Mooncake.jl's approach to AD is recursive. It has a single specification for what it means to differentiate a Julia callable, and basically two approaches to achieving this. This section of the documentation explains the former.

We take an iterative approach to this explanation, starting at a high-level and adding more depth as we go.

10,000 Foot View

A rule r(f, x) for a function f(x) "does reverse mode AD", and executes in two phases, known as the forwards pass and the reverse pass. In the forwards pass a rule executes the original function, and does some additional book-keeping in preparation for the reverse pass. On the reverse pass it undoes the computation from the forwards pass, "backpropagates" the gradient w.r.t. the output of the original function by applying the adjoint of the derivative of the original function to it, and writes the results of this computation to the correct places.

A precise mathematical model for the original function is therefore entirely crucial to this discussion, as it is needed to understand what the adjoint of its derivative is.

A Model For A Julia Function

Since Julia permits the in-place modification / mutation of many data structures, we cannot make a naive translation between a Julia function and a mathematical object. Rather, we will have to model the state of the arguments to a function both before and after execution. Moreover, since a function can allocate new memory as part of execution and return it to the calling scope, we must track that too.

Consider Only Externally-Visible Effects Of Function Evaluation

We wish to treat a given function as a black box – we care about what a function does, not how it does it – so we consider only the externally-visible results of executing it. There are two ways in which changes can be made externally visible.

Return Value

(This point hardly requires explanation, but for the sake of completeness we do so anyway.)

The most obvious way in which a result can be made visible outside of a function is via its return value. For example, letting bar(x) = sin(x), consider the function

function foo(x)
     y = bar(x)
     z = bar(y)
     return z
@@ -51,7 +51,7 @@
            x::Float64
        end

you will find that it is

julia> tangent_type(Bar)
 MutableTangent{@NamedTuple{x::Float64}}

Primitive Types

We've already seen a couple of primitive types (Float64 and Int). The basic story here is that all primitive types require an explicit specification of what their tangent type must be.

One interesting case are Ptr types. The tangent type of a Ptr{P} is Ptr{T}, where T = tangent_type(P). For example

julia> tangent_type(Ptr{Float64})
-Ptr{Float64}
source

FData and RData

While tangents are the things used to represent gradients and are what high-level interfaces will return, they are not what gets propagated forwards and backwards by rules during AD.

Rather, during AD, Mooncake.jl makes a fundamental distinction between data which is identified by its address in memory (Arrays, mutable structs, etc), and data which is identified by its value (is-bits types such as Float64, Int, and structs thereof). In particular, memory which is identified by its address gets assigned a unique location in memory in which its gradient lives (that this "unique gradient address" system is essential will become apparent when we discuss aliasing later on). Conversely, the gradient w.r.t. a value type resides in another value type.

The following docstring provides the best in-depth explanation.

Mooncake.fdata_typeMethod
fdata_type(T)

Returns the type of the forwards data associated to a tangent of type T.

Extended help

Rules in Mooncake.jl do not operate on tangents directly. Rather, functionality is defined to split each tangent into two components, that we call fdata (forwards-pass data) and rdata (reverse-pass data). In short, any component of a tangent which is identified by its address (e.g. a mutable structs or an Array) gets passed around on the forwards-pass of AD and is incremented in-place on the reverse-pass, while components of tangents identified by their value get propagated and accumulated only on the reverse-pass.

Given a tangent type T, you can find out what type its fdata and rdata must be with fdata_type(T) and rdata_type(T) respectively. A consequence of this is that there is exactly one valid fdata type and rdata type for each primal type.

Given a tangent t, you can get its fdata and rdata using f = fdata(t) and r = rdata(t) respectively. f and r can be re-combined to recover the original tangent using the binary version of tangent: tangent(f, r). It must always hold that

tangent(fdata(t), rdata(t)) === t

The need for all of this is explained in the docs, but for now it suffices to consider our running examples again, and to see what their fdata and rdata look like.

Int

Ints are non-differentiable types, so there is nothing to pass around on the forwards- or reverse-pass. Therefore

julia> fdata_type(tangent_type(Int)), rdata_type(tangent_type(Int))
+Ptr{Float64}
source

FData and RData

While tangents are the things used to represent gradients and are what high-level interfaces will return, they are not what gets propagated forwards and backwards by rules during AD.

Rather, during AD, Mooncake.jl makes a fundamental distinction between data which is identified by its address in memory (Arrays, mutable structs, etc), and data which is identified by its value (is-bits types such as Float64, Int, and structs thereof). In particular, memory which is identified by its address gets assigned a unique location in memory in which its gradient lives (that this "unique gradient address" system is essential will become apparent when we discuss aliasing later on). Conversely, the gradient w.r.t. a value type resides in another value type.

The following docstring provides the best in-depth explanation.

Mooncake.fdata_typeMethod
fdata_type(T)

Returns the type of the forwards data associated to a tangent of type T.

Extended help

Rules in Mooncake.jl do not operate on tangents directly. Rather, functionality is defined to split each tangent into two components, that we call fdata (forwards-pass data) and rdata (reverse-pass data). In short, any component of a tangent which is identified by its address (e.g. a mutable structs or an Array) gets passed around on the forwards-pass of AD and is incremented in-place on the reverse-pass, while components of tangents identified by their value get propagated and accumulated only on the reverse-pass.

Given a tangent type T, you can find out what type its fdata and rdata must be with fdata_type(T) and rdata_type(T) respectively. A consequence of this is that there is exactly one valid fdata type and rdata type for each primal type.

Given a tangent t, you can get its fdata and rdata using f = fdata(t) and r = rdata(t) respectively. f and r can be re-combined to recover the original tangent using the binary version of tangent: tangent(f, r). It must always hold that

tangent(fdata(t), rdata(t)) === t

The need for all of this is explained in the docs, but for now it suffices to consider our running examples again, and to see what their fdata and rdata look like.

Int

Ints are non-differentiable types, so there is nothing to pass around on the forwards- or reverse-pass. Therefore

julia> fdata_type(tangent_type(Int)), rdata_type(tangent_type(Int))
 (NoFData, NoRData)

Float64

The tangent type of Float64 is Float64. Float64s are identified by their value / have no fixed address, so

julia> (fdata_type(Float64), rdata_type(Float64))
 (NoFData, Float64)

Vector{Float64}

The tangent type of Vector{Float64} is Vector{Float64}. A Vector{Float64} is identified by its address, so

julia> (fdata_type(Vector{Float64}), rdata_type(Vector{Float64}))
 (Vector{Float64}, NoRData)

Tuple{Float64, Vector{Float64}, Int}

This is an example of a type which has both fdata and rdata. The tangent type for Tuple{Float64, Vector{Float64}, Int} is Tuple{Float64, Vector{Float64}, NoTangent}. Tuples have no fixed memory address, so we interogate each field on its own. We have already established the fdata and rdata types for each element, so we recurse to obtain:

julia> T = tangent_type(Tuple{Float64, Vector{Float64}, Int})
@@ -70,7 +70,7 @@
            z::Int
        end

has tangent type

julia> tangent_type(Bar)
 MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}

and fdata / rdata types

julia> (fdata_type(tangent_type(Bar)), rdata_type(tangent_type(Bar)))
-(MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}, NoRData)

Primitive Types

As with tangents, each primitive type must specify what its fdata and rdata is. See specific examples for details.

source

CoDuals

CoDuals are simply used to bundle together a primal and an associated fdata, depending upon context. Occassionally, they are used to pair together a primal and a tangent.

A quick aside: Non-Differentiable Data

In the introduction to algorithmic differentiation, we assumed that the domain / range of function are the same as that of its derivative. Unfortunately, this story is only partly true. Matters are complicated by the fact that not all data types in Julia can reasonably be thought of as forming a Hilbert space. e.g. the String type.

Consequently we introduce the special type NoTangent, instances of which can be thought of as representing the set containing only a $0$ tangent. Morally speaking, for any non-differentiable data x, x + NoTangent() == x.

Other than non-differentiable data, the model of data in Julia as living in a real-valued finite dimensional Hilbert space is quite reasonable. Therefore, we hope readers will forgive us for largely ignoring the distinction between the domain and range of a function and that of its derivative in mathematical discussions, while simultaneously drawing a distinction when discussing code.

TODO: update this to cast e.g. each possible String as its own vector space containing only the 0 element. This works, even if it seems a little contrived.

The Rule Interface (Round 2)

Now that you've seen what data structures are used to represent gradients, we can describe in more depth the detail of how fdata and rdata are used to propagate gradients backwards on the reverse pass.

Consider the function

julia> foo(x::Tuple{Float64, Vector{Float64}}) = x[1] + sum(x[2])
+(MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}, NoRData)

Primitive Types

As with tangents, each primitive type must specify what its fdata and rdata is. See specific examples for details.

source

CoDuals

CoDuals are simply used to bundle together a primal and an associated fdata, depending upon context. Occassionally, they are used to pair together a primal and a tangent.

A quick aside: Non-Differentiable Data

In the introduction to algorithmic differentiation, we assumed that the domain / range of function are the same as that of its derivative. Unfortunately, this story is only partly true. Matters are complicated by the fact that not all data types in Julia can reasonably be thought of as forming a Hilbert space. e.g. the String type.

Consequently we introduce the special type NoTangent, instances of which can be thought of as representing the set containing only a $0$ tangent. Morally speaking, for any non-differentiable data x, x + NoTangent() == x.

Other than non-differentiable data, the model of data in Julia as living in a real-valued finite dimensional Hilbert space is quite reasonable. Therefore, we hope readers will forgive us for largely ignoring the distinction between the domain and range of a function and that of its derivative in mathematical discussions, while simultaneously drawing a distinction when discussing code.

TODO: update this to cast e.g. each possible String as its own vector space containing only the 0 element. This works, even if it seems a little contrived.

The Rule Interface (Round 2)

Now that you've seen what data structures are used to represent gradients, we can describe in more depth the detail of how fdata and rdata are used to propagate gradients backwards on the reverse pass.

Consider the function

julia> foo(x::Tuple{Float64, Vector{Float64}}) = x[1] + sum(x[2])
 foo (generic function with 1 method)

The fdata for x is a Tuple{NoFData, Vector{Float64}}, and its rdata is a Tuple{Float64, NoRData}. The function returns a Float64, which has no fdata, and whose rdata is Float64. So on the forwards pass there is really nothing that needs to happen with the fdata for x.

Under the framework introduced above, the model for this function is

\[f(x) = (x, x_1 + \sum_{n=1}^N (x_2)_n)\]

where the vector in the second element of x is of length $N$. Now, following our usual steps, the derivative is

\[D f [x](\dot{x}) = (\dot{x}, \dot{x}_1 + \sum_{n=1}^N (\dot{x}_2)_n)\]

A gradient for this is a tuple $(\bar{y}_x, \bar{y}_a)$ where $\bar{y}_a \in \RR$ and $\bar{y}_x \in \RR \times \RR^N$. A quick derivation will show that the adjoint is

\[D f [x]^\ast(\bar{y}) = ((\bar{y}_x)_1 + \bar{y}_a, (\bar{y}_x)_2 + \bar{y}_a \mathbf{1})\]

where $\mathbf{1}$ is the vector of length $N$ in which each element is equal to $1$. (Observe that this agrees with the result we derived earlier for functions which don't mutate their arguments).

Now that we know what the adjoint is, we'll write down the rrule!!, and then explain what is going on in terms of the adjoint. This hand-written implementation is to aid your understanding – Mooncake.jl should be relied upon to generate this code automatically in practice.

julia> function rrule!!(::CoDual{typeof(foo)}, x::CoDual{Tuple{Float64, Vector{Float64}}})
            dx_fdata = x.dx
            function dfoo_adjoint(dy::Float64)
@@ -105,4 +105,4 @@
 typeof(g)
 
 # output
-typeof(g) (singleton type of function g, subtype of Function)

Neither the value nor type of a are present in g. Since a doesn't enter g via its arguments, it is unclear how it should be handled in general.

+typeof(g) (singleton type of function g, subtype of Function)

Neither the value nor type of a are present in g. Since a doesn't enter g via its arguments, it is unclear how it should be handled in general.

diff --git a/dev/objects.inv b/dev/objects.inv index 5b58aed6525133d6bf11031ff512d758c73d5746..1bd177768bd9c5c51e2b4b0f1d44838504d0af95 100644 GIT binary patch delta 1553 zcmV+s2JZRI4F3#}fq&a_+cpq=$5-r3qiK^VB3tQ9TE8fYWv6NFI^#<+?TZE$Aqxpa zH~?uy?eyQf3xX6#QL@;kFCq=>IlEl!4OVF)zkthg!jFZX<~nb1?#gF!Mhbxo zhzklXD7jda@qFA;HTKrCqjEg^*({&IYnyA(EY-phP6=1XHg4LfM+4Mxa@zdQIS zY*7eE0)J`}FdBo`7{`GR+)Md@^;MQUb&J}HFMe|zOBBX22Wn!+c$f;2{36vWJW`CC zloFo5S5DpH>7L-nO*6+YW>uu~lLw*m7XLpIgQYGYU9hPOf-bn2RmlKrU97bIop7#` z;s68!A`9+0IxegSZ>&fZzIWwYQ$R(*8s~GwpMNdqA`4NaGJJN4q6S;<#fW1_5n~lu zo|7V}ajHi^Vnu8zdJeJruJ}!=KIbx(&T%v2*QY#ok$b z`+uBc1Y#cyY359BHyF1hSPdUiEa0*4SoMJFADUc`6e`YYLkOo4f)_(bhck&ddV`mB zmhD4MnRu?K>widV_$c{kkiD+EIr5<06}G73goS`*Sobh0L};3`AbEr%S+QPXXo=8kEB&|Clpp#jlQ zqIaCV)9P-XVGGbyHvXST*sWuGT$DH<+(wj~=Y?-c?;CAyk@CH;O5AQ7ZtH23aDP%> zz@dVZzkKp7P>6bPLbb2cq%$ruFc1I{uux4y$aveS+`TitaO`7rV1PA)FL)&btLB)_ z9lAu8ExOrA=hbxPdmjG&_f|TOX0(#EwVBS#InDP{{2j6z=ksJPRO-`nV^nX+5s!fy z#t-wpF25z0%*0gk?=B6t8h-e+gMR~CXnKPm15t1AlOO+mm9GS#&;y50*;rZWqaI^8 zYFR!ywhTy*#9zb+aTBCySoGrzuGN*g_KP*mjjdN}`hSa7yO_+` zNn4HF=DITVX|n2)=w6;39(;draDc-c`uV+*%MO+v9`pvsb#TE3v;ee#=ZX`Cl^kP~ zZuVyCuS~-|rmuEnF;wE@VI}w-pdBhP;XeUIW^es}!;|=5G20V#@U;?juscBqPuJ7n zt3Pq1nusutebh$_F~u3HrGMO1klqgmHBdBLHp$B7*Tz>HN?$F)+r$(tes5DYfHCXL zn#IHtYHz*>5;0Lm-DFC;=ahnzk}>+eZHZD*%&28fGP=ZE=a^@IjndWM-Yl8T4VKM4 zXYVFY@8xS`>b)Gt%CaKezJj_T7F@pR9L1yjY DILrS7 delta 1502 zcmV<41tI$X49yIXfqz?X+cp$_=dU4P=Y|C~Z-qRyWr-?E#6MAl}N)}5sokVIjlUBxw5`VN;CfXs^htXye#??)> z6v;y3kD3Zi2axvO%A~=+L#$uE@OKU_Q!U|Y?Gm|Ui80zuX2wn|uDUW~ZL1I0v;er^ z861n`T{@pEO(e^(&1xUU^p4sinF-Iyh-k~na#hvU+znt_16+f%k0*|Oj?F5bXth+? z%_liY?J7@p1A8w|fg~W4 z0F(ffC>dbLiu-W952yRIXsg=b$j$(fDRV?qdRbBoVzlRj1fcN0^)*<3vw=2%27UU$@9JS?8x=!@675skBoHW+?BAA z^H%1gF<%+`dDuZyZ!khO{N2GvVT(dQ5>SzV(HOkOavb=;y_64F5oOL(JyBb^#5$bp3*lTR#Q_KeL>An0@+FzIWx@P(VS!YW_Lm&sPhP zgs4&pKD$Iwjji``#4)6Zv4||sNs(0Z>JgAw7Jplcol8X;46KUPV^C*RX+tnE^q4Wz3<_2M@$&b~1LX zEY=2LF(vs*Z+{#@KN;1AbVAJeMrG`7uc+h6BYuN zVco;15TR+#g5(jdV6lHtLUS2!)5_b*fwt6FnlkyA>1)S{VlW&Pa)Pp0ceVrzqah8FB0EmEvY8pbupMRao z-G%XmW1m9@2G}t8f|oO}VvgzDp>t&3qU()xQA}sP=i%>vZ>5XUjFz*uHq%8ur}{Q10223Dz{3mI-KpyV-nustwebPruF~y0fr8pFj-j4?rP-zmZljW_djju97zFvm6iOCxB-bZNw%d9hN zmJ>^;z4=Oz5)) +Running Tests Locally · Mooncake.jl

Running Tests Locally

Mooncake.jl's test suite is fairly extensive. While you can use Pkg.test to run the test suite in the standard manner, this is not usually optimal in Mooncake.jl. When editing some code, you typically only want to run the tests associated with it, not the entire test suite.

Mooncake's tests are organised as follows:

  1. Things that are required for most / all test suites are loaded up in test/front_matter.jl.
  2. The tests for something in src are located in an identically-named file in test. e.g. the unit tests for src/rrules/new.jl are located in test/rrules/new.jl.

Thus, a workflow that I (Will) find works very well is the following:

  1. Ensure that you have Revise.jl and TestEnv.jl installed in your default environment.
  2. start the REPL, dev Mooncake.jl, and navigate to the top level of the Mooncake.jl directory.
  3. using TestEnv, Revise. Better still, load both of these in your .julia/config/startup.jl file so that you don't ever forget to load them.
  4. Run the following: using Pkg; Pkg.activate("."); TestEnv.activate(); include("test/front_matter.jl"); to set up your environment.
  5. include whichever test file you want to run the tests from.
  6. Modify code, and re-include tests to check it has done was you need. Loop this until done.
  7. Make a PR. This runs the entire test suite – I find that I almost never run the entire test suite locally.

The purpose of this approach is to:

  1. Avoid restarting the REPL each time you make a change, and
  2. Run the smallest bit of the test suite possible when making changes, in order to make development a fast and enjoyable process.

If you find that this strategy leaves you running more of the test suite than you would like, consider copy + pasting specific tests into the REPL, or commenting out a chunk of tests in the file that you are editing during development (try not to commit this). I find this is rather crude strategy effective in practice.

diff --git a/dev/search_index.js b/dev/search_index.js index 8efe94c6b..354187c59 100644 --- a/dev/search_index.js +++ b/dev/search_index.js @@ -1,3 +1,3 @@ var documenterSearchIndex = {"docs": -[{"location":"mathematical_interpretation/#Mooncake.jl's-Rule-System","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Mooncake.jl's approach to AD is recursive. It has a single specification for what it means to differentiate a Julia callable, and basically two approaches to achieving this. This section of the documentation explains the former.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We take an iterative approach to this explanation, starting at a high-level and adding more depth as we go.","category":"page"},{"location":"mathematical_interpretation/#10,000-Foot-View","page":"Mooncake.jl's Rule System","title":"10,000 Foot View","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A rule r(f, x) for a function f(x) \"does reverse mode AD\", and executes in two phases, known as the forwards pass and the reverse pass. In the forwards pass a rule executes the original function, and does some additional book-keeping in preparation for the reverse pass. On the reverse pass it undoes the computation from the forwards pass, \"backpropagates\" the gradient w.r.t. the output of the original function by applying the adjoint of the derivative of the original function to it, and writes the results of this computation to the correct places.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A precise mathematical model for the original function is therefore entirely crucial to this discussion, as it is needed to understand what the adjoint of its derivative is.","category":"page"},{"location":"mathematical_interpretation/#A-Model-For-A-Julia-Function","page":"Mooncake.jl's Rule System","title":"A Model For A Julia Function","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Since Julia permits the in-place modification / mutation of many data structures, we cannot make a naive translation between a Julia function and a mathematical object. Rather, we will have to model the state of the arguments to a function both before and after execution. Moreover, since a function can allocate new memory as part of execution and return it to the calling scope, we must track that too.","category":"page"},{"location":"mathematical_interpretation/#Consider-Only-Externally-Visible-Effects-Of-Function-Evaluation","page":"Mooncake.jl's Rule System","title":"Consider Only Externally-Visible Effects Of Function Evaluation","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We wish to treat a given function as a black box – we care about what a function does, not how it does it – so we consider only the externally-visible results of executing it. There are two ways in which changes can be made externally visible.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Return Value","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"(This point hardly requires explanation, but for the sake of completeness we do so anyway.)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The most obvious way in which a result can be made visible outside of a function is via its return value. For example, letting bar(x) = sin(x), consider the function","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function foo(x)\n y = bar(x)\n z = bar(y)\n return z\nend","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The communication between the two invocations of bar happen via the value it returns.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Modification of arguments","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In contrast to the above, changes made by one function can be made available to another implicitly if it modifies the values of its arguments, even if it doesn't return anything. For example, consider:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function bar(x::Vector{Float64})\n x .*= 2\n return nothing\nend\n\nfunction foo(x::Vector{Float64})\n bar(x)\n bar(x)\n return x\nend","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The second call to bar in foo sees the changes made to x by the first call to bar, despite not being explicitly returned.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"No Global Mutable State","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"functions can in principle also communicate via global mutable state. We make the decision to not support this.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For example, we assume functions of the following form cannot be encountered:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"const a = randn(10)\n\nfunction bar(x)\n a .+= x\n return nothing\nend\n\nfunction foo(x)\n bar(x)\n return a\nend","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this example, a is modified by bar, the effect of which is visible to foo.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For a variety of reasons this is very awkward to handle well. Since it's largely considered poor practice anyway, we explicitly outlaw this mode of communication between functions. See Why Support Closures But Not Mutable Globals for more info.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Note that this does not preclude the use of closed-over values or callable structs. For example, something like","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function foo(x)\n function bar(y)\n x .+= y\n return nothing\n end\n return bar(x)\nend","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"is perfectly fine.","category":"page"},{"location":"mathematical_interpretation/#The-Model","page":"Mooncake.jl's Rule System","title":"The Model","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"It is helpful to have a concrete example which uses both of the permissible methods to make results externally visible. To this end, consider the following function:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function f(x::Vector{Float64}, y::Vector{Float64}, z::Vector{Float64}, s::Ref{Vector{Float64}})\n z .*= y .* x\n s[] = 2z\n return sum(z)\nend","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We draw your attention to a variety of features of this function:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"z is mutated,\ns is mutated to reference freshly allocated memory,\nthe value previously pointed to by s is unmodified, and\nwe allocate a new value and return it (albeit, it is probably allocated on the stack).","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The model we adopt for any Julia function f is a function f mathcalX to mathcalX times mathcalA where mathcalX is the real finite Hilbert space associated to the arguments to f prior to execution, and mathcalA is the real finite Hilbert space associated to any newly allocated data during execution which is externally visible after execution – any newly allocated data which is not made visible is of no concern.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this example, mathcalX = RR^D times RR^D times RR^D times RR^S where D is the length of x / y / z, and S the length of s[] prior to running f. mathcalA = RR^D times RR, where the RR^D component corresponds to the freshly allocated memory that s references, and RR to the return value. Observe that we model Float64s as elements of RR, Vector{Float64}s as elements of RR^D (for some value of D), and Refs with whatever the model for their contents is. The keen-eyed reader will note that these choices abstract away several details which could conceivably be included in the model. In particular, Vector{Float64} is implemented via a memory buffer, a pointer to the start of this buffer, and an integer which indicates the length of this buffer – none of these details are exposed in the model.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this example, some of the memory allocated during execution is made externally visible by modifying one of the arguments, not just via the return value.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The argument to f is the arguments to f before execution, and the output is the 2-tuple comprising the same arguments after execution and the values associated to any newly allocated / created data. Crucially, observe that we distinguish between the state of the arguments before and after execution.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For our example, the exact form of f is","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f((x y z s)) = ((x y x odot y s) (2 x odot y sum_d=1^D x odot y))","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that f behaves a little like a transition operator, in the that the first element of the tuple returned is the updated state of the arguments.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This model is good enough for the vast majority of functions. Unfortunately it isn't sufficient to describe a function when arguments alias each other (e.g. consider the way in which this particular model is wrong if y aliases z). Fortunately this is only a problem in a small fraction of all cases of aliasing, so we defer discussion of this until later on.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider now how this approach can be used to model several additional Julia functions, and to obtain their derivatives and adjoints.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"sin(x::Float64)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"mathcalX = RR, mathcalA = RR, f(x) = (x sin(x)).","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Thus the derivative is D f x (dotx) = (dotx cos(x) dotx), and its adjoint is D f x^ast (bary) = bary_x + bary_a cos(x), where bary = (bary_x bary_a).","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that this result is slightly different to the last example we saw involving sin.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"AD With Mutable Data","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider again","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function f!(x::Vector{Float64})\n x .*= x\n return sum(x)\nend","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Our framework is able to accomodate this function, and has essentially the same solution as the last time we saw this example:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f(x) = (x odot x sum_n=1^N x_n^2)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Non-Mutating Functions","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A very interesting class of functions are those which do not modify their arguments. These are interesting because they are common, and are all that many AD frameworks like ChainRules.jl / Zygote.jl support – by considering this class of functions, we highlight some key similarities between these distinct rule systems.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"As always we can model these kinds of functions with a function f mathcalX to mathcalX times mathcalA, but we additionally have that f must have the form","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f(x) = (x varphi(x))","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"for some function varphi mathcalX to mathcalA. The derivative is","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x (dotx) = (dotx D varphi x(dotx))","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider the usual inner product to derive the adjoint:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"beginalign\n langle bary D f x (dotx) rangle = langle (bary_1 bary_2) (dotx D varphi x(dotx)) rangle nonumber \n = langle bary_1 dotx rangle + langle bary_2 D varphi x(dotx) rangle nonumber \n = langle bary_1 dotx rangle + langle D varphi x^ast (bary_2) dotx rangle nonumber quad text(by definition of the adjoint) \n = langle bary_1 + D varphi x^ast (bary_2) dotx rangle nonumber\nendalign","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"So the adjoint of the derivative is","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x^ast (bary) = bary_1 + D varphi x^ast (bary_2)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We see the correct thing to do is to increment the gradient of the output – bary_1 – by the result of applying the adjoint of the derivative of varphi to bary_2. In a ChainRules.rrule the bary_1 term is always zero, but the D varphi x^ast (bary_2) term is essentially the same.","category":"page"},{"location":"mathematical_interpretation/#The-Rule-Interface-(Round-1)","page":"Mooncake.jl's Rule System","title":"The Rule Interface (Round 1)","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Having explained in principle what it is that a rule must do, we now take a first look at the interface we use to achieve this. A rule for a function foo with signature","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Tuple{typeof(foo), Float64} -> Float64","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"must have signature","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Tuple{Trule, CoDual{typeof(foo), NoFData}, CoDual{Float64, NoFData}} ->\n Tuple{CoDual{Float64, NoFData}, Trvs_pass}","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For example, if we call foo(5.0), it rules would be called as rule(CoDual(foo, NoFData()), CoDual(5.0, NoFData())). The precise definition and role of NoFData will be explained shortly, but the general scheme is that to a rule for foo you must pass foo itself, its arguments, and some additional data for book-keeping. foo and each of its arguments are paired with this additional book-keeping data via the CoDual type.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The rule returns another CoDual (it propagates book-keeping information forwards), along with a function which runs the reverse pass.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In a little more depth:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Notation: primal","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Throughout the rest of this document, we will refer to the function being differentiated as the \"primal\" computation, and its arguments as the \"primal\" arguments.","category":"page"},{"location":"mathematical_interpretation/#Forwards-Pass","page":"Mooncake.jl's Rule System","title":"Forwards Pass","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Inputs","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Each piece of each input to the primal is paired with shadow data, if it has a fixed address. For example, a Vector{Float64} argument is paired with another Vector{Float64}. The adjoint of f is accumulated into this shadow vector on the reverse pass. However, a Float64 argument gets paired with NoFData(), since it is a bits type and therefore has no fixed address.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Outputs","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A rule must return a Tuple of two things. The first thing must be a CoDual containing the output of the primal computation and its shadow memory (if it has any). The second must be a function which runs the reverse pass of AD – this will usually be a closure of some kind.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Functionality","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A rule must","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"ensure that the state of the primal components of all inputs / the output are as they would have been had the primal computation been run (up to differences due to finite precision arithmetic),\npropagate / construct the shadow memory associated to the output (initialised to zero), and\nconstruct the function to run the reverse pass – typically this will involve storing some quantities computed during the forwards pass.","category":"page"},{"location":"mathematical_interpretation/#Reverse-Pass","page":"Mooncake.jl's Rule System","title":"Reverse Pass","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The second element of the output of a rule is a function which runs the reverse pass.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Inputs","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The \"rdata\" associated to the output of the primal.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Outputs","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The \"rdata\" associated to the inputs of the primal.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Functionality","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"undo changes made to primal state on the forwards pass.\napply adjoint of derivative of primal operation, putting the results in the correct place.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This description should leave you with (at least) a couple of questions. What is \"rdata\", and what is \"the correct place\" to put the results of applying the adjoint of the derivative? In order to address these, we need to discuss the types that Mooncake.jl uses to represent the results of AD, and to propagate results backwards on the reverse pass.","category":"page"},{"location":"mathematical_interpretation/#Representing-Gradients","page":"Mooncake.jl's Rule System","title":"Representing Gradients","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We refer to both inputs and outputs of derivatives D f x mathcalX to mathcalY as tangents, e.g. dotx or doty. Conversely, we refer to both inputs and outputs to the adjoint of this derivative D f x^ast mathcalY to mathcalX as gradients, e.g. bary and barx.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Note, however, that the sets involved are the same whether dealing with a derivative or its adjoint. Consequently, we use the same type to represent both.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Representing Gradients","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This package assigns to each type in Julia a unique tangent_type, the purpose of which is to contain the gradients computed during reverse mode AD. The extended docstring for tangent_type provides the best introduction to the types which are used to represent tangents / gradients.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"tangent_type(P)","category":"page"},{"location":"mathematical_interpretation/#Mooncake.tangent_type-Tuple{Any}","page":"Mooncake.jl's Rule System","title":"Mooncake.tangent_type","text":"tangent_type(P)\n\nThere must be a single type used to represents tangents of primals of type P, and it must be given by tangent_type(P).\n\nExtended help\n\nThe tangent types which Mooncake.jl uses are quite similar in spirit to ChainRules.jl. For example, tangent \"vectors\" for\n\nFloat64s are Float64s,\nVector{Float64}s are Vector{Float64}s, and\nstructs are other another (special) struct with field types specified recursively.\n\nThere are, however, some major differences. Firstly, while it is certainly true that the above tangent types are permissible in ChainRules.jl, they are not the uniquely permissible types. For example, ZeroTangent is also a permissible type of tangent for any of them, and Float32 is permissible for Float64. This is a general theme in ChainRules.jl – it intentionally declines to place restrictions on what type can be used to represent the tangent of a given type.\n\nMooncake.jl differs from this. It insists that each primal type is associated to a single tangent type. Furthermore, this type is always given by the function Mooncake.tangent_type(primal_type).\n\nConsider some more worked examples.\n\nInt\n\nInt is not a differentiable type, so its tangent type is NoTangent:\n\njulia> tangent_type(Int)\nNoTangent\n\nTuples\n\nThe tangent type of a Tuple is defined recursively based on its field types. For example\n\njulia> tangent_type(Tuple{Float64, Vector{Float64}, Int})\nTuple{Float64, Vector{Float64}, NoTangent}\n\nThere is one edge case to be aware of: if all of the field of a Tuple are non-differentiable, then the tangent type is NoTangent. For example,\n\njulia> tangent_type(Tuple{Int, Int})\nNoTangent\n\nStructs\n\nAs with Tuples, the tangent type of a struct is, by default, given recursively. In particular, the tangent type of a struct type is Tangent. This type contains a NamedTuple containing the tangent to each field in the primal struct.\n\nAs with Tuples, if all field types are non-differentiable, the tangent type of the entire struct is NoTangent.\n\nThere are a couple of additional subtleties to consider over Tuples though. Firstly, not all fields of a struct have to be defined. Fortunately, Julia makes it easy to determine how many of the fields might possibly not be defined. The tangent associated to any field which might possibly not be defined is wrapped in a PossiblyUninitTangent.\n\nFurthermore, structs can have fields whose static type is abstract. For example\n\njulia> struct Foo\n x\n end\n\nIf you ask for the tangent type of Foo, you will see that it is\n\njulia> tangent_type(Foo)\nTangent{@NamedTuple{x}}\n\nObserve that the field type associated to x is Any. The way to understand this result is to observe that\n\nx could have literally any type at runtime, so we know nothing about what its tangent type must be until runtime, and\nwe require that the tangent type of Foo be unique.\n\nThe consequence of these two considerations is that the tangent type of Foo must be able to contain any type of tangent in its x field. It follows that the fieldtype of the x field of Foos tangent must be Any.\n\nMutable Structs\n\nThe tangent type for mutable structs have the same set of considerations as structs. The only difference is that they must themselves be mutable. Consequently, we use a type called MutableTangent to represent their tangents. It is a mutable struct with the same structure as Tangent.\n\nFor example, if you ask for the tangent_type of\n\njulia> mutable struct Bar\n x::Float64\n end\n\nyou will find that it is\n\njulia> tangent_type(Bar)\nMutableTangent{@NamedTuple{x::Float64}}\n\nPrimitive Types\n\nWe've already seen a couple of primitive types (Float64 and Int). The basic story here is that all primitive types require an explicit specification of what their tangent type must be.\n\nOne interesting case are Ptr types. The tangent type of a Ptr{P} is Ptr{T}, where T = tangent_type(P). For example\n\njulia> tangent_type(Ptr{Float64})\nPtr{Float64}\n\n\n\n\n\n","category":"method"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"FData and RData","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"While tangents are the things used to represent gradients and are what high-level interfaces will return, they are not what gets propagated forwards and backwards by rules during AD.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Rather, during AD, Mooncake.jl makes a fundamental distinction between data which is identified by its address in memory (Arrays, mutable structs, etc), and data which is identified by its value (is-bits types such as Float64, Int, and structs thereof). In particular, memory which is identified by its address gets assigned a unique location in memory in which its gradient lives (that this \"unique gradient address\" system is essential will become apparent when we discuss aliasing later on). Conversely, the gradient w.r.t. a value type resides in another value type.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The following docstring provides the best in-depth explanation.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Mooncake.fdata_type(T)","category":"page"},{"location":"mathematical_interpretation/#Mooncake.fdata_type-Tuple{Any}","page":"Mooncake.jl's Rule System","title":"Mooncake.fdata_type","text":"fdata_type(T)\n\nReturns the type of the forwards data associated to a tangent of type T.\n\nExtended help\n\nRules in Mooncake.jl do not operate on tangents directly. Rather, functionality is defined to split each tangent into two components, that we call fdata (forwards-pass data) and rdata (reverse-pass data). In short, any component of a tangent which is identified by its address (e.g. a mutable structs or an Array) gets passed around on the forwards-pass of AD and is incremented in-place on the reverse-pass, while components of tangents identified by their value get propagated and accumulated only on the reverse-pass.\n\nGiven a tangent type T, you can find out what type its fdata and rdata must be with fdata_type(T) and rdata_type(T) respectively. A consequence of this is that there is exactly one valid fdata type and rdata type for each primal type.\n\nGiven a tangent t, you can get its fdata and rdata using f = fdata(t) and r = rdata(t) respectively. f and r can be re-combined to recover the original tangent using the binary version of tangent: tangent(f, r). It must always hold that\n\ntangent(fdata(t), rdata(t)) === t\n\nThe need for all of this is explained in the docs, but for now it suffices to consider our running examples again, and to see what their fdata and rdata look like.\n\nInt\n\nInts are non-differentiable types, so there is nothing to pass around on the forwards- or reverse-pass. Therefore\n\njulia> fdata_type(tangent_type(Int)), rdata_type(tangent_type(Int))\n(NoFData, NoRData)\n\nFloat64\n\nThe tangent type of Float64 is Float64. Float64s are identified by their value / have no fixed address, so\n\njulia> (fdata_type(Float64), rdata_type(Float64))\n(NoFData, Float64)\n\nVector{Float64}\n\nThe tangent type of Vector{Float64} is Vector{Float64}. A Vector{Float64} is identified by its address, so\n\njulia> (fdata_type(Vector{Float64}), rdata_type(Vector{Float64}))\n(Vector{Float64}, NoRData)\n\nTuple{Float64, Vector{Float64}, Int}\n\nThis is an example of a type which has both fdata and rdata. The tangent type for Tuple{Float64, Vector{Float64}, Int} is Tuple{Float64, Vector{Float64}, NoTangent}. Tuples have no fixed memory address, so we interogate each field on its own. We have already established the fdata and rdata types for each element, so we recurse to obtain:\n\njulia> T = tangent_type(Tuple{Float64, Vector{Float64}, Int})\nTuple{Float64, Vector{Float64}, NoTangent}\n\njulia> (fdata_type(T), rdata_type(T))\n(Tuple{NoFData, Vector{Float64}, NoFData}, Tuple{Float64, NoRData, NoRData})\n\nThe zero tangent for (5.0, [5.0]) is t = (0.0, [0.0]). fdata(t) returns (NoFData(), [0.0]), where the second element is === to the second element of t. rdata(t) returns (0.0, NoRData()). In this example, t contains a mixture of data, some of which is identified by its value, and some of which is identified by its address, so there is some fdata and some rdata.\n\nStructs\n\nStructs are handled in more-or-less the same way as Tuples, albeit with the possibility of undefined fields needing to be explicitly handled. For example, a struct such as\n\njulia> struct Foo\n x::Float64\n y\n z::Int\n end\n\nhas tangent type\n\njulia> tangent_type(Foo)\nTangent{@NamedTuple{x::Float64, y, z::NoTangent}}\n\nIts fdata and rdata are given by special FData and RData types:\n\njulia> (fdata_type(tangent_type(Foo)), rdata_type(tangent_type(Foo)))\n(Mooncake.FData{@NamedTuple{x::NoFData, y, z::NoFData}}, Mooncake.RData{@NamedTuple{x::Float64, y, z::NoRData}})\n\nPractically speaking, FData and RData both have the same structure as Tangents and are just used in different contexts.\n\nMutable Structs\n\nThe fdata for a mutable structs is its tangent, and it has no rdata. This is because mutable structs have fixed memory addresses, and can therefore be incremented in-place. For example,\n\njulia> mutable struct Bar\n x::Float64\n y\n z::Int\n end\n\nhas tangent type\n\njulia> tangent_type(Bar)\nMutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}\n\nand fdata / rdata types\n\njulia> (fdata_type(tangent_type(Bar)), rdata_type(tangent_type(Bar)))\n(MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}, NoRData)\n\nPrimitive Types\n\nAs with tangents, each primitive type must specify what its fdata and rdata is. See specific examples for details.\n\n\n\n\n\n","category":"method"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"CoDuals","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"CoDuals are simply used to bundle together a primal and an associated fdata, depending upon context. Occassionally, they are used to pair together a primal and a tangent.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A quick aside: Non-Differentiable Data","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In the introduction to algorithmic differentiation, we assumed that the domain / range of function are the same as that of its derivative. Unfortunately, this story is only partly true. Matters are complicated by the fact that not all data types in Julia can reasonably be thought of as forming a Hilbert space. e.g. the String type.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consequently we introduce the special type NoTangent, instances of which can be thought of as representing the set containing only a 0 tangent. Morally speaking, for any non-differentiable data x, x + NoTangent() == x.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Other than non-differentiable data, the model of data in Julia as living in a real-valued finite dimensional Hilbert space is quite reasonable. Therefore, we hope readers will forgive us for largely ignoring the distinction between the domain and range of a function and that of its derivative in mathematical discussions, while simultaneously drawing a distinction when discussing code.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"TODO: update this to cast e.g. each possible String as its own vector space containing only the 0 element. This works, even if it seems a little contrived.","category":"page"},{"location":"mathematical_interpretation/#The-Rule-Interface-(Round-2)","page":"Mooncake.jl's Rule System","title":"The Rule Interface (Round 2)","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Now that you've seen what data structures are used to represent gradients, we can describe in more depth the detail of how fdata and rdata are used to propagate gradients backwards on the reverse pass.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"DocTestSetup = quote\n using Mooncake\n using Mooncake: CoDual\n import Mooncake: rrule!!\nend","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider the function","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> foo(x::Tuple{Float64, Vector{Float64}}) = x[1] + sum(x[2])\nfoo (generic function with 1 method)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The fdata for x is a Tuple{NoFData, Vector{Float64}}, and its rdata is a Tuple{Float64, NoRData}. The function returns a Float64, which has no fdata, and whose rdata is Float64. So on the forwards pass there is really nothing that needs to happen with the fdata for x.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Under the framework introduced above, the model for this function is","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f(x) = (x x_1 + sum_n=1^N (x_2)_n)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"where the vector in the second element of x is of length N. Now, following our usual steps, the derivative is","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x(dotx) = (dotx dotx_1 + sum_n=1^N (dotx_2)_n)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A gradient for this is a tuple (bary_x bary_a) where bary_a in RR and bary_x in RR times RR^N. A quick derivation will show that the adjoint is","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x^ast(bary) = ((bary_x)_1 + bary_a (bary_x)_2 + bary_a mathbf1)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"where mathbf1 is the vector of length N in which each element is equal to 1. (Observe that this agrees with the result we derived earlier for functions which don't mutate their arguments).","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Now that we know what the adjoint is, we'll write down the rrule!!, and then explain what is going on in terms of the adjoint. This hand-written implementation is to aid your understanding – Mooncake.jl should be relied upon to generate this code automatically in practice.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> function rrule!!(::CoDual{typeof(foo)}, x::CoDual{Tuple{Float64, Vector{Float64}}})\n dx_fdata = x.dx\n function dfoo_adjoint(dy::Float64)\n dx_fdata[2] .+= dy\n dx_1_rdata = dy\n dx_rdata = (dx_1_rdata, NoRData())\n return NoRData(), dx_rdata\n end\n x_p = x.x\n return CoDual(x_p[1] + sum(x_p[2]), NoFData()), dfoo_adjoint\n end;\n","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"where dy is the rdata for the output to foo. The rrule!! can be called with the appropriate CoDuals:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> out, pb!! = rrule!!(CoDual(foo, NoFData()), CoDual((5.0, [1.0, 2.0]), (NoFData(), [0.0, 0.0])))\n(CoDual{Float64, NoFData}(8.0, NoFData()), var\"#dfoo_adjoint#1\"{Tuple{NoFData, Vector{Float64}}}((NoFData(), [0.0, 0.0])))","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"and the pullback with appropriate rdata:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> pb!!(1.0)\n(NoRData(), (1.0, NoRData()))","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"DocTestSetup = nothing","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that the forwards pass:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"computes the result of the initial function, and\npulls out the fdata for the Vector{Float64} component of the argument.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"As promised, the forwards pass really has nothing to do with the adjoint. It's just book-keeping and running the primal computation.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The reverse pass:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"increments each element of dx_fdata[2] by dy – this corresponds to (bary_x)_2 + bary_a mathbf1 in the adjoint,\nsets dx_1_rdata to dy – this corresponds (bary_x)_1 + bary_a subject to the constraint that (bary_x)_1 = 0,\nconstructs the rdata for x – this is essentially just book-keeping.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Each of these items serve to demonstrate more general points. The first that, upon entry into the reverse pass, all fdata values correspond to gradients for the arguments / output of f \"upon exit\" (for the components of these which are identified by their address), and once the reverse-pass finishes running, they must contain the gradients w.r.t. the arguments of f \"upon entry\".","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The second that we always assume that the components of bary_x which are identified by their value have zero-rdata.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The third is that the components of the arguments of f which are identified by their value must have rdata passed back explicitly by a rule, while the components of the arguments to f which are identified by their address get their gradients propagated back implicitly (i.e. via the in-place modification of fdata).","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Reminder: the first element of the tuple returned by dfoo_adjoint is the rdata associated to foo itself, hence it is NoRData.","category":"page"},{"location":"mathematical_interpretation/#Testing","page":"Mooncake.jl's Rule System","title":"Testing","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Mooncake.jl has an almost entirely automated system for testing rules – Mooncake.TestUtils.test_rule. You should absolutely make use of these when writing rules.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"TODO: improve docstring for testing functionality.","category":"page"},{"location":"mathematical_interpretation/#Summary","page":"Mooncake.jl's Rule System","title":"Summary","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this section we have covered the rule system. Every callable object / function in the Julia language is differentiated using rules with this interface, whether they be hand-written rrule!!s, or rules derived by Mooncake.jl.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"At this point you should be equipped with enough information to understand what a rule in Mooncake.jl does, and how you can write your own ones. Later sections will explain how Mooncake.jl goes about deriving rules itself in a recursive manner, and introduce you to some of the internals.","category":"page"},{"location":"mathematical_interpretation/#Asides","page":"Mooncake.jl's Rule System","title":"Asides","text":"","category":"section"},{"location":"mathematical_interpretation/#Why-Uniqueness-of-Type-For-Tangents-/-FData-/-RData?","page":"Mooncake.jl's Rule System","title":"Why Uniqueness of Type For Tangents / FData / RData?","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Why does Mooncake.jl insist that each primal type P be paired with a single tangent type T, as opposed to being more permissive. There are a few notable reasons:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"To provide a precise interface. Rules pass fdata around on the forwards pass and rdata on the reverse pass – being able to make strong assumptions about the type of the fdata / rdata given the primal type makes implementing rules much easier in practice.\nConditional type stability. We wish to have a high degree of confidence that if the primal code is type-stable, then the AD code will also be. It is straightforward to construct type stable primal codes which have type-unstable forwards and reverse passes if you permit there to be more than one fdata / rdata type for a given primal. So while uniqueness is certainly not sufficient on its own to guarantee conditional type stability, it is probably necessary in general.\nTest-case generation and coverage. There being a unique tangent / fdata / rdata type for each primal makes being confident that a given rule is being tested thoroughly much easier. For a given primal, rather than there being many possible input / output types to consider, there is just one.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This topic, in particular what goes wrong with permissive tangent type systems like those employed by ChainRules, deserves a more thorough treatment – hopefully someone will write something more expansive on this topic at some point.","category":"page"},{"location":"mathematical_interpretation/#Why-Support-Closures-But-Not-Mutable-Globals","page":"Mooncake.jl's Rule System","title":"Why Support Closures But Not Mutable Globals","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"First consider why closures are straightforward to support. Look at the type of the closure produced by foo:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function foo(x)\n function bar(y)\n x .+= y\n return nothing\n end\n return bar\nend\nbar = foo(randn(5))\ntypeof(bar)\n\n# output\nvar\"#bar#1\"{Vector{Float64}}","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that the Vector{Float64} that we passed to foo, and closed over in bar, is present in the type. This alludes to the fact that closures are basically just callable structs whose fields are the closed-over variables. Since the function itself is an argument to its rule, everything enters the rule for bar via its arguments, and the rule system developed in this document applies straightforwardly.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"On the other hand, globals do not appear in the functions that they are a part of. For example,","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"const a = randn(10)\n\nfunction g(x)\n a .+= x\n return nothing\nend\n\ntypeof(g)\n\n# output\ntypeof(g) (singleton type of function g, subtype of Function)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Neither the value nor type of a are present in g. Since a doesn't enter g via its arguments, it is unclear how it should be handled in general.","category":"page"},{"location":"understanding_intro/#Mooncake.jl-and-Reverse-Mode-AD","page":"Introduction","title":"Mooncake.jl and Reverse-Mode AD","text":"","category":"section"},{"location":"understanding_intro/","page":"Introduction","title":"Introduction","text":"The point of Mooncake.jl is to perform reverse-mode algorithmic differentiation (AD). The purpose of this section is to explain what precisely is meant by this, and how it can be interpreted mathematically.","category":"page"},{"location":"understanding_intro/","page":"Introduction","title":"Introduction","text":"we recap what AD is, and introduce the mathematics necessary to understand is,\nexplain how this mathematics relates to functions and data structures in Julia, and\nhow this is handled in Mooncake.jl.","category":"page"},{"location":"understanding_intro/","page":"Introduction","title":"Introduction","text":"Since Mooncake.jl supports in-place operations / mutation, these will push beyond what is encountered in Zygote / Diffractor / ChainRules. Consequently, while there is a great deal of overlap with these existing systems, you will need to read through this section of the docs in order to properly understand Mooncake.jl.","category":"page"},{"location":"understanding_intro/#Who-Are-These-Docs-For?","page":"Introduction","title":"Who Are These Docs For?","text":"","category":"section"},{"location":"understanding_intro/","page":"Introduction","title":"Introduction","text":"These are primarily designed for anyone who is interested in contributing to Mooncake.jl. They are also hopefully of interest to anyone how is interested in understanding AD more broadly. If you aren't interested in understanding how Mooncake.jl and AD work, you don't need to have read them in order to make use of this package.","category":"page"},{"location":"understanding_intro/#Prerequisites-and-Resources","page":"Introduction","title":"Prerequisites and Resources","text":"","category":"section"},{"location":"understanding_intro/","page":"Introduction","title":"Introduction","text":"This introduction assumes familiarity with the differentiation of vector-valued functions – familiarity with the gradient and Jacobian matrices is a given.","category":"page"},{"location":"understanding_intro/","page":"Introduction","title":"Introduction","text":"In order to provide a convenient exposition of AD, we need to abstract a little further than this and make use of a slightly more general notion of the derivative, gradient, and \"transposed Jacobian\". Please note that, fortunately, we only ever have to handle finite dimensional objects when doing AD, so there is no need for any knowledge of functional analysis to understand what is going on here. The required concepts will be introduced here, but I cannot promise that these docs give the best exposition – they're most appropriate as a refresher and to establish notation. Rather, I would recommend a couple of lectures from the \"Matrix Calculus for Machine Learning and Beyond\" course, which you can find on MIT's OCW website, delivered by Edelman and Johnson (who will be familiar faces to anyone who has spent much time in the Julia world!). It is designed for undergraduates, and is accessible to anyone with some undergraduate-level linear algebra and calculus. While I recommend the whole course, Lecture 1 part 2 and Lecture 4 part 1 are especially relevant to the problems we shall discuss – you can skip to 11:30 in Lecture 4 part 1 if you're in a hurry.","category":"page"},{"location":"tools_for_rules/#Tools-for-Rules","page":"Tools for Rules","title":"Tools for Rules","text":"","category":"section"},{"location":"tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. However, this does not always necessitate writing your own rrule!! from scratch. In this section, we detail some useful strategies which can help you avoid having to write rrule!!s in many situations.","category":"page"},{"location":"tools_for_rules/#Simplfiying-Code-via-Overlays","page":"Tools for Rules","title":"Simplfiying Code via Overlays","text":"","category":"section"},{"location":"tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Mooncake.@mooncake_overlay","category":"page"},{"location":"tools_for_rules/#Mooncake.@mooncake_overlay","page":"Tools for Rules","title":"Mooncake.@mooncake_overlay","text":"@mooncake_overlay method_expr\n\nDefine a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.\n\nFor example, suppose that you have a function\n\njulia> foo(x::Float64) = bar(x)\nfoo (generic function with 1 method)\n\nwhere Mooncake.jl fails to differentiate bar for some reason. If you have access to another function baz, which does the same thing as bar, but does so in a way which Mooncake.jl can differentiate, you can simply write:\n\njulia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)\n\n\nWhen looking up the code for foo(::Float64), Mooncake.jl will see this method, rather than the original, and differentiate it instead.\n\nA Worked Example\n\nTo demonstrate how to use @mooncake_overlays in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!\n\nFirst, consider a simple example:\n\njulia> scale(x) = 2x\nscale (generic function with 1 method)\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(10.0, (NoTangent(), 2.0))\n\nWe can use @mooncake_overlay to change the definition which Mooncake.jl sees:\n\njulia> Mooncake.@mooncake_overlay scale(x) = 3x\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(15.0, (NoTangent(), 3.0))\n\nAs can be seen from the output, the result of differentiating using Mooncake.jl has changed to reflect the overlay-ed definition of the method.\n\nAdditionally, it is possible to use the usual multi-line syntax to declare an overlay:\n\njulia> Mooncake.@mooncake_overlay function scale(x)\n return 4x\n end\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(20.0, (NoTangent(), 4.0))\n\n\n\n\n\n","category":"macro"},{"location":"tools_for_rules/#Functions-with-Zero-Adjoint","page":"Tools for Rules","title":"Functions with Zero Adjoint","text":"","category":"section"},{"location":"tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"If the above strategy does not work, but you find yourself in the surprisingly common situation that the adjoint of the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following:","category":"page"},{"location":"tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Mooncake.@zero_adjoint\nMooncake.zero_adjoint","category":"page"},{"location":"tools_for_rules/#Mooncake.@zero_adjoint","page":"Tools for Rules","title":"Mooncake.@zero_adjoint","text":"@zero_adjoint ctx sig\n\nDefines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.\n\nFor example:\n\njulia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive\n\njulia> foo(x) = 5\nfoo (generic function with 1 method)\n\njulia> @zero_adjoint DefaultCtx Tuple{typeof(foo), Any}\n\njulia> is_primitive(DefaultCtx, Tuple{typeof(foo), Any})\ntrue\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(3.0))[2](NoRData())\n(NoRData(), 0.0)\n\nLimited support for Varargs is also available. For example\n\njulia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive\n\njulia> foo_varargs(x...) = 5\nfoo_varargs (generic function with 1 method)\n\njulia> @zero_adjoint DefaultCtx Tuple{typeof(foo_varargs), Vararg}\n\njulia> is_primitive(DefaultCtx, Tuple{typeof(foo_varargs), Any, Float64, Int})\ntrue\n\njulia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())\n(NoRData(), 0.0, NoRData())\n\nBe aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.\n\nWARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.\n\nAs always, you should use TestUtils.test_rule to ensure that you've not made a mistake.\n\nSignatures Unsupported By This Macro\n\nIf the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.\n\n\n\n\n\n","category":"macro"},{"location":"tools_for_rules/#Mooncake.zero_adjoint","page":"Tools for Rules","title":"Mooncake.zero_adjoint","text":"zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}\n\nUtility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.\n\nNOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.\n\nYou make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:\n\njulia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual\n\njulia> foo(x::Vararg{Int}) = 5\nfoo (generic function with 1 method)\n\njulia> is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(foo), Vararg{Int}}}) = true;\n\njulia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())\n(NoRData(), NoRData(), NoRData())\n\nWARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```\n\n\n\n\n\n","category":"function"},{"location":"tools_for_rules/#Using-ChainRules.jl","page":"Tools for Rules","title":"Using ChainRules.jl","text":"","category":"section"},{"location":"tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"ChainRules.jl provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the ChainRulesCore.rrule function. There are some instances where it is most convenient to implement a Mooncake.rrule!! by wrapping an existing ChainRulesCore.rrule.","category":"page"},{"location":"tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"There is enough similarity between these two systems that most of the boilerplate code can be avoided.","category":"page"},{"location":"tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Mooncake.@from_rrule","category":"page"},{"location":"tools_for_rules/#Mooncake.@from_rrule","page":"Tools for Rules","title":"Mooncake.@from_rrule","text":"@from_rrule ctx sig [has_kwargs=false]\n\nConvenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.\n\nArguments\n\nctx: A Mooncake context type\nsig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.\nhas_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.\n\nExample Usage\n\nA Basic Example\n\njulia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils\n\njulia> using ChainRulesCore\n\njulia> foo(x::Real) = 5x;\n\njulia> function ChainRulesCore.rrule(::typeof(foo), x::Real)\n foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), 5Ω\n return foo(x), foo_pb\n end;\n\njulia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat}\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(5.0))[2](1.0)\n(NoRData(), 5.0)\n\njulia> # Check that the rule works as intended.\n TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true)\nTest Passed\n\nAn Example with Keyword Arguments\n\njulia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils\n\njulia> using ChainRulesCore\n\njulia> foo(x::Real; cond::Bool) = cond ? 5x : 4x;\n\njulia> function ChainRulesCore.rrule(::typeof(foo), x::Real; cond::Bool)\n foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), cond ? 5Ω : 4Ω\n return foo(x; cond), foo_pb\n end;\n\njulia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true\n\njulia> _, pb = rrule!!(\n zero_fcodual(Core.kwcall),\n zero_fcodual((cond=false, )),\n zero_fcodual(foo),\n zero_fcodual(5.0),\n );\n\njulia> pb(3.0)\n(NoRData(), NoRData(), NoRData(), 12.0)\n\njulia> # Check that the rule works as intended.\n TestUtils.test_rule(\n Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true\n )\nTest Passed\n\nNotice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.\n\nLimitations\n\nIt is your responsibility to ensure that\n\ncalls with signature sig do not mutate their arguments,\nthe output of calls with signature sig does not alias any of the inputs.\n\nAs with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.\n\nArgument Type Constraints\n\nMany methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature\n\nTuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}\n\nThere are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.\n\nSuffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.\n\nConversions Between Different Tangent Type Systems\n\nUnder the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.\n\n\n\n\n\n","category":"macro"},{"location":"algorithmic_differentiation/#Algorithmic-Differentiation","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This section introduces the mathematics behind AD. Even if you have worked with AD before, we recommend reading in order to acclimatise yourself to the perspective that Mooncake.jl takes on the subject.","category":"page"},{"location":"algorithmic_differentiation/#Derivatives","page":"Algorithmic Differentiation","title":"Derivatives","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"A foundation on which all of AD is built the the derivate – we require a fairly general definition of it, which we build up to here.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Scalar-to-Scalar Functions","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Consider first f RR to RR, which we require to be differentiable at x in RR. Its derivative at x is usually thought of as the scalar alpha in RR such that","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"textdf = alpha textdx ","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Loosely speaking, by this notation we mean that for arbitrary small changes textd x in the input to f, the change in the output textd f is alpha textdx. We refer readers to the first few minutes of the first lecture mentioned before for a more careful explanation.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Vector-to-Vector Functions","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The generalisation of this to Euclidean space should be familiar: if f RR^P to RR^Q is differentiable at a point x in RR^P, then the derivative of f at x is given by the Jacobian matrix at x, denoted Jx in RR^Q times P, such that","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"textdf = Jx textdx ","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"It is possible to stop here, as all the functions we shall need to consider can in principle be written as functions on some subset RR^P.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"However, when we consider differentiating computer programmes, we will have to deal with complicated nested data structures, e.g. structs inside Tuples inside Vectors etc. While all of these data structures can be mapped onto a flat vector in order to make sense of the Jacobian of a computer programme, this becomes very inconvenient very quickly. To see the problem, consider the Julia function whose input is of type Tuple{Tuple{Float64, Vector{Float64}}, Vector{Float64}, Float64} and whose output is of type Tuple{Vector{Float64}, Float64}. What kind of object might be use to represent the derivative of a function mapping between these two spaces? We certainly can treat these as structured \"view\" into a \"flat\" Vector{Float64}s, and then define a Jacobian, but actually finding this mapping is a tedious exercise, even if it quite obviously exists.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In fact, a more general formulation of the derivative is used all the time in the context of AD – the matrix calculus discussed by [1] and [2] (to name a couple) make use of a generalised form of the derivative in order to work with functions which map to and from matrices (albeit there are slight differences in naming conventions from text to text), without needing to \"flatten\" them into vectors in order to make sense of them.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In general, it will be much easier to avoid \"flattening\" operations wherever possible. In order to do so, we now introduce a generalised notion of the derivative.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Functions Between More General Spaces","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to avoid the difficulties described above, we consider functions f mathcalX to mathcalY, where mathcalX and mathcalY are finite dimensional real Hilbert spaces (read: finite-dimensional vector space with an inner product, and real-valued scalars). This definition includes functions to / from RR, RR^D, but also real-valued matrices, and any other \"container\" for collections of real numbers. Furthermore, we shall see later how we can model all sorts of structured representations of data directly as such spaces.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"For such spaces, the derivative of f at x in mathcalX is the linear operator (read: linear function) D f x mathcalX to mathcalY satisfying","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"textdf = D f x (textd x)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The purpose of this linear operator is to provide a linear approximation to f which is accurate for arguments which are very close to x.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Please note that D f x is a single mathematical object, despite the fact that 3 separate symbols are used to denote it – D f x (dotx) denotes the application of the function D f x to argument dotx. Furthermore, the dot-notation (dotx) does not have anything to do with time-derivatives, it is simply common notation used in the AD literature to denote the arguments of derivatives.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"So, instead of thinking of the derivative as a number or a matrix, we think about it as a function. We can express the previous notions of the derivative in this language.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In the scalar case, rather than thinking of the derivative as being alpha, we think of it is a the linear operator D f x (dotx) = alpha dotx. Put differently, rather than thinking of the derivative as the slope of the tangent to f at x, think of it as the function decribing the tangent itself. Observe that up until now we had only considered inputs to D f x which were small (textd x) – here we extend it to the entire space mathcalX and denote inputs in this space dotx. Inputs dotx should be thoughts of as \"directions\", in the directional derivative sense (why this is true will be discussed later).","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Similarly, if mathcalX = RR^P and mathcalY = RR^Q then this operator can be specified in terms of the Jacobian matrix: D f x (dotx) = Jx dotx – brackets are used to emphasise that D f x is a function, and is being applied to dotx.[note_for_geometers]","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"To reiterate, for the rest of this document, we define the derivative to be \"multiply by alpha\" or \"multiply by Jx\", rather than to be alpha or Jx. So whenever you see the word \"derivative\", you should think \"linear function\".","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The Chain Rule","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The chain rule is the result which makes AD work. Fortunately, it applies to this version of the derivative:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f = g circ h implies D f x = (D g h(x)) circ (D h x)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"By induction, this extends to a collection of N functions f_1 dots f_N:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f = f_N circ dots circ f_1 implies D f x = (D f_N x_N) circ dots circ (D f_1 x_1)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where x_n+1 = f(x_n), and x_1 = x.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"An aside: the definition of the Frechet Derivative","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This definition of the derivative has a name: the Frechet derivative. It is a generalisation of the Total Derivative. Formally, we say that a function f mathcalX to mathcalY is differentiable at a point x in mathcalX if there exists a linear operator D f x mathcalX to mathcalY (the derivative) satisfying","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"lim_textd h to 0 frac f(x + textd h) - f(x) + D f x (textd h) _mathcalY textdh _mathcalX = 0","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where cdot _mathcalX and cdot _mathcalY are the norms associated to Hilbert spaces mathcalX and mathcalY respectively. It is a good idea to consider what this looks like when mathcalX = mathcalY = RR and when mathcalX = mathcalY = RR^D. It is sometimes helpful to refer to this definition to e.g. verify the correctness of the derivative of a function – as with single-variable calculus, however, this is rare.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Another aside: what does Forwards-Mode AD compute?","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"At this point we have enough machinery to discuss forwards-mode AD. Expressed in the language of linear operators and Hilbert spaces, the goal of forwards-mode AD is the following: given a function f which is differentiable at a point x, compute D f x (dotx) for a given vector dotx. If f RR^P to RR^Q, this is equivalent to computing Jx dotx, where Jx is the Jacobian of f at x. For the interested reader we provide a high-level explanation of how forwards-mode AD does this in How does Forwards-Mode AD work?.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Another aside: notation","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"You may have noticed that we typically denote the argument to a derivative with a \"dot\" over it, e.g. dotx. This is something that we will do consistently, and we will use the same notation for the outputs of derivatives. Wherever you see a symbol with a \"dot\" over it, expect it to be an input or output of a derivative / forwards-mode AD.","category":"page"},{"location":"algorithmic_differentiation/#Reverse-Mode-AD:-*what*-does-it-do?","page":"Algorithmic Differentiation","title":"Reverse-Mode AD: what does it do?","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to explain what reverse-mode AD does, we first consider the \"vector-Jacobian product\" definition in Euclidean space which will be familiar to many readers. We then generalise.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Reverse-Mode AD: what does it do in Euclidean space?","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In this setting, the goal of reverse-mode AD is the following: given a function f RR^P to RR^Q which is differentiable at x in RR^P with Jacobian Jx at x, compute Jx^top bary for any bary in RR^Q. This is useful because we can obtain the gradient from this when Q = 1 by letting bary = 1.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Adjoint Operators","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to generalise this algorithm to work with linear operators, we must first generalise the idea of multiplying a vector by the transpose of the Jacobian. The relevant concept here is that of the adjoint operator. Specifically, the adjoint A^ast of linear operator A is the linear operator satisfying","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"langle A^ast bary dotx rangle = langle bary A dotx rangle","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where langle cdot cdot rangle denotes the inner-product. The relationship between the adjoint and matrix transpose is: if A (x) = J x for some matrix J, then A^ast (y) = J^top y.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Moreover, just as (A B)^top = B^top A^top when A and B are matrices, (A B)^ast = B^ast A^ast when A and B are linear operators. This result follows in short order from the definition of the adjoint operator – (and is a good exercise!)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Reverse-Mode AD: what does it do in general?","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Equipped with adjoints, we can express reverse-mode AD only in terms of linear operators, dispensing with the need to express everything in terms of Jacobians. The goal of reverse-mode AD is as follows: given a differentiable function f mathcalX to mathcalY, compute D f x^ast (bary) for some bary.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Notation: D f x^ast denotes the single mathematical object which is the adjoint of D f x. It is a linear function from mathcalY to mathcalX. We may occassionally write it as (D f x)^ast if there is some risk of confusion.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We will explain how reverse-mode AD goes about computing this after some worked examples.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Aside: Notation","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"You will have noticed that arguments to adjoints have thus far always had a \"bar\" over them, e.g. bary. This notation is common in the AD literature and will be used throughout. Additionally, this \"bar\" notation will be used for the outputs of adjoints of derivatives. So wherever you see a symbol with a \"bar\" over it, think \"input or output of adjoint of derivative\".","category":"page"},{"location":"algorithmic_differentiation/#Some-Worked-Examples","page":"Algorithmic Differentiation","title":"Some Worked Examples","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We now present some worked examples in order to prime intuition, and to introduce the important classes of problems that will be encountered when doing AD in the Julia language. We will put all of these problems in a single general framework later on.","category":"page"},{"location":"algorithmic_differentiation/#An-Example-with-Matrix-Calculus","page":"Algorithmic Differentiation","title":"An Example with Matrix Calculus","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We have introduced some mathematical abstraction in order to simplify the calculations involved in AD. To this end, we consider differentiating f(X) = X^top X. Results for this and similar operations are given by [1]. A similar operation, but which maps from matrices to RR is discussed in Lecture 4 part 2 of the MIT course mentioned previouly. Both [1] and Lecture 4 part 2 provide approaches to obtaining the derivative of this function.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Following either resource will yield the derivative:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f X (dotX) = dotX^top X + X^top dotX","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Observe that this is indeed a linear operator (i.e. it is linear in its argument, dotX). (You can always plug it in to the definition of the Frechet derivative to confirm that it is indeed the derivative.)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to perform reverse-mode AD, we need to find the adjoint operator. Using the usual definition of the inner product between matrices,","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"langle X Y rangle = textrmtr (X^top Y)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"we can rearrange the inner product as follows:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\n langle barY D f X (dotX) rangle = langle barY dotX^top X + X^top dotX rangle nonumber \n = textrmtr (barY^top dotX^top X) + textrmtr(barY^top X^top dotX) nonumber \n = textrmtr ( barY X^top^top dotX) + textrmtr( X barY^top dotX) nonumber \n = langle barY X^top + X barY dotX rangle nonumber\nendalign","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We can read off the adjoint operator from the first argument to the inner product:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f X^ast (barY) = barY X^top + X barY","category":"page"},{"location":"algorithmic_differentiation/#AD-of-a-Julia-function:-a-trivial-example","page":"Algorithmic Differentiation","title":"AD of a Julia function: a trivial example","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We now turn to differentiating Julia functions (we use function to refer to the programming language construct, and function to refer to a more general mathematical concept). The way that Mooncake.jl handles immutable data is very similar to how Zygote / ChainRules do. For example, consider the Julia function","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x::Float64) = sin(x)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"If you've previously worked with ChainRules / Zygote, without thinking too hard about the formalisms we introduced previously (perhaps by considering a variety of partial derivatives) you can probably arrive at the following adjoint for the derivative of f:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"g -> g * cos(x)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Implicitly, you have performed three steps:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"model f as a differentiable function,\ncompute its derivative, and\ncompute the adjoint of the derivative.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"It is helpful to work through this simple example in detail, as the steps involved apply more generally. The goal is to spell out the steps involved in detail, as this detail becomes helpful in more complicated examples. If at any point this exercise feels pedantic, we ask you to stick with it.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 1: Differentiable Mathematical Model","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Obviously, we model the Julia function f as the function f RR to RR where","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x) = sin(x)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Observe that, we've made (at least) two modelling assumptions here:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"a Float64 is modelled as a real number,\nthe Julia function sin is modelled as the usual mathematical function sin.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"As promised we're being quite pedantic. While the first assumption is obvious and will remain true, we will shortly see examples where we have to work a bit harder to obtain a correspondence between a Julia function and a mathematical object.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 2: Compute Derivative","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now that we have a mathematical model, we can differentiate it:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x (dotx) = cos(x) dotx","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 3: Compute Adjoint of Derivative","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Given the derivative, we can find its adjoint:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"langle barf D f x(dotx) rangle = langle barf cos(x) dotx rangle = langle cos(x) barf dotx rangle","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"From here the adjoint can be read off from the first argument to the inner product:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x^ast (barf) = cos(x) barf","category":"page"},{"location":"algorithmic_differentiation/#AD-of-a-Julia-function:-a-slightly-less-trivial-example","page":"Algorithmic Differentiation","title":"AD of a Julia function: a slightly less trivial example","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now consider the Julia function","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x::Float64, y::Tuple{Float64, Float64}) = x + y[1] * y[2]","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Its adjoint is going to be something along the lines of","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"g -> (g, (y[2] * g, y[1] * g))","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"As before, we work through in detail.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 1: Differentiable Mathematical Model","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"There are a couple of aspects of f which require thought:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"it has two arguments – we've only handled single argument functions previously, and\nthe second argument is a Tuple – we've not yet decided how to model this.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"To this end, we define a mathematical notion of a tuple. A tuple is a collection of N elements, each of which is drawn from some set mathcalX_n. We denote by mathcalX = mathcalX_1 times dots times mathcalX_N the set of all N-tuples whose nth element is drawn from mathcalX_n. Provided that each mathcalX_n forms a finite Hilbert space, mathcalX forms a Hilbert space with","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"alpha x = (alpha x_1 dots alpha x_N),\nx + y = (x_1 + y_1 dots x_N + y_N), and\nlangle x y rangle = sum_n=1^N langle x_n y_n rangle.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We can think of multi-argument functions as single-argument functions of a tuple, so a reasonable mathematical model for f might be a function f RR times RR times RR to RR, where","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x y) = x + y_1 y_2","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Note that while the function is written with two arguments, you should treat them as a single tuple, where we've assigned the name x to the first element, and y to the second.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 2: Compute Derivative","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now that we have a mathematical object, we can differentiate it:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x y(dotx doty) = dotx + doty_1 y_2 + y_1 doty_2","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 3: Compute Adjoint of Derivative","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D fx y maps RR times RR times RR to RR, so D f x y^ast must map the other way. You should verify that the following follows quickly from the definition of the adjoint:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x y^ast (barf) = (barf (barf y_2 barf y_1))","category":"page"},{"location":"algorithmic_differentiation/#AD-with-mutable-data","page":"Algorithmic Differentiation","title":"AD with mutable data","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In the previous two examples there was an obvious mathematical model for the Julia function. Indeed this model was sufficiently obvious that it required little explanation. This is not always the case though, in particular, Julia functions which modify / mutate their inputs require a little more thought.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Consider the following Julia function:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"function f!(x::Vector{Float64})\n x .*= x\n return sum(x)\nend","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This function squares each element of its input in-place, and returns the sum of the result. So what is an appropriate mathematical model for this function?","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 1: Differentiable Mathematical Model","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The trick is to distinguish between the state of x upon entry to / exit from f!. In particular, let phi_textf RR^N to RR^N times RR be given by","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"phi_textf(x) = (x odot x sum_n=1^N x_n^2)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where odot denotes the Hadamard / element-wise product (corresponds to line x .*= x in the above code). The point here is that the inputs to phi_textf are the inputs to x upon entry to f!, and the value returned from phi_textf is a tuple containing the both the inputs upon exit from f! and the value returned by f!.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The remaining steps are straightforward now that we have the model.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 2: Compute Derivative","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The derivative of phi_textf is","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D phi_textf x(dotx) = (2 x odot dotx 2 sum_n=1^N x_n dotx_n)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 3: Compute Adjoint of Derivative","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The argument to the adjoint of the derivative must be a 2-tuple whose elements are drawn from RR^N times RR . Denote such a tuple as (bary_1 bary_2). Plugging this into an inner product with the derivative and rearranging yields","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\n langle (bary_1 bary_2) D phi_textf x (dotx) rangle = langle (bary_1 bary_2) (2 x odot dotx 2 sum_n=1^N x_n dotx_n) rangle nonumber \n = langle bary_1 2 x odot dotx rangle + langle bary_2 2 sum_n=1^N x_n dotx_n rangle nonumber \n = langle 2x odot bary_1 dotx rangle + langle 2 bary_2 x dotx rangle nonumber \n = langle 2 (x odot bary_1 + bary_2 x) dotx rangle nonumber\nendalign","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"So we can read off the adjoint to be","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D phi_textf x^ast (bary) = 2 (x odot bary_1 + bary_2 x)","category":"page"},{"location":"algorithmic_differentiation/#Reverse-Mode-AD:-*how*-does-it-do-it?","page":"Algorithmic Differentiation","title":"Reverse-Mode AD: how does it do it?","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now that we know what it is that AD computes, we need a rough understanding of how it computes it.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In short: reverse-mode AD breaks down a \"complicated\" function f into the composition of a collection of \"simple\" functions f_1 dots f_N, applies the chain rule, and takes the adjoint.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Specifically, we assume that we can express any function f as f = f_N circ dots circ f_1, and that we can compute the adjoint of the derivative for each f_n. From this, we can obtain the adjoint of f by applying the chain rule to the derivatives and taking the adjoint:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\nD f x^ast = (D f_N x_N circ dots circ D f_1 x_1)^ast nonumber \n = D f_1 x_1^ast circ dots circ D f_N x_N^ast nonumber\nendalign","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"For example, suppose that f(x) = sin(cos(texttr(X^top X))). One option to compute its adjoint is to figure it out by hand directly (probably using the chain rule somewhere). Instead, we could notice that f = f_4 circ f_3 circ f_2 circ f_1 where f_4 = sin, f_3 = cos, f_2 = texttr and f_1(X) = X^top X. We could derive the adjoint for each of these functions (a fairly straightforward task), and then compute","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x^ast (bary) = (D f_1 x_1^ast circ D f_2 x_2^ast circ D f_3 x_3^ast circ D f_4 x_4^ast)(1)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"in order to obtain the gradient of f. Reverse-mode AD essentially just does this. Modern systems have hand-written adjoints for (hopefully!) all of the \"simple\" functions you may wish to build a function such as f from (often there are hundreds of these), and composes them to compute the adjoint of f. A sketch of a more generic algorithm is as follows.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Forwards-Pass:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"x_1 = x, n = 1\nconstruct D f_n x_n^ast\nlet x_n+1 = f_n (x_n)\nlet n = n + 1\nif n N + 1 then go to step 2.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Reverse-Pass:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"let barx_N+1 = bary\nlet n = n - 1\nlet barx_n = D f_n x_n^ast (barx_n+1)\nif n = 1 return barx_1 else go to step 2.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"How does this relate to vector-Jacobian products?","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In Euclidean space, each derivative D f_n x_n(dotx_n) = J_nx_n dotx_n. Applying the chain rule to D f x and substituting this in yields","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Jx = J_Nx_N dots J_1x_1 ","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Taking the transpose and multiplying from the left by bary yields","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Jx^top bary = Jx_1^top_1 dots Jx_N^top_N bary ","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Comparing this with the expression in terms of adjoints and operators, we see that composition of adjoints of derivatives has been replaced with multiplying by transposed Jacobian matrices. This \"vector-Jacobian product\" expression is commonly used to explain AD, and is likely familiar to many readers.","category":"page"},{"location":"algorithmic_differentiation/#Directional-Derivatives-and-Gradients","page":"Algorithmic Differentiation","title":"Directional Derivatives and Gradients","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now we turn to using reverse-mode AD to compute the gradient of a function. In short, given a function g mathcalX to RR with derivative D g x at x, its gradient is equal to D g x^ast (1). We explain why in this section.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The derivative discussed here can be used to compute directional derivatives. Consider a function f mathcalX to RR with Frechet derivative D f x mathcalX to RR at x in mathcalX. Then D fx(dotx) returns the directional derivative in direction dotx.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Gradients are closely related to the adjoint of the derivative. Recall that the gradient of f at x is defined to be the vector nabla f (x) in mathcalX such that langle nabla f (x) dotx rangle gives the directional derivative of f at x in direction dotx. Having noted that D fx(dotx) is exactly this directional derivative, we can equivalently say that","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D fx(dotx) = langle nabla f (x) dotx rangle ","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The role of the adjoint is revealed when we consider f = mathcall circ g, where g mathcalX to mathcalY, mathcall(y) = langle bary y rangle, and bary in mathcalY is some fixed vector. Noting that D mathcall y(doty) = langle bary doty rangle, we apply the chain rule to obtain","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\nD f x (dotx) = (D mathcall g(x)) circ (D g x)(dotx) nonumber \n = langle bary D g x (dotx) rangle nonumber \n = langle D g x^ast (bary) dotx rangle nonumber\nendalign","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"from which we conclude that D g x^ast (bary) is the gradient of the composition l circ g at x.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The consequence is that we can always view the computation performed by reverse-mode AD as computing the gradient of the composition of the function in question and an inner product with the argument to the adjoint.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The above shows that if mathcalY = RR and g is the function we wish to compute the gradient of, we can simply set bary = 1 and compute D g x^ast (bary) to obtain the gradient of g at x.","category":"page"},{"location":"algorithmic_differentiation/#Summary","page":"Algorithmic Differentiation","title":"Summary","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This document explains the core mathematical foundations of AD. It explains separately what is does, and how it goes about it. Some basic examples are given which show how these mathematical foundations can be applied to differentiate functions of matrices, and Julia functions.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Subsequent sections will build on these foundations, to provide a more general explanation of what AD looks like for a Julia programme.","category":"page"},{"location":"algorithmic_differentiation/#Asides","page":"Algorithmic Differentiation","title":"Asides","text":"","category":"section"},{"location":"algorithmic_differentiation/#*How*-does-Forwards-Mode-AD-work?","page":"Algorithmic Differentiation","title":"How does Forwards-Mode AD work?","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Forwards-mode AD achieves this by breaking down f into the composition f = f_N circ dots circ f_1, where each f_n is a simple function whose derivative (function) D f_n x_n we know for any given x_n. By the chain rule, we have that","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x (dotx) = D f_N x_N circ dots circ D f_1 x_1 (dotx)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"which suggests the following algorithm:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"let x_1 = x, dotx_1 = dotx, and n = 1\nlet dotx_n+1 = D f_n x_n (dotx_n)\nlet x_n+1 = f(x_n)\nlet n = n + 1\nif n = N+1 then return dotx_N+1, otherwise go to 2.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"When each function f_n maps between Euclidean spaces, the applications of derivatives D f_n x_n (dotx_n) are given by J_n dotx_n where J_n is the Jacobian of f_n at x_n.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"M. Giles. An extended collection of matrix derivative results for forward and reverse mode automatic differentiation. Unpublished (2008).\n\n\n\nT. P. Minka. Old and new matrix algebra useful for statistics. See www. stat. cmu. edu/minka/papers/matrix. html 4 (2000).\n\n\n\n","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"[note_for_geometers]: in AD we only really need to discuss differentiatiable functions between vector spaces that are isomorphic to Euclidean space. Consequently, a variety of considerations which are usually required in differential geometry are not required here. Notably, the tangent space is assumed to be the same everywhere, and to be the same as the domain of the function. Avoiding these additional considerations helps keep the mathematics as simple as possible.","category":"page"},{"location":"debugging_and_mwes/#Debugging-and-MWEs","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"","category":"section"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"There's a reasonable chance that you'll run into an issue with Mooncake.jl at some point. In order to debug what is going on when this happens, or to produce an MWE, it is helpful to have a convenient way to run Mooncake.jl on whatever function and arguments you have which are causing problems.","category":"page"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"We recommend making use of Mooncake.jl's testing functionality to generate your test cases:","category":"page"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"Mooncake.TestUtils.test_rule","category":"page"},{"location":"debugging_and_mwes/#Mooncake.TestUtils.test_rule","page":"Debugging and MWEs","title":"Mooncake.TestUtils.test_rule","text":"test_rule(\n rng, x...;\n interface_only=false,\n is_primitive::Bool=true,\n perf_flag::Symbol=:none,\n interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(),\n debug_mode::Bool=false,\n)\n\nRun standardised tests on the rule for x. The first element of x should be the primal function to test, and each other element a positional argument. In most cases, elements of x can just be the primal values, and randn_tangent can be relied upon to generate an appropriate tangent to test. Some notable exceptions exist though, in partcular Ptrs. In this case, the argument for which randn_tangent cannot be readily defined should be a CoDual containing the primal, and a manually constructed tangent field.\n\nThis function uses Mooncake.build_rrule to construct a rule. This will use an rrule!! if one exists, and derive a rule otherwise.\n\n\n\n\n\n","category":"function"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"This approach is convenient because it can","category":"page"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"check whether AD runs at all,\ncheck whether AD produces the correct answers,\ncheck whether AD is performant, and\ncan be used without having to manually generate tangents.","category":"page"},{"location":"debugging_and_mwes/#Example","page":"Debugging and MWEs","title":"Example","text":"","category":"section"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"DocTestSetup = quote\n using Random, Mooncake\nend","category":"page"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"For example","category":"page"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"f(x) = Core.bitcast(Float64, x)\nMooncake.TestUtils.test_rule(Random.Xoshiro(123), f, 3; is_primitive=false)","category":"page"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"will error. (In this particular case, it is caused by Mooncake.jl preventing you from doing (potentially) unsafe casting. In this particular instance, Mooncake.jl just fails to compile, but in other instances other things can happen.)","category":"page"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"In any case, the point here is that Mooncake.TestUtils.test_rule provides a convenient way to produce and report an error.","category":"page"},{"location":"debugging_and_mwes/#Segfaults","page":"Debugging and MWEs","title":"Segfaults","text":"","category":"section"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"These are everyone's least favourite kind of problem, and they should be extremely rare in Mooncake.jl. However, if you are unfortunate enough to encounter one, please re-run your problem with the debug_mode kwarg set to true. See Debug Mode for more info. In general, this will catch problems before they become segfaults, at which point the above strategy for debugging and error reporting should work well.","category":"page"},{"location":"debug_mode/#Debug-Mode","page":"Debug Mode","title":"Debug Mode","text":"","category":"section"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"DocTestSetup = quote\n using Mooncake, ADTypes\nend","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"The Problem","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"A major source of potential problems in AD systems is rules returning the wrong type of tangent / fdata / rdata for a given primal value. For example, if someone writes a rule like","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"function rrule!!(::CoDual{typeof(+)}, x::CoDual{<:Real}, y::CoDual{<:Real})\n plus_reverse_pass(dz::Real) = NoRData(), dz, dz\n return zero_fcodual(primal(x) + primal(y))\nend","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"and calls","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"rrule(zero_fcodual(+), zero_fcodual(5.0), zero_fcodual(4f0))","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"then the type of dz on the reverse pass will be Float64 (assuming everything happens correctly), and this rule will return a Float64 as the rdata for y. However, the primal value of y is a Float32, so the appropriate tangent type is also Float32. This error may causes the reverse pass to fail loudly, but it may fail silently. It may cause an error much later in the reverse pass, making it hard to determine that the source of the error was the above rule. Worst of all, in some cases it could plausibly cause a segfault, which is more-or-less the worst kind of outcome possible.","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"The Solution","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Check that the types of the fdata / rdata associated to arguments are exactly what tangent_type / fdata_type / rdata_type require upon entry to / exit from rules and pullbacks.","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"This is implemented via DebugRRule:","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Mooncake.DebugRRule","category":"page"},{"location":"debug_mode/#Mooncake.DebugRRule","page":"Debug Mode","title":"Mooncake.DebugRRule","text":"DebugRRule(rule)\n\nConstruct a callable which is equivalent to rule, but inserts additional type checking. In particular:\n\ncheck that the fdata in each argument is of the correct type for the primal\ncheck that the fdata in the CoDual returned from the rule is of the correct type for the primal.\n\nThis happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.\n\nSome additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).\n\nLet rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.\n\nNote: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.\n\n\n\n\n\n","category":"type"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"You can straightforwardly enable it when building a rule via the debug_mode kwarg in the following:","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Mooncake.build_rrule","category":"page"},{"location":"debug_mode/#Mooncake.build_rrule","page":"Debug Mode","title":"Mooncake.build_rrule","text":"build_rrule(args...; debug_mode=false)\n\nHelper method. Only uses static information from args.\n\n\n\n\n\nbuild_rrule(sig::Type{<:Tuple})\n\nEquivalent to build_rrule(Mooncake.get_interpreter(), sig).\n\n\n\n\n\nbuild_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}\n\nReturns a DerivedRule which is an rrule!! for sig_or_mi in context C. See the docstring for rrule!! for more info.\n\nIf debug_mode is true, then all calls to rules are replaced with calls to DebugRRules.\n\n\n\n\n\n","category":"function"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"When using ADTypes.jl, you can choose whether or not to use it via the debug_mode kwarg:","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"julia> AutoMooncake(; config=Mooncake.Config(; debug_mode=true))\nAutoMooncake{Mooncake.Config}(Mooncake.Config(true, false))","category":"page"},{"location":"debug_mode/#When-Should-You-Use-Debug-Mode?","page":"Debug Mode","title":"When Should You Use Debug Mode?","text":"","category":"section"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Only use debug_mode when debugging a problem. This is because is has substantial performance implications.","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"DocTestSetup = nothing","category":"page"},{"location":"#Mooncake.jl","page":"Mooncake.jl","title":"Mooncake.jl","text":"","category":"section"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Documentation for Mooncake.jl is on its way!","category":"page"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Note (03/10/2024): Various bits of utility functionality are now carefully documented. This includes how to change the code which Mooncake sees, declare that the derivative of a function is zero, make use of existing ChainRules.rrules to quicky create new rules in Mooncake, and more.","category":"page"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Note (02/07/2024): The first round of documentation has arrived. This is largely targetted at those who are interested in contributing to Mooncake.jl – you can find this work in the \"Understanding Mooncake.jl\" section of the docs. There is more to to do, but it should be sufficient to understand how AD works in principle, and the core abstractions underlying Mooncake.jl.","category":"page"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Note (29/05/2024): I (Will) am currently actively working on the documentation. It will be merged in chunks over the next month or so as good first drafts of sections are completed. Please don't be alarmed that not all of it is here!","category":"page"},{"location":"known_limitations/#Known-Limitations","page":"Known Limitations","title":"Known Limitations","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Mooncake.jl has a number of known qualitative limitations, which we document here.","category":"page"},{"location":"known_limitations/#Mutation-of-Global-Variables","page":"Known Limitations","title":"Mutation of Global Variables","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"DocTestSetup = quote\n using Mooncake\nend","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"While great care is taken in this package to prevent silent errors, this is one edge case that we have yet to provide a satisfactory solution for. Consider a function of the form:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"julia> const x = Ref(1.0);\n\njulia> function foo(y::Float64)\n x[] = y\n return x[]\n end\nfoo (generic function with 1 method)","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"x is a global variable (if you refer to it in your code, it appears as a GlobalRef in the AST or lowered code). For some technical reasons that are beyond the scope of this section, this package cannot propagate gradient information through x. foo is the identity function, so it should have gradient 1.0. However, if you differentiate this example, you'll see:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"julia> rule = Mooncake.build_rrule(foo, 2.0);\n\njulia> Mooncake.value_and_gradient!!(rule, foo, 2.0)\n(2.0, (NoTangent(), 0.0))","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Observe that while it has correctly computed the identity function, the gradient is zero.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The takehome: do not attempt to differentiate functions which modify global state. Uses of globals which does not involve mutating them is fine though.","category":"page"},{"location":"known_limitations/#Circular-References","page":"Known Limitations","title":"Circular References","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"To a large extent, Mooncake.jl does not presently support circular references in an automatic fashion. It is generally possible to hand-write solutions, so we explain some of the problems here, and the general approach to resolving them.","category":"page"},{"location":"known_limitations/#Tangent-Types","page":"Known Limitations","title":"Tangent Types","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Problem","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Suppose that you have a type such as:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"mutable struct A\n x::Float64\n a::A\n function A(x::Float64)\n a = new(x)\n a.a = a\n return a\n end\nend","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is a fairly canonical example of a self-referential type. There are a couple of things which will not work with it out-of-the-box. tangent_type(A) will produce a stack overflow error. To see this, note that it will in effect try to produce a tangent of type Tangent{Tuple{tangent_type(A)}} – the circular dependency on the tangent_type function causes real problems here.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Solution","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"In order to resolve this, you need to produce a tangent type by hand. You might go with something like","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"mutable struct TangentForA\n x::Float64 # tangent type for Float64 is Float64\n a::TangentForA\n function TangentForA(x::Float64)\n a = new(x)\n a.a = a\n return a\n end\nend","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The point here is that you can manually resolve the circular dependency using a data structure which mimics the primal type. You will, however, need to implement similar methods for zero_tangent, randn_tangent, etc, and presumably need to implement additional getfield and setfield rules which are specific to this type.","category":"page"},{"location":"known_limitations/#Circular-References-in-General","page":"Known Limitations","title":"Circular References in General","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Problem","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Consider a type of the form","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"mutable struct Foo\n x\n Foo() = new()\nend","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"In this instance, tangent_type will work fine because Foo does not directly reference itself in its definition. Moreover, general uses of Foo will be fine.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"However, it's possible to construct an instance of Foo with a circular reference:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"f = Foo()\nf.x = f","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is actually fine provided we never attempt to call zero_tangent / randn_tangent / similar functionality on f once we've set its x field to itself. If we attempt to call such a function, we'll find ourselves with a stack overflow.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Solution This is a little tricker to handle. You could specialise zero_tangent etc for Foo, but this is something of a pain. Fortunately, it seems to be incredibly rare that this is ever a problem in practice. If we gain evidence that this is often a problem in practice, we'll look into supporting zero_tangent etc automatically for this case.","category":"page"},{"location":"known_limitations/#Tangent-Generation-and-Pointers","page":"Known Limitations","title":"Tangent Generation and Pointers","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Problem","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"In many use cases, a pointer provides the address of the start of a block of memory which has been allocated to e.g. store an array. However, we cannot get any of this context from the pointer itself – by just looking at a pointer, I cannot know whether its purpose is to refer to the start of a large block of memory, some proportion of the way through a block of memory, or even to keep track of a single address.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Recall that the tangent to a pointer is another pointer:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"julia> Mooncake.tangent_type(Ptr{Float64})\nPtr{Float64}","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Plainly I cannot implement a method of zero_tangent for Ptr{Float64} because I don't know how much memory to allocate.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is, however, fine if a pointer appears half way through a function, having been derived from another data structure. e.g.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"function foo(x::Vector{Float64})\n p = pointer(x, 2)\n return unsafe_load(p)\nend\n\nrule = build_rrule(Tuple{typeof(foo), Vector{Float64}})\nMooncake.value_and_gradient!!(rule, foo, [5.0, 4.0])\n\n# output\n(4.0, (NoTangent(), [0.0, 1.0]))","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Solution","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is only really a problem for tangent / fdata / rdata generation functionality, such as zero_tangent. As a work-around, AD testing functionality permits users to pass in CoDuals. So if you are testing something involving a pointer, you will need to construct its tangent yourself, and pass a CoDual to e.g. Mooncake.TestUtils.test_rule.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"While pointers tend to be a low-level implementation detail in Julia code, you could in principle actually be interested in differentiating a function of a pointer. In this case, you will not be able to use Mooncake.value_and_gradient!! as this requires the use of zero_tangent. Instead, you will need to use lower-level (internal) functionality, such as Mooncake.__value_and_gradient!!, or use the rule interface directly.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Honestly, your best bet is just to avoid differentiating functions whose arguments are pointers if you can.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"DocTestSetup = nothing","category":"page"}] +[{"location":"mathematical_interpretation/#Mooncake.jl's-Rule-System","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Mooncake.jl's approach to AD is recursive. It has a single specification for what it means to differentiate a Julia callable, and basically two approaches to achieving this. This section of the documentation explains the former.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We take an iterative approach to this explanation, starting at a high-level and adding more depth as we go.","category":"page"},{"location":"mathematical_interpretation/#10,000-Foot-View","page":"Mooncake.jl's Rule System","title":"10,000 Foot View","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A rule r(f, x) for a function f(x) \"does reverse mode AD\", and executes in two phases, known as the forwards pass and the reverse pass. In the forwards pass a rule executes the original function, and does some additional book-keeping in preparation for the reverse pass. On the reverse pass it undoes the computation from the forwards pass, \"backpropagates\" the gradient w.r.t. the output of the original function by applying the adjoint of the derivative of the original function to it, and writes the results of this computation to the correct places.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A precise mathematical model for the original function is therefore entirely crucial to this discussion, as it is needed to understand what the adjoint of its derivative is.","category":"page"},{"location":"mathematical_interpretation/#A-Model-For-A-Julia-Function","page":"Mooncake.jl's Rule System","title":"A Model For A Julia Function","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Since Julia permits the in-place modification / mutation of many data structures, we cannot make a naive translation between a Julia function and a mathematical object. Rather, we will have to model the state of the arguments to a function both before and after execution. Moreover, since a function can allocate new memory as part of execution and return it to the calling scope, we must track that too.","category":"page"},{"location":"mathematical_interpretation/#Consider-Only-Externally-Visible-Effects-Of-Function-Evaluation","page":"Mooncake.jl's Rule System","title":"Consider Only Externally-Visible Effects Of Function Evaluation","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We wish to treat a given function as a black box – we care about what a function does, not how it does it – so we consider only the externally-visible results of executing it. There are two ways in which changes can be made externally visible.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Return Value","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"(This point hardly requires explanation, but for the sake of completeness we do so anyway.)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The most obvious way in which a result can be made visible outside of a function is via its return value. For example, letting bar(x) = sin(x), consider the function","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function foo(x)\n y = bar(x)\n z = bar(y)\n return z\nend","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The communication between the two invocations of bar happen via the value it returns.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Modification of arguments","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In contrast to the above, changes made by one function can be made available to another implicitly if it modifies the values of its arguments, even if it doesn't return anything. For example, consider:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function bar(x::Vector{Float64})\n x .*= 2\n return nothing\nend\n\nfunction foo(x::Vector{Float64})\n bar(x)\n bar(x)\n return x\nend","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The second call to bar in foo sees the changes made to x by the first call to bar, despite not being explicitly returned.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"No Global Mutable State","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"functions can in principle also communicate via global mutable state. We make the decision to not support this.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For example, we assume functions of the following form cannot be encountered:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"const a = randn(10)\n\nfunction bar(x)\n a .+= x\n return nothing\nend\n\nfunction foo(x)\n bar(x)\n return a\nend","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this example, a is modified by bar, the effect of which is visible to foo.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For a variety of reasons this is very awkward to handle well. Since it's largely considered poor practice anyway, we explicitly outlaw this mode of communication between functions. See Why Support Closures But Not Mutable Globals for more info.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Note that this does not preclude the use of closed-over values or callable structs. For example, something like","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function foo(x)\n function bar(y)\n x .+= y\n return nothing\n end\n return bar(x)\nend","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"is perfectly fine.","category":"page"},{"location":"mathematical_interpretation/#The-Model","page":"Mooncake.jl's Rule System","title":"The Model","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"It is helpful to have a concrete example which uses both of the permissible methods to make results externally visible. To this end, consider the following function:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function f(x::Vector{Float64}, y::Vector{Float64}, z::Vector{Float64}, s::Ref{Vector{Float64}})\n z .*= y .* x\n s[] = 2z\n return sum(z)\nend","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We draw your attention to a variety of features of this function:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"z is mutated,\ns is mutated to reference freshly allocated memory,\nthe value previously pointed to by s is unmodified, and\nwe allocate a new value and return it (albeit, it is probably allocated on the stack).","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The model we adopt for any Julia function f is a function f mathcalX to mathcalX times mathcalA where mathcalX is the real finite Hilbert space associated to the arguments to f prior to execution, and mathcalA is the real finite Hilbert space associated to any newly allocated data during execution which is externally visible after execution – any newly allocated data which is not made visible is of no concern.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this example, mathcalX = RR^D times RR^D times RR^D times RR^S where D is the length of x / y / z, and S the length of s[] prior to running f. mathcalA = RR^D times RR, where the RR^D component corresponds to the freshly allocated memory that s references, and RR to the return value. Observe that we model Float64s as elements of RR, Vector{Float64}s as elements of RR^D (for some value of D), and Refs with whatever the model for their contents is. The keen-eyed reader will note that these choices abstract away several details which could conceivably be included in the model. In particular, Vector{Float64} is implemented via a memory buffer, a pointer to the start of this buffer, and an integer which indicates the length of this buffer – none of these details are exposed in the model.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this example, some of the memory allocated during execution is made externally visible by modifying one of the arguments, not just via the return value.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The argument to f is the arguments to f before execution, and the output is the 2-tuple comprising the same arguments after execution and the values associated to any newly allocated / created data. Crucially, observe that we distinguish between the state of the arguments before and after execution.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For our example, the exact form of f is","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f((x y z s)) = ((x y x odot y s) (2 x odot y sum_d=1^D x odot y))","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that f behaves a little like a transition operator, in the that the first element of the tuple returned is the updated state of the arguments.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This model is good enough for the vast majority of functions. Unfortunately it isn't sufficient to describe a function when arguments alias each other (e.g. consider the way in which this particular model is wrong if y aliases z). Fortunately this is only a problem in a small fraction of all cases of aliasing, so we defer discussion of this until later on.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider now how this approach can be used to model several additional Julia functions, and to obtain their derivatives and adjoints.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"sin(x::Float64)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"mathcalX = RR, mathcalA = RR, f(x) = (x sin(x)).","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Thus the derivative is D f x (dotx) = (dotx cos(x) dotx), and its adjoint is D f x^ast (bary) = bary_x + bary_a cos(x), where bary = (bary_x bary_a).","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that this result is slightly different to the last example we saw involving sin.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"AD With Mutable Data","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider again","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function f!(x::Vector{Float64})\n x .*= x\n return sum(x)\nend","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Our framework is able to accomodate this function, and has essentially the same solution as the last time we saw this example:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f(x) = (x odot x sum_n=1^N x_n^2)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Non-Mutating Functions","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A very interesting class of functions are those which do not modify their arguments. These are interesting because they are common, and are all that many AD frameworks like ChainRules.jl / Zygote.jl support – by considering this class of functions, we highlight some key similarities between these distinct rule systems.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"As always we can model these kinds of functions with a function f mathcalX to mathcalX times mathcalA, but we additionally have that f must have the form","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f(x) = (x varphi(x))","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"for some function varphi mathcalX to mathcalA. The derivative is","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x (dotx) = (dotx D varphi x(dotx))","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider the usual inner product to derive the adjoint:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"beginalign\n langle bary D f x (dotx) rangle = langle (bary_1 bary_2) (dotx D varphi x(dotx)) rangle nonumber \n = langle bary_1 dotx rangle + langle bary_2 D varphi x(dotx) rangle nonumber \n = langle bary_1 dotx rangle + langle D varphi x^ast (bary_2) dotx rangle nonumber quad text(by definition of the adjoint) \n = langle bary_1 + D varphi x^ast (bary_2) dotx rangle nonumber\nendalign","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"So the adjoint of the derivative is","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x^ast (bary) = bary_1 + D varphi x^ast (bary_2)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We see the correct thing to do is to increment the gradient of the output – bary_1 – by the result of applying the adjoint of the derivative of varphi to bary_2. In a ChainRules.rrule the bary_1 term is always zero, but the D varphi x^ast (bary_2) term is essentially the same.","category":"page"},{"location":"mathematical_interpretation/#The-Rule-Interface-(Round-1)","page":"Mooncake.jl's Rule System","title":"The Rule Interface (Round 1)","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Having explained in principle what it is that a rule must do, we now take a first look at the interface we use to achieve this. A rule for a function foo with signature","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Tuple{typeof(foo), Float64} -> Float64","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"must have signature","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Tuple{Trule, CoDual{typeof(foo), NoFData}, CoDual{Float64, NoFData}} ->\n Tuple{CoDual{Float64, NoFData}, Trvs_pass}","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For example, if we call foo(5.0), it rules would be called as rule(CoDual(foo, NoFData()), CoDual(5.0, NoFData())). The precise definition and role of NoFData will be explained shortly, but the general scheme is that to a rule for foo you must pass foo itself, its arguments, and some additional data for book-keeping. foo and each of its arguments are paired with this additional book-keeping data via the CoDual type.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The rule returns another CoDual (it propagates book-keeping information forwards), along with a function which runs the reverse pass.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In a little more depth:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Notation: primal","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Throughout the rest of this document, we will refer to the function being differentiated as the \"primal\" computation, and its arguments as the \"primal\" arguments.","category":"page"},{"location":"mathematical_interpretation/#Forwards-Pass","page":"Mooncake.jl's Rule System","title":"Forwards Pass","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Inputs","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Each piece of each input to the primal is paired with shadow data, if it has a fixed address. For example, a Vector{Float64} argument is paired with another Vector{Float64}. The adjoint of f is accumulated into this shadow vector on the reverse pass. However, a Float64 argument gets paired with NoFData(), since it is a bits type and therefore has no fixed address.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Outputs","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A rule must return a Tuple of two things. The first thing must be a CoDual containing the output of the primal computation and its shadow memory (if it has any). The second must be a function which runs the reverse pass of AD – this will usually be a closure of some kind.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Functionality","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A rule must","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"ensure that the state of the primal components of all inputs / the output are as they would have been had the primal computation been run (up to differences due to finite precision arithmetic),\npropagate / construct the shadow memory associated to the output (initialised to zero), and\nconstruct the function to run the reverse pass – typically this will involve storing some quantities computed during the forwards pass.","category":"page"},{"location":"mathematical_interpretation/#Reverse-Pass","page":"Mooncake.jl's Rule System","title":"Reverse Pass","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The second element of the output of a rule is a function which runs the reverse pass.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Inputs","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The \"rdata\" associated to the output of the primal.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Outputs","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The \"rdata\" associated to the inputs of the primal.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Functionality","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"undo changes made to primal state on the forwards pass.\napply adjoint of derivative of primal operation, putting the results in the correct place.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This description should leave you with (at least) a couple of questions. What is \"rdata\", and what is \"the correct place\" to put the results of applying the adjoint of the derivative? In order to address these, we need to discuss the types that Mooncake.jl uses to represent the results of AD, and to propagate results backwards on the reverse pass.","category":"page"},{"location":"mathematical_interpretation/#Representing-Gradients","page":"Mooncake.jl's Rule System","title":"Representing Gradients","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We refer to both inputs and outputs of derivatives D f x mathcalX to mathcalY as tangents, e.g. dotx or doty. Conversely, we refer to both inputs and outputs to the adjoint of this derivative D f x^ast mathcalY to mathcalX as gradients, e.g. bary and barx.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Note, however, that the sets involved are the same whether dealing with a derivative or its adjoint. Consequently, we use the same type to represent both.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Representing Gradients","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This package assigns to each type in Julia a unique tangent_type, the purpose of which is to contain the gradients computed during reverse mode AD. The extended docstring for tangent_type provides the best introduction to the types which are used to represent tangents / gradients.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"tangent_type(P)","category":"page"},{"location":"mathematical_interpretation/#Mooncake.tangent_type-Tuple{Any}","page":"Mooncake.jl's Rule System","title":"Mooncake.tangent_type","text":"tangent_type(P)\n\nThere must be a single type used to represents tangents of primals of type P, and it must be given by tangent_type(P).\n\nExtended help\n\nThe tangent types which Mooncake.jl uses are quite similar in spirit to ChainRules.jl. For example, tangent \"vectors\" for\n\nFloat64s are Float64s,\nVector{Float64}s are Vector{Float64}s, and\nstructs are other another (special) struct with field types specified recursively.\n\nThere are, however, some major differences. Firstly, while it is certainly true that the above tangent types are permissible in ChainRules.jl, they are not the uniquely permissible types. For example, ZeroTangent is also a permissible type of tangent for any of them, and Float32 is permissible for Float64. This is a general theme in ChainRules.jl – it intentionally declines to place restrictions on what type can be used to represent the tangent of a given type.\n\nMooncake.jl differs from this. It insists that each primal type is associated to a single tangent type. Furthermore, this type is always given by the function Mooncake.tangent_type(primal_type).\n\nConsider some more worked examples.\n\nInt\n\nInt is not a differentiable type, so its tangent type is NoTangent:\n\njulia> tangent_type(Int)\nNoTangent\n\nTuples\n\nThe tangent type of a Tuple is defined recursively based on its field types. For example\n\njulia> tangent_type(Tuple{Float64, Vector{Float64}, Int})\nTuple{Float64, Vector{Float64}, NoTangent}\n\nThere is one edge case to be aware of: if all of the field of a Tuple are non-differentiable, then the tangent type is NoTangent. For example,\n\njulia> tangent_type(Tuple{Int, Int})\nNoTangent\n\nStructs\n\nAs with Tuples, the tangent type of a struct is, by default, given recursively. In particular, the tangent type of a struct type is Tangent. This type contains a NamedTuple containing the tangent to each field in the primal struct.\n\nAs with Tuples, if all field types are non-differentiable, the tangent type of the entire struct is NoTangent.\n\nThere are a couple of additional subtleties to consider over Tuples though. Firstly, not all fields of a struct have to be defined. Fortunately, Julia makes it easy to determine how many of the fields might possibly not be defined. The tangent associated to any field which might possibly not be defined is wrapped in a PossiblyUninitTangent.\n\nFurthermore, structs can have fields whose static type is abstract. For example\n\njulia> struct Foo\n x\n end\n\nIf you ask for the tangent type of Foo, you will see that it is\n\njulia> tangent_type(Foo)\nTangent{@NamedTuple{x}}\n\nObserve that the field type associated to x is Any. The way to understand this result is to observe that\n\nx could have literally any type at runtime, so we know nothing about what its tangent type must be until runtime, and\nwe require that the tangent type of Foo be unique.\n\nThe consequence of these two considerations is that the tangent type of Foo must be able to contain any type of tangent in its x field. It follows that the fieldtype of the x field of Foos tangent must be Any.\n\nMutable Structs\n\nThe tangent type for mutable structs have the same set of considerations as structs. The only difference is that they must themselves be mutable. Consequently, we use a type called MutableTangent to represent their tangents. It is a mutable struct with the same structure as Tangent.\n\nFor example, if you ask for the tangent_type of\n\njulia> mutable struct Bar\n x::Float64\n end\n\nyou will find that it is\n\njulia> tangent_type(Bar)\nMutableTangent{@NamedTuple{x::Float64}}\n\nPrimitive Types\n\nWe've already seen a couple of primitive types (Float64 and Int). The basic story here is that all primitive types require an explicit specification of what their tangent type must be.\n\nOne interesting case are Ptr types. The tangent type of a Ptr{P} is Ptr{T}, where T = tangent_type(P). For example\n\njulia> tangent_type(Ptr{Float64})\nPtr{Float64}\n\n\n\n\n\n","category":"method"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"FData and RData","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"While tangents are the things used to represent gradients and are what high-level interfaces will return, they are not what gets propagated forwards and backwards by rules during AD.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Rather, during AD, Mooncake.jl makes a fundamental distinction between data which is identified by its address in memory (Arrays, mutable structs, etc), and data which is identified by its value (is-bits types such as Float64, Int, and structs thereof). In particular, memory which is identified by its address gets assigned a unique location in memory in which its gradient lives (that this \"unique gradient address\" system is essential will become apparent when we discuss aliasing later on). Conversely, the gradient w.r.t. a value type resides in another value type.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The following docstring provides the best in-depth explanation.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Mooncake.fdata_type(T)","category":"page"},{"location":"mathematical_interpretation/#Mooncake.fdata_type-Tuple{Any}","page":"Mooncake.jl's Rule System","title":"Mooncake.fdata_type","text":"fdata_type(T)\n\nReturns the type of the forwards data associated to a tangent of type T.\n\nExtended help\n\nRules in Mooncake.jl do not operate on tangents directly. Rather, functionality is defined to split each tangent into two components, that we call fdata (forwards-pass data) and rdata (reverse-pass data). In short, any component of a tangent which is identified by its address (e.g. a mutable structs or an Array) gets passed around on the forwards-pass of AD and is incremented in-place on the reverse-pass, while components of tangents identified by their value get propagated and accumulated only on the reverse-pass.\n\nGiven a tangent type T, you can find out what type its fdata and rdata must be with fdata_type(T) and rdata_type(T) respectively. A consequence of this is that there is exactly one valid fdata type and rdata type for each primal type.\n\nGiven a tangent t, you can get its fdata and rdata using f = fdata(t) and r = rdata(t) respectively. f and r can be re-combined to recover the original tangent using the binary version of tangent: tangent(f, r). It must always hold that\n\ntangent(fdata(t), rdata(t)) === t\n\nThe need for all of this is explained in the docs, but for now it suffices to consider our running examples again, and to see what their fdata and rdata look like.\n\nInt\n\nInts are non-differentiable types, so there is nothing to pass around on the forwards- or reverse-pass. Therefore\n\njulia> fdata_type(tangent_type(Int)), rdata_type(tangent_type(Int))\n(NoFData, NoRData)\n\nFloat64\n\nThe tangent type of Float64 is Float64. Float64s are identified by their value / have no fixed address, so\n\njulia> (fdata_type(Float64), rdata_type(Float64))\n(NoFData, Float64)\n\nVector{Float64}\n\nThe tangent type of Vector{Float64} is Vector{Float64}. A Vector{Float64} is identified by its address, so\n\njulia> (fdata_type(Vector{Float64}), rdata_type(Vector{Float64}))\n(Vector{Float64}, NoRData)\n\nTuple{Float64, Vector{Float64}, Int}\n\nThis is an example of a type which has both fdata and rdata. The tangent type for Tuple{Float64, Vector{Float64}, Int} is Tuple{Float64, Vector{Float64}, NoTangent}. Tuples have no fixed memory address, so we interogate each field on its own. We have already established the fdata and rdata types for each element, so we recurse to obtain:\n\njulia> T = tangent_type(Tuple{Float64, Vector{Float64}, Int})\nTuple{Float64, Vector{Float64}, NoTangent}\n\njulia> (fdata_type(T), rdata_type(T))\n(Tuple{NoFData, Vector{Float64}, NoFData}, Tuple{Float64, NoRData, NoRData})\n\nThe zero tangent for (5.0, [5.0]) is t = (0.0, [0.0]). fdata(t) returns (NoFData(), [0.0]), where the second element is === to the second element of t. rdata(t) returns (0.0, NoRData()). In this example, t contains a mixture of data, some of which is identified by its value, and some of which is identified by its address, so there is some fdata and some rdata.\n\nStructs\n\nStructs are handled in more-or-less the same way as Tuples, albeit with the possibility of undefined fields needing to be explicitly handled. For example, a struct such as\n\njulia> struct Foo\n x::Float64\n y\n z::Int\n end\n\nhas tangent type\n\njulia> tangent_type(Foo)\nTangent{@NamedTuple{x::Float64, y, z::NoTangent}}\n\nIts fdata and rdata are given by special FData and RData types:\n\njulia> (fdata_type(tangent_type(Foo)), rdata_type(tangent_type(Foo)))\n(Mooncake.FData{@NamedTuple{x::NoFData, y, z::NoFData}}, Mooncake.RData{@NamedTuple{x::Float64, y, z::NoRData}})\n\nPractically speaking, FData and RData both have the same structure as Tangents and are just used in different contexts.\n\nMutable Structs\n\nThe fdata for a mutable structs is its tangent, and it has no rdata. This is because mutable structs have fixed memory addresses, and can therefore be incremented in-place. For example,\n\njulia> mutable struct Bar\n x::Float64\n y\n z::Int\n end\n\nhas tangent type\n\njulia> tangent_type(Bar)\nMutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}\n\nand fdata / rdata types\n\njulia> (fdata_type(tangent_type(Bar)), rdata_type(tangent_type(Bar)))\n(MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}, NoRData)\n\nPrimitive Types\n\nAs with tangents, each primitive type must specify what its fdata and rdata is. See specific examples for details.\n\n\n\n\n\n","category":"method"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"CoDuals","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"CoDuals are simply used to bundle together a primal and an associated fdata, depending upon context. Occassionally, they are used to pair together a primal and a tangent.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A quick aside: Non-Differentiable Data","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In the introduction to algorithmic differentiation, we assumed that the domain / range of function are the same as that of its derivative. Unfortunately, this story is only partly true. Matters are complicated by the fact that not all data types in Julia can reasonably be thought of as forming a Hilbert space. e.g. the String type.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consequently we introduce the special type NoTangent, instances of which can be thought of as representing the set containing only a 0 tangent. Morally speaking, for any non-differentiable data x, x + NoTangent() == x.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Other than non-differentiable data, the model of data in Julia as living in a real-valued finite dimensional Hilbert space is quite reasonable. Therefore, we hope readers will forgive us for largely ignoring the distinction between the domain and range of a function and that of its derivative in mathematical discussions, while simultaneously drawing a distinction when discussing code.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"TODO: update this to cast e.g. each possible String as its own vector space containing only the 0 element. This works, even if it seems a little contrived.","category":"page"},{"location":"mathematical_interpretation/#The-Rule-Interface-(Round-2)","page":"Mooncake.jl's Rule System","title":"The Rule Interface (Round 2)","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Now that you've seen what data structures are used to represent gradients, we can describe in more depth the detail of how fdata and rdata are used to propagate gradients backwards on the reverse pass.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"DocTestSetup = quote\n using Mooncake\n using Mooncake: CoDual\n import Mooncake: rrule!!\nend","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider the function","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> foo(x::Tuple{Float64, Vector{Float64}}) = x[1] + sum(x[2])\nfoo (generic function with 1 method)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The fdata for x is a Tuple{NoFData, Vector{Float64}}, and its rdata is a Tuple{Float64, NoRData}. The function returns a Float64, which has no fdata, and whose rdata is Float64. So on the forwards pass there is really nothing that needs to happen with the fdata for x.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Under the framework introduced above, the model for this function is","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f(x) = (x x_1 + sum_n=1^N (x_2)_n)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"where the vector in the second element of x is of length N. Now, following our usual steps, the derivative is","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x(dotx) = (dotx dotx_1 + sum_n=1^N (dotx_2)_n)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A gradient for this is a tuple (bary_x bary_a) where bary_a in RR and bary_x in RR times RR^N. A quick derivation will show that the adjoint is","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x^ast(bary) = ((bary_x)_1 + bary_a (bary_x)_2 + bary_a mathbf1)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"where mathbf1 is the vector of length N in which each element is equal to 1. (Observe that this agrees with the result we derived earlier for functions which don't mutate their arguments).","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Now that we know what the adjoint is, we'll write down the rrule!!, and then explain what is going on in terms of the adjoint. This hand-written implementation is to aid your understanding – Mooncake.jl should be relied upon to generate this code automatically in practice.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> function rrule!!(::CoDual{typeof(foo)}, x::CoDual{Tuple{Float64, Vector{Float64}}})\n dx_fdata = x.dx\n function dfoo_adjoint(dy::Float64)\n dx_fdata[2] .+= dy\n dx_1_rdata = dy\n dx_rdata = (dx_1_rdata, NoRData())\n return NoRData(), dx_rdata\n end\n x_p = x.x\n return CoDual(x_p[1] + sum(x_p[2]), NoFData()), dfoo_adjoint\n end;\n","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"where dy is the rdata for the output to foo. The rrule!! can be called with the appropriate CoDuals:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> out, pb!! = rrule!!(CoDual(foo, NoFData()), CoDual((5.0, [1.0, 2.0]), (NoFData(), [0.0, 0.0])))\n(CoDual{Float64, NoFData}(8.0, NoFData()), var\"#dfoo_adjoint#1\"{Tuple{NoFData, Vector{Float64}}}((NoFData(), [0.0, 0.0])))","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"and the pullback with appropriate rdata:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> pb!!(1.0)\n(NoRData(), (1.0, NoRData()))","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"DocTestSetup = nothing","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that the forwards pass:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"computes the result of the initial function, and\npulls out the fdata for the Vector{Float64} component of the argument.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"As promised, the forwards pass really has nothing to do with the adjoint. It's just book-keeping and running the primal computation.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The reverse pass:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"increments each element of dx_fdata[2] by dy – this corresponds to (bary_x)_2 + bary_a mathbf1 in the adjoint,\nsets dx_1_rdata to dy – this corresponds (bary_x)_1 + bary_a subject to the constraint that (bary_x)_1 = 0,\nconstructs the rdata for x – this is essentially just book-keeping.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Each of these items serve to demonstrate more general points. The first that, upon entry into the reverse pass, all fdata values correspond to gradients for the arguments / output of f \"upon exit\" (for the components of these which are identified by their address), and once the reverse-pass finishes running, they must contain the gradients w.r.t. the arguments of f \"upon entry\".","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The second that we always assume that the components of bary_x which are identified by their value have zero-rdata.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The third is that the components of the arguments of f which are identified by their value must have rdata passed back explicitly by a rule, while the components of the arguments to f which are identified by their address get their gradients propagated back implicitly (i.e. via the in-place modification of fdata).","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Reminder: the first element of the tuple returned by dfoo_adjoint is the rdata associated to foo itself, hence it is NoRData.","category":"page"},{"location":"mathematical_interpretation/#Testing","page":"Mooncake.jl's Rule System","title":"Testing","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Mooncake.jl has an almost entirely automated system for testing rules – Mooncake.TestUtils.test_rule. You should absolutely make use of these when writing rules.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"TODO: improve docstring for testing functionality.","category":"page"},{"location":"mathematical_interpretation/#Summary","page":"Mooncake.jl's Rule System","title":"Summary","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this section we have covered the rule system. Every callable object / function in the Julia language is differentiated using rules with this interface, whether they be hand-written rrule!!s, or rules derived by Mooncake.jl.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"At this point you should be equipped with enough information to understand what a rule in Mooncake.jl does, and how you can write your own ones. Later sections will explain how Mooncake.jl goes about deriving rules itself in a recursive manner, and introduce you to some of the internals.","category":"page"},{"location":"mathematical_interpretation/#Asides","page":"Mooncake.jl's Rule System","title":"Asides","text":"","category":"section"},{"location":"mathematical_interpretation/#Why-Uniqueness-of-Type-For-Tangents-/-FData-/-RData?","page":"Mooncake.jl's Rule System","title":"Why Uniqueness of Type For Tangents / FData / RData?","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Why does Mooncake.jl insist that each primal type P be paired with a single tangent type T, as opposed to being more permissive. There are a few notable reasons:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"To provide a precise interface. Rules pass fdata around on the forwards pass and rdata on the reverse pass – being able to make strong assumptions about the type of the fdata / rdata given the primal type makes implementing rules much easier in practice.\nConditional type stability. We wish to have a high degree of confidence that if the primal code is type-stable, then the AD code will also be. It is straightforward to construct type stable primal codes which have type-unstable forwards and reverse passes if you permit there to be more than one fdata / rdata type for a given primal. So while uniqueness is certainly not sufficient on its own to guarantee conditional type stability, it is probably necessary in general.\nTest-case generation and coverage. There being a unique tangent / fdata / rdata type for each primal makes being confident that a given rule is being tested thoroughly much easier. For a given primal, rather than there being many possible input / output types to consider, there is just one.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This topic, in particular what goes wrong with permissive tangent type systems like those employed by ChainRules, deserves a more thorough treatment – hopefully someone will write something more expansive on this topic at some point.","category":"page"},{"location":"mathematical_interpretation/#Why-Support-Closures-But-Not-Mutable-Globals","page":"Mooncake.jl's Rule System","title":"Why Support Closures But Not Mutable Globals","text":"","category":"section"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"First consider why closures are straightforward to support. Look at the type of the closure produced by foo:","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function foo(x)\n function bar(y)\n x .+= y\n return nothing\n end\n return bar\nend\nbar = foo(randn(5))\ntypeof(bar)\n\n# output\nvar\"#bar#1\"{Vector{Float64}}","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that the Vector{Float64} that we passed to foo, and closed over in bar, is present in the type. This alludes to the fact that closures are basically just callable structs whose fields are the closed-over variables. Since the function itself is an argument to its rule, everything enters the rule for bar via its arguments, and the rule system developed in this document applies straightforwardly.","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"On the other hand, globals do not appear in the functions that they are a part of. For example,","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"const a = randn(10)\n\nfunction g(x)\n a .+= x\n return nothing\nend\n\ntypeof(g)\n\n# output\ntypeof(g) (singleton type of function g, subtype of Function)","category":"page"},{"location":"mathematical_interpretation/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Neither the value nor type of a are present in g. Since a doesn't enter g via its arguments, it is unclear how it should be handled in general.","category":"page"},{"location":"running_tests_locally/#Running-Tests-Locally","page":"Running Tests Locally","title":"Running Tests Locally","text":"","category":"section"},{"location":"running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Mooncake.jl's test suite is fairly extensive. While you can use Pkg.test to run the test suite in the standard manner, this is not usually optimal in Mooncake.jl. When editing some code, you typically only want to run the tests associated with it, not the entire test suite.","category":"page"},{"location":"running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Mooncake's tests are organised as follows:","category":"page"},{"location":"running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Things that are required for most / all test suites are loaded up in test/front_matter.jl.\nThe tests for something in src are located in an identically-named file in test. e.g. the unit tests for src/rrules/new.jl are located in test/rrules/new.jl.","category":"page"},{"location":"running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Thus, a workflow that I (Will) find works very well is the following:","category":"page"},{"location":"running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Ensure that you have Revise.jl and TestEnv.jl installed in your default environment.\nstart the REPL, dev Mooncake.jl, and navigate to the top level of the Mooncake.jl directory.\nusing TestEnv, Revise. Better still, load both of these in your .julia/config/startup.jl file so that you don't ever forget to load them.\nRun the following: using Pkg; Pkg.activate(\".\"); TestEnv.activate(); include(\"test/front_matter.jl\"); to set up your environment.\ninclude whichever test file you want to run the tests from.\nModify code, and re-include tests to check it has done was you need. Loop this until done.\nMake a PR. This runs the entire test suite – I find that I almost never run the entire test suite locally.","category":"page"},{"location":"running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"The purpose of this approach is to:","category":"page"},{"location":"running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Avoid restarting the REPL each time you make a change, and\nRun the smallest bit of the test suite possible when making changes, in order to make development a fast and enjoyable process.","category":"page"},{"location":"running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"If you find that this strategy leaves you running more of the test suite than you would like, consider copy + pasting specific tests into the REPL, or commenting out a chunk of tests in the file that you are editing during development (try not to commit this). I find this is rather crude strategy effective in practice.","category":"page"},{"location":"understanding_intro/#Mooncake.jl-and-Reverse-Mode-AD","page":"Introduction","title":"Mooncake.jl and Reverse-Mode AD","text":"","category":"section"},{"location":"understanding_intro/","page":"Introduction","title":"Introduction","text":"The point of Mooncake.jl is to perform reverse-mode algorithmic differentiation (AD). The purpose of this section is to explain what precisely is meant by this, and how it can be interpreted mathematically.","category":"page"},{"location":"understanding_intro/","page":"Introduction","title":"Introduction","text":"we recap what AD is, and introduce the mathematics necessary to understand is,\nexplain how this mathematics relates to functions and data structures in Julia, and\nhow this is handled in Mooncake.jl.","category":"page"},{"location":"understanding_intro/","page":"Introduction","title":"Introduction","text":"Since Mooncake.jl supports in-place operations / mutation, these will push beyond what is encountered in Zygote / Diffractor / ChainRules. Consequently, while there is a great deal of overlap with these existing systems, you will need to read through this section of the docs in order to properly understand Mooncake.jl.","category":"page"},{"location":"understanding_intro/#Who-Are-These-Docs-For?","page":"Introduction","title":"Who Are These Docs For?","text":"","category":"section"},{"location":"understanding_intro/","page":"Introduction","title":"Introduction","text":"These are primarily designed for anyone who is interested in contributing to Mooncake.jl. They are also hopefully of interest to anyone how is interested in understanding AD more broadly. If you aren't interested in understanding how Mooncake.jl and AD work, you don't need to have read them in order to make use of this package.","category":"page"},{"location":"understanding_intro/#Prerequisites-and-Resources","page":"Introduction","title":"Prerequisites and Resources","text":"","category":"section"},{"location":"understanding_intro/","page":"Introduction","title":"Introduction","text":"This introduction assumes familiarity with the differentiation of vector-valued functions – familiarity with the gradient and Jacobian matrices is a given.","category":"page"},{"location":"understanding_intro/","page":"Introduction","title":"Introduction","text":"In order to provide a convenient exposition of AD, we need to abstract a little further than this and make use of a slightly more general notion of the derivative, gradient, and \"transposed Jacobian\". Please note that, fortunately, we only ever have to handle finite dimensional objects when doing AD, so there is no need for any knowledge of functional analysis to understand what is going on here. The required concepts will be introduced here, but I cannot promise that these docs give the best exposition – they're most appropriate as a refresher and to establish notation. Rather, I would recommend a couple of lectures from the \"Matrix Calculus for Machine Learning and Beyond\" course, which you can find on MIT's OCW website, delivered by Edelman and Johnson (who will be familiar faces to anyone who has spent much time in the Julia world!). It is designed for undergraduates, and is accessible to anyone with some undergraduate-level linear algebra and calculus. While I recommend the whole course, Lecture 1 part 2 and Lecture 4 part 1 are especially relevant to the problems we shall discuss – you can skip to 11:30 in Lecture 4 part 1 if you're in a hurry.","category":"page"},{"location":"tools_for_rules/#Tools-for-Rules","page":"Tools for Rules","title":"Tools for Rules","text":"","category":"section"},{"location":"tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. However, this does not always necessitate writing your own rrule!! from scratch. In this section, we detail some useful strategies which can help you avoid having to write rrule!!s in many situations.","category":"page"},{"location":"tools_for_rules/#Simplfiying-Code-via-Overlays","page":"Tools for Rules","title":"Simplfiying Code via Overlays","text":"","category":"section"},{"location":"tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Mooncake.@mooncake_overlay","category":"page"},{"location":"tools_for_rules/#Mooncake.@mooncake_overlay","page":"Tools for Rules","title":"Mooncake.@mooncake_overlay","text":"@mooncake_overlay method_expr\n\nDefine a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.\n\nFor example, suppose that you have a function\n\njulia> foo(x::Float64) = bar(x)\nfoo (generic function with 1 method)\n\nwhere Mooncake.jl fails to differentiate bar for some reason. If you have access to another function baz, which does the same thing as bar, but does so in a way which Mooncake.jl can differentiate, you can simply write:\n\njulia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)\n\n\nWhen looking up the code for foo(::Float64), Mooncake.jl will see this method, rather than the original, and differentiate it instead.\n\nA Worked Example\n\nTo demonstrate how to use @mooncake_overlays in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!\n\nFirst, consider a simple example:\n\njulia> scale(x) = 2x\nscale (generic function with 1 method)\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(10.0, (NoTangent(), 2.0))\n\nWe can use @mooncake_overlay to change the definition which Mooncake.jl sees:\n\njulia> Mooncake.@mooncake_overlay scale(x) = 3x\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(15.0, (NoTangent(), 3.0))\n\nAs can be seen from the output, the result of differentiating using Mooncake.jl has changed to reflect the overlay-ed definition of the method.\n\nAdditionally, it is possible to use the usual multi-line syntax to declare an overlay:\n\njulia> Mooncake.@mooncake_overlay function scale(x)\n return 4x\n end\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(20.0, (NoTangent(), 4.0))\n\n\n\n\n\n","category":"macro"},{"location":"tools_for_rules/#Functions-with-Zero-Adjoint","page":"Tools for Rules","title":"Functions with Zero Adjoint","text":"","category":"section"},{"location":"tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"If the above strategy does not work, but you find yourself in the surprisingly common situation that the adjoint of the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following:","category":"page"},{"location":"tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Mooncake.@zero_adjoint\nMooncake.zero_adjoint","category":"page"},{"location":"tools_for_rules/#Mooncake.@zero_adjoint","page":"Tools for Rules","title":"Mooncake.@zero_adjoint","text":"@zero_adjoint ctx sig\n\nDefines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.\n\nFor example:\n\njulia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive\n\njulia> foo(x) = 5\nfoo (generic function with 1 method)\n\njulia> @zero_adjoint DefaultCtx Tuple{typeof(foo), Any}\n\njulia> is_primitive(DefaultCtx, Tuple{typeof(foo), Any})\ntrue\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(3.0))[2](NoRData())\n(NoRData(), 0.0)\n\nLimited support for Varargs is also available. For example\n\njulia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive\n\njulia> foo_varargs(x...) = 5\nfoo_varargs (generic function with 1 method)\n\njulia> @zero_adjoint DefaultCtx Tuple{typeof(foo_varargs), Vararg}\n\njulia> is_primitive(DefaultCtx, Tuple{typeof(foo_varargs), Any, Float64, Int})\ntrue\n\njulia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())\n(NoRData(), 0.0, NoRData())\n\nBe aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.\n\nWARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.\n\nAs always, you should use TestUtils.test_rule to ensure that you've not made a mistake.\n\nSignatures Unsupported By This Macro\n\nIf the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.\n\n\n\n\n\n","category":"macro"},{"location":"tools_for_rules/#Mooncake.zero_adjoint","page":"Tools for Rules","title":"Mooncake.zero_adjoint","text":"zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}\n\nUtility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.\n\nNOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.\n\nYou make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:\n\njulia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual\n\njulia> foo(x::Vararg{Int}) = 5\nfoo (generic function with 1 method)\n\njulia> is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(foo), Vararg{Int}}}) = true;\n\njulia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())\n(NoRData(), NoRData(), NoRData())\n\nWARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```\n\n\n\n\n\n","category":"function"},{"location":"tools_for_rules/#Using-ChainRules.jl","page":"Tools for Rules","title":"Using ChainRules.jl","text":"","category":"section"},{"location":"tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"ChainRules.jl provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the ChainRulesCore.rrule function. There are some instances where it is most convenient to implement a Mooncake.rrule!! by wrapping an existing ChainRulesCore.rrule.","category":"page"},{"location":"tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"There is enough similarity between these two systems that most of the boilerplate code can be avoided.","category":"page"},{"location":"tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Mooncake.@from_rrule","category":"page"},{"location":"tools_for_rules/#Mooncake.@from_rrule","page":"Tools for Rules","title":"Mooncake.@from_rrule","text":"@from_rrule ctx sig [has_kwargs=false]\n\nConvenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.\n\nArguments\n\nctx: A Mooncake context type\nsig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.\nhas_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.\n\nExample Usage\n\nA Basic Example\n\njulia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils\n\njulia> using ChainRulesCore\n\njulia> foo(x::Real) = 5x;\n\njulia> function ChainRulesCore.rrule(::typeof(foo), x::Real)\n foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), 5Ω\n return foo(x), foo_pb\n end;\n\njulia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat}\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(5.0))[2](1.0)\n(NoRData(), 5.0)\n\njulia> # Check that the rule works as intended.\n TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true)\nTest Passed\n\nAn Example with Keyword Arguments\n\njulia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils\n\njulia> using ChainRulesCore\n\njulia> foo(x::Real; cond::Bool) = cond ? 5x : 4x;\n\njulia> function ChainRulesCore.rrule(::typeof(foo), x::Real; cond::Bool)\n foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), cond ? 5Ω : 4Ω\n return foo(x; cond), foo_pb\n end;\n\njulia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true\n\njulia> _, pb = rrule!!(\n zero_fcodual(Core.kwcall),\n zero_fcodual((cond=false, )),\n zero_fcodual(foo),\n zero_fcodual(5.0),\n );\n\njulia> pb(3.0)\n(NoRData(), NoRData(), NoRData(), 12.0)\n\njulia> # Check that the rule works as intended.\n TestUtils.test_rule(\n Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true\n )\nTest Passed\n\nNotice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.\n\nLimitations\n\nIt is your responsibility to ensure that\n\ncalls with signature sig do not mutate their arguments,\nthe output of calls with signature sig does not alias any of the inputs.\n\nAs with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.\n\nArgument Type Constraints\n\nMany methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature\n\nTuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}\n\nThere are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.\n\nSuffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.\n\nConversions Between Different Tangent Type Systems\n\nUnder the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.\n\n\n\n\n\n","category":"macro"},{"location":"algorithmic_differentiation/#Algorithmic-Differentiation","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This section introduces the mathematics behind AD. Even if you have worked with AD before, we recommend reading in order to acclimatise yourself to the perspective that Mooncake.jl takes on the subject.","category":"page"},{"location":"algorithmic_differentiation/#Derivatives","page":"Algorithmic Differentiation","title":"Derivatives","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"A foundation on which all of AD is built the the derivate – we require a fairly general definition of it, which we build up to here.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Scalar-to-Scalar Functions","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Consider first f RR to RR, which we require to be differentiable at x in RR. Its derivative at x is usually thought of as the scalar alpha in RR such that","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"textdf = alpha textdx ","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Loosely speaking, by this notation we mean that for arbitrary small changes textd x in the input to f, the change in the output textd f is alpha textdx. We refer readers to the first few minutes of the first lecture mentioned before for a more careful explanation.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Vector-to-Vector Functions","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The generalisation of this to Euclidean space should be familiar: if f RR^P to RR^Q is differentiable at a point x in RR^P, then the derivative of f at x is given by the Jacobian matrix at x, denoted Jx in RR^Q times P, such that","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"textdf = Jx textdx ","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"It is possible to stop here, as all the functions we shall need to consider can in principle be written as functions on some subset RR^P.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"However, when we consider differentiating computer programmes, we will have to deal with complicated nested data structures, e.g. structs inside Tuples inside Vectors etc. While all of these data structures can be mapped onto a flat vector in order to make sense of the Jacobian of a computer programme, this becomes very inconvenient very quickly. To see the problem, consider the Julia function whose input is of type Tuple{Tuple{Float64, Vector{Float64}}, Vector{Float64}, Float64} and whose output is of type Tuple{Vector{Float64}, Float64}. What kind of object might be use to represent the derivative of a function mapping between these two spaces? We certainly can treat these as structured \"view\" into a \"flat\" Vector{Float64}s, and then define a Jacobian, but actually finding this mapping is a tedious exercise, even if it quite obviously exists.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In fact, a more general formulation of the derivative is used all the time in the context of AD – the matrix calculus discussed by [1] and [2] (to name a couple) make use of a generalised form of the derivative in order to work with functions which map to and from matrices (albeit there are slight differences in naming conventions from text to text), without needing to \"flatten\" them into vectors in order to make sense of them.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In general, it will be much easier to avoid \"flattening\" operations wherever possible. In order to do so, we now introduce a generalised notion of the derivative.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Functions Between More General Spaces","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to avoid the difficulties described above, we consider functions f mathcalX to mathcalY, where mathcalX and mathcalY are finite dimensional real Hilbert spaces (read: finite-dimensional vector space with an inner product, and real-valued scalars). This definition includes functions to / from RR, RR^D, but also real-valued matrices, and any other \"container\" for collections of real numbers. Furthermore, we shall see later how we can model all sorts of structured representations of data directly as such spaces.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"For such spaces, the derivative of f at x in mathcalX is the linear operator (read: linear function) D f x mathcalX to mathcalY satisfying","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"textdf = D f x (textd x)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The purpose of this linear operator is to provide a linear approximation to f which is accurate for arguments which are very close to x.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Please note that D f x is a single mathematical object, despite the fact that 3 separate symbols are used to denote it – D f x (dotx) denotes the application of the function D f x to argument dotx. Furthermore, the dot-notation (dotx) does not have anything to do with time-derivatives, it is simply common notation used in the AD literature to denote the arguments of derivatives.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"So, instead of thinking of the derivative as a number or a matrix, we think about it as a function. We can express the previous notions of the derivative in this language.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In the scalar case, rather than thinking of the derivative as being alpha, we think of it is a the linear operator D f x (dotx) = alpha dotx. Put differently, rather than thinking of the derivative as the slope of the tangent to f at x, think of it as the function decribing the tangent itself. Observe that up until now we had only considered inputs to D f x which were small (textd x) – here we extend it to the entire space mathcalX and denote inputs in this space dotx. Inputs dotx should be thoughts of as \"directions\", in the directional derivative sense (why this is true will be discussed later).","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Similarly, if mathcalX = RR^P and mathcalY = RR^Q then this operator can be specified in terms of the Jacobian matrix: D f x (dotx) = Jx dotx – brackets are used to emphasise that D f x is a function, and is being applied to dotx.[note_for_geometers]","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"To reiterate, for the rest of this document, we define the derivative to be \"multiply by alpha\" or \"multiply by Jx\", rather than to be alpha or Jx. So whenever you see the word \"derivative\", you should think \"linear function\".","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The Chain Rule","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The chain rule is the result which makes AD work. Fortunately, it applies to this version of the derivative:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f = g circ h implies D f x = (D g h(x)) circ (D h x)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"By induction, this extends to a collection of N functions f_1 dots f_N:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f = f_N circ dots circ f_1 implies D f x = (D f_N x_N) circ dots circ (D f_1 x_1)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where x_n+1 = f(x_n), and x_1 = x.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"An aside: the definition of the Frechet Derivative","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This definition of the derivative has a name: the Frechet derivative. It is a generalisation of the Total Derivative. Formally, we say that a function f mathcalX to mathcalY is differentiable at a point x in mathcalX if there exists a linear operator D f x mathcalX to mathcalY (the derivative) satisfying","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"lim_textd h to 0 frac f(x + textd h) - f(x) + D f x (textd h) _mathcalY textdh _mathcalX = 0","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where cdot _mathcalX and cdot _mathcalY are the norms associated to Hilbert spaces mathcalX and mathcalY respectively. It is a good idea to consider what this looks like when mathcalX = mathcalY = RR and when mathcalX = mathcalY = RR^D. It is sometimes helpful to refer to this definition to e.g. verify the correctness of the derivative of a function – as with single-variable calculus, however, this is rare.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Another aside: what does Forwards-Mode AD compute?","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"At this point we have enough machinery to discuss forwards-mode AD. Expressed in the language of linear operators and Hilbert spaces, the goal of forwards-mode AD is the following: given a function f which is differentiable at a point x, compute D f x (dotx) for a given vector dotx. If f RR^P to RR^Q, this is equivalent to computing Jx dotx, where Jx is the Jacobian of f at x. For the interested reader we provide a high-level explanation of how forwards-mode AD does this in How does Forwards-Mode AD work?.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Another aside: notation","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"You may have noticed that we typically denote the argument to a derivative with a \"dot\" over it, e.g. dotx. This is something that we will do consistently, and we will use the same notation for the outputs of derivatives. Wherever you see a symbol with a \"dot\" over it, expect it to be an input or output of a derivative / forwards-mode AD.","category":"page"},{"location":"algorithmic_differentiation/#Reverse-Mode-AD:-*what*-does-it-do?","page":"Algorithmic Differentiation","title":"Reverse-Mode AD: what does it do?","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to explain what reverse-mode AD does, we first consider the \"vector-Jacobian product\" definition in Euclidean space which will be familiar to many readers. We then generalise.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Reverse-Mode AD: what does it do in Euclidean space?","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In this setting, the goal of reverse-mode AD is the following: given a function f RR^P to RR^Q which is differentiable at x in RR^P with Jacobian Jx at x, compute Jx^top bary for any bary in RR^Q. This is useful because we can obtain the gradient from this when Q = 1 by letting bary = 1.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Adjoint Operators","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to generalise this algorithm to work with linear operators, we must first generalise the idea of multiplying a vector by the transpose of the Jacobian. The relevant concept here is that of the adjoint operator. Specifically, the adjoint A^ast of linear operator A is the linear operator satisfying","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"langle A^ast bary dotx rangle = langle bary A dotx rangle","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where langle cdot cdot rangle denotes the inner-product. The relationship between the adjoint and matrix transpose is: if A (x) = J x for some matrix J, then A^ast (y) = J^top y.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Moreover, just as (A B)^top = B^top A^top when A and B are matrices, (A B)^ast = B^ast A^ast when A and B are linear operators. This result follows in short order from the definition of the adjoint operator – (and is a good exercise!)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Reverse-Mode AD: what does it do in general?","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Equipped with adjoints, we can express reverse-mode AD only in terms of linear operators, dispensing with the need to express everything in terms of Jacobians. The goal of reverse-mode AD is as follows: given a differentiable function f mathcalX to mathcalY, compute D f x^ast (bary) for some bary.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Notation: D f x^ast denotes the single mathematical object which is the adjoint of D f x. It is a linear function from mathcalY to mathcalX. We may occassionally write it as (D f x)^ast if there is some risk of confusion.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We will explain how reverse-mode AD goes about computing this after some worked examples.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Aside: Notation","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"You will have noticed that arguments to adjoints have thus far always had a \"bar\" over them, e.g. bary. This notation is common in the AD literature and will be used throughout. Additionally, this \"bar\" notation will be used for the outputs of adjoints of derivatives. So wherever you see a symbol with a \"bar\" over it, think \"input or output of adjoint of derivative\".","category":"page"},{"location":"algorithmic_differentiation/#Some-Worked-Examples","page":"Algorithmic Differentiation","title":"Some Worked Examples","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We now present some worked examples in order to prime intuition, and to introduce the important classes of problems that will be encountered when doing AD in the Julia language. We will put all of these problems in a single general framework later on.","category":"page"},{"location":"algorithmic_differentiation/#An-Example-with-Matrix-Calculus","page":"Algorithmic Differentiation","title":"An Example with Matrix Calculus","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We have introduced some mathematical abstraction in order to simplify the calculations involved in AD. To this end, we consider differentiating f(X) = X^top X. Results for this and similar operations are given by [1]. A similar operation, but which maps from matrices to RR is discussed in Lecture 4 part 2 of the MIT course mentioned previouly. Both [1] and Lecture 4 part 2 provide approaches to obtaining the derivative of this function.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Following either resource will yield the derivative:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f X (dotX) = dotX^top X + X^top dotX","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Observe that this is indeed a linear operator (i.e. it is linear in its argument, dotX). (You can always plug it in to the definition of the Frechet derivative to confirm that it is indeed the derivative.)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to perform reverse-mode AD, we need to find the adjoint operator. Using the usual definition of the inner product between matrices,","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"langle X Y rangle = textrmtr (X^top Y)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"we can rearrange the inner product as follows:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\n langle barY D f X (dotX) rangle = langle barY dotX^top X + X^top dotX rangle nonumber \n = textrmtr (barY^top dotX^top X) + textrmtr(barY^top X^top dotX) nonumber \n = textrmtr ( barY X^top^top dotX) + textrmtr( X barY^top dotX) nonumber \n = langle barY X^top + X barY dotX rangle nonumber\nendalign","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We can read off the adjoint operator from the first argument to the inner product:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f X^ast (barY) = barY X^top + X barY","category":"page"},{"location":"algorithmic_differentiation/#AD-of-a-Julia-function:-a-trivial-example","page":"Algorithmic Differentiation","title":"AD of a Julia function: a trivial example","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We now turn to differentiating Julia functions (we use function to refer to the programming language construct, and function to refer to a more general mathematical concept). The way that Mooncake.jl handles immutable data is very similar to how Zygote / ChainRules do. For example, consider the Julia function","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x::Float64) = sin(x)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"If you've previously worked with ChainRules / Zygote, without thinking too hard about the formalisms we introduced previously (perhaps by considering a variety of partial derivatives) you can probably arrive at the following adjoint for the derivative of f:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"g -> g * cos(x)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Implicitly, you have performed three steps:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"model f as a differentiable function,\ncompute its derivative, and\ncompute the adjoint of the derivative.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"It is helpful to work through this simple example in detail, as the steps involved apply more generally. The goal is to spell out the steps involved in detail, as this detail becomes helpful in more complicated examples. If at any point this exercise feels pedantic, we ask you to stick with it.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 1: Differentiable Mathematical Model","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Obviously, we model the Julia function f as the function f RR to RR where","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x) = sin(x)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Observe that, we've made (at least) two modelling assumptions here:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"a Float64 is modelled as a real number,\nthe Julia function sin is modelled as the usual mathematical function sin.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"As promised we're being quite pedantic. While the first assumption is obvious and will remain true, we will shortly see examples where we have to work a bit harder to obtain a correspondence between a Julia function and a mathematical object.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 2: Compute Derivative","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now that we have a mathematical model, we can differentiate it:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x (dotx) = cos(x) dotx","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 3: Compute Adjoint of Derivative","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Given the derivative, we can find its adjoint:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"langle barf D f x(dotx) rangle = langle barf cos(x) dotx rangle = langle cos(x) barf dotx rangle","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"From here the adjoint can be read off from the first argument to the inner product:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x^ast (barf) = cos(x) barf","category":"page"},{"location":"algorithmic_differentiation/#AD-of-a-Julia-function:-a-slightly-less-trivial-example","page":"Algorithmic Differentiation","title":"AD of a Julia function: a slightly less trivial example","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now consider the Julia function","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x::Float64, y::Tuple{Float64, Float64}) = x + y[1] * y[2]","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Its adjoint is going to be something along the lines of","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"g -> (g, (y[2] * g, y[1] * g))","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"As before, we work through in detail.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 1: Differentiable Mathematical Model","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"There are a couple of aspects of f which require thought:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"it has two arguments – we've only handled single argument functions previously, and\nthe second argument is a Tuple – we've not yet decided how to model this.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"To this end, we define a mathematical notion of a tuple. A tuple is a collection of N elements, each of which is drawn from some set mathcalX_n. We denote by mathcalX = mathcalX_1 times dots times mathcalX_N the set of all N-tuples whose nth element is drawn from mathcalX_n. Provided that each mathcalX_n forms a finite Hilbert space, mathcalX forms a Hilbert space with","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"alpha x = (alpha x_1 dots alpha x_N),\nx + y = (x_1 + y_1 dots x_N + y_N), and\nlangle x y rangle = sum_n=1^N langle x_n y_n rangle.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We can think of multi-argument functions as single-argument functions of a tuple, so a reasonable mathematical model for f might be a function f RR times RR times RR to RR, where","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x y) = x + y_1 y_2","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Note that while the function is written with two arguments, you should treat them as a single tuple, where we've assigned the name x to the first element, and y to the second.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 2: Compute Derivative","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now that we have a mathematical object, we can differentiate it:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x y(dotx doty) = dotx + doty_1 y_2 + y_1 doty_2","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 3: Compute Adjoint of Derivative","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D fx y maps RR times RR times RR to RR, so D f x y^ast must map the other way. You should verify that the following follows quickly from the definition of the adjoint:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x y^ast (barf) = (barf (barf y_2 barf y_1))","category":"page"},{"location":"algorithmic_differentiation/#AD-with-mutable-data","page":"Algorithmic Differentiation","title":"AD with mutable data","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In the previous two examples there was an obvious mathematical model for the Julia function. Indeed this model was sufficiently obvious that it required little explanation. This is not always the case though, in particular, Julia functions which modify / mutate their inputs require a little more thought.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Consider the following Julia function:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"function f!(x::Vector{Float64})\n x .*= x\n return sum(x)\nend","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This function squares each element of its input in-place, and returns the sum of the result. So what is an appropriate mathematical model for this function?","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 1: Differentiable Mathematical Model","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The trick is to distinguish between the state of x upon entry to / exit from f!. In particular, let phi_textf RR^N to RR^N times RR be given by","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"phi_textf(x) = (x odot x sum_n=1^N x_n^2)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where odot denotes the Hadamard / element-wise product (corresponds to line x .*= x in the above code). The point here is that the inputs to phi_textf are the inputs to x upon entry to f!, and the value returned from phi_textf is a tuple containing the both the inputs upon exit from f! and the value returned by f!.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The remaining steps are straightforward now that we have the model.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 2: Compute Derivative","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The derivative of phi_textf is","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D phi_textf x(dotx) = (2 x odot dotx 2 sum_n=1^N x_n dotx_n)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 3: Compute Adjoint of Derivative","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The argument to the adjoint of the derivative must be a 2-tuple whose elements are drawn from RR^N times RR . Denote such a tuple as (bary_1 bary_2). Plugging this into an inner product with the derivative and rearranging yields","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\n langle (bary_1 bary_2) D phi_textf x (dotx) rangle = langle (bary_1 bary_2) (2 x odot dotx 2 sum_n=1^N x_n dotx_n) rangle nonumber \n = langle bary_1 2 x odot dotx rangle + langle bary_2 2 sum_n=1^N x_n dotx_n rangle nonumber \n = langle 2x odot bary_1 dotx rangle + langle 2 bary_2 x dotx rangle nonumber \n = langle 2 (x odot bary_1 + bary_2 x) dotx rangle nonumber\nendalign","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"So we can read off the adjoint to be","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D phi_textf x^ast (bary) = 2 (x odot bary_1 + bary_2 x)","category":"page"},{"location":"algorithmic_differentiation/#Reverse-Mode-AD:-*how*-does-it-do-it?","page":"Algorithmic Differentiation","title":"Reverse-Mode AD: how does it do it?","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now that we know what it is that AD computes, we need a rough understanding of how it computes it.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In short: reverse-mode AD breaks down a \"complicated\" function f into the composition of a collection of \"simple\" functions f_1 dots f_N, applies the chain rule, and takes the adjoint.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Specifically, we assume that we can express any function f as f = f_N circ dots circ f_1, and that we can compute the adjoint of the derivative for each f_n. From this, we can obtain the adjoint of f by applying the chain rule to the derivatives and taking the adjoint:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\nD f x^ast = (D f_N x_N circ dots circ D f_1 x_1)^ast nonumber \n = D f_1 x_1^ast circ dots circ D f_N x_N^ast nonumber\nendalign","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"For example, suppose that f(x) = sin(cos(texttr(X^top X))). One option to compute its adjoint is to figure it out by hand directly (probably using the chain rule somewhere). Instead, we could notice that f = f_4 circ f_3 circ f_2 circ f_1 where f_4 = sin, f_3 = cos, f_2 = texttr and f_1(X) = X^top X. We could derive the adjoint for each of these functions (a fairly straightforward task), and then compute","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x^ast (bary) = (D f_1 x_1^ast circ D f_2 x_2^ast circ D f_3 x_3^ast circ D f_4 x_4^ast)(1)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"in order to obtain the gradient of f. Reverse-mode AD essentially just does this. Modern systems have hand-written adjoints for (hopefully!) all of the \"simple\" functions you may wish to build a function such as f from (often there are hundreds of these), and composes them to compute the adjoint of f. A sketch of a more generic algorithm is as follows.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Forwards-Pass:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"x_1 = x, n = 1\nconstruct D f_n x_n^ast\nlet x_n+1 = f_n (x_n)\nlet n = n + 1\nif n N + 1 then go to step 2.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Reverse-Pass:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"let barx_N+1 = bary\nlet n = n - 1\nlet barx_n = D f_n x_n^ast (barx_n+1)\nif n = 1 return barx_1 else go to step 2.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"How does this relate to vector-Jacobian products?","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In Euclidean space, each derivative D f_n x_n(dotx_n) = J_nx_n dotx_n. Applying the chain rule to D f x and substituting this in yields","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Jx = J_Nx_N dots J_1x_1 ","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Taking the transpose and multiplying from the left by bary yields","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Jx^top bary = Jx_1^top_1 dots Jx_N^top_N bary ","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Comparing this with the expression in terms of adjoints and operators, we see that composition of adjoints of derivatives has been replaced with multiplying by transposed Jacobian matrices. This \"vector-Jacobian product\" expression is commonly used to explain AD, and is likely familiar to many readers.","category":"page"},{"location":"algorithmic_differentiation/#Directional-Derivatives-and-Gradients","page":"Algorithmic Differentiation","title":"Directional Derivatives and Gradients","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now we turn to using reverse-mode AD to compute the gradient of a function. In short, given a function g mathcalX to RR with derivative D g x at x, its gradient is equal to D g x^ast (1). We explain why in this section.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The derivative discussed here can be used to compute directional derivatives. Consider a function f mathcalX to RR with Frechet derivative D f x mathcalX to RR at x in mathcalX. Then D fx(dotx) returns the directional derivative in direction dotx.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Gradients are closely related to the adjoint of the derivative. Recall that the gradient of f at x is defined to be the vector nabla f (x) in mathcalX such that langle nabla f (x) dotx rangle gives the directional derivative of f at x in direction dotx. Having noted that D fx(dotx) is exactly this directional derivative, we can equivalently say that","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D fx(dotx) = langle nabla f (x) dotx rangle ","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The role of the adjoint is revealed when we consider f = mathcall circ g, where g mathcalX to mathcalY, mathcall(y) = langle bary y rangle, and bary in mathcalY is some fixed vector. Noting that D mathcall y(doty) = langle bary doty rangle, we apply the chain rule to obtain","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\nD f x (dotx) = (D mathcall g(x)) circ (D g x)(dotx) nonumber \n = langle bary D g x (dotx) rangle nonumber \n = langle D g x^ast (bary) dotx rangle nonumber\nendalign","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"from which we conclude that D g x^ast (bary) is the gradient of the composition l circ g at x.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The consequence is that we can always view the computation performed by reverse-mode AD as computing the gradient of the composition of the function in question and an inner product with the argument to the adjoint.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The above shows that if mathcalY = RR and g is the function we wish to compute the gradient of, we can simply set bary = 1 and compute D g x^ast (bary) to obtain the gradient of g at x.","category":"page"},{"location":"algorithmic_differentiation/#Summary","page":"Algorithmic Differentiation","title":"Summary","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This document explains the core mathematical foundations of AD. It explains separately what is does, and how it goes about it. Some basic examples are given which show how these mathematical foundations can be applied to differentiate functions of matrices, and Julia functions.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Subsequent sections will build on these foundations, to provide a more general explanation of what AD looks like for a Julia programme.","category":"page"},{"location":"algorithmic_differentiation/#Asides","page":"Algorithmic Differentiation","title":"Asides","text":"","category":"section"},{"location":"algorithmic_differentiation/#*How*-does-Forwards-Mode-AD-work?","page":"Algorithmic Differentiation","title":"How does Forwards-Mode AD work?","text":"","category":"section"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Forwards-mode AD achieves this by breaking down f into the composition f = f_N circ dots circ f_1, where each f_n is a simple function whose derivative (function) D f_n x_n we know for any given x_n. By the chain rule, we have that","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x (dotx) = D f_N x_N circ dots circ D f_1 x_1 (dotx)","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"which suggests the following algorithm:","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"let x_1 = x, dotx_1 = dotx, and n = 1\nlet dotx_n+1 = D f_n x_n (dotx_n)\nlet x_n+1 = f(x_n)\nlet n = n + 1\nif n = N+1 then return dotx_N+1, otherwise go to 2.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"When each function f_n maps between Euclidean spaces, the applications of derivatives D f_n x_n (dotx_n) are given by J_n dotx_n where J_n is the Jacobian of f_n at x_n.","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"M. Giles. An extended collection of matrix derivative results for forward and reverse mode automatic differentiation. Unpublished (2008).\n\n\n\nT. P. Minka. Old and new matrix algebra useful for statistics. See www. stat. cmu. edu/minka/papers/matrix. html 4 (2000).\n\n\n\n","category":"page"},{"location":"algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"[note_for_geometers]: in AD we only really need to discuss differentiatiable functions between vector spaces that are isomorphic to Euclidean space. Consequently, a variety of considerations which are usually required in differential geometry are not required here. Notably, the tangent space is assumed to be the same everywhere, and to be the same as the domain of the function. Avoiding these additional considerations helps keep the mathematics as simple as possible.","category":"page"},{"location":"debugging_and_mwes/#Debugging-and-MWEs","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"","category":"section"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"There's a reasonable chance that you'll run into an issue with Mooncake.jl at some point. In order to debug what is going on when this happens, or to produce an MWE, it is helpful to have a convenient way to run Mooncake.jl on whatever function and arguments you have which are causing problems.","category":"page"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"We recommend making use of Mooncake.jl's testing functionality to generate your test cases:","category":"page"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"Mooncake.TestUtils.test_rule","category":"page"},{"location":"debugging_and_mwes/#Mooncake.TestUtils.test_rule","page":"Debugging and MWEs","title":"Mooncake.TestUtils.test_rule","text":"test_rule(\n rng, x...;\n interface_only=false,\n is_primitive::Bool=true,\n perf_flag::Symbol=:none,\n interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(),\n debug_mode::Bool=false,\n)\n\nRun standardised tests on the rule for x. The first element of x should be the primal function to test, and each other element a positional argument. In most cases, elements of x can just be the primal values, and randn_tangent can be relied upon to generate an appropriate tangent to test. Some notable exceptions exist though, in partcular Ptrs. In this case, the argument for which randn_tangent cannot be readily defined should be a CoDual containing the primal, and a manually constructed tangent field.\n\nThis function uses Mooncake.build_rrule to construct a rule. This will use an rrule!! if one exists, and derive a rule otherwise.\n\n\n\n\n\n","category":"function"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"This approach is convenient because it can","category":"page"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"check whether AD runs at all,\ncheck whether AD produces the correct answers,\ncheck whether AD is performant, and\ncan be used without having to manually generate tangents.","category":"page"},{"location":"debugging_and_mwes/#Example","page":"Debugging and MWEs","title":"Example","text":"","category":"section"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"DocTestSetup = quote\n using Random, Mooncake\nend","category":"page"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"For example","category":"page"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"f(x) = Core.bitcast(Float64, x)\nMooncake.TestUtils.test_rule(Random.Xoshiro(123), f, 3; is_primitive=false)","category":"page"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"will error. (In this particular case, it is caused by Mooncake.jl preventing you from doing (potentially) unsafe casting. In this particular instance, Mooncake.jl just fails to compile, but in other instances other things can happen.)","category":"page"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"In any case, the point here is that Mooncake.TestUtils.test_rule provides a convenient way to produce and report an error.","category":"page"},{"location":"debugging_and_mwes/#Segfaults","page":"Debugging and MWEs","title":"Segfaults","text":"","category":"section"},{"location":"debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"These are everyone's least favourite kind of problem, and they should be extremely rare in Mooncake.jl. However, if you are unfortunate enough to encounter one, please re-run your problem with the debug_mode kwarg set to true. See Debug Mode for more info. In general, this will catch problems before they become segfaults, at which point the above strategy for debugging and error reporting should work well.","category":"page"},{"location":"debug_mode/#Debug-Mode","page":"Debug Mode","title":"Debug Mode","text":"","category":"section"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"DocTestSetup = quote\n using Mooncake, ADTypes\nend","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"The Problem","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"A major source of potential problems in AD systems is rules returning the wrong type of tangent / fdata / rdata for a given primal value. For example, if someone writes a rule like","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"function rrule!!(::CoDual{typeof(+)}, x::CoDual{<:Real}, y::CoDual{<:Real})\n plus_reverse_pass(dz::Real) = NoRData(), dz, dz\n return zero_fcodual(primal(x) + primal(y))\nend","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"and calls","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"rrule(zero_fcodual(+), zero_fcodual(5.0), zero_fcodual(4f0))","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"then the type of dz on the reverse pass will be Float64 (assuming everything happens correctly), and this rule will return a Float64 as the rdata for y. However, the primal value of y is a Float32, so the appropriate tangent type is also Float32. This error may causes the reverse pass to fail loudly, but it may fail silently. It may cause an error much later in the reverse pass, making it hard to determine that the source of the error was the above rule. Worst of all, in some cases it could plausibly cause a segfault, which is more-or-less the worst kind of outcome possible.","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"The Solution","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Check that the types of the fdata / rdata associated to arguments are exactly what tangent_type / fdata_type / rdata_type require upon entry to / exit from rules and pullbacks.","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"This is implemented via DebugRRule:","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Mooncake.DebugRRule","category":"page"},{"location":"debug_mode/#Mooncake.DebugRRule","page":"Debug Mode","title":"Mooncake.DebugRRule","text":"DebugRRule(rule)\n\nConstruct a callable which is equivalent to rule, but inserts additional type checking. In particular:\n\ncheck that the fdata in each argument is of the correct type for the primal\ncheck that the fdata in the CoDual returned from the rule is of the correct type for the primal.\n\nThis happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.\n\nSome additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).\n\nLet rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.\n\nNote: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.\n\n\n\n\n\n","category":"type"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"You can straightforwardly enable it when building a rule via the debug_mode kwarg in the following:","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Mooncake.build_rrule","category":"page"},{"location":"debug_mode/#Mooncake.build_rrule","page":"Debug Mode","title":"Mooncake.build_rrule","text":"build_rrule(args...; debug_mode=false)\n\nHelper method. Only uses static information from args.\n\n\n\n\n\nbuild_rrule(sig::Type{<:Tuple})\n\nEquivalent to build_rrule(Mooncake.get_interpreter(), sig).\n\n\n\n\n\nbuild_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}\n\nReturns a DerivedRule which is an rrule!! for sig_or_mi in context C. See the docstring for rrule!! for more info.\n\nIf debug_mode is true, then all calls to rules are replaced with calls to DebugRRules.\n\n\n\n\n\n","category":"function"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"When using ADTypes.jl, you can choose whether or not to use it via the debug_mode kwarg:","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"julia> AutoMooncake(; config=Mooncake.Config(; debug_mode=true))\nAutoMooncake{Mooncake.Config}(Mooncake.Config(true, false))","category":"page"},{"location":"debug_mode/#When-Should-You-Use-Debug-Mode?","page":"Debug Mode","title":"When Should You Use Debug Mode?","text":"","category":"section"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Only use debug_mode when debugging a problem. This is because is has substantial performance implications.","category":"page"},{"location":"debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"DocTestSetup = nothing","category":"page"},{"location":"#Mooncake.jl","page":"Mooncake.jl","title":"Mooncake.jl","text":"","category":"section"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Documentation for Mooncake.jl is on its way!","category":"page"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Note (03/10/2024): Various bits of utility functionality are now carefully documented. This includes how to change the code which Mooncake sees, declare that the derivative of a function is zero, make use of existing ChainRules.rrules to quicky create new rules in Mooncake, and more.","category":"page"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Note (02/07/2024): The first round of documentation has arrived. This is largely targetted at those who are interested in contributing to Mooncake.jl – you can find this work in the \"Understanding Mooncake.jl\" section of the docs. There is more to to do, but it should be sufficient to understand how AD works in principle, and the core abstractions underlying Mooncake.jl.","category":"page"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Note (29/05/2024): I (Will) am currently actively working on the documentation. It will be merged in chunks over the next month or so as good first drafts of sections are completed. Please don't be alarmed that not all of it is here!","category":"page"},{"location":"known_limitations/#Known-Limitations","page":"Known Limitations","title":"Known Limitations","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Mooncake.jl has a number of known qualitative limitations, which we document here.","category":"page"},{"location":"known_limitations/#Mutation-of-Global-Variables","page":"Known Limitations","title":"Mutation of Global Variables","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"DocTestSetup = quote\n using Mooncake\nend","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"While great care is taken in this package to prevent silent errors, this is one edge case that we have yet to provide a satisfactory solution for. Consider a function of the form:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"julia> const x = Ref(1.0);\n\njulia> function foo(y::Float64)\n x[] = y\n return x[]\n end\nfoo (generic function with 1 method)","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"x is a global variable (if you refer to it in your code, it appears as a GlobalRef in the AST or lowered code). For some technical reasons that are beyond the scope of this section, this package cannot propagate gradient information through x. foo is the identity function, so it should have gradient 1.0. However, if you differentiate this example, you'll see:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"julia> rule = Mooncake.build_rrule(foo, 2.0);\n\njulia> Mooncake.value_and_gradient!!(rule, foo, 2.0)\n(2.0, (NoTangent(), 0.0))","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Observe that while it has correctly computed the identity function, the gradient is zero.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The takehome: do not attempt to differentiate functions which modify global state. Uses of globals which does not involve mutating them is fine though.","category":"page"},{"location":"known_limitations/#Circular-References","page":"Known Limitations","title":"Circular References","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"To a large extent, Mooncake.jl does not presently support circular references in an automatic fashion. It is generally possible to hand-write solutions, so we explain some of the problems here, and the general approach to resolving them.","category":"page"},{"location":"known_limitations/#Tangent-Types","page":"Known Limitations","title":"Tangent Types","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Problem","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Suppose that you have a type such as:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"mutable struct A\n x::Float64\n a::A\n function A(x::Float64)\n a = new(x)\n a.a = a\n return a\n end\nend","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is a fairly canonical example of a self-referential type. There are a couple of things which will not work with it out-of-the-box. tangent_type(A) will produce a stack overflow error. To see this, note that it will in effect try to produce a tangent of type Tangent{Tuple{tangent_type(A)}} – the circular dependency on the tangent_type function causes real problems here.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Solution","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"In order to resolve this, you need to produce a tangent type by hand. You might go with something like","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"mutable struct TangentForA\n x::Float64 # tangent type for Float64 is Float64\n a::TangentForA\n function TangentForA(x::Float64)\n a = new(x)\n a.a = a\n return a\n end\nend","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The point here is that you can manually resolve the circular dependency using a data structure which mimics the primal type. You will, however, need to implement similar methods for zero_tangent, randn_tangent, etc, and presumably need to implement additional getfield and setfield rules which are specific to this type.","category":"page"},{"location":"known_limitations/#Circular-References-in-General","page":"Known Limitations","title":"Circular References in General","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Problem","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Consider a type of the form","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"mutable struct Foo\n x\n Foo() = new()\nend","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"In this instance, tangent_type will work fine because Foo does not directly reference itself in its definition. Moreover, general uses of Foo will be fine.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"However, it's possible to construct an instance of Foo with a circular reference:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"f = Foo()\nf.x = f","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is actually fine provided we never attempt to call zero_tangent / randn_tangent / similar functionality on f once we've set its x field to itself. If we attempt to call such a function, we'll find ourselves with a stack overflow.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Solution This is a little tricker to handle. You could specialise zero_tangent etc for Foo, but this is something of a pain. Fortunately, it seems to be incredibly rare that this is ever a problem in practice. If we gain evidence that this is often a problem in practice, we'll look into supporting zero_tangent etc automatically for this case.","category":"page"},{"location":"known_limitations/#Tangent-Generation-and-Pointers","page":"Known Limitations","title":"Tangent Generation and Pointers","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Problem","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"In many use cases, a pointer provides the address of the start of a block of memory which has been allocated to e.g. store an array. However, we cannot get any of this context from the pointer itself – by just looking at a pointer, I cannot know whether its purpose is to refer to the start of a large block of memory, some proportion of the way through a block of memory, or even to keep track of a single address.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Recall that the tangent to a pointer is another pointer:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"julia> Mooncake.tangent_type(Ptr{Float64})\nPtr{Float64}","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Plainly I cannot implement a method of zero_tangent for Ptr{Float64} because I don't know how much memory to allocate.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is, however, fine if a pointer appears half way through a function, having been derived from another data structure. e.g.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"function foo(x::Vector{Float64})\n p = pointer(x, 2)\n return unsafe_load(p)\nend\n\nrule = build_rrule(Tuple{typeof(foo), Vector{Float64}})\nMooncake.value_and_gradient!!(rule, foo, [5.0, 4.0])\n\n# output\n(4.0, (NoTangent(), [0.0, 1.0]))","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Solution","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is only really a problem for tangent / fdata / rdata generation functionality, such as zero_tangent. As a work-around, AD testing functionality permits users to pass in CoDuals. So if you are testing something involving a pointer, you will need to construct its tangent yourself, and pass a CoDual to e.g. Mooncake.TestUtils.test_rule.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"While pointers tend to be a low-level implementation detail in Julia code, you could in principle actually be interested in differentiating a function of a pointer. In this case, you will not be able to use Mooncake.value_and_gradient!! as this requires the use of zero_tangent. Instead, you will need to use lower-level (internal) functionality, such as Mooncake.__value_and_gradient!!, or use the rule interface directly.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Honestly, your best bet is just to avoid differentiating functions whose arguments are pointers if you can.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"DocTestSetup = nothing","category":"page"}] } diff --git a/dev/tools_for_rules/index.html b/dev/tools_for_rules/index.html index 02f84c04c..dffc83370 100644 --- a/dev/tools_for_rules/index.html +++ b/dev/tools_for_rules/index.html @@ -1,5 +1,5 @@ -Tools for Rules · Mooncake.jl

Tools for Rules

Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. However, this does not always necessitate writing your own rrule!! from scratch. In this section, we detail some useful strategies which can help you avoid having to write rrule!!s in many situations.

Simplfiying Code via Overlays

Mooncake.@mooncake_overlayMacro
@mooncake_overlay method_expr

Define a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.

For example, suppose that you have a function

julia> foo(x::Float64) = bar(x)
+Tools for Rules · Mooncake.jl

Tools for Rules

Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. However, this does not always necessitate writing your own rrule!! from scratch. In this section, we detail some useful strategies which can help you avoid having to write rrule!!s in many situations.

Simplfiying Code via Overlays

Mooncake.@mooncake_overlayMacro
@mooncake_overlay method_expr

Define a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.

For example, suppose that you have a function

julia> foo(x::Float64) = bar(x)
 foo (generic function with 1 method)

where Mooncake.jl fails to differentiate bar for some reason. If you have access to another function baz, which does the same thing as bar, but does so in a way which Mooncake.jl can differentiate, you can simply write:

julia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)
 

When looking up the code for foo(::Float64), Mooncake.jl will see this method, rather than the original, and differentiate it instead.

A Worked Example

To demonstrate how to use @mooncake_overlays in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!

First, consider a simple example:

julia> scale(x) = 2x
 scale (generic function with 1 method)
@@ -19,7 +19,7 @@
 julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});
 
 julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
-(20.0, (NoTangent(), 4.0))
source

Functions with Zero Adjoint

If the above strategy does not work, but you find yourself in the surprisingly common situation that the adjoint of the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following:

Mooncake.@zero_adjointMacro
@zero_adjoint ctx sig

Defines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.

For example:

julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
+(20.0, (NoTangent(), 4.0))
source

Functions with Zero Adjoint

If the above strategy does not work, but you find yourself in the surprisingly common situation that the adjoint of the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following:

Mooncake.@zero_adjointMacro
@zero_adjoint ctx sig

Defines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.

For example:

julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
 
 julia> foo(x) = 5
 foo (generic function with 1 method)
@@ -41,7 +41,7 @@
 true
 
 julia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())
-(NoRData(), 0.0, NoRData())

Be aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.

WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.

As always, you should use TestUtils.test_rule to ensure that you've not made a mistake.

Signatures Unsupported By This Macro

If the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.

source
Mooncake.zero_adjointFunction
zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}

Utility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.

NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.

You make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:

julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual
+(NoRData(), 0.0, NoRData())

Be aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.

WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.

As always, you should use TestUtils.test_rule to ensure that you've not made a mistake.

Signatures Unsupported By This Macro

If the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.

source
Mooncake.zero_adjointFunction
zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}

Utility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.

NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.

You make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:

julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual
 
 julia> foo(x::Vararg{Int}) = 5
 foo (generic function with 1 method)
@@ -51,7 +51,7 @@
 julia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);
 
 julia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())
-(NoRData(), NoRData(), NoRData())

WARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```

source

Using ChainRules.jl

ChainRules.jl provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the ChainRulesCore.rrule function. There are some instances where it is most convenient to implement a Mooncake.rrule!! by wrapping an existing ChainRulesCore.rrule.

There is enough similarity between these two systems that most of the boilerplate code can be avoided.

Mooncake.@from_rruleMacro
@from_rrule ctx sig [has_kwargs=false]

Convenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.

Arguments

  • ctx: A Mooncake context type
  • sig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.
  • has_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.

Example Usage

A Basic Example

julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
+(NoRData(), NoRData(), NoRData())

WARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```

source

Using ChainRules.jl

ChainRules.jl provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the ChainRulesCore.rrule function. There are some instances where it is most convenient to implement a Mooncake.rrule!! by wrapping an existing ChainRulesCore.rrule.

There is enough similarity between these two systems that most of the boilerplate code can be avoided.

Mooncake.@from_rruleMacro
@from_rrule ctx sig [has_kwargs=false]

Convenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.

Arguments

  • ctx: A Mooncake context type
  • sig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.
  • has_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.

Example Usage

A Basic Example

julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
 
 julia> using ChainRulesCore
 
@@ -96,4 +96,4 @@
        TestUtils.test_rule(
            Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true
        )
-Test Passed

Notice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.

Limitations

It is your responsibility to ensure that

  1. calls with signature sig do not mutate their arguments,
  2. the output of calls with signature sig does not alias any of the inputs.

As with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.

Argument Type Constraints

Many methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature

Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}

There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.

Suffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.

Conversions Between Different Tangent Type Systems

Under the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.

source
+Test Passed

Notice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.

Limitations

It is your responsibility to ensure that

  1. calls with signature sig do not mutate their arguments,
  2. the output of calls with signature sig does not alias any of the inputs.

As with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.

Argument Type Constraints

Many methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature

Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}

There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.

Suffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.

Conversions Between Different Tangent Type Systems

Under the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.

source
diff --git a/dev/understanding_intro/index.html b/dev/understanding_intro/index.html index 561d7cf5a..e560c1924 100644 --- a/dev/understanding_intro/index.html +++ b/dev/understanding_intro/index.html @@ -1,2 +1,2 @@ -Introduction · Mooncake.jl

Mooncake.jl and Reverse-Mode AD

The point of Mooncake.jl is to perform reverse-mode algorithmic differentiation (AD). The purpose of this section is to explain what precisely is meant by this, and how it can be interpreted mathematically.

  1. we recap what AD is, and introduce the mathematics necessary to understand is,
  2. explain how this mathematics relates to functions and data structures in Julia, and
  3. how this is handled in Mooncake.jl.

Since Mooncake.jl supports in-place operations / mutation, these will push beyond what is encountered in Zygote / Diffractor / ChainRules. Consequently, while there is a great deal of overlap with these existing systems, you will need to read through this section of the docs in order to properly understand Mooncake.jl.

Who Are These Docs For?

These are primarily designed for anyone who is interested in contributing to Mooncake.jl. They are also hopefully of interest to anyone how is interested in understanding AD more broadly. If you aren't interested in understanding how Mooncake.jl and AD work, you don't need to have read them in order to make use of this package.

Prerequisites and Resources

This introduction assumes familiarity with the differentiation of vector-valued functions – familiarity with the gradient and Jacobian matrices is a given.

In order to provide a convenient exposition of AD, we need to abstract a little further than this and make use of a slightly more general notion of the derivative, gradient, and "transposed Jacobian". Please note that, fortunately, we only ever have to handle finite dimensional objects when doing AD, so there is no need for any knowledge of functional analysis to understand what is going on here. The required concepts will be introduced here, but I cannot promise that these docs give the best exposition – they're most appropriate as a refresher and to establish notation. Rather, I would recommend a couple of lectures from the "Matrix Calculus for Machine Learning and Beyond" course, which you can find on MIT's OCW website, delivered by Edelman and Johnson (who will be familiar faces to anyone who has spent much time in the Julia world!). It is designed for undergraduates, and is accessible to anyone with some undergraduate-level linear algebra and calculus. While I recommend the whole course, Lecture 1 part 2 and Lecture 4 part 1 are especially relevant to the problems we shall discuss – you can skip to 11:30 in Lecture 4 part 1 if you're in a hurry.

+Introduction · Mooncake.jl

Mooncake.jl and Reverse-Mode AD

The point of Mooncake.jl is to perform reverse-mode algorithmic differentiation (AD). The purpose of this section is to explain what precisely is meant by this, and how it can be interpreted mathematically.

  1. we recap what AD is, and introduce the mathematics necessary to understand is,
  2. explain how this mathematics relates to functions and data structures in Julia, and
  3. how this is handled in Mooncake.jl.

Since Mooncake.jl supports in-place operations / mutation, these will push beyond what is encountered in Zygote / Diffractor / ChainRules. Consequently, while there is a great deal of overlap with these existing systems, you will need to read through this section of the docs in order to properly understand Mooncake.jl.

Who Are These Docs For?

These are primarily designed for anyone who is interested in contributing to Mooncake.jl. They are also hopefully of interest to anyone how is interested in understanding AD more broadly. If you aren't interested in understanding how Mooncake.jl and AD work, you don't need to have read them in order to make use of this package.

Prerequisites and Resources

This introduction assumes familiarity with the differentiation of vector-valued functions – familiarity with the gradient and Jacobian matrices is a given.

In order to provide a convenient exposition of AD, we need to abstract a little further than this and make use of a slightly more general notion of the derivative, gradient, and "transposed Jacobian". Please note that, fortunately, we only ever have to handle finite dimensional objects when doing AD, so there is no need for any knowledge of functional analysis to understand what is going on here. The required concepts will be introduced here, but I cannot promise that these docs give the best exposition – they're most appropriate as a refresher and to establish notation. Rather, I would recommend a couple of lectures from the "Matrix Calculus for Machine Learning and Beyond" course, which you can find on MIT's OCW website, delivered by Edelman and Johnson (who will be familiar faces to anyone who has spent much time in the Julia world!). It is designed for undergraduates, and is accessible to anyone with some undergraduate-level linear algebra and calculus. While I recommend the whole course, Lecture 1 part 2 and Lecture 4 part 1 are especially relevant to the problems we shall discuss – you can skip to 11:30 in Lecture 4 part 1 if you're in a hurry.