Main Content

Train Vision Transformer Network for Image Classification

This example shows how to fine-tune a pretrained vision transformer (ViT) neural network neural network to perform classification on a new collection of images.

ViT [1] is a neural network model that uses the transformer architecture to encode image inputs into feature vectors. The network consists of two main components: the backbone and the head. The backbone is responsible for the encoding step of the network. The backbone takes the input images and outputs a vector of features. The head is responsible for making the predictions. The head maps the encoded feature vectors to the prediction scores.

In this example, the pretrained ViT network has learned a strong feature representation for images. You can fine-tune the model for specific tasks using transfer learning. To transfer this feature representation and fine-tune it for a new data set, replace the head of the network with a new head that classifies data for your task and then fine-tune the network on the new data set.

This diagram outlines the architecture of a ViT network that makes predictions for K classes and how to edit the network to enable transfer learning for a new data set that has K* classes.

In this example, you fine-tune the base-sized ViT model (86.8 million parameters) with a patch size of 16, which is fine-tuned using the ImageNet 2012 data set at a resolution of 384-by-384.

Load Pretrained ViT Network

Load a pretrained ViT network using the visionTransformer function. This function requires a Deep Learning Toolbox™ license and the Computer Vision Toolbox™ Model for Vision Transformer Network support package. You can download this support package from the Add-On Explorer. If you do not have the support package installed, then the function provides a download link.

net = visionTransformer
net = 
  dlnetwork with properties:

         Layers: [143×1 nnet.cnn.layer.Layer]
    Connections: [167×2 table]
     Learnables: [200×3 table]
          State: [0×3 table]
     InputNames: {'imageinput'}
    OutputNames: {'softmax'}
    Initialized: 1

  View summary with summary.

View the input size of the network.

inputSize = net.Layers(1).InputSize
inputSize = 1×3

   384   384     3

To fine-tune a ViT network, you can usually fine-tune the attention layers only and freeze the other learnable parameters [2]. Freeze the network weights using the freezeNetwork function, which is attached to this example as a supporting file. To access this function, open the example as a live script.

net = freezeNetwork(net,LayersToIgnore="SelfAttentionLayer");

Load Training Data

Download and extract the Flowers data set [3]. The data set has a size of about 218 MB and contains 3670 images of flowers belonging to five classes: daisy, dandelion, roses, sunflowers, and tulips.

url = "http://download.tensorflow.org/example_images/flower_photos.tgz";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"flower_dataset.tgz");

imageFolder = fullfile(downloadFolder,"flower_photos");
if ~datasetExists(imageFolder)
    disp("Downloading Flowers data set (218 MB)...")
    websave(filename,url);
    untar(filename,downloadFolder)
end

Create an image datastore containing the images.

imds = imageDatastore(imageFolder,IncludeSubfolders=true,LabelSource="foldernames");

View the number of classes.

classNames = categories(imds.Labels);
numClasses = numel(categories(imds.Labels))
numClasses = 5

Split the datastore into training, validation, and test partitions using the splitEachLabel function. Use 80% of the images for training and set aside 10% for validation and 10% for testing.

[imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.8,0.1);

To improve training, augment the training data to include random rotation, scaling, and horizontal flipping. Resize the images to have size that matches the network input size.

augmenter = imageDataAugmenter( ...
    RandXReflection=true, ...
    RandRotation=[-90 90], ...
    RandScale=[1 2]);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,DataAugmentation=augmenter);

Create augmented image datastores that resize the validation and testing images to have size that matches the network input size. Do not apply any additional augmentations to the validation and testing data.

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest);

Replace Network Classification Head

The ViT network has two main components. The body of the network extracts features from the input images. The classification head maps the extracted features to probability vectors that represent the prediction scores for each class. To train the neural network to classify images over a new set of classes, replace the classification head with a new classification head that maps the extracted features to prediction scores for the new set of classes.

View the network architecture using the analyzeNetwork function. Locate the layers at the end of the network that map the extracted features to vectors of prediction scores. In this case, the fully connected layer with the name "head" maps the extracted features to vectors with a length of 1000, the number of classes that the network is trained to predict. The softmax layer with the name "softmax" maps those vectors to probability vectors.

