Make trainnet pass in one observation to the network at a time

2 次查看(过去 30 天)
As best I can tell, trainnet vectorizes its computations across the batch dimension of every minibatch. I.e., during training, it feeds minibatches of data into the network as a complete SxSxSxCxB array where B is the number of observations in the minibatch. This can be quite memory intensive depending on S^3*C, and limits the minibatch sizes that can be selected for training.
Is there no way to force trainnet to compute the loss function and its gradients incrementally, by feeding one member of the minibatch (of size SxSxSxCx1) to the network at a time?

采纳的回答

Katja Mogalle
Katja Mogalle 2024-5-6
Hi Matt,
The MiniBatchSize setting in trainingOptions is exactly the mechanism that lets you regulate the chunks of data propagated through the network at the same time. So if you run out of memory during training, it is best to try to reduce the MiniBatchSize value. Note that if the mini-batch size is small, it might be necessary to also reduce the InitialLearnRate setting so that at each iteration, the training doesn't overshoot into a gradient direction which is only based on few data samples.
You can read more about the training algorithms here. The page nicely explains how stochastic gradient descent works, if that is the solver you are using:
"[...] at each iteration the stochastic gradient descent algorithm evaluates the gradient and updates the parameters using a subset of the training data. A different subset, called a mini-batch, is used at each iteration. The full pass of the training algorithm over the entire training set using mini-batches is one epoch. Stochastic gradient descent is stochastic because the parameter updates computed using a mini-batch is a noisy estimate of the parameter update that would result from using the full data set."
"The stochastic gradient descent algorithm can oscillate along the path of steepest descent towards the optimum. Adding a momentum term to the parameter update is one way to reduce this oscillation."
So in a way, the momentum term in SGDM implements the incremental gradient that you mentioned.
Hope this helps.
  2 个评论
Matt J
Matt J 2024-5-6
编辑:Matt J 2024-5-6
Thanks @Katja Mogalle, but I guess you're implying that the answer is no, we cannot unvectorize the minibatch calculations?
Momentum may help to reduce the stochastic noise in the gradient calculation, but surely using larger minibatches is the more reliable way. If I used Minibatchsize=1, do you still suppose I could get reliable convergence just from the momentum mechanism? When I repeat training with small minibatch sizes, I sometimes get good results and sometimes bad results.
Katja Mogalle
Katja Mogalle 2024-5-6
I don't think there is a built-in way to unvectorize the minibatch calculations. You could try implementing a custom training loop where you do this yourself. But I haven't heard about this approach before.
I suppose if you can hold two sets of parameter gradients and one set of the parameters themselves, you're pretty close to being able to use a MiniBatchSize of 2. Some papers in the deep learning literature do seem to show decent results with MiniBatchSize of 2. E.g. https://arxiv.org/pdf/1804.07612
My intuition is still that SGD with Momentum is like a weighted sum of previous parameter gradients. I don't know what kind of network architecture you have, but you might have to try out several training options (e.g. varying learning rates, learning rate schedules, and the the optimization algorithm) to find a training scheme that works for your task. The ADAM solver might also be a good one to try, it keeps an element-wise moving average of both the parameter gradients and their squared values.
Certainly, if you can use larger batch sizes, that should make your gradient estimation more robust (up to a certain amount). Likely, you are using random shuffling of the data at each epoch. That could be a reason you are getting varying results.

请先登录,再进行评论。

更多回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Sequence and Numeric Feature Data Workflows 的更多信息

产品


版本

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by