Introduction to SwissArmyTransformer
The SwissArmyTransformer (often abbreviated as sat
) is a versatile and powerful library designed to facilitate the development of customized Transformer model variants. Named after the iconic "Swiss Army knife," which is known for its multipurpose utility, sat
uses a common core codebase and light-weight mixins to support a range of diverse models, including BERT, GPT, T5, GLM, CogView, and ViT, among others.
At its core, sat
leverages deepspeed-ZeRO
and model parallelism to optimize both the pretraining and finetuning of large models ranging from 100 million to 20 billion parameters.
Installation
To install the SwissArmyTransformer, you can simply use pip:
pip install SwissArmyTransformer
Features
Easy Integration of Model-Agnostic Components
One of the standout features of sat
is its ability to add model-agnostic components using just a single line of code. For example, Prefix-tuning and P-tuning, both of which enhance finetuning by introducing trainable parameters in each attention layer, can be effortlessly implemented.
For instance, adding Prefix-tuning to a GLM classification model is straightforward:
class ClassificationModel(GLMModel):
def __init__(self, args, transformer=None, **kwargs):
super().__init__(args, transformer=transformer, **kwargs)
self.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1))
# Add Prefix-tuning
self.add_mixin('prefix-tuning', PrefixTuningMixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.prefix_len))
Similarly, transforming GPT or other auto-regressive models from training to inference-mode is simplified by caching previous states for efficiency during text generation:
model, args = AutoModel.from_pretrained('glm-10b-chinese', args)
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
Building Custom Transformer Models
SwissArmyTransformer
is designed to streamline the process of building custom Transformer models with minimal code. Projects like GLM, which vary from the standard Transformers mainly in aspects such as positional embeddings and training losses, can be implemented by focusing only on the distinct parts of the model configuration.
Extensive Support for Training
sat
provides robust support for training, offering features like:
- Multi-GPU and multi-node training with simple configurations.
- DeepSpeed and model parallelism integration.
- Enhanced ZeRO-2 support and activation checkpointing.
- Automatic data management for training datasets.
- Successful training support for models like CogView2 and CogVideo.
- The only available open-source code for finetuning the T5-10B model on GPUs.
Quick Start Guide
Here's a simple example of how to use a pretrained BERT model for inference with sat
:
from sat import get_args, get_tokenizer, AutoModel
args = get_args()
model, args = AutoModel.from_pretrained('bert-base-uncased', args)
tokenizer = get_tokenizer(args)
Run the script via:
SAT_HOME=/path/to/download python inference_bert.py --mode inference
Additionally, finetuning or pretraining a Transformer can be set up with ease. Here's how you can prepare and execute a finetuning task:
from sat import get_args, get_tokenizer, AutoModel
from sat.model.mixins import MLPHeadMixin
def create_dataset_function(path, args):
# Loading the dataset
assert isinstance(dataset, torch.utils.data.Dataset)
return dataset
def forward_step(data_iterator, model, args, timers):
inputs = next(data_iterator)
loss, *others = model(inputs)
return loss
args = get_args()
model, args = AutoModel.from_pretrained('bert-base-uncased', args)
tokenizer = get_tokenizer(args)
model.del_mixin('bert-final')
model.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1))
# Train the model!
training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
Execute the training with:
deepspeed --include localhost:0,1 finetune_bert.py --experiment-name ftbert --mode finetune --train-iters 1000 ...
You can easily extend this setup to multiple nodes by adjusting the hostfile
configuration settings, among other parameters.
Tutorials and Resources
Several tutorials are available to help you make the most of sat
:
- How to effectively leverage pretrained models available in
sat
. - The rationale and methods for training models using
sat
.
Citation and Acknowledgement
Although SwissArmyTransformer currently lacks a formal academic paper to cite, users are encouraged to credit the project on relevant platforms using the following: \footnote{https://github.com/THUDM/SwissArmyTransformer}
.
SwissArmyTransformer is developed on top of technologies like DeepSpeed, Megatron-LM, and Huggingface transformers, with gratitude for their substantial contributions to the field.