Skip to content

Commit

Permalink
Fix doc mistakes (#701)
Browse files Browse the repository at this point in the history
* Fix equation formatting

* Clarify JAX gradient error

* Fix punctuation + capitalization

* Fix grammar

Should not begin sentence with "i.e." in English.

* Fix math formatting error

* Fix typo

Change parallel _ensample_ chain adaptation to parallel _ensemble_ chain adaptation.

* Add SVGD citation to appear in doc

Currently the SVGD paper is only cited in the `kernel` function, which is defined _within_ the `build_kernel` function. Because of this nested function format, the SVGD paper is _not_ cited in the documentation.

To fix this, I added a citation to the SVGD paper in the `as_top_level_api` docstring.

* Fix grammar + clarify doc

* Fix typo

---------

Co-authored-by: Junpeng Lao <[email protected]>
  • Loading branch information
gil2rok and junpenglao authored Jun 24, 2024
1 parent 5764a2b commit f8db9aa
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 11 deletions.
4 changes: 2 additions & 2 deletions blackjax/adaptation/meads_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MEADSAdaptationState(NamedTuple):
alpha
Value of the alpha parameter of the generalized HMC algorithm.
delta
Value of the alpha parameter of the generalized HMC algorithm.
Value of the delta parameter of the generalized HMC algorithm.
"""

Expand All @@ -60,7 +60,7 @@ def base():
with shape.
This is an implementation of Algorithm 3 of :cite:p:`hoffman2022tuning` using cross-chain
adaptation instead of parallel ensample chain adaptation.
adaptation instead of parallel ensemble chain adaptation.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion blackjax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class SamplingAlgorithm(NamedTuple):
"""A pair of functions that represents a MCMC sampling algorithm.
Blackjax sampling algorithms are implemented as a pair of pure functions: a
kernel, that takes a new samples starting from the current state, and an
kernel, that generates a new sample from the current state, and an
initialization function that creates a kernel state from a chain position.
As they represent Markov kernels, the kernel functions are pure functions
Expand Down
2 changes: 1 addition & 1 deletion blackjax/vi/svgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def as_top_level_api(
kernel: Callable = rbf_kernel,
update_kernel_parameters: Callable = update_median_heuristic,
):
"""Implements the (basic) user interface for the svgd algorithm.
"""Implements the (basic) user interface for the svgd algorithm :cite:p:`liu2016stein`.
Parameters
----------
Expand Down
13 changes: 6 additions & 7 deletions docs/examples/howto_custom_gradients.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ Functions can be defined as the minimum of another one, $f(x) = min_{y} g(x,y)$.
Our example is taken from the theory of [convex conjugates](https://en.wikipedia.org/wiki/Convex_conjugate), used for example in optimal transport. Let's consider the following function:

$$
\begin{align*}
g(x, y) &= h(y) - \langle x, y\rangle\\
h(x) &= \frac{1}{p}|x|^p,\qquad p > 1.\\
\end{align*}
\begin{equation*}
g(x, y) = h(y) - \langle x, y\rangle,\qquad h(x) = \frac{1}{p}|x|^p,\qquad p > 1.
\end{equation*}
$$

And define the function $f$ as $f(x) = -min_y g(x, y)$ which we can be implemented as:
Expand Down Expand Up @@ -69,7 +68,7 @@ Note the we also return the value of $y$ where the minimum of $g$ is achieved (t

### Trying to differentate the function with `jax.grad`

The gradient of the function $f$ is undefined for JAX, which cannot differentiate through `while` loops, and trying to compute it directly raises an error:
The gradient of the function $f$ is undefined for JAX, which cannot differentiate through `while` loops used in BFGS, and trying to compute it directly raises an error:

```{code-cell} ipython3
# We only want the gradient with respect to `x`
Expand Down Expand Up @@ -97,15 +96,15 @@ The first order optimality criterion
\end{equation*}
```

Ensures that:
ensures that

```{math}
\begin{equation*}
\frac{df}{dx} = y(x).
\end{equation*}
```

i.e. the value of the derivative at $x$ is the value $y(x)$ at which the minimum of the function $g$ is achieved.
In other words, the value of the derivative at $x$ is the value $y(x)$ at which the minimum of the function $g$ is achieved.


### Telling JAX to use a custom gradient
Expand Down

0 comments on commit f8db9aa

Please sign in to comment.