Image GPT Project Introduction
Image GPT is a PyTorch implementation that harnesses the principles of generative pretraining for image data, inspired by the seminal paper Generative Pretraining from Pixels authored by Chen et al. This innovative project explores the adaptation of the GPT model architecture, originally designed for text, to the domain of images.
Project Overview
The Image GPT project comprises an intriguing exploration of how models can be trained to generate images or complete partial images. By utilizing a similar approach to language models, where knowledge is incrementally accrued from pixels rather than words, this project demonstrates the capability to generate coherent and contextually appropriate image content.
Image GPT notably illustrates its potential by attempting to complete partially given images from datasets like MNIST and CIFAR-10. Intriguingly, the first image in a series of examples represents the input image with missing parts, while the last image shows the original, complete image. The resulting model-generated images fill in the gaps, paving the way for advancements in image processing and understanding through generative modeling.
Current Progress and Future Work
The project remains a remarkable work-in-progress (WIP) with several vital developments in the pipeline:
-
Batched k-means on GPU: Plans are underway to implement efficient batched k-means operations on GPUs, allowing for faster processing of larger datasets with quantization handled by
sklearn.cluster.MiniBatchKMeans
. -
BERT-style Pretraining: While the current focus is on generative pretraining, there's an intent to include BERT-style pretraining to enhance the capability of the model.
-
Pre-trained Model Utilization: Efforts are being made to integrate pre-trained models directly from OpenAI, enriching the versatility of the project.
-
Reproducibility of iGPT-S Results: A goal is set to match or exceed the results produced by the iGPT-S model, ensuring comparable efficiency and accuracy.
The largest model in the OpenAI blog post, iGPT-L, features an imposing 1.4 million parameters and was trained over 2500 V100-days. However, the project demonstrates efficient downsizing, with a significantly lighter model boasting just 26,000 parameters trainable in less than two hours on a single NVIDIA 2070 GPU, using the Fashion-MNIST dataset.
Usage Guide
The implementation supports various functionalities that cater to different needs of training and utilization:
Pre-trained Models
The project provides some pre-trained models stored within the models
directory. For instance, the iGPT-S model pre-trained on CIFAR-10 can be conveniently downloaded using a script provided (./download.sh
).
Compute Centroids
A critical step involves the calculation of centroids through k-means clustering. These centroids play a crucial role in the quantization of images prior to model processing, ensuring streamlined and refined input data.
python src/compute_centroids.py --dataset mnist --num_clusters=8
This command helps create a file data/<dataset>_centroids.npy
, a vital component for preparing images.
Training
Enabling dynamic model training, the src/run.py
script offers comprehensive training modules. Both generative pretraining and classification fine-tuning are available:
-
Generative Pre-training: This approach focuses on training the model to generate image data from scratch by familiarizing it with patterns in specified datasets.
python src/run.py --dataset mnist train configs/xxs_gen.yml
-
Classification Fine-tuning: Leveraging pre-trained weights, the model undergoes further refinement for more precise image classification tasks.
python src/run.py --dataset mnist train configs/xxs_clf.yml --pretrained=models/mnist_gen.ckpt
Sampling
The project allows for creating interesting visual samples from test datasets. These samples illustrate the model's potential in generating and completing images akin to the showcased figures and GIFs.
python src/sample.py models/mnist_gen.ckpt
For generating dynamic visuals, a GIF can be created using:
python src/gif.py models/mnist_gen.ckpt
Image GPT represents a significant leap forward in the domain of image processing using deep learning models. Its innovative approach and versatile capabilities hold promise for future developments in AI-driven image generation and enhancement.