Overview of jax-triton
The jax-triton
project serves as a seamless bridge between two powerful machine learning libraries: JAX and Triton. With this integration, users can leverage Triton's high-performance parallel computing capabilities directly within the JAX framework. It's important to note that jax-triton
is not an officially supported Google product.
Quickstart Guide
At the heart of jax-triton
is the jax_triton.triton_call
function. This function enables the use of Triton kernels on JAX arrays, which can be especially useful within JAX's jit
-compiled functions. To illustrate, consider a simple kernel from Triton's vector addition tutorial:
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
length,
output_ptr,
block_size: tl.constexpr,
):
"""Adds two vectors."""
pid = tl.program_id(axis=0)
block_start = pid * block_size
offsets = block_start + tl.arange(0, block_size)
mask = offsets < length
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
This Triton kernel can be applied to JAX arrays using jax_triton.triton_call
, as shown in the following example:
import jax
import jax.numpy as jnp
import jax_triton as jt
def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
block_size = 8
return jt.triton_call(
x,
y,
x.size,
kernel=add_kernel,
out_shape=out_shape,
grid=(x.size // block_size,),
block_size=block_size)
x_val = jnp.arange(8)
y_val = jnp.arange(8, 16)
print(add(x_val, y_val))
print(jax.jit(add)(x_val, y_val))
For more practical applications, one can explore the examples directory in the jax-triton
repository, which includes scripts like fused_attention.py
and detailed notebooks such as the JAX + Triton Flash Attention example.
Installation
Installing jax-triton
is straightforward with Python’s package manager:
$ pip install jax-triton
Users have the option to install either a stable version of triton
or opt for the more experimental nightly release. Additionally, ensure that JAX is installed with CUDA support to take full advantage of hardware acceleration:
$ pip install "jax[cuda12]"
Development Setup
For developers interested in contributing to jax-triton
or experimenting with its source code, the process begins with cloning the project’s repository:
$ git clone https://github.com/jax-ml/jax-triton.git
After cloning, an editable installation can be achieved with the commands:
$ cd jax-triton
$ pip install -e .
Before running tests, install pytest
with:
$ pip install pytest
Run the tests to ensure that the installation and any modifications work as expected:
$ pytest tests/
This foundation allows developers and users alike to delve deep into the capabilities of jax-triton
, exploring its potential to accelerate machine learning workflows.