Muse - Pytorch
Muse is an innovative project dedicated to text-to-image generation using masked generative transformers. Implemented in Pytorch, this project replicates the cutting-edge Muse model, which aims to transform text inputs directly into richly detailed images. It is part of the broader endeavors of the LAION community to advance AI-based creative tools. Participants and contributors are encouraged to join the project's community on Discord for collaborative efforts and discussions.
Installation
To get started with Muse, one simply needs to install the library using pip:
$ pip install muse-maskgit-pytorch
Usage
Training the VAE
The process begins with training a Variational Autoencoder (VAE). The VQGanVAE
is set up with a specific dimensionality and a sizable codebook to efficiently handle image tokenization. Training involves utilizing a folder full of images, progressively working from simpler to more complex image resolutions:
import torch
from muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer
vae = VQGanVAE(
dim = 256,
codebook_size = 65536
)
trainer = VQGanVAETrainer(
vae = vae,
image_size = 128,
folder = '/path/to/images',
batch_size = 4,
grad_accum_every = 8,
num_train_steps = 50000
).cuda()
trainer.train()
Utilizing MaskGit for Image Generation
Following the VAE training, the project integrates a transformer model with the MaskGit framework. This involves creating a transformer network, linking it with the VAE and converting text descriptions into images.
import torch
from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer
vae = VQGanVAE(
dim = 256,
codebook_size = 65536
).cuda()
vae.load('/path/to/vae.pt')
transformer = MaskGitTransformer(
num_tokens = 65536,
seq_len = 256,
dim = 512,
depth = 8,
dim_head = 64,
heads = 8,
ff_mult = 4,
t5_name = 't5-small',
)
base_maskgit = MaskGit(
vae = vae,
transformer = transformer,
image_size = 256,
cond_drop_prob = 0.25,
).cuda()
texts = [
'a child screaming at finding a worm within a half-eaten apple',
'lizard running across the desert on two feet',
'waking up to a psychedelic landscape',
'seashells sparkling in the shallow waters'
]
images = torch.randn(4, 3, 256, 256).cuda()
loss = base_maskgit(
images,
texts = texts
)
loss.backward()
images = base_maskgit.generate(texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
], cond_scale = 3.)
images.shape
Super-Resolution with MaskGit
Expanding on the basic image generation, Muse also includes features for super-resolution. This takes initial low-resolution outputs and enhances them to produce larger, more refined images.
import torch
import torch.nn.functional as F
from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer
vae = VQGanVAE(
dim = 256,
codebook_size = 65536
).cuda()
vae.load('./path/to/vae.pt')
transformer = MaskGitTransformer(
num_tokens = 65536,
seq_len = 1024,
dim = 512,
depth = 2,
dim_head = 64,
heads = 8,
ff_mult = 4,
t5_name = 't5-small',
)
superres_maskgit = MaskGit(
vae = vae,
transformer = transformer,
cond_drop_prob = 0.25,
image_size = 512,
cond_image_size = 256,
).cuda()
texts = [
'a child screaming at finding a worm within a half-eaten apple',
'lizard running across the desert on two feet',
'waking up to a psychedelic landscape',
'seashells sparkling in the shallow waters'
]
images = torch.randn(4, 3, 512, 512).cuda()
loss = superres_maskgit(
images,
texts = texts
)
loss.backward()
images = superres_maskgit.generate(
texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles',
'waking up to a psychedelic landscape'
],
cond_images = F.interpolate(images, 256),
cond_scale = 3.
)
images.shape
Unified Approach with Muse
The Muse model combines both the base and super-resolution MaskGit models, allowing for seamless text-to-high-resolution image generation:
from muse_maskgit_pytorch import Muse
base_maskgit.load('./path/to/base.pt')
superres_maskgit.load('./path/to/superres.pt')
muse = Muse(
base = base_maskgit,
superres = superres_maskgit
)
images = muse([
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles',
'waking up to a psychedelic landscape'
])
images
Acknowledgments and Future Work
The Muse project owes its development to the support of StabilityAI, Huggingface, and various contributors from around the world. Future updates aim to enhance training methodologies further, integrate self-conditioning features, and potentially connect to the accelerate library for more efficient training.
Through continuous advancements, Muse will remain a frontier project in the AI domain, pushing the boundaries of what text-to-image generation can achieve.