Main Content

kfoldfun

Cross validate function

Syntax

vals = kfoldfun(obj,fun)

Description

vals = kfoldfun(obj,fun) cross validates the function fun by applying fun to the data stored in the cross-validated model obj. You must pass fun as a function handle.

Input Arguments

obj

Object of class RegressionPartitionedModel or RegressionPartitionedEnsemble. Create obj with fitrtree or fitrensemble along with one of the cross-validation options: 'CrossVal', 'KFold', 'Holdout', 'Leaveout', or 'CVPartition'. Alternatively, create obj from a regression tree or regression ensemble with crossval.

fun

A function handle for a cross-validation function. fun has the syntax

testvals = fun(CMP,Xtrain,Ytrain,Wtrain,Xtest,Ytest,Wtest)
  • CMP is a compact model stored in one element of the obj.Trained property.

  • Xtrain is the training matrix of predictor values.

  • Ytrain is the training array of response values.

  • Wtrain are the training weights for observations.

  • Xtest and Ytest are the test data, with associated weights Wtest.

  • The returned value testvals must have the same size across all folds.

Output Arguments

vals

The arrays of testvals output, concatenated vertically over all folds. For example, if testvals from every fold is a numeric vector of length N, kfoldfun returns a KFold-by-N numeric matrix with one row per fold.

Examples

Cross validate a regression tree, and obtain the mean squared error (see kfoldLoss):

load imports-85
t = fitrtree(X(:,[4 5]),X(:,16),...
    'predictornames',{'length' 'width'},...
    'responsename','price');
cv = crossval(t);
L = kfoldLoss(cv)

L =
  1.5489e+007

Examine the result of simple averaging of responses instead of using predictions:

f = @(cmp,Xtrain,Ytrain,Wtrain,Xtest,Ytest,Wtest)...
    mean((Ytest-mean(Ytrain)).^2)
mean(kfoldfun(cv,f))

ans =
  6.3497e+007