How to create a custom PyTorch Dataset for an image dataset

With the advance in Deep learning, working with huge datasets has increasingly become part of the Data Science pipeline. Particularly in computer vision problems, we have a visual dataset that takes a huge amount of memory space to process and augment. Also, loading such a huge dataset takes a lot of time to process the dataset, and programs often slow down. To meet the requirement of processing a huge dataset we need an efficient way to process our dataset so that we have better memory and time management.

The PyTorch framework provides a well-defined schema to process our dataset more efficiently so that we can use our GPU fully. It formalizes the data loading and preprocessing, which helps us process the data in real-time on multiple cores and can be fed to the deep learning model directly. This not only makes our code more efficient in terms of execution time and memory but also helps to make code more readable and manageable which is an absolute necessity for large projects.

In this article, we will learn how to code a custom Dataset class for a computer vision project using PyTorch’s torch.utils.data.Dataset class. Computer vision datasets can be organized in many different ways, we will work on three different types.

  • Images organised in class folders.
  • Images data organised as CSV file.
  • Images data organised in a folder with image name as class.

We will be using example cat and dog datasets to code custom datasets.

Images organised in class folders

Ref

Before actually jumping into Dataloader, let’s first take a look at how our folder structure should be organized for this part. As you can below, the dataset is divided into class folders.

─── train
    ├── cat
    │    ├── cat1.png
    │    ├── cat2.png
    │    ├── ......
    │    └── cat100.png
    ├── dog
            ├── dog1.png
            ├── dog2.png
            ├── ......
            └── dog100.png
─── test
    ├── cat
    │    ├── cat1.png
    │    ├── cat2.png
    │    ├── ......
    │    └── cat10.png
    ├── dog
            ├── dog1.png
            ├── dog2.png
            ├── ......
            └── dog10.png

PyTorch provides two data primitives: torch.utils.data.DataLoader and torch.utils.data.Dataset that allow you to use pre-loaded datasets as well as your own data. Dataset stores the samples and their corresponding labels, and Data Loader wraps an iterable around the Dataset to enable easy access to the samples. Let us first write the code for custome Dataset Class and then we will explore each section:

import os
import torch
from PIL import Image
from torchvision import transforms, datasets

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir):
        self.classnames=["cat", "dog"]   # Class List
        self.root_dir = root_dir         # Train or Test
        self.img_label_tuple = []        # Data list

        # Loop to extract image path and label for folder name
        for i in range(0, len(self.classnames)):
            class_dir = os.path.join(root_dir, self.classnames[i])
            all_files = [os.path.join(class_dir,f) for f in listdir(class_dir) if isfile(join(class_dir, f))]
            for j in all_files:
              item = tuple((j, i))
              self.img_label_tuple.append(item)

        # Simple transform
        self.transform = transforms.Compose([transforms.Resize((32,32)),
                                              transforms.ToTensor()])

    def __len__(self):
        # Returns lenght of the dataset.
        return len(self.img_label_tuple)

    # Contructur to process dataset, tranform it and returns (input image with label)
    def __getitem__(self, idx):
        path = self.img_label_tuple[idx][0]
        class_id = self.classnames.index(img_label_tuple[idx][1])
        img = Image.open(path).convert('RGB')
        img = self.transform(img)
    
        return (img, class_id)

In the above code, __init__() is known as a constructor which always gets executed when the object is called. In our CustomDataset, we initiate every Dataset with the class list, image_label_tuple list, and then we extract the dataset information from the given folder, In this way, we use the extracted data information to further process our image dataset. We have inherited the above custom class with the properties of torch.utils.data.Dataset, so that we can leverage features provided by this PyTorch’s class such as multiprocessing.

We also defined two other constructor __len__() and __getitem__(), here __len__() is use to get the length of the given dataset. Whereas __getitem__() is a little complex, here we can define specific pre-processing techniques to pre-process our images before passing them to our machine learning model.

