Image classification on the CIFAR10 dataset¶
The following files are also available on GitHub - https://github.com/albumentations-team/autoalbument/tree/master/examples/cifar10
dataset.py¶
import cv2
import torchvision
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
def __init__(self, root="~/data/cifar10", train=True, download=True, transform=None):
super().__init__(root=root, train=train, download=download, transform=transform)
def __getitem__(self, index):
image, label = self.data[index], self.targets[index]
if self.transform is not None:
transformed = self.transform(image=image)
image = transformed["image"]
return image, label
search.yaml¶
# @package _global_
_version: 2 # An internal value that indicates a version of the config schema. This value is used by
# `autoalbument-search` and `autoalbument-migrate` to upgrade the config to the latest version if necessary.
# Please do not change it manually.
task: classification # Deep learning task. Should either be `classification` or `semantic_segmentation`.
policy_model:
# Settings for Policy Model that searches augmentation policies.
task_factor: 0.1
# Multiplier for classification loss of a model. Faster AutoAugment uses classification loss to prevent augmentations
# from transforming images of a particular class to another class. The authors of Faster AutoAugment use 0.1 as
# default value.
gp_factor: 10
# Multiplier for the gradient penalty for WGAN-GP training. 10 is the default value that was proposed in
# `Improved Training of Wasserstein GANs`.
temperature: 0.05
# Temperature for Relaxed Bernoulli distribution. The probability of applying a certain augmentation is sampled from
# Relaxed Bernoulli distribution (because Bernoulli distribution is not differentiable). With lower values of
# `temperature` Relaxed Bernoulli distribution behaves like Bernoulli distribution. In the paper, the authors
# of Faster AutoAugment used 0.05 as a default value for `temperature`.
num_sub_policies: 100
# Number of augmentation sub-policies. When an image passes through an augmentation pipeline, Faster AutoAugment
# randomly chooses one sub-policy and uses augmentations from that sub-policy to transform an input image. A larger
# number of sub-policies leads to a more diverse set of augmentations and better performance of a model trained on
# augmented images. However, an increase in the number of sub-policies leads to the exponential growth of a search
# space of augmentations, so you need more training data for Policy Model to find good augmentation policies.
num_chunks: 8
# Number of chunks in a batch. Faster AutoAugment splits each batch of images into `num_chunks` chunks. Then it
# applies the same sub-policy with the same parameters to each image in a chunk. This parameter controls the tradeoff
# between the speed of augmentation search and diversity of augmentations. Larger `num_chunks` values will lead to
# faster searching but less diverse set of augmentations. Note that this parameter is used only in the searching
# phase. When you train a model with found sub-policies, Albumentations will apply a distinct set of transformations
# to each image separately.
operation_count: 4
# Number of consecutive augmentations in each sub-policy. Faster AutoAugment will sequentially apply `operation_count`
# augmentations from a sub-policy to an image. Larger values of `operation_count` lead to better performance of
# a model trained on augmented images. Simultaneously, larger values of `operation_count` affect the speed of search
# and increase the searching time.
classification_model:
# Settings for Classification Model that is used for two purposes:
# 1. As a model that performs classification of input images.
# 2. As a Discriminator for Policy Model.
_target_: model.Cifar10ClassificationModel
# A custom classification model is used. This model is defined inside the `model.py` file which is located
# in the same directory with `search.yaml` and `dataset.py`.
# # As an alternative, you could use a built-in AutoAlbument model using the following config:
# # _target_: autoalbument.faster_autoaugment.models.ClassificationModel
#
# # Number of classes in the dataset. The dataset implementation should return an integer in the range
# # [0, num_classes - 1] as a class label of an image.
# num_classes: 10
#
# # The architecture of Classification Model. AutoAlbument uses models from
# # https://github.com/rwightman/pytorch-image-models/. Please refer to its documentation to get a list of available
# # models - https://rwightman.github.io/pytorch-image-models/#list-models-with-pretrained-weights.
# architecture: resnet18
#
# # Boolean flag that indicates whether the selected model architecture should load pretrained weights or use randomly
# # initialized weights.
# pretrained: False
data:
dataset:
_target_: dataset.Cifar10SearchDataset
root: ~/data/cifar10
train: true
download: true
# Class for the PyTorch Dataset and arguments to it. AutoAlbument will create an object of this class using
# the `instantiate` method from Hydra - https://hydra.cc/docs/next/patterns/instantiate_objects/overview/.
#
# Note that the target class value in the `_target_` argument should be located inside PYTHONPATH so Hydra could
# find it. The directory with the config file is automatically added to PYTHONPATH, so the default value
# `dataset.SearchDataset` points to the class `SearchDataset` from the `dataset.py` file. This `dataset.py` file is
# located along with the `search.yaml` file in the same directory provided by `--config-dir`.
#
# As an alternative, you could provide a path to a Python file with the dataset using the `dataset_file` parameter
# instead of the `dataset` parameter. The Python file should contain the implementation of a PyTorch dataset for
# augmentation search. The dataset class should have named `SearchDataset`. The value in `dataset_file` could either
# be a relative or an absolute path ; in the case of a relative path, the path should be relative to this config
# file's location.
#
# - Example of a relative path:
# dataset_file: dataset.py
#
# - Example of an absolute path:
# dataset_file: /projects/pytorch/dataset.py
#
input_dtype: uint8
# The data type of input images. Two values are supported:
# - uint8. In that case, all input images should be NumPy arrays with the np.uint8 data type and values in the range
# [0, 255].
# - float32. In that case, all input images should be NumPy arrays with the np.float32 data type and values in the
# range [0.0, 1.0].
preprocessing: null
# A list of preprocessing augmentations that will be applied to each image before applying augmentations from
# a policy. A preprocessing augmentation should be defined as `key`: `value`, where `key` is the name of augmentation
# from Albumentations, and `value` is a dictionary with augmentation parameters. The found policy will also apply
# those preprocessing augmentations before applying the main augmentations.
#
# Here is an example of an augmentation pipeline that first pads an image to the size 512x512 pixels, then resizes
# the resulting image to the size 256x256 pixels and finally crops a random patch with the size 224x224 pixels.
#
# preprocessing:
# - PadIfNeeded:
# min_height: 512
# min_width: 512
# - Resize:
# height: 256
# width: 256
# - RandomCrop:
# height: 224
# width: 224
#
normalization:
mean: [0.4914, 0.4822, 0.4465]
std: [0.247, 0.243, 0.261]
# Normalization values for images. For each image, the search pipeline will subtract `mean` and divide by `std`.
# Normalization is applied after transforms defined in `preprocessing`. Note that regardless of `input_dtype`,
# the normalization function will always receive a `float32` input with values in the range [0.0, 1.0], so you should
# define `mean` and `std` values accordingly.
dataloader:
_target_: torch.utils.data.DataLoader
batch_size: 128
shuffle: true
num_workers: 8
pin_memory: true
drop_last: true
# Parameters for the PyTorch DataLoader. Please refer to the PyTorch documentation for the description of parameters -
# https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader.
optim:
main:
_target_: torch.optim.Adam
lr: 1e-3
betas: [0, 0.999]
# Optimizer configuration for the main (either Classification or Semantic Segmentation) Model
policy:
_target_: torch.optim.Adam
lr: 1e-3
betas: [0, 0.999]
# Optimizer configuration for Policy Model
seed: 42 # Random seed. If the value is not null, it will be passed to `seed_everything` -
# https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.utilities.seed.html?highlight=seed_everything
hydra:
run:
dir: ${config_dir:}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
# Path to the directory that will contain all outputs produced by the search algorithm. `${config_dir:}` contains
# path to the directory with the `search.yaml` config file. Please refer to the Hydra documentation for more
# information - https://hydra.cc/docs/configure_hydra/workdir.
trainer:
# Configuration for PyTorch Lightning Trainer. You can read more about Trainer and its arguments at
# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html.
max_epochs: 40
# Number of epochs to search for augmentation parameters.
# More detailed description - https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#max-epochs
benchmark: true
# If true enables cudnn.benchmark.
# More detailed description - https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#benchmark
gpus: 1
# Number of GPUs to train on. Set to `0` or None` to use CPU for training.
# More detailed description - https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#gpus
model.py¶
"""WideResNet code from https://github.com/xternalz/WideResNet-pytorch"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from autoalbument.faster_autoaugment.models import BaseDiscriminator
class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride):
super(BasicBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_planes)
self.relu2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
self.equal_in_out = in_planes == out_planes
self.conv_shortcut = (
(not self.equal_in_out)
and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
or None
)
def forward(self, x):
if not self.equal_in_out:
x = self.relu1(self.bn1(x))
else:
out = self.relu1(self.bn1(x))
out = self.relu2(self.bn2(self.conv1(out if self.equal_in_out else x)))
out = self.conv2(out)
return torch.add(x if self.equal_in_out else self.conv_shortcut(x), out)
class NetworkBlock(nn.Module):
def __init__(self, nb_layers, in_planes, out_planes, block, stride):
super(NetworkBlock, self).__init__()
self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride)
def _make_layer(self, block, in_planes, out_planes, nb_layers, stride):
layers = []
for i in range(int(nb_layers)):
layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1))
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
class WideResNet(nn.Module):
def __init__(self, depth, num_classes, widen_factor=1):
super(WideResNet, self).__init__()
n_channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
assert (depth - 4) % 6 == 0
n = (depth - 4) / 6
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(3, n_channels[0], kernel_size=3, stride=1, padding=1, bias=False)
# 1st block
self.block1 = NetworkBlock(n, n_channels[0], n_channels[1], block, 1)
# 2nd block
self.block2 = NetworkBlock(n, n_channels[1], n_channels[2], block, 2)
# 3rd block
self.block3 = NetworkBlock(n, n_channels[2], n_channels[3], block, 2)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(n_channels[3])
self.relu = nn.ReLU(inplace=True)
self.fc = nn.Linear(n_channels[3], num_classes)
self.n_channels = n_channels[3]
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward_features(self, x):
x = self.conv1(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.relu(self.bn1(x))
x = F.avg_pool2d(x, 8, 1, 0)
x = x.view(-1, self.n_channels)
return x
def forward_classifier(self, x):
return self.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_classifier(x)
return x
def wide_resnet_28x10(num_classes):
return WideResNet(depth=28, widen_factor=10, num_classes=num_classes)
class Cifar10ClassificationModel(BaseDiscriminator):
def __init__(self, *args, **kwargs):
super().__init__()
self.base_model = wide_resnet_28x10(num_classes=10)
num_features = self.base_model.fc.in_features
self.discriminator = nn.Sequential(
nn.Linear(num_features, num_features), nn.ReLU(), nn.Linear(num_features, 1)
)
def forward(self, input):
x = self.base_model.forward_features(input)
return self.base_model.forward_classifier(x), self.discriminator(x).view(-1)