LSTM - Set special loss function
9 次查看(过去 30 天)
显示 更早的评论
Based on this great MatLab-example I would like to build a neural network classifying each timestep of a timeseries (x_i,y_i) (i=1:N) as 1 or 2. For training purpose I created 500 different timeseries and the corresponding target-vectors. In reality, about 85 % of a timeseries is in state 1 and the rest (15 %) in state 2. The training-process lookes like this:
It stagnates at about 85%, because having the Mean Squared Error as a loss-function, a policy classifying every timestep as 1 results in a "good" accuracy of 85 %. So nearly every timestep is classified as 1. I am quite sure this can be avoided by using another loss function, but unfortunately I do not know how I can create an arbitrary loss-function.
I would like to adapt the loss function in a way, that if it falsely classifies a true state 2 as 1, then the loss is weighted by a factor f.
How can this be done?
This is how the training is set up:
featureDimension = 2;
numHiddenUnits = 100;
numClasses = 2;
layers = [ ...
sequenceInputLayer(featureDimension)
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
options = trainingOptions('adam', ...
'GradientThreshold',1, ...
'InitialLearnRate',0.01, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',20, ...
'Verbose',0, ...
'Plots','training-progress');
net = trainNetwork(XTrain,YTrain,layers,options);
0 个评论
回答(2 个)
sii
2018-5-18
If it is still relevant, check following documentation. As far as I know, it is only available in the R2018a release.
0 个评论
Stuart Whipp
2018-12-10
I think what is needed is a weighted classification output so you can account for the imbalance in your classes. A custom layer tutorial exists for this, but it only works on image classification problems seemingly.
1 个评论
Stuart Whipp
2018-12-12
编辑:Stuart Whipp
2018-12-12
Conor Daly (staff) kindly answered my question this morning with a custom output layer that I can confirm has worked for my Use Case. Please take a look at this link as I believe it also answers your question.
Regards Stuart
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Sequence and Numeric Feature Data Workflows 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!