Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a Bayesian CNN on the MNIST dataset #14

Open
rlouf opened this issue Jun 29, 2022 · 5 comments
Open

Use a Bayesian CNN on the MNIST dataset #14

rlouf opened this issue Jun 29, 2022 · 5 comments
Assignees
Labels
help wanted Extra attention is needed Model Add a new model to the book

Comments

@rlouf
Copy link
Member

rlouf commented Jun 29, 2022

Blackjax already has an example where we use SGLD to sample from a 3 layer MLP with a very decent accuracy when using the uncertainties to discard ambiguous predictions. We can use the CNN architecture in the Flax documentation:

from flax import linen as nn  

class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)

    return x

And the logprob function as (not tested):

from jax.tree_utils import flatten_pytree
import distrax

def logpdf(params, images, categories, model):
    logits = model.apply(params, images).ravel()
    flat_params, _ = ravel_pytree(params)
    log_prior = distrax.Normal(0.0, 1.0).log_prob(flat_params).sum()
    log_likelihood = distrax.Bernoulli(logits=logits).log_prob(categories).sum()
 
    return log_prior + log_likelihood

We should look at:

  • Comparison between SgLD and SgHMC (#211)
  • Raw accuracy compared to a solution that uses SGD (with Optax)
  • Show the distribution of "confidence" in predictions
  • Accuracy once we've removed examples where model is not sure
  • Examples where the model is not sure / proportion of examples where it is not sure
@rlouf rlouf added documentation help wanted Extra attention is needed labels Jun 29, 2022
@cgarciae
Copy link

cgarciae commented Jun 30, 2022

Hey @rlouf, love the example! Inside the logpdf function the y variable doesn't exist, I am guessing it should be categories instead?

@rlouf
Copy link
Member Author

rlouf commented Jun 30, 2022

Yes, made the change, thank you! I have no guarantee that this will work though

@gerdm
Copy link

gerdm commented Jun 30, 2022

Hi @rlouf, I’ll work on this issue!

@rlouf
Copy link
Member Author

rlouf commented Aug 29, 2022

Hey @gerdm do you still intend on working on this?

@gerdm
Copy link

gerdm commented Aug 30, 2022

Hey @rlouf. Yes, still planning to work on it. Expect updates in September.

@rlouf rlouf transferred this issue from blackjax-devs/blackjax Jan 13, 2023
@rlouf rlouf added Model Add a new model to the book and removed documentation labels Jan 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed Model Add a new model to the book
Projects
None yet
Development

No branches or pull requests

3 participants