I want to train a network and save the weights. Then use a new data set and resume training with the weights from the previous data set.
20 次查看(过去 30 天)
显示 更早的评论
It seems train always randomizes the weights when called. I can't seem to maintain the weights from the first training set.
0 个评论
采纳的回答
Sonam Gupta
2018-3-28
You can continue the training from weights obtained by previous data set by extracting the layers from the network's "Layers" property, and then passing it to "trainNetwork", as follows:
if true
% Train a network
net = trainNetwork(XTrain, YTrain, layers, options);
% Extract layers from the trained network
newLayers = net.Layers;
% Retrain the network, but start from where we left off
newNet = trainNetwork(XTrain, YTrain, newLayers, options);
'trainNetwork' will always use the weights that are stored in the layers which you pass in for training.
3 个评论
Moh. Saadat
2022-8-29
There is a small caveat to this: check that whether your output 'net' is a LayerGraph or a DAGNetwork. If it is not a LayerGraph, use layerGraph(net) instead of net.Layers.
更多回答(0 个)
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Image Data Workflows 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!