Skip to content
Run in Google Colab View notebook on GitHub

Migrating from torchvision to Albumentations

This notebook shows how you can use Albumentations instead of torchvision to perform data augmentation.

Import the required libraries

from PIL import Image
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2

An example pipeline that uses torchvision

class TorchvisionDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]

        # Read an image with PIL
        image = Image.open(file_path)
        if self.transform:
            image = self.transform(image)
        return image, label


torchvision_transform = transforms.Compose([
    transforms.Resize((256, 256)), 
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
])


torchvision_dataset = TorchvisionDataset(
    file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
    labels=[1, 2, 3],
    transform=torchvision_transform,
)

The same pipeline with Albumentations

class AlbumentationsDataset(Dataset):
    """__init__ and __len__ functions are the same as in TorchvisionDataset"""
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]

        # Read an image with OpenCV
        image = cv2.imread(file_path)

        # By default OpenCV uses BGR color space for color images,
        # so we need to convert the image to RGB color space.
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label


albumentations_transform = A.Compose([
    A.Resize(256, 256), 
    A.RandomCrop(224, 224),
    A.HorizontalFlip(),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2()
])


albumentations_dataset = AlbumentationsDataset(
    file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
    labels=[1, 2, 3],
    transform=albumentations_transform,
)

Using albumentations with PIL

You can use PIL instead of OpenCV while working with Albumentations, but in that case, you need to convert a PIL image to a NumPy array before applying transformations. Them you need to convert the augmented image back from a NumPy array to a PIL image.

class AlbumentationsPilDataset(Dataset):
    """__init__ and __len__ functions are the same as in TorchvisionDataset"""
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]

        image = Image.open(file_path)

        if self.transform:
            # Convert PIL image to numpy array
            image_np = np.array(image)
            # Apply transformations
            augmented = self.transform(image=image_np)
            # Convert numpy array to PIL Image
            image = Image.fromarray(augmented['image'])
        return image, label


albumentations_pil_transform = A.Compose([
    A.Resize(256, 256), 
    A.RandomCrop(224, 224),
    A.HorizontalFlip(),
])


# Note that this dataset will output PIL images and not numpy arrays nor PyTorch tensors
albumentations_pil_dataset = AlbumentationsPilDataset(
    file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
    labels=[1, 2, 3],
    transform=albumentations_pil_transform,
)

Albumentations equivalents for torchvision transforms

torchvision transform albumentations transform albumentations example
Compose Compose Compose([Resize(256, 256), RandomCrop(224, 224)])
CenterCrop CenterCrop CenterCrop(256, 256)
ColorJitter HueSaturationValue HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5)
Pad PadIfNeeded PadIfNeeded(min_height=512, min_width=512)
RandomAffine ShiftScaleRotate ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, p=0.5)
RandomCrop RandomCrop RandomCrop(256, 256)
RandomGrayscale ToGray ToGray(p=0.5)
RandomHorizontalFlip HorizontalFlip HorizontalFlip(p=0.5)
RandomRotation Rotate Rotate(limit=45, p=0.5)
RandomVerticalFlip VerticalFlip VerticalFlip(p=0.5)
Resize Resize Resize(256, 256)
Normalize Normalize Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])