TensorDict: A Comprehensive Introduction
TensorDict is an innovative library designed for PyTorch users to facilitate the handling of collections of tensors with ease. It serves as a dictionary-like class, allowing developers to perform various operations while circumventing the tediousness often involved in individually managing multiple tensors. This makes TensorDict a perfect companion for those building and training machine learning models.
Key Features
TensorDict enhances codebases by making them more readable, compact, modular, and efficient. Some of its standout features include:
- 🧮 Composability: Extends the operations of
torch.Tensor
to collections of tensors, enabling seamless operations. - ⚡️ Speed: Supports asynchronous device transfers and inter-node communication, while being compatible with
torch.compile
. - ✂️ Shape Operations: Facilitates operations such as indexing, slicing, and concatenation on TensorDict instances.
- 🌐 Distributed/Multi-processing Capabilities: Effortlessly distribute TensorDict instances across multiple workers and devices.
- 💾 Serialization: Includes capabilities for serialization and memory mapping.
- λ Functional Programming: Compatible with functional programming concepts and
torch.vmap
. - 📦 Nesting: Supports hierarchical structuring by nesting TensorDict instances.
- ⏰ Lazy Preallocation: Allows preallocation of memory without initializing the tensors.
- 📝 Specialized Dataclass: Introduces a specialized dataclass for
torch.Tensor
with the@tensorclass
decorator.
Examples
TensorDict excels in various applications, a few of which are detailed below:
Fast Copy on Device
TensorDict optimizes device transfers to be both safe and quick. Data transfers occur asynchronously with necessary synchronizations.
# Asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)
Coding an Optimizer
TensorDict supports efficient coding of optimizers. For instance, the Adam optimizer can be implemented to work with both single tensors and TensorDict inputs, leveraging fused kernels on CUDA for speed.
class Adam:
def __init__(self, weights: TensorDict, alpha: float=1e-3,
beta1: float=0.9, beta2: float=0.999,
eps: float = 1e-6):
self.weights = weights.lock_()
self.t = 0
self._mu = weights.data.clone()
self._sigma = weights.data.mul(0.0)
self.beta1 = beta1
self.beta2 = beta2
self.alpha = alpha
self.eps = eps
def step(self):
self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
self.t += 1
mu = self._mu.div_(1-self.beta1**self.t)
sigma = self._sigma.div_(1 - self.beta2 ** self.t)
self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))
Training a Model
TensorDict allows the rewriting of supervised training loops generically, which leads to highly adaptable and reusable code due to its abstract nature.
for i, data in enumerate(dataset):
data = model(data)
loss = loss_module(data)
loss.backward()
optimizer.step()
optimizer.zero_grad()
This adaptable framework can facilitate various tasks, such as classification and segmentation.
Installation
TensorDict can be easily installed using either pip or conda.
With Pip:
pip install tensordict
For the latest nightly features:
pip install tensordict-nightly
With Conda:
conda install -c conda-forge tensordict
Citation
For those using TensorDict in academic work, a citation is provided:
@misc{bou2023torchrl,
title={TorchRL: A data-driven decision-making library for PyTorch},
author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
year={2023},
eprint={2306.00577},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Disclaimer
Currently in beta stage, TensorDict may introduce backward-compatible-breaking changes, though it aims to maintain stability.
License
TensorDict is available under the MIT License. For more details, refer to the LICENSE file.