predict function not working in custom training loop
9 次查看(过去 30 天)
显示 更早的评论
I am building a custom training loop for a simple LSTM classification network because I need a custom loss function (specifically, 0-1 loss). I have followed a tutorial but when I call the predict function within my custom loss function, I get the error: 'Undefined function 'predict' for input arguments of type 'nnet.cnn.layer.Layer'.'
I can successfully train the network with trainNetwork, so I am wondering what trainNetwork is doing that I am missing. If I call predict after training with trainNetwork, it works, but not in the training loop.
my input is a 50x1 sequence that is either classified as 1 or 2 (depending on if its average is positive or negative).
My network is defined as follows:
numFeatures = 1; %input data value (50 time points in sequence)
numHiddenUnits = 100;
numClasses = 2; %Left/rigth decision at the end
layers = [ ...
sequenceInputLayer(numFeatures)
lstmLayer(numHiddenUnits,"OutputMode","last")
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
net = layers;
and my custom loss function is:
function [gradients,state,loss] = customGradients2Lay(net,dlX,Ylabel)
[Y,state]=predict(net,dlX);
loss=loss01(Y,Ylabel);
gradients=dlgradient(loss,net.Learnables);
end
function loss = loss01(Y, T)
if isequal(Y,T)
loss = 0;
else
loss = 1;
end
end
My training loop just calls a random dataset to test on and sends it to predict. The error again is: Undefined function 'predict' for input arguments of type 'nnet.cnn.layer.Layer'.
I also am not sure why its calling nnet.cnn. I even built a custom fully connected layer to try to get around this and it was still calling nnet.cnn class.
What am I missing?
0 个评论
采纳的回答
James Gross
2022-10-25
Hello,
net = dlnetwork(layers);
You should then be able to train and call predict on your network as desired. For examples of how to train using a custom training loop with a dlnetwork, you can refer to one of the following:
I hope this information helps!
更多回答(0 个)
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Build Deep Neural Networks 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!