Introduction to TorchTyping
TorchTyping is a Python library designed to bring clarity and precision to code involving PyTorch tensors. It aids developers in specifying the shapes, data types (dtype), and naming dimensions of tensors through type annotations, thereby enhancing code readability and reducing bugs.
Motivation Behind TorchTyping
When working with tensors, it's common to comment the expected shape or dtype in code to keep things organized. For instance, one might write comments like # x has shape (batch, x_channels)
next to their tensor operations. However, such comments lack enforceability. TorchTyping addresses this by allowing developers to describe tensor characteristics directly in function signatures, providing both a form of documentation and a runtime check on tensor properties.
Key Features
- Shape Annotations: Define the size and number of dimensions of a tensor.
- Data Type Annotations: Specify expected data types, such as floats or integers.
- Layout Specifications: Detail sparse or dense tensor layouts.
- Named Dimensions: Utilize PyTorch's named tensors feature for clearer dimension names.
- Support for Arbitrary Batch Dimensions: Use
...
to indicate flexible batches. - Extensible Adjustments: TorchTyping can be customized to include additional checks or specifications pertinent to user needs.
Installation
To install TorchTyping, use pip:
pip install torchtyping
Ensure you have Python >=3.7 and PyTorch >=1.7.0. If using typeguard
, ensure it is a version <3.0.0 for compatibility reasons.
How to Use TorchTyping
TorchTyping integrates with PyTorch and, optionally, with Typeguard for runtime checks. Below is an example of how TorchTyping enhances a typical tensor function:
Without TorchTyping:
def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x.unsqueeze(-1) * y.unsqueeze(-2)
With TorchTyping:
def batch_outer_product(x: TensorType["batch", "x_channels"],
y: TensorType["batch", "y_channels"]
) -> TensorType["batch", "x_channels", "y_channels"]:
return x.unsqueeze(-1) * y.unsqueeze(-2)
By using TorchTyping, you transform verbose comments into explicit function definitions, significantly reducing the chance of errors and improving maintenance.
Runtime Checks with Typeguard
Using Typeguard, developers can enforce tensor shape and type checks during runtime. First, install Typeguard and patch it with TorchTyping:
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked
patch_typeguard()
@typechecked
def func(x: TensorType["batch"], y: TensorType["batch"]) -> TensorType["batch"]:
return x + y
func(rand(3), rand(3)) # Correct usage
func(rand(3), rand(1)) # Raises an error
If not using Typeguard, TorchTyping annotations serve primarily for documentation, leaving performance unaffected during execution.
TorchTyping API
The main construct is TensorType
, which allows specifying:
- Shapes: Using integers, strings, ellipses (...), or any combinations thereof for flexibility.
- Data Types: Specify using standard PyTorch dtypes or Python's built-in data types.
- Layouts: Define as
torch.strided
ortorch.sparse_coo
. - Additional Details: Customize checks by adding further specifications like
torchtyping.is_float
.
Conclusion
TorchTyping makes your PyTorch codebase more robust, clear, and concise by leveraging type annotations to document and enforce tensor specifications. It's a valuable tool for anyone regularly working with tensors, offering compatibility with testing frameworks like pytest
and dynamically aiding bug prevention. For new projects, the creator suggests considering jaxtyping
, a newer tool that offers improved support for static type checkers.