Transfer Learning with Deep Network Designer

This example shows how to interactively prepare a network for transfer learning using the Deep Network Designer app. Transfer learning is the process of taking a pretrained deep learning network and fine-tuning it to learn a new task. Using transfer learning is usually much faster and easier than training a network from scratch because you can quickly transfer learned features to a new task using a smaller number of training images.

Perform transfer learning by following these steps:

  1. Choose a pretrained network and import it into the app.

  2. Replace the final layers with new layers adapted to the new data set:

    1. Specify the new number of classes in your training images.

    2. Set learning rates to learn faster in the new layers than in the transferred layers.

  3. Export the network for training at the command line.

Choose a Pretrained Network

Deep Learning Toolbox™ provides a selection of pretrained image classification networks that have learned rich feature representations suitable for a wide range of images. Transfer learning works best if your images are similar to the images originally used to train the network. If your training images are natural images like those in the ImageNet database, then any of the pretrained networks is suitable. To try a faster network first, use googlenet or squeezenet. For a list of available networks and how to compare them, see Pretrained Deep Neural Networks.

If your data is very different from the ImageNet data, it might be better to train a new network. For example, if you have tiny images, spectrograms, or nonimage data, then see instead Build Networks with Deep Network Designer.

Load a pretrained GoogLeNet network. If you need to download the network, then the function provides a link to Add-On Explorer.

net = googlenet;

Import Network into Deep Network Designer

To open Deep Network Designer, on the Apps tab, under Machine Learning and Deep Learning, click the app icon. Alternatively, you can open the app from the command line.


Click Import and select the network to load from the workspace. Deep Network Designer displays a zoomed-out view of the whole network.

Explore the network plot. To zoom in with the mouse, use Ctrl+scroll wheel. To pan, use the arrow keys, or hold down the scroll wheel and drag the mouse. Select a layer to view its properties. Deselect all layers to view the network summary in the Properties pane.

Edit Network for Transfer Learning

The network classifies input images using the last learnable layer and the final classification layer. To retrain a pretrained network to classify new images, replace these final layers with new layers adapted to the new data set.

Change Number of Classes

To use a pretrained network for transfer learning, you must change the number of classes to match your new data set. First, find the last learnable layer in the network. For GoogLeNet, and most pretrained networks, the last learnable layer is a fully connected layer. Click the layer loss3-classifier and view its properties.

The OutputSize property defines the number of classes for classification problems. The Properties pane indicates that the pretrained network can classify images into 1000 classes. You cannot edit OutputSize.

To change the number of classes, drag a new fullyConnectedLayer from the Layer Library onto the canvas. Edit the OutputSize property to the number of classes in your data. For this example, enter 5. Delete the original loss3-classifier layer and connect your new layer in its place.

Select the last layer, the classification layer. In the Properties pane, the layer property OutputSize shows 1000 classes and the first few class names.

For transfer learning, you need to replace the output layer. Scroll to the end of the Layer Library and drag a new classificationLayer onto the canvas. Delete the original output layer and connect your new layer in its place. For a new output layer, you do not need to set the OutputSize. At training time, trainNetwork automatically sets the output classes of the layer from the data.

Make New Layers Learn Faster

Edit learning rates to learn faster in the new layer than in the transferred layers. On your new fullyConnectedLayer layer, set WeightLearnRateFactor and BiasLearnRateFactor to 10.

Check Network

To check the network and examine more details of the layers, click Analyze. The edited network is ready for training if the Deep Learning Network Analyzer reports zero errors.

Export Network for Training

To export the network to the workspace, return to the Deep Network Designer and click Export. The Deep Network Designer exports the network to a new variable containing the edited network layers, called lgraph_1. After exporting, you can supply the layer variable to the trainNetwork function. You can also generate MATLAB® code that recreates the network architecture and returns it as a variable in the workspace. For more information, see Generate MATLAB Code from Deep Network Designer.

Train Network Exported from Deep Network Designer

This example shows how to use a network exported from Deep Network Designer for transfer learning. After preparing the network in the app, you need to:

  • Resize images.

  • Specify training options.

  • Train the network.

Resize Images for Transfer Learning

For transfer learning, resize your images to match the input size of the pretrained network. To find the image input size of the network, in Deep Network Designer, examine the imageInputLayer. For GoogLeNet, the InputSize is 224x224.

Unzip and load the images as an image datastore. This very small data set contains only 75 images in 5 classes. Divide the data into 70% for training and 30% for validation.

imds = imageDatastore('MerchData', ...
    'IncludeSubfolders',true, ...
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7);

If your training images are in a folder with subfolders for each class, you can create a datastore for your data by replacing MerchData with the folder location. Check the number of classes - you must prepare the network for transfer learning with the number of classes to match your data.

Resize images in the image datastores to match the pretrained network GoogLeNet.

augimdsTrain = augmentedImageDatastore([224 224],imdsTrain);
augimdsValidation = augmentedImageDatastore([224 224],imdsValidation);

You can also apply transformations to the images to help prevent the network from overfitting. For details, see imageDataAugmenter.

Set Training Options for Transfer Learning

Before training, specify training options.

  • For transfer learning, set InitialLearnRate to a small value to slow down learning in the transferred layers. In the app, you increased the learning rate factors for the fully connected layer to speed up learning in the new final layers. This combination of learning rate settings results in fast learning only in the new layers and slower learning in the other layers.

  • Specify a small number of epochs. An epoch is a full training cycle on the entire training data set. For transfer learning, you do not need to train for as many epochs. Shuffle the data every epoch.

  • Specify the mini-batch size, that is, how many images to use in each iteration.

  • Specify validation data and validation frequency.

  • Turn on the training plot to monitor progress while you train.

options = trainingOptions('sgdm', ...
    'MiniBatchSize',10, ...
    'MaxEpochs',6, ...
    'InitialLearnRate',1e-4, ...
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',6, ...
    'Verbose',false, ...

Train Network

To train the network, supply the layers you exported from the app, here named lgraph_1, your resized images, and training options, to the trainNetwork function. By default, trainNetwork uses a GPU if available (requires Parallel Computing Toolbox™). Otherwise, it uses a CPU. Training is fast because the data set is so small.

net = trainNetwork(augimdsTrain,lgraph_1,options);

Test Trained Network by Classifying Validation Images

Use the fine-tuned network to classify the validation images, and calculate the classification accuracy.

[YPred,probs] = classify(net,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
accuracy = 0.9000

Display four sample validation images with predicted labels and predicted probabilities.

idx = randperm(numel(imdsValidation.Files),4);
for i = 1:4
    I = readimage(imdsValidation,idx(i));
    label = YPred(idx(i));
    title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%");

See Also

Related Topics