Data is an integral part of Machine Learning but not all data sources are publicly available. For custom datasets, we need data loaders that are separately implemented from the training code.

Photo by Lukas from Pexels

Introduction

To learn the basics of custom data loaders, we implement an MNIST data loader in this article. We do this for a classification task where we have an image and a corresponding label for each image. We need to provide code on how to load and preprocess data. This can be implemented independently of any model allowing re-usability and abstraction.

Dataset Abstract Class

Any data loader we implement must inherit from the torch.utils.data.Dataset class that is provided by PyTorch for this very purpose. We have to override the __len__ and __getitem__ methods that are internally used by PyTorch dataloaders to fetch data at runtime.

Any custom dataset implementation must follow this basic structure.

from torch.utils.data import Dataset


class CustomDataset(Dataset):
def __init__(self) -> None:
super(CustomDataset, self).__init__()

def __len__():
pass

def __getitem__(self, index):
pass

In the __init__ function, we load all files and required data that is necessary for loading a specific instance. For example, we may have a CSV file containing image paths and their corresponding labels. We load only the CSV file as it is sufficient to load each image and its label at runtime. We do not load all of the data during initialization as it is done at runtime using the __getitem__ function. This saves space as we do not have to load all of the data at once on GPU or CPU memory.

The __len__ function returns a single integer value that represents the total data instances. We can return the total number of images or the total rows of a pandas data frame. The output of the __len__ function acts as the upper bound for the index parameter passed to the __getitem__ function so we do not face the IndexOutOfBound runtime exception.

The __getitem__ function can return any number of values as required. For the classification task, we return both the image and label. For other tasks, we may only require the image or we may require more than two values. The loading of a specific data instance is done with the __getitem__ function. We are provided an index, and we provide code to load the values at that index from the data files.

Implementation: MNIST Custom Dataset

Imports

from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Normalize, Compose, Resize
from torchvision.datasets import MNIST

Torchvision Datasets is a PyTorch library that provides helper classes to download datasets. We use the MNIST class that will allow using the MNIST dataset.

Initialization

class MNISTDataset(Dataset):
def __init__(self, train: bool = True, output_size=(227,227)) -> None:
super(MNISTDataset, self).__init__()
self.mnist = MNIST(
root="data",
train=train,
download=True,
transform=Compose([Resize(output_size), ToTensor(), Normalize(mean=(0.5,), std=(0.5,))]),)

During initialization, the MNIST class downloads the data to the root directory we provide as a parameter. Also, we can define custom transforms that preprocess the data. If we are using locally stored data, we process data in the __getitem__ function. Here, we resize an image as required, normalize it, and convert it to a PyTorch tensor.

Loading Data

def __len__(self):
return len(self.mnist)

def __getitem__(self, index):
img, label = self.mnist[index]
return img, label

Dataloader

PyTorch dataloaders create batches of data and provide an iterator over the data. It requires a torch.utils.data.Dataset class as a parameter. As we have inherited from the same class, we can now use this custom dataset class to create a data loader.

dataset = MNISTDataset(train=True)
mnist_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

Conclusion

This provides a clean interface for using various datasets. We implement each dataset class independently and can reuse it with the PyTorch dataloaders with any deep learning architecture.

If we visualize our data, we achieve the following results.

Image by Author