Introduction to EGNN - Pytorch
The EGNN-Pytorch project offers a powerful implementation of E(n)-Equivariant Graph Neural Networks in the Pytorch framework. This approach was designed to focus on simple invariant features and has proven to outperform existing techniques, such as SE3 Transformer and Lie Conv, both in accuracy and performance. It has achieved state-of-the-art results in various fields, including dynamical system models and molecular activity prediction tasks.
Installation
To install the EGNN-Pytorch library, simply use pip:
$ pip install egnn-pytorch
Basic Usage
EGNN-Pytorch allows users to create neural network layers that can process both graph features and their coordinates. Here's a basic example:
import torch
from egnn_pytorch import EGNN
layer1 = EGNN(dim=512)
layer2 = EGNN(dim=512)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
feats, coors = layer1(feats, coors)
feats, coors = layer2(feats, coors)
Including Edge Features
EGNN can also accommodate graphs with edge features:
import torch
from egnn_pytorch import EGNN
layer1 = EGNN(dim=512, edge_dim=4)
layer2 = EGNN(dim=512, edge_dim=4)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)
feats, coors = layer1(feats, coors, edges)
feats, coors = layer2(feats, coors, edges)
Full Network Configuration
To create a full EGNN network, users can specify the number of layers, dimensions, and other parameters. An example of such configuration is shown below:
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network(
num_tokens=21,
num_positions=1024,
dim=32,
depth=3,
num_nearest_neighbors=8,
coor_weights_clamp_value=2.
)
feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()
feats_out, coors_out = net(feats, coors, mask=mask)
Sparse Neighbors and Adjacency Matrix
EGNN allows focusing on sparse neighbors using an adjacency matrix. This is useful for handling sequences with specific neighbor connections:
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network(
num_tokens=21,
dim=32,
depth=3,
only_sparse_neighbors=True
)
feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()
# Create a naive adjacency matrix connecting each node to its immediate neighbors
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
feats_out, coors_out = net(feats, coors, mask=mask, adj_mat=adj_mat)
Continuous Edges
Continuous edge features can also be incorporated, enhancing the model's capacity to handle complex graph data:
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network(
num_tokens=21,
dim=32,
depth=3,
edge_dim=4,
num_nearest_neighbors=3
)
feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()
continuous_edges = torch.randn(1, 1024, 1024, 4)
feats_out, coors_out = net(feats, coors, edges=continuous_edges, mask=mask, adj_mat=adj_mat)
Stability Improvements
To combat instability issues that arise with numerous neighbors, the library provides solutions like coordinate normalization and clamping of coordinate weights:
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network(
num_tokens=21,
dim=32,
depth=3,
num_nearest_neighbors=32,
norm_coors=True,
coor_weights_clamp_value=2.
)
feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()
feats_out, coors_out = net(feats, coors, mask=mask)
Detailed Parameter Configuration
The library allows users to fine-tune a variety of parameters within the EGNN model:
import torch
from egnn_pytorch import EGNN
model = EGNN(
dim=dim,
edge_dim=0,
m_dim=16,
fourier_features=0,
num_nearest_neighbors=0,
dropout=0.0,
norm_feats=False,
norm_coors=False,
update_feats=True,
update_coors=True,
only_sparse_neighbors=False,
valid_radius=float('inf'),
m_pool_method='sum',
soft_edges=False,
coor_weights_clamp_value=None
)
Examples
To see EGNN in action, users can run examples like protein backbone denoising, which requires additional dependencies:
$ pip install sidechainnet
$ python denoise_sparse.py
Testing
To run tests for EGNN-Pytorch, ensure Pytorch Geometric is installed:
$ python setup.py test
Citations
For those looking to cite this work, here is the relevant citation in Bibtex format:
@misc{satorras2021en,
title = {E(n) Equivariant Graph Neural Networks},
author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
year = {2021},
eprint = {2102.09844},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
EGNN-Pytorch is an exemplary tool for researchers and developers looking to utilize advanced graph neural networks with ease and flexibility.