Skip to content

Julia package for automated Bayesian inference on a factor graph with reactive message passing

License

Notifications You must be signed in to change notification settings

blolt/RxInfer.jl

 
 

Repository files navigation

Official page Stable Dev Examples Q&A Roadmap Build Status Coverage DOI Zenodo

Overview

RxInfer.jl is a Julia package for automatic Bayesian inference on a factor graph with reactive message passing.

Given a probabilistic model, RxInfer allows for an efficient message-passing based Bayesian inference. It uses the model structure to generate an algorithm that consists of a sequence of local computations on a factor graph representation of the model.

Performance and scalability

RxInfer.jl has been designed with a focus on efficiency, scalability and maximum performance for running Bayesian inference with message passing. Below is a comparison between RxInfer.jl and Turing.jl on latent state estimation in a linear multi-variate Gaussian state-space model. Turing.jl is a state-of-the-art Julia-based general-purpose probabilistic programming package and is capable of running inference in a broader class of models. Still, RxInfer.jl executes the inference task in various models faster and more accurately. RxInfer.jl accomplishes this by taking advantage of any conjugate likelihood-prior pairings in the model, which have analytical posteriors that are known by RxInfer.jl. As a result, in models with conjugate pairings, RxInfer.jl often beats general-purpose probabilistic programming packages in terms of computational load, speed, memory and accuracy. Note, however, that RxInfer.jl also supports non-conjugate inference and is continually improving in order to support a larger class of models.

Turing comparison Scalability performance

Faster inference with better results

RxInfer.jl not only beats generic-purpose Bayesian inference methods in conjugate models, executes faster, and scales better, but also provides more accurate results. Check out the documentation for more examples!

Inference with RxInfer Inference with HMC

The benchmark and accuracy experiment, which generated these plots, is available in the benchmarks/ folder. Note, that the execution speed and accuracy of the HMC estimator heavily depends on the choice of hyperparameters. In this example, RxInfer executes exact inference consistently and does not depend on any hyperparameters.

References

Installation

Install RxInfer through the Julia package manager:

] add RxInfer

Optionally, use ] test RxInfer to validate the installation by running the test suite.

Documentation

For more information about RxInfer.jl please refer to the documentation.

Note

RxInfer.jl API has been changed in version 3.0.0. See Migration Guide for more details.

Getting Started

There are examples available to get you started in the examples/ folder. Alternatively, preview the same examples in the documentation.

Coin flip simulation

Here we show a simple example of how to use RxInfer.jl for Bayesian inference problems. In this example we want to estimate a bias of a coin in a form of a probability distribution in a coin flip simulation.

First let's setup our environment by importing all needed packages:

using RxInfer, Random

We start by creating some dataset. For simplicity in this example we will use static pre-generated dataset. Each sample can be thought of as the outcome of single flip which is either heads or tails (1 or 0). We will assume that our virtual coin is biased, and lands heads up on 75% of the trials (on average).

n = 500  # Number of coin flips
p = 0.75 # Bias of a coin

distribution = Bernoulli(p) 
dataset      = rand(distribution, n)

Model specification

In a Bayesian setting, the next step is to specify our probabilistic model. This amounts to specifying the joint probability of the random variables of the system.

Likelihood

We will assume that the outcome of each coin flip is governed by the Bernoulli distribution, i.e.

$$y_i \sim \mathrm{Bernoulli}(\theta)$$

where $y_i = 1$ represents "heads", $y_i = 0$ represents "tails". The underlying probability of the coin landing heads up for a single coin flip is $\theta \in [0,1]$.

Prior

We will choose the conjugate prior of the Bernoulli likelihood function defined above, namely the beta distribution, i.e.

$$\theta \sim Beta(a, b)$$

where $a$ and $b$ are the hyperparameters that encode our prior beliefs about the possible values of $\theta$. We will assign values to the hyperparameters in a later step.

Joint probability

The joint probability is given by the multiplication of the likelihood and the prior, i.e.

$$P(y_{1:N}, \theta) = P(\theta) \prod_{i=1}^N P(y_i | \theta).$$

Now let's see how to specify this model using GraphPPL's package syntax:

# GraphPPL.jl export `@model` macro for model specification
# It accepts a regular Julia function and builds a factor graph under the hood
@model function coin_model(y, a, b) 
    # We endow θ parameter of our model with some prior
    θ ~ Beta(a, b)
    # We assume that outcome of each coin flip 
    # is governed by the Bernoulli distribution
    for i in eachindex(y)
        y[i] ~ Bernoulli(θ)
    end  
end

In short, the @model macro converts a textual description of a probabilistic model into a corresponding Factor Graph (FG). In the example above, the $\theta \sim \mathrm{Beta}(a, b)$ expression creates latent variable $θ$ and assigns it as an output of $\mathrm{Beta}$ node in the corresponding factor graph. The ~ operation can be understood as "is modelled by". Next, we model each data point y[i] as $\mathrm{Bernoulli}$ distribution with $\theta$ as its parameter.

Tip

Alternatively, we could use the broadcasting operation:

@model function coin_model(y, a, b) 
    θ  ~ Beta(a, b)
    y .~ Bernoulli(θ) 
end

As you can see, RxInfer in combination with GraphPPL offers a model specification syntax that resembles closely to the mathematical equations defined above.

Note

GraphPPL.jl API has been changed in version 4.0.0. See Migration Guide for more details.

Inference specification

Once we have defined our model, the next step is to use RxInfer API to infer quantities of interests. To do this we can use a generic infer function from RxInfer.jl that supports static datasets.

result = infer(
    model = coin_model(a = 2.0, b = 7.0),
    data  = (y = dataset, )
)

Coin Flip

Roadmap

Our high-level project roadmap outlines the key milestones and focus areas for the upcoming years:

Q1/Q2 2024 Q3/Q4 2024 2025
🧩 Nested models with GraphPPL.jl 🌐 Graph structure visualization 🔀 Stochastic Processes
🔄 Development of ExponentialFamilyProjection.jl 🧠 Automated inference with ExponentialFamilyProjection.jl 🚀 Robustness & Memory-efficiency

For a more granular view of our progress and ongoing tasks, check out our project board or join our 4-weekly public meetings.

Contributing

We welcome contributions from the community. If you are interested in contributing to the development of RxInfer.jl, please check out our contributing guide, the contributing guidelines, or look at the issues linked with the good first issue label to get started.

Where to go next?

There are a set of examples available in RxInfer repository that demonstrate the more advanced features of the package. Alternatively, you can head to the documentation that provides more detailed information of how to use RxInfer to specify more complex probabilistic models.

Ecosystem

The RxInfer framework consists of three core packages developed by ReactiveBayes:

  • ReactiveMP.jl - the underlying message passing-based inference engine
  • GraphPPL.jl - model and constraints specification package
  • Rocket.jl - reactive extensions package for Julia

JuliaCon 2023 presentation

Additionally, checkout our video from JuliaCon 2023 for a high-level overview of the package

Our presentation at the Julia User Group Munich meetup

Also check out the recorded presentation at the Julia User Group Munich meetup for a more detailed overview of the package

License

MIT License Copyright (c) 2021-2024 BIASlab, 2024-present ReactiveBayes

About

Julia package for automated Bayesian inference on a factor graph with reactive message passing

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 73.6%
  • Julia 24.3%
  • TeX 1.8%
  • Other 0.3%