analyzeNetwork(net)

Create a new fully connected layer with an output size that matches the number of classes in the training data:

  • Set the output size to the number of classes of the training data.

  • Set the layer name to "head".

layer = fullyConnectedLayer(numClasses,Name="head");

Replace the fully connected layer with the new layer using the replaceLayer (Deep Learning Toolbox) function. You do not need to replace the softmax layer because it does not have any learnable parameters.

net = replaceLayer(net,"head",layer);

Specify Training Options

Specify the training options. Choosing among the options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager (Deep Learning Toolbox) app.

  • Train using the Adam optimizer.

  • For fine-tuning, lower the learning rate to 0.0001.

  • Train for four epochs.

  • Use a mini-batch size of 12. Training a ViT network typically requires lots of memory. If you run out of memory, try using a smaller mini-batch size. Alternatively, try using a smaller model, such as the tiny-sized ViT model (5.7 million parameters), by specifying "tiny-16-imagenet-384" as the model name in the visionTransformer function.

  • Once per epoch, validate the network using the validation data.

  • Output the network that results in the lowest validation loss.

  • Monitor the training progress in a plot and monitor the accuracy metric.

  • Disable verbose output.

miniBatchSize = 12;

numObservationsTrain = numel(augimdsTrain.Files);
numIterationsPerEpoch = floor(numObservationsTrain/miniBatchSize);

options = trainingOptions("adam", ...
    MaxEpochs=4, ...
    InitialLearnRate=0.0001, ...
    MiniBatchSize=miniBatchSize, ...
    ValidationData=augimdsValidation, ...
    ValidationFrequency=numIterationsPerEpoch, ...
    OutputNetwork="best-validation", ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

Train Neural Network

Train the neural network using the trainnet (Deep Learning Toolbox) function. For classification, use cross-entropy loss. By default, the trainnet function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainnet function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.

This example trains the network using an NVIDIA Titan RTX GPU with 24 GB RAM. The training takes about 37 minutes to run.

net = trainnet(augimdsTrain,net,"crossentropy",options);

Test Neural Network

Evaluate the accuracy of the network using the test data.

Make predictions using the test data. To convert the prediction scores to class labels, use the onehotdecode (Deep Learning Toolbox) function.

YTest = minibatchpredict(net,augimdsTest);
YTest = onehotdecode(YTest,classNames,2);

Display the test classification results in a confusion matrix.

figure
TTest = imdsTest.Labels;
confusionchart(TTest,YTest)

Evaluate the test accuracy.

accuracy = mean(YTest == TTest)
accuracy = 0.9564

Make Prediction Using New Data

Use the trained neural network to make a prediction using the first image in the test data.

Read the image from the first file of the test data.

idx = 1;
testData = readByIndex(augimdsTest,idx);
I = testData.input{1};

Make a prediction using the image.

Y = minibatchpredict(net,single(I));

Get the label with the highest probability using the onehotdecode function.

label = onehotdecode(Y,classNames,2);

Display the image and the predicted label.

imshow(I)
title(label)

fprintf("Image Credit: %s\n",flowerCredit(augimdsValidation.Files(idx)))
Image Credit: CC-BY by mikeyskatie - https://www.flickr.com/photos/mikeyskatie/5948835387/

References

  1. Dosovitskiy, Alexey, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani et al. "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale." Preprint, submitted June 3, 2021. https://doi.org/10.48550/arXiv.2010.11929

  2. Touvron, Hugo, Matthieu Cord, Alaaeldin El-Nouby, Jakob Verbeek, and Hervé Jégou. "Three things everyone should know about vision transformers." In Computer Vision–ECCV 2022, edited by Shai Avidan, Gabriel Brostow, Moustapha Cissé, Giovanni Maria Farinella, and Tal Hassner, 13684: 497-515. Cham: Springer Nature Switzerland, 2022. https://doi.org/10.1007/978-3-031-20053-3_29.

  3. TensorFlow. “Tf_flowers | TensorFlow Datasets.” Accessed June 16, 2023. https://www.tensorflow.org/datasets/catalog/tf_flowers.

See Also

| | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox)

Related Topics

Go to top of page