Project Overview: x-transformers
The x-transformers library is a compact yet comprehensive implementation of transformer models, augmented with innovative experimental features from various scholarly articles. This Python library aims to streamline the development and experimentation with different transformer architectures, accommodating both novices and experts in the field of machine learning.
Installation
Installing x-transformers is straightforward. You can add it to your project using pip:
$ pip install x-transformers
Usage
x-transformers is flexible and can be used in a variety of ways:
Full Encoder-Decoder Model
This configuration allows you to build models similar to the classic transformer networks used in many natural language processing tasks. Here's a basic example of setting up an encoder-decoder model:
import torch
from x_transformers import XTransformer
model = XTransformer(
dim=512,
enc_num_tokens=256,
enc_depth=6,
enc_heads=8,
enc_max_seq_len=1024,
dec_num_tokens=256,
dec_depth=6,
dec_heads=8,
dec_max_seq_len=1024,
tie_token_emb=True
)
src = torch.randint(0, 256, (1, 1024))
src_mask = torch.ones_like(src).bool()
tgt = torch.randint(0, 256, (1, 1024))
loss = model(src, tgt, mask=src_mask)
loss.backward()
Decoder-Only (GPT-like)
x-transformers also supports decoder-only models, which are akin to the architecture used in GPT models:
import torch
from x_transformers import TransformerWrapper, Decoder
model = TransformerWrapper(
num_tokens=20000,
max_seq_len=1024,
attn_layers=Decoder(dim=512, depth=12, heads=8)
).cuda()
x = torch.randint(0, 256, (1, 1024)).cuda()
model(x)
Encoder-Only (BERT-like)
Similar to BERT, x-transformers can configure encoder-only models:
import torch
from x_transformers import TransformerWrapper, Encoder
model = TransformerWrapper(
num_tokens=20000,
max_seq_len=1024,
attn_layers=Encoder(dim=512, depth=12, heads=8)
).cuda()
x = torch.randint(0, 256, (1, 1024)).cuda()
mask = torch.ones_like(x).bool()
model(x, mask=mask)
Additional Features
x-transformers incorporates innovative features that target specific improvements in transformer performance and efficiency.
Flash Attention
Flash Attention is an optimization that improves memory efficiency and speed by processing the attention matrix in tiles. This feature is available in x-transformers:
import torch
from x_transformers import TransformerWrapper, Decoder
model = TransformerWrapper(
num_tokens=20000,
max_seq_len=1024,
attn_layers=Decoder(dim=512, depth=6, heads=8, attn_flash=True)
)
Memory Transformers
This feature introduces learned memory tokens that pass through attention layers alongside input tokens, enhancing transformer capabilities:
import torch
from x_transformers import TransformerWrapper, Encoder
model = TransformerWrapper(
num_tokens=20000,
max_seq_len=1024,
num_memory_tokens=20,
attn_layers=Encoder(dim=512, depth=6, heads=8)
)
Alternative Normalizations and Activation Functions
x-transformers provides options for various normalizations and activation functions, such as RMSNorm and ReLU², offering potential enhancements in training stability and convergence:
import torch
from x_transformers import TransformerWrapper, Decoder
model = TransformerWrapper(
num_tokens=20000,
max_seq_len=1024,
attn_layers=Decoder(dim=512, depth=6, heads=8, use_rmsnorm=True)
)
Conclusion
x-transformers is a versatile library designed to facilitate experimentation with transformer models. It provides comprehensive support for various model architectures, while also offering cutting-edge features to optimize and enhance model performance. Whether working on natural language processing or vision tasks, x-transformers serves as a powerful tool in the machine learning toolkit.