PyTorch Tabular: Making Deep Learning Accessible for Tabular Data
PyTorch Tabular is an innovative library designed to make deep learning with tabular data (such as CSV files) easy and accessible for both researchers and real-world applications. This tool is built on top of robust platforms like PyTorch and PyTorch Lightning, ensuring it's scalable and ready for deployment.
Core Principles
The design of PyTorch Tabular follows three main principles:
- Low Resistance Usability: Making it easy for anyone to get started quickly without having to deal with complex setups.
- Easy Customization: Allowing users to adjust and extend functionalities according to their specific needs.
- Scalability and Deployability: Ensuring models created are scalable and can be easily deployed in production settings.
Installation
While the library includes PyTorch, it is recommended to install PyTorch separately first to ensure compatibility with your system's CUDA version. Once PyTorch is set, you can easily install PyTorch Tabular with:
pip install -U "pytorch_tabular[extra]"
This installs the complete library with additional tools for enhanced functionality. For just the core components, use:
pip install -U "pytorch_tabular"
Documentation
PyTorch Tabular has comprehensive documentation, equipped with tutorials available on ReadTheDocs.
Models Available
This library comes equipped with a variety of deep learning models specifically designed for tabular data:
- FeedForward Network with Category Embedding: A simple network that uses embeddings for categorical data.
- Neural Oblivious Decision Ensembles (NODE): An advanced model that outperforms many classic models, especially in classification tasks.
- TabNet: Utilizes attention mechanisms for better interpretability of tabular data.
- Mixture Density Networks: Offers probabilistic predictions by modeling the target as a mixture of densities.
- AutoInt: Employs neural networks to automatically discover and utilize feature interactions.
- TabTransformer: Adapts the Transformer model to create meaningful representations of categorical data.
- Gated Additive Tree Ensemble (GATE): Incorporates tree structures with neural networks for efficient learning.
- GANDALF and DANETs: Cutting-edge models focusing on automated feature abstraction and selection.
Additionally, for semi-supervised learning, PyTorch Tabular includes the Denoising AutoEncoder, which helps in creating noise-resistant feature representations.
Usage Example
Using PyTorch Tabular involves defining configurations for your data, model, optimizer, and training process, then fitting the model to your data:
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig
from pytorch_tabular.config import (
DataConfig,
OptimizerConfig,
TrainerConfig,
)
data_config = DataConfig(
target=["target"],
continuous_cols=["num1", "num2"],
categorical_cols=["cat1", "cat2"],
)
model_config = CategoryEmbeddingModelConfig(
task="classification",
layers="1024-512-512",
activation="LeakyReLU",
learning_rate=1e-3,
)
tabular_model = TabularModel(
data_config=data_config,
model_config=model_config,
)
tabular_model.fit(train=train_df, validation=val_df)
results = tabular_model.evaluate(test_df)
predictions = tabular_model.predict(test_df)
Community and Contribution
PyTorch Tabular is open-source and welcomes contributions from the community. It is actively developed with a roadmap that includes integration with Optuna for hyperparameter tuning and further expansion of model architectures.
Citation
If you find PyTorch Tabular useful in your research or project, the developers encourage you to cite their work in scientific publications. They provide both a paper and software citation format.
With its wide array of models and user-friendly design, PyTorch Tabular is a versatile tool that eases the application of deep learning techniques to structured data. It bridges the gap between cutting-edge AI research and practical business applications, making it a valuable asset for data scientists and engineers alike.