Skip to content

Commit

Permalink
Merge pull request #85 from liesel-devs/fix-linreg-tutorial
Browse files Browse the repository at this point in the history
Correct sampling for sigma_sq
  • Loading branch information
hriebl authored Sep 7, 2023
2 parents e3eec6e + 2ed4609 commit 332c9de
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 20 deletions.
41 changes: 22 additions & 19 deletions docs/source/tutorials/qmd/01-lin-reg.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,16 @@ beta = lsl.Param(value=np.array([0.0, 0.0]), distribution=beta_dist,name="beta")

#### The standard deviation

The second branch of the tree contains the residual standard deviation. We build it in a similar way, but this time, using the weakly informative prior $\sigma \sim \text{InverseGamme}(0.01, 0.01)$. Again, we use the parameter names based on TFP.
The second branch of the tree contains the residual standard deviation. We build it in a similar way, but this time, using the weakly informative prior $\sigma^2 \sim \text{InverseGamme}(0.01, 0.01)$ on the squared standard deviation, i.e. the variance. Again, we use the parameter names based on TFP.

```{python}
#| label: standard-deviation-node
sigma_a = lsl.Var(0.01, name="a")
sigma_b = lsl.Var(0.01, name="b")
a = lsl.Var(0.01, name="a")
b = lsl.Var(0.01, name="b")
sigma_dist = lsl.Dist(tfd.InverseGamma, concentration=sigma_a, scale=sigma_b)
sigma = lsl.Param(value=10.0, distribution=sigma_dist, name="sigma")
sigma_sq_dist = lsl.Dist(tfd.InverseGamma, concentration=a, scale=b)
sigma_sq = lsl.Param(value=10.0, distribution=sigma_sq_dist, name="sigma_sq")
sigma = lsl.Var(lsl.Calc(jnp.sqrt, sigma_sq), name="sigma").update()
```

#### Design matrix, fitted values, and response
Expand Down Expand Up @@ -215,7 +216,7 @@ The individual nodes also have a `log_prob` property. In fact, because of the co

```{python}
#| label: log-prob-demo
beta.log_prob.sum() + sigma.log_prob + y.log_prob.sum()
beta.log_prob.sum() + sigma_sq.log_prob + y.log_prob.sum()
```

Nodes without a probability distribution return a log-probability of zero.
Expand All @@ -225,18 +226,20 @@ Nodes without a probability distribution return a log-probability of zero.
beta_loc.log_prob
```

The log-probability of a node depends on its value and its inputs. Thus, if we change the standard deviation of the response from 10 to 1, the log-probability of the corresponding node, the log-probability of the response node, and the log-probability of the model change as well.
The log-probability of a node depends on its value and its inputs. Thus, if we change the variance of the response from 10 to 1, the log-probability of the corresponding node, the log-probability of the response node, and the log-probability of the model change as well.
Note that, since the actual input to the response distribution is the standard deviation $\sigma$, we have to update its value after changing the value of $\sigma^2$.

```{python}
#| label: log-prob-updating-demo
print(f"Old value of sigma: {sigma.value}")
print(f"Old log-prob of sigma: {sigma.log_prob}")
print(f"Old value of sigma_sq: {sigma_sq.value}")
print(f"Old log-prob of sigma_sq: {sigma_sq.log_prob}")
print(f"Old log-prob of y: {y.log_prob.sum()}\n")
sigma.value = 1.0
sigma_sq.value = 1.0
sigma.update()
print(f"New value of sigma: {sigma.value}")
print(f"New log-prob of sigma: {sigma.log_prob}")
print(f"New value of sigma_sq: {sigma_sq.value}")
print(f"New log-prob of sigma_sq: {sigma_sq.log_prob}")
print(f"New log-prob of y: {y.log_prob.sum()}\n")
print(f"New model log-prob: {model.log_prob}")
Expand All @@ -252,7 +255,7 @@ We start with a very simple sampling scheme, keeping $\sigma$ fixed at the true