Now that we have created the custom Dataset class for the dataset, we can move ahead with creating an iterable using torch.utils.data.DataLoader. In order to create the iterable we need our custom Dataset class and a few arguments given below, learn about more arguments here:

  • batch_size(int): Number of samples in each set of genearted batch.
  • shuffle(bool): If set True, the data will shuffle every time a batch is generated. Generally True for training.
  • num_worker: which denotes the number of processes that generate batches in parallel.
bs = 64 	#Batch Size

root_dir = "Dataset folder path"
dataset = CustomDataset(root_dir)  # Dataset class initiation 
dataloader = torch.utils.data.DataLoader(dataset, bs, shuffle=True)

Images data organised as CSV file

In many Kaggle competitions, we are given we data folder and one CSV file. This file often contains meta-data of the dataset, for this section, we are assuming only two columns are given 1.) image_file_path 2.) class_label.

import os
import torch
import pandas as pd
from PIL import Image
from torchvision import transforms, datasets

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir):

        self.dataframe = pd.read_csv(root_dir)   # Data frame from CSV file
        self.class_list = self.dataframe['class_label'].unique()
        # Simple transform
        self.transform = transforms.Compose([transforms.Resize((32,32)),
                                              transforms.ToTensor()])

    def __len__(self):
        # Returns lenght of the dataset.
        return len(self.dataframe)

    # Contructur to process dataset, tranform it and returns (input image with label)
    def __getitem__(self, idx):
        path = self.dataframe['img_path'].iloc[idx]
        class_id = self.class_list.index(self.dataframe['class_label'].iloc[idx])
        img = Image.open(path).convert('RGB')
        img = self.transform(img)
    
        return (img, class_id)

Similar to first section, torch.utils.data.DataLoader will be use to create dataloaders, shown in the code given below:

bs = 64 	#Batch Size

root_dir = "Dataset folder path"
dataset = CustomDataset(root_dir)  # Dataset class initiation 
dataloader = torch.utils.data.DataLoader(dataset, bs, shuffle=True)

Images data organised in a folder with image name as class.

In a case when we have the image name as their classes we can just extract the class label from the image name. Below is the folder structure of such a scenario.

─── train
        ├── cat_1.png
        ├── dog_2.png
        ├── cat_3.png
        ├── cat_4.png 
        ├── dog_5.png 
        ├── cat_6.png
        ├── ......
        └── dog_100.png

Similar to the above approach we will be creating a CustomDataset for the above folder structure.

import os
import torch
import pandas as pd
from PIL import Image
from torchvision import transforms, datasets

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir):
        self.classnames=["cat", "dog"]   # Class List
        self.root_dir = root_dir         # Train or Test
        self.img_path_list = []        # Data list

        # Loop to extract image path
        all_files = [os.path.join(self.root_dir,f) for f in listdir(self.root_dir) if isfile(join(self.root_dir, f))]
        for item in all_files:
          self.img_path_list.append(item)

        # Simple transform
        self.transform = transforms.Compose([transforms.Resize((32,32)),
                                              transforms.ToTensor()])

    def __len__(self):
        # Returns lenght of the dataset.
        return len(self.img_path_list)

    # Contructur to process dataset, tranform it and returns (input image with label)
    def __getitem__(self, idx):
        path = self.img_path_list[idx]
        class_id = self.classnames.index(path.split("_")[0])
        img = Image.open(path).convert('RGB')
        img = self.transform(img)
    
        return (img, class_id)

In __getitem__(), we simply extracted the class label name by splitting the pathname of the image. After creation of custom Dataset class. We are ready to create DataLoader again for the third time,

bs = 64 	#Batch Size

root_dir = "Dataset folder path"
dataset = CustomDataset(root_dir)  # Dataset class initiation 
dataloader = torch.utils.data.DataLoader(dataset, bs, shuffle=True)

Conclution

In this article, we learned about Custom Dataset and DataLoader and how to create them in PyTorch. We also explored three different ways in which our data could be organized and how to create a Custom dataset for each set of scenarios. Hope you like the article and learned something, do share with one who needs such content. Read more articles here.

References
Mohammad Ahmad
Mohammad Ahmad

Research Engineer at OLA electric mobility

Articles: 2

Leave a Reply

Your email address will not be published. Required fields are marked *