Introduction to Accelerated Scan
The Accelerated Scan project offers a groundbreaking algorithm designed to perform highly efficient first-order parallel associative scans on GPU systems. The implementations are particularly beneficial for complex computations in areas like state space models and linear recurrent neural networks (RNNs).
What is an Associative Scan?
An associative scan is a computational technique used to solve first-order recurrences of the form x[t] = gate[t] * x[t-1] + token[t]
. These types of calculations are prevalent in signal processing and machine learning models where current states rely on previous ones.
Key Features
- Speed: The accelerated scan is crafted to be one of the fastest implementations available for this type of operation, making full use of GPU technology to boost performance.
- Architecture: It uses a unique chunked processing algorithm that leverages the powerful GPU communication primitives. Within a thread block, operations are quicker due to warp shuffles between threads, while shared memory facilities enable efficient inter-thread block communication.
- Flexibility: Supports sequence processing in batches where sequence lengths can range between 32 and 65536 — all of which must be a power of two. The flexibility offered makes it adaptable to different problem sizes and computing needs.
Different Implementation Approaches
- accelerated_scan.warp: This version uses a CUDA kernel written in C++. It is optimized for speed and is designed to outperform even standard GPU libraries like CUB.
- accelerated_scan.triton: Utilizes Triton's tl.associative_scan primitive. This implementation requires Triton version 2.2 due to its advanced feature support.
- accelerated_scan.ref: A reference implementation in PyTorch, which is slower but ensures numerical equivalence and reliability for experimental purposes.
How to Get Started
Begin by installing the package via pip:
pip install accelerated-scan
Here's a basic example of how to use the package:
import torch
from accelerated_scan.warp import scan
batch_size, dim, seqlen = 3, 1536, 4096
gates = 0.999 + 0.001 * torch.rand(batch_size, dim, seqlen, device="cuda")
tokens = torch.rand(batch_size, dim, seqlen, device="cuda")
out = scan(gates, tokens)
This snippet demonstrates setting up the problem parameters and executing a scan using the accelerated_scan.warp
module.
Performance Benchmarking
In performance benchmarks, the Accelerated Scan demonstrates significant speed advantages over standard methods especially as the sequence length increases. For instance, with sequence lengths of 65536, the accelerated_scan.warp
method completes operations in a mere 11.3 units of time, outperforming reference methods and even Triton implementations by a substantial margin.
Notes on Precision
There are some considerations regarding precision, particularly when both gates and tokens are sampled from a uniform distribution between 0 and 1. The precision inaccuracies can become noticeable in floating-point operations, notably with bfloat16 precision, which is a tradeoff for performance in high-complexity environments.
In summary, the Accelerated Scan provides an exceptionally fast and efficient solution for first-order recurrence calculations on GPUs, supporting substantial performance improvements for applications in machine learning and computational models that require heavy data processing.