Diffrax: Numerical Differential Equation Solvers in JAX
Introduction
Diffrax is an innovative library built on JAX, designed to provide powerful numerical solvers for differential equations. These solvers cater to a variety of equations, including ordinary differential equations (ODEs), stochastic differential equations (SDEs), and controlled differential equations (CDEs). Diffrax stands out due to its ability to handle these equations in a uniform manner, ensuring a streamlined and compact library. Additionally, the library is fully autodifferentiable and capable of running on GPUs, making it a robust tool in computational applications.
Key Features
-
Solver Diversity: Diffrax offers a wide array of solvers, such as
Tsit5
,Dopri8
, symplectic solvers, and implicit solvers. These options allow users to select the most appropriate solver for their specific problem. -
Vmap Everything: The library supports vmapping for all operations, including the region of integration. This feature enhances the library's flexibility and efficiency in handling large-scale computations.
-
PyTree Utilization: Users can employ a PyTree as the state, facilitating more complex data structures and transformations within differential equations.
-
Dense Solutions & Adjoint Methods: Diffrax provides dense solutions and various adjoint methods for backpropagation, allowing for precise and efficient gradient computations, a critical feature for neural differential equations.
-
Neural Differential Equations: The library supports integration with neural networks, making it an excellent choice for research and applications involving neural differential equations.
Technical Excellence
Behind its functionalities, Diffrax possesses a streamlined technical architecture, solving all types of differential equations in a consistent manner. This integration results in a compact yet comprehensive library, catering to diverse computational needs without sacrificing performance.
Installation
To start using Diffrax, users need to ensure they have Python 3.9+, JAX 0.4.13+, and Equinox 0.10.11+. It can be installed easily via pip:
pip install diffrax
Documentation and Resources
Comprehensive documentation is available at https://docs.kidger.site/diffrax, where users can find detailed guides and examples to maximize their use of the library.
Quick Example
The following example illustrates the ease with which Diffrax can be used to solve an ordinary differential equation:
from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnp
def f(t, y, args):
return -y
term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)
In this example, Dopri5
specifies the Dormand--Prince 5(4) solver, commonly used for its efficiency and accuracy in a range of problems.
Citations
For those benefitting from Diffrax in academic contexts, citations are encouraged. Here is a relevant citation format:
@phdthesis{kidger2021on,
title={{O}n {N}eural {D}ifferential {E}quations},
author={Patrick Kidger},
year={2021},
school={University of Oxford},
}
Related JAX Libraries
Diffrax is part of a larger ecosystem of JAX-related libraries that enhance various aspects of scientific computing and deep learning:
- Equinox: A library for neural networks and more, working seamlessly with JAX.
- Optax: Offers optimizers like SGD and Adam.
- Optimistix: Provides tools for root finding and minimization tasks.
- Other useful libraries include BlackJAX, jaxtyping, and sympy2jax.
Explore more projects under the Awesome JAX repository for a comprehensive look at the capabilities within the JAX ecosystem.