Main Content

Bootstrap Aggregation (Bagging) of Classification Trees Using TreeBagger

Statistics and Machine Learning Toolbox™ offers two objects that support bootstrap aggregation (bagging) of classification trees: TreeBagger created by using TreeBagger and ClassificationBaggedEnsemble created by using fitcensemble. See Comparison of TreeBagger and Bagged Ensembles for differences between TreeBagger and ClassificationBaggedEnsemble.

This example shows the workflow for classification using the features in TreeBagger only.

Use ionosphere data with 351 observations and 34 real-valued predictors. The response variable is categorical with two levels:

  • 'g' represents good radar returns.

  • 'b' represents bad radar returns.

The goal is to predict good or bad returns using a set of 34 measurements.

Fix the initial random seed, grow 50 trees, inspect how the ensemble error changes with accumulation of trees, and estimate feature importance. For classification, it is best to set the minimal leaf size to 1 and select the square root of the total number of features for each decision split at random. These settings are defaults for TreeBagger used for classification.

load ionosphere
rng(1945,'twister')
b = TreeBagger(50,X,Y,'OOBPredictorImportance','On');
figure
plot(oobError(b))
xlabel('Number of Grown Trees')
ylabel('Out-of-Bag Classification Error')

Figure contains an axes object. The axes object with xlabel Number of Grown Trees, ylabel Out-of-Bag Classification Error contains an object of type line.

The method trains ensembles with few trees on observations that are in bag for all trees. For such observations, it is impossible to compute the true out-of-bag prediction, and TreeBagger returns the most probable class for classification and the sample mean for regression. You can change the default value returned for in-bag observations using the DefaultYfit property. If you set the default value to an empty character vector for classification, the method excludes in-bag observations from computation of the out-of-bag error. In this case, the curve is more variable when the number of trees is small, either because some observations are never out of bag (and are therefore excluded) or because their predictions are based on few trees.

b.DefaultYfit = '';
figure
plot(oobError(b))
xlabel('Number of Grown Trees')
ylabel('Out-of-Bag Error Excluding In-Bag Observations')

Figure contains an axes object. The axes object with xlabel Number of Grown Trees, ylabel Out-of-Bag Error Excluding In-Bag Observations contains an object of type line.

The OOBIndices property of TreeBagger tracks which observations are out of bag for what trees. Using this property, you can monitor the fraction of observations in the training data that are in bag for all trees. The curve starts at approximately 2/3, which is the fraction of unique observations selected by one bootstrap replica, and goes down to 0 at approximately 10 trees.

finbag = zeros(1,b.NumTrees);
for t=1:b.NTrees
    finbag(t) = sum(all(~b.OOBIndices(:,1:t),2));
end
finbag = finbag / size(X,1);
figure
plot(finbag)
xlabel('Number of Grown Trees')
ylabel('Fraction of In-Bag Observations')

Figure contains an axes object. The axes object with xlabel Number of Grown Trees, ylabel Fraction of In-Bag Observations contains an object of type line.

Estimate feature importance.

figure
bar(b.OOBPermutedPredictorDeltaError)
xlabel('Feature Index')
ylabel('Out-of-Bag Feature Importance')

Figure contains an axes object. The axes object with xlabel Feature Index, ylabel Out-of-Bag Feature Importance contains an object of type bar.

Select the features yielding an importance measure greater than 0.75. This threshold is chosen arbitrarily.

idxvar = find(b.OOBPermutedPredictorDeltaError>0.75)
idxvar = 1×5

     3     5     7     8    27

Having selected the most important features, grow a larger ensemble on the reduced feature set. Save time by not permuting out-of-bag observations to obtain new estimates of feature importance for the reduced feature set (set OOBVarImp to 'off'). You would still be interested in obtaining out-of-bag estimates of classification error (set OOBPred to 'on').

b5v = TreeBagger(100,X(:,idxvar),Y,'OOBPredictorImportance','off','OOBPrediction','on');
figure
plot(oobError(b5v))
xlabel('Number of Grown Trees')
ylabel('Out-of-Bag Classification Error')

Figure contains an axes object. The axes object with xlabel Number of Grown Trees, ylabel Out-of-Bag Classification Error contains an object of type line.

For classification ensembles, in addition to classification error (fraction of misclassified observations), you can also monitor the average classification margin. For each observation, the margin is defined as the difference between the score for the true class and the maximal score for other classes predicted by this tree. The cumulative classification margin uses the scores averaged over all trees and the mean cumulative classification margin is the cumulative margin averaged over all observations. The oobMeanMargin method with the 'mode' argument set to 'cumulative' (default) shows how the mean cumulative margin changes as the ensemble grows: every new element in the returned array represents the cumulative margin obtained by including a new tree in the ensemble. If training is successful, you would expect to see a gradual increase in the mean classification margin.

The method trains ensembles with few trees on observations that are in bag for all trees. For such observations, it is impossible to compute the true out-of-bag prediction, and TreeBagger returns the most probable class for classification and the sample mean for regression.

