TextRL: Text Generation with Reinforcement Learning
Introduction
TextRL is a Python library designed to enhance text generation by leveraging reinforcement learning (RL). Built on top of renowned libraries like Hugging Face's Transformers, PFRL, and OpenAI GYM, TextRL is suitable for a variety of text generation models. It offers flexibility and customization for users aiming to improve their text generation capabilities.
Examples
TextRL provides various examples to illustrate its capabilities:
- GPT-2 Example: Demonstrates the use of the GPT-2 model in conjunction with TextRL for text generation tasks.
- FLAN-T5 Example: Shows how FLAN-T5, another powerful text generation model, can be enhanced with TextRL.
- Bigscience/BLOOMZ-7B1-MT Example: Illustrates the integration of the BLOOMZ-7B1-MT model with TextRL.
- 176B BLOOM Example: Encourages contributions to the public swarm to enhance the capacity of this model.
- Controllable Generation via RL Example: An interesting use-case where reinforcement learning is applied to generate text as if Elon Musk speaks ill of DOGE.
Installation
TextRL can be easily installed using pip:
pip install pfrl@git+https://github.com/voidful/pfrl.git
pip install textrl
Alternatively, it can be built from source:
git clone <repository-url>
cd <repository-path>
pip install -e .
Usage
Initialize Agent and Environment
Initialization begins by loading the model and tokenizer from a pretrained checkpoint, which is then utilized to set up the text generation environment.
import torch
from textrl import TextRLEnv, TextRLActor
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "bigscience/bloomz-7b1-mt"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto", device_map="auto")
model = model.cuda()
Setup Reward Function for Environment
The reward function is central to reinforcement learning. It evaluates how well the generated text meets the desired criteria:
class MyRLEnv(TextRLEnv):
def get_reward(self, input_item, predicted_list, finish):
if finish:
reward = [0] # Change this to calculate a meaningful reward
return reward
Prepare for Training
The environment and actor (agent), equipped with key parameters, are prepared for the training process:
observation_list = [{"input":'example sentence'}]
env = MyRLEnv(model, tokenizer, observation_input=observation_list)
actor = TextRLActor(env, model, tokenizer)
agent = actor.agent_ppo(update_interval=10, minibatch_size=2000, epochs=20)
Training the Model
Training involves adjusting the model over numerous episodes to improve its performance and accuracy in generating desired text:
train_agent_with_evaluation(
agent,
env,
steps=1000,
eval_n_steps=None,
eval_n_episodes=1500,
train_max_episode_len=50,
eval_interval=10000,
outdir='output_directory',
)
Prediction
Once training is complete, the model is ready to make predictions based on the learned patterns.
agent.load("output_directory/best")
actor.predict("input text")
Dump Trained Model to Hugging Face's Model
TextRL allows easy exportation of trained models to Hugging Face for broader use.
textrl-dump --model ./model_path_before_rl --rl ./rl_path --dump ./output_dir
Key Parameters for RL Training
Training parameters such as update interval, minibatch size, epochs, learning rate, and more play crucial roles in refining model performance. Proper tuning of these parameters can significantly boost the quality and relevance of generated text.
TextRL offers a comprehensive framework for enhancing text generation models through reinforcement learning, demonstrating practical applications across different scenarios and models.