attorch Project Introduction
attorch is a dedicated subset of PyTorch's neural network module, crafted purely using Python and OpenAI's Triton. The project aims to offer a hackable, self-contained, and easily readable collection of neural network modules that seeks to maintain or even surpass the efficiency of the original PyTorch implementation. attorch is created to be a simple and intuitive project, serving as an entry point for developers looking to create customized deep learning operations. Many may find pure PyTorch implementations too slow or lack the technical expertise to develop CUDA kernels, and attorch aims to bridge that gap.
Existing frameworks like kernl, xFormers, Unsloth, and fla mostly focus on Transformers and NLP applications. attorch, however, seeks to be more inclusive by offering a variety of layers for areas beyond NLP, such as computer vision. Unlike many inference-only packages, attorch supports both forward and backward passes, making it suitable for both training and inference, though it may not match the performance of dedicated inference engines.
Installation
To use attorch, ensure you have the dependencies torch==2.4.0
and triton==3.0.0
installed. Clone the repository to get started with attorch.
Layers
The project presently supports various layers with automatic mixed precision (AMP) capability. Some of these include:
- Convolution layers (
Conv1d
,Conv2d
). - Activation functions (
Hardsigmoid
,Hardswish
,LeakyReLU
,GELU
,ReLU
,SiLU
, etc.). - Attention mechanisms (
MultiheadAttention
). - Normalization techniques (
BatchNorm1d
,BatchNorm2d
,LayerNorm
). - Loss functions (
L1Loss
,MSELoss
,CrossEntropyLoss
,NLLLoss
).
These layers generally behave like their PyTorch equivalents, offering an easy transition for users familiar with PyTorch.
Math Functions
In the world of Triton, kernels are split into two tasks: loading/storing relevant tensors and performing math operations on them. attorch's math functions provide utility for implementing custom kernels and operation fusion. Although these functions only cover forward passes, their gradients can be automatically deduced using the triton-autodiff library, allowing developers to confidently refactor parts of attorch's kernels through these math transformations.
PyTorch Fallback
attorch offers a seamless integration with PyTorch layers via the attorch.nn
interface. Should attorch lack the desired layer, it defaults to its PyTorch equivalent. For example:
from attorch import nn
# Uses attorch's linear layer
lin = nn.Linear(10, 20)
# Uses PyTorch's global pooling since it is not available in attorch
gap = nn.AdaptiveAvgPool2d(1)
Tests
Every module in attorch can be tested against its PyTorch counterpart to ensure they function correctly. These tests are available under the tests/
directory and can be executed using pytest
. While there may be occasional failures due to numerical precision discrepancies, they are generally not a concern for most practical applications.