For decision trees, a classification score is the probability of observing an instance of this class in this tree leaf. For example, if the leaf of a grown decision tree has five 'good' and three 'bad' training observations in it, the scores returned by this decision tree for any observation fallen on this leaf are 5/8 for the 'good' class and 3/8 for the 'bad' class. These probabilities are called 'scores' for consistency with other classifiers that might not have an obvious interpretation for numeric values of returned predictions.

figure
plot(oobMeanMargin(b5v));
xlabel('Number of Grown Trees')
ylabel('Out-of-Bag Mean Classification Margin')

Figure contains an axes object. The axes object with xlabel Number of Grown Trees, ylabel Out-of-Bag Mean Classification Margin contains an object of type line.

Compute the matrix of proximities and examine the distribution of outlier measures. Unlike regression, outlier measures for classification ensembles are computed within each class separately.

b5v = fillProximities(b5v);
figure
histogram(b5v.OutlierMeasure)
xlabel('Outlier Measure')
ylabel('Number of Observations')

Figure contains an axes object. The axes object with xlabel Outlier Measure, ylabel Number of Observations contains an object of type histogram.

Find the class of the extreme outliers.

extremeOutliers = b5v.Y(b5v.OutlierMeasure>40)
extremeOutliers = 6x1 cell
    {'g'}
    {'g'}
    {'g'}
    {'g'}
    {'g'}
    {'g'}

percentGood = 100*sum(strcmp(extremeOutliers,'g'))/numel(extremeOutliers)
percentGood = 100

All of the extreme outliers are labeled 'good'.

As for regression, you can plot scaled coordinates, displaying the two classes in different colors using the 'Colors' name-value pair argument of mdsProx. This argument takes a character vector in which every character represents a color. The software does not rank class names. Therefore, it is best practice to determine the position of the classes in the ClassNames property of the ensemble.

gPosition = find(strcmp('g',b5v.ClassNames))
gPosition = 2

The 'bad' class is first and the 'good' class is second. Display scaled coordinates using red for the 'bad' class and blue for the 'good' class observations.

figure
[s,e] = mdsProx(b5v,'Colors','rb');
xlabel('First Scaled Coordinate')
ylabel('Second Scaled Coordinate')

Figure contains an axes object. The axes object with xlabel First Scaled Coordinate, ylabel Second Scaled Coordinate contains 2 objects of type scatter. These objects represent b, g.

Plot the first 20 eigenvalues obtained by scaling. The first eigenvalue clearly dominates and the first scaled coordinate is most important.

figure
bar(e(1:20))
xlabel('Scaled Coordinate Index')
ylabel('Eigenvalue')

Figure contains an axes object. The axes object with xlabel Scaled Coordinate Index, ylabel Eigenvalue contains an object of type bar.

Another way of exploring the performance of a classification ensemble is to plot its receiver operating characteristic (ROC) curve or another performance curve suitable for the current problem. Obtain predictions for out-of-bag observations. For a classification ensemble, the oobPredict method returns a cell array of classification labels as the first output argument and a numeric array of scores as the second output argument. The returned array of scores has two columns, one for each class. In this case, the first column is for the 'bad' class and the second column is for the 'good' class. One column in the score matrix is redundant because the scores represent class probabilities in tree leaves and by definition add up to 1.

[Yfit,Sfit] = oobPredict(b5v);

Use rocmetrics to compute a performance curve. By default, rocmetrics computes true positive rates and false positive rates for a ROC curve.

rocObj = rocmetrics(b5v.Y,Sfit(:,gPosition),'g');

Plot the ROC curve for the 'good' class by using the plot function of rocmetrics.

plot(rocObj)

Figure contains an axes object. The axes object with title ROC Curve, xlabel False Positive Rate, ylabel True Positive Rate contains 3 objects of type roccurve, scatter, line. These objects represent g (AUC = 0.9689), g Model Operating Point.

Instead of the standard ROC curve, you might want to plot, for example, ensemble accuracy versus threshold on the score for the 'good' class. Compute accuracy by using the addMetrics function of rocmetrics. Accuracy is the fraction of correctly classified observations, or equivalently, 1 minus the classification error.

rocObj = addMetrics(rocObj,'Accuracy');

Create a plot of ensemble accuracy versus threshold.

thre = rocObj.Metrics.Threshold;
accu = rocObj.Metrics.Accuracy;
plot(thre,accu)
xlabel('Threshold for ''good'' Returns')
ylabel('Classification Accuracy')

Figure contains an axes object. The axes object with xlabel Threshold for 'good' Returns, ylabel Classification Accuracy contains an object of type line.

The curve shows a flat region indicating that any threshold from 0.2 to 0.6 is a reasonable choice. By default, a classification model assigns classification labels using 0.5 as the boundary between the two classes. You can find exactly what accuracy this corresponds to.

[~,idx] = min(abs(thre-0.5));
accu(idx)
ans = 0.9316

Find the maximal accuracy.

[maxaccu,iaccu] = max(accu)
maxaccu = 0.9345
iaccu = 99

The maximal accuracy is a little higher than the default one. The optimal threshold is therefore.

thre(iaccu)
ans = 0.5278

See Also

| | | | | | |

Related Topics