```{python}
#| label: goose-MCMC-engine-setup
sigma.value = true_sigma
sigma_sq.value = true_sigma**2
builder = gs.EngineBuilder(seed=1337, num_chains=4)
Expand Down Expand Up @@ -291,14 +294,14 @@ No compilation is required at this point, so this is pretty fast.

### Using a Gibbs kernel

So far, we have not sampled our standard deviation parameter `sigma`; we simply fixed it to zero. Now we extend our model with a Gibbs sampler for `sigma`. Using a Gibbs kernel is a bit more complicated, because Goose doesn't automatically derive the full conditional from the model graph. Hence, the user needs to provide a function to sample from the full conditional. The function needs to accept a PRNG state and a model state as arguments, and it needs to return a dictionary with the node name as the key and the new node value as the value. We could also update multiple parameters with one Gibbs kernel if we returned a dictionary of length two or more.
So far, we have not sampled our variance parameter `sigma_sq`; we simply fixed it to the true value of one. Now we extend our model with a Gibbs sampler for `sigma_sq`. Using a Gibbs kernel is a bit more complicated, because Goose doesn't automatically derive the full conditional from the model graph. Hence, the user needs to provide a function to sample from the full conditional. The function needs to accept a PRNG state and a model state as arguments, and it needs to return a dictionary with the node name as the key and the new node value as the value. We could also update multiple parameters with one Gibbs kernel if we returned a dictionary of length two or more.

To retrieve the values of our nodes from the `model_state`, we need to add the suffix `_value` behind the nodes' names. Likewise, the node name in returned dictionary needs to have the
added `_value` suffix.

```{python}
#| label: define-transition-function
def draw_sigma(prng_key, model_state):
def draw_sigma_sq(prng_key, model_state):
a_prior = model_state["a_value"].value
b_prior = model_state["b_value"].value
n = len(model_state["y_value"].value)
Expand All @@ -308,7 +311,7 @@ def draw_sigma(prng_key, model_state):
a_gibbs = a_prior + n / 2
b_gibbs = b_prior + jnp.sum(resid**2) / 2
draw = b_gibbs / jax.random.gamma(prng_key, a_gibbs)
return {"sigma_value": draw}
return {"sigma_sq_value": draw}
```

We build the engine in a similar way as before, but this time adding the Gibbs kernel as well.
Expand All @@ -321,7 +324,7 @@ builder.set_model(lsl.GooseModel(model))
builder.set_initial_values(model.state)
builder.add_kernel(gs.NUTSKernel(["beta"]))
builder.add_kernel(gs.GibbsKernel(["sigma"], draw_sigma))
builder.add_kernel(gs.GibbsKernel(["sigma_sq"], draw_sigma_sq))
builder.set_duration(warmup_duration=1000, posterior_duration=1000)
Expand All @@ -334,7 +337,7 @@ Goose provides a couple of convenient numerical and graphical summary tools. The
```{python results="asis"}
#| label: results-fo-sampling-with-gibbs-sampler
results = engine.get_results()
gs.Summary.from_result(results)
gs.Summary(results)
```

We can plot the trace plots of the chains with {func}`.plot_trace()`.
Expand All @@ -352,5 +355,5 @@ gs.plot_param(results, param="beta", param_index=0)
```


Here, we end this first tutorial. We have learned about a lot of different classes and
Here, we end this first tutorial. We have learned about a lot of different classes and we have
seen how we can flexibly use different Kernels for drawing MCMC samples - that is quite a bit for the start. Now, have fun modelling with Liesel!
2 changes: 1 addition & 1 deletion docs/source/tutorials/qmd/01a-transform.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ y_dist = lsl.Dist(tfd.Normal, loc=y_hat, scale=sigma)
y = lsl.Var(y_vec, distribution=y_dist, name="y")
```

Now let's try to sample the full parameter vector $(\boldsymbol{\beta}', \sigma)'$ with a single NUTS kernel instead of using a NUTS kernel for $\boldsymbol{\beta}$ and a Gibbs kernel for $\sigma$. Since the standard deviation is a positive-valued parameter, we need to log-transform it to sample it with a NUTS kernel. The {class}`.GraphBuilder` class provides the {meth}`.transform_parameter` method for this purpose.
Now let's try to sample the full parameter vector $(\boldsymbol{\beta}', \sigma)'$ with a single NUTS kernel instead of using a NUTS kernel for $\boldsymbol{\beta}$ and a Gibbs kernel for $\sigma^2$. Since the standard deviation is a positive-valued parameter, we need to log-transform it to sample it with a NUTS kernel. The {class}`.GraphBuilder` class provides the {meth}`.transform_parameter` method for this purpose.

```{python}
#| label: graph-and-transformation
Expand Down

0 comments on commit 332c9de

Please sign in to comment.