Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix doc mistakes #701

Merged
merged 10 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions blackjax/adaptation/meads_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MEADSAdaptationState(NamedTuple):
alpha
Value of the alpha parameter of the generalized HMC algorithm.
delta
Value of the alpha parameter of the generalized HMC algorithm.
Value of the delta parameter of the generalized HMC algorithm.

"""

Expand All @@ -60,7 +60,7 @@ def base():
with shape.

This is an implementation of Algorithm 3 of :cite:p:`hoffman2022tuning` using cross-chain
adaptation instead of parallel ensample chain adaptation.
adaptation instead of parallel ensemble chain adaptation.

Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion blackjax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class SamplingAlgorithm(NamedTuple):
"""A pair of functions that represents a MCMC sampling algorithm.

Blackjax sampling algorithms are implemented as a pair of pure functions: a
kernel, that takes a new samples starting from the current state, and an
kernel, that generates a new sample from the current state, and an
initialization function that creates a kernel state from a chain position.

As they represent Markov kernels, the kernel functions are pure functions
Expand Down
2 changes: 1 addition & 1 deletion blackjax/vi/svgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def as_top_level_api(
kernel: Callable = rbf_kernel,
update_kernel_parameters: Callable = update_median_heuristic,
):
"""Implements the (basic) user interface for the svgd algorithm.
"""Implements the (basic) user interface for the svgd algorithm :cite:p:`liu2016stein`.

Parameters
----------
Expand Down
13 changes: 6 additions & 7 deletions docs/examples/howto_custom_gradients.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ Functions can be defined as the minimum of another one, $f(x) = min_{y} g(x,y)$.
Our example is taken from the theory of [convex conjugates](https://en.wikipedia.org/wiki/Convex_conjugate), used for example in optimal transport. Let's consider the following function:

$$
\begin{align*}
g(x, y) &= h(y) - \langle x, y\rangle\\
h(x) &= \frac{1}{p}|x|^p,\qquad p > 1.\\
\end{align*}
\begin{equation*}
g(x, y) = h(y) - \langle x, y\rangle,\qquad h(x) = \frac{1}{p}|x|^p,\qquad p > 1.
\end{equation*}
$$

And define the function $f$ as $f(x) = -min_y g(x, y)$ which we can be implemented as:
Expand Down Expand Up @@ -69,7 +68,7 @@ Note the we also return the value of $y$ where the minimum of $g$ is achieved (t

### Trying to differentate the function with `jax.grad`

The gradient of the function $f$ is undefined for JAX, which cannot differentiate through `while` loops, and trying to compute it directly raises an error:
The gradient of the function $f$ is undefined for JAX, which cannot differentiate through `while` loops used in BFGS, and trying to compute it directly raises an error:

```{code-cell} ipython3
# We only want the gradient with respect to `x`
Expand Down Expand Up @@ -97,15 +96,15 @@ The first order optimality criterion
\end{equation*}
```

Ensures that:
ensures that

```{math}
\begin{equation*}
\frac{df}{dx} = y(x).
\end{equation*}
```

i.e. the value of the derivative at $x$ is the value $y(x)$ at which the minimum of the function $g$ is achieved.
In other words, the value of the derivative at $x$ is the value $y(x)$ at which the minimum of the function $g$ is achieved.


### Telling JAX to use a custom gradient
Expand Down
Loading