Introduction to the dm-haiku Project
Overview
Haiku is a neural network library designed for use with JAX, developed by some of the authors behind Sonnet—a library for TensorFlow. It is tailored to make the process of building neural networks simpler by allowing users to take advantage of JAX’s powerful features, including automatic differentiation and first-class support for GPU/TPU.
Haiku offers object-oriented programming tools familiar to many developers while still granting full access to JAX’s pure function capabilities. The library provides a module abstraction (hk.Module
) and a simple function transformation tool (hk.transform
) to facilitate efficient neural network construction and operation.
Why Haiku?
-
Tested at Scale: Haiku has been rigorously tested by researchers at DeepMind across large experiments involving image and language processing, generative models, and reinforcement learning. The results have demonstrated its robustness and ease of use at scale.
-
Library, Not a Framework: Haiku focuses on simplifying specific tasks such as managing model parameters and states. Unlike frameworks, it does not define custom optimizers or checkpoints, making it easier to integrate with other JAX components.
-
Built on Familiar Models: Haiku builds upon the concepts from Sonnet, a widely adopted library at DeepMind. It retains Sonnet’s module-based programming style while seamlessly integrating with JAX’s function transformations.
-
Ease of Transition: For users familiar with TensorFlow and Sonnet, transitioning to JAX and Haiku is straightforward, as it retains many of Sonnet’s API signatures and structure.
-
Simplified JAX Operations: Haiku offers straightforward solutions for handling random numbers and module initialization, vital for leveraging JAX transformations effectively.
Quickstart
To illustrate Haiku's ease of use, consider a simple example of defining a neural network with a multi-layer perceptron and a training loop:
import haiku as hk
import jax.numpy as jnp
def loss_fn(images, labels):
mlp = hk.Sequential([
hk.Linear(300), jax.nn.relu,
hk.Linear(100), jax.nn.relu,
hk.Linear(10),
])
logits = mlp(images)
return jnp.mean(softmax_cross_entropy(logits, labels))
loss_fn_t = hk.transform(loss_fn)
params = loss_fn_t.init(rng, dummy_images, dummy_labels)
for images, labels in input_dataset:
grads = jax.grad(loss_fn_t.apply)(params, images, labels)
params = jax.tree.map(update_rule, params, grads)
Haiku’s hk.transform
function transforms impure functions into pure ones, necessary for JAX operations like jax.jit
, jax.grad
, etc. The library simplifies parameter management within neural networks and separates initialization (init
) and application (apply
) stages.
Installation
To use Haiku, follow these steps:
- Install JAX: Follow the JAX installation instructions to align with your CUDA version.
- Install Haiku: Use pip to install Haiku:
$ pip install git+https://github.com/deepmind/dm-haiku
For a PyPI installation:
$ pip install -U dm-haiku
Some examples in Haiku require additional libraries, which can be installed via:
$ pip install -r examples/requirements.txt
User Manual
Writing Custom Modules
In Haiku, modules are subclasses of hk.Module
, and typically implement methods like __init__
and __call__
for defining layers, such as a custom linear layer:
class MyLinear(hk.Module):
def __init__(self, output_size, name=None):
super().__init__(name=name)
self.output_size = output_size
def __call__(self, x):
# Implementation details...
return jnp.dot(x, w) + b
Modules manage parameters through Haiku’s API, ensuring compatibility with pure functional transformations.
Stochastic Models and State Management
Haiku handles randomness and mutable states seamlessly. Methods such as hk.next_rng_key()
and hk.set_state()
allow for stochastic sampling and state tracking, crucial for models like VAEs or those using batch normalization, ensuring the module's functions remain pure and transformable.
Distributed Training
Haiku interfaces smoothly with jax.pmap
for data-parallel training on multiple devices, making it ideal for training large networks in distributable environments.
Conclusion
Haiku provides a straightforward, effective way to build neural networks in JAX, leveraging its robust research foundations and compatibility with existing libraries. Despite being in maintenance mode, it continues to offer a practical toolset for developers accustomed to its Sonnet-based model, now within the flexible JAX environment.