TFRecord Reader and Writer
The TFRecord Reader and Writer library is designed to efficiently handle TFRecord files in Python. This library is especially useful for those working with PyTorch, as it includes a convenient IterableDataset reader for working with TFRecord files. The library supports both uncompressed and compressed gzip TFRecords.
Installation
To start using the library, you can easily install it using pip with the following command:
pip3 install 'tfrecord[torch]'
Usage
Creating Index Files
When working with multiple workers, it is recommended to create an index file for each TFRecord file to avoid duplicate records. To generate an index file for a single TFRecord, you can use this command:
python3 -m tfrecord.tools.tfrecord2idx <tfrecord path> <index path>
To create index files for all TFRecord files in a directory, use:
tfrecord2idx <data dir>
Reading & Writing tf.train.Example
Reading TFRecord Files in PyTorch
To read TFRecord files in PyTorch, you can use the provided TFRecordDataset
class:
import torch
from tfrecord.torch.dataset import TFRecordDataset
tfrecord_path = "/tmp/data.tfrecord"
index_path = None
description = {"image": "byte", "label": "float"}
dataset = TFRecordDataset(tfrecord_path, index_path, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)
data = next(iter(loader))
print(data)
For reading multiple TFRecord files, use MultiTFRecordDataset
, which samples from the given TFRecords based on specified probabilities:
import torch
from tfrecord.torch.dataset import MultiTFRecordDataset
tfrecord_pattern = "/tmp/{}.tfrecord"
index_pattern = "/tmp/{}.index"
splits = {
"dataset1": 0.8,
"dataset2": 0.2,
}
description = {"image": "byte", "label": "int"}
dataset = MultiTFRecordDataset(tfrecord_pattern, index_pattern, splits, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)
data = next(iter(loader))
print(data)
Infinite and Finite PyTorch Dataset
By default, MultiTFRecordDataset
will sample data indefinitely. To make it finite, set the infinite
flag to False
:
dataset = MultiTFRecordDataset(..., infinite=False)
Shuffling the Data
You can shuffle the data by setting a queue size:
dataset = TFRecordDataset(..., shuffle_queue_size=1024)
Transforming Input Data
You can apply transformations to the input data by passing a function to the transform
argument:
import tfrecord
import cv2
def decode_image(features):
features["image"] = cv2.imdecode(features["image"], -1)
return features
description = {
"image": "bytes",
}
dataset = tfrecord.torch.TFRecordDataset("/tmp/data.tfrecord",
index_path=None,
description=description,
transform=decode_image)
data = next(iter(dataset))
print(data)
Writing TFRecord Files
To write tf.train.Example
records in Python, you can use the TFRecordWriter
:
import tfrecord
writer = tfrecord.TFRecordWriter("/tmp/data.tfrecord")
writer.write({"image": (image_bytes, "byte"), "label": (label, "float"), "index": (index, "int")})
writer.close()
Reading TFRecord Files
TFRecord files can be read using the tfrecord_loader
:
import tfrecord
loader = tfrecord.tfrecord_loader("/tmp/data.tfrecord", None, {"image": "byte", "label": "float", "index": "int"})
for record in loader:
print(record["label"])
Reading & Writing tf.train.SequenceExample
Similar methods can be used for reading and writing tf.train.SequenceExample
. Additional arguments (sequence_description
for reading and sequence_datum
for writing) are required to handle SequenceExample
data correctly.
Writing SequenceExamples to File
import tfrecord
writer = tfrecord.TFRecordWriter("/tmp/data.tfrecord")
writer.write({'length': (3, 'int'), 'label': (1, 'int')}, {'tokens': ([[0, 0, 1], [0, 1, 0], [1, 0, 0]], 'int'), 'seq_labels': ([0, 1, 1], 'int')})
writer.write({'length': (3, 'int'), 'label': (1, 'int')}, {'tokens': ([[0, 0, 1], [1, 0, 0]], 'int'), 'seq_labels': ([0, 1], 'int')})
writer.close()
Reading SequenceExamples in Python
Reading SequenceExamples produces a tuple with two elements:
import tfrecord
context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int", "seq_labels": "int"}
loader = tfrecord.tfrecord_loader("/tmp/data.tfrecord", None, context_description, sequence_description=sequence_description)
for context, sequence_feats in loader:
print(context["label"])
print(sequence_feats["seq_labels"])
Read SequenceExamples in PyTorch
When using PyTorch, sequences often require padding due to their variable length. You can transform feature data using the transform
function before batching it:
import torch
import numpy as np
from tfrecord.torch.dataset import TFRecordDataset
PAD_WIDTH = 5
def pad_sequence_feats(data):
context, features = data
for k, v in features.items():
features[k] = np.pad(v, ((0, PAD_WIDTH - len(v)), (0, 0)), 'constant')
return (context, features)
context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int", "seq_labels": "int"}
dataset = TFRecordDataset("/tmp/data.tfrecord", index_path=None, description=context_description, transform=pad_sequence_feats, sequence_description=sequence_description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)
data = next(iter(loader))
print(data)
Alternatively, a custom collate_fn
can be used for dynamic padding:
import torch
import numpy as np
from tfrecord.torch.dataset import TFRecordDataset
def collate_fn(batch):
from torch.utils.data._utils import collate
from torch.nn.utils import rnn
context, feats = zip(*batch)
feats_ = {k: [torch.Tensor(d[k]) for d in feats] for k in feats[0]}
return (collate.default_collate(context), {k: rnn.pad_sequence(f, True) for (k, f) in feats_.items()})
context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int", "seq_labels": "int"}
dataset = TFRecordDataset("/tmp/data.tfrecord", index_path=None, description=context_description, transform=pad_sequence_feats, sequence_description=sequence_description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, collate_fn=collate_fn)
data = next(iter(loader))
print(data)
This library thus provides a robust and flexible solution for working with TFRecord files in Python, with particular attention to integration with PyTorch.