predict function not working in custom training loop

10 次查看(过去 30 天)
I am building a custom training loop for a simple LSTM classification network because I need a custom loss function (specifically, 0-1 loss). I have followed a tutorial but when I call the predict function within my custom loss function, I get the error: 'Undefined function 'predict' for input arguments of type 'nnet.cnn.layer.Layer'.'
I can successfully train the network with trainNetwork, so I am wondering what trainNetwork is doing that I am missing. If I call predict after training with trainNetwork, it works, but not in the training loop.
my input is a 50x1 sequence that is either classified as 1 or 2 (depending on if its average is positive or negative).
My network is defined as follows:
numFeatures = 1; %input data value (50 time points in sequence)
numHiddenUnits = 100;
numClasses = 2; %Left/rigth decision at the end
layers = [ ...
sequenceInputLayer(numFeatures)
lstmLayer(numHiddenUnits,"OutputMode","last")
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
net = layers;
and my custom loss function is:
function [gradients,state,loss] = customGradients2Lay(net,dlX,Ylabel)
[Y,state]=predict(net,dlX);
loss=loss01(Y,Ylabel);
gradients=dlgradient(loss,net.Learnables);
end
function loss = loss01(Y, T)
if isequal(Y,T)
loss = 0;
else
loss = 1;
end
end
My training loop just calls a random dataset to test on and sends it to predict. The error again is: Undefined function 'predict' for input arguments of type 'nnet.cnn.layer.Layer'.
I also am not sure why its calling nnet.cnn. I even built a custom fully connected layer to try to get around this and it was still calling nnet.cnn class.
What am I missing?

采纳的回答

James Gross
James Gross 2022-10-25
Hello,
To train your network in a custom training loop, you must specify your network as a dlnetwork.
net = dlnetwork(layers);
You should then be able to train and call predict on your network as desired. For examples of how to train using a custom training loop with a dlnetwork, you can refer to one of the following:
I hope this information helps!

更多回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Image Data Workflows 的更多信息

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by