BlackJAX: A Comprehensive Introduction
What is BlackJAX?
BlackJAX is a specialized library designed to offer samplers for JAX, which is a high-performance numerical computing platform. BlackJAX supports computations both on the CPU and GPU. Although not a probabilistic programming library on its own, it seamlessly integrates with probabilistic programming languages (PPLs), provided they can offer a log-probability density function compatible with JAX.
Ideal Users of BlackJAX
BlackJAX is a perfect fit for:
- Individuals with a log-probability density function who need effective samplers.
- Those requiring more than just a general-purpose sampler.
- Users looking to conduct sampling on GPU.
- Researchers interested in building upon robust foundational elements.
- Developers of new probabilistic programming languages.
- Learners aiming to understand the workings of sampling algorithms.
Quickstart Guide
Installation
Installing BlackJAX is straightforward. For those using pip
, the command is:
pip install blackjax
Alternatively, you can install it via conda-forge:
conda install -c conda-forge blackjax
It's important to note that BlackJAX is written in pure Python but relies on JAX's XLA. By default, the installed JAX version runs on CPU. For GPU or TPU usage, follow JAX installation instructions.
Example Usage
Here is a simple example that illustrates using the NUTS algorithm with BlackJAX:
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np
import blackjax
observed = np.random.normal(10, 20, size=1_000)
def logdensity_fn(x):
logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"])
return jnp.sum(logpdf)
# Build the kernel
step_size = 1e-3
inverse_mass_matrix = jnp.array([1., 1.])
nuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix)
# Initialize the state
initial_position = {"loc": 1., "scale": 2.}
state = nuts.init(initial_position)
# Iterate
rng_key = jax.random.PRNGKey(0)
step = jax.jit(nuts.step)
for i in range(100):
nuts_key = jax.random.fold_in(rng_key, i)
state, _ = step(nuts_key, state)
For more examples and detailed guides, consider visiting the documentation.
Philosophy Behind BlackJAX
Concept
BlackJAX aims to bridge the gap between simple, one-liner frameworks and complex, modular libraries. It offers users access to efficient, tested samplers with minimal coding effort. Its main appeal lies in exposing the inner workings of inference algorithms, which can be used to experiment with and develop new sampling schemes.
Purpose
Usually, sampling algorithms are heavily integrated within PPLs, making their modular use difficult. BlackJAX decouples these samplings, offering reusable algorithmic parts for custom development. This modularity enables research acceleration in sampling algorithms.
Operational Framework
BlackJAX revolves around a versatile pattern where any function that involves transitioning from one state to another is a "kernel". These stateless functions are easily interchangeable and composable due to their uniform API, yielding flexibility in algorithm design.
Contributions and Citations
Contributors are encouraged to follow the guidelines detailed in the contribution guide. For citing BlackJAX in academic work, use the reference provided in the documentation.
Acknowledgements
The implementation of NUTS in BlackJAX has been significantly inspired by Numpyro, a probabilistic programming library using the same computation model as BlackJAX.