how to define custom loss function like tf.train.A​​damOptimi​z​er().min​im​ize(los​s) ?

3 次查看(过去 30 天)
I put X=[4,100] into LSTM and received the last hidden state[4,1]. I use this state vector to estimate Q. I want to use Adam Optimizer to minimize a custom loss function. But options offered by official tool cannot do it. So I want to define a loss function and use Adam Optimizer to min the loss. How can I do it like python:train = tf.train.AdamOptimizer().minimize(loss)?
z=[0.01,0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,0.99];
z_normal=norminv(z);
numFeatures = 4;
A=4;
X = reshape(X_test.',4,100,[]);
Y = reshape(Y_test.',1,[]);
dlX = dlarray(X,'CBT');
dlY = dlarray(Y,'BT');
numHiddenUnits = 16;
outputdim=4;
H0 = zeros(numHiddenUnits,1);
C0 = zeros(numHiddenUnits,1);
weights = dlarray(randn(4*numHiddenUnits,numFeatures),'CU');
recurrentWeights = dlarray(randn(4*numHiddenUnits,numHiddenUnits),'CU');
bias = dlarray(randn(4*numHiddenUnits,1),'C');
[outputs,hiddenState,cellState] = lstm(dlX,H0,C0,weights,recurrentWeights,bias);
state_h=hiddenState(:,end);
w_out=normrnd(0,1,[outputdim,numHiddenUnits]);
b_out=normrnd(0,1,[outputdim,1]);
out=w_out*state_h+b_out;
params=out+[0;1;1;1];
mu=params(1,:);
sig=params(2,:);
utail=params(3,:);
vtail=params(4,:);
factor1=exp(z_normal.'*utail)/A+1;
factor2=exp(z_normal.'*utail)/A+1;
factor=factor1.*factor2;
Q=factor.*z_normal.'.*sig+mu;
error=dlY-Q;
error1=z.'.*error;
error2=(z.'-1).*error;
loss=mean(max(error1,error2));

回答(1 个)

Sai Bhargav Avula
Sai Bhargav Avula 2020-2-18
Hi,
You can create custom loss function by creating a function of the form loss = myLoss(Y,T), where Y is the network predictions, T are the targets. The loss can be used to update the gradients in the modelGradient function.
This link explains how to create custom layers and loss functions in matlab
Hope this helps!

类别

Help CenterFile Exchange 中查找有关 Deep Learning Toolbox 的更多信息

产品


版本

R2019b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by