Diffusion Reinforcement Learning X (DRLX)
Diffusion Reinforcement Learning X, or DRLX, is a specialized library designed to simplify the process of distributed training for diffusion models through the application of reinforcement learning techniques. The primary goal of DRLX is to enhance the capabilities of Hugging Face's Diffusers library by enabling more efficient multi-GPU and multi-node training, although multi-node support is still in its experimental phase.
Key Features
- Reinforcement Learning for Diffusion Models: DRLX stands out by integrating reinforcement learning methodologies to improve the training of diffusion models.
- Integration with Hugging Face Tools: It is built to work seamlessly with Hugging Face’s Diffusers, making it a powerful extension for existing machine learning workflows.
- Accelerated Training with Hugging Face's Accelerate: Leveraging Accelerate, DRLX is optimized for both multi-GPU setups, ensuring that training is faster and more resource-efficient.
- Latest Experiments and Findings: The project team actively shares updates and new experiments, such as those mentioned in their blog post.
Getting Started
To start using DRLX, users need to first install OpenCLIP. Afterwards, DRLX can be easily installed via pip
:
pip install drlx
Alternatively, one can install it directly from the source:
pip install git+https://github.com/CarperAI/DRLX.git
How to Apply DRLX
DRLX has been tested with popular Stable Diffusion models (versions 1.4, 1.5, and 2.1), but its flexible architecture allows for the use of most denoising pipelines. It uses the novel DDPO algorithm (DDPO paper) for training, ensuring models remain compatible with their original pipelines.
Here is a basic setup for training:
from drlx.reward_modelling.aesthetics import Aesthetics
from drlx.pipeline.pickapic_prompts import PickAPicPrompts
from drlx.trainer.ddpo_trainer import DDPOTrainer
from drlx.configs import DRLXConfig
# Initialize components
pipe = PickAPicPrompts()
config = DRLXConfig.load_yaml("configs/my_cfg.yml")
trainer = DDPOTrainer(config)
# Start training
trainer.train(pipe, Aesthetics())
And for inferring with a trained model:
pipe = StableDiffusionPipeline.from_pretrained("out/ddpo_exp")
prompt = "A mad panda scientist"
image = pipe(prompt).images[0]
image.save("test.jpeg")
Enhanced Training Capabilities
For scaled-up training, DRLX users can configure accelerate
and launch their modules with simple command line instructions:
accelerate config
accelerate launch -m [your module]
Future Plans
DRLX is actively evolving, with several features either completed or in the works:
- Initial DDPO launch.
- Models fine-tuned with PickScore.
- Future integration of DPO algorithms.
- Anticipated support for SDXL.
For more detailed information and updates, users can refer to the comprehensive documentation provided by DRLX's developers.