LLaMA Classification: A Simple Guide to Text Classification
The LLaMA Classification repository is a straightforward yet robust codebase designed to address text classification challenges using the LLaMA model. This project is focused on helping users classify text more effectively, utilizing advanced methods to improve accuracy and performance.
Development Environment
For optimal performance, the LLaMA Classification project was developed using a specific hardware setup:
- Device: Nvidia 1xV100 GPU
- Device Memory: 34 GB
- Host Memory: 252 GB
Should you require more detailed hardware information, the project encourages you to reach out by creating an issue on their page.
Getting Started
Setting Up Your Environment
-
Obtain Checkpoints: To get started, you need to download the required checkpoint from the official LLaMA repository. This checkpoint is expected to be organized within a directory structure in the root of the project as follows:
checkpoints ├── llama │ ├── 7B │ │ ├── checklist.chk │ │ ├── consolidated.00.pth │ │ └── params.json │ └── tokenizer.model
-
Prepare Python Environment: The project recommends using Anaconda to manage your Python environment. Here’s how you can set it up:
conda create -y -n llama-classification python=3.8 conda activate llama-classification conda install cudatoolkit=11.7 -y -c nvidia conda list cudatoolkit # to verify the installed CUDA version pip install -r requirements.txt
Classification Methods
The project supports several methods for classifying text using LLaMA, each catering to different needs and preferences:
Direct Method
The Direct Method involves comparing the conditional probability p(y|x)
:
-
Data Preprocessing: Preprocess data, such as the ag_news dataset, using specific scripts.
python run_preprocess_direct_ag_news.py
-
Inference and Prediction: Execute inference to compute probabilities and predict classes.
torchrun --nproc_per_node 1 run_evaluate_direct_llama.py ...
Calibration
This method enhances the Direct Method by incorporating calibration techniques to improve prediction accuracy.
-
torchrun --nproc_per_node 1 run_evaluate_direct_calibrate_llama.py ...
Channel Method
The Channel Method involves examining the conditional probability p(x|y)
:
-
Data Preprocessing: Similar to the Direct Method, using a channel approach.
python run_preprocess_channel_ag_news.py
-
Inference: Execute inference using the Channel Method.
torchrun --nproc_per_node 1 run_evaluate_channel_llama.py ...
Pure Generation
This approach evaluates data using a generation mode, utilizing preprocessed data.
torchrun --nproc_per_node 1 run_evaluate_generate_llama.py ...
Experimental Results
The project's experiments have yielded the following results:
Dataset | Num Examples | K | Method | Accuracy | Inference Time |
---|---|---|---|---|---|
ag_news | 7600 | 1 | direct | 76.82% | 00:38:40 |
ag_news | 7600 | 1 | direct+calibrated | 85.67% | 00:38:40 |
ag_news | 7600 | 1 | channel | 78.25% | 00:38:37 |
Next Steps
The project outlines several goals for future development:
- Implementing additional calibration methods
- Supporting more datasets from the Hugging Face library
- Integrating LLM.int8 for further efficiency
- Developing diverse evaluation metrics
Final Thoughts
The LLaMA Classification project is built on the foundation of the official LLaMA repository, and its team expresses gratitude for their contributions. Users are encouraged to engage with the project by submitting issues or pull requests for new features, implementation details, or research directions.
For those using the codebase in research, citing the work is appreciated:
@software{Lee_Simple_Text_Classification_2023,
author = {Lee, Seonghyeon},
month = {3},
title = {{Simple Text Classification Codebase using LLaMA}},
url = {https://github.com/github/sh0416/llama-classification},
version = {1.1.0},
year = {2023}
}