Weighted Classification Layer for Time Series/LSTM

18 次查看(过去 30 天)
Hi,
Recently came across WeightedClassificationLayer example for the custom Deep Learning layer templates. Pleased as this is exactly what I'm after for my current problem. custom layer
Unfortunately, the layer throws an error during backpropagation when I try to utilise with an LSTM. Is this because the layer works only with imageInputLayer problems? In my mind it ought to work the same except that dimensions 1-4 are used for images (height, width, channels, observations) and 2 additional dimensions are used for time series as I understand (featureDimension and SequenceLength).
Could anyone guide me on altering the tutorial such that it works for time series? Or does it work for anyone else? The error message isn't too descriptive, simply:
'Error using trainNetwork (line 150)
Error using 'backwardLoss' in layer weightedClassificationLayer. The function threw an error and could not be executed
  2 个评论
Conor Daly
Conor Daly 2018-12-12
Hi Stuart,
The issue here is that the dimensions of the input that the network uses differs between image input networks and sequence networks.
As described in the custom layer page that you linked to, image classification loss layers use shape whereas for sequence-to-sequence problems, the shape is . Here Kis the number of classes for the classification problem, N is the number of observations, or mini-batch size, and S is the sequence length, or number of time steps.
So to make the weighted classification layer work for sequence-to-sequence problems, we need to modiy the forwardLoss method as follows:
function loss = forwardLoss(layer, Y, T)
% loss = forwardLoss(layer, Y, T) returns the weighted cross
% entropy loss between the predictions Y and the training
% targets T.
% Find observation and sequence dimensions of Y
[~, N, S] = size(Y);
% Reshape ClassWeights to KxNxS
W = repmat(layer.ClassWeights(:), 1, N, S);
% Compute the loss
loss = -sum( W(:).*T(:).*log(Y(:)) )/N;
end
And then the backwardLoss method becomes:
function dLdY = backwardLoss(layer, Y, T)
% dLdY = backwardLoss(layer, Y, T) returns the derivatives of
% the weighted cross entropy loss with respect to the
% predictions Y.
% Find observation and sequence dimensions of Y
[~, N, S] = size(Y);
% Reshape ClassWeights to KxNxS
W = repmat(layer.ClassWeights(:), 1, N, S);
% Compute the derivative
dLdY = -(W.*T./Y)/N;
end
Stuart Whipp
Stuart Whipp 2018-12-12
Thank you very much for the expedient response! Seems obvious reading your code but I spent hours yesterday failing to make this work correctly so thank you very, very much. Hugely grateful :) I believe there's a similar question on another thread and so I'll point them towards your answer.

请先登录,再进行评论。

回答(1 个)

Dario Walter
Dario Walter 2020-6-9
Dear Conor,
would you mind to explain how it works for a Sequence-to-label classification?
I have imbalanced sequence data (98% is false, 2% is true). As a first step, I have to change weights xx1 and xx2 (do you have a recommendation for this?).
classWeight = [xx1 xx2];
layer = myClassificationLayer(classWeights);
Afterwards I did
numClasses = numel(classWeights);
validInputSize = [vv1 vv2 numClasses];
checkLayer(layer,validInputSize,'ObservationDimension',2);
vv1 refers to the number of classes (=2) and vv2 to the number of observations(of one sequence or ALL observations in the dataset, e.g. vv2 = 10000 observations, each with 30 timesteps).
How do I have to modify forwardLoss and backwardLoss function of class ClassificationLayer?
Thank you so much for your help!
classdef myClassificationLayer < nnet.layer.ClassificationLayer
properties
ClassWeights
end
methods
function layer = myClassificationLayer(classWeights, name)
layer.ClassWeights = classWeights;
if nargin == 2
layer.Name = name;
end
layer.Description = 'Weighted cross entropy';
end
function loss = forwardLoss(layer, Y, T)
N = size(Y,4);
Y = squeeze(Y);
T = squeeze(T);
W = layer.ClassWeights;
loss = -sum(W*(T.*log(Y)))/N;
end
function dLdY = backwardLoss(layer, Y, T)
end
end
end

产品

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by