Main Content

Custom Training Loops

Customize deep learning training loops and loss functions

If the trainingOptions function does not provide the training options that you need for your task, or custom output layers do not support the loss functions that you need, then you can define a custom training loop. For models that cannot be specified as networks of layers, you can define the model as a function. To learn more, see Define Custom Training Loops, Loss Functions, and Networks.


expand all

dlnetworkDeep learning neural network (Since R2019b)
imagePretrainedNetworkPretrained neural network for images (Since R2024a)
resnetNetwork2-D residual neural network (Since R2024a)
resnet3dNetwork3-D residual neural network (Since R2024a)
addLayersAdd layers to neural network
removeLayersRemove layers from neural network
replaceLayerReplace layer in neural network
connectLayersConnect layers in neural network
disconnectLayersDisconnect layers in neural network
addInputLayerAdd input layer to network (Since R2022b)
initializeInitialize learnable and state parameters of a dlnetwork (Since R2021a)
networkDataLayoutDeep learning network data layout for learnable parameter initialization (Since R2022b)
setL2FactorSet L2 regularization factor of layer learnable parameter
getL2FactorGet L2 regularization factor of layer learnable parameter
setLearnRateFactorSet learn rate factor of layer learnable parameter
getLearnRateFactorGet learn rate factor of layer learnable parameter
plotPlot neural network architecture
summaryPrint network summary (Since R2022b)
analyzeNetworkAnalyze deep learning network architecture
checkLayerCheck validity of custom or function layer
isequalCheck equality of neural networks (Since R2021a)
isequalnCheck equality of neural networks ignoring NaN values (Since R2021a)
forwardCompute deep learning network output for training (Since R2019b)
predictCompute deep learning network output for inference (Since R2019b)
adamupdateUpdate parameters using adaptive moment estimation (Adam) (Since R2019b)
rmspropupdate Update parameters using root mean squared propagation (RMSProp) (Since R2019b)
sgdmupdate Update parameters using stochastic gradient descent with momentum (SGDM) (Since R2019b)
lbfgsupdateUpdate parameters using limited-memory BFGS (L-BFGS) (Since R2023a)
lbfgsStateState of limited-memory BFGS (L-BFGS) solver (Since R2023a)
dlupdate Update parameters using custom function (Since R2019b)
trainingProgressMonitorMonitor and plot training progress for deep learning custom training loops (Since R2022b)
updateInfoUpdate information values for custom training loops (Since R2022b)
recordMetricsRecord metric values for custom training loops (Since R2022b)
groupSubPlotGroup metrics in training plot (Since R2022b)
padsequencesPad or truncate sequence data to same length (Since R2021a)
minibatchqueueCreate mini-batches for deep learning (Since R2020b)
onehotencodeEncode data labels into one-hot vectors (Since R2020b)
onehotdecodeDecode probability vectors into class labels (Since R2020b)
nextObtain next mini-batch of data from minibatchqueue (Since R2020b)
resetReset minibatchqueue to start of data (Since R2020b)
shuffleShuffle data in minibatchqueue (Since R2020b)
hasdataDetermine if minibatchqueue can return mini-batch (Since R2020b)
partitionPartition minibatchqueue (Since R2020b)
dlarrayDeep learning array for customization (Since R2019b)
dlgradientCompute gradients for custom training loops using automatic differentiation (Since R2019b)
dlfevalEvaluate deep learning model for custom training loops (Since R2019b)
dimsDimension labels of dlarray (Since R2019b)
finddimFind dimensions with specified label (Since R2019b)
stripdimsRemove dlarray data format (Since R2019b)
extractdataExtract data from dlarray (Since R2019b)
isdlarrayCheck if object is dlarray (Since R2020b)
crossentropyCross-entropy loss for classification tasks (Since R2019b)
l1lossL1 loss for regression tasks (Since R2021b)
l2lossL2 loss for regression tasks (Since R2021b)
huberHuber loss for regression tasks (Since R2021a)
mseHalf mean squared error (Since R2019b)
ctcConnectionist temporal classification (CTC) loss for unaligned sequence classification (Since R2021a)
dlaccelerateAccelerate deep learning function for custom training loops (Since R2021a)
AcceleratedFunctionAccelerated deep learning function (Since R2021a)
clearCacheClear accelerated deep learning function trace cache (Since R2021a)


Custom Training Loops

Automatic Differentiation

Deep Learning Function Acceleration

Related Information