Main Content

自定义训练循环

自定义深度学习训练循环和损失函数

如果 trainingOptions 函数不提供任务所需的训练选项,或者自定义输出层不支持所需的损失函数,则您可以定义自定义训练循环。对于无法指定为由层组成的网络的模型,可以将模型定义为函数。要了解详细信息,请参阅定义自定义训练循环、损失函数和网络

函数

全部展开

dlnetworkDeep learning neural network
imagePretrainedNetworkPretrained neural network for images (自 R2024a 起)
resnetNetwork2-D residual neural network (自 R2024a 起)
resnet3dNetwork3-D residual neural network (自 R2024a 起)
addLayers向神经网络添加层
removeLayersRemove layers from neural network
replaceLayerReplace layer in neural network
connectLayers在神经网络中连接各层
disconnectLayersDisconnect layers in neural network
addInputLayerAdd input layer to network (自 R2022b 起)
initializeInitialize learnable and state parameters of a dlnetwork (自 R2021a 起)
networkDataLayoutDeep learning network data layout for learnable parameter initialization (自 R2022b 起)
setL2FactorSet L2 regularization factor of layer learnable parameter
getL2FactorGet L2 regularization factor of layer learnable parameter
setLearnRateFactorSet learn rate factor of layer learnable parameter
getLearnRateFactorGet learn rate factor of layer learnable parameter
plot绘制神经网络架构
summary打印网络摘要 (自 R2022b 起)
analyzeNetworkAnalyze deep learning network architecture
checkLayerCheck validity of custom or function layer
isequalCheck equality of neural networks (自 R2021a 起)
isequalnCheck equality of neural networks ignoring NaN values (自 R2021a 起)
forwardCompute deep learning network output for training
predictCompute deep learning network output for inference
adamupdateUpdate parameters using adaptive moment estimation (Adam)
rmspropupdate Update parameters using root mean squared propagation (RMSProp)
sgdmupdate Update parameters using stochastic gradient descent with momentum (SGDM)
lbfgsupdateUpdate parameters using limited-memory BFGS (L-BFGS) (自 R2023a 起)
lbfgsStateState of limited-memory BFGS (L-BFGS) solver (自 R2023a 起)
dlupdate Update parameters using custom function
trainingProgressMonitorMonitor and plot training progress for deep learning custom training loops (自 R2022b 起)
updateInfoUpdate information values for custom training loops (自 R2022b 起)
recordMetricsRecord metric values for custom training loops (自 R2022b 起)
groupSubPlotGroup metrics in training plot (自 R2022b 起)
deep.gpu.deterministicAlgorithmsSet determinism of deep learning operations on the GPU to get reproducible results (自 R2024b 起)
padsequencesPad or truncate sequence data to same length (自 R2021a 起)
minibatchqueueCreate mini-batches for deep learning (自 R2020b 起)
onehotencodeEncode data labels into one-hot vectors (自 R2020b 起)
onehotdecodeDecode probability vectors into class labels (自 R2020b 起)
nextObtain next mini-batch of data from minibatchqueue (自 R2020b 起)
resetReset minibatchqueue to start of data (自 R2020b 起)
shuffleShuffle data in minibatchqueue (自 R2020b 起)
hasdataDetermine if minibatchqueue can return mini-batch (自 R2020b 起)
partitionPartition minibatchqueue (自 R2020b 起)
dlarray用于自定义的深度学习数组
dlgradientCompute gradients for custom training loops using automatic differentiation
dljacobianJacobian matrix deep learning operation (自 R2024b 起)
dldivergenceDivergence of deep learning data (自 R2024b 起)
dllaplacianLaplacian of deep learning data (自 R2024b 起)
dlfevalEvaluate deep learning model for custom training loops
dimsdlarray 对象的数据格式
finddimFind dimensions with specified label
stripdimsRemove dlarray data format
extractdatadlarray 中提取数据
isdlarrayCheck if object is dlarray (自 R2020b 起)
crossentropyCross-entropy loss for classification tasks
indexcrossentropyIndex cross-entropy loss for classification tasks (自 R2024b 起)
l1lossL1 loss for regression tasks (自 R2021b 起)
l2lossL2 loss for regression tasks (自 R2021b 起)
huberHuber loss for regression tasks (自 R2021a 起)
ctcConnectionist temporal classification (CTC) loss for unaligned sequence classification (自 R2021a 起)
mseHalf mean squared error
dlaccelerateAccelerate deep learning function for custom training loops (自 R2021a 起)
AcceleratedFunctionAccelerated deep learning function (自 R2021a 起)
clearCacheClear accelerated deep learning function trace cache (自 R2021a 起)

主题

自定义训练循环

自动微分

深度学习函数加速

相关信息

精选示例