Perform Instance Segmentation Using SOLOv2
This example shows how to segment object instances of randomly rotated machine parts in a bin using a deep learning SOLOv2 network.
Instance segmentation is a computer vision technique in which you detect and localize objects while simultaneously generating a segmentation map for each of the detected instances. For more information about instance segmentation with SOLOv2, see Get Started with SOLOv2 for Instance Segmentation.
This example first shows how to perform instance segmentation using a pretrained SOLOv2 network that can detect a single class. Then, you can optionally configure and train a SOLOv2 network using transfer learning, and evaluate prediction results.
Download Pretrained SOLOv2 Network
By default, this example downloads a pretrained version of the SOLOv2 instance segmentation network using the downloadTrainedNetwork
helper function. The helper function is attached to this example as a supporting file. You can use the pretrained network to run the entire example without waiting for training to complete.
trainedSOLOv2_url = "https://ssd.mathworks.com/supportfiles/vision/data/trainedSOLOv2BinDataset.zip";
downloadTrainedNetwork(trainedSOLOv2_url,pwd);
Downloading pretrained network. This can take several minutes to download... Done.
load("trainedSOLOv2.mat");
Download Bin Picking Dataset
This example uses the bin picking data set. The data set contains 150 images of 3-D pipe connectors, generated with Simulink® software. The data consists of images of machine parts lying at random orientations inside a bin, viewed from different angles and under different lighting conditions. The data set contains instance mask information for every object in every image, and combines all types of machine parts into a single class.
Specify dataDir
as the location of the data set. Download the data set using the downloadBinObjectData
helper function. This function is attached to the example as a supporting file.
dataDir = fullfile(tempdir,"BinDataset"); dataset_url = "https://ssd.mathworks.com/supportfiles/vision/data/binDataset.zip"; downloadBinObjectData(dataset_url,dataDir);
Perform Instance Segmentation
Read a sample image from the data set.
sampleImage = imread("testBinDataImage.png");
Predict the mask, labels, and confidence scores for each object instance using the segmentObjects
function.
[masks,labels,scores] = segmentObjects(net,sampleImage,Threshold=0.4);
Display the instance masks over the image using the insertObjectMask
function. Specify a colormap using the lines
function, so that each object instance appears in a different color. Use the getBoxFromMask
helper function to generate bounding boxes corresponding to each segmented object instance and overlay them on the image with probability scores as labels.
maskColors = lines(numel(labels)); overlayedMasks = insertObjectMask(sampleImage,masks,MaskColor=maskColors); imshow(overlayedMasks) boxes = getBoxFromMask(masks); showShape("rectangle",boxes,Label="Scores: "+num2str(scores),LabelOpacity=0.4);
Prepare Data for Training
Create a file datastore that reads the annotation data from MAT files. Use the matReaderBinData
function, attached to the example as a supporting file, to parse the MAT files and return the corresponding training data as a 1-by-4 cell array containing image data, bounding boxes, object masks, and labels.
annsDir = fullfile(dataDir,"synthetic_parts_dataset","annotations"); ds = fileDatastore(annsDir,FileExtensions=".mat",ReadFcn=@(x)matReaderBinData(x,dataDir));
Partition Data
To improve the reproducibility of this example, set the global random state to the default state.
rng("default");
Split the data into training, validation, and test sets. Because the total number of images is relatively small, allocate a relatively large percentage (70%) of the data for training. Allocate 15% for validation and the rest for testing.
numImages = length(ds.Files); numTrain = floor(0.7*numImages); numVal = floor(0.15*numImages); shuffledIndices = randperm(numImages); trainDS = subset(ds,shuffledIndices(1:numTrain)); valDS = subset(ds,shuffledIndices(numTrain+1:numTrain+numVal)); testDS = subset(ds,shuffledIndices(numTrain+numVal+1:end));
Visualize Training Data
Preview the ground truth data for training by reading a sample image from the training subset of the file datastore.
gsSample = preview(trainDS); gsImg = gsSample{1}; boxes = gsSample{2}; labels = gsSample{3}; masks = gsSample{4};
Visualize the ground truth data by using the insertObjectMasks
function to overlay the instance masks and corresponding bounding boxes and labels on the sample image.
overlayedMasks = insertObjectMask(gsImg,masks,Opacity=0.5); imshow(overlayedMasks) showShape("rectangle",boxes,Label=string(labels),Color="green");
Define SOLOv2 Network Architecture
Create the SOLOv2 instance segmentation model by using the solov2
object. Specify the name of the pretrained SOLOv2 instance segmentation network trained on COCO data set. Specify the class name, the estimated anchor boxes, and the network input size. Specify an input size to which all images must be resized using the optional InputSize
name-value argument.
networkToTrain = solov2("resnet50-coco","Object",InputSize=[736 1280 3]);
Specify Training Options
Specify network training options using the trainingOptions
(Deep Learning Toolbox) function. Train the instance segmentation network using the SGDM solver for five epochs. Specify the learning rate dropping factor of 0.99 every epoch. To ensure the convergence of gradients in the initial iterations, set the GradientThreshold
name-value argument to 35
. Specify the ValidationData
name-value argument as the validation data, valDS
.
options = trainingOptions("sgdm", ... InitialLearnRate=0.0005, ... LearnRateSchedule="piecewise", ... LearnRateDropPeriod=1, ... LearnRateDropFactor=0.99, ... Momentum=0.9, ... MaxEpochs=5, ... MiniBatchSize=4, ... ExecutionEnvironment="auto", ... VerboseFrequency=5, ... Plots="training-progress", ... ResetInputNormalization=false, ... ValidationData=valDS, ... ValidationFrequency=25, ... GradientThreshold=35, ... OutputNetwork="best-validation-loss");
Train SOLOv2 Network
To train the network, set the doTraining
variable to true
. Train the network by using the trainSOLOV2
function. To reuse the extracted features from the pretrained backbone network and optimize the detection heads for the data set, freeze the feature extraction subnetwork by specifying the FreezeSubNetwork
name-value argument.
Train on one or more GPUs, if they are available. Using a GPU requires a Parallel Computing Toolbox™ license and a CUDA®-enabled NVIDIA® GPU. For more information, see GPU Computing Requirements (Parallel Computing Toolbox). Training takes about 15 minutes on an NVIDIA Titan RTX™ with 24 GB of memory.
doTraining = false; if doTraining net = trainSOLOV2(trainDS,networkToTrain,options,FreezeSubNetwork="backbone"); modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss")); save(fullfile(tempdir,"trainedSOLOv2"+modelDateTime+".mat"), ... "net"); else load("trainedSOLOv2.mat"); end
Evaluate Trained SOLOv2 Network
Evaluate the trained SOLOv2 network by measuring the average precision. Precision quantifies the ability of the network to classify objects correctly.
Detect the instance masks for all test images.
resultsDS = segmentObjects(net,testDS,Threshold=0.1);
Running SoloV2 network -------------------------- * Processed 23 images.
Calculate the average precision (AP) and mean average precision (mAP) metrics by using the evaluateInstanceSegmentation
function. In this example, AP and mAP are identical because the objects are in only one class.
metrics = evaluateInstanceSegmentation(resultsDS,testDS,0.5); summarize(metrics)
ans=1×3 table
NumObjects mAPOverlapAvg mAP0.5
__________ _____________ _______
184 0.98784 0.98784
Display the metrics for every test image to identify which images are not performing as expected.
display(metrics.ImageMetrics)
23×3 table NumObjects APOverlapAvg AP __________ ____________ __________ 1 8 1 {[ 1]} 2 8 1 {[ 1]} 3 8 1 {[ 1]} 4 8 1 {[ 1]} 5 8 1 {[ 1]} 6 8 1 {[ 1]} 7 8 0.85938 {[0.8594]} 8 8 1 {[ 1]} 9 8 1 {[ 1]} 10 8 1 {[ 1]} 11 8 1 {[ 1]} 12 8 1 {[ 1]} 13 8 0.85938 {[0.8594]} 14 8 1 {[ 1]} 15 8 1 {[ 1]} 16 8 1 {[ 1]} 17 8 1 {[ 1]} 18 8 1 {[ 1]} 19 8 1 {[ 1]} 20 8 1 {[ 1]} 21 8 1 {[ 1]} 22 8 1 {[ 1]} 23 8 1 {[ 1]}
A precision/recall (PR) curve highlights how precise an instance segmentation model is at varying levels of recall. The ideal precision is 1 at all recall levels. Extract the precision, recall, and average precision metrics from the evaluateInstanceSegmentation
function output.
[precision,recall] = precisionRecall(metrics); averagePrecision = averagePrecision(metrics);
Plot the PR curve for the test data.
figure plot(recall{:},precision{:}) title(sprintf("Average Precision for Single Class Instance Segmentation: " + "%.2f",averagePrecision)) xlabel("Recall") ylabel("Precision") grid on
Supporting Function
The getBoxFromMask
function converts instance masks to bounding boxes.
function boxes = getBoxFromMask(masks) for idx = 1:size(masks,3) mask = masks(:,:,idx); [ptsR, ptsC] = find(mask); minR = min(ptsR); maxR = max(ptsR); minC = min(ptsC); maxC = max(ptsC); boxes(idx,:) = [minC minR maxC-minC maxR-minR]; end end
See Also
solov2
| segmentObjects
| trainSOLOV2
| evaluateInstanceSegmentation
| insertObjectMask
Related Topics
- Get Started with SOLOv2 for Instance Segmentation
- Deep Learning in MATLAB (Deep Learning Toolbox)
- Datastores for Deep Learning (Deep Learning Toolbox)