To learn the basics of custom data loaders, we implement an MNIST data loader in this article. We do this for a classification task where we have an image and a corresponding label for each image. We need to provide code on how to load and preprocess data. This can be implemented independently of any model allowing re-usability and abstraction.
Any data loader we implement must inherit from the torch.utils.data.Dataset class that is provided by PyTorch for this very purpose. We have to override the __len__ and __getitem__ methods that are internally used by PyTorch dataloaders to fetch data at runtime.
Any custom dataset implementation must follow this basic structure.
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self) -> None:
super(CustomDataset, self).__init__()
def __len__():
pass
def __getitem__(self, index):
pass
In the __init__ function, we load all files and required data that is necessary for loading a specific instance. For example, we may have a CSV file containing image paths and their corresponding labels. We load only the CSV file as it is sufficient to load each image and its label at runtime. We do not load all of the data during initialization as it is done at runtime using the __getitem__ function. This saves space as we do not have to load all of the data at once on GPU or CPU memory.
The __len__ function returns a single integer value that represents the total data instances. We can return the total number of images or the total rows of a pandas data frame. The output of the __len__ function acts as the upper bound for the index parameter passed to the __getitem__ function so we do not face the IndexOutOfBound runtime exception.
The __getitem__ function can return any number of values as required. For the classification task, we return both the image and label. For other tasks, we may only require the image or we may require more than two values. The loading of a specific data instance is done with the __getitem__ function. We are provided an index, and we provide code to load the values at that index from the data files.
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Normalize, Compose, Resize
from torchvision.datasets import MNIST
Torchvision Datasets is a PyTorch library that provides helper classes to download datasets. We use the MNIST class that will allow using the MNIST dataset.
class MNISTDataset(Dataset):
def __init__(self, train: bool = True, output_size=(227,227)) -> None:
super(MNISTDataset, self).__init__()
self.mnist = MNIST(
root="data",
train=train,
download=True,
transform=Compose([Resize(output_size), ToTensor(), Normalize(mean=(0.5,), std=(0.5,))]),)
During initialization, the MNIST class downloads the data to the root directory we provide as a parameter. Also, we can define custom transforms that preprocess the data. If we are using locally stored data, we process data in the __getitem__ function. Here, we resize an image as required, normalize it, and convert it to a PyTorch tensor.
def __len__(self):
return len(self.mnist)
def __getitem__(self, index):
img, label = self.mnist[index]
return img, label
PyTorch dataloaders create batches of data and provide an iterator over the data. It requires a torch.utils.data.Dataset class as a parameter. As we have inherited from the same class, we can now use this custom dataset class to create a data loader.
dataset = MNISTDataset(train=True)
mnist_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
This provides a clean interface for using various datasets. We implement each dataset class independently and can reuse it with the PyTorch dataloaders with any deep learning architecture.
If we visualize our data, we achieve the following results.