InfiniTransformer: A Revolution in Efficient Context Handling
InfiniTransformer is an unofficial implementation using PyTorch and 🤗Transformers that expands on the concepts presented in the research paper "Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention." This innovative project supports models like Llama3 and Gemma, as well as their predecessors, Llama 2 and 1. It aims to offer a more efficient way of handling large context windows in transformer models, significantly enhancing their capabilities.
Infini-Attention: Two Implementation Approaches
InfiniTransformer introduces two distinct methods for integrating Infini-attention into transformer models:
Type I: Model-Wise and Trainer-Wise Implementation
This approach represents a comprehensive overhaul of PyTorch modeling and configuration files, requiring custom training codes. While it is incompatible with the standard Hugging Face trainer, its main advantage lies in its efficient memory usage. For example, with this method:
- A Gemma-2B model can be trained with a sequence length of 32,768 on two H100 80G GPUs without gradient checkpointing, using the AdamW optimizer.
- Llama-3-8B can handle up to 1 million sequence lengths with two H100 80G GPUs and the Adafactor optimizer, also without gradient checkpointing.
- It enables 'infinite' context training using specific scripts and configurations.
Type II: Attention-Layer Only Implementation
This more minimalist strategy focuses solely on updating the attention layer within the modeling files. It remains fully compatible with Hugging Face tools, like their trainer, with similar memory usage as the SDPA (default) attention model. When applying this method:
- Gemma-2B can be trained with a sequence length of 8,192 using two H100 80G GPUs with the Adafactor optimizer and gradient checkpointing.
Using the InfiniTransformer
Here is a step-by-step guide to implementing these methods:
Steps for Type I Implementation
-
Clone the Repository:
git clone https://github.com/Beomi/InfiniTransformer
-
Install Dependencies: You will need the latest version of 🤗Transformers.
pip install -r requirements.txt pip install -e git+https://github.com/huggingface/transformers.git@b109257f4f#egg=transformers
-
Test the Basics: Run a simple forward/backward compatibility test.
python test_basic.infini.py
-
Train with Your Data: Implement training with different datasets, such as the MiniPile Dataset or WikiText2 Dataset, using specific training scripts.
For example, to train a Llama-3 model with 1 million sequence length:
./train.llama.infini.noclm.1Mseq.sh
Steps for Type II Implementation
-
Clone the Repository: Use the same cloning process as Type I.
-
Install Dependencies: Again, ensure you have the latest Transformers version.
pip install -r requirements.txt pip install -e git+https://github.com/huggingface/transformers.git@b109257f4f#egg=transformers
-
Modify Files as Needed: Remove and link to the appropriate modeling file.
python test_basic.infini.py
-
Test and Train: After testing basic compatibility:
python test_basic.py
-
Initiate Training with Your Data using provided scripts.
Inference and Sample Outputs
InfiniTransformer allows performing inference after training. An example output from a one-epoch trained model on WikiText2 demonstrates the model's capability to generate coherent text across diverse contexts.
By following this guide, users can explore the capabilities of InfiniTransformer and harness its potential to manage infinitely large contexts in AI-driven text generation, offering new possibilities for scaling AI applications.