jaxdf - JAX-based Discretization Framework
Overview
The jaxdf project is a framework built on JAX, aimed at facilitating the creation of differentiable numerical simulators. These simulators can use any kind of discretization, making them versatile tools for investigating physical systems or solving partial differential equations numerically. The core feature of jaxdf is its ability to craft pure functions that integrate smoothly into any differentiable program written in JAX. This can be particularly powerful for embedding these models into neural networks as layers or utilizing them to define physics-based loss functions in machine learning models.
Example
In practical terms, consider the code snippet from jaxdf that demonstrates constructing a non-linear operator (∇^2 + sin)
and applying it with a Fourier spectral discretization on a two-dimensional square domain. This construction is further used to define a loss function, and JAX's automatic differentiation calculates the gradient of this function effectively in this framework. Here's a glance at what the process looks like:
from jaxdf import operators as jops
from jaxdf import FourierSeries, operator
from jaxdf.geometry import Domain
from jax import numpy as jnp
from jax import jit, grad
# Define the operator
@operator
def custom_op(u, *, params=None):
grad_u = jops.gradient(u)
diag_jacobian = jops.diag_jacobian(grad_u)
laplacian = jops.sum_over_dims(diag_jacobian)
sin_u = jops.compose(u)(jnp.sin)
return laplacian + sin_u
# Set up the discretization
domain = Domain((128, 128), (1., 1.))
parameters = jnp.ones((128,128,1))
u = FourierSeries(parameters, domain)
# Define a differentiable loss function
@jit
def loss(u):
v = custom_op(u)
return jnp.mean(jnp.abs(v.on_grid)**2)
gradient = grad(loss)(u) # computing the gradient
Installation
Before installing jaxdf, ensure you have JAX set up on your system, especially if you intend to use NVIDIA GPU support. Installing jaxdf can be done effortlessly via PyPI using pip:
pip install jaxdf
For those interested in development, the recommended approach is to clone the repository or download the archives. From the root directory in your terminal, use these commands to set up the environment:
pip install --upgrade poetry
poetry install
This ensures all dependencies and the package are installed in an editable mode.
Support
For any issues with the code or suggestions for new features, contributors are encouraged to open issues on the project site. The jaxdf team is also active on Discord for discussions or casual interactions.
Contributing
Contributions to jaxdf are highly welcome. The process generally starts by creating an issue for discussions or feature requests. Contributors should ensure their additions are well-tested and document changes through the changelog. Detailed instructions for contributing can be found in the project's contributing guide.
Citation and Acknowledgements
jaxdf was initially presented at the Differentiable Programming workshop during NeurIPS 2021. The framework’s acknowledgement includes using a template from @rochacbruno for packaging elements and incorporating a multiple-dispatch method based on the plum
project. Related projects like odl
, deepXDE
, and SciML
can also serve as valuable resources for those exploring similar domains.
The project acknowledges and appreciates the contributions and support of the community and encourages further participation and exploration within the scientific computing landscape.