Introducing TorchGeo: Bringing Machine Learning and Geospatial Data Together
TorchGeo is an innovative library that sits within the PyTorch ecosystem, much like torchvision, but is tailored to geospatial data. It integrates datasets, samplers, transforms, and pre-trained models to facilitate work with spatial data. TorchGeo serves dual purposes: simplifying geospatial data handling for machine learning professionals and introducing machine learning approaches to remote sensing experts.
Community and Collaboration
The TorchGeo community is active and inclusive, offering platforms for interaction via Slack, OSGeo, and Hugging Face, and incorporates PyTorch’s ecosystem for broad engagement and collaboration.
Easy Installation
TorchGeo is simple to install, primarily recommended through pip:
$ pip install torchgeo
Additionally, detailed installation guidelines for conda and spack can be found in the TorchGeo documentation.
Comprehensive Documentation
The extensive documentation of TorchGeo provides everything from API references to getting started tutorials. Additionally, several resources, such as a paper, podcast episode, tutorial, and blog post, provide in-depth insights into using TorchGeo for geospatial deep learning.
Works Seamlessly with Geospatial Datasets
Working with geospatial datasets entails unique challenges, such as diverse spectral bands and spatial resolutions across different satellites. TorchGeo simplifies this by offering a robust way to deal with datasets through straightforward data loading and sampling techniques. Here's an example where users can manage Landsat and Cropland Data Layer (CDL) data effortlessly:
from torchgeo.datasets import CDL, Landsat7, Landsat8
landsat7 = Landsat7(root="...", bands=["B1", ..., "B7"])
landsat8 = Landsat8(root="...", bands=["B2", ..., "B8"])
landsat = landsat7 | landsat8
cdl = CDL(root="...", download=True, checksum=True)
dataset = landsat & cdl
Efficient Data Loading and Sampling
TorchGeo enhances data loading by adapting PyTorch data loaders to handle large geospatial images and offers sampling techniques based on coordinates, which is a significant advantage for efficient and effective model training and evaluation.
sampler = RandomGeoSampler(dataset, size=256, length=10000)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler, collate_fn=stack_samples)
Diverse Benchmark Datasets
With TorchGeo, users can access a range of benchmark datasets that support tasks from image classification to object detection. These datasets are designed with input images and target labels, providing a familiar environment for those accustomed to using torchvision datasets.
Leveraging Pre-trained Weights
The library supports models pre-trained on diverse multispectral sensors, broadening the scope beyond traditional RGB channels available in other datasets such as ImageNet. This adaptability in using pre-trained weights enhances TorchGeo’s utility in remote sensing applications.
import timm
from torchgeo.models import ResNet18_Weights
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"], num_classes=10)
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
Ensuring Reproducibility with Lightning
To maintain reproducibility and ease of use, TorchGeo integrates with PyTorch Lightning, providing well-defined datamodules and trainers. This setup allows for direct comparisons and easy experimentation within the research community.
from lightning.pytorch import Trainer
from torchgeo.trainers import SemanticSegmentationTask
from torchgeo.datamodules import InriaAerialImageLabelingDataModule
datamodule = InriaAerialImageLabelingDataModule(root="...", batch_size=64, num_workers=6)
task = SemanticSegmentationTask(model="unet", backbone="resnet50", weights=True)
trainer = Trainer(default_root_dir="...")
trainer.fit(model=task, datamodule=datamodule)
Citation and Contribution
For those utilizing TorchGeo in their research, proper citation of the corresponding paper is encouraged. Contributions are welcomed through the project’s open-source code of conduct, fostering a collaborative and evolving library.
In summary, TorchGeo is a powerful tool designed to bridge the gap between machine learning advancements and the complexities of geospatial data. Its comprehensive capabilities make it a valuable asset for both machine learning and remote sensing fields.