Tab Transformer Project Overview
Tab Transformer is a neural network architecture tailored for tabular data, implemented in Pytorch. Tabular data is one of the most common data types used across various industries, encompassing any data that can fit within a spreadsheet structure. Unlike traditional deep learning models which excel with image or text data, Tab Transformer was created to bridge the performance gap for tabular data against methods like Gradient Boosted Decision Trees (GBDT), widely known for their efficacy in this domain.
Key Features
-
Near GBDT Performance: The Tab Transformer architecture is impressively competitive, performing almost on par with GBDT. Recent advancements by Amazon AI suggest they have surpassed GBDT by employing attention mechanisms on real-world tabular datasets.
-
Pytorch Implementation: It's built using Pytorch, a popular deep learning framework, making it accessible for researchers and practitioners familiar with this environment.
Installation
Setting up Tab Transformer is straightforward with pip, allowing for rapid deployment in your Python environment:
$ pip install tab-transformer-pytorch
Usage
The core functionality of the Tab Transformer can be seen in the simple usage example provided. It requires defining the architecture's dimensions and the categorical and continuous data it will process:
import torch
from tab_transformer_pytorch import TabTransformer
cont_mean_std = torch.randn(10, 2)
model = TabTransformer(
categories = (10, 5, 6, 5, 8),
num_continuous = 10,
dim = 32,
dim_out = 1,
depth = 6,
heads = 8,
attn_dropout = 0.1,
ff_dropout = 0.1,
mlp_hidden_mults = (4, 2),
mlp_act = torch.nn.ReLU(),
continuous_mean_std = cont_mean_std
)
x_categ = torch.randint(0, 5, (1, 5))
x_cont = torch.randn(1, 10)
pred = model(x_categ, x_cont)
FT Transformer: An Enhancement
The FT Transformer is a variation developed to further enhance performance by adopting a more simplified scheme for embedding continuous numerical values. This model is included for comparative purposes, allowing users to experiment and determine which version best suits their specific dataset and application:
import torch
from tab_transformer_pytorch import FTTransformer
model = FTTransformer(
categories = (10, 5, 6, 5, 8),
num_continuous = 10,
dim = 32,
dim_out = 1,
depth = 6,
heads = 8,
attn_dropout = 0.1,
ff_dropout = 0.1
)
x_categ = torch.randint(0, 5, (1, 5))
x_numer = torch.randn(1, 10)
pred = model(x_categ, x_numer)
Unsupervised Training
For unsupervised training methodologies, one can leverage category token conversion to unique ids, followed by training using techniques akin to those from the Electra model on model.transformer
.
Future Enhancements and Research
The project roadmap includes exploring further model optimizations such as those proposed in recent research papers to continuously improve model efficiency and performance.
References
The project's origins and foundational concepts are documented in key research papers, such as:
- "TabTransformer: Tabular Data Modeling Using Contextual Embeddings" by Xin Huang et al.
- "Revisiting Deep Learning Models for Tabular Data" by Yu. V. Gorishniy et al.
These references provide deeper insights into the underpinnings and innovations driving the Tab Transformer architecture.