Main Content

本页翻译不是最新的。点击此处可查看最新英文版本。

自定义训练循环

自定义深度学习训练循环和损失函数

如果 trainingOptions 函数不提供任务所需的训练选项,或者自定义输出层不支持所需的损失函数,则您可以定义自定义训练循环。对于层图不支持的模型,您可以将自定义模型定义为函数。要了解详细信息,请参阅定义自定义训练循环、损失函数和网络

函数

全部展开

dlnetworkDeep learning neural network (自 R2019b 起)
resetStateReset state parameters of neural network
plot绘制神经网络架构
addInputLayerAdd input layer to network (自 R2022b 起)
addLayersAdd layers to neural network
removeLayersRemove layers from neural network
connectLayersConnect layers in neural network
disconnectLayersDisconnect layers in neural network
replaceLayerReplace layer in neural network
summaryPrint network summary (自 R2022b 起)
initializeInitialize learnable and state parameters of a dlnetwork (自 R2021a 起)
networkDataLayoutDeep learning network data layout for learnable parameter initialization (自 R2022b 起)
layerGraph(Not recommended) Graph of network layers for deep learning
setL2FactorSet L2 regularization factor of layer learnable parameter
getL2FactorGet L2 regularization factor of layer learnable parameter
setLearnRateFactorSet learn rate factor of layer learnable parameter
getLearnRateFactorGet learn rate factor of layer learnable parameter
forwardCompute deep learning network output for training (自 R2019b 起)
predictCompute deep learning network output for inference (自 R2019b 起)
adamupdateUpdate parameters using adaptive moment estimation (Adam) (自 R2019b 起)
rmspropupdate Update parameters using root mean squared propagation (RMSProp) (自 R2019b 起)
sgdmupdate Update parameters using stochastic gradient descent with momentum (SGDM) (自 R2019b 起)
lbfgsupdateUpdate parameters using limited-memory BFGS (L-BFGS) (自 R2023a 起)
lbfgsStateState of limited-memory BFGS (L-BFGS) solver (自 R2023a 起)
dlupdate Update parameters using custom function (自 R2019b 起)
trainingProgressMonitorMonitor and plot training progress for deep learning custom training loops (自 R2022b 起)
updateInfoUpdate information values for custom training loops (自 R2022b 起)
recordMetricsRecord metric values for custom training loops (自 R2022b 起)
groupSubPlotGroup metrics in training plot (自 R2022b 起)
padsequencesPad or truncate sequence data to same length (自 R2021a 起)
minibatchqueueCreate mini-batches for deep learning (自 R2020b 起)
onehotencodeEncode data labels into one-hot vectors (自 R2020b 起)
onehotdecodeDecode probability vectors into class labels (自 R2020b 起)
nextObtain next mini-batch of data from minibatchqueue (自 R2020b 起)
resetReset minibatchqueue to start of data (自 R2020b 起)
shuffleShuffle data in minibatchqueue (自 R2020b 起)
hasdataDetermine if minibatchqueue can return mini-batch (自 R2020b 起)
partitionPartition minibatchqueue (自 R2020b 起)
dlarrayDeep learning array for customization (自 R2019b 起)
dlgradientCompute gradients for custom training loops using automatic differentiation (自 R2019b 起)
dlfevalEvaluate deep learning model for custom training loops (自 R2019b 起)
dimsdlarray 的维度标签 (自 R2019b 起)
finddimFind dimensions with specified label (自 R2019b 起)
stripdimsRemove dlarray data format (自 R2019b 起)
extractdatadlarray 中提取数据 (自 R2019b 起)
isdlarrayCheck if object is dlarray (自 R2020b 起)
crossentropyCross-entropy loss for classification tasks (自 R2019b 起)
l1lossL1 loss for regression tasks (自 R2021b 起)
l2lossL2 loss for regression tasks (自 R2021b 起)
huberHuber loss for regression tasks (自 R2021a 起)
mseHalf mean squared error (自 R2019b 起)
ctcConnectionist temporal classification (CTC) loss for unaligned sequence classification (自 R2021a 起)
dlaccelerateAccelerate deep learning function for custom training loops (自 R2021a 起)
AcceleratedFunctionAccelerated deep learning function (自 R2021a 起)
clearCacheClear accelerated deep learning function trace cache (自 R2021a 起)

主题

自定义训练循环

自动微分

深度学习函数加速

相